CIRCT 21.0.0git
Loading...
Searching...
No Matches
AffineParallelUnroll.cpp
Go to the documentation of this file.
1//===- AffineParallelUnroll.cpp - Unroll AffineParallelOp ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Unroll AffineParallelOp to facilitate lowering to Calyx ParOp.
10//
11//===----------------------------------------------------------------------===//
12
14#include "circt/Support/LLVM.h"
15#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/OperationSupport.h"
22#include "mlir/IR/Visitors.h"
23#include "mlir/Pass/PassManager.h"
24#include "mlir/Support/LLVM.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28namespace circt {
29namespace calyx {
30#define GEN_PASS_DEF_AFFINEPARALLELUNROLL
31#include "circt/Dialect/Calyx/CalyxPasses.h.inc"
32} // namespace calyx
33} // namespace circt
34
35using namespace circt;
36using namespace mlir;
37using namespace mlir::affine;
38using namespace mlir::arith;
39
40// A shared interface to store affine memory accesses.
42 Value memref;
43 AffineMap map;
44 SmallVector<Value, 4> operands;
45
46 bool operator==(const AffineAccessExpr &other) const {
47 return memref == other.memref && map == other.map &&
48 operands == other.operands;
49 }
50};
51
52namespace llvm {
53template <>
56 return {DenseMapInfo<Value>::getEmptyKey(), {}, {}};
57 }
58
62
63 static unsigned getHashValue(const AffineAccessExpr &expr) {
65 h = llvm::hash_combine(h, DenseMapInfo<AffineMap>::getHashValue(expr.map));
66 for (Value operand : expr.operands)
67 h = llvm::hash_combine(h, DenseMapInfo<Value>::getHashValue(operand));
68 return h;
69 }
70
71 static bool isEqual(const AffineAccessExpr &a, const AffineAccessExpr &b) {
72 return a == b;
73 }
74};
75} // namespace llvm
76
77namespace {
78// This pass tries to prevent potential memory banking contention by hoisting
79// memory reads after AffineParallelUnroll. It only hoists memory reads
80// that occur *more than once* inside `scf.execute_region`s. Since
81// AffineParallelUnroll converts loop indices to constants, this consecutive
82// pass can safely analyze and remove redundant accesses.
83struct MemoryBankConflictResolver {
84 LogicalResult run(AffineParallelOp affineParallelOp);
85
86 // Performs memory contention analysis on memory write operations, returning
87 // `failure` if the pass identifies two write operations write to the same
88 // memory reference at the same access indices in different parallel regions.
89 LogicalResult writeOpAnalysis(AffineParallelOp affineParallelOp);
90
91 // Tries to hoist memory read operations that will cause memory access
92 // contention, such as reading from the same memory reference with the same
93 // access indices in different parallel regions.
94 LogicalResult readOpAnalysis(AffineParallelOp affineParallelOp);
95
96 // Returns if `readOp`'s access indices are invariant with respect to
97 // `affineParallelOp`.
98 bool isInvariantToAffineParallel(AffineReadOpInterface readOp,
99 AffineParallelOp affineParallelOp);
100
101 // Accumulate all memory reads and writes to the fields.
102 void accumulateReadWriteOps(AffineParallelOp affineParallelOp);
103
104 // Stores all memory reads in current `affineParallelOp`.
105 DenseSet<AffineReadOpInterface> allReadOps;
106 // Stores all memory writes in current `affineParallelOp`.
107 DenseSet<AffineWriteOpInterface> allWriteOps;
108 // Stores the memory references across all `scf.execute_region` ops under
109 // current `affineParallelOp`.
110 DenseMap<Value, scf::ExecuteRegionOp> writtenMemRefs;
111 // Counts the number of reads across all `scf.execute_region` ops under
112 // current `affineParallelOp`.
113 DenseMap<AffineAccessExpr, int> readOpCounts;
114 // Records the memory reads that are already hoisted outside current
115 // `affineParallelOp`.
116 DenseMap<AffineAccessExpr, Value> hoistedReads;
117};
118} // end anonymous namespace
119
120void MemoryBankConflictResolver::accumulateReadWriteOps(
121 AffineParallelOp affineParallelOp) {
122 auto executeRegionOps =
123 affineParallelOp.getBody()->getOps<scf::ExecuteRegionOp>();
124 for (auto executeRegionOp : executeRegionOps) {
125 executeRegionOp.walk([&](Operation *op) {
126 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
127 return WalkResult::advance();
128
129 if (auto read = dyn_cast<AffineReadOpInterface>(op)) {
130 allReadOps.insert(read);
131
132 AffineAccessExpr key{read.getMemRef(), read.getAffineMap(),
133 read.getMapOperands()};
134 ++readOpCounts[key];
135 } else {
136 allWriteOps.insert(cast<AffineWriteOpInterface>(op));
137 }
138
139 return WalkResult::advance();
140 });
141 }
142}
143
144LogicalResult
145MemoryBankConflictResolver::writeOpAnalysis(AffineParallelOp affineParallelOp) {
146 for (auto writeOp : allWriteOps) {
147 scf::ExecuteRegionOp parentExecuteRegion =
148 writeOp->getParentOfType<scf::ExecuteRegionOp>();
149 auto memref = writeOp.getMemRef();
150
151 auto it = writtenMemRefs.find(memref);
152 if (it != writtenMemRefs.end() && it->second != parentExecuteRegion) {
153 writeOp.emitError("Multiple writes to the same memory reference");
154 return failure();
155 }
156
157 writtenMemRefs[memref] = parentExecuteRegion;
158 }
159
160 return success();
161}
162
163bool MemoryBankConflictResolver::isInvariantToAffineParallel(
164 AffineReadOpInterface readOp, AffineParallelOp affineParallelOp) {
165 // Check if any operand depends on loop IVs
166 for (Value iv : affineParallelOp.getIVs()) {
167 for (Value operand : readOp.getMapOperands()) {
168 if (operand == iv)
169 return false;
170
171 // Walk through defining operation chains, such as `affine.apply`s.
172 if (Operation *def = operand.getDefiningOp()) {
173 if (affineParallelOp->isAncestor(def) && !isa<arith::ConstantOp>(def)) {
174 // Operand is computed inside the loop, not invariant
175 return false;
176 }
177 }
178 }
179 }
180
181 // Check if memref is written in loop
182 Value memref = readOp.getMemRef();
183 for (Operation *user : memref.getUsers()) {
184 if (user == readOp.getOperation())
185 continue;
186
187 if (affineParallelOp->isAncestor(user) &&
188 hasEffect<MemoryEffects::Write>(user, memref))
189 return false;
190 }
191
192 return true;
193}
194
195LogicalResult
196MemoryBankConflictResolver::readOpAnalysis(AffineParallelOp affineParallelOp) {
197 OpBuilder builder(affineParallelOp);
198 for (auto readOp : allReadOps) {
199 AffineAccessExpr key{readOp.getMemRef(), readOp.getAffineMap(),
200 readOp.getMapOperands()};
201 auto it = readOpCounts.find(key);
202 if (it == readOpCounts.end() || it->second <= 1 ||
203 !isInvariantToAffineParallel(readOp, affineParallelOp))
204 continue;
205
206 // Only hoist once per unique access
207 auto hoistedReadsIt = hoistedReads.find(key);
208 if (hoistedReadsIt == hoistedReads.end()) {
209 builder.setInsertionPoint(affineParallelOp);
210 Operation *cloned = builder.clone(*readOp.getOperation());
211 hoistedReadsIt = hoistedReads.insert({key, cloned->getResult(0)}).first;
212 }
213
214 readOp->getResult(0).replaceAllUsesWith(hoistedReadsIt->second);
215 readOp.getOperation()->erase();
216 }
217
218 return success();
219}
220
221LogicalResult
222MemoryBankConflictResolver::run(AffineParallelOp affineParallelOp) {
223 accumulateReadWriteOps(affineParallelOp);
224
225 if (failed(writeOpAnalysis(affineParallelOp))) {
226 return failure();
227 }
228
229 if (failed(readOpAnalysis(affineParallelOp))) {
230 return failure();
231 }
232
233 return success();
234}
235
236namespace {
237
238struct AffineParallelUnroll : public OpRewritePattern<AffineParallelOp> {
239 using OpRewritePattern::OpRewritePattern;
240
241 LogicalResult matchAndRewrite(AffineParallelOp affineParallelOp,
242 PatternRewriter &rewriter) const override {
243 if (affineParallelOp->hasAttr("calyx.unroll"))
244 // We assume that having "calyx.unroll" attribute means that it has
245 // already been unrolled.
246 return failure();
247
248 if (!affineParallelOp.getResults().empty()) {
249 affineParallelOp.emitError(
250 "affine.parallel with reductions is not supported yet");
251 return failure();
252 }
253
254 Location loc = affineParallelOp.getLoc();
255
256 rewriter.setInsertionPointAfter(affineParallelOp);
257 // Create a single-iteration parallel loop op and mark its special by
258 // setting the "calyx.unroll" attribute.
259 AffineMap lbMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(0),
260 rewriter.getContext());
261 AffineMap ubMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(1),
262 rewriter.getContext());
263 auto newParallelOp = rewriter.create<AffineParallelOp>(
264 loc, /*resultTypes=*/TypeRange(),
265 /*reductions=*/SmallVector<arith::AtomicRMWKind>(),
266 /*lowerBoundsMap=*/lbMap, /*lowerBoundsOperands=*/SmallVector<Value>(),
267 /*upperBoundsMap=*/ubMap, /*upperBoundsOperands=*/SmallVector<Value>(),
268 /*steps=*/SmallVector<int64_t>({1}));
269 newParallelOp->setAttr("calyx.unroll", rewriter.getBoolAttr(true));
270
271 SmallVector<int64_t> pLoopLowerBounds =
272 affineParallelOp.getLowerBoundsMap().getConstantResults();
273 if (pLoopLowerBounds.empty()) {
274 affineParallelOp.emitError(
275 "affine.parallel must have constant lower bounds");
276 return failure();
277 }
278 SmallVector<int64_t> pLoopUpperBounds =
279 affineParallelOp.getUpperBoundsMap().getConstantResults();
280 if (pLoopUpperBounds.empty()) {
281 affineParallelOp.emitError(
282 "affine.parallel must have constant upper bounds");
283 return failure();
284 }
285 SmallVector<int64_t, 8> pLoopSteps = affineParallelOp.getSteps();
286
287 Block *pLoopBody = affineParallelOp.getBody();
288 MutableArrayRef<BlockArgument> pLoopIVs = affineParallelOp.getIVs();
289
290 OpBuilder insideBuilder(newParallelOp);
291 SmallVector<int64_t> indices = pLoopLowerBounds;
292 while (true) {
293 insideBuilder.setInsertionPointToStart(newParallelOp.getBody());
294 // Create an `scf.execute_region` to wrap each unrolled block since
295 // `affine.parallel` requires only one block in the body region.
296 auto executeRegionOp =
297 insideBuilder.create<scf::ExecuteRegionOp>(loc, TypeRange{});
298 Region &executeRegionRegion = executeRegionOp.getRegion();
299 Block *executeRegionBlock = &executeRegionRegion.emplaceBlock();
300
301 OpBuilder regionBuilder(executeRegionOp);
302 // Each iteration starts with a fresh mapping, so each new block’s
303 // argument of a region-based operation (such as `affine.for`) get
304 // re-mapped independently.
305 IRMapping operandMap;
306 regionBuilder.setInsertionPointToEnd(executeRegionBlock);
307 // Map induction variables to constant indices
308 for (unsigned i = 0; i < indices.size(); ++i) {
309 Value ivConstant =
310 regionBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
311 operandMap.map(pLoopIVs[i], ivConstant);
312 }
313
314 for (auto it = pLoopBody->begin(); it != std::prev(pLoopBody->end());
315 ++it)
316 regionBuilder.clone(*it, operandMap);
317
318 // A terminator should always be inserted in `scf.execute_region`'s block.
319 regionBuilder.create<scf::YieldOp>(loc);
320
321 // Increment indices using `step`.
322 bool done = false;
323 for (int dim = indices.size() - 1; dim >= 0; --dim) {
324 indices[dim] += pLoopSteps[dim];
325 if (indices[dim] < pLoopUpperBounds[dim])
326 break;
327 indices[dim] = pLoopLowerBounds[dim];
328 if (dim == 0)
329 // All combinations have been generated
330 done = true;
331 }
332 if (done)
333 break;
334 }
335
336 rewriter.replaceOp(affineParallelOp, newParallelOp);
337
338 return success();
339 }
340};
341
342struct AffineParallelUnrollPass
343 : public circt::calyx::impl::AffineParallelUnrollBase<
344 AffineParallelUnrollPass> {
345 void getDependentDialects(DialectRegistry &registry) const override {
346 registry.insert<mlir::scf::SCFDialect>();
347 registry.insert<mlir::memref::MemRefDialect>();
348 }
349 void runOnOperation() override;
350};
351
352} // end anonymous namespace
353
354void AffineParallelUnrollPass::runOnOperation() {
355 auto *ctx = &getContext();
356 ConversionTarget target(*ctx);
357
358 RewritePatternSet patterns(ctx);
359 patterns.add<AffineParallelUnroll>(ctx);
360
361 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
362 getOperation()->emitError("Failed to unroll affine.parallel");
363 signalPassFailure();
364 }
365
366 // `AffineParallelUnroll` pattern introduces constant values, so running
367 // `canonicalizePatterns` before `MemoryBankConflictResolver` will help ease
368 // the analysis in `MemoryBankConflictResolver`.
369 PassManager pm(ctx);
371
372 if (failed(pm.run(getOperation()))) {
373 getOperation()->emitError("Nested PassManager failed when running "
374 "ExcludeExecuteRegionCanonicalize pass.");
375 signalPassFailure();
376 }
377
378 getOperation()->walk([&](AffineParallelOp parOp) {
379 if (parOp->hasAttr("calyx.unroll")) {
380 if (failed(MemoryBankConflictResolver().run(parOp))) {
381 parOp.emitError("Failed to unroll");
382 signalPassFailure();
383 }
384 }
385 });
386}
387
388std::unique_ptr<mlir::Pass> circt::calyx::createAffineParallelUnrollPass() {
389 return std::make_unique<AffineParallelUnrollPass>();
390}
std::unique_ptr< mlir::Pass > createAffineParallelUnrollPass()
std::unique_ptr< mlir::Pass > createExcludeExecuteRegionCanonicalizePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:121
bool operator==(const AffineAccessExpr &other) const
SmallVector< Value, 4 > operands
static bool isEqual(const AffineAccessExpr &a, const AffineAccessExpr &b)
static unsigned getHashValue(const AffineAccessExpr &expr)