CIRCT  20.0.0git
LowerVectorizations.cpp
Go to the documentation of this file.
1 //===- LowerVectorizations.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/Pass/Pass.h"
16 
17 #include "circt/Dialect/Arc/ArcPassesEnums.cpp.inc"
18 
19 namespace circt {
20 namespace arc {
21 #define GEN_PASS_DEF_LOWERVECTORIZATIONS
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
23 } // namespace arc
24 } // namespace circt
25 
26 using namespace mlir;
27 using namespace circt;
28 using namespace arc;
29 
30 /// Vectorizes the `arc.vectorize` boundary by packing the vector elements into
31 /// an integer value. Returns the vectorized version of the op. May invalidate
32 /// the passed operation.
33 ///
34 /// Example:
35 /// ```mlir
36 /// %0:2 = arc.vectorize (%in0, %in1), (%in2, %in3) : (i1, i1, i1, i1) -> (i1,
37 /// i1) { ^bb0(%arg0: i1, %arg1: i1):
38 /// %1 = comb.and %arg0, %arg1 : i1
39 /// arc.vectorize.return %1 : i1
40 /// }
41 /// ```
42 /// becomes
43 /// ```mlir
44 /// %0 = comb.concat %in0, %in1 : i1, i1
45 /// %1 = comb.concat %in2, %in3 : i1, i1
46 /// %2 = arc.vectorize (%0), (%1) : (i2, i2) -> i2 {
47 /// ^bb0(%arg0: i1, %arg1: i1):
48 /// %11 = comb.and %arg0, %arg1 : i1
49 /// arc.vectorize.return %11 : i1
50 /// }
51 /// %3 = comb.extract %2 from 0 : (i2) -> i1
52 /// %4 = comb.extract %2 from 1 : (i2) -> i1
53 /// ```
54 static VectorizeOp lowerBoundaryScalar(VectorizeOp op) {
55  ImplicitLocOpBuilder builder(op.getLoc(), op);
56  SmallVector<ValueRange> vectors;
57 
58  for (ValueRange range : op.getInputs()) {
59  unsigned bw = range.front().getType().getIntOrFloatBitWidth();
60  vectors.push_back(builder
61  .create<comb::ConcatOp>(
62  builder.getIntegerType(bw * range.size()), range)
63  ->getResults());
64  }
65 
66  unsigned width = op->getResult(0).getType().getIntOrFloatBitWidth();
67  VectorizeOp newOp = builder.create<VectorizeOp>(
68  builder.getIntegerType(width * op->getNumResults()), vectors);
69  newOp.getBody().takeBody(op.getBody());
70 
71  for (OpResult res : op.getResults()) {
72  Value newRes = builder.create<comb::ExtractOp>(
73  newOp.getResult(0), width * res.getResultNumber(), width);
74  res.replaceAllUsesWith(newRes);
75  }
76 
77  op->erase();
78  return newOp;
79 }
80 
81 /// Vectorizes the `arc.vectorize` boundary by using the `vector` type and
82 /// dialect for SIMD-based vectorization. Returns the vectorized version of the
83 /// op. May invalidate the passed operation.
84 ///
85 /// Example:
86 /// ```mlir
87 /// %0:2 = arc.vectorize (%in0, %in1), (%in2, %in2) : (i64, i64, i32, i32) ->
88 /// (i64, i64) { ^bb0(%arg0: i64, %arg1: i32):
89 /// %c0_i32 = hw.constant 0 : i32
90 /// %1 = comb.concat %c0_i32, %arg1 : i32, i32
91 /// %2 = comb.and %arg0, %1 : i64
92 /// arc.vectorize.return %2 : i64
93 /// }
94 /// ```
95 /// becomes
96 /// ```mlir
97 /// %cst = arith.constant dense<0> : vector<2xi64>
98 /// %0 = vector.insert %in0, %cst [0] : i64 into vector<2xi64>
99 /// %1 = vector.insert %in1, %0 [1] : i64 into vector<2xi64>
100 /// %2 = vector.broadcast %in2 : i32 to vector<2xi32>
101 /// %3 = arc.vectorize (%1), (%2) : (vector<2xi64>, vector<2xi32>) ->
102 /// vector<2xi64> { ^bb0(%arg0: i64, %arg1: i32):
103 /// %c0_i32 = hw.constant 0 : i32
104 /// %4 = comb.concat %c0_i32, %arg1 : i32, i32
105 /// %5 = comb.and %arg0, %4 : i64
106 /// arc.vectorize.return %5 : i64
107 /// }
108 /// %4 = vector.extract %3[0] : vector<2xi64>
109 /// %5 = vector.extract %3[1] : vector<2xi64>
110 /// ```
111 static VectorizeOp lowerBoundaryVector(VectorizeOp op) {
112  ImplicitLocOpBuilder builder(op.getLoc(), op);
113  SmallVector<ValueRange> vectors;
114 
115  for (ValueRange range : op.getInputs()) {
116  // Insert a broadcast operation if all elements of a vector are the same
117  // because it's a significantly cheaper instruction.
118  VectorType type = VectorType::get(SmallVector<int64_t>(1, range.size()),
119  range.front().getType());
120  if (llvm::all_equal(range)) {
121  vectors.push_back(
122  builder.create<vector::BroadcastOp>(type, range.front())
123  ->getResults());
124  continue;
125  }
126 
127  // Otherwise do a gather.
128  ValueRange vector =
129  builder
130  .create<arith::ConstantOp>(DenseElementsAttr::get(
131  type, SmallVector<Attribute>(
132  range.size(),
133  builder.getIntegerAttr(type.getElementType(), 0))))
134  ->getResults();
135  for (auto [i, element] : llvm::enumerate(range))
136  vector = builder.create<vector::InsertOp>(element, vector.front(), i)
137  ->getResults();
138 
139  vectors.push_back(vector);
140  }
141 
142  VectorType resType = VectorType::get(
143  SmallVector<int64_t>(1, op->getNumResults()), op.getResult(0).getType());
144  VectorizeOp newOp = builder.create<VectorizeOp>(resType, vectors);
145  newOp.getBody().takeBody(op.getBody());
146 
147  for (OpResult res : op.getResults())
148  res.replaceAllUsesWith(builder.create<vector::ExtractOp>(
149  newOp.getResult(0), res.getResultNumber()));
150 
151  op->erase();
152  return newOp;
153 }
154 
155 /// Vectorizes the boundary of the given `arc.vectorize` operation if it is not
156 /// already vectorized. If the body of the `arc.vectorize` operation is already
157 /// vectorized the same vectorization technique (SIMD or scalar) is chosen.
158 /// Otherwise,
159 /// * packs the vector in a scalar if it fits in a 64-bit integer or
160 /// * uses the `vector` type and dialect for SIMD vectorization
161 /// Returns the vectorized version of the op. May invalidate the passed
162 /// operation.
163 static VectorizeOp lowerBoundary(VectorizeOp op) {
164  // Nothing to do if it is already vectorized.
165  if (op.isBoundaryVectorized())
166  return op;
167 
168  // If the body is already vectorized, we must use the same vectorization
169  // technique. Otherwise, we would produce invalid IR.
170  if (op.isBodyVectorized()) {
171  if (isa<VectorType>(op.getBody().front().getArgumentTypes().front()))
172  return lowerBoundaryVector(op);
173  return lowerBoundaryScalar(op);
174  }
175 
176  // If the vector can fit in an i64 value, use scalar vectorization, otherwise
177  // use SIMD.
178  unsigned numLanes = op.getInputs().size();
179  unsigned maxLaneWidth = 0;
180  for (OperandRange range : op.getInputs())
181  maxLaneWidth =
182  std::max(maxLaneWidth, range.front().getType().getIntOrFloatBitWidth());
183 
184  if ((numLanes * maxLaneWidth <= 64) &&
185  op->getResult(0).getType().getIntOrFloatBitWidth() *
186  op->getNumResults() <=
187  64)
188  return lowerBoundaryScalar(op);
189  return lowerBoundaryVector(op);
190 }
191 
192 /// Vectorizes the body of the given `arc.vectorize` operation if it is not
193 /// already vectorized. If the boundary of the `arc.vectorize` operation is
194 /// already vectorized the same vectorization technique (SIMD or scalar) is
195 /// chosen. Otherwise,
196 /// * packs the vector in a scalar if it fits in a 64-bit integer or
197 /// * uses the `vector` type and dialect for SIMD vectorization
198 /// Returns the vectorized version of the op or failure. May invalidate the
199 /// passed operation.
200 static FailureOr<VectorizeOp> lowerBody(VectorizeOp op) {
201  if (op.isBodyVectorized())
202  return op;
203 
204  return op->emitError("lowering body not yet supported");
205 }
206 
207 /// Inlines the `arc.vectorize` operations body once both the boundary and body
208 /// are vectorized.
209 ///
210 /// Example:
211 /// ```mlir
212 /// %0 = comb.concat %in0, %in1 : i1, i1
213 /// %1 = comb.concat %in2, %in2 : i1, i1
214 /// %2 = arc.vectorize (%0), (%1) : (i2, i2) -> i2 {
215 /// ^bb0(%arg0: i2, %arg1: i2):
216 /// %12 = arith.andi %arg0, %arg1 : i2
217 /// arc.vectorize.return %12 : i2
218 /// }
219 /// %3 = comb.extract %2 from 0 : (i2) -> i1
220 /// %4 = comb.extract %2 from 1 : (i2) -> i1
221 /// ```
222 /// becomes
223 /// ```mlir
224 /// %0 = comb.concat %in0, %in1 : i1, i1
225 /// %1 = comb.concat %in2, %in2 : i1, i1
226 /// %2 = arith.andi %0, %1 : i2
227 /// %3 = comb.extract %2 from 0 : (i2) -> i1
228 /// %4 = comb.extract %2 from 1 : (i2) -> i1
229 /// ```
230 static LogicalResult inlineBody(VectorizeOp op) {
231  if (!(op.isBodyVectorized() && op.isBoundaryVectorized()))
232  return op->emitError(
233  "can only inline body if boundary and body are already vectorized");
234 
235  Block &block = op.getBody().front();
236  for (auto [operand, arg] : llvm::zip(op.getInputs(), block.getArguments()))
237  arg.replaceAllUsesWith(operand.front());
238 
239  Operation *terminator = block.getTerminator();
240  op->getResult(0).replaceAllUsesWith(terminator->getOperand(0));
241  terminator->erase();
242 
243  op->getBlock()->getOperations().splice(op->getIterator(),
244  block.getOperations());
245  op->erase();
246 
247  return success();
248 }
249 
250 namespace {
251 /// A pass to vectorize (parts of) an `arc.vectorize` operation.
252 struct LowerVectorizationsPass
253  : public arc::impl::LowerVectorizationsBase<LowerVectorizationsPass> {
254  LowerVectorizationsPass() = default;
255  explicit LowerVectorizationsPass(LowerVectorizationsModeEnum mode)
256  : LowerVectorizationsPass() {
257  this->mode.setValue(mode);
258  }
259 
260  void runOnOperation() override {
261  WalkResult result = getOperation().walk([&](VectorizeOp op) -> WalkResult {
262  switch (mode) {
263  case LowerVectorizationsModeEnum::Full:
264  if (auto newOp = lowerBody(lowerBoundary(op));
265  succeeded(newOp) && succeeded(inlineBody(*newOp)))
266  return WalkResult::advance();
267  return WalkResult::interrupt();
268  case LowerVectorizationsModeEnum::Boundary:
269  return lowerBoundary(op), WalkResult::advance();
270  case LowerVectorizationsModeEnum::Body:
271  return static_cast<LogicalResult>(lowerBody(op));
272  case LowerVectorizationsModeEnum::InlineBody:
273  return inlineBody(op);
274  }
275  llvm_unreachable("all enum cases must be handled above");
276  });
277  if (result.wasInterrupted())
278  signalPassFailure();
279  }
280 };
281 } // namespace
282 
283 std::unique_ptr<Pass>
284 arc::createLowerVectorizationsPass(LowerVectorizationsModeEnum mode) {
285  return std::make_unique<LowerVectorizationsPass>(mode);
286 }
int32_t width
Definition: FIRRTL.cpp:36
static VectorizeOp lowerBoundary(VectorizeOp op)
Vectorizes the boundary of the given arc.vectorize operation if it is not already vectorized.
static FailureOr< VectorizeOp > lowerBody(VectorizeOp op)
Vectorizes the body of the given arc.vectorize operation if it is not already vectorized.
static VectorizeOp lowerBoundaryVector(VectorizeOp op)
Vectorizes the arc.vectorize boundary by using the vector type and dialect for SIMD-based vectorizati...
static LogicalResult inlineBody(VectorizeOp op)
Inlines the arc.vectorize operations body once both the boundary and body are vectorized.
static VectorizeOp lowerBoundaryScalar(VectorizeOp op)
Vectorizes the arc.vectorize boundary by packing the vector elements into an integer value.
std::unique_ptr< mlir::Pass > createLowerVectorizationsPass(LowerVectorizationsModeEnum mode=LowerVectorizationsModeEnum::Full)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21