15#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
16#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Affine/LoopUtils.h"
21#include "mlir/Dialect/Affine/Utils.h"
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/IR/BuiltinDialect.h"
28#include "mlir/IR/Dominance.h"
29#include "mlir/IR/IRMapping.h"
30#include "mlir/IR/ImplicitLocOpBuilder.h"
31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/Support/Debug.h"
39#define DEBUG_TYPE "affine-to-loopschedule"
42#define GEN_PASS_DEF_AFFINETOLOOPSCHEDULE
43#include "circt/Conversion/Passes.h.inc"
47using namespace mlir::arith;
48using namespace mlir::memref;
49using namespace mlir::scf;
55using namespace circt::loopschedule;
59struct AffineToLoopSchedule
60 :
public circt::impl::AffineToLoopScheduleBase<AffineToLoopSchedule> {
61 void runOnOperation()
override;
67 LogicalResult populateOperatorTypes(SmallVectorImpl<AffineForOp> &loopNest,
69 LogicalResult solveSchedulingProblem(SmallVectorImpl<AffineForOp> &loopNest,
72 createLoopSchedulePipeline(SmallVectorImpl<AffineForOp> &loopNest,
82 for (
auto *op : prob.getOperations()) {
84 if (opr.has_value()) {
85 modProb.setLinkedOperatorType(op, opr.value());
87 if (latency.has_value())
88 modProb.setLatency(opr.value(), latency.value());
92 modProb.setLinkedResourceTypes(op, rsrc.value());
93 modProb.insertOperation(op);
96 for (
auto *op : prob.getOperations()) {
97 for (
auto dep : prob.getDependences(op)) {
98 if (dep.isAuxiliary()) {
99 auto depInserted = modProb.insertDependence(dep);
100 assert(succeeded(depInserted));
104 if (distance.has_value())
105 modProb.setDistance(dep, distance.value());
112void AffineToLoopSchedule::runOnOperation() {
114 auto dependenceAnalysis = getAnalysis<MemoryDependenceAnalysis>();
117 if (failed(lowerAffineStructures(dependenceAnalysis)))
118 return signalPassFailure();
121 schedulingAnalysis = &getAnalysis<CyclicSchedulingAnalysis>();
124 auto outerLoops = getOperation().getOps<AffineForOp>();
125 for (
auto root :
llvm::make_early_inc_range(outerLoops)) {
126 SmallVector<AffineForOp> nestedLoops;
127 getPerfectlyNestedLoops(nestedLoops, root);
130 if (nestedLoops.size() != 1)
134 getModuloProblem(schedulingAnalysis->getProblem(nestedLoops.back()));
137 if (failed(populateOperatorTypes(nestedLoops, moduloProblem)))
138 return signalPassFailure();
141 if (failed(solveSchedulingProblem(nestedLoops, moduloProblem)))
142 return signalPassFailure();
145 if (failed(createLoopSchedulePipeline(nestedLoops, moduloProblem)))
146 return signalPassFailure();
163 ConversionPatternRewriter &rewriter)
const override {
165 SmallVector<Value, 8> indices(op.getMapOperands());
166 auto resultOperands =
167 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
172 auto memrefLoad = rewriter.replaceOpWithNewOp<memref::LoadOp>(
173 op, op.getMemRef(), *resultOperands);
197 ConversionPatternRewriter &rewriter)
const override {
199 SmallVector<Value, 8> indices(op.getMapOperands());
200 auto maybeExpandedMap =
201 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
202 if (!maybeExpandedMap)
206 auto memrefStore = rewriter.replaceOpWithNewOp<memref::StoreOp>(
207 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
226 ConversionPatternRewriter &rewriter)
const override {
227 rewriter.modifyOpInPlace(op, [&]() {
228 if (!op.thenBlock()->without_terminator().empty()) {
229 rewriter.splitBlock(op.thenBlock(), --op.thenBlock()->end());
230 rewriter.inlineBlockBefore(&op.getThenRegion().front(), op);
232 if (op.elseBlock() && !op.elseBlock()->without_terminator().empty()) {
233 rewriter.splitBlock(op.elseBlock(), --op.elseBlock()->end());
234 rewriter.inlineBlockBefore(&op.getElseRegion().front(), op);
244 return op.thenBlock()->without_terminator().empty() &&
245 (!op.elseBlock() || op.elseBlock()->without_terminator().empty());
251 return !op->getParentOfType<IfOp>();
261LogicalResult AffineToLoopSchedule::lowerAffineStructures(
263 auto *context = &getContext();
264 auto op = getOperation();
266 ConversionTarget target(*context);
267 target.addLegalDialect<AffineDialect, ArithDialect, MemRefDialect,
269 target.addIllegalOp<AffineIfOp, AffineLoadOp, AffineStoreOp>();
273 RewritePatternSet
patterns(context);
274 populateAffineToStdConversionPatterns(
patterns);
279 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
289LogicalResult AffineToLoopSchedule::populateOperatorTypes(
290 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
292 auto forOp = loopNest.back();
304 Operation *unsupported;
305 WalkResult result = forOp.getBody()->walk([&](Operation *op) {
306 return TypeSwitch<Operation *, WalkResult>(op)
307 .Case<AddIOp, IfOp, AffineYieldOp, arith::ConstantOp, CmpIOp,
308 IndexCastOp, memref::AllocaOp, YieldOp>([&](Operation *combOp) {
311 return WalkResult::advance();
313 .Case<AddIOp, CmpIOp>([&](Operation *seqOp) {
317 return WalkResult::advance();
319 .Case<AffineStoreOp, memref::StoreOp>([&](Operation *memOp) {
323 Value memRef = isa<AffineStoreOp>(*memOp)
324 ? cast<AffineStoreOp>(*memOp).getMemRef()
325 : cast<memref::StoreOp>(*memOp).getMemRef();
332 "mem_" + std::to_string(
hash_value(memRef)) +
"_rsrc");
335 memOp, SmallVector<Problem::ResourceType>{memRsrc});
337 return WalkResult::advance();
339 .Case<AffineLoadOp, memref::LoadOp>([&](Operation *memOp) {
343 Value memRef = isa<AffineLoadOp>(*memOp)
344 ? cast<AffineLoadOp>(*memOp).getMemRef()
345 : cast<memref::LoadOp>(*memOp).getMemRef();
352 "mem_" + std::to_string(
hash_value(memRef)) +
"_rsrc");
355 memOp, SmallVector<Problem::ResourceType>{memRsrc});
357 return WalkResult::advance();
359 .Case<MulIOp>([&](Operation *mcOp) {
362 return WalkResult::advance();
364 .Default([&](Operation *badOp) {
366 return WalkResult::interrupt();
370 if (result.wasInterrupted())
371 return forOp.emitError(
"unsupported operation ") << *unsupported;
377LogicalResult AffineToLoopSchedule::solveSchedulingProblem(
378 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
380 auto forOp = loopNest.back();
383 LLVM_DEBUG(forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
384 llvm::dbgs() <<
"Scheduling inputs for " << *op;
385 auto opr = problem.getLinkedOperatorType(op);
386 llvm::dbgs() <<
"\n opr = " << opr->getAttr();
387 llvm::dbgs() <<
"\n latency = " << problem.getLatency(*opr);
388 for (auto dep : problem.getDependences(op))
389 if (dep.isAuxiliary())
390 llvm::dbgs() <<
"\n dep = { distance = " << problem.getDistance(dep)
391 <<
", source = " << *dep.getSource() <<
" }";
392 llvm::dbgs() <<
"\n\n";
396 if (failed(problem.
check()))
399 auto *anchor = forOp.getBody()->getTerminator();
404 if (failed(problem.
verify()))
409 llvm::dbgs() <<
"Scheduled initiation interval = "
411 forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
412 llvm::dbgs() <<
"Scheduling outputs for " << *op;
413 llvm::dbgs() <<
"\n start = " << problem.
getStartTime(op);
414 llvm::dbgs() <<
"\n\n";
422LogicalResult AffineToLoopSchedule::createLoopSchedulePipeline(
423 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
425 auto forOp = loopNest.back();
427 auto outerLoop = loopNest.front();
428 auto innerLoop = loopNest.back();
429 ImplicitLocOpBuilder builder(outerLoop.getLoc(), outerLoop);
432 Value lowerBound = lowerAffineLowerBound(innerLoop, builder);
433 Value upperBound = lowerAffineUpperBound(innerLoop, builder);
434 int64_t stepValue = innerLoop.getStep().getSExtValue();
435 auto step = builder.create<arith::ConstantOp>(
436 IntegerAttr::get(builder.getIndexType(), stepValue));
440 TypeRange resultTypes = innerLoop.getResultTypes();
444 SmallVector<Value> iterArgs;
445 iterArgs.push_back(lowerBound);
446 iterArgs.append(innerLoop.getInits().begin(), innerLoop.getInits().end());
450 std::optional<IntegerAttr> tripCountAttr;
451 if (
auto tripCount = getConstantTripCount(forOp))
452 tripCountAttr = builder.getI64IntegerAttr(*tripCount);
455 resultTypes, ii, tripCountAttr, iterArgs);
459 Block &condBlock = pipeline.getCondBlock();
460 builder.setInsertionPointToStart(&condBlock);
461 auto cmpResult = builder.create<arith::CmpIOp>(
462 builder.getI1Type(), arith::CmpIPredicate::ult, condBlock.getArgument(0),
464 condBlock.getTerminator()->insertOperands(0, {cmpResult});
467 DenseMap<unsigned, SmallVector<Operation *>> startGroups;
468 for (
auto *op : problem.getOperations()) {
469 if (isa<AffineYieldOp, YieldOp>(op))
472 startGroups[*startTime].push_back(op);
479 assert(iterArgs.size() == forOp.getBody()->getNumArguments());
480 for (
size_t i = 0; i < iterArgs.size(); ++i)
481 valueMap.map(forOp.getBody()->getArgument(i),
482 pipeline.getStagesBlock().getArgument(i));
485 Block &stagesBlock = pipeline.getStagesBlock();
486 builder.setInsertionPointToStart(&stagesBlock);
489 SmallVector<unsigned> startTimes;
490 for (
const auto &group : startGroups)
491 startTimes.push_back(group.first);
492 llvm::sort(startTimes);
494 DominanceInfo dom(getOperation());
497 SmallVector<SmallVector<Value>> registerValues;
498 SmallVector<SmallVector<Type>> registerTypes;
501 SmallVector<IRMapping> stageValueMaps;
504 DenseMap<Operation *, std::pair<unsigned, unsigned>> pipeTimes;
506 for (
auto startTime : startTimes) {
507 auto group = startGroups[startTime];
511 auto isLoopTerminator = [forOp](Operation *op) {
512 return isa<AffineYieldOp>(op) && op->getParentOp() == forOp;
516 for (
unsigned i = registerValues.size(); i <= startTime; ++i)
517 registerValues.emplace_back(SmallVector<Value>());
520 for (
auto *op : group) {
521 if (op->getUsers().empty())
524 unsigned pipeEndTime = 0;
525 for (
auto *user : op->getUsers()) {
528 pipeEndTime = std::max(pipeEndTime, userStartTime);
529 else if (isLoopTerminator(user))
531 pipeEndTime = std::max(pipeEndTime, userStartTime + 1);
535 pipeTimes[op] = std::pair(startTime, pipeEndTime);
538 for (
unsigned i = registerValues.size(); i <= pipeEndTime; ++i)
539 registerValues.push_back(SmallVector<Value>());
542 for (
auto result : op->getResults())
543 registerValues[startTime].push_back(result);
546 unsigned firstUse = std::max(
549 for (
unsigned i = firstUse; i < pipeEndTime; ++i) {
550 for (
auto result : op->getResults())
551 registerValues[i].push_back(result);
557 for (
unsigned i = 0; i < registerValues.size(); ++i) {
558 SmallVector<mlir::Type> types;
559 for (
auto val : registerValues[i])
560 types.push_back(val.getType());
562 registerTypes.push_back(types);
563 stageValueMaps.push_back(valueMap);
567 stageValueMaps.push_back(valueMap);
570 for (
auto startTime : startTimes) {
571 auto group = startGroups[startTime];
573 [&](Operation *a, Operation *b) {
return dom.dominates(a, b); });
574 auto stageTypes = registerTypes[startTime];
577 stageTypes.push_back(lowerBound.getType());
580 builder.setInsertionPoint(stagesBlock.getTerminator());
581 auto startTimeAttr = builder.getIntegerAttr(
582 builder.getIntegerType(64,
true), startTime);
584 builder.create<LoopSchedulePipelineStageOp>(stageTypes, startTimeAttr);
585 auto &stageBlock = stage.getBodyBlock();
586 auto *stageTerminator = stageBlock.getTerminator();
587 builder.setInsertionPointToStart(&stageBlock);
589 for (
auto *op : group) {
590 auto *newOp = builder.clone(*op, stageValueMaps[startTime]);
594 for (
auto result : op->getResults())
595 stageValueMaps[startTime].map(
596 result, newOp->getResult(result.getResultNumber()));
600 SmallVector<Value> stageOperands;
601 unsigned resIndex = 0;
602 for (
auto res : registerValues[startTime]) {
603 stageOperands.push_back(stageValueMaps[startTime].lookup(res));
606 unsigned destTime = startTime + 1;
610 if (*problem.
getStartTime(res.getDefiningOp()) == startTime &&
612 destTime = startTime + latency;
613 destTime = std::min((
unsigned)(stageValueMaps.size() - 1), destTime);
614 stageValueMaps[destTime].map(res, stage.getResult(resIndex++));
617 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
621 if (startTime == 0) {
623 builder.create<arith::AddIOp>(stagesBlock.getArgument(0), step);
624 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
625 incResult->getResults());
630 auto stagesTerminator =
631 cast<LoopScheduleTerminatorOp>(stagesBlock.getTerminator());
635 SmallVector<Value> termIterArgs;
636 SmallVector<Value> termResults;
637 termIterArgs.push_back(
638 stagesBlock.front().getResult(stagesBlock.front().getNumResults() - 1));
640 for (
auto value : forOp.getBody()->getTerminator()->getOperands()) {
641 unsigned lookupTime = std::min((
unsigned)(stageValueMaps.size() - 1),
642 pipeTimes[value.getDefiningOp()].second);
644 termIterArgs.push_back(stageValueMaps[lookupTime].lookup(value));
645 termResults.push_back(stageValueMaps[lookupTime].lookup(value));
648 stagesTerminator.getIterArgsMutable().append(termIterArgs);
649 stagesTerminator.getResultsMutable().append(termResults);
652 for (
size_t i = 0; i < forOp.getNumResults(); ++i)
653 forOp.getResult(i).replaceAllUsesWith(pipeline.getResult(i));
656 loopNest.front().walk([](Operation *op) {
658 op->dropAllDefinedValueUses();
659 op->dropAllReferences();
667 return std::make_unique<AffineToLoopSchedule>();
static bool yieldOpLegalityCallback(AffineYieldOp op)
Helper to mark AffineYieldOp legal, unless it is inside a partially converted scf::IfOp.
static bool ifOpLegalityCallback(IfOp op)
Helper to determine if an scf::IfOp is in mux-like form.
assert(baseType &&"element must be base type")
Apply the affine map from an 'affine.load' operation to its operands, and feed the results to a newly...
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AffineLoadLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
Apply the affine map from an 'affine.store' operation to its operands, and feed the results to a newl...
AffineStoreLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class models a cyclic scheduling problem.
std::optional< unsigned > getDistance(Dependence dep)
The distance determines whether a dependence has to be satisfied in the same iteration (distance=0 or...
std::optional< unsigned > getInitiationInterval()
The initiation interval (II) is the number of time steps between subsequent iterations,...
This class models the modulo scheduling problem as the composition of the cyclic problem and the reso...
virtual LogicalResult verify() override
Return success if the computed solution is valid.
void setLatency(OperatorType opr, unsigned val)
virtual LogicalResult check()
Return success if the constructed scheduling problem is valid.
std::optional< unsigned > getLatency(OperatorType opr)
The latency is the number of cycles opr needs to compute its result.
std::optional< SmallVector< ResourceType > > getLinkedResourceTypes(Operation *op)
The linked resource type provides the available resources for op.
void setLinkedResourceTypes(Operation *op, SmallVector< ResourceType > rsrc)
std::optional< OperatorType > getLinkedOperatorType(Operation *op)
The linked operator type provides the runtime characteristics for op.
OperatorType getOrInsertOperatorType(StringRef name)
Retrieves the operator type identified by the client-specific name.
std::optional< unsigned > getStartTime(Operation *op)
Return the start time for op, as computed by the scheduler.
void setLinkedOperatorType(Operation *op, OperatorType opr)
Operation * getContainingOp()
Return the operation containing this problem, e.g. to emit diagnostics.
ResourceType getOrInsertResourceType(StringRef name)
Retrieves the resource type identified by the client-specific name.
void setLimit(ResourceType rsrc, unsigned val)
static llvm::hash_code hash_value(const ModulePort &port)
LogicalResult scheduleSimplex(Problem &prob, Operation *lastOp)
Solve the basic problem using linear programming and a handwritten implementation of the simplex algo...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createAffineToLoopSchedule()
Helper to hoist computation out of scf::IfOp branches, turning it into a mux-like operation,...
LogicalResult matchAndRewrite(IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
CyclicSchedulingAnalysis constructs a CyclicProblem for each AffineForOp by performing a memory depen...
MemoryDependenceAnalysis traverses any AffineForOps in the FuncOp body and checks for affine memory a...
void replaceOp(Operation *oldOp, Operation *newOp)
Replaces the dependences, if any, from the oldOp to the newOp.
Operator types are distinguished by name (chosen by the client).