CIRCT  20.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Vectorize primitive operations ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass performs vectorization for primitive operations, e.g:
9 // vector_create (or a[0], b[0]), (or a[1], b[1]), (or a[2], b[2])
10 // => elementwise_or a, b
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "circt/Support/Debug.h"
18 #include "circt/Support/LLVM.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/Support/Debug.h"
23 
24 namespace circt {
25 namespace firrtl {
26 #define GEN_PASS_DEF_VECTORIZATION
27 #include "circt/Dialect/FIRRTL/Passes.h.inc"
28 } // namespace firrtl
29 } // namespace circt
30 
31 using namespace circt;
32 using namespace firrtl;
33 
34 #define DEBUG_TYPE "firrtl-vectorization"
35 
36 //===----------------------------------------------------------------------===//
37 // Pass Infrastructure
38 //===----------------------------------------------------------------------===//
39 
40 namespace {
41 
42 template <typename OpTy, typename ResultOpType>
43 class VectorCreateToLogicElementwise : public mlir::RewritePattern {
44 public:
45  VectorCreateToLogicElementwise(MLIRContext *context)
46  : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
47 
48  LogicalResult
49  matchAndRewrite(Operation *op,
50  mlir::PatternRewriter &rewriter) const override {
51  auto vectorCreateOp = cast<VectorCreateOp>(op);
52  FVectorType type = vectorCreateOp.getType();
53  if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
54  return failure();
55 
56  SmallVector<Value> lhs, rhs;
57 
58  // Vectorize if all operands are `OpTy`. Currently there is no other
59  // condition so it could be too aggressive.
60  if (llvm::all_of(op->getOperands(), [&](Value operand) {
61  auto op = operand.getDefiningOp<OpTy>();
62  if (!op)
63  return false;
64  lhs.push_back(op.getLhs());
65  rhs.push_back(op.getRhs());
66  return true;
67  })) {
68  auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
69  op->getLoc(), vectorCreateOp.getType(), lhs);
70  auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
71  op->getLoc(), vectorCreateOp.getType(), rhs);
72  rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
73  return success();
74  }
75  return failure();
76  }
77 };
78 
79 struct VectorizationPass
80  : public circt::firrtl::impl::VectorizationBase<VectorizationPass> {
81  VectorizationPass() = default;
82  void runOnOperation() override;
83 };
84 
85 } // namespace
86 
87 void VectorizationPass::runOnOperation() {
88  LLVM_DEBUG(debugPassHeader(this)
89  << "\n"
90  << "Module: '" << getOperation().getName() << "'\n";);
91 
92  RewritePatternSet patterns(&getContext());
93  patterns
94  .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
95  VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
96  VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
97  &getContext());
98  mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));
99  (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
100 }
101 
102 std::unique_ptr<mlir::Pass> circt::firrtl::createVectorizationPass() {
103  return std::make_unique<VectorizationPass>();
104 }
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31