12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/Pass/Pass.h"
17 #include "circt/Dialect/Arc/ArcPassesEnums.cpp.inc"
21 #define GEN_PASS_DEF_LOWERVECTORIZATIONS
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
27 using namespace circt;
55 ImplicitLocOpBuilder builder(op.getLoc(), op);
56 SmallVector<ValueRange> vectors;
58 for (ValueRange range : op.getInputs()) {
59 unsigned bw = range.front().getType().getIntOrFloatBitWidth();
60 vectors.push_back(builder
61 .create<comb::ConcatOp>(
62 builder.getIntegerType(bw * range.size()), range)
66 unsigned width = op->getResult(0).getType().getIntOrFloatBitWidth();
67 VectorizeOp newOp = builder.create<VectorizeOp>(
68 builder.getIntegerType(width * op->getNumResults()), vectors);
69 newOp.getBody().takeBody(op.getBody());
71 for (OpResult res : op.getResults()) {
73 newOp.getResult(0), width * res.getResultNumber(), width);
74 res.replaceAllUsesWith(newRes);
112 ImplicitLocOpBuilder builder(op.getLoc(), op);
113 SmallVector<ValueRange> vectors;
115 for (ValueRange range : op.getInputs()) {
118 VectorType type =
VectorType::get(SmallVector<int64_t>(1, range.size()),
119 range.front().getType());
120 if (llvm::all_equal(range)) {
122 builder.create<vector::BroadcastOp>(type, range.front())
131 type, SmallVector<Attribute>(
133 builder.getIntegerAttr(type.getElementType(), 0))))
135 for (
auto [i, element] : llvm::enumerate(range))
136 vector = builder.create<vector::InsertOp>(element, vector.front(), i)
139 vectors.push_back(vector);
143 SmallVector<int64_t>(1, op->getNumResults()), op.getResult(0).getType());
144 VectorizeOp newOp = builder.create<VectorizeOp>(resType, vectors);
145 newOp.getBody().takeBody(op.getBody());
147 for (OpResult res : op.getResults())
148 res.replaceAllUsesWith(builder.create<vector::ExtractOp>(
149 newOp.getResult(0), res.getResultNumber()));
165 if (op.isBoundaryVectorized())
170 if (op.isBodyVectorized()) {
171 if (isa<VectorType>(op.getBody().front().getArgumentTypes().front()))
178 unsigned numLanes = op.getInputs().size();
179 unsigned maxLaneWidth = 0;
180 for (OperandRange range : op.getInputs())
182 std::max(maxLaneWidth, range.front().getType().getIntOrFloatBitWidth());
184 if ((numLanes * maxLaneWidth <= 64) &&
185 op->getResult(0).getType().getIntOrFloatBitWidth() *
186 op->getNumResults() <=
200 static FailureOr<VectorizeOp>
lowerBody(VectorizeOp op) {
201 if (op.isBodyVectorized())
204 return op->emitError(
"lowering body not yet supported");
231 if (!(op.isBodyVectorized() && op.isBoundaryVectorized()))
232 return op->emitError(
233 "can only inline body if boundary and body are already vectorized");
235 Block &block = op.getBody().front();
236 for (
auto [operand, arg] : llvm::zip(op.getInputs(), block.getArguments()))
237 arg.replaceAllUsesWith(operand.front());
239 Operation *terminator = block.getTerminator();
240 op->getResult(0).replaceAllUsesWith(terminator->getOperand(0));
243 op->getBlock()->getOperations().splice(op->getIterator(),
244 block.getOperations());
252 struct LowerVectorizationsPass
253 :
public arc::impl::LowerVectorizationsBase<LowerVectorizationsPass> {
254 LowerVectorizationsPass() =
default;
255 explicit LowerVectorizationsPass(LowerVectorizationsModeEnum mode)
256 : LowerVectorizationsPass() {
257 this->mode.setValue(mode);
260 void runOnOperation()
override {
261 WalkResult result = getOperation().walk([&](VectorizeOp op) -> WalkResult {
263 case LowerVectorizationsModeEnum::Full:
265 succeeded(newOp) && succeeded(
inlineBody(*newOp)))
266 return WalkResult::advance();
267 return WalkResult::interrupt();
268 case LowerVectorizationsModeEnum::Boundary:
270 case LowerVectorizationsModeEnum::Body:
271 return static_cast<LogicalResult
>(
lowerBody(op));
272 case LowerVectorizationsModeEnum::InlineBody:
275 llvm_unreachable(
"all enum cases must be handled above");
277 if (result.wasInterrupted())
283 std::unique_ptr<Pass>
285 return std::make_unique<LowerVectorizationsPass>(mode);
static VectorizeOp lowerBoundary(VectorizeOp op)
Vectorizes the boundary of the given arc.vectorize operation if it is not already vectorized.
static FailureOr< VectorizeOp > lowerBody(VectorizeOp op)
Vectorizes the body of the given arc.vectorize operation if it is not already vectorized.
static VectorizeOp lowerBoundaryVector(VectorizeOp op)
Vectorizes the arc.vectorize boundary by using the vector type and dialect for SIMD-based vectorizati...
static LogicalResult inlineBody(VectorizeOp op)
Inlines the arc.vectorize operations body once both the boundary and body are vectorized.
static VectorizeOp lowerBoundaryScalar(VectorizeOp op)
Vectorizes the arc.vectorize boundary by packing the vector elements into an integer value.
std::unique_ptr< mlir::Pass > createLowerVectorizationsPass(LowerVectorizationsModeEnum mode=LowerVectorizationsModeEnum::Full)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.