18 #include "mlir/Analysis/CFGLoopInfo.h"
19 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
20 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
21 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
24 #include "mlir/Dialect/Affine/Utils.h"
25 #include "mlir/Dialect/Arith/IR/Arith.h"
26 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"
29 #include "mlir/Dialect/SCF/IR/SCF.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Diagnostics.h"
33 #include "mlir/IR/Dominance.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/Types.h"
37 #include "mlir/IR/Value.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Support/LLVM.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 #include "mlir/Transforms/Passes.h"
42 #include "llvm/ADT/SmallSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/Support/raw_ostream.h"
50 #define GEN_PASS_DEF_CFTOHANDSHAKE
51 #define GEN_PASS_DEF_HANDSHAKEREMOVEBLOCK
52 #include "circt/Conversion/Passes.h.inc"
58 using namespace circt;
67 template <
typename TOp>
68 class LowerOpTarget :
public ConversionTarget {
70 explicit LowerOpTarget(MLIRContext &context) : ConversionTarget(context) {
72 addLegalDialect<HandshakeDialect>();
73 addLegalDialect<mlir::func::FuncDialect>();
74 addLegalDialect<mlir::arith::ArithDialect>();
75 addIllegalDialect<mlir::scf::SCFDialect>();
76 addIllegalDialect<AffineDialect>();
81 addDynamicallyLegalOp<TOp>([&](
const auto &op) {
return loweredOps[op]; });
83 DenseMap<Operation *, bool> loweredOps;
100 template <
typename TOp>
101 struct PartialLowerOp :
public ConversionPattern {
102 using PartialLoweringFunc =
103 std::function<LogicalResult(TOp, ConversionPatternRewriter &)>;
106 PartialLowerOp(LowerOpTarget<TOp> &target, MLIRContext *context,
107 LogicalResult &loweringResRef,
const PartialLoweringFunc &fun)
108 : ConversionPattern(TOp::getOperationName(), 1, context), target(target),
109 loweringRes(loweringResRef), fun(fun) {}
110 using ConversionPattern::ConversionPattern;
112 matchAndRewrite(Operation *op, ArrayRef<Value> ,
113 ConversionPatternRewriter &rewriter)
const override {
115 loweringRes = fun(dyn_cast<TOp>(op), rewriter);
116 target.loweredOps[op] =
true;
121 LowerOpTarget<TOp> ⌖
122 LogicalResult &loweringRes;
124 PartialLoweringFunc fun;
130 template <
typename TOp>
132 const std::function<LogicalResult(TOp, ConversionPatternRewriter &)>
134 MLIRContext *ctx, TOp op) {
137 auto target = LowerOpTarget<TOp>(*ctx);
138 LogicalResult partialLoweringSuccessfull = success();
139 patterns.add<PartialLowerOp<TOp>>(target, ctx, partialLoweringSuccessfull,
142 applyPartialConversion(op, target, std::move(
patterns)).succeeded() &&
143 partialLoweringSuccessfull.succeeded());
149 : ConversionTarget(context), region(region) {
152 markUnknownOpDynamicallyLegal([&](Operation *op) {
153 if (op != region.getParentOp())
158 bool opLowered =
false;
169 std::function<LogicalResult(Region &, ConversionPatternRewriter &)>;
173 LogicalResult &loweringResRef,
175 : ConversionPattern(target.region.getParentOp()->
getName().getStringRef(),
177 target(target), loweringRes(loweringResRef), fun(fun) {}
178 using ConversionPattern::ConversionPattern;
181 ConversionPatternRewriter &rewriter)
const override {
182 rewriter.modifyOpInPlace(
183 op, [&] { loweringRes = fun(target.region, rewriter); });
185 target.opLowered =
true;
197 MLIRContext *ctx, Region &r) {
199 Operation *op = r.getParentOp();
202 LogicalResult partialLoweringSuccessfull = success();
206 applyPartialConversion(op, target, std::move(
patterns)).succeeded() &&
207 partialLoweringSuccessfull.succeeded());
214 Value HandshakeLowering::getBlockEntryControl(Block *block)
const {
215 auto it = blockEntryControlMap.find(block);
216 assert(it != blockEntryControlMap.end() &&
217 "No block entry control value registerred for this block!");
221 void HandshakeLowering::setBlockEntryControl(Block *block, Value v) {
222 blockEntryControlMap[block] = v;
226 Block *entryBlock = &r.front();
227 auto &entryBlockOps = entryBlock->getOperations();
230 for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
231 entryBlockOps.splice(entryBlockOps.end(), block.getOperations());
234 block.dropAllDefinedValueUses();
235 for (
size_t i = 0; i < block.getNumArguments(); i++) {
236 block.eraseArgument(i);
243 for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) {
244 if (!terminatorLike.hasTrait<OpTrait::IsTerminator>())
247 if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(terminatorLike)) {
248 terminatorLike.erase();
253 terminatorLike.moveBefore(entryBlock, entryBlock->end());
258 HandshakeLowering::runSSAMaximization(ConversionPatternRewriter &rewriter,
264 if (funcOp.isExternal())
271 if (type.getNumDynamicDims() != 0 || type.getShape().size() != 1)
272 return emitError(loc) <<
"memref's must be both statically sized and "
279 auto predecessors = block->getPredecessors();
280 return std::distance(predecessors.begin(), predecessors.end());
286 HandshakeLowering::insertMerge(Block *block, Value val,
288 ConversionPatternRewriter &rewriter) {
290 auto insertLoc = block->front().getLoc();
291 SmallVector<Backedge> dataEdges;
292 SmallVector<Value> operands;
296 if (val == getBlockEntryControl(block)) {
299 if (block == &r.front()) {
307 operands.push_back(val);
308 mergeOp = rewriter.create<handshake::MergeOp>(insertLoc, operands);
310 for (
unsigned i = 0; i < numPredecessors; i++) {
311 auto edge = edgeBuilder.
get(rewriter.getNoneType());
312 dataEdges.push_back(edge);
313 operands.push_back(Value(edge));
315 mergeOp = rewriter.create<handshake::ControlMergeOp>(insertLoc, operands);
317 setBlockEntryControl(block, mergeOp->getResult(0));
326 if (numPredecessors <= 1) {
327 if (numPredecessors == 0) {
331 operands.push_back(val);
335 auto edge = edgeBuilder.
get(val.getType());
336 dataEdges.push_back(edge);
337 operands.push_back(Value(edge));
339 auto merge = rewriter.create<handshake::MergeOp>(insertLoc, operands);
347 Backedge indexEdge = edgeBuilder.
get(rewriter.getIndexType());
348 for (
unsigned i = 0; i < numPredecessors; i++) {
349 auto edge = edgeBuilder.
get(val.getType());
350 dataEdges.push_back(edge);
351 operands.push_back(Value(edge));
354 rewriter.create<handshake::MuxOp>(insertLoc, Value(indexEdge), operands);
355 return MergeOpInfo{mux, val, dataEdges, indexEdge};
361 ConversionPatternRewriter &rewriter) {
363 for (Block &block : r) {
364 rewriter.setInsertionPointToStart(&block);
368 for (
auto &arg : block.getArguments()) {
370 if (isa<mlir::MemRefType>(arg.getType()))
373 auto mergeInfo = insertMerge(&block, arg, edgeBuilder, rewriter);
374 blockMerges[&block].push_back(mergeInfo);
375 mergePairs[arg] = mergeInfo.op->getResult(0);
385 Value srcVal = mergeInfo.
val;
387 Block *block = mergeInfo.
op->getBlock();
392 unsigned index = cast<BlockArgument>(srcVal).getArgNumber();
393 Operation *termOp = predBlock->getTerminator();
394 if (mlir::cf::CondBranchOp br = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
396 if (block == br.getTrueDest())
397 return br.getTrueOperand(index);
398 assert(block == br.getFalseDest());
399 return br.getFalseOperand(index);
401 if (isa<mlir::cf::BranchOp>(termOp))
402 return termOp->getOperand(index);
409 for (Block &block : f) {
410 if (!block.isEntryBlock()) {
411 int x = block.getNumArguments() - 1;
412 for (
int i = x; i >= 0; --i)
413 block.eraseArgument(i);
420 template <
typename TOp>
422 auto ops = block->getOps<TOp>();
429 return getFirstOp<ControlMergeOp>(block);
433 for (
auto cbranch : block->getOps<handshake::ConditionalBranchOp>()) {
434 if (cbranch.isControl())
448 for (Block &block : r) {
449 for (
auto &mergeInfo : blockMerges[&block]) {
452 for (
auto *predBlock : block.getPredecessors()) {
454 assert(mgOperand !=
nullptr);
455 if (!mgOperand.getDefiningOp()) {
456 assert(mergePairs.count(mgOperand));
457 mgOperand = mergePairs[mgOperand];
459 mergeInfo.dataEdges[operandIdx].setValue(mgOperand);
465 for (Operation &opp : block)
466 if (!isa<MergeLikeOpInterface>(opp))
467 opp.replaceUsesOfWith(mergeInfo.val, mergeInfo.op->getResult(0));
473 for (Block &block : r) {
476 assert(cntrlMg !=
nullptr);
478 for (
auto &mergeInfo : blockMerges[&block]) {
479 if (mergeInfo.op != cntrlMg) {
483 assert(mergeInfo.indexEdge.has_value());
484 (*mergeInfo.indexEdge).setValue(cntrlMg->getResult(1));
494 return isa<memref::AllocOp, memref::AllocaOp>(op);
498 HandshakeLowering::addMergeOps(ConversionPatternRewriter &rewriter) {
509 BlockOps mergeOps = insertMergeOps(mergePairs, edgeBuilder, rewriter);
519 for (
auto &u : val.getUses())
521 if (isa<MergeLikeOpInterface>(u.getOwner()))
532 for (
int i = 0, e = block->getNumSuccessors(); i < e; ++i) {
534 Block *succ = block->getSuccessor(i);
535 for (
auto &u : val.getUses()) {
536 if (u.getOwner()->getBlock() == succ)
539 uses = (curr > uses) ? curr : uses;
553 class FeedForwardNetworkRewriter {
556 ConversionPatternRewriter &rewriter)
557 : hl(hl), rewriter(rewriter), postDomInfo(hl.getRegion().getParentOp()),
558 domInfo(hl.getRegion().getParentOp()),
559 loopInfo(domInfo.getDomTree(&hl.getRegion())) {}
560 LogicalResult apply();
564 ConversionPatternRewriter &rewriter;
565 PostDominanceInfo postDomInfo;
566 DominanceInfo domInfo;
567 CFGLoopInfo loopInfo;
569 using BlockPair = std::pair<Block *, Block *>;
570 using BlockPairs = SmallVector<BlockPair>;
571 LogicalResult findBlockPairs(BlockPairs &blockPairs);
573 BufferOp buildSplitNetwork(Block *splitBlock,
574 handshake::ConditionalBranchOp &ctrlBr);
575 LogicalResult buildMergeNetwork(Block *
mergeBlock, BufferOp buf,
576 handshake::ConditionalBranchOp &ctrlBr);
579 bool requiresOperandFlip(ControlMergeOp &ctrlMerge,
580 handshake::ConditionalBranchOp &ctrlBr);
581 bool formsIrreducibleCF(Block *splitBlock, Block *
mergeBlock);
586 HandshakeLowering::feedForwardRewriting(ConversionPatternRewriter &rewriter) {
588 if (this->getRegion().hasOneBlock())
590 return FeedForwardNetworkRewriter(*
this, rewriter).apply();
594 for (CFGLoop *loop : loopInfo.getTopLevelLoops())
595 if (!loop->getExitBlock())
600 bool FeedForwardNetworkRewriter::formsIrreducibleCF(Block *splitBlock,
602 CFGLoop *loop = loopInfo.getLoopFor(
mergeBlock);
603 for (
auto *mergePred :
mergeBlock->getPredecessors()) {
605 if (loop && loop->contains(mergePred))
613 if (llvm::none_of(splitBlock->getSuccessors(), [&](Block *splitSucc) {
614 if (splitSucc == mergeBlock || mergePred == splitBlock)
616 return domInfo.dominates(splitSucc, mergePred);
624 Block *pred = *block->getPredecessors().begin();
625 return pred->getTerminator();
629 FeedForwardNetworkRewriter::findBlockPairs(BlockPairs &blockPairs) {
633 Region &r = hl.getRegion();
634 Operation *parentOp = r.getParentOp();
639 "expected loop to only have one exit block.");
642 if (b.getNumSuccessors() < 2)
646 if (loopInfo.getLoopFor(&b))
649 assert(b.getNumSuccessors() == 2);
650 Block *succ0 = b.getSuccessor(0);
651 Block *succ1 = b.getSuccessor(1);
656 Block *
mergeBlock = postDomInfo.findNearestCommonDominator(succ0, succ1);
660 return parentOp->emitError(
"expected only reducible control flow.")
662 <<
"This branch is involved in the irreducible control flow";
665 unsigned nonLoopPreds = 0;
666 CFGLoop *loop = loopInfo.getLoopFor(
mergeBlock);
667 for (
auto *pred :
mergeBlock->getPredecessors()) {
668 if (loop && loop->contains(pred))
672 if (nonLoopPreds > 2)
674 ->emitError(
"expected a merge block to have two predecessors. "
675 "Did you run the merge block insertion pass?")
677 <<
"This branch jumps to the illegal block";
685 LogicalResult FeedForwardNetworkRewriter::apply() {
688 if (failed(findBlockPairs(pairs)))
692 handshake::ConditionalBranchOp ctrlBr;
693 BufferOp buffer = buildSplitNetwork(splitBlock, ctrlBr);
694 if (failed(buildMergeNetwork(
mergeBlock, buffer, ctrlBr)))
701 BufferOp FeedForwardNetworkRewriter::buildSplitNetwork(
702 Block *splitBlock, handshake::ConditionalBranchOp &ctrlBr) {
703 SmallVector<handshake::ConditionalBranchOp> branches;
704 llvm::copy(splitBlock->getOps<handshake::ConditionalBranchOp>(),
705 std::back_inserter(branches));
707 auto *findRes = llvm::find_if(branches, [](
auto br) {
708 return llvm::isa<NoneType>(br.getDataOperand().getType());
711 assert(findRes &&
"expected one branch for the ctrl signal");
714 Value cond = ctrlBr.getConditionOperand();
715 assert(llvm::all_of(branches, [&](
auto branch) {
716 return branch.getConditionOperand() == cond;
719 Location loc = cond.getLoc();
720 rewriter.setInsertionPointAfterValue(cond);
724 size_t bufferSize = 2;
728 return rewriter.create<handshake::BufferOp>(loc, cond, bufferSize,
729 BufferTypeEnum::fifo);
732 LogicalResult FeedForwardNetworkRewriter::buildMergeNetwork(
733 Block *
mergeBlock, BufferOp buf, handshake::ConditionalBranchOp &ctrlBr) {
735 auto ctrlMerges =
mergeBlock->getOps<handshake::ControlMergeOp>();
736 assert(std::distance(ctrlMerges.begin(), ctrlMerges.end()) == 1);
738 handshake::ControlMergeOp ctrlMerge = *ctrlMerges.begin();
740 if (ctrlMerge.getNumOperands() != 2)
741 return ctrlMerge.emitError(
"expected cmerges to have two operands");
742 rewriter.setInsertionPointAfter(ctrlMerge);
743 Location loc = ctrlMerge->getLoc();
748 bool requiresFlip = requiresOperandFlip(ctrlMerge, ctrlBr);
749 SmallVector<Value> muxOperands;
751 muxOperands = llvm::to_vector(llvm::reverse(ctrlMerge.getOperands()));
753 muxOperands = llvm::to_vector(ctrlMerge.getOperands());
755 Value newCtrl = rewriter.create<handshake::MuxOp>(loc, buf, muxOperands);
757 Value cond = buf.getResult();
762 cond = rewriter.create<arith::XOrIOp>(
763 loc, cond.getType(), cond,
764 rewriter.create<arith::ConstantOp>(
765 loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 1)));
770 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), cond);
775 rewriter.replaceOp(ctrlMerge, {newCtrl, condAsIndex});
779 bool FeedForwardNetworkRewriter::requiresOperandFlip(
780 ControlMergeOp &ctrlMerge, handshake::ConditionalBranchOp &ctrlBr) {
781 assert(ctrlMerge.getNumOperands() == 2 &&
782 "Loops should already have been handled");
784 Value fstOperand = ctrlMerge.getOperand(0);
786 assert(ctrlBr.getTrueResult().hasOneUse() &&
787 "expected the result of a branch to only have one user");
788 Operation *trueUser = *ctrlBr.getTrueResult().user_begin();
789 if (trueUser == ctrlBr)
791 return ctrlBr.getTrueResult() == fstOperand;
795 Block *trueBlock = trueUser->getBlock();
796 return domInfo.dominates(trueBlock, fstOperand.getDefiningOp()->getBlock());
809 class LoopNetworkRewriter {
813 LogicalResult processRegion(Region &r, ConversionPatternRewriter &rewriter);
818 using ExitPair = std::pair<Block *, Block *>;
819 LogicalResult processOuterLoop(Location loc, CFGLoop *loop);
828 BufferOp buildContinueNetwork(Block *loopHeader, Block *loopLatch,
834 void buildExitNetwork(Block *loopHeader,
836 BufferOp loopPrimingRegister,
840 ConversionPatternRewriter *rewriter =
nullptr;
846 HandshakeLowering::loopNetworkRewriting(ConversionPatternRewriter &rewriter) {
847 return LoopNetworkRewriter(*this).processRegion(r, rewriter);
851 LoopNetworkRewriter::processRegion(Region &r,
852 ConversionPatternRewriter &rewriter) {
856 this->rewriter = &rewriter;
858 Operation *op = r.getParentOp();
860 DominanceInfo domInfo(op);
861 CFGLoopInfo loopInfo(domInfo.getDomTree(&r));
863 for (CFGLoop *loop : loopInfo.getTopLevelLoops()) {
864 if (!loop->getLoopLatch())
865 return emitError(op->getLoc()) <<
"Multiple loop latches detected "
866 "(backedges from within the loop "
867 "to the loop header). Loop task "
868 "pipelining is only supported for "
869 "loops with unified loop latches.";
872 if (failed(processOuterLoop(op->getLoc(), loop)))
881 auto inValueIt = llvm::find_if(mux.getDataOperands(), [&](Value operand) {
882 return block == operand.getParentBlock();
885 inValueIt != mux.getDataOperands().end() &&
886 "Expected mux to have an operand originating from the requested block.");
894 std::vector<Value> sortedOperands;
895 for (
auto in : cmerge.getOperands()) {
896 auto *srcBlock = in.getParentBlock();
901 for (
unsigned i = 0; i < sortedOperands.size(); ++i) {
902 for (
unsigned j = 0; j < sortedOperands.size(); ++j) {
905 assert(sortedOperands[i] != sortedOperands[j] &&
906 "Cannot have an identical operand from two different blocks!");
910 return sortedOperands;
913 BufferOp LoopNetworkRewriter::buildContinueNetwork(Block *loopHeader,
920 llvm::SmallVector<MuxOp> muxesToReplace;
921 llvm::copy(loopHeader->getOps<MuxOp>(), std::back_inserter(muxesToReplace));
927 assert(hl.getBlockEntryControl(loopHeader) == cmerge->getResult(0) &&
928 "Expected control merge to be the control component of a loop header");
929 auto loc = cmerge->getLoc();
932 assert(cmerge->getNumOperands() > 1 &&
"This cannot be a loop header");
936 SmallVector<Value> externalCtrls, loopCtrls;
937 for (
auto cval : cmerge->getOperands()) {
938 if (cval.getParentBlock() == loopLatch)
939 loopCtrls.push_back(cval);
941 externalCtrls.push_back(cval);
943 assert(loopCtrls.size() == 1 &&
944 "Expected a single loop control value to match the single loop latch");
945 Value loopCtrl = loopCtrls.front();
948 rewriter->setInsertionPointToStart(loopHeader);
949 auto externalCtrlMerge = rewriter->create<ControlMergeOp>(loc, externalCtrls);
954 auto primingRegister =
955 rewriter->create<BufferOp>(loc, loopPrimingInput, 1, BufferTypeEnum::seq);
957 primingRegister->setAttr(
"initValues", rewriter->getI64ArrayAttr({0}));
961 auto loopCtrlMux = rewriter->create<MuxOp>(
962 loc, primingRegister.getResult(),
963 llvm::SmallVector<Value>{externalCtrlMerge.getResult(), loopCtrl});
967 cmerge->getResult(0).replaceAllUsesWith(loopCtrlMux.getResult());
970 hl.setBlockEntryControl(loopHeader, loopCtrlMux.getResult());
981 DenseMap<MuxOp, std::vector<Value>> externalDataInputs;
982 DenseMap<MuxOp, Value> loopDataInputs;
983 for (
auto muxOp : muxesToReplace) {
984 if (muxOp == loopCtrlMux)
989 assert( 1 + externalDataInputs[muxOp].size() ==
990 muxOp.getDataOperands().size() &&
991 "Expected all mux operands to be partitioned between loop and "
992 "external data inputs");
1000 for (MuxOp mux : muxesToReplace) {
1001 auto externalDataMux = rewriter->create<MuxOp>(
1002 loc, externalCtrlMerge.getIndex(), externalDataInputs[mux]);
1004 rewriter->replaceOp(
1006 ->create<MuxOp>(loc, primingRegister,
1007 llvm::SmallVector<Value>{externalDataMux,
1008 loopDataInputs[mux]})
1014 rewriter->eraseOp(cmerge);
1017 return primingRegister;
1020 void LoopNetworkRewriter::buildExitNetwork(
1022 BufferOp loopPrimingRegister,
Backedge &loopPrimingInput) {
1023 auto loc = loopPrimingRegister.getLoc();
1032 SmallVector<Value> parityCorrectedConds;
1033 for (
auto &[condBlock, exitBlock] : exitPairs) {
1037 "Expected a conditional control branch op in the loop condition block");
1038 Operation *trueUser = *condBr.getTrueResult().getUsers().begin();
1039 bool isTrueParity = trueUser->getBlock() == exitBlock;
1041 ((*condBr.getFalseResult().getUsers().begin())->getBlock() ==
1043 "The user of either the true or the false result should be in the "
1046 Value condValue = condBr.getConditionOperand();
1050 rewriter->setInsertionPoint(condBr);
1051 condValue = rewriter->create<arith::XOrIOp>(
1052 loc, condValue.getType(), condValue,
1053 rewriter->create<arith::ConstantOp>(
1054 loc, rewriter->getIntegerAttr(rewriter->getI1Type(), 1)));
1056 parityCorrectedConds.push_back(condValue);
1061 auto exitMerge = rewriter->create<MergeOp>(loc, parityCorrectedConds);
1062 loopPrimingInput.
setValue(exitMerge);
1065 LogicalResult LoopNetworkRewriter::processOuterLoop(Location loc,
1070 SmallVector<Block *> exitBlocks;
1071 loop->getExitBlocks(exitBlocks);
1072 for (
auto *exitNode : exitBlocks) {
1073 for (
auto *pred : exitNode->getPredecessors()) {
1075 if (!loop->contains(pred))
1078 ExitPair condPair = {pred, exitNode};
1079 assert(!exitPairs.count(condPair) &&
1080 "identical condition pairs should never be possible");
1081 exitPairs.insert(condPair);
1084 assert(!exitPairs.empty() &&
"No exits from loop?");
1088 if (exitPairs.size() > 1)
1089 return emitError(loc)
1090 <<
"Multiple exits detected within a loop. Loop task pipelining is "
1091 "only supported for loops with unified loop exit blocks.";
1093 Block *header = loop->getHeader();
1098 auto loopPrimingRegisterInput = bebuilder.get(rewriter->getI1Type());
1099 auto loopPrimingRegister = buildContinueNetwork(header, loop->getLoopLatch(),
1100 loopPrimingRegisterInput);
1104 buildExitNetwork(header, exitPairs, loopPrimingRegister,
1105 loopPrimingRegisterInput);
1114 if (
auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
1115 if (condBranchOp.getTrueDest() == succBlock)
1116 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getTrueResult();
1118 assert(condBranchOp.getFalseDest() == succBlock);
1119 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getFalseResult();
1123 return newOp->getResult(0);
1127 HandshakeLowering::addBranchOps(ConversionPatternRewriter &rewriter) {
1131 for (Block &block : r) {
1132 for (Operation &op : block) {
1133 for (
auto result : op.getResults())
1135 liveOuts[&block].push_back(result);
1139 for (Block &block : r) {
1140 Operation *termOp = block.getTerminator();
1141 rewriter.setInsertionPoint(termOp);
1143 for (Value val : liveOuts[&block]) {
1148 for (
int i = 0, e = numBranches; i < e; ++i) {
1149 Operation *newOp =
nullptr;
1151 if (
auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp))
1152 newOp = rewriter.create<handshake::ConditionalBranchOp>(
1153 termOp->getLoc(), condBranchOp.getCondition(), val);
1154 else if (isa<mlir::cf::BranchOp>(termOp))
1155 newOp = rewriter.create<handshake::BranchOp>(termOp->getLoc(), val);
1157 if (newOp ==
nullptr)
1160 for (
int j = 0, e = block.getNumSuccessors(); j < e; ++j) {
1161 Block *succ = block.getSuccessor(j);
1164 for (
auto &u : val.getUses()) {
1165 if (u.getOwner()->getBlock() == succ) {
1166 u.getOwner()->replaceUsesOfWith(val, res);
1178 LogicalResult HandshakeLowering::connectConstantsToControl(
1179 ConversionPatternRewriter &rewriter,
bool sourceConstants) {
1185 if (sourceConstants) {
1186 for (
auto constantOp : llvm::make_early_inc_range(
1187 r.template getOps<mlir::arith::ConstantOp>())) {
1188 rewriter.setInsertionPointAfter(constantOp);
1189 auto value = constantOp.getValue();
1190 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1191 constantOp, value.getType(), value,
1192 rewriter.create<handshake::SourceOp>(constantOp.getLoc(),
1193 rewriter.getNoneType()));
1196 for (Block &block : r) {
1197 Value blockEntryCtrl = getBlockEntryControl(&block);
1198 for (
auto constantOp : llvm::make_early_inc_range(
1199 block.template getOps<mlir::arith::ConstantOp>())) {
1200 rewriter.setInsertionPointAfter(constantOp);
1201 auto value = constantOp.getValue();
1202 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1203 constantOp, value.getType(), value, blockEntryCtrl);
1219 : op(op), ctrlOperand(ctrlOperand) {
1220 assert(op && ctrlOperand);
1221 assert(isa<NoneType>(ctrlOperand.getType()) &&
1222 "Control operand must be a NoneType");
1235 for (Operation &op : *block) {
1236 if (
auto branchOp = dyn_cast<handshake::BranchOp>(op))
1237 if (branchOp.isControl())
1238 return {branchOp, branchOp.getDataOperand()};
1239 if (
auto branchOp = dyn_cast<handshake::ConditionalBranchOp>(op))
1240 if (branchOp.isControl())
1241 return {branchOp, branchOp.getDataOperand()};
1242 if (
auto endOp = dyn_cast<handshake::ReturnOp>(op))
1243 return {endOp, endOp.getOperands().back()};
1245 llvm_unreachable(
"Block terminator must exist");
1250 if (
auto memOp = dyn_cast<memref::LoadOp>(op))
1251 out = memOp.getMemRef();
1252 else if (
auto memOp = dyn_cast<memref::StoreOp>(op))
1253 out = memOp.getMemRef();
1254 else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
1255 MemRefAccess access(op);
1256 out = access.memref;
1260 return op->emitOpError(
"Unknown Op type");
1264 return isa<memref::LoadOp, memref::StoreOp, AffineReadOpInterface,
1265 AffineWriteOpInterface>(op);
1269 HandshakeLowering::replaceMemoryOps(ConversionPatternRewriter &rewriter,
1272 std::vector<Operation *> opsToErase;
1275 for (
auto arg : r.getArguments()) {
1276 auto memrefType = dyn_cast<mlir::MemRefType>(arg.getType());
1282 memRefOps.insert(std::make_pair(arg, std::vector<Operation *>()));
1288 for (Operation &op : r.getOps()) {
1292 rewriter.setInsertionPoint(&op);
1296 Operation *newOp =
nullptr;
1298 llvm::TypeSwitch<Operation *>(&op)
1299 .Case<memref::LoadOp>([&](
auto loadOp) {
1302 SmallVector<Value, 8> operands(loadOp.getIndices());
1305 rewriter.create<handshake::LoadOp>(op.getLoc(), memref, operands);
1306 op.getResult(0).replaceAllUsesWith(newOp->getResult(0));
1308 .Case<memref::StoreOp>([&](
auto storeOp) {
1311 SmallVector<Value, 8> operands(storeOp.getIndices());
1314 newOp = rewriter.create<handshake::StoreOp>(
1315 op.getLoc(), storeOp.getValueToStore(), operands);
1317 .Case<AffineReadOpInterface, AffineWriteOpInterface>([&](
auto) {
1319 MemRefAccess access(&op);
1326 if (
auto loadOp = dyn_cast<AffineReadOpInterface>(op))
1327 map = loadOp.getAffineMap();
1329 map = dyn_cast<AffineWriteOpInterface>(op).getAffineMap();
1337 expandAffineMap(rewriter, op.getLoc(), map, access.indices);
1338 assert(operands &&
"Address operands of affine memref access "
1339 "cannot be reduced.");
1341 if (isa<AffineReadOpInterface>(op)) {
1342 auto loadOp = rewriter.create<handshake::LoadOp>(
1343 op.getLoc(), access.memref, *operands);
1345 op.getResult(0).replaceAllUsesWith(loadOp.getDataResult());
1347 newOp = rewriter.create<handshake::StoreOp>(
1348 op.getLoc(), op.getOperand(0), *operands);
1351 .Default([&](
auto) {
1352 op.emitOpError(
"Load/store operation cannot be handled.");
1355 memRefOps[memref].push_back(newOp);
1356 opsToErase.push_back(&op);
1360 for (
unsigned i = 0, e = opsToErase.size(); i != e; ++i) {
1361 auto *op = opsToErase[i];
1362 for (
int j = 0, e = op->getNumOperands(); j < e; ++j)
1363 op->eraseOperand(0);
1364 assert(op->getNumOperands() == 0);
1366 rewriter.eraseOp(op);
1375 if (handshake::LoadOp loadOp = dyn_cast<handshake::LoadOp>(op)) {
1378 SmallVector<Value, 8> results(loadOp.getAddressResults());
1383 assert(dyn_cast<handshake::StoreOp>(op));
1384 handshake::StoreOp storeOp = dyn_cast<handshake::StoreOp>(op);
1385 SmallVector<Value, 8> results(storeOp.getResults());
1392 for (Block &block : f) {
1394 if (!ctrl.hasOneUse())
1400 ConversionPatternRewriter &rewriter) {
1401 std::vector<Operation *> opsToDelete;
1404 for (
auto &op : r.getOps())
1405 if (
isAllocOp(&op) && op.getResult(0).use_empty())
1406 opsToDelete.push_back(&op);
1408 llvm::for_each(opsToDelete, [&](
auto allocOp) { rewriter.eraseOp(allocOp); });
1412 ArrayRef<BlockControlTerm> controlTerms) {
1413 for (
auto term : controlTerms) {
1414 auto &[op, ctrl] = term;
1415 auto *srcOp = ctrl.getDefiningOp();
1418 if (!isa<JoinOp>(srcOp)) {
1419 rewriter.setInsertionPointAfter(srcOp);
1420 Operation *newJoin = rewriter.create<JoinOp>(srcOp->getLoc(), ctrl);
1421 op->replaceUsesOfWith(ctrl, newJoin->getResult(0));
1426 static std::vector<BlockControlTerm>
1428 std::vector<BlockControlTerm> terminators;
1430 for (Operation *op : memOps) {
1432 Block *block = op->getBlock();
1435 if (std::find(terminators.begin(), terminators.end(), term) ==
1437 terminators.push_back(term);
1444 SmallVector<Value, 8> results(op->getOperands());
1445 results.push_back(val);
1446 op->setOperands(results);
1452 for (
auto *op : memOps) {
1453 if (isa<handshake::LoadOp>(op))
1459 Operation *memOp,
int offset,
1460 ArrayRef<int> cntrlInd) {
1463 for (
int i = 0, e = memOps.size(); i < e; ++i) {
1464 auto *op = memOps[i];
1466 auto *srcOp = ctrl.getDefiningOp();
1467 if (!isa<JoinOp>(srcOp)) {
1468 return srcOp->emitOpError(
"Op expected to be a JoinOp");
1475 void HandshakeLowering::setMemOpControlInputs(
1476 ConversionPatternRewriter &rewriter, ArrayRef<Operation *> memOps,
1477 Operation *memOp,
int offset, ArrayRef<int> cntrlInd) {
1478 for (
int i = 0, e = memOps.size(); i < e; ++i) {
1479 std::vector<Value> controlOperands;
1480 Operation *currOp = memOps[i];
1481 Block *currBlock = currOp->getBlock();
1484 Value blockEntryCtrl = getBlockEntryControl(currBlock);
1485 controlOperands.push_back(blockEntryCtrl);
1488 for (
int j = 0, f = i; j < f; ++j) {
1489 Operation *predOp = memOps[j];
1490 Block *predBlock = predOp->getBlock();
1491 if (currBlock == predBlock)
1493 if (!(isa<handshake::LoadOp>(currOp) && isa<handshake::LoadOp>(predOp)))
1495 controlOperands.push_back(memOp->getResult(offset + cntrlInd[j]));
1499 if (controlOperands.size() == 1)
1504 rewriter.setInsertionPoint(currOp);
1506 rewriter.create<JoinOp>(currOp->getLoc(), controlOperands);
1513 HandshakeLowering::connectToMemory(ConversionPatternRewriter &rewriter,
1518 for (
auto memory : memRefOps) {
1520 Value memrefOperand = memory.first;
1524 bool isExternalMemory = isa<BlockArgument>(memrefOperand);
1526 mlir::MemRefType memrefType =
1527 cast<mlir::MemRefType>(memrefOperand.getType());
1531 std::vector<Value> operands;
1534 std::vector<BlockControlTerm> controlTerms =
1540 for (
auto valOp : controlTerms)
1541 operands.push_back(valOp.ctrlOperand);
1553 std::vector<int> newInd(memory.second.size(), 0);
1555 for (
int i = 0, e = memory.second.size(); i < e; ++i) {
1556 auto *op = memory.second[i];
1557 if (isa<handshake::StoreOp>(op)) {
1559 operands.insert(operands.end(), results.begin(), results.end());
1566 for (
int i = 0, e = memory.second.size(); i < e; ++i) {
1567 auto *op = memory.second[i];
1568 if (isa<handshake::LoadOp>(op)) {
1570 operands.insert(operands.end(), results.begin(), results.end());
1578 int cntrl_count = lsq ? 0 : memory.second.size();
1580 Block *entryBlock = &r.front();
1581 rewriter.setInsertionPointToStart(entryBlock);
1584 Operation *newOp =
nullptr;
1585 if (isExternalMemory)
1586 newOp = rewriter.create<ExternalMemoryOp>(
1587 entryBlock->front().getLoc(), memrefOperand, operands, ld_count,
1588 cntrl_count - ld_count, mem_count++);
1590 newOp = rewriter.create<MemoryOp>(entryBlock->front().getLoc(), operands,
1591 ld_count, cntrl_count, lsq, mem_count++,
1606 bool control =
true;
1616 setMemOpControlInputs(rewriter, memory.second, newOp, ld_count, newInd);
1628 HandshakeLowering::replaceCallOps(ConversionPatternRewriter &rewriter) {
1629 for (Block &block : r) {
1632 Value blockEntryControl = getBlockEntryControl(&block);
1633 for (Operation &op : block) {
1634 if (
auto callOp = dyn_cast<mlir::func::CallOp>(op)) {
1635 llvm::SmallVector<Value> operands;
1636 llvm::copy(callOp.getOperands(), std::back_inserter(operands));
1637 operands.push_back(blockEntryControl);
1638 rewriter.setInsertionPoint(callOp);
1639 auto instanceOp = rewriter.create<handshake::InstanceOp>(
1640 callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
1643 for (
auto it : llvm::zip(callOp.getResults(), instanceOp.getResults()))
1644 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
1645 rewriter.eraseOp(callOp);
1658 bool maximizeArgument(BlockArgument arg)
override {
1659 return !isa<mlir::MemRefType>(arg.getType());
1663 bool maximizeOp(Operation *op)
override {
return !
isAllocOp(op); }
1671 ConversionPatternRewriter &rewriter) {
1672 HandshakeLoweringSSAStrategy
strategy;
1676 static LogicalResult
lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
1677 bool sourceConstants,
1678 bool disableTaskPipelining) {
1680 SmallVector<NamedAttribute, 4> attributes;
1681 for (
const auto &attr : funcOp->getAttrs()) {
1682 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
1683 attr.getName() == funcOp.getFunctionTypeAttrName())
1685 attributes.push_back(attr);
1689 llvm::SmallVector<mlir::Type, 8> argTypes;
1690 for (
auto &argType : funcOp.getArgumentTypes())
1691 argTypes.push_back(argType);
1694 llvm::SmallVector<mlir::Type, 8> resTypes;
1695 for (
auto resType : funcOp.getResultTypes())
1696 resTypes.push_back(resType);
1702 if (partiallyLowerOp<func::FuncOp>(
1703 [&](func::FuncOp funcOp, PatternRewriter &rewriter) {
1704 auto noneType = rewriter.getNoneType();
1705 resTypes.push_back(noneType);
1706 argTypes.push_back(noneType);
1707 auto func_type = rewriter.getFunctionType(argTypes, resTypes);
1709 funcOp.getLoc(), funcOp.getName(), func_type, attributes);
1710 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1712 if (!newFuncOp.isExternal()) {
1713 newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(),
1715 newFuncOp.resolveArgAndResNames();
1717 rewriter.eraseOp(funcOp);
1728 if (!newFuncOp.isExternal()) {
1729 Block *bodyBlock = newFuncOp.getBodyBlock();
1730 Value entryCtrl = bodyBlock->getArguments().back();
1732 if (failed(lowerRegion<func::ReturnOp, handshake::ReturnOp>(
1733 fol, sourceConstants, disableTaskPipelining, entryCtrl)))
1742 struct HandshakeRemoveBlockPass
1743 : circt::impl::HandshakeRemoveBlockBase<HandshakeRemoveBlockPass> {
1747 struct CFToHandshakePass
1748 :
public circt::impl::CFToHandshakeBase<CFToHandshakePass> {
1749 CFToHandshakePass(
bool sourceConstants,
bool disableTaskPipelining) {
1750 this->sourceConstants = sourceConstants;
1751 this->disableTaskPipelining = disableTaskPipelining;
1753 void runOnOperation()
override {
1754 ModuleOp m = getOperation();
1756 for (
auto funcOp : llvm::make_early_inc_range(m.getOps<func::FuncOp>())) {
1757 if (failed(
lowerFuncOp(funcOp, &getContext(), sourceConstants,
1758 disableTaskPipelining))) {
1759 signalPassFailure();
1768 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1770 bool disableTaskPipelining) {
1771 return std::make_unique<CFToHandshakePass>(sourceConstants,
1772 disableTaskPipelining);
1775 std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
1777 return std::make_unique<HandshakeRemoveBlockPass>();
static ConditionalBranchOp getControlCondBranch(Block *block)
static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx, bool sourceConstants, bool disableTaskPipelining)
static bool isMemoryOp(Operation *op)
static LogicalResult setJoinControlInputs(ArrayRef< Operation * > memOps, Operation *memOp, int offset, ArrayRef< int > cntrlInd)
static void addJoinOps(ConversionPatternRewriter &rewriter, ArrayRef< BlockControlTerm > controlTerms)
static void addLazyForks(Region &f, ConversionPatternRewriter &rewriter)
static bool isLiveOut(Value val)
static Operation * getControlMerge(Block *block)
static SmallVector< Value, 8 > getResultsToMemory(Operation *op)
static unsigned getBlockPredecessorCount(Block *block)
static int getBranchCount(Value val, Block *block)
void removeBasicBlocks(handshake::FuncOp funcOp)
static Operation * getFirstOp(Block *block)
Returns the first occurance of an operation of type TOp, else, returns null op.
static Value getSuccResult(Operation *termOp, Operation *newOp, Block *succBlock)
static Value getMergeOperand(HandshakeLowering::MergeOpInfo mergeInfo, Block *predBlock)
static LogicalResult isValidMemrefType(Location loc, mlir::MemRefType type)
static bool isAllocOp(Operation *op)
static Value getOperandFromBlock(MuxOp mux, Block *block)
static LogicalResult getOpMemRef(Operation *op, Value &out)
static LogicalResult partiallyLowerOp(const std::function< LogicalResult(TOp, ConversionPatternRewriter &)> &loweringFunc, MLIRContext *ctx, TOp op)
static void addValueToOperands(Operation *op, Value val)
static bool loopsHaveSingleExit(CFGLoopInfo &loopInfo)
static std::vector< Value > getSortedInputs(ControlMergeOp cmerge, MuxOp mux)
static void removeBlockOperands(Region &f)
static void removeUnusedAllocOps(Region &r, ConversionPatternRewriter &rewriter)
static LogicalResult maximizeSSANoMem(Region &r, ConversionPatternRewriter &rewriter)
Converts every value in the region into maximal SSA form, unless the value is a block argument of typ...
static Operation * findBranchToBlock(Block *block)
static std::vector< BlockControlTerm > getControlTerminators(ArrayRef< Operation * > memOps)
static BlockControlTerm getBlockControlTerminator(Block *block)
static void reconnectMergeOps(Region &r, HandshakeLowering::BlockOps blockMerges, HandshakeLowering::ValueMap &mergePairs)
static void setLoadDataInputs(ArrayRef< Operation * > memOps, Operation *memOp)
assert(baseType &&"element must be base type")
static void insertFork(Value result, OpBuilder &rewriter)
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
LowerRegionTarget(MLIRContext &context, Region ®ion)
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Strategy class to control the behavior of SSA maximization.
DenseMap< Block *, std::vector< MergeOpInfo > > BlockOps
DenseMap< Block *, std::vector< Value > > BlockValues
llvm::MapVector< Value, std::vector< Operation * > > MemRefToMemoryAccessOp
DenseMap< Value, Value > ValueMap
FuncOp create(Union[StringAttr, str] sym_name, List[Tuple[str, Type]] args, List[Tuple[str, Type]] results, Dict[str, Attribute] attributes={}, loc=None, ip=None)
llvm::function_ref< LogicalResult(Region &, ConversionPatternRewriter &)> RegionLoweringFunc
LogicalResult partiallyLowerRegion(const RegionLoweringFunc &loweringFunc, MLIRContext *ctx, Region &r)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createCFToHandshakePass(bool sourceConstants=false, bool disableTaskPipelining=false)
std::unique_ptr< mlir::OperationPass< handshake::FuncOp > > createHandshakeRemoveBlockPass()
Holds information about an handshake "basic block terminator" control operation.
friend bool operator==(const BlockControlTerm &lhs, const BlockControlTerm &rhs)
Checks for member-wise equality.
Value ctrlOperand
The operation's control operand (must have type NoneType)
BlockControlTerm(Operation *op, Value ctrlOperand)
Operation * op
The operation.
Allows to partially lower a region by matching on the parent operation to then call the provided part...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value >, ConversionPatternRewriter &rewriter) const override
std::function< LogicalResult(Region &, ConversionPatternRewriter &)> PartialLoweringFunc
PartialLowerRegion(LowerRegionTarget &target, MLIRContext *context, LogicalResult &loweringResRef, const PartialLoweringFunc &fun)
LogicalResult & loweringRes
LowerRegionTarget & target