CIRCT  18.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Vectorize primitive operations ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass performs vectorization for primitive operations, e.g:
9 // vector_create (or a[0], b[0]), (or a[1], b[1]), (or a[2], b[2])
10 // => elementwise_or a, b
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetails.h"
19 #include "circt/Support/LLVM.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "firrtl-vectorization"
25 
26 using namespace circt;
27 using namespace firrtl;
28 
29 namespace {
30 //===----------------------------------------------------------------------===//
31 // Pass Infrastructure
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 
36 template <typename OpTy, typename ResultOpType>
37 class VectorCreateToLogicElementwise : public mlir::RewritePattern {
38 public:
39  VectorCreateToLogicElementwise(MLIRContext *context)
40  : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
41 
42  LogicalResult
43  matchAndRewrite(Operation *op,
44  mlir::PatternRewriter &rewriter) const override {
45  auto vectorCreateOp = cast<VectorCreateOp>(op);
46  FVectorType type = vectorCreateOp.getType();
47  if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
48  return failure();
49 
50  SmallVector<Value> lhs, rhs;
51 
52  // Vectorize if all operands are `OpTy`. Currently there is no other
53  // condition so it could be too aggressive.
54  if (llvm::all_of(op->getOperands(), [&](Value operand) {
55  auto op = operand.getDefiningOp<OpTy>();
56  if (!op)
57  return false;
58  lhs.push_back(op.getLhs());
59  rhs.push_back(op.getRhs());
60  return true;
61  })) {
62  auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
63  op->getLoc(), vectorCreateOp.getType(), lhs);
64  auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
65  op->getLoc(), vectorCreateOp.getType(), rhs);
66  rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
67  return success();
68  }
69  return failure();
70  }
71 };
72 } // namespace
73 
74 struct VectorizationPass : public VectorizationBase<VectorizationPass> {
75  VectorizationPass() = default;
76  void runOnOperation() override;
77 };
78 
79 } // namespace
80 
81 void VectorizationPass::runOnOperation() {
82  LLVM_DEBUG(llvm::dbgs() << "===----- Running Vectorization "
83  "--------------------------------------===\n"
84  << "Module: '" << getOperation().getName() << "'\n";);
85 
86  RewritePatternSet patterns(&getContext());
87  patterns
88  .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
89  VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
90  VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
91  &getContext());
92  mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));
93  (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
94 }
95 
96 std::unique_ptr<mlir::Pass> circt::firrtl::createVectorizationPass() {
97  return std::make_unique<VectorizationPass>();
98 }
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28