19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/Support/Debug.h"
26 #define GEN_PASS_DEF_VECTORIZATION
27 #include "circt/Dialect/FIRRTL/Passes.h.inc"
31 using namespace circt;
32 using namespace firrtl;
34 #define DEBUG_TYPE "firrtl-vectorization"
42 template <
typename OpTy,
typename ResultOpType>
43 class VectorCreateToLogicElementwise :
public mlir::RewritePattern {
45 VectorCreateToLogicElementwise(MLIRContext *context)
46 : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
49 matchAndRewrite(Operation *op,
50 mlir::PatternRewriter &rewriter)
const override {
51 auto vectorCreateOp = cast<VectorCreateOp>(op);
52 FVectorType type = vectorCreateOp.getType();
53 if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
56 SmallVector<Value> lhs, rhs;
60 if (llvm::all_of(op->getOperands(), [&](Value operand) {
61 auto op = operand.getDefiningOp<OpTy>();
64 lhs.push_back(op.getLhs());
65 rhs.push_back(op.getRhs());
68 auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
69 op->getLoc(), vectorCreateOp.getType(), lhs);
70 auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
71 op->getLoc(), vectorCreateOp.getType(), rhs);
72 rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
79 struct VectorizationPass
80 :
public circt::firrtl::impl::VectorizationBase<VectorizationPass> {
81 VectorizationPass() =
default;
82 void runOnOperation()
override;
87 void VectorizationPass::runOnOperation() {
90 <<
"Module: '" << getOperation().
getName() <<
"'\n";);
92 RewritePatternSet
patterns(&getContext());
94 .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
95 VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
96 VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
98 mlir::FrozenRewritePatternSet frozenPatterns(std::move(
patterns));
99 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
103 return std::make_unique<VectorizationPass>();
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.