CIRCT 21.0.0git
Loading...
Searching...
No Matches
AffineParallelUnroll.cpp
Go to the documentation of this file.
1//===- AffineParallelUnroll.cpp - Unroll AffineParallelOp ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Unroll AffineParallelOp to facilitate lowering to Calyx ParOp.
10//
11//===----------------------------------------------------------------------===//
12
14#include "circt/Support/LLVM.h"
15#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/OperationSupport.h"
21#include "mlir/IR/Visitors.h"
22#include "mlir/Pass/PassManager.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27namespace circt {
28namespace calyx {
29#define GEN_PASS_DEF_AFFINEPARALLELUNROLL
30#include "circt/Dialect/Calyx/CalyxPasses.h.inc"
31} // namespace calyx
32} // namespace circt
33
34using namespace circt;
35using namespace mlir;
36using namespace mlir::affine;
37using namespace mlir::arith;
38
39namespace {
40// This pass tries to prevent potential memory banking contention by hoisting
41// memory reads after AffineParallelUnroll. It only hoists memory reads
42// that occur *more than once* inside `scf.execute_region`s. Since
43// AffineParallelUnroll converts loop indices to constants, this consecutive
44// pass can safely analyze and remove redundant accesses.
45struct MemoryBankConflictResolver {
46 LogicalResult run(AffineParallelOp affineParallelOp);
47
48 // Computes and collects all memory accesses (read/write) that have constant
49 // access indices.
50 DenseMap<Operation *, SmallVector<int64_t, 4>>
51 computeConstMemAccessIndices(AffineParallelOp affineParallelOp);
52
53 // Performs memory contention analysis on memory write operations, returning
54 // `failure` if the pass identifies two write operations write to the same
55 // memory reference at the same access indices in different parallel regions.
56 LogicalResult
57 writeOpAnalysis(DenseMap<Operation *, SmallVector<int64_t, 4>> &,
58 AffineParallelOp affineParallelOp);
59
60 // Tries to hoist memory read operations that will cause memory access
61 // contention, such as reading from the same memory reference with the same
62 // access indices in different parallel regions.
63 void readOpHoistAnalysis(AffineParallelOp affineParallelOp,
64 DenseMap<Operation *, SmallVector<int64_t, 4>> &);
65
66 // Stores the mapping from memory writes and their associated constant access
67 // indices to the parallel region.
68 DenseMap<std::pair<Value, SmallVector<int64_t, 4>>, scf::ExecuteRegionOp>
69 constantWriteOpIndices;
70
71 // Counts the total number of each memory reads and their associated constant
72 // access indices across all parallel regions.
73 DenseMap<std::pair<Value, SmallVector<int64_t, 4>>, int> constantReadOpCounts;
74};
75} // end anonymous namespace
76
77namespace llvm {
78template <>
79struct DenseMapInfo<std::pair<Value, SmallVector<int64_t, 4>>> {
80 using PairType = std::pair<Value, SmallVector<int64_t, 4>>;
81
82 static inline PairType getEmptyKey() {
84 }
85
86 static inline PairType getTombstoneKey() {
88 }
89
90 static unsigned getHashValue(const PairType &pair) {
91 unsigned hash = DenseMapInfo<Value>::getHashValue(pair.first);
92 for (const auto &v : pair.second)
93 hash = llvm::hash_combine(hash, DenseMapInfo<int64_t>::getHashValue(v));
94 return hash;
95 }
96
97 static bool isEqual(const PairType &lhs, const PairType &rhs) {
98 return lhs.first == rhs.first && lhs.second == rhs.second;
99 }
100};
101} // namespace llvm
102
103DenseMap<Operation *, SmallVector<int64_t, 4>>
104MemoryBankConflictResolver::computeConstMemAccessIndices(
105 AffineParallelOp affineParallelOp) {
106 DenseMap<Operation *, SmallVector<int64_t, 4>> constantMemAccessIndices;
107
108 MLIRContext *ctx = affineParallelOp->getContext();
109 auto executeRegionOps =
110 affineParallelOp.getBody()->getOps<scf::ExecuteRegionOp>();
111 for (auto executeRegionOp : executeRegionOps) {
112 executeRegionOp.walk([&](Operation *op) {
113 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
114 return WalkResult::advance();
115
116 auto read = dyn_cast<AffineReadOpInterface>(op);
117 AffineMap map = read ? read.getAffineMap()
118 : cast<AffineWriteOpInterface>(op).getAffineMap();
119 ValueRange mapOperands =
120 read ? read.getMapOperands()
121 : cast<AffineWriteOpInterface>(op).getMapOperands();
122
123 SmallVector<Attribute> operandConsts;
124 for (Value operand : mapOperands) {
125 if (auto constOp =
126 operand.template getDefiningOp<arith::ConstantIndexOp>()) {
127 operandConsts.push_back(
128 IntegerAttr::get(IndexType::get(ctx), constOp.value()));
129 } else {
130 return WalkResult::advance();
131 }
132 }
133
134 SmallVector<int64_t, 4> evaluatedIndices, foldedResults;
135 bool hasPoison = false;
136 map.partialConstantFold(operandConsts, &foldedResults, &hasPoison);
137 if (!(hasPoison || foldedResults.empty()))
138 constantMemAccessIndices[op] = foldedResults;
139
140 return WalkResult::advance();
141 });
142 }
143
144 return constantMemAccessIndices;
145}
146
147LogicalResult MemoryBankConflictResolver::writeOpAnalysis(
148 DenseMap<Operation *, SmallVector<int64_t, 4>> &constantMemAccessIndices,
149 AffineParallelOp affineParallelOp) {
150 auto executeRegionOps =
151 affineParallelOp.getBody()->getOps<scf::ExecuteRegionOp>();
152 WalkResult result;
153 for (auto executeRegionOp : executeRegionOps) {
154 auto walkResult = executeRegionOp.walk([&](Operation *op) {
155 if (!isa<AffineWriteOpInterface>(op))
156 return WalkResult::advance();
157
158 auto writeOp = cast<AffineWriteOpInterface>(op);
159
160 auto constIndicesIt = constantMemAccessIndices.find(op);
161 if (constIndicesIt == constantMemAccessIndices.end())
162 // Currently, we give up all write op analysis whose write indices are
163 // non-constants.
164 return WalkResult::advance();
165
166 auto parentExecuteRegionOp =
167 writeOp->getParentOfType<scf::ExecuteRegionOp>();
168 auto key = std::pair(writeOp.getMemRef(), constIndicesIt->second);
169 auto [writeOpIndicesIt, emplaced] =
170 constantWriteOpIndices.try_emplace(key, parentExecuteRegionOp);
171 if (!emplaced && writeOpIndicesIt->second != parentExecuteRegionOp) {
172 // Cannot write to the same memory reference at the same indices more
173 // than once in different parallel regions (but it's okay to write twice
174 // within the same parallel region because everything is sequential),
175 // because it will result in write contention.
176 return WalkResult::interrupt();
177 }
178 return WalkResult::advance();
179 });
180
181 if (walkResult.wasInterrupted())
182 return failure();
183 }
184
185 return success();
186}
187
188void MemoryBankConflictResolver::readOpHoistAnalysis(
189 AffineParallelOp affineParallelOp,
190 DenseMap<Operation *, SmallVector<int64_t, 4>> &constantMemAccessIndices) {
191 for (auto &[memOp, constIndices] : constantMemAccessIndices) {
192 auto readOp = dyn_cast<AffineReadOpInterface>(memOp);
193 if (!readOp)
194 continue;
195
196 auto memref = readOp.getMemRef();
197 auto key = std::pair(memref, constIndices);
198 // We do not hoist any read as long as it's being written in any parallel
199 // region.
200 if (llvm::any_of(memref.getUsers(), [&](Operation *user) {
201 return affineParallelOp->isAncestor(user) &&
202 hasEffect<MemoryEffects::Write>(user, memref);
203 })) {
204 continue;
205 }
206 constantReadOpCounts[key]++;
207 }
208
209 bool shouldHoist = llvm::any_of(
210 constantReadOpCounts, [](const auto &entry) { return entry.second > 1; });
211 if (!shouldHoist)
212 return;
213
214 OpBuilder builder(affineParallelOp);
215 DenseMap<std::pair<Value, SmallVector<int64_t, 4>>, ValueRange> hoistedReads;
216 for (auto &[memOp, constIndices] : constantMemAccessIndices) {
217 auto readOp = dyn_cast<AffineReadOpInterface>(memOp);
218 if (!readOp)
219 continue;
220
221 auto key = std::pair(readOp.getMemRef(), constIndices);
222 if (constantReadOpCounts[key] > 1) {
223 if (hoistedReads.find(key) == hoistedReads.end()) {
224 builder.setInsertionPoint(affineParallelOp);
225 Operation *clonedRead = builder.clone(*readOp.getOperation());
226 hoistedReads[key] = clonedRead->getOpResults();
227 }
228 readOp->replaceAllUsesWith(hoistedReads[key]);
229 }
230 }
231}
232
233LogicalResult
234MemoryBankConflictResolver::run(AffineParallelOp affineParallelOp) {
235 auto constantMemAccessIndices =
236 computeConstMemAccessIndices(affineParallelOp);
237
238 if (failed(writeOpAnalysis(constantMemAccessIndices, affineParallelOp))) {
239 return failure();
240 }
241
242 readOpHoistAnalysis(affineParallelOp, constantMemAccessIndices);
243
244 return success();
245}
246
247namespace {
248
249struct AffineParallelUnroll : public OpRewritePattern<AffineParallelOp> {
250 using OpRewritePattern::OpRewritePattern;
251
252 LogicalResult matchAndRewrite(AffineParallelOp affineParallelOp,
253 PatternRewriter &rewriter) const override {
254 if (affineParallelOp->hasAttr("calyx.unroll"))
255 // We assume that having "calyx.unroll" attribute means that it has
256 // already been unrolled.
257 return failure();
258
259 if (!affineParallelOp.getResults().empty()) {
260 affineParallelOp.emitError(
261 "affine.parallel with reductions is not supported yet");
262 return failure();
263 }
264
265 Location loc = affineParallelOp.getLoc();
266
267 rewriter.setInsertionPointAfter(affineParallelOp);
268 // Create a single-iteration parallel loop op and mark its special by
269 // setting the "calyx.unroll" attribute.
270 AffineMap lbMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(0),
271 rewriter.getContext());
272 AffineMap ubMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(1),
273 rewriter.getContext());
274 auto newParallelOp = rewriter.create<AffineParallelOp>(
275 loc, /*resultTypes=*/TypeRange(),
276 /*reductions=*/SmallVector<arith::AtomicRMWKind>(),
277 /*lowerBoundsMap=*/lbMap, /*lowerBoundsOperands=*/SmallVector<Value>(),
278 /*upperBoundsMap=*/ubMap, /*upperBoundsOperands=*/SmallVector<Value>(),
279 /*steps=*/SmallVector<int64_t>({1}));
280 newParallelOp->setAttr("calyx.unroll", rewriter.getBoolAttr(true));
281
282 SmallVector<int64_t> pLoopLowerBounds =
283 affineParallelOp.getLowerBoundsMap().getConstantResults();
284 if (pLoopLowerBounds.empty()) {
285 affineParallelOp.emitError(
286 "affine.parallel must have constant lower bounds");
287 return failure();
288 }
289 SmallVector<int64_t> pLoopUpperBounds =
290 affineParallelOp.getUpperBoundsMap().getConstantResults();
291 if (pLoopUpperBounds.empty()) {
292 affineParallelOp.emitError(
293 "affine.parallel must have constant upper bounds");
294 return failure();
295 }
296 SmallVector<int64_t, 8> pLoopSteps = affineParallelOp.getSteps();
297
298 Block *pLoopBody = affineParallelOp.getBody();
299 MutableArrayRef<BlockArgument> pLoopIVs = affineParallelOp.getIVs();
300
301 OpBuilder insideBuilder(newParallelOp);
302 SmallVector<int64_t> indices = pLoopLowerBounds;
303 while (true) {
304 insideBuilder.setInsertionPointToStart(newParallelOp.getBody());
305 // Create an `scf.execute_region` to wrap each unrolled block since
306 // `affine.parallel` requires only one block in the body region.
307 auto executeRegionOp =
308 insideBuilder.create<scf::ExecuteRegionOp>(loc, TypeRange{});
309 Region &executeRegionRegion = executeRegionOp.getRegion();
310 Block *executeRegionBlock = &executeRegionRegion.emplaceBlock();
311
312 OpBuilder regionBuilder(executeRegionOp);
313 // Each iteration starts with a fresh mapping, so each new block’s
314 // argument of a region-based operation (such as `affine.for`) get
315 // re-mapped independently.
316 IRMapping operandMap;
317 regionBuilder.setInsertionPointToEnd(executeRegionBlock);
318 // Map induction variables to constant indices
319 for (unsigned i = 0; i < indices.size(); ++i) {
320 Value ivConstant =
321 regionBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
322 operandMap.map(pLoopIVs[i], ivConstant);
323 }
324
325 for (auto it = pLoopBody->begin(); it != std::prev(pLoopBody->end());
326 ++it)
327 regionBuilder.clone(*it, operandMap);
328
329 // A terminator should always be inserted in `scf.execute_region`'s block.
330 regionBuilder.create<scf::YieldOp>(loc);
331
332 // Increment indices using `step`.
333 bool done = false;
334 for (int dim = indices.size() - 1; dim >= 0; --dim) {
335 indices[dim] += pLoopSteps[dim];
336 if (indices[dim] < pLoopUpperBounds[dim])
337 break;
338 indices[dim] = pLoopLowerBounds[dim];
339 if (dim == 0)
340 // All combinations have been generated
341 done = true;
342 }
343 if (done)
344 break;
345 }
346
347 rewriter.replaceOp(affineParallelOp, newParallelOp);
348
349 return success();
350 }
351};
352
353struct AffineParallelUnrollPass
354 : public circt::calyx::impl::AffineParallelUnrollBase<
355 AffineParallelUnrollPass> {
356 void getDependentDialects(DialectRegistry &registry) const override {
357 registry.insert<mlir::scf::SCFDialect>();
358 }
359 void runOnOperation() override;
360};
361
362} // end anonymous namespace
363
364void AffineParallelUnrollPass::runOnOperation() {
365 auto *ctx = &getContext();
366 ConversionTarget target(*ctx);
367
368 RewritePatternSet patterns(ctx);
369 patterns.add<AffineParallelUnroll>(ctx);
370
371 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
372 getOperation()->emitError("Failed to unroll affine.parallel");
373 signalPassFailure();
374 }
375
376 // `AffineParallelUnroll` pattern introduces constant values, so running
377 // `canonicalizePatterns` before `MemoryBankConflictResolver` will help ease
378 // the analysis in `MemoryBankConflictResolver`.
379 RewritePatternSet canonicalizePatterns(ctx);
380 scf::IndexSwitchOp::getCanonicalizationPatterns(canonicalizePatterns, ctx);
381 if (failed(applyPatternsGreedily(getOperation(),
382 std::move(canonicalizePatterns)))) {
383 getOperation()->emitError("Failed to apply canonicalization.");
384 signalPassFailure();
385 }
386
387 getOperation()->walk([&](AffineParallelOp parOp) {
388 if (parOp->hasAttr("calyx.unroll")) {
389 if (failed(MemoryBankConflictResolver().run(parOp))) {
390 parOp.emitError("Failed to unroll");
391 signalPassFailure();
392 }
393 }
394 });
395}
396
397std::unique_ptr<mlir::Pass> circt::calyx::createAffineParallelUnrollPass() {
398 return std::make_unique<AffineParallelUnrollPass>();
399}
std::unique_ptr< mlir::Pass > createAffineParallelUnrollPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:121
static bool isEqual(const PairType &lhs, const PairType &rhs)