15 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Affine/LoopUtils.h"
21 #include "mlir/Dialect/Affine/Utils.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/IR/BuiltinDialect.h"
28 #include "mlir/IR/Dominance.h"
29 #include "mlir/IR/IRMapping.h"
30 #include "mlir/IR/ImplicitLocOpBuilder.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
39 #define DEBUG_TYPE "affine-to-loopschedule"
42 #define GEN_PASS_DEF_AFFINETOLOOPSCHEDULE
43 #include "circt/Conversion/Passes.h.inc"
47 using namespace mlir::arith;
48 using namespace mlir::memref;
49 using namespace mlir::scf;
52 using namespace circt;
55 using namespace circt::loopschedule;
59 struct AffineToLoopSchedule
60 :
public circt::impl::AffineToLoopScheduleBase<AffineToLoopSchedule> {
61 void runOnOperation()
override;
67 LogicalResult populateOperatorTypes(SmallVectorImpl<AffineForOp> &loopNest,
69 LogicalResult solveSchedulingProblem(SmallVectorImpl<AffineForOp> &loopNest,
72 createLoopSchedulePipeline(SmallVectorImpl<AffineForOp> &loopNest,
84 if (opr.has_value()) {
85 modProb.setLinkedOperatorType(op, opr.value());
87 if (latency.has_value())
88 modProb.setLatency(opr.value(), latency.value());
90 modProb.insertOperation(op);
95 if (dep.isAuxiliary()) {
96 auto depInserted = modProb.insertDependence(dep);
97 assert(succeeded(depInserted));
101 if (distance.has_value())
102 modProb.setDistance(dep, distance.value());
109 void AffineToLoopSchedule::runOnOperation() {
111 auto dependenceAnalysis = getAnalysis<MemoryDependenceAnalysis>();
114 if (failed(lowerAffineStructures(dependenceAnalysis)))
115 return signalPassFailure();
118 schedulingAnalysis = &getAnalysis<CyclicSchedulingAnalysis>();
121 auto outerLoops = getOperation().getOps<AffineForOp>();
122 for (
auto root : llvm::make_early_inc_range(outerLoops)) {
123 SmallVector<AffineForOp> nestedLoops;
124 getPerfectlyNestedLoops(nestedLoops, root);
127 if (nestedLoops.size() != 1)
131 getModuloProblem(schedulingAnalysis->getProblem(nestedLoops.back()));
134 if (failed(populateOperatorTypes(nestedLoops, moduloProblem)))
135 return signalPassFailure();
138 if (failed(solveSchedulingProblem(nestedLoops, moduloProblem)))
139 return signalPassFailure();
142 if (failed(createLoopSchedulePipeline(nestedLoops, moduloProblem)))
143 return signalPassFailure();
160 ConversionPatternRewriter &rewriter)
const override {
162 SmallVector<Value, 8> indices(op.getMapOperands());
163 auto resultOperands =
164 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
169 auto memrefLoad = rewriter.replaceOpWithNewOp<memref::LoadOp>(
170 op, op.getMemRef(), *resultOperands);
172 dependenceAnalysis.replaceOp(op, memrefLoad);
194 ConversionPatternRewriter &rewriter)
const override {
196 SmallVector<Value, 8> indices(op.getMapOperands());
197 auto maybeExpandedMap =
198 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
199 if (!maybeExpandedMap)
203 auto memrefStore = rewriter.replaceOpWithNewOp<memref::StoreOp>(
204 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
206 dependenceAnalysis.replaceOp(op, memrefStore);
223 ConversionPatternRewriter &rewriter)
const override {
224 rewriter.modifyOpInPlace(op, [&]() {
225 if (!op.thenBlock()->without_terminator().empty()) {
226 rewriter.splitBlock(op.thenBlock(), --op.thenBlock()->end());
227 rewriter.inlineBlockBefore(&op.getThenRegion().front(), op);
229 if (op.elseBlock() && !op.elseBlock()->without_terminator().empty()) {
230 rewriter.splitBlock(op.elseBlock(), --op.elseBlock()->end());
231 rewriter.inlineBlockBefore(&op.getElseRegion().front(), op);
241 return op.thenBlock()->without_terminator().empty() &&
242 (!op.elseBlock() || op.elseBlock()->without_terminator().empty());
248 return !op->getParentOfType<IfOp>();
258 LogicalResult AffineToLoopSchedule::lowerAffineStructures(
260 auto *context = &getContext();
261 auto op = getOperation();
263 ConversionTarget target(*context);
264 target.addLegalDialect<AffineDialect, ArithDialect, MemRefDialect,
266 target.addIllegalOp<AffineIfOp, AffineLoadOp, AffineStoreOp>();
270 RewritePatternSet
patterns(context);
271 populateAffineToStdConversionPatterns(
patterns);
276 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
286 LogicalResult AffineToLoopSchedule::populateOperatorTypes(
287 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
289 auto forOp = loopNest.back();
301 Operation *unsupported;
302 WalkResult result = forOp.getBody()->walk([&](Operation *op) {
303 return TypeSwitch<Operation *, WalkResult>(op)
304 .Case<AddIOp, IfOp, AffineYieldOp, arith::ConstantOp, CmpIOp,
305 IndexCastOp, memref::AllocaOp, YieldOp>([&](Operation *combOp) {
308 return WalkResult::advance();
310 .Case<AddIOp, CmpIOp>([&](Operation *seqOp) {
314 return WalkResult::advance();
316 .Case<AffineStoreOp, memref::StoreOp>([&](Operation *memOp) {
320 Value memRef = isa<AffineStoreOp>(*memOp)
321 ? cast<AffineStoreOp>(*memOp).getMemRef()
322 : cast<memref::StoreOp>(*memOp).getMemRef();
328 return WalkResult::advance();
330 .Case<AffineLoadOp, memref::LoadOp>([&](Operation *memOp) {
334 Value memRef = isa<AffineLoadOp>(*memOp)
335 ? cast<AffineLoadOp>(*memOp).getMemRef()
336 : cast<memref::LoadOp>(*memOp).getMemRef();
342 return WalkResult::advance();
344 .Case<MulIOp>([&](Operation *mcOp) {
347 return WalkResult::advance();
349 .Default([&](Operation *badOp) {
351 return WalkResult::interrupt();
355 if (result.wasInterrupted())
356 return forOp.emitError(
"unsupported operation ") << *unsupported;
362 LogicalResult AffineToLoopSchedule::solveSchedulingProblem(
363 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
365 auto forOp = loopNest.back();
368 LLVM_DEBUG(forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
369 llvm::dbgs() <<
"Scheduling inputs for " << *op;
370 auto opr = problem.getLinkedOperatorType(op);
371 llvm::dbgs() <<
"\n opr = " << opr;
372 llvm::dbgs() <<
"\n latency = " << problem.getLatency(*opr);
373 for (auto dep : problem.getDependences(op))
374 if (dep.isAuxiliary())
375 llvm::dbgs() <<
"\n dep = { distance = " << problem.getDistance(dep)
376 <<
", source = " << *dep.getSource() <<
" }";
377 llvm::dbgs() <<
"\n\n";
381 if (failed(problem.
check()))
384 auto *anchor = forOp.getBody()->getTerminator();
389 if (failed(problem.
verify()))
394 llvm::dbgs() <<
"Scheduled initiation interval = "
396 forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
397 llvm::dbgs() <<
"Scheduling outputs for " << *op;
398 llvm::dbgs() <<
"\n start = " << problem.
getStartTime(op);
399 llvm::dbgs() <<
"\n\n";
407 LogicalResult AffineToLoopSchedule::createLoopSchedulePipeline(
408 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
410 auto forOp = loopNest.back();
412 auto outerLoop = loopNest.front();
413 auto innerLoop = loopNest.back();
414 ImplicitLocOpBuilder builder(outerLoop.getLoc(), outerLoop);
417 Value lowerBound = lowerAffineLowerBound(innerLoop, builder);
418 Value upperBound = lowerAffineUpperBound(innerLoop, builder);
419 int64_t stepValue = innerLoop.getStep().getSExtValue();
420 auto step = builder.create<arith::ConstantOp>(
425 TypeRange resultTypes = innerLoop.getResultTypes();
429 SmallVector<Value> iterArgs;
430 iterArgs.push_back(lowerBound);
431 iterArgs.append(innerLoop.getInits().begin(), innerLoop.getInits().end());
435 std::optional<IntegerAttr> tripCountAttr;
436 if (
auto tripCount = getConstantTripCount(forOp))
437 tripCountAttr = builder.getI64IntegerAttr(*tripCount);
439 auto pipeline = builder.create<LoopSchedulePipelineOp>(
440 resultTypes, ii, tripCountAttr, iterArgs);
444 Block &condBlock = pipeline.getCondBlock();
445 builder.setInsertionPointToStart(&condBlock);
446 auto cmpResult = builder.create<arith::CmpIOp>(
447 builder.getI1Type(), arith::CmpIPredicate::ult, condBlock.getArgument(0),
449 condBlock.getTerminator()->insertOperands(0, {cmpResult});
452 DenseMap<unsigned, SmallVector<Operation *>> startGroups;
454 if (isa<AffineYieldOp, YieldOp>(op))
457 startGroups[*startTime].push_back(op);
464 assert(iterArgs.size() == forOp.getBody()->getNumArguments());
465 for (
size_t i = 0; i < iterArgs.size(); ++i)
466 valueMap.map(forOp.getBody()->getArgument(i),
467 pipeline.getStagesBlock().getArgument(i));
470 Block &stagesBlock = pipeline.getStagesBlock();
471 builder.setInsertionPointToStart(&stagesBlock);
474 SmallVector<unsigned> startTimes;
475 for (
const auto &group : startGroups)
476 startTimes.push_back(group.first);
477 llvm::sort(startTimes);
479 DominanceInfo dom(getOperation());
482 SmallVector<SmallVector<Value>> registerValues;
483 SmallVector<SmallVector<Type>> registerTypes;
486 SmallVector<IRMapping> stageValueMaps;
489 DenseMap<Operation *, std::pair<unsigned, unsigned>> pipeTimes;
491 for (
auto startTime : startTimes) {
492 auto group = startGroups[startTime];
496 auto isLoopTerminator = [forOp](Operation *op) {
497 return isa<AffineYieldOp>(op) && op->getParentOp() == forOp;
501 for (
unsigned i = registerValues.size(); i <= startTime; ++i)
502 registerValues.emplace_back(SmallVector<Value>());
505 for (
auto *op : group) {
506 if (op->getUsers().empty())
509 unsigned pipeEndTime = 0;
510 for (
auto *user : op->getUsers()) {
513 pipeEndTime = std::max(pipeEndTime, userStartTime);
514 else if (isLoopTerminator(user))
516 pipeEndTime = std::max(pipeEndTime, userStartTime + 1);
520 pipeTimes[op] = std::pair(startTime, pipeEndTime);
523 for (
unsigned i = registerValues.size(); i <= pipeEndTime; ++i)
524 registerValues.push_back(SmallVector<Value>());
527 for (
auto result : op->getResults())
528 registerValues[startTime].push_back(result);
531 unsigned firstUse = std::max(
534 for (
unsigned i = firstUse; i < pipeEndTime; ++i) {
535 for (
auto result : op->getResults())
536 registerValues[i].push_back(result);
542 for (
unsigned i = 0; i < registerValues.size(); ++i) {
543 SmallVector<mlir::Type> types;
544 for (
auto val : registerValues[i])
545 types.push_back(val.getType());
547 registerTypes.push_back(types);
548 stageValueMaps.push_back(valueMap);
552 stageValueMaps.push_back(valueMap);
555 for (
auto startTime : startTimes) {
556 auto group = startGroups[startTime];
558 [&](Operation *a, Operation *b) {
return dom.dominates(a, b); });
559 auto stageTypes = registerTypes[startTime];
562 stageTypes.push_back(lowerBound.getType());
565 builder.setInsertionPoint(stagesBlock.getTerminator());
566 auto startTimeAttr = builder.getIntegerAttr(
567 builder.getIntegerType(64,
true), startTime);
569 builder.create<LoopSchedulePipelineStageOp>(stageTypes, startTimeAttr);
570 auto &stageBlock = stage.getBodyBlock();
571 auto *stageTerminator = stageBlock.getTerminator();
572 builder.setInsertionPointToStart(&stageBlock);
574 for (
auto *op : group) {
575 auto *newOp = builder.clone(*op, stageValueMaps[startTime]);
579 for (
auto result : op->getResults())
580 stageValueMaps[startTime].map(
581 result, newOp->getResult(result.getResultNumber()));
585 SmallVector<Value> stageOperands;
586 unsigned resIndex = 0;
587 for (
auto res : registerValues[startTime]) {
588 stageOperands.push_back(stageValueMaps[startTime].lookup(res));
591 unsigned destTime = startTime + 1;
595 if (*problem.
getStartTime(res.getDefiningOp()) == startTime &&
597 destTime = startTime + latency;
598 destTime = std::min((
unsigned)(stageValueMaps.size() - 1), destTime);
599 stageValueMaps[destTime].map(res, stage.getResult(resIndex++));
602 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
606 if (startTime == 0) {
608 builder.create<arith::AddIOp>(stagesBlock.getArgument(0), step);
609 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
610 incResult->getResults());
615 auto stagesTerminator =
616 cast<LoopScheduleTerminatorOp>(stagesBlock.getTerminator());
620 SmallVector<Value> termIterArgs;
621 SmallVector<Value> termResults;
622 termIterArgs.push_back(
623 stagesBlock.front().getResult(stagesBlock.front().getNumResults() - 1));
625 for (
auto value : forOp.getBody()->getTerminator()->getOperands()) {
626 unsigned lookupTime = std::min((
unsigned)(stageValueMaps.size() - 1),
627 pipeTimes[value.getDefiningOp()].second);
629 termIterArgs.push_back(stageValueMaps[lookupTime].lookup(value));
630 termResults.push_back(stageValueMaps[lookupTime].lookup(value));
633 stagesTerminator.getIterArgsMutable().append(termIterArgs);
634 stagesTerminator.getResultsMutable().append(termResults);
637 for (
size_t i = 0; i < forOp.getNumResults(); ++i)
638 forOp.getResult(i).replaceAllUsesWith(pipeline.getResult(i));
641 loopNest.front().walk([](Operation *op) {
643 op->dropAllDefinedValueUses();
644 op->dropAllReferences();
652 return std::make_unique<AffineToLoopSchedule>();
static bool yieldOpLegalityCallback(AffineYieldOp op)
Helper to mark AffineYieldOp legal, unless it is inside a partially converted scf::IfOp.
static bool ifOpLegalityCallback(IfOp op)
Helper to determine if an scf::IfOp is in mux-like form.
assert(baseType &&"element must be base type")
Apply the affine map from an 'affine.load' operation to its operands, and feed the results to a newly...
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AffineLoadLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
Apply the affine map from an 'affine.store' operation to its operands, and feed the results to a newl...
AffineStoreLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class models a cyclic scheduling problem.
std::optional< unsigned > getInitiationInterval()
The initiation interval (II) is the number of time steps between subsequent iterations,...
std::optional< unsigned > getDistance(Dependence dep)
The distance determines whether a dependence has to be satisfied in the same iteration (distance=0 or...
This class models the modulo scheduling problem as the composition of the cyclic problem and the reso...
virtual LogicalResult verify() override
Return success if the computed solution is valid.
void setLatency(OperatorType opr, unsigned val)
virtual LogicalResult check()
Return success if the constructed scheduling problem is valid.
std::optional< OperatorType > getLinkedOperatorType(Operation *op)
The linked operator type provides the runtime characteristics for op.
std::optional< unsigned > getStartTime(Operation *op)
Return the start time for op, as computed by the scheduler.
OperatorType getOrInsertOperatorType(StringRef name)
Retrieves the operator type identified by the client-specific name.
DependenceRange getDependences(Operation *op)
Return a range object to transparently iterate over op's incoming 1) implicit def-use dependences (ba...
const OperationSet & getOperations()
Return the set of operations.
void setLinkedOperatorType(Operation *op, OperatorType opr)
mlir::StringAttr OperatorType
Operator types are distinguished by name (chosen by the client).
std::optional< unsigned > getLatency(OperatorType opr)
The latency is the number of cycles opr needs to compute its result.
Operation * getContainingOp()
Return the operation containing this problem, e.g. to emit diagnostics.
void setLimit(OperatorType opr, unsigned val)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult scheduleSimplex(Problem &prob, Operation *lastOp)
Solve the basic problem using linear programming and a handwritten implementation of the simplex algo...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createAffineToLoopSchedule()
llvm::hash_code hash_value(const T &e)
Helper to hoist computation out of scf::IfOp branches, turning it into a mux-like operation,...
LogicalResult matchAndRewrite(IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
CyclicSchedulingAnalysis constructs a CyclicProblem for each AffineForOp by performing a memory depen...
MemoryDependenceAnalysis traverses any AffineForOps in the FuncOp body and checks for affine memory a...