CIRCT 20.0.0git
Loading...
Searching...
No Matches
LowerVectorizations.cpp
Go to the documentation of this file.
1//===- LowerVectorizations.cpp --------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Vector/IR/VectorOps.h"
14#include "mlir/IR/ImplicitLocOpBuilder.h"
15#include "mlir/Pass/Pass.h"
16
17#include "circt/Dialect/Arc/ArcPassesEnums.cpp.inc"
18
19namespace circt {
20namespace arc {
21#define GEN_PASS_DEF_LOWERVECTORIZATIONS
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
23} // namespace arc
24} // namespace circt
25
26using namespace mlir;
27using namespace circt;
28using namespace arc;
29
30/// Vectorizes the `arc.vectorize` boundary by packing the vector elements into
31/// an integer value. Returns the vectorized version of the op. May invalidate
32/// the passed operation.
33///
34/// Example:
35/// ```mlir
36/// %0:2 = arc.vectorize (%in0, %in1), (%in2, %in3) : (i1, i1, i1, i1) -> (i1,
37/// i1) { ^bb0(%arg0: i1, %arg1: i1):
38/// %1 = comb.and %arg0, %arg1 : i1
39/// arc.vectorize.return %1 : i1
40/// }
41/// ```
42/// becomes
43/// ```mlir
44/// %0 = comb.concat %in0, %in1 : i1, i1
45/// %1 = comb.concat %in2, %in3 : i1, i1
46/// %2 = arc.vectorize (%0), (%1) : (i2, i2) -> i2 {
47/// ^bb0(%arg0: i1, %arg1: i1):
48/// %11 = comb.and %arg0, %arg1 : i1
49/// arc.vectorize.return %11 : i1
50/// }
51/// %3 = comb.extract %2 from 0 : (i2) -> i1
52/// %4 = comb.extract %2 from 1 : (i2) -> i1
53/// ```
54static VectorizeOp lowerBoundaryScalar(VectorizeOp op) {
55 ImplicitLocOpBuilder builder(op.getLoc(), op);
56 SmallVector<ValueRange> vectors;
57
58 for (ValueRange range : op.getInputs()) {
59 unsigned bw = range.front().getType().getIntOrFloatBitWidth();
60 vectors.push_back(builder
61 .create<comb::ConcatOp>(
62 builder.getIntegerType(bw * range.size()), range)
63 ->getResults());
64 }
65
66 unsigned width = op->getResult(0).getType().getIntOrFloatBitWidth();
67 VectorizeOp newOp = builder.create<VectorizeOp>(
68 builder.getIntegerType(width * op->getNumResults()), vectors);
69 newOp.getBody().takeBody(op.getBody());
70
71 for (OpResult res : op.getResults()) {
72 Value newRes = builder.create<comb::ExtractOp>(
73 newOp.getResult(0), width * res.getResultNumber(), width);
74 res.replaceAllUsesWith(newRes);
75 }
76
77 op->erase();
78 return newOp;
79}
80
81/// Vectorizes the `arc.vectorize` boundary by using the `vector` type and
82/// dialect for SIMD-based vectorization. Returns the vectorized version of the
83/// op. May invalidate the passed operation.
84///
85/// Example:
86/// ```mlir
87/// %0:2 = arc.vectorize (%in0, %in1), (%in2, %in2) : (i64, i64, i32, i32) ->
88/// (i64, i64) { ^bb0(%arg0: i64, %arg1: i32):
89/// %c0_i32 = hw.constant 0 : i32
90/// %1 = comb.concat %c0_i32, %arg1 : i32, i32
91/// %2 = comb.and %arg0, %1 : i64
92/// arc.vectorize.return %2 : i64
93/// }
94/// ```
95/// becomes
96/// ```mlir
97/// %cst = arith.constant dense<0> : vector<2xi64>
98/// %0 = vector.insert %in0, %cst [0] : i64 into vector<2xi64>
99/// %1 = vector.insert %in1, %0 [1] : i64 into vector<2xi64>
100/// %2 = vector.broadcast %in2 : i32 to vector<2xi32>
101/// %3 = arc.vectorize (%1), (%2) : (vector<2xi64>, vector<2xi32>) ->
102/// vector<2xi64> { ^bb0(%arg0: i64, %arg1: i32):
103/// %c0_i32 = hw.constant 0 : i32
104/// %4 = comb.concat %c0_i32, %arg1 : i32, i32
105/// %5 = comb.and %arg0, %4 : i64
106/// arc.vectorize.return %5 : i64
107/// }
108/// %4 = vector.extract %3[0] : vector<2xi64>
109/// %5 = vector.extract %3[1] : vector<2xi64>
110/// ```
111static VectorizeOp lowerBoundaryVector(VectorizeOp op) {
112 ImplicitLocOpBuilder builder(op.getLoc(), op);
113 SmallVector<ValueRange> vectors;
114
115 for (ValueRange range : op.getInputs()) {
116 // Insert a broadcast operation if all elements of a vector are the same
117 // because it's a significantly cheaper instruction.
118 VectorType type = VectorType::get(SmallVector<int64_t>(1, range.size()),
119 range.front().getType());
120 if (llvm::all_equal(range)) {
121 vectors.push_back(
122 builder.create<vector::BroadcastOp>(type, range.front())
123 ->getResults());
124 continue;
125 }
126
127 // Otherwise do a gather.
128 ValueRange vector =
129 builder
130 .create<arith::ConstantOp>(DenseElementsAttr::get(
131 type, SmallVector<Attribute>(
132 range.size(),
133 builder.getIntegerAttr(type.getElementType(), 0))))
134 ->getResults();
135 for (auto [i, element] : llvm::enumerate(range))
136 vector = builder.create<vector::InsertOp>(element, vector.front(), i)
137 ->getResults();
138
139 vectors.push_back(vector);
140 }
141
142 VectorType resType = VectorType::get(
143 SmallVector<int64_t>(1, op->getNumResults()), op.getResult(0).getType());
144 VectorizeOp newOp = builder.create<VectorizeOp>(resType, vectors);
145 newOp.getBody().takeBody(op.getBody());
146
147 for (OpResult res : op.getResults())
148 res.replaceAllUsesWith(builder.create<vector::ExtractOp>(
149 newOp.getResult(0), res.getResultNumber()));
150
151 op->erase();
152 return newOp;
153}
154
155/// Vectorizes the boundary of the given `arc.vectorize` operation if it is not
156/// already vectorized. If the body of the `arc.vectorize` operation is already
157/// vectorized the same vectorization technique (SIMD or scalar) is chosen.
158/// Otherwise,
159/// * packs the vector in a scalar if it fits in a 64-bit integer or
160/// * uses the `vector` type and dialect for SIMD vectorization
161/// Returns the vectorized version of the op. May invalidate the passed
162/// operation.
163static VectorizeOp lowerBoundary(VectorizeOp op) {
164 // Nothing to do if it is already vectorized.
165 if (op.isBoundaryVectorized())
166 return op;
167
168 // If the body is already vectorized, we must use the same vectorization
169 // technique. Otherwise, we would produce invalid IR.
170 if (op.isBodyVectorized()) {
171 if (isa<VectorType>(op.getBody().front().getArgumentTypes().front()))
172 return lowerBoundaryVector(op);
173 return lowerBoundaryScalar(op);
174 }
175
176 // If the vector can fit in an i64 value, use scalar vectorization, otherwise
177 // use SIMD.
178 unsigned numLanes = op.getInputs().size();
179 unsigned maxLaneWidth = 0;
180 for (OperandRange range : op.getInputs())
181 maxLaneWidth =
182 std::max(maxLaneWidth, range.front().getType().getIntOrFloatBitWidth());
183
184 if ((numLanes * maxLaneWidth <= 64) &&
185 op->getResult(0).getType().getIntOrFloatBitWidth() *
186 op->getNumResults() <=
187 64)
188 return lowerBoundaryScalar(op);
189 return lowerBoundaryVector(op);
190}
191
192/// Vectorizes the body of the given `arc.vectorize` operation if it is not
193/// already vectorized. If the boundary of the `arc.vectorize` operation is
194/// already vectorized the same vectorization technique (SIMD or scalar) is
195/// chosen. Otherwise,
196/// * packs the vector in a scalar if it fits in a 64-bit integer or
197/// * uses the `vector` type and dialect for SIMD vectorization
198/// Returns the vectorized version of the op or failure. May invalidate the
199/// passed operation.
200static FailureOr<VectorizeOp> lowerBody(VectorizeOp op) {
201 if (op.isBodyVectorized())
202 return op;
203
204 return op->emitError("lowering body not yet supported");
205}
206
207/// Inlines the `arc.vectorize` operations body once both the boundary and body
208/// are vectorized.
209///
210/// Example:
211/// ```mlir
212/// %0 = comb.concat %in0, %in1 : i1, i1
213/// %1 = comb.concat %in2, %in2 : i1, i1
214/// %2 = arc.vectorize (%0), (%1) : (i2, i2) -> i2 {
215/// ^bb0(%arg0: i2, %arg1: i2):
216/// %12 = arith.andi %arg0, %arg1 : i2
217/// arc.vectorize.return %12 : i2
218/// }
219/// %3 = comb.extract %2 from 0 : (i2) -> i1
220/// %4 = comb.extract %2 from 1 : (i2) -> i1
221/// ```
222/// becomes
223/// ```mlir
224/// %0 = comb.concat %in0, %in1 : i1, i1
225/// %1 = comb.concat %in2, %in2 : i1, i1
226/// %2 = arith.andi %0, %1 : i2
227/// %3 = comb.extract %2 from 0 : (i2) -> i1
228/// %4 = comb.extract %2 from 1 : (i2) -> i1
229/// ```
230static LogicalResult inlineBody(VectorizeOp op) {
231 if (!(op.isBodyVectorized() && op.isBoundaryVectorized()))
232 return op->emitError(
233 "can only inline body if boundary and body are already vectorized");
234
235 Block &block = op.getBody().front();
236 for (auto [operand, arg] : llvm::zip(op.getInputs(), block.getArguments()))
237 arg.replaceAllUsesWith(operand.front());
238
239 Operation *terminator = block.getTerminator();
240 op->getResult(0).replaceAllUsesWith(terminator->getOperand(0));
241 terminator->erase();
242
243 op->getBlock()->getOperations().splice(op->getIterator(),
244 block.getOperations());
245 op->erase();
246
247 return success();
248}
249
250namespace {
251/// A pass to vectorize (parts of) an `arc.vectorize` operation.
252struct LowerVectorizationsPass
253 : public arc::impl::LowerVectorizationsBase<LowerVectorizationsPass> {
254 LowerVectorizationsPass() = default;
255 explicit LowerVectorizationsPass(LowerVectorizationsModeEnum mode)
256 : LowerVectorizationsPass() {
257 this->mode.setValue(mode);
258 }
259
260 void runOnOperation() override {
261 WalkResult result = getOperation().walk([&](VectorizeOp op) -> WalkResult {
262 switch (mode) {
263 case LowerVectorizationsModeEnum::Full:
264 if (auto newOp = lowerBody(lowerBoundary(op));
265 succeeded(newOp) && succeeded(inlineBody(*newOp)))
266 return WalkResult::advance();
267 return WalkResult::interrupt();
268 case LowerVectorizationsModeEnum::Boundary:
269 return lowerBoundary(op), WalkResult::advance();
270 case LowerVectorizationsModeEnum::Body:
271 return static_cast<LogicalResult>(lowerBody(op));
272 case LowerVectorizationsModeEnum::InlineBody:
273 return inlineBody(op);
274 }
275 llvm_unreachable("all enum cases must be handled above");
276 });
277 if (result.wasInterrupted())
278 signalPassFailure();
279 }
280};
281} // namespace
282
283std::unique_ptr<Pass>
284arc::createLowerVectorizationsPass(LowerVectorizationsModeEnum mode) {
285 return std::make_unique<LowerVectorizationsPass>(mode);
286}
static VectorizeOp lowerBoundary(VectorizeOp op)
Vectorizes the boundary of the given arc.vectorize operation if it is not already vectorized.
static FailureOr< VectorizeOp > lowerBody(VectorizeOp op)
Vectorizes the body of the given arc.vectorize operation if it is not already vectorized.
static VectorizeOp lowerBoundaryVector(VectorizeOp op)
Vectorizes the arc.vectorize boundary by using the vector type and dialect for SIMD-based vectorizati...
static LogicalResult inlineBody(VectorizeOp op)
Inlines the arc.vectorize operations body once both the boundary and body are vectorized.
static VectorizeOp lowerBoundaryScalar(VectorizeOp op)
Vectorizes the arc.vectorize boundary by packing the vector elements into an integer value.
std::unique_ptr< mlir::Pass > createLowerVectorizationsPass(LowerVectorizationsModeEnum mode=LowerVectorizationsModeEnum::Full)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.