Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
AffinePloopUnparallelize.cpp
Go to the documentation of this file.
1//===- AffinePloopUnparallize.cpp
2//----------------------------------------------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Affine/Utils.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Arith/Transforms/Passes.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/Visitors.h"
21#include "mlir/Pass/PassManager.h"
22#include "mlir/Support/LogicalResult.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26namespace circt {
27namespace calyx {
28#define GEN_PASS_DEF_AFFINEPLOOPUNPARALLELIZE
29#include "circt/Dialect/Calyx/CalyxPasses.h.inc"
30} // namespace calyx
31} // namespace circt
32
33using namespace mlir;
34using namespace mlir::arith;
35using namespace mlir::memref;
36using namespace mlir::scf;
37using namespace mlir::func;
38using namespace circt;
39
41 : public OpConversionPattern<affine::AffineParallelOp> {
42 using OpConversionPattern::OpConversionPattern;
43
44public:
45 LogicalResult
46 matchAndRewrite(affine::AffineParallelOp affineParallelOp, OpAdaptor adaptor,
47 ConversionPatternRewriter &rewriter) const override {
48 if (affineParallelOp.getIVs().size() != 1)
49 return rewriter.notifyMatchFailure(affineParallelOp,
50 "currently only support single IV");
51
52 auto loc = affineParallelOp.getLoc();
53 auto upperBoundTuple = mlir::affine::expandAffineMap(
54 rewriter, loc, affineParallelOp.getUpperBoundsMap(),
55 affineParallelOp.getUpperBoundsOperands());
56 if (!upperBoundTuple)
57 return rewriter.notifyMatchFailure(affineParallelOp,
58 "does not have upper bounds");
59 Value upperBound = (*upperBoundTuple)[0];
60
61 auto lowerBoundTuple = mlir::affine::expandAffineMap(
62 rewriter, loc, affineParallelOp.getLowerBoundsMap(),
63 affineParallelOp.getLowerBoundsOperands());
64 if (!lowerBoundTuple)
65 return rewriter.notifyMatchFailure(affineParallelOp,
66 "does not have lower bounds");
67 Value lowerBound = (*lowerBoundTuple)[0];
68
69 auto step = affineParallelOp.getSteps()[0];
70
71 auto factorAttr =
72 affineParallelOp->getAttrOfType<IntegerAttr>("unparallelize.factor");
73 if (!factorAttr)
74 return rewriter.notifyMatchFailure(affineParallelOp,
75 "Missing 'unparallelize.factor'");
76
77 int64_t factor = factorAttr.getInt();
78
79 SmallVector<scf::IndexSwitchOp> simplifiableIndexSwitchOps =
80 collectSimplifiableIndexSwitchOps(affineParallelOp, factor);
81
82 auto outerLoop = rewriter.create<affine::AffineForOp>(
83 loc, lowerBound, rewriter.getDimIdentityMap(), upperBound,
84 rewriter.getDimIdentityMap(), step * factor);
85
86 rewriter.setInsertionPointToStart(outerLoop.getBody());
87 AffineMap lbMap = AffineMap::get(
88 /*dimCount=*/0, /*symbolCount=*/0,
89 /*results=*/rewriter.getAffineConstantExpr(0), rewriter.getContext());
90 AffineMap ubMap = AffineMap::get(
91 0, 0, rewriter.getAffineConstantExpr(factor), rewriter.getContext());
92 auto innerParallel = rewriter.create<affine::AffineParallelOp>(
93 loc, /*resultTypes=*/TypeRange(),
94 /*reductions=*/SmallVector<arith::AtomicRMWKind>(),
95 /*lowerBoundsMap=*/lbMap, /*lowerBoundsOperands=*/SmallVector<Value>(),
96 /*upperBoundsMap=*/ubMap, /*upperBoundsOperands=*/SmallVector<Value>(),
97 /*steps=*/SmallVector<int64_t>({step}));
98
99 if (!innerParallel.getBody()->empty()) {
100 Operation &lastOp = innerParallel.getBody()->back();
101 if (isa<affine::AffineYieldOp>(lastOp))
102 lastOp.erase();
103 }
104 rewriter.setInsertionPointToStart(innerParallel.getBody());
105
106 // `newIndex` will be the newly created `affine.for`'s IV added with the
107 // inner `affine.parallel`'s IV.
108 auto addMap = AffineMap::get(
109 2, 0, rewriter.getAffineDimExpr(0) + rewriter.getAffineDimExpr(1),
110 rewriter.getContext());
111
112 auto newIndex = rewriter.create<affine::AffineApplyOp>(
113 loc, addMap,
114 ValueRange{outerLoop.getInductionVar(), innerParallel.getIVs()[0]});
115
116 Block *srcBlock = affineParallelOp.getBody();
117 Block *destBlock = innerParallel.getBody();
118
119 // Move all operations except the terminator from `srcBlock` to `destBlock`.
120 destBlock->getOperations().splice(
121 destBlock->end(), // insert at the end of `destBlock`
122 srcBlock->getOperations(), // move ops from `srcBlock`
123 srcBlock->begin(), // start at beginning of `srcBlock`
124 std::prev(srcBlock->end()) // stop before the terminator op
125 );
126
127 // Remap occurrences of the old induction variable in the moved ops.
128 destBlock->walk([&](Operation *op) {
129 for (OpOperand &operand : op->getOpOperands()) {
130 if (operand.get() == affineParallelOp.getIVs()[0])
131 operand.set(newIndex);
132 }
133 });
134
135 rewriter.setInsertionPointToEnd(destBlock);
136 rewriter.create<affine::AffineYieldOp>(loc);
137
138 for (auto indexSwitchOp : simplifiableIndexSwitchOps) {
139 indexSwitchOp.setOperand(innerParallel.getIVs().front());
140 }
141
142 return success();
143 }
144
145private:
146 // Collect all simplifiable `scf.index_switch` ops in `affineParallelOp`. An
147 // `scf.index_switch` op is simpliiable if its argument only depends on
148 // `affineParallelOp`'s loop IV and if it's a result of a modulo expression.
149 SmallVector<scf::IndexSwitchOp>
150 collectSimplifiableIndexSwitchOps(affine::AffineParallelOp affineParallelOp,
151 int64_t factor) const {
152 SmallVector<scf::IndexSwitchOp> result;
153 affineParallelOp->walk([&](scf::IndexSwitchOp indexSwitchOp) {
154 auto switchArg = indexSwitchOp.getArg();
155 auto affineApplyOp =
156 dyn_cast_or_null<affine::AffineApplyOp>(switchArg.getDefiningOp());
157 if (!affineApplyOp || affineApplyOp->getNumOperands() != 1 ||
158 affineApplyOp->getNumResults() != 1)
159 return WalkResult::advance();
160
161 auto affineMap = affineApplyOp.getAffineMap();
162 auto binExpr = dyn_cast<AffineBinaryOpExpr>(affineMap.getResult(0));
163 if (!binExpr || binExpr.getKind() != AffineExprKind::Mod)
164 return WalkResult::advance();
165
166 if (affineApplyOp.getOperand(0) != affineParallelOp.getIVs().front())
167 return WalkResult::advance();
168
169 auto rhs = binExpr.getRHS();
170 auto constRhs = dyn_cast<AffineConstantExpr>(rhs);
171 if (!constRhs || factor != constRhs.getValue())
172 return WalkResult::advance();
173
174 result.push_back(indexSwitchOp);
175 return WalkResult::advance();
176 });
177 return result;
178 }
179};
180
181namespace {
182class AffinePloopUnparallelizePass
183 : public circt::calyx::impl::AffinePloopUnparallelizeBase<
184 AffinePloopUnparallelizePass> {
185 void runOnOperation() override;
186};
187} // namespace
188
189void AffinePloopUnparallelizePass::runOnOperation() {
190 MLIRContext *ctx = &getContext();
191
192 ConversionTarget target(*ctx);
193 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
194 scf::SCFDialect, affine::AffineDialect>();
195
196 RewritePatternSet patterns(ctx);
198 GreedyRewriteConfig config;
199 config.strictMode = GreedyRewriteStrictness::ExistingOps;
200 if (failed(
201 applyPatternsGreedily(getOperation(), std::move(patterns), config))) {
202 signalPassFailure();
203 }
204}
205
207 return std::make_unique<AffinePloopUnparallelizePass>();
208}
SmallVector< scf::IndexSwitchOp > collectSimplifiableIndexSwitchOps(affine::AffineParallelOp affineParallelOp, int64_t factor) const
LogicalResult matchAndRewrite(affine::AffineParallelOp affineParallelOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
std::unique_ptr< mlir::Pass > createAffinePloopUnparallelizePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.