21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "firrtl-vectorization"
27 using namespace circt;
28 using namespace firrtl;
37 template <
typename OpTy,
typename ResultOpType>
38 class VectorCreateToLogicElementwise :
public mlir::RewritePattern {
40 VectorCreateToLogicElementwise(MLIRContext *context)
41 : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
44 matchAndRewrite(Operation *op,
45 mlir::PatternRewriter &rewriter)
const override {
46 auto vectorCreateOp = cast<VectorCreateOp>(op);
47 FVectorType type = vectorCreateOp.getType();
48 if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
51 SmallVector<Value> lhs, rhs;
55 if (llvm::all_of(op->getOperands(), [&](Value operand) {
56 auto op = operand.getDefiningOp<OpTy>();
59 lhs.push_back(op.getLhs());
60 rhs.push_back(op.getRhs());
63 auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
64 op->getLoc(), vectorCreateOp.getType(), lhs);
65 auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
66 op->getLoc(), vectorCreateOp.getType(), rhs);
67 rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
75 struct VectorizationPass :
public VectorizationBase<VectorizationPass> {
76 VectorizationPass() =
default;
77 void runOnOperation()
override;
82 void VectorizationPass::runOnOperation() {
85 <<
"Module: '" << getOperation().
getName() <<
"'\n";);
87 RewritePatternSet
patterns(&getContext());
89 .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
90 VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
91 VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
93 mlir::FrozenRewritePatternSet frozenPatterns(std::move(
patterns));
94 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
98 return std::make_unique<VectorizationPass>();
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.