CIRCT 20.0.0git
Loading...
Searching...
No Matches
Vectorization.cpp
Go to the documentation of this file.
1//===- Vectorization.cpp - Vectorize primitive operations ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//===----------------------------------------------------------------------===//
7//
8// This pass performs vectorization for primitive operations, e.g:
9// vector_create (or a[0], b[0]), (or a[1], b[1]), (or a[2], b[2])
10// => elementwise_or a, b
11//
12//===----------------------------------------------------------------------===//
13
17#include "circt/Support/Debug.h"
18#include "circt/Support/LLVM.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22#include "llvm/Support/Debug.h"
23
24namespace circt {
25namespace firrtl {
26#define GEN_PASS_DEF_VECTORIZATION
27#include "circt/Dialect/FIRRTL/Passes.h.inc"
28} // namespace firrtl
29} // namespace circt
30
31using namespace circt;
32using namespace firrtl;
33
34#define DEBUG_TYPE "firrtl-vectorization"
35
36//===----------------------------------------------------------------------===//
37// Pass Infrastructure
38//===----------------------------------------------------------------------===//
39
40namespace {
41
42template <typename OpTy, typename ResultOpType>
43class VectorCreateToLogicElementwise : public mlir::RewritePattern {
44public:
45 VectorCreateToLogicElementwise(MLIRContext *context)
46 : RewritePattern(VectorCreateOp::getOperationName(), 0, context) {}
47
48 LogicalResult
49 matchAndRewrite(Operation *op,
50 mlir::PatternRewriter &rewriter) const override {
51 auto vectorCreateOp = cast<VectorCreateOp>(op);
52 FVectorType type = vectorCreateOp.getType();
53 if (type.hasUninferredWidth() || !type_isa<UIntType>(type.getElementType()))
54 return failure();
55
56 SmallVector<Value> lhs, rhs;
57
58 // Vectorize if all operands are `OpTy`. Currently there is no other
59 // condition so it could be too aggressive.
60 if (llvm::all_of(op->getOperands(), [&](Value operand) {
61 auto op = operand.getDefiningOp<OpTy>();
62 if (!op)
63 return false;
64 lhs.push_back(op.getLhs());
65 rhs.push_back(op.getRhs());
66 return true;
67 })) {
68 auto lhsVec = rewriter.createOrFold<VectorCreateOp>(
69 op->getLoc(), vectorCreateOp.getType(), lhs);
70 auto rhsVec = rewriter.createOrFold<VectorCreateOp>(
71 op->getLoc(), vectorCreateOp.getType(), rhs);
72 rewriter.replaceOpWithNewOp<ResultOpType>(op, lhsVec, rhsVec);
73 return success();
74 }
75 return failure();
76 }
77};
78
79struct VectorizationPass
80 : public circt::firrtl::impl::VectorizationBase<VectorizationPass> {
81 VectorizationPass() = default;
82 void runOnOperation() override;
83};
84
85} // namespace
86
87void VectorizationPass::runOnOperation() {
88 LLVM_DEBUG(debugPassHeader(this)
89 << "\n"
90 << "Module: '" << getOperation().getName() << "'\n";);
91
92 RewritePatternSet patterns(&getContext());
94 .insert<VectorCreateToLogicElementwise<OrPrimOp, ElementwiseOrPrimOp>,
95 VectorCreateToLogicElementwise<AndPrimOp, ElementwiseAndPrimOp>,
96 VectorCreateToLogicElementwise<XorPrimOp, ElementwiseXorPrimOp>>(
97 &getContext());
98 mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));
99 (void)applyPatternsGreedily(getOperation(), frozenPatterns);
100}
101
102std::unique_ptr<mlir::Pass> circt::firrtl::createVectorizationPass() {
103 return std::make_unique<VectorizationPass>();
104}
std::unique_ptr< mlir::Pass > createVectorizationPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition Debug.cpp:31