CIRCT  19.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Vectorize primitive operations ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass performs vectorization for primitive operations, e.g:
9 // vector_create (or a[0], b[0]), (or a[1], b[1]), (or a[2], b[2])
10 // => elementwise_or a, b
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetails.h"
19 #include "circt/Support/Debug.h"
20 #include "circt/Support/LLVM.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "firrtl-vectorization"
26 
27 using namespace circt;
28 using namespace firrtl;
29 
30 namespace {
31 //===----------------------------------------------------------------------===//
32 // Pass Infrastructure
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 
37 template <typename OpTy, typename ResultOpType>
38 class VectorCreateToLogicElementwise : public mlir::RewritePattern {
39 public:
40  VectorCreateToLogicElementwise(MLIRContext *context)
41  : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
42 
43  LogicalResult
44  matchAndRewrite(Operation *op,
45  mlir::PatternRewriter &rewriter) const override {
46  auto vectorCreateOp = cast<VectorCreateOp>(op);
47  FVectorType type = vectorCreateOp.getType();
48  if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
49  return failure();
50 
51  SmallVector<Value> lhs, rhs;
52 
53  // Vectorize if all operands are `OpTy`. Currently there is no other
54  // condition so it could be too aggressive.
55  if (llvm::all_of(op->getOperands(), [&](Value operand) {
56  auto op = operand.getDefiningOp<OpTy>();
57  if (!op)
58  return false;
59  lhs.push_back(op.getLhs());
60  rhs.push_back(op.getRhs());
61  return true;
62  })) {
63  auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
64  op->getLoc(), vectorCreateOp.getType(), lhs);
65  auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
66  op->getLoc(), vectorCreateOp.getType(), rhs);
67  rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
68  return success();
69  }
70  return failure();
71  }
72 };
73 } // namespace
74 
75 struct VectorizationPass : public VectorizationBase<VectorizationPass> {
76  VectorizationPass() = default;
77  void runOnOperation() override;
78 };
79 
80 } // namespace
81 
82 void VectorizationPass::runOnOperation() {
83  LLVM_DEBUG(debugPassHeader(this)
84  << "\n"
85  << "Module: '" << getOperation().getName() << "'\n";);
86 
87  RewritePatternSet patterns(&getContext());
88  patterns
89  .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
90  VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
91  VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
92  &getContext());
93  mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));
94  (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
95 }
96 
97 std::unique_ptr<mlir::Pass> circt::firrtl::createVectorizationPass() {
98  return std::make_unique<VectorizationPass>();
99 }
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31