15#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/OperationSupport.h"
21#include "mlir/IR/Visitors.h"
22#include "mlir/Pass/PassManager.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#define GEN_PASS_DEF_AFFINEPARALLELUNROLL
30#include "circt/Dialect/Calyx/CalyxPasses.h.inc"
37using namespace mlir::arith;
45struct MemoryBankConflictResolver {
46 LogicalResult
run(AffineParallelOp affineParallelOp);
50 DenseMap<Operation *, SmallVector<int64_t, 4>>
51 computeConstMemAccessIndices(AffineParallelOp affineParallelOp);
57 writeOpAnalysis(DenseMap<Operation *, SmallVector<int64_t, 4>> &,
58 AffineParallelOp affineParallelOp);
63 void readOpHoistAnalysis(AffineParallelOp affineParallelOp,
64 DenseMap<Operation *, SmallVector<int64_t, 4>> &);
68 DenseMap<std::pair<Value, SmallVector<int64_t, 4>>, scf::ExecuteRegionOp>
69 constantWriteOpIndices;
73 DenseMap<std::pair<Value, SmallVector<int64_t, 4>>,
int> constantReadOpCounts;
80 using PairType = std::pair<Value, SmallVector<int64_t, 4>>;
92 for (
const auto &v : pair.second)
98 return lhs.first == rhs.first && lhs.second == rhs.second;
103DenseMap<Operation *, SmallVector<int64_t, 4>>
104MemoryBankConflictResolver::computeConstMemAccessIndices(
105 AffineParallelOp affineParallelOp) {
106 DenseMap<Operation *, SmallVector<int64_t, 4>> constantMemAccessIndices;
108 MLIRContext *ctx = affineParallelOp->getContext();
109 auto executeRegionOps =
110 affineParallelOp.getBody()->getOps<scf::ExecuteRegionOp>();
111 for (
auto executeRegionOp : executeRegionOps) {
112 executeRegionOp.walk([&](Operation *op) {
113 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
114 return WalkResult::advance();
116 auto read = dyn_cast<AffineReadOpInterface>(op);
117 AffineMap map = read ? read.getAffineMap()
118 : cast<AffineWriteOpInterface>(op).getAffineMap();
119 ValueRange mapOperands =
120 read ? read.getMapOperands()
121 : cast<AffineWriteOpInterface>(op).getMapOperands();
123 SmallVector<Attribute> operandConsts;
124 for (Value operand : mapOperands) {
126 operand.template getDefiningOp<arith::ConstantIndexOp>()) {
127 operandConsts.push_back(
128 IntegerAttr::get(IndexType::get(ctx), constOp.value()));
130 return WalkResult::advance();
134 SmallVector<int64_t, 4> evaluatedIndices, foldedResults;
135 bool hasPoison =
false;
136 map.partialConstantFold(operandConsts, &foldedResults, &hasPoison);
137 if (!(hasPoison || foldedResults.empty()))
138 constantMemAccessIndices[op] = foldedResults;
140 return WalkResult::advance();
144 return constantMemAccessIndices;
147LogicalResult MemoryBankConflictResolver::writeOpAnalysis(
148 DenseMap<Operation *, SmallVector<int64_t, 4>> &constantMemAccessIndices,
149 AffineParallelOp affineParallelOp) {
150 auto executeRegionOps =
151 affineParallelOp.getBody()->getOps<scf::ExecuteRegionOp>();
153 for (
auto executeRegionOp : executeRegionOps) {
154 auto walkResult = executeRegionOp.walk([&](Operation *op) {
155 if (!isa<AffineWriteOpInterface>(op))
156 return WalkResult::advance();
158 auto writeOp = cast<AffineWriteOpInterface>(op);
160 auto constIndicesIt = constantMemAccessIndices.find(op);
161 if (constIndicesIt == constantMemAccessIndices.end())
164 return WalkResult::advance();
166 auto parentExecuteRegionOp =
167 writeOp->getParentOfType<scf::ExecuteRegionOp>();
168 auto key = std::pair(writeOp.getMemRef(), constIndicesIt->second);
169 auto [writeOpIndicesIt, emplaced] =
170 constantWriteOpIndices.try_emplace(key, parentExecuteRegionOp);
171 if (!emplaced && writeOpIndicesIt->second != parentExecuteRegionOp) {
176 return WalkResult::interrupt();
178 return WalkResult::advance();
181 if (walkResult.wasInterrupted())
188void MemoryBankConflictResolver::readOpHoistAnalysis(
189 AffineParallelOp affineParallelOp,
190 DenseMap<Operation *, SmallVector<int64_t, 4>> &constantMemAccessIndices) {
191 for (
auto &[memOp, constIndices] : constantMemAccessIndices) {
192 auto readOp = dyn_cast<AffineReadOpInterface>(memOp);
196 auto memref = readOp.getMemRef();
197 auto key = std::pair(memref, constIndices);
200 if (llvm::any_of(memref.getUsers(), [&](Operation *user) {
201 return affineParallelOp->isAncestor(user) &&
202 hasEffect<MemoryEffects::Write>(user, memref);
206 constantReadOpCounts[key]++;
209 bool shouldHoist = llvm::any_of(
210 constantReadOpCounts, [](
const auto &entry) {
return entry.second > 1; });
214 OpBuilder builder(affineParallelOp);
215 DenseMap<std::pair<Value, SmallVector<int64_t, 4>>, ValueRange> hoistedReads;
216 for (
auto &[memOp, constIndices] : constantMemAccessIndices) {
217 auto readOp = dyn_cast<AffineReadOpInterface>(memOp);
221 auto key = std::pair(readOp.getMemRef(), constIndices);
222 if (constantReadOpCounts[key] > 1) {
223 if (hoistedReads.find(key) == hoistedReads.end()) {
224 builder.setInsertionPoint(affineParallelOp);
225 Operation *clonedRead = builder.clone(*readOp.getOperation());
226 hoistedReads[key] = clonedRead->getOpResults();
228 readOp->replaceAllUsesWith(hoistedReads[key]);
234MemoryBankConflictResolver::run(AffineParallelOp affineParallelOp) {
235 auto constantMemAccessIndices =
236 computeConstMemAccessIndices(affineParallelOp);
238 if (failed(writeOpAnalysis(constantMemAccessIndices, affineParallelOp))) {
242 readOpHoistAnalysis(affineParallelOp, constantMemAccessIndices);
250 using OpRewritePattern::OpRewritePattern;
252 LogicalResult matchAndRewrite(AffineParallelOp affineParallelOp,
253 PatternRewriter &rewriter)
const override {
254 if (affineParallelOp->hasAttr(
"calyx.unroll"))
259 if (!affineParallelOp.getResults().empty()) {
260 affineParallelOp.emitError(
261 "affine.parallel with reductions is not supported yet");
265 Location loc = affineParallelOp.getLoc();
267 rewriter.setInsertionPointAfter(affineParallelOp);
270 AffineMap lbMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(0),
271 rewriter.getContext());
272 AffineMap ubMap = AffineMap::get(0, 0, rewriter.getAffineConstantExpr(1),
273 rewriter.getContext());
274 auto newParallelOp = rewriter.create<AffineParallelOp>(
276 SmallVector<arith::AtomicRMWKind>(),
277 lbMap, SmallVector<Value>(),
278 ubMap, SmallVector<Value>(),
279 SmallVector<int64_t>({1}));
280 newParallelOp->setAttr(
"calyx.unroll", rewriter.getBoolAttr(
true));
282 SmallVector<int64_t> pLoopLowerBounds =
283 affineParallelOp.getLowerBoundsMap().getConstantResults();
284 if (pLoopLowerBounds.empty()) {
285 affineParallelOp.emitError(
286 "affine.parallel must have constant lower bounds");
289 SmallVector<int64_t> pLoopUpperBounds =
290 affineParallelOp.getUpperBoundsMap().getConstantResults();
291 if (pLoopUpperBounds.empty()) {
292 affineParallelOp.emitError(
293 "affine.parallel must have constant upper bounds");
296 SmallVector<int64_t, 8> pLoopSteps = affineParallelOp.getSteps();
298 Block *pLoopBody = affineParallelOp.getBody();
299 MutableArrayRef<BlockArgument> pLoopIVs = affineParallelOp.getIVs();
301 OpBuilder insideBuilder(newParallelOp);
302 SmallVector<int64_t> indices = pLoopLowerBounds;
304 insideBuilder.setInsertionPointToStart(newParallelOp.getBody());
307 auto executeRegionOp =
308 insideBuilder.create<scf::ExecuteRegionOp>(loc, TypeRange{});
309 Region &executeRegionRegion = executeRegionOp.getRegion();
310 Block *executeRegionBlock = &executeRegionRegion.emplaceBlock();
312 OpBuilder regionBuilder(executeRegionOp);
316 IRMapping operandMap;
317 regionBuilder.setInsertionPointToEnd(executeRegionBlock);
319 for (
unsigned i = 0; i < indices.size(); ++i) {
321 regionBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
322 operandMap.map(pLoopIVs[i], ivConstant);
325 for (
auto it = pLoopBody->begin(); it != std::prev(pLoopBody->end());
327 regionBuilder.clone(*it, operandMap);
330 regionBuilder.create<scf::YieldOp>(loc);
334 for (
int dim = indices.size() - 1; dim >= 0; --dim) {
335 indices[dim] += pLoopSteps[dim];
336 if (indices[dim] < pLoopUpperBounds[dim])
338 indices[dim] = pLoopLowerBounds[dim];
347 rewriter.replaceOp(affineParallelOp, newParallelOp);
353struct AffineParallelUnrollPass
354 :
public circt::calyx::impl::AffineParallelUnrollBase<
355 AffineParallelUnrollPass> {
356 void getDependentDialects(DialectRegistry ®istry)
const override {
357 registry.insert<mlir::scf::SCFDialect>();
359 void runOnOperation()
override;
364void AffineParallelUnrollPass::runOnOperation() {
365 auto *ctx = &getContext();
366 ConversionTarget target(*ctx);
369 patterns.add<AffineParallelUnroll>(ctx);
371 if (failed(applyPatternsGreedily(getOperation(), std::move(
patterns)))) {
372 getOperation()->emitError(
"Failed to unroll affine.parallel");
379 RewritePatternSet canonicalizePatterns(ctx);
380 scf::IndexSwitchOp::getCanonicalizationPatterns(canonicalizePatterns, ctx);
381 if (failed(applyPatternsGreedily(getOperation(),
382 std::move(canonicalizePatterns)))) {
383 getOperation()->emitError(
"Failed to apply canonicalization.");
387 getOperation()->walk([&](AffineParallelOp parOp) {
388 if (parOp->hasAttr(
"calyx.unroll")) {
389 if (failed(MemoryBankConflictResolver().run(parOp))) {
390 parOp.emitError(
"Failed to unroll");
398 return std::make_unique<AffineParallelUnrollPass>();
std::unique_ptr< mlir::Pass > createAffineParallelUnrollPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
static bool isEqual(const PairType &lhs, const PairType &rhs)
static PairType getTombstoneKey()
static unsigned getHashValue(const PairType &pair)
static PairType getEmptyKey()
std::pair< Value, SmallVector< int64_t, 4 > > PairType