CIRCT  19.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Vectorize primitive operations ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass performs vectorization for primitive operations, e.g:
9 // vector_create (or a[0], b[0]), (or a[1], b[1]), (or a[2], b[2])
10 // => elementwise_or a, b
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetails.h"
19 #include "circt/Support/Debug.h"
20 #include "circt/Support/LLVM.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "firrtl-vectorization"
26 
27 using namespace circt;
28 using namespace firrtl;
29 
30 namespace {
31 //===----------------------------------------------------------------------===//
32 // Pass Infrastructure
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 
37 template <typename OpTy, typename ResultOpType>
38 class VectorCreateToLogicElementwise : public mlir::RewritePattern {
39 public:
40  VectorCreateToLogicElementwise(MLIRContext *context)
41  : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
42 
43  LogicalResult
44  matchAndRewrite(Operation *op,
45  mlir::PatternRewriter &rewriter) const override {
46  auto vectorCreateOp = cast<VectorCreateOp>(op);
47  FVectorType type = vectorCreateOp.getType();
48  if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
49  return failure();
50 
51  SmallVector<Value> lhs, rhs;
52 
53  // Vectorize if all operands are `OpTy`. Currently there is no other
54  // condition so it could be too aggressive.
55  if (llvm::all_of(op->getOperands(), [&](Value operand) {
56  auto op = operand.getDefiningOp<OpTy>();
57  if (!op)
58  return false;
59  lhs.push_back(op.getLhs());
60  rhs.push_back(op.getRhs());
61  return true;
62  })) {
63  auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
64  op->getLoc(), vectorCreateOp.getType(), lhs);
65  auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
66  op->getLoc(), vectorCreateOp.getType(), rhs);
67  rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
68  return success();
69  }
70  return failure();
71  }
72 };
73 } // namespace
74 
75 struct VectorizationPass : public VectorizationBase<VectorizationPass> {
76  VectorizationPass() = default;
77  void runOnOperation() override;
78 };
79 
80 } // namespace
81 
82 void VectorizationPass::runOnOperation() {
83  LLVM_DEBUG(debugPassHeader(this)
84  << "\n"
85  << "Module: '" << getOperation().getName() << "'\n";);
86 
87  RewritePatternSet patterns(&getContext());
88  patterns
89  .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
90  VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
91  VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
92  &getContext());
93  mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));
94  (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
95 }
96 
97 std::unique_ptr<mlir::Pass> circt::firrtl::createVectorizationPass() {
98  return std::make_unique<VectorizationPass>();
99 }
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31