10 #include "../PassDetail.h"
16 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
17 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
18 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
20 #include "mlir/Dialect/Affine/IR/AffineOps.h"
21 #include "mlir/Dialect/Affine/LoopUtils.h"
22 #include "mlir/Dialect/Affine/Utils.h"
23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/IR/BuiltinDialect.h"
29 #include "mlir/IR/Dominance.h"
30 #include "mlir/IR/IRMapping.h"
31 #include "mlir/IR/ImplicitLocOpBuilder.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
39 #define DEBUG_TYPE "affine-to-loopschedule"
47 using namespace circt;
54 struct AffineToLoopSchedule
55 :
public AffineToLoopScheduleBase<AffineToLoopSchedule> {
56 void runOnOperation()
override;
62 LogicalResult populateOperatorTypes(SmallVectorImpl<AffineForOp> &loopNest,
64 LogicalResult solveSchedulingProblem(SmallVectorImpl<AffineForOp> &loopNest,
67 createLoopSchedulePipeline(SmallVectorImpl<AffineForOp> &loopNest,
79 if (opr.has_value()) {
80 modProb.setLinkedOperatorType(op, opr.value());
82 if (latency.has_value())
83 modProb.setLatency(opr.value(), latency.value());
85 modProb.insertOperation(op);
90 if (dep.isAuxiliary()) {
91 auto depInserted = modProb.insertDependence(dep);
92 assert(succeeded(depInserted));
96 if (distance.has_value())
97 modProb.setDistance(dep, distance.value());
104 void AffineToLoopSchedule::runOnOperation() {
106 auto dependenceAnalysis = getAnalysis<MemoryDependenceAnalysis>();
109 if (failed(lowerAffineStructures(dependenceAnalysis)))
110 return signalPassFailure();
113 schedulingAnalysis = &getAnalysis<CyclicSchedulingAnalysis>();
116 auto outerLoops = getOperation().getOps<AffineForOp>();
117 for (
auto root : llvm::make_early_inc_range(outerLoops)) {
118 SmallVector<AffineForOp> nestedLoops;
119 getPerfectlyNestedLoops(nestedLoops, root);
122 if (nestedLoops.size() != 1)
126 getModuloProblem(schedulingAnalysis->getProblem(nestedLoops.back()));
129 if (failed(populateOperatorTypes(nestedLoops, moduloProblem)))
130 return signalPassFailure();
133 if (failed(solveSchedulingProblem(nestedLoops, moduloProblem)))
134 return signalPassFailure();
137 if (failed(createLoopSchedulePipeline(nestedLoops, moduloProblem)))
138 return signalPassFailure();
155 ConversionPatternRewriter &rewriter)
const override {
157 SmallVector<Value, 8> indices(op.getMapOperands());
158 auto resultOperands =
159 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
164 auto memrefLoad = rewriter.replaceOpWithNewOp<memref::LoadOp>(
165 op, op.getMemRef(), *resultOperands);
167 dependenceAnalysis.replaceOp(op, memrefLoad);
189 ConversionPatternRewriter &rewriter)
const override {
191 SmallVector<Value, 8> indices(op.getMapOperands());
192 auto maybeExpandedMap =
193 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
194 if (!maybeExpandedMap)
198 auto memrefStore = rewriter.replaceOpWithNewOp<memref::StoreOp>(
199 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
201 dependenceAnalysis.replaceOp(op, memrefStore);
218 ConversionPatternRewriter &rewriter)
const override {
219 rewriter.updateRootInPlace(op, [&]() {
220 if (!op.thenBlock()->without_terminator().empty()) {
221 rewriter.splitBlock(op.thenBlock(), --op.thenBlock()->end());
222 rewriter.inlineBlockBefore(&op.getThenRegion().front(), op);
224 if (op.elseBlock() && !op.elseBlock()->without_terminator().empty()) {
225 rewriter.splitBlock(op.elseBlock(), --op.elseBlock()->end());
226 rewriter.inlineBlockBefore(&op.getElseRegion().front(), op);
236 return op.thenBlock()->without_terminator().empty() &&
237 (!op.elseBlock() || op.elseBlock()->without_terminator().empty());
243 return !op->getParentOfType<IfOp>();
253 LogicalResult AffineToLoopSchedule::lowerAffineStructures(
255 auto *context = &getContext();
256 auto op = getOperation();
258 ConversionTarget target(*context);
259 target.addLegalDialect<AffineDialect, ArithDialect, MemRefDialect,
261 target.addIllegalOp<AffineIfOp, AffineLoadOp, AffineStoreOp>();
265 RewritePatternSet
patterns(context);
266 populateAffineToStdConversionPatterns(
patterns);
271 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
281 LogicalResult AffineToLoopSchedule::populateOperatorTypes(
282 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
284 auto forOp = loopNest.back();
296 Operation *unsupported;
297 WalkResult result = forOp.getBody()->walk([&](Operation *op) {
298 return TypeSwitch<Operation *, WalkResult>(op)
299 .Case<AddIOp, IfOp, AffineYieldOp, arith::ConstantOp, CmpIOp,
300 IndexCastOp, memref::AllocaOp, YieldOp>([&](Operation *combOp) {
303 return WalkResult::advance();
305 .Case<AddIOp, CmpIOp>([&](Operation *seqOp) {
309 return WalkResult::advance();
311 .Case<AffineStoreOp, memref::StoreOp>([&](Operation *memOp) {
315 Value memRef = isa<AffineStoreOp>(*memOp)
316 ? cast<AffineStoreOp>(*memOp).getMemRef()
317 : cast<memref::StoreOp>(*memOp).getMemRef();
323 return WalkResult::advance();
325 .Case<AffineLoadOp, memref::LoadOp>([&](Operation *memOp) {
329 Value memRef = isa<AffineLoadOp>(*memOp)
330 ? cast<AffineLoadOp>(*memOp).getMemRef()
331 : cast<memref::LoadOp>(*memOp).getMemRef();
337 return WalkResult::advance();
339 .Case<MulIOp>([&](Operation *mcOp) {
342 return WalkResult::advance();
344 .Default([&](Operation *badOp) {
346 return WalkResult::interrupt();
350 if (result.wasInterrupted())
351 return forOp.emitError(
"unsupported operation ") << *unsupported;
357 LogicalResult AffineToLoopSchedule::solveSchedulingProblem(
358 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
360 auto forOp = loopNest.back();
363 LLVM_DEBUG(forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
364 llvm::dbgs() <<
"Scheduling inputs for " << *op;
365 auto opr = problem.getLinkedOperatorType(op);
366 llvm::dbgs() <<
"\n opr = " << opr;
367 llvm::dbgs() <<
"\n latency = " << problem.getLatency(*opr);
368 for (auto dep : problem.getDependences(op))
369 if (dep.isAuxiliary())
370 llvm::dbgs() <<
"\n dep = { distance = " << problem.getDistance(dep)
371 <<
", source = " << *dep.getSource() <<
" }";
372 llvm::dbgs() <<
"\n\n";
376 if (failed(problem.
check()))
379 auto *anchor = forOp.getBody()->getTerminator();
384 if (failed(problem.
verify()))
389 llvm::dbgs() <<
"Scheduled initiation interval = "
391 forOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
392 llvm::dbgs() <<
"Scheduling outputs for " << *op;
402 LogicalResult AffineToLoopSchedule::createLoopSchedulePipeline(
403 SmallVectorImpl<AffineForOp> &loopNest,
ModuloProblem &problem) {
405 auto forOp = loopNest.back();
407 auto outerLoop = loopNest.front();
408 auto innerLoop = loopNest.back();
409 ImplicitLocOpBuilder
builder(outerLoop.getLoc(), outerLoop);
412 Value lowerBound = lowerAffineLowerBound(innerLoop,
builder);
413 Value upperBound = lowerAffineUpperBound(innerLoop,
builder);
414 int64_t stepValue = innerLoop.getStep();
415 auto step =
builder.create<arith::ConstantOp>(
420 TypeRange resultTypes = innerLoop.getResultTypes();
424 SmallVector<Value> iterArgs;
425 iterArgs.push_back(lowerBound);
426 iterArgs.append(innerLoop.getIterOperands().begin(),
427 innerLoop.getIterOperands().end());
431 std::optional<IntegerAttr> tripCountAttr;
432 if (
auto tripCount = getConstantTripCount(forOp))
433 tripCountAttr =
builder.getI64IntegerAttr(*tripCount);
435 auto pipeline =
builder.create<LoopSchedulePipelineOp>(
436 resultTypes, ii, tripCountAttr, iterArgs);
440 Block &condBlock = pipeline.getCondBlock();
441 builder.setInsertionPointToStart(&condBlock);
442 auto cmpResult =
builder.create<arith::CmpIOp>(
443 builder.getI1Type(), arith::CmpIPredicate::ult, condBlock.getArgument(0),
445 condBlock.getTerminator()->insertOperands(0, {cmpResult});
448 DenseMap<unsigned, SmallVector<Operation *>> startGroups;
450 if (isa<AffineYieldOp, YieldOp>(op))
453 startGroups[*startTime].push_back(op);
460 assert(iterArgs.size() == forOp.getBody()->getNumArguments());
461 for (
size_t i = 0; i < iterArgs.size(); ++i)
462 valueMap.map(forOp.getBody()->getArgument(i),
463 pipeline.getStagesBlock().getArgument(i));
466 Block &stagesBlock = pipeline.getStagesBlock();
467 builder.setInsertionPointToStart(&stagesBlock);
470 SmallVector<unsigned> startTimes;
471 for (
const auto &group : startGroups)
472 startTimes.push_back(group.first);
473 llvm::sort(startTimes);
475 DominanceInfo dom(getOperation());
478 SmallVector<SmallVector<Value>> registerValues;
479 SmallVector<SmallVector<Type>> registerTypes;
482 SmallVector<IRMapping> stageValueMaps;
485 DenseMap<Operation *, std::pair<unsigned, unsigned>> pipeTimes;
487 for (
auto startTime : startTimes) {
488 auto group = startGroups[startTime];
492 auto isLoopTerminator = [forOp](Operation *op) {
493 return isa<AffineYieldOp>(op) && op->getParentOp() == forOp;
497 for (
unsigned i = registerValues.size(); i <= startTime; ++i)
498 registerValues.emplace_back(SmallVector<Value>());
501 for (
auto *op : group) {
502 if (op->getUsers().empty())
505 unsigned pipeEndTime = 0;
506 for (
auto *user : op->getUsers()) {
509 pipeEndTime = std::max(pipeEndTime, userStartTime);
510 else if (isLoopTerminator(user))
512 pipeEndTime = std::max(pipeEndTime, userStartTime + 1);
516 pipeTimes[op] = std::pair(startTime, pipeEndTime);
519 for (
unsigned i = registerValues.size(); i <= pipeEndTime; ++i)
520 registerValues.push_back(SmallVector<Value>());
523 for (
auto result : op->getResults())
524 registerValues[startTime].push_back(result);
527 unsigned firstUse = std::max(
530 for (
unsigned i = firstUse; i < pipeEndTime; ++i) {
531 for (
auto result : op->getResults())
532 registerValues[i].push_back(result);
538 for (
unsigned i = 0; i < registerValues.size(); ++i) {
539 SmallVector<mlir::Type> types;
540 for (
auto val : registerValues[i])
541 types.push_back(val.getType());
543 registerTypes.push_back(types);
544 stageValueMaps.push_back(valueMap);
548 stageValueMaps.push_back(valueMap);
551 for (
auto startTime : startTimes) {
552 auto group = startGroups[startTime];
554 [&](Operation *a, Operation *b) {
return dom.dominates(a, b); });
555 auto stageTypes = registerTypes[startTime];
558 stageTypes.push_back(lowerBound.getType());
561 builder.setInsertionPoint(stagesBlock.getTerminator());
562 auto startTimeAttr =
builder.getIntegerAttr(
563 builder.getIntegerType(64,
true), startTime);
565 builder.create<LoopSchedulePipelineStageOp>(stageTypes, startTimeAttr);
566 auto &stageBlock = stage.getBodyBlock();
567 auto *stageTerminator = stageBlock.getTerminator();
568 builder.setInsertionPointToStart(&stageBlock);
570 for (
auto *op : group) {
571 auto *newOp =
builder.clone(*op, stageValueMaps[startTime]);
575 for (
auto result : op->getResults())
576 stageValueMaps[startTime].map(
577 result, newOp->getResult(result.getResultNumber()));
581 SmallVector<Value> stageOperands;
582 unsigned resIndex = 0;
583 for (
auto res : registerValues[startTime]) {
584 stageOperands.push_back(stageValueMaps[startTime].lookup(res));
587 unsigned destTime = startTime + 1;
591 if (*problem.
getStartTime(res.getDefiningOp()) == startTime &&
593 destTime = startTime + latency;
594 destTime = std::min((
unsigned)(stageValueMaps.size() - 1), destTime);
595 stageValueMaps[destTime].map(res, stage.getResult(resIndex++));
598 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
602 if (startTime == 0) {
604 builder.create<arith::AddIOp>(stagesBlock.getArgument(0), step);
605 stageTerminator->insertOperands(stageTerminator->getNumOperands(),
606 incResult->getResults());
611 auto stagesTerminator =
612 cast<LoopScheduleTerminatorOp>(stagesBlock.getTerminator());
616 SmallVector<Value> termIterArgs;
617 SmallVector<Value> termResults;
618 termIterArgs.push_back(
619 stagesBlock.front().getResult(stagesBlock.front().getNumResults() - 1));
621 for (
auto value : forOp.getBody()->getTerminator()->getOperands()) {
622 unsigned lookupTime = std::min((
unsigned)(stageValueMaps.size() - 1),
623 pipeTimes[value.getDefiningOp()].second);
625 termIterArgs.push_back(stageValueMaps[lookupTime].lookup(value));
626 termResults.push_back(stageValueMaps[lookupTime].lookup(value));
629 stagesTerminator.getIterArgsMutable().append(termIterArgs);
630 stagesTerminator.getResultsMutable().append(termResults);
633 for (
size_t i = 0; i < forOp.getNumResults(); ++i)
634 forOp.getResult(i).replaceAllUsesWith(pipeline.getResult(i));
637 loopNest.front().walk([](Operation *op) {
639 op->dropAllDefinedValueUses();
640 op->dropAllReferences();
648 return std::make_unique<AffineToLoopSchedule>();
static bool yieldOpLegalityCallback(AffineYieldOp op)
Helper to mark AffineYieldOp legal, unless it is inside a partially converted scf::IfOp.
static bool ifOpLegalityCallback(IfOp op)
Helper to determine if an scf::IfOp is in mux-like form.
assert(baseType &&"element must be base type")
Apply the affine map from an 'affine.load' operation to its operands, and feed the results to a newly...
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AffineLoadLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
Apply the affine map from an 'affine.store' operation to its operands, and feed the results to a newl...
AffineStoreLowering(MLIRContext *context, MemoryDependenceAnalysis &dependenceAnalysis)
MemoryDependenceAnalysis & dependenceAnalysis
LogicalResult matchAndRewrite(AffineStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class models a cyclic scheduling problem.
std::optional< unsigned > getInitiationInterval()
The initiation interval (II) is the number of time steps between subsequent iterations,...
std::optional< unsigned > getDistance(Dependence dep)
The distance determines whether a dependence has to be satisfied in the same iteration (distance=0 or...
This class models the modulo scheduling problem as the composition of the cyclic problem and the reso...
virtual LogicalResult verify() override
Return success if the computed solution is valid.
void setLatency(OperatorType opr, unsigned val)
virtual LogicalResult check()
Return success if the constructed scheduling problem is valid.
std::optional< OperatorType > getLinkedOperatorType(Operation *op)
The linked operator type provides the runtime characteristics for op.
std::optional< unsigned > getStartTime(Operation *op)
Return the start time for op, as computed by the scheduler.
OperatorType getOrInsertOperatorType(StringRef name)
Retrieves the operator type identified by the client-specific name.
DependenceRange getDependences(Operation *op)
Return a range object to transparently iterate over op's incoming 1) implicit def-use dependences (ba...
const OperationSet & getOperations()
Return the set of operations.
void setLinkedOperatorType(Operation *op, OperatorType opr)
mlir::StringAttr OperatorType
Operator types are distinguished by name (chosen by the client).
std::optional< unsigned > getLatency(OperatorType opr)
The latency is the number of cycles opr needs to compute its result.
Operation * getContainingOp()
Return the operation containing this problem, e.g. to emit diagnostics.
void setLimit(OperatorType opr, unsigned val)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult scheduleSimplex(Problem &prob, Operation *lastOp)
Solve the basic problem using linear programming and a handwritten implementation of the simplex algo...
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createAffineToLoopSchedule()
mlir::raw_indented_ostream & dbgs()
llvm::hash_code hash_value(const T &e)
Helper to hoist computation out of scf::IfOp branches, turning it into a mux-like operation,...
LogicalResult matchAndRewrite(IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
CyclicSchedulingAnalysis constructs a CyclicProblem for each AffineForOp by performing a memory depen...
MemoryDependenceAnalysis traverses any AffineForOps in the FuncOp body and checks for affine memory a...