15#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/OperationSupport.h"
22#include "mlir/IR/Visitors.h"
23#include "mlir/Pass/PassManager.h"
24#include "mlir/Support/LLVM.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#define GEN_PASS_DEF_AFFINEPARALLELUNROLL
31#include "circt/Dialect/Calyx/CalyxPasses.h.inc"
38using namespace mlir::arith;
83struct MemoryBankConflictResolver {
84 LogicalResult
run(AffineParallelOp affineParallelOp);
89 LogicalResult writeOpAnalysis(AffineParallelOp affineParallelOp);
94 LogicalResult readOpAnalysis(AffineParallelOp affineParallelOp);
98 bool isInvariantToAffineParallel(AffineReadOpInterface readOp,
99 AffineParallelOp affineParallelOp);
102 void accumulateReadWriteOps(AffineParallelOp affineParallelOp);
105 DenseSet<AffineReadOpInterface> allReadOps;
107 DenseSet<AffineWriteOpInterface> allWriteOps;
110 DenseMap<Value, scf::ExecuteRegionOp> writtenMemRefs;
113 DenseMap<AffineAccessExpr, int> readOpCounts;
116 DenseMap<AffineAccessExpr, Value> hoistedReads;
120void MemoryBankConflictResolver::accumulateReadWriteOps(
121 AffineParallelOp affineParallelOp) {
122 auto executeRegionOps =
123 affineParallelOp.getBody()->getOps<scf::ExecuteRegionOp>();
124 for (
auto executeRegionOp : executeRegionOps) {
125 executeRegionOp.walk([&](Operation *op) {
126 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
127 return WalkResult::advance();
129 if (
auto read = dyn_cast<AffineReadOpInterface>(op)) {
130 allReadOps.insert(read);
133 read.getMapOperands()};
136 allWriteOps.insert(cast<AffineWriteOpInterface>(op));
139 return WalkResult::advance();
145MemoryBankConflictResolver::writeOpAnalysis(AffineParallelOp affineParallelOp) {
146 for (
auto writeOp : allWriteOps) {
147 scf::ExecuteRegionOp parentExecuteRegion =
148 writeOp->getParentOfType<scf::ExecuteRegionOp>();
149 auto memref = writeOp.getMemRef();
151 auto it = writtenMemRefs.find(memref);
152 if (it != writtenMemRefs.end() && it->second != parentExecuteRegion) {
153 writeOp.emitError(
"Multiple writes to the same memory reference");
157 writtenMemRefs[memref] = parentExecuteRegion;
163bool MemoryBankConflictResolver::isInvariantToAffineParallel(
164 AffineReadOpInterface readOp, AffineParallelOp affineParallelOp) {
166 for (Value iv : affineParallelOp.getIVs()) {
167 for (Value operand : readOp.getMapOperands()) {
172 if (Operation *def = operand.getDefiningOp()) {
173 if (affineParallelOp->isAncestor(def) && !isa<arith::ConstantOp>(def)) {
182 Value memref = readOp.getMemRef();
183 for (Operation *user : memref.getUsers()) {
184 if (user == readOp.getOperation())
187 if (affineParallelOp->isAncestor(user) &&
188 hasEffect<MemoryEffects::Write>(user, memref))
196MemoryBankConflictResolver::readOpAnalysis(AffineParallelOp affineParallelOp) {
197 OpBuilder builder(affineParallelOp);
198 for (
auto readOp : allReadOps) {
200 readOp.getMapOperands()};
201 auto it = readOpCounts.find(key);
202 if (it == readOpCounts.end() || it->second <= 1 ||
203 !isInvariantToAffineParallel(readOp, affineParallelOp))
207 auto hoistedReadsIt = hoistedReads.find(key);
208 if (hoistedReadsIt == hoistedReads.end()) {
209 builder.setInsertionPoint(affineParallelOp);
210 Operation *cloned = builder.clone(*readOp.getOperation());
211 hoistedReadsIt = hoistedReads.insert({key, cloned->getResult(0)}).first;
214 readOp->getResult(0).replaceAllUsesWith(hoistedReadsIt->second);
215 readOp.getOperation()->erase();
222MemoryBankConflictResolver::run(AffineParallelOp affineParallelOp) {
223 accumulateReadWriteOps(affineParallelOp);
225 if (failed(writeOpAnalysis(affineParallelOp))) {
229 if (failed(readOpAnalysis(affineParallelOp))) {
239 using OpRewritePattern::OpRewritePattern;
241 LogicalResult matchAndRewrite(AffineParallelOp affineParallelOp,
242 PatternRewriter &rewriter)
const override {
243 if (affineParallelOp->hasAttr(
"calyx.unroll"))
248 if (!affineParallelOp.getResults().empty()) {
249 affineParallelOp.emitError(
250 "affine.parallel with reductions is not supported yet");
254 Location loc = affineParallelOp.getLoc();
256 rewriter.setInsertionPointAfter(affineParallelOp);
259 AffineMap lbMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(0),
260 rewriter.getContext());
261 AffineMap ubMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(1),
262 rewriter.getContext());
263 auto newParallelOp = rewriter.create<AffineParallelOp>(
265 SmallVector<arith::AtomicRMWKind>(),
266 lbMap, SmallVector<Value>(),
267 ubMap, SmallVector<Value>(),
268 SmallVector<int64_t>({1}));
269 newParallelOp->setAttr(
"calyx.unroll", rewriter.getBoolAttr(
true));
271 SmallVector<int64_t> pLoopLowerBounds =
272 affineParallelOp.getLowerBoundsMap().getConstantResults();
273 if (pLoopLowerBounds.empty()) {
274 affineParallelOp.emitError(
275 "affine.parallel must have constant lower bounds");
278 SmallVector<int64_t> pLoopUpperBounds =
279 affineParallelOp.getUpperBoundsMap().getConstantResults();
280 if (pLoopUpperBounds.empty()) {
281 affineParallelOp.emitError(
282 "affine.parallel must have constant upper bounds");
285 SmallVector<int64_t, 8> pLoopSteps = affineParallelOp.getSteps();
287 Block *pLoopBody = affineParallelOp.getBody();
288 MutableArrayRef<BlockArgument> pLoopIVs = affineParallelOp.getIVs();
290 OpBuilder insideBuilder(newParallelOp);
291 SmallVector<int64_t> indices = pLoopLowerBounds;
293 insideBuilder.setInsertionPointToStart(newParallelOp.getBody());
296 auto executeRegionOp =
297 insideBuilder.create<scf::ExecuteRegionOp>(loc, TypeRange{});
298 Region &executeRegionRegion = executeRegionOp.getRegion();
299 Block *executeRegionBlock = &executeRegionRegion.emplaceBlock();
301 OpBuilder regionBuilder(executeRegionOp);
305 IRMapping operandMap;
306 regionBuilder.setInsertionPointToEnd(executeRegionBlock);
308 for (
unsigned i = 0; i < indices.size(); ++i) {
310 regionBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
311 operandMap.map(pLoopIVs[i], ivConstant);
314 for (
auto it = pLoopBody->begin(); it != std::prev(pLoopBody->end());
316 regionBuilder.clone(*it, operandMap);
319 regionBuilder.create<scf::YieldOp>(loc);
323 for (
int dim = indices.size() - 1; dim >= 0; --dim) {
324 indices[dim] += pLoopSteps[dim];
325 if (indices[dim] < pLoopUpperBounds[dim])
327 indices[dim] = pLoopLowerBounds[dim];
336 rewriter.replaceOp(affineParallelOp, newParallelOp);
342struct AffineParallelUnrollPass
343 :
public circt::calyx::impl::AffineParallelUnrollBase<
344 AffineParallelUnrollPass> {
345 void getDependentDialects(DialectRegistry ®istry)
const override {
346 registry.insert<mlir::scf::SCFDialect>();
347 registry.insert<mlir::memref::MemRefDialect>();
349 void runOnOperation()
override;
354void AffineParallelUnrollPass::runOnOperation() {
355 auto *ctx = &getContext();
356 ConversionTarget target(*ctx);
359 patterns.add<AffineParallelUnroll>(ctx);
361 if (failed(applyPatternsGreedily(getOperation(), std::move(
patterns)))) {
362 getOperation()->emitError(
"Failed to unroll affine.parallel");
372 if (failed(pm.run(getOperation()))) {
373 getOperation()->emitError(
"Nested PassManager failed when running "
374 "ExcludeExecuteRegionCanonicalize pass.");
378 getOperation()->walk([&](AffineParallelOp parOp) {
379 if (parOp->hasAttr(
"calyx.unroll")) {
380 if (failed(MemoryBankConflictResolver().run(parOp))) {
381 parOp.emitError(
"Failed to unroll");
389 return std::make_unique<AffineParallelUnrollPass>();
std::unique_ptr< mlir::Pass > createAffineParallelUnrollPass()
std::unique_ptr< mlir::Pass > createExcludeExecuteRegionCanonicalizePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
bool operator==(const AffineAccessExpr &other) const
SmallVector< Value, 4 > operands
static AffineAccessExpr getTombstoneKey()
static AffineAccessExpr getEmptyKey()
static bool isEqual(const AffineAccessExpr &a, const AffineAccessExpr &b)
static unsigned getHashValue(const AffineAccessExpr &expr)