20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/Support/Debug.h"
27 #define GEN_PASS_DEF_VECTORIZATION
28 #include "circt/Dialect/FIRRTL/Passes.h.inc"
32 using namespace circt;
33 using namespace firrtl;
35 #define DEBUG_TYPE "firrtl-vectorization"
43 template <
typename OpTy,
typename ResultOpType>
44 class VectorCreateToLogicElementwise :
public mlir::RewritePattern {
46 VectorCreateToLogicElementwise(MLIRContext *context)
47 : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
50 matchAndRewrite(Operation *op,
51 mlir::PatternRewriter &rewriter)
const override {
52 auto vectorCreateOp = cast<VectorCreateOp>(op);
53 FVectorType type = vectorCreateOp.getType();
54 if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
57 SmallVector<Value> lhs, rhs;
61 if (llvm::all_of(op->getOperands(), [&](Value operand) {
62 auto op = operand.getDefiningOp<OpTy>();
65 lhs.push_back(op.getLhs());
66 rhs.push_back(op.getRhs());
69 auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
70 op->getLoc(), vectorCreateOp.getType(), lhs);
71 auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
72 op->getLoc(), vectorCreateOp.getType(), rhs);
73 rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
80 struct VectorizationPass
81 :
public circt::firrtl::impl::VectorizationBase<VectorizationPass> {
82 VectorizationPass() =
default;
83 void runOnOperation()
override;
88 void VectorizationPass::runOnOperation() {
91 <<
"Module: '" << getOperation().
getName() <<
"'\n";);
93 RewritePatternSet
patterns(&getContext());
95 .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
96 VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
97 VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
99 mlir::FrozenRewritePatternSet frozenPatterns(std::move(
patterns));
100 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
104 return std::make_unique<VectorizationPass>();
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.