20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "firrtl-vectorization"
26 using namespace circt;
27 using namespace firrtl;
36 template <
typename OpTy,
typename ResultOpType>
37 class VectorCreateToLogicElementwise :
public mlir::RewritePattern {
39 VectorCreateToLogicElementwise(MLIRContext *context)
40 : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
43 matchAndRewrite(Operation *op,
44 mlir::PatternRewriter &rewriter)
const override {
45 auto vectorCreateOp = cast<VectorCreateOp>(op);
46 FVectorType type = vectorCreateOp.getType();
47 if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
50 SmallVector<Value> lhs, rhs;
54 if (llvm::all_of(op->getOperands(), [&](Value operand) {
55 auto op = operand.getDefiningOp<OpTy>();
58 lhs.push_back(op.getLhs());
59 rhs.push_back(op.getRhs());
62 auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
63 op->getLoc(), vectorCreateOp.getType(), lhs);
64 auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
65 op->getLoc(), vectorCreateOp.getType(), rhs);
66 rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
74 struct VectorizationPass :
public VectorizationBase<VectorizationPass> {
75 VectorizationPass() =
default;
76 void runOnOperation()
override;
81 void VectorizationPass::runOnOperation() {
82 LLVM_DEBUG(
llvm::dbgs() <<
"===----- Running Vectorization "
83 "--------------------------------------===\n"
84 <<
"Module: '" << getOperation().
getName() <<
"'\n";);
86 RewritePatternSet
patterns(&getContext());
88 .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
89 VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
90 VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
92 mlir::FrozenRewritePatternSet frozenPatterns(std::move(
patterns));
93 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
97 return std::make_unique<VectorizationPass>();
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()