15#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/OperationSupport.h"
22#include "mlir/IR/Visitors.h"
23#include "mlir/Pass/PassManager.h"
24#include "mlir/Support/LLVM.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#define GEN_PASS_DEF_AFFINEPARALLELUNROLL
31#include "circt/Dialect/Calyx/CalyxPasses.h.inc"
38using namespace mlir::arith;
75struct MemoryBankConflictResolver {
76 LogicalResult
run(AffineParallelOp affineParallelOp);
81 LogicalResult writeOpAnalysis(AffineParallelOp affineParallelOp);
86 LogicalResult readOpAnalysis(AffineParallelOp affineParallelOp);
90 bool isInvariantToAffineParallel(AffineReadOpInterface readOp,
91 AffineParallelOp affineParallelOp);
94 void accumulateReadWriteOps(AffineParallelOp affineParallelOp);
97 DenseSet<AffineReadOpInterface> allReadOps;
99 DenseSet<AffineWriteOpInterface> allWriteOps;
102 DenseMap<Value, scf::ExecuteRegionOp> writtenMemRefs;
105 DenseMap<AffineAccessExpr, int> readOpCounts;
108 DenseMap<AffineAccessExpr, Value> hoistedReads;
112void MemoryBankConflictResolver::accumulateReadWriteOps(
113 AffineParallelOp affineParallelOp) {
114 auto executeRegionOps =
115 affineParallelOp.getBody()->getOps<scf::ExecuteRegionOp>();
116 for (
auto executeRegionOp : executeRegionOps) {
117 executeRegionOp.walk([&](Operation *op) {
118 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
119 return WalkResult::advance();
121 if (
auto read = dyn_cast<AffineReadOpInterface>(op)) {
122 allReadOps.insert(read);
125 read.getMapOperands()};
128 allWriteOps.insert(cast<AffineWriteOpInterface>(op));
131 return WalkResult::advance();
137MemoryBankConflictResolver::writeOpAnalysis(AffineParallelOp affineParallelOp) {
138 for (
auto writeOp : allWriteOps) {
139 scf::ExecuteRegionOp parentExecuteRegion =
140 writeOp->getParentOfType<scf::ExecuteRegionOp>();
141 auto memref = writeOp.getMemRef();
143 auto it = writtenMemRefs.find(memref);
144 if (it != writtenMemRefs.end() && it->second != parentExecuteRegion) {
145 writeOp.emitError(
"Multiple writes to the same memory reference");
149 writtenMemRefs[memref] = parentExecuteRegion;
155bool MemoryBankConflictResolver::isInvariantToAffineParallel(
156 AffineReadOpInterface readOp, AffineParallelOp affineParallelOp) {
158 for (Value iv : affineParallelOp.getIVs()) {
159 for (Value operand : readOp.getMapOperands()) {
164 if (Operation *def = operand.getDefiningOp()) {
165 if (affineParallelOp->isAncestor(def) && !isa<arith::ConstantOp>(def)) {
174 Value memref = readOp.getMemRef();
175 for (Operation *user : memref.getUsers()) {
176 if (user == readOp.getOperation())
179 if (affineParallelOp->isAncestor(user) &&
180 hasEffect<MemoryEffects::Write>(user, memref))
188MemoryBankConflictResolver::readOpAnalysis(AffineParallelOp affineParallelOp) {
189 OpBuilder builder(affineParallelOp);
190 for (
auto readOp : allReadOps) {
192 readOp.getMapOperands()};
193 auto it = readOpCounts.find(key);
194 if (it == readOpCounts.end() || it->second <= 1 ||
195 !isInvariantToAffineParallel(readOp, affineParallelOp))
199 auto hoistedReadsIt = hoistedReads.find(key);
200 if (hoistedReadsIt == hoistedReads.end()) {
201 builder.setInsertionPoint(affineParallelOp);
202 Operation *cloned = builder.clone(*readOp.getOperation());
203 hoistedReadsIt = hoistedReads.insert({key, cloned->getResult(0)}).first;
206 readOp->getResult(0).replaceAllUsesWith(hoistedReadsIt->second);
207 readOp.getOperation()->erase();
214MemoryBankConflictResolver::run(AffineParallelOp affineParallelOp) {
215 accumulateReadWriteOps(affineParallelOp);
217 if (failed(writeOpAnalysis(affineParallelOp))) {
221 if (failed(readOpAnalysis(affineParallelOp))) {
231 using OpRewritePattern::OpRewritePattern;
233 LogicalResult matchAndRewrite(AffineParallelOp affineParallelOp,
234 PatternRewriter &rewriter)
const override {
235 if (affineParallelOp->hasAttr(
"calyx.unroll"))
240 if (!affineParallelOp.getResults().empty()) {
241 affineParallelOp.emitError(
242 "affine.parallel with reductions is not supported yet");
246 Location loc = affineParallelOp.getLoc();
248 rewriter.setInsertionPointAfter(affineParallelOp);
251 AffineMap lbMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(0),
252 rewriter.getContext());
253 AffineMap ubMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(1),
254 rewriter.getContext());
255 auto newParallelOp = AffineParallelOp::create(
256 rewriter, loc, TypeRange(),
257 SmallVector<arith::AtomicRMWKind>(),
258 lbMap, SmallVector<Value>(),
259 ubMap, SmallVector<Value>(),
260 SmallVector<int64_t>({1}));
261 newParallelOp->setAttr(
"calyx.unroll", rewriter.getBoolAttr(
true));
263 SmallVector<int64_t> pLoopLowerBounds =
264 affineParallelOp.getLowerBoundsMap().getConstantResults();
265 if (pLoopLowerBounds.empty()) {
266 affineParallelOp.emitError(
267 "affine.parallel must have constant lower bounds");
270 SmallVector<int64_t> pLoopUpperBounds =
271 affineParallelOp.getUpperBoundsMap().getConstantResults();
272 if (pLoopUpperBounds.empty()) {
273 affineParallelOp.emitError(
274 "affine.parallel must have constant upper bounds");
277 SmallVector<int64_t, 8> pLoopSteps = affineParallelOp.getSteps();
279 Block *pLoopBody = affineParallelOp.getBody();
280 MutableArrayRef<BlockArgument> pLoopIVs = affineParallelOp.getIVs();
282 OpBuilder insideBuilder(newParallelOp);
283 SmallVector<int64_t> indices = pLoopLowerBounds;
285 insideBuilder.setInsertionPointToStart(newParallelOp.getBody());
288 auto executeRegionOp =
289 scf::ExecuteRegionOp::create(insideBuilder, loc, TypeRange{});
290 Region &executeRegionRegion = executeRegionOp.getRegion();
291 Block *executeRegionBlock = &executeRegionRegion.emplaceBlock();
293 OpBuilder regionBuilder(executeRegionOp);
297 IRMapping operandMap;
298 regionBuilder.setInsertionPointToEnd(executeRegionBlock);
300 for (
unsigned i = 0; i < indices.size(); ++i) {
302 arith::ConstantIndexOp::create(regionBuilder, loc, indices[i]);
303 operandMap.map(pLoopIVs[i], ivConstant);
306 for (
auto it = pLoopBody->begin(); it != std::prev(pLoopBody->end());
308 regionBuilder.clone(*it, operandMap);
311 scf::YieldOp::create(regionBuilder, loc);
315 for (
int dim = indices.size() - 1; dim >= 0; --dim) {
316 indices[dim] += pLoopSteps[dim];
317 if (indices[dim] < pLoopUpperBounds[dim])
319 indices[dim] = pLoopLowerBounds[dim];
328 rewriter.replaceOp(affineParallelOp, newParallelOp);
334struct AffineParallelUnrollPass
335 :
public circt::calyx::impl::AffineParallelUnrollBase<
336 AffineParallelUnrollPass> {
337 void getDependentDialects(DialectRegistry ®istry)
const override {
338 registry.insert<mlir::scf::SCFDialect>();
339 registry.insert<mlir::memref::MemRefDialect>();
341 void runOnOperation()
override;
346void AffineParallelUnrollPass::runOnOperation() {
347 auto *ctx = &getContext();
348 ConversionTarget target(*ctx);
351 patterns.add<AffineParallelUnroll>(ctx);
353 if (failed(applyPatternsGreedily(getOperation(), std::move(
patterns)))) {
354 getOperation()->emitError(
"Failed to unroll affine.parallel");
364 if (failed(pm.run(getOperation()))) {
365 getOperation()->emitError(
"Nested PassManager failed when running "
366 "ExcludeExecuteRegionCanonicalize pass.");
370 getOperation()->walk([&](AffineParallelOp parOp) {
371 if (parOp->hasAttr(
"calyx.unroll")) {
372 if (failed(MemoryBankConflictResolver().run(parOp))) {
373 parOp.emitError(
"Failed to unroll");
381 return std::make_unique<AffineParallelUnrollPass>();
std::unique_ptr< mlir::Pass > createAffineParallelUnrollPass()
std::unique_ptr< mlir::Pass > createExcludeExecuteRegionCanonicalizePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
bool operator==(const AffineAccessExpr &other) const
SmallVector< Value, 4 > operands
static bool isEqual(const AffineAccessExpr &a, const AffineAccessExpr &b)
static unsigned getHashValue(const AffineAccessExpr &expr)