CIRCT  19.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Vectorize primitive operations ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass performs vectorization for primitive operations, e.g:
9 // vector_create (or a[0], b[0]), (or a[1], b[1]), (or a[2], b[2])
10 // => elementwise_or a, b
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "circt/Support/Debug.h"
19 #include "circt/Support/LLVM.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/Support/Debug.h"
24 
25 namespace circt {
26 namespace firrtl {
27 #define GEN_PASS_DEF_VECTORIZATION
28 #include "circt/Dialect/FIRRTL/Passes.h.inc"
29 } // namespace firrtl
30 } // namespace circt
31 
32 using namespace circt;
33 using namespace firrtl;
34 
35 #define DEBUG_TYPE "firrtl-vectorization"
36 
37 //===----------------------------------------------------------------------===//
38 // Pass Infrastructure
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 template <typename OpTy, typename ResultOpType>
44 class VectorCreateToLogicElementwise : public mlir::RewritePattern {
45 public:
46  VectorCreateToLogicElementwise(MLIRContext *context)
47  : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
48 
49  LogicalResult
50  matchAndRewrite(Operation *op,
51  mlir::PatternRewriter &rewriter) const override {
52  auto vectorCreateOp = cast<VectorCreateOp>(op);
53  FVectorType type = vectorCreateOp.getType();
54  if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
55  return failure();
56 
57  SmallVector<Value> lhs, rhs;
58 
59  // Vectorize if all operands are `OpTy`. Currently there is no other
60  // condition so it could be too aggressive.
61  if (llvm::all_of(op->getOperands(), [&](Value operand) {
62  auto op = operand.getDefiningOp<OpTy>();
63  if (!op)
64  return false;
65  lhs.push_back(op.getLhs());
66  rhs.push_back(op.getRhs());
67  return true;
68  })) {
69  auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
70  op->getLoc(), vectorCreateOp.getType(), lhs);
71  auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
72  op->getLoc(), vectorCreateOp.getType(), rhs);
73  rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
74  return success();
75  }
76  return failure();
77  }
78 };
79 
80 struct VectorizationPass
81  : public circt::firrtl::impl::VectorizationBase<VectorizationPass> {
82  VectorizationPass() = default;
83  void runOnOperation() override;
84 };
85 
86 } // namespace
87 
88 void VectorizationPass::runOnOperation() {
89  LLVM_DEBUG(debugPassHeader(this)
90  << "\n"
91  << "Module: '" << getOperation().getName() << "'\n";);
92 
93  RewritePatternSet patterns(&getContext());
94  patterns
95  .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
96  VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
97  VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
98  &getContext());
99  mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));
100  (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
101 }
102 
103 std::unique_ptr<mlir::Pass> circt::firrtl::createVectorizationPass() {
104  return std::make_unique<VectorizationPass>();
105 }
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31