CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerVectorizations.cpp
Go to the documentation of this file.
1//===- LowerVectorizations.cpp --------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Vector/IR/VectorOps.h"
14#include "mlir/IR/ImplicitLocOpBuilder.h"
15#include "mlir/Pass/Pass.h"
16
17#include "circt/Dialect/Arc/ArcPassesEnums.cpp.inc"
18
19namespace circt {
20namespace arc {
21#define GEN_PASS_DEF_LOWERVECTORIZATIONS
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
23} // namespace arc
24} // namespace circt
25
26using namespace mlir;
27using namespace circt;
28using namespace arc;
29
30/// Vectorizes the `arc.vectorize` boundary by packing the vector elements into
31/// an integer value. Returns the vectorized version of the op. May invalidate
32/// the passed operation.
33///
34/// Example:
35/// ```mlir
36/// %0:2 = arc.vectorize (%in0, %in1), (%in2, %in3) : (i1, i1, i1, i1) -> (i1,
37/// i1) { ^bb0(%arg0: i1, %arg1: i1):
38/// %1 = comb.and %arg0, %arg1 : i1
39/// arc.vectorize.return %1 : i1
40/// }
41/// ```
42/// becomes
43/// ```mlir
44/// %0 = comb.concat %in0, %in1 : i1, i1
45/// %1 = comb.concat %in2, %in3 : i1, i1
46/// %2 = arc.vectorize (%0), (%1) : (i2, i2) -> i2 {
47/// ^bb0(%arg0: i1, %arg1: i1):
48/// %11 = comb.and %arg0, %arg1 : i1
49/// arc.vectorize.return %11 : i1
50/// }
51/// %3 = comb.extract %2 from 0 : (i2) -> i1
52/// %4 = comb.extract %2 from 1 : (i2) -> i1
53/// ```
54static VectorizeOp lowerBoundaryScalar(VectorizeOp op) {
55 ImplicitLocOpBuilder builder(op.getLoc(), op);
56 SmallVector<ValueRange> vectors;
57
58 for (ValueRange range : op.getInputs()) {
59 unsigned bw = range.front().getType().getIntOrFloatBitWidth();
60 vectors.push_back(
61 comb::ConcatOp::create(builder,
62 builder.getIntegerType(bw * range.size()), range)
63 ->getResults());
64 }
65
66 unsigned width = op->getResult(0).getType().getIntOrFloatBitWidth();
67 VectorizeOp newOp = VectorizeOp::create(
68 builder, builder.getIntegerType(width * op->getNumResults()), vectors);
69 newOp.getBody().takeBody(op.getBody());
70
71 for (OpResult res : op.getResults()) {
72 Value newRes = comb::ExtractOp::create(
73 builder, newOp.getResult(0), width * res.getResultNumber(), width);
74 res.replaceAllUsesWith(newRes);
75 }
76
77 op->erase();
78 return newOp;
79}
80
81/// Vectorizes the `arc.vectorize` boundary by using the `vector` type and
82/// dialect for SIMD-based vectorization. Returns the vectorized version of the
83/// op. May invalidate the passed operation.
84///
85/// Example:
86/// ```mlir
87/// %0:2 = arc.vectorize (%in0, %in1), (%in2, %in2) : (i64, i64, i32, i32) ->
88/// (i64, i64) { ^bb0(%arg0: i64, %arg1: i32):
89/// %c0_i32 = hw.constant 0 : i32
90/// %1 = comb.concat %c0_i32, %arg1 : i32, i32
91/// %2 = comb.and %arg0, %1 : i64
92/// arc.vectorize.return %2 : i64
93/// }
94/// ```
95/// becomes
96/// ```mlir
97/// %cst = arith.constant dense<0> : vector<2xi64>
98/// %0 = vector.insert %in0, %cst [0] : i64 into vector<2xi64>
99/// %1 = vector.insert %in1, %0 [1] : i64 into vector<2xi64>
100/// %2 = vector.broadcast %in2 : i32 to vector<2xi32>
101/// %3 = arc.vectorize (%1), (%2) : (vector<2xi64>, vector<2xi32>) ->
102/// vector<2xi64> { ^bb0(%arg0: i64, %arg1: i32):
103/// %c0_i32 = hw.constant 0 : i32
104/// %4 = comb.concat %c0_i32, %arg1 : i32, i32
105/// %5 = comb.and %arg0, %4 : i64
106/// arc.vectorize.return %5 : i64
107/// }
108/// %4 = vector.extract %3[0] : vector<2xi64>
109/// %5 = vector.extract %3[1] : vector<2xi64>
110/// ```
111static VectorizeOp lowerBoundaryVector(VectorizeOp op) {
112 ImplicitLocOpBuilder builder(op.getLoc(), op);
113 SmallVector<ValueRange> vectors;
114
115 for (ValueRange range : op.getInputs()) {
116 // Insert a broadcast operation if all elements of a vector are the same
117 // because it's a significantly cheaper instruction.
118 VectorType type = VectorType::get(SmallVector<int64_t>(1, range.size()),
119 range.front().getType());
120 if (llvm::all_equal(range)) {
121 vectors.push_back(
122 vector::BroadcastOp::create(builder, type, range.front())
123 ->getResults());
124 continue;
125 }
126
127 // Otherwise do a gather.
128 ValueRange vector =
129 arith::ConstantOp::create(
130 builder,
131 DenseElementsAttr::get(
132 type, SmallVector<Attribute>(
133 range.size(),
134 builder.getIntegerAttr(type.getElementType(), 0))))
135 ->getResults();
136 for (auto [i, element] : llvm::enumerate(range))
137 vector = vector::InsertOp::create(builder, element, vector.front(), i)
138 ->getResults();
139
140 vectors.push_back(vector);
141 }
142
143 VectorType resType = VectorType::get(
144 SmallVector<int64_t>(1, op->getNumResults()), op.getResult(0).getType());
145 VectorizeOp newOp = VectorizeOp::create(builder, resType, vectors);
146 newOp.getBody().takeBody(op.getBody());
147
148 for (OpResult res : op.getResults())
149 res.replaceAllUsesWith(vector::ExtractOp::create(
150 builder, newOp.getResult(0), res.getResultNumber()));
151
152 op->erase();
153 return newOp;
154}
155
156/// Vectorizes the boundary of the given `arc.vectorize` operation if it is not
157/// already vectorized. If the body of the `arc.vectorize` operation is already
158/// vectorized the same vectorization technique (SIMD or scalar) is chosen.
159/// Otherwise,
160/// * packs the vector in a scalar if it fits in a 64-bit integer or
161/// * uses the `vector` type and dialect for SIMD vectorization
162/// Returns the vectorized version of the op. May invalidate the passed
163/// operation.
164static VectorizeOp lowerBoundary(VectorizeOp op) {
165 // Nothing to do if it is already vectorized.
166 if (op.isBoundaryVectorized())
167 return op;
168
169 // If the body is already vectorized, we must use the same vectorization
170 // technique. Otherwise, we would produce invalid IR.
171 if (op.isBodyVectorized()) {
172 if (isa<VectorType>(op.getBody().front().getArgumentTypes().front()))
173 return lowerBoundaryVector(op);
174 return lowerBoundaryScalar(op);
175 }
176
177 // If the vector can fit in an i64 value, use scalar vectorization, otherwise
178 // use SIMD.
179 unsigned numLanes = op.getInputs().size();
180 unsigned maxLaneWidth = 0;
181 for (OperandRange range : op.getInputs())
182 maxLaneWidth =
183 std::max(maxLaneWidth, range.front().getType().getIntOrFloatBitWidth());
184
185 if ((numLanes * maxLaneWidth <= 64) &&
186 op->getResult(0).getType().getIntOrFloatBitWidth() *
187 op->getNumResults() <=
188 64)
189 return lowerBoundaryScalar(op);
190 return lowerBoundaryVector(op);
191}
192
193/// Vectorizes the body of the given `arc.vectorize` operation if it is not
194/// already vectorized. If the boundary of the `arc.vectorize` operation is
195/// already vectorized the same vectorization technique (SIMD or scalar) is
196/// chosen. Otherwise,
197/// * packs the vector in a scalar if it fits in a 64-bit integer or
198/// * uses the `vector` type and dialect for SIMD vectorization
199/// Returns the vectorized version of the op or failure. May invalidate the
200/// passed operation.
201static FailureOr<VectorizeOp> lowerBody(VectorizeOp op) {
202 if (op.isBodyVectorized())
203 return op;
204
205 return op->emitError("lowering body not yet supported");
206}
207
208/// Inlines the `arc.vectorize` operations body once both the boundary and body
209/// are vectorized.
210///
211/// Example:
212/// ```mlir
213/// %0 = comb.concat %in0, %in1 : i1, i1
214/// %1 = comb.concat %in2, %in2 : i1, i1
215/// %2 = arc.vectorize (%0), (%1) : (i2, i2) -> i2 {
216/// ^bb0(%arg0: i2, %arg1: i2):
217/// %12 = arith.andi %arg0, %arg1 : i2
218/// arc.vectorize.return %12 : i2
219/// }
220/// %3 = comb.extract %2 from 0 : (i2) -> i1
221/// %4 = comb.extract %2 from 1 : (i2) -> i1
222/// ```
223/// becomes
224/// ```mlir
225/// %0 = comb.concat %in0, %in1 : i1, i1
226/// %1 = comb.concat %in2, %in2 : i1, i1
227/// %2 = arith.andi %0, %1 : i2
228/// %3 = comb.extract %2 from 0 : (i2) -> i1
229/// %4 = comb.extract %2 from 1 : (i2) -> i1
230/// ```
231static LogicalResult inlineBody(VectorizeOp op) {
232 if (!(op.isBodyVectorized() && op.isBoundaryVectorized()))
233 return op->emitError(
234 "can only inline body if boundary and body are already vectorized");
235
236 Block &block = op.getBody().front();
237 for (auto [operand, arg] : llvm::zip(op.getInputs(), block.getArguments()))
238 arg.replaceAllUsesWith(operand.front());
239
240 Operation *terminator = block.getTerminator();
241 op->getResult(0).replaceAllUsesWith(terminator->getOperand(0));
242 terminator->erase();
243
244 op->getBlock()->getOperations().splice(op->getIterator(),
245 block.getOperations());
246 op->erase();
247
248 return success();
249}
250
251namespace {
252/// A pass to vectorize (parts of) an `arc.vectorize` operation.
253struct LowerVectorizationsPass
254 : public arc::impl::LowerVectorizationsBase<LowerVectorizationsPass> {
255 LowerVectorizationsPass() = default;
256 explicit LowerVectorizationsPass(LowerVectorizationsModeEnum mode)
257 : LowerVectorizationsPass() {
258 this->mode.setValue(mode);
259 }
260
261 void runOnOperation() override {
262 WalkResult result = getOperation().walk([&](VectorizeOp op) -> WalkResult {
263 switch (mode) {
264 case LowerVectorizationsModeEnum::Full:
265 if (auto newOp = lowerBody(lowerBoundary(op));
266 succeeded(newOp) && succeeded(inlineBody(*newOp)))
267 return WalkResult::advance();
268 return WalkResult::interrupt();
269 case LowerVectorizationsModeEnum::Boundary:
270 return lowerBoundary(op), WalkResult::advance();
271 case LowerVectorizationsModeEnum::Body:
272 return static_cast<LogicalResult>(lowerBody(op));
273 case LowerVectorizationsModeEnum::InlineBody:
274 return inlineBody(op);
275 }
276 llvm_unreachable("all enum cases must be handled above");
277 });
278 if (result.wasInterrupted())
279 signalPassFailure();
280 }
281};
282} // namespace
283
284std::unique_ptr<Pass>
285arc::createLowerVectorizationsPass(LowerVectorizationsModeEnum mode) {
286 return std::make_unique<LowerVectorizationsPass>(mode);
287}
static VectorizeOp lowerBoundary(VectorizeOp op)
Vectorizes the boundary of the given arc.vectorize operation if it is not already vectorized.
static FailureOr< VectorizeOp > lowerBody(VectorizeOp op)
Vectorizes the body of the given arc.vectorize operation if it is not already vectorized.
static VectorizeOp lowerBoundaryVector(VectorizeOp op)
Vectorizes the arc.vectorize boundary by using the vector type and dialect for SIMD-based vectorizati...
static LogicalResult inlineBody(VectorizeOp op)
Inlines the arc.vectorize operations body once both the boundary and body are vectorized.
static VectorizeOp lowerBoundaryScalar(VectorizeOp op)
Vectorizes the arc.vectorize boundary by packing the vector elements into an integer value.
create(low_bit, result_type, input=None)
Definition comb.py:187
std::unique_ptr< mlir::Pass > createLowerVectorizationsPass(LowerVectorizationsModeEnum mode=LowerVectorizationsModeEnum::Full)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.