15#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
16#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Affine/LoopUtils.h"
21#include "mlir/Dialect/Affine/Utils.h"
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/IR/BuiltinDialect.h"
28#include "mlir/IR/Dominance.h"
29#include "mlir/IR/IRMapping.h"
30#include "mlir/IR/ImplicitLocOpBuilder.h"
31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/Support/Debug.h"
39#define DEBUG_TYPE "affine-to-loopschedule"
42#define GEN_PASS_DEF_AFFINETOLOOPSCHEDULE
43#include "circt/Conversion/Passes.h.inc"
47using namespace mlir::arith;
48using namespace mlir::memref;
49using namespace mlir::scf;
55using namespace circt::loopschedule;
59struct AffineToLoopSchedule
60 :
public circt::impl::AffineToLoopScheduleBase<AffineToLoopSchedule> {
61 void runOnOperation()
override;
67 LogicalResult populateOperatorTypes(SmallVectorImpl<AffineForOp> &loopNest,
69 LogicalResult solveSchedulingProblem(SmallVectorImpl<AffineForOp> &loopNest,
72 createLoopSchedulePipeline(SmallVectorImpl<AffineForOp> &loopNest,
82 for (
auto *op : prob.getOperations()) {
84 if (opr.has_value()) {
85 modProb.setLinkedOperatorType(op, opr.value());
87 if (latency.has_value())
88 modProb.setLatency(opr.value(), latency.value());
92 modProb.setLinkedResourceTypes(op, rsrc.value());
93 modProb.insertOperation(op);
96 for (
auto *op : prob.getOperations()) {
97 for (
auto dep : prob.getDependences(op)) {
98 if (dep.isAuxiliary()) {
99 auto depInserted = modProb.insertDependence(dep);
100 assert(succeeded(depInserted));
104 if (distance.has_value())
105 modProb.setDistance(dep, distance.value());
112void AffineToLoopSchedule::runOnOperation() {
114 auto dependenceAnalysis = getAnalysis<MemoryDependenceAnalysis>();
117 if (failed(lowerAffineStructures(dependenceAnalysis)))
118 return signalPassFailure();
121 schedulingAnalysis = &getAnalysis<CyclicSchedulingAnalysis>();
124 auto outerLoops = getOperation().getOps<AffineForOp>();
125 for (
auto root :
llvm::make_early_inc_range(outerLoops)) {
126 SmallVector<AffineForOp> nestedLoops;
127 getPerfectlyNestedLoops(nestedLoops, root);
130 if (nestedLoops.size() != 1)
134 getModuloProblem(schedulingAnalysis->getProblem(nestedLoops.back()));
137 if (failed(populateOperatorTypes(nestedLoops, moduloProblem)))
138 return signalPassFailure();
141 if (failed(solveSchedulingProblem(nestedLoops, moduloProblem)))
142 return signalPassFailure();
145 if (failed(createLoopSchedulePipeline(nestedLoops, moduloProblem)))
146 return signalPassFailure();
163 ConversionPatternRewriter &rewriter)
const override {
165 SmallVector<Value, 8> indices(op.getMapOperands());
166 auto resultOperands =
167 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
172 auto memrefLoad = rewriter.replaceOpWithNewOp<memref::LoadOp>(
173 op, op.getMemRef(), *resultOperands);
197 ConversionPatternRewriter &rewriter)
const override {
199 SmallVector<Value, 8> indices(op.getMapOperands());
200 auto maybeExpandedMap =
201 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
202 if (!maybeExpandedMap)
206 auto memrefStore = rewriter.replaceOpWithNewOp<memref::StoreOp>(
207 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
226 ConversionPatternRewriter &rewriter)
const override {
227 rewriter.modifyOpInPlace(op, [&]() {
228 if (!op.thenBlock()->without_terminator().empty()) {
229 rewriter.splitBlock(op.thenBlock(), --op.thenBlock()->end());
230 rewriter.inlineBlockBefore(&op.getThenRegion().front(), op);
232 if (op.elseBlock() && !op.elseBlock()->without_terminator().empty()) {
233 rewriter.splitBlock(op.elseBlock(), --op.elseBlock()->end());
234 rewriter.inlineBlockBefore(&op.getElseRegion().front(), op);
244 return op.thenBlock()->without_terminator().empty() &&
245 (!op.elseBlock() || op.elseBlock()->without_terminator().empty());
251 return !op->getParentOfType<IfOp>();
261LogicalResult AffineToLoopSchedule::lowerAffineStructures(
263 auto *context = &getContext();
264 auto op = getOperation();
266 ConversionTarget target(*context);
267 target.addLegalDialect<AffineDialect, ArithDialect, MemRefDialect,
269 target.addIllegalOp<AffineIfOp, AffineLoadOp, AffineStoreOp>();
273 RewritePatternSet
patterns(context);
274 populateAffineToStdConversionPatterns(
patterns);
279 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
289LogicalResult AffineToLoopSchedule::populateOperatorTypes(
290 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
292 auto forOp = loopNest.back();
304 Operation *unsupported;
305 WalkResult result = forOp.getBody()->walk([&](Operation *op) {
306 return TypeSwitch<Operation *, WalkResult>(op)
307 .Case<AddIOp, IfOp, AffineYieldOp, arith::ConstantOp, CmpIOp,
308 IndexCastOp, memref::AllocaOp, YieldOp>([&](Operation *combOp) {
311 return WalkResult::advance();
313 .Case<AddIOp, CmpIOp>([&](Operation *seqOp) {
317 return WalkResult::advance();
319 .Case<AffineStoreOp, memref::StoreOp>([&](Operation *memOp) {
323 Value memRef = isa<AffineStoreOp>(*memOp)
324 ? cast<AffineStoreOp>(*memOp).getMemRef()
325 : cast<memref::StoreOp>(*memOp).getMemRef();
332 "mem_" + std::to_string(
hash_value(memRef)) +
"_rsrc");
335 memOp, SmallVector<Problem::ResourceType>{memRsrc});
337 return WalkResult::advance();
339 .Case<AffineLoadOp, memref::LoadOp>([&](Operation *memOp) {
343 Value memRef = isa<AffineLoadOp>(*memOp)
344 ? cast<AffineLoadOp>(*memOp).getMemRef()
345 : cast<memref::LoadOp>(*memOp).getMemRef();
352 "mem_" + std::to_string(
hash_value(memRef)) +
"_rsrc");
355 memOp, SmallVector<Problem::ResourceType>{memRsrc});
357 return WalkResult::advance();
359 .Case<MulIOp>([&](Operation *mcOp) {
362 return WalkResult::advance();
364 .Default([&](Operation *badOp) {
366 return WalkResult::interrupt();
370 if (result.wasInterrupted())
371 return forOp.emitError(
"unsupported operation ") << *unsupported;
377LogicalResult AffineToLoopSchedule::solveSchedulingProblem(
378 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
380 auto forOp = loopNest.back();
383 LLVM_DEBUG(forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
384 llvm::dbgs() <<
"Scheduling inputs for " << *op;
385 auto opr = problem.getLinkedOperatorType(op);
386 llvm::dbgs() <<
"\n opr = " << opr->getAttr();
387 llvm::dbgs() <<
"\n latency = " << problem.getLatency(*opr);
388 for (auto dep : problem.getDependences(op))
389 if (dep.isAuxiliary())
390 llvm::dbgs() <<
"\n dep = { distance = " << problem.getDistance(dep)
391 <<
", source = " << *dep.getSource() <<
" }";
392 llvm::dbgs() <<
"\n\n";
396 if (failed(problem.
check()))
399 auto *anchor = forOp.getBody()->getTerminator();
404 if (failed(problem.
verify()))
409 llvm::dbgs() <<
"Scheduled initiation interval = "
411 forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
412 llvm::dbgs() <<
"Scheduling outputs for " << *op;
413 llvm::dbgs() <<
"\n start = " << problem.
getStartTime(op);
414 llvm::dbgs() <<
"\n\n";
422LogicalResult AffineToLoopSchedule::createLoopSchedulePipeline(
423 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
425 auto forOp = loopNest.back();
427 auto outerLoop = loopNest.front();
428 auto innerLoop = loopNest.back();
429 ImplicitLocOpBuilder builder(outerLoop.getLoc(), outerLoop);
432 Value lowerBound = lowerAffineLowerBound(innerLoop, builder);
433 Value upperBound = lowerAffineUpperBound(innerLoop, builder);
434 int64_t stepValue = innerLoop.getStep().getSExtValue();
435 auto step = arith::ConstantOp::create(
436 builder, IntegerAttr::get(builder.getIndexType(), stepValue));
440 TypeRange resultTypes = innerLoop.getResultTypes();
444 SmallVector<Value> iterArgs;
445 iterArgs.push_back(lowerBound);
446 iterArgs.append(innerLoop.getInits().begin(), innerLoop.getInits().end());
450 std::optional<IntegerAttr> tripCountAttr;
451 if (
auto tripCount = getConstantTripCount(forOp))
452 tripCountAttr = builder.getI64IntegerAttr(*tripCount);
454 auto pipeline = LoopSchedulePipelineOp::create(builder, resultTypes, ii,
455 tripCountAttr, iterArgs);
459 Block &condBlock = pipeline.getCondBlock();
460 builder.setInsertionPointToStart(&condBlock);
461 auto cmpResult = arith::CmpIOp::create(builder, builder.getI1Type(),
462 arith::CmpIPredicate::ult,
463 condBlock.getArgument(0), upperBound);
464 condBlock.getTerminator()->insertOperands(0, {cmpResult});
467 DenseMap<unsigned, SmallVector<Operation *>> startGroups;
468 for (
auto *op : problem.getOperations()) {
469 if (isa<AffineYieldOp, YieldOp>(op))
472 startGroups[*startTime].push_back(op);
479 assert(iterArgs.size() == forOp.getBody()->getNumArguments());
480 for (
size_t i = 0; i < iterArgs.size(); ++i)
481 valueMap.map(forOp.getBody()->getArgument(i),
482 pipeline.getStagesBlock().getArgument(i));
485 Block &stagesBlock = pipeline.getStagesBlock();
486 builder.setInsertionPointToStart(&stagesBlock);
489 SmallVector<unsigned> startTimes;
490 for (
const auto &group : startGroups)
491 startTimes.push_back(group.first);
492 llvm::sort(startTimes);
494 DominanceInfo dom(getOperation());
497 SmallVector<SmallVector<Value>> registerValues;
498 SmallVector<SmallVector<Type>> registerTypes;
501 SmallVector<IRMapping> stageValueMaps;
504 DenseMap<Operation *, std::pair<unsigned, unsigned>> pipeTimes;
506 for (
auto startTime : startTimes) {
507 auto group = startGroups[startTime];
511 auto isLoopTerminator = [forOp](Operation *op) {
512 return isa<AffineYieldOp>(op) && op->getParentOp() == forOp;
516 for (
unsigned i = registerValues.size(); i <= startTime; ++i)
517 registerValues.emplace_back(SmallVector<Value>());
520 for (
auto *op : group) {
521 if (op->getUsers().empty())
524 unsigned pipeEndTime = 0;
525 for (
auto *user : op->getUsers()) {
528 pipeEndTime = std::max(pipeEndTime, userStartTime);
529 else if (isLoopTerminator(user))
531 pipeEndTime = std::max(pipeEndTime, userStartTime + 1);
535 pipeTimes[op] = std::pair(startTime, pipeEndTime);
538 for (
unsigned i = registerValues.size(); i <= pipeEndTime; ++i)
539 registerValues.push_back(SmallVector<Value>());
542 for (
auto result : op->getResults())
543 registerValues[startTime].push_back(result);
546 unsigned firstUse = std::max(
549 for (
unsigned i = firstUse; i < pipeEndTime; ++i) {
550 for (
auto result : op->getResults())
551 registerValues[i].push_back(result);
557 for (
unsigned i = 0; i < registerValues.size(); ++i) {
558 SmallVector<mlir::Type> types;
559 for (
auto val : registerValues[i])
560 types.push_back(val.getType());
562 registerTypes.push_back(types);
563 stageValueMaps.push_back(valueMap);
567 stageValueMaps.push_back(valueMap);
570 for (
auto startTime : startTimes) {
571 auto group = startGroups[startTime];
572 llvm::sort(group, [&](Operation *a, Operation *b) {
573 return dom.properlyDominates(a, b);
575 auto stageTypes = registerTypes[startTime];
578 stageTypes.push_back(lowerBound.getType());
581 builder.setInsertionPoint(stagesBlock.getTerminator());
582 auto startTimeAttr = builder.getIntegerAttr(
583 builder.getIntegerType(64,
true), startTime);
585 LoopSchedulePipelineStageOp::create(builder, stageTypes, startTimeAttr);
586 auto &stageBlock = stage.getBodyBlock();
587 auto *stageTerminator = stageBlock.getTerminator();
588 builder.setInsertionPointToStart(&stageBlock);
590 for (
auto *op : group) {
591 auto *newOp = builder.clone(*op, stageValueMaps[startTime]);
595 for (
auto result : op->getResults())
596 stageValueMaps[startTime].map(
597 result, newOp->getResult(result.getResultNumber()));
601 SmallVector<Value> stageOperands;
602 unsigned resIndex = 0;
603 for (
auto res : registerValues[startTime]) {
604 stageOperands.push_back(stageValueMaps[startTime].lookup(res));
607 unsigned destTime = startTime + 1;
611 if (*problem.
getStartTime(res.getDefiningOp()) == startTime &&
613 destTime = startTime + latency;
614 destTime = std::min((
unsigned)(stageValueMaps.size() - 1), destTime);
615 stageValueMaps[destTime].map(res, stage.getResult(resIndex++));
618 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
622 if (startTime == 0) {
624 arith::AddIOp::create(builder, stagesBlock.getArgument(0), step);
625 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
626 incResult->getResults());
631 auto stagesTerminator =
632 cast<LoopScheduleTerminatorOp>(stagesBlock.getTerminator());
636 SmallVector<Value> termIterArgs;
637 SmallVector<Value> termResults;
638 termIterArgs.push_back(
639 stagesBlock.front().getResult(stagesBlock.front().getNumResults() - 1));
641 for (
auto value : forOp.getBody()->getTerminator()->getOperands()) {
642 unsigned lookupTime = std::min((
unsigned)(stageValueMaps.size() - 1),
643 pipeTimes[value.getDefiningOp()].second);
645 termIterArgs.push_back(stageValueMaps[lookupTime].lookup(value));
646 termResults.push_back(stageValueMaps[lookupTime].lookup(value));
649 stagesTerminator.getIterArgsMutable().append(termIterArgs);
650 stagesTerminator.getResultsMutable().append(termResults);
653 for (
size_t i = 0; i < forOp.getNumResults(); ++i)
654 forOp.getResult(i).replaceAllUsesWith(pipeline.getResult(i));
657 loopNest.front().walk([](Operation *op) {
659 op->dropAllDefinedValueUses();
660 op->dropAllReferences();
668 return std::make_unique<AffineToLoopSchedule>();
static bool yieldOpLegalityCallback(AffineYieldOp op)
Helper to mark AffineYieldOp legal, unless it is inside a partially converted scf::IfOp.
static bool ifOpLegalityCallback(IfOp op)
Helper to determine if an scf::IfOp is in mux-like form.
assert(baseType &&"element must be base type")
Apply the affine map from an 'affine.load' operation to its operands, and feed the results to a newly...
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AffineLoadLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
Apply the affine map from an 'affine.store' operation to its operands, and feed the results to a newl...
AffineStoreLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class models a cyclic scheduling problem.
std::optional< unsigned > getDistance(Dependence dep)
The distance determines whether a dependence has to be satisfied in the same iteration (distance=0 or...
std::optional< unsigned > getInitiationInterval()
The initiation interval (II) is the number of time steps between subsequent iterations,...
This class models the modulo scheduling problem as the composition of the cyclic problem and the reso...
virtual LogicalResult verify() override
Return success if the computed solution is valid.
void setLatency(OperatorType opr, unsigned val)
virtual LogicalResult check()
Return success if the constructed scheduling problem is valid.
std::optional< unsigned > getLatency(OperatorType opr)
The latency is the number of cycles opr needs to compute its result.
std::optional< SmallVector< ResourceType > > getLinkedResourceTypes(Operation *op)
The linked resource type provides the available resources for op.
void setLinkedResourceTypes(Operation *op, SmallVector< ResourceType > rsrc)
std::optional< OperatorType > getLinkedOperatorType(Operation *op)
The linked operator type provides the runtime characteristics for op.
OperatorType getOrInsertOperatorType(StringRef name)
Retrieves the operator type identified by the client-specific name.
std::optional< unsigned > getStartTime(Operation *op)
Return the start time for op, as computed by the scheduler.
void setLinkedOperatorType(Operation *op, OperatorType opr)
Operation * getContainingOp()
Return the operation containing this problem, e.g. to emit diagnostics.
ResourceType getOrInsertResourceType(StringRef name)
Retrieves the resource type identified by the client-specific name.
void setLimit(ResourceType rsrc, unsigned val)
static llvm::hash_code hash_value(const ModulePort &port)
LogicalResult scheduleSimplex(Problem &prob, Operation *lastOp)
Solve the basic problem using linear programming and a handwritten implementation of the simplex algo...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createAffineToLoopSchedule()
Helper to hoist computation out of scf::IfOp branches, turning it into a mux-like operation,...
LogicalResult matchAndRewrite(IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
CyclicSchedulingAnalysis constructs a CyclicProblem for each AffineForOp by performing a memory depen...
MemoryDependenceAnalysis traverses any AffineForOps in the FuncOp body and checks for affine memory a...
void replaceOp(Operation *oldOp, Operation *newOp)
Replaces the dependences, if any, from the oldOp to the newOp.
Operator types are distinguished by name (chosen by the client).