10 #include "../PassDetail.h"
16 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
17 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
18 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
20 #include "mlir/Dialect/Affine/IR/AffineOps.h"
21 #include "mlir/Dialect/Affine/LoopUtils.h"
22 #include "mlir/Dialect/Affine/Utils.h"
23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/IR/BuiltinDialect.h"
29 #include "mlir/IR/Dominance.h"
30 #include "mlir/IR/IRMapping.h"
31 #include "mlir/IR/ImplicitLocOpBuilder.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
39 #define DEBUG_TYPE "affine-to-loopschedule"
47 using namespace circt;
54 struct AffineToLoopSchedule
55 :
public AffineToLoopScheduleBase<AffineToLoopSchedule> {
56 void runOnOperation()
override;
62 LogicalResult populateOperatorTypes(SmallVectorImpl<AffineForOp> &loopNest,
64 LogicalResult solveSchedulingProblem(SmallVectorImpl<AffineForOp> &loopNest,
67 createLoopSchedulePipeline(SmallVectorImpl<AffineForOp> &loopNest,
79 if (opr.has_value()) {
80 modProb.setLinkedOperatorType(op, opr.value());
82 if (latency.has_value())
83 modProb.setLatency(opr.value(), latency.value());
85 modProb.insertOperation(op);
90 if (dep.isAuxiliary()) {
91 auto depInserted = modProb.insertDependence(dep);
92 assert(succeeded(depInserted));
96 if (distance.has_value())
97 modProb.setDistance(dep, distance.value());
104 void AffineToLoopSchedule::runOnOperation() {
106 auto dependenceAnalysis = getAnalysis<MemoryDependenceAnalysis>();
109 if (failed(lowerAffineStructures(dependenceAnalysis)))
110 return signalPassFailure();
113 schedulingAnalysis = &getAnalysis<CyclicSchedulingAnalysis>();
116 auto outerLoops = getOperation().getOps<AffineForOp>();
117 for (
auto root : llvm::make_early_inc_range(outerLoops)) {
118 SmallVector<AffineForOp> nestedLoops;
119 getPerfectlyNestedLoops(nestedLoops, root);
122 if (nestedLoops.size() != 1)
126 getModuloProblem(schedulingAnalysis->getProblem(nestedLoops.back()));
129 if (failed(populateOperatorTypes(nestedLoops, moduloProblem)))
130 return signalPassFailure();
133 if (failed(solveSchedulingProblem(nestedLoops, moduloProblem)))
134 return signalPassFailure();
137 if (failed(createLoopSchedulePipeline(nestedLoops, moduloProblem)))
138 return signalPassFailure();
155 ConversionPatternRewriter &rewriter)
const override {
157 SmallVector<Value, 8> indices(op.getMapOperands());
158 auto resultOperands =
159 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
164 auto memrefLoad = rewriter.replaceOpWithNewOp<memref::LoadOp>(
165 op, op.getMemRef(), *resultOperands);
167 dependenceAnalysis.replaceOp(op, memrefLoad);
189 ConversionPatternRewriter &rewriter)
const override {
191 SmallVector<Value, 8> indices(op.getMapOperands());
192 auto maybeExpandedMap =
193 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
194 if (!maybeExpandedMap)
198 auto memrefStore = rewriter.replaceOpWithNewOp<memref::StoreOp>(
199 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
201 dependenceAnalysis.replaceOp(op, memrefStore);
218 ConversionPatternRewriter &rewriter)
const override {
219 rewriter.modifyOpInPlace(op, [&]() {
220 if (!op.thenBlock()->without_terminator().empty()) {
221 rewriter.splitBlock(op.thenBlock(), --op.thenBlock()->end());
222 rewriter.inlineBlockBefore(&op.getThenRegion().front(), op);
224 if (op.elseBlock() && !op.elseBlock()->without_terminator().empty()) {
225 rewriter.splitBlock(op.elseBlock(), --op.elseBlock()->end());
226 rewriter.inlineBlockBefore(&op.getElseRegion().front(), op);
236 return op.thenBlock()->without_terminator().empty() &&
237 (!op.elseBlock() || op.elseBlock()->without_terminator().empty());
243 return !op->getParentOfType<IfOp>();
253 LogicalResult AffineToLoopSchedule::lowerAffineStructures(
255 auto *context = &getContext();
256 auto op = getOperation();
258 ConversionTarget target(*context);
259 target.addLegalDialect<AffineDialect, ArithDialect, MemRefDialect,
261 target.addIllegalOp<AffineIfOp, AffineLoadOp, AffineStoreOp>();
265 RewritePatternSet
patterns(context);
266 populateAffineToStdConversionPatterns(
patterns);
271 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
281 LogicalResult AffineToLoopSchedule::populateOperatorTypes(
282 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
284 auto forOp = loopNest.back();
296 Operation *unsupported;
297 WalkResult result = forOp.getBody()->walk([&](Operation *op) {
298 return TypeSwitch<Operation *, WalkResult>(op)
299 .Case<AddIOp, IfOp, AffineYieldOp, arith::ConstantOp, CmpIOp,
300 IndexCastOp, memref::AllocaOp, YieldOp>([&](Operation *combOp) {
303 return WalkResult::advance();
305 .Case<AddIOp, CmpIOp>([&](Operation *seqOp) {
309 return WalkResult::advance();
311 .Case<AffineStoreOp, memref::StoreOp>([&](Operation *memOp) {
315 Value memRef = isa<AffineStoreOp>(*memOp)
316 ? cast<AffineStoreOp>(*memOp).getMemRef()
317 : cast<memref::StoreOp>(*memOp).getMemRef();
323 return WalkResult::advance();
325 .Case<AffineLoadOp, memref::LoadOp>([&](Operation *memOp) {
329 Value memRef = isa<AffineLoadOp>(*memOp)
330 ? cast<AffineLoadOp>(*memOp).getMemRef()
331 : cast<memref::LoadOp>(*memOp).getMemRef();
337 return WalkResult::advance();
339 .Case<MulIOp>([&](Operation *mcOp) {
342 return WalkResult::advance();
344 .Default([&](Operation *badOp) {
346 return WalkResult::interrupt();
350 if (result.wasInterrupted())
351 return forOp.emitError(
"unsupported operation ") << *unsupported;
357 LogicalResult AffineToLoopSchedule::solveSchedulingProblem(
358 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
360 auto forOp = loopNest.back();
363 LLVM_DEBUG(forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
364 llvm::dbgs() <<
"Scheduling inputs for " << *op;
365 auto opr = problem.getLinkedOperatorType(op);
366 llvm::dbgs() <<
"\n opr = " << opr;
367 llvm::dbgs() <<
"\n latency = " << problem.getLatency(*opr);
368 for (auto dep : problem.getDependences(op))
369 if (dep.isAuxiliary())
370 llvm::dbgs() <<
"\n dep = { distance = " << problem.getDistance(dep)
371 <<
", source = " << *dep.getSource() <<
" }";
372 llvm::dbgs() <<
"\n\n";
376 if (failed(problem.
check()))
379 auto *anchor = forOp.getBody()->getTerminator();
384 if (failed(problem.
verify()))
389 llvm::dbgs() <<
"Scheduled initiation interval = "
391 forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
392 llvm::dbgs() <<
"Scheduling outputs for " << *op;
393 llvm::dbgs() <<
"\n start = " << problem.
getStartTime(op);
394 llvm::dbgs() <<
"\n\n";
402 LogicalResult AffineToLoopSchedule::createLoopSchedulePipeline(
403 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
405 auto forOp = loopNest.back();
407 auto outerLoop = loopNest.front();
408 auto innerLoop = loopNest.back();
409 ImplicitLocOpBuilder
builder(outerLoop.getLoc(), outerLoop);
412 Value lowerBound = lowerAffineLowerBound(innerLoop,
builder);
413 Value upperBound = lowerAffineUpperBound(innerLoop,
builder);
414 int64_t stepValue = innerLoop.getStep().getSExtValue();
415 auto step =
builder.create<arith::ConstantOp>(
420 TypeRange resultTypes = innerLoop.getResultTypes();
424 SmallVector<Value> iterArgs;
425 iterArgs.push_back(lowerBound);
426 iterArgs.append(innerLoop.getInits().begin(), innerLoop.getInits().end());
430 std::optional<IntegerAttr> tripCountAttr;
431 if (
auto tripCount = getConstantTripCount(forOp))
432 tripCountAttr =
builder.getI64IntegerAttr(*tripCount);
434 auto pipeline =
builder.create<LoopSchedulePipelineOp>(
435 resultTypes, ii, tripCountAttr, iterArgs);
439 Block &condBlock = pipeline.getCondBlock();
440 builder.setInsertionPointToStart(&condBlock);
441 auto cmpResult =
builder.create<arith::CmpIOp>(
442 builder.getI1Type(), arith::CmpIPredicate::ult, condBlock.getArgument(0),
444 condBlock.getTerminator()->insertOperands(0, {cmpResult});
447 DenseMap<unsigned, SmallVector<Operation *>> startGroups;
449 if (isa<AffineYieldOp, YieldOp>(op))
452 startGroups[*startTime].push_back(op);
459 assert(iterArgs.size() == forOp.getBody()->getNumArguments());
460 for (
size_t i = 0; i < iterArgs.size(); ++i)
461 valueMap.map(forOp.getBody()->getArgument(i),
462 pipeline.getStagesBlock().getArgument(i));
465 Block &stagesBlock = pipeline.getStagesBlock();
466 builder.setInsertionPointToStart(&stagesBlock);
469 SmallVector<unsigned> startTimes;
470 for (
const auto &group : startGroups)
471 startTimes.push_back(group.first);
472 llvm::sort(startTimes);
474 DominanceInfo dom(getOperation());
477 SmallVector<SmallVector<Value>> registerValues;
478 SmallVector<SmallVector<Type>> registerTypes;
481 SmallVector<IRMapping> stageValueMaps;
484 DenseMap<Operation *, std::pair<unsigned, unsigned>> pipeTimes;
486 for (
auto startTime : startTimes) {
487 auto group = startGroups[startTime];
491 auto isLoopTerminator = [forOp](Operation *op) {
492 return isa<AffineYieldOp>(op) && op->getParentOp() == forOp;
496 for (
unsigned i = registerValues.size(); i <= startTime; ++i)
497 registerValues.emplace_back(SmallVector<Value>());
500 for (
auto *op : group) {
501 if (op->getUsers().empty())
504 unsigned pipeEndTime = 0;
505 for (
auto *user : op->getUsers()) {
508 pipeEndTime = std::max(pipeEndTime, userStartTime);
509 else if (isLoopTerminator(user))
511 pipeEndTime = std::max(pipeEndTime, userStartTime + 1);
515 pipeTimes[op] = std::pair(startTime, pipeEndTime);
518 for (
unsigned i = registerValues.size(); i <= pipeEndTime; ++i)
519 registerValues.push_back(SmallVector<Value>());
522 for (
auto result : op->getResults())
523 registerValues[startTime].push_back(result);
526 unsigned firstUse = std::max(
529 for (
unsigned i = firstUse; i < pipeEndTime; ++i) {
530 for (
auto result : op->getResults())
531 registerValues[i].push_back(result);
537 for (
unsigned i = 0; i < registerValues.size(); ++i) {
538 SmallVector<mlir::Type> types;
539 for (
auto val : registerValues[i])
540 types.push_back(val.getType());
542 registerTypes.push_back(types);
543 stageValueMaps.push_back(valueMap);
547 stageValueMaps.push_back(valueMap);
550 for (
auto startTime : startTimes) {
551 auto group = startGroups[startTime];
553 [&](Operation *a, Operation *b) {
return dom.dominates(a, b); });
554 auto stageTypes = registerTypes[startTime];
557 stageTypes.push_back(lowerBound.getType());
560 builder.setInsertionPoint(stagesBlock.getTerminator());
561 auto startTimeAttr =
builder.getIntegerAttr(
562 builder.getIntegerType(64,
true), startTime);
564 builder.create<LoopSchedulePipelineStageOp>(stageTypes, startTimeAttr);
565 auto &stageBlock = stage.getBodyBlock();
566 auto *stageTerminator = stageBlock.getTerminator();
567 builder.setInsertionPointToStart(&stageBlock);
569 for (
auto *op : group) {
570 auto *newOp =
builder.clone(*op, stageValueMaps[startTime]);
574 for (
auto result : op->getResults())
575 stageValueMaps[startTime].map(
576 result, newOp->getResult(result.getResultNumber()));
580 SmallVector<Value> stageOperands;
581 unsigned resIndex = 0;
582 for (
auto res : registerValues[startTime]) {
583 stageOperands.push_back(stageValueMaps[startTime].lookup(res));
586 unsigned destTime = startTime + 1;
590 if (*problem.
getStartTime(res.getDefiningOp()) == startTime &&
592 destTime = startTime + latency;
593 destTime = std::min((
unsigned)(stageValueMaps.size() - 1), destTime);
594 stageValueMaps[destTime].map(res, stage.getResult(resIndex++));
597 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
601 if (startTime == 0) {
603 builder.create<arith::AddIOp>(stagesBlock.getArgument(0), step);
604 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
605 incResult->getResults());
610 auto stagesTerminator =
611 cast<LoopScheduleTerminatorOp>(stagesBlock.getTerminator());
615 SmallVector<Value> termIterArgs;
616 SmallVector<Value> termResults;
617 termIterArgs.push_back(
618 stagesBlock.front().getResult(stagesBlock.front().getNumResults() - 1));
620 for (
auto value : forOp.getBody()->getTerminator()->getOperands()) {
621 unsigned lookupTime = std::min((
unsigned)(stageValueMaps.size() - 1),
622 pipeTimes[value.getDefiningOp()].second);
624 termIterArgs.push_back(stageValueMaps[lookupTime].lookup(value));
625 termResults.push_back(stageValueMaps[lookupTime].lookup(value));
628 stagesTerminator.getIterArgsMutable().append(termIterArgs);
629 stagesTerminator.getResultsMutable().append(termResults);
632 for (
size_t i = 0; i < forOp.getNumResults(); ++i)
633 forOp.getResult(i).replaceAllUsesWith(pipeline.getResult(i));
636 loopNest.front().walk([](Operation *op) {
638 op->dropAllDefinedValueUses();
639 op->dropAllReferences();
647 return std::make_unique<AffineToLoopSchedule>();
static bool yieldOpLegalityCallback(AffineYieldOp op)
Helper to mark AffineYieldOp legal, unless it is inside a partially converted scf::IfOp.
static bool ifOpLegalityCallback(IfOp op)
Helper to determine if an scf::IfOp is in mux-like form.
assert(baseType &&"element must be base type")
Apply the affine map from an 'affine.load' operation to its operands, and feed the results to a newly...
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AffineLoadLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
Apply the affine map from an 'affine.store' operation to its operands, and feed the results to a newl...
AffineStoreLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class models a cyclic scheduling problem.
std::optional< unsigned > getInitiationInterval()
The initiation interval (II) is the number of time steps between subsequent iterations,...
std::optional< unsigned > getDistance(Dependence dep)
The distance determines whether a dependence has to be satisfied in the same iteration (distance=0 or...
This class models the modulo scheduling problem as the composition of the cyclic problem and the reso...
virtual LogicalResult verify() override
Return success if the computed solution is valid.
void setLatency(OperatorType opr, unsigned val)
virtual LogicalResult check()
Return success if the constructed scheduling problem is valid.
std::optional< OperatorType > getLinkedOperatorType(Operation *op)
The linked operator type provides the runtime characteristics for op.
std::optional< unsigned > getStartTime(Operation *op)
Return the start time for op, as computed by the scheduler.
OperatorType getOrInsertOperatorType(StringRef name)
Retrieves the operator type identified by the client-specific name.
DependenceRange getDependences(Operation *op)
Return a range object to transparently iterate over op's incoming 1) implicit def-use dependences (ba...
const OperationSet & getOperations()
Return the set of operations.
void setLinkedOperatorType(Operation *op, OperatorType opr)
mlir::StringAttr OperatorType
Operator types are distinguished by name (chosen by the client).
std::optional< unsigned > getLatency(OperatorType opr)
The latency is the number of cycles opr needs to compute its result.
Operation * getContainingOp()
Return the operation containing this problem, e.g. to emit diagnostics.
void setLimit(OperatorType opr, unsigned val)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult scheduleSimplex(Problem &prob, Operation *lastOp)
Solve the basic problem using linear programming and a handwritten implementation of the simplex algo...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createAffineToLoopSchedule()
llvm::hash_code hash_value(const T &e)
Helper to hoist computation out of scf::IfOp branches, turning it into a mux-like operation,...
LogicalResult matchAndRewrite(IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
CyclicSchedulingAnalysis constructs a CyclicProblem for each AffineForOp by performing a memory depen...
MemoryDependenceAnalysis traverses any AffineForOps in the FuncOp body and checks for affine memory a...