18#include "mlir/Analysis/CFGLoopInfo.h"
19#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
20#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
21#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
22#include "mlir/Dialect/Affine/IR/AffineOps.h"
23#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
24#include "mlir/Dialect/Affine/Utils.h"
25#include "mlir/Dialect/Arith/IR/Arith.h"
26#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
27#include "mlir/Dialect/Func/IR/FuncOps.h"
28#include "mlir/Dialect/MemRef/IR/MemRef.h"
29#include "mlir/Dialect/SCF/IR/SCF.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinOps.h"
32#include "mlir/IR/Diagnostics.h"
33#include "mlir/IR/Dominance.h"
34#include "mlir/IR/OpImplementation.h"
35#include "mlir/IR/PatternMatch.h"
36#include "mlir/IR/Types.h"
37#include "mlir/IR/Value.h"
38#include "mlir/Pass/Pass.h"
39#include "mlir/Support/LLVM.h"
40#include "mlir/Transforms/DialectConversion.h"
41#include "mlir/Transforms/Passes.h"
42#include "llvm/ADT/SmallSet.h"
43#include "llvm/ADT/TypeSwitch.h"
44#include "llvm/Support/raw_ostream.h"
50#define GEN_PASS_DEF_CFTOHANDSHAKE
51#define GEN_PASS_DEF_HANDSHAKEREMOVEBLOCK
52#include "circt/Conversion/Passes.h.inc"
67template <
typename TOp>
68class LowerOpTarget :
public ConversionTarget {
70 explicit LowerOpTarget(MLIRContext &context) : ConversionTarget(context) {
72 addLegalDialect<HandshakeDialect>();
73 addLegalDialect<mlir::func::FuncDialect>();
74 addLegalDialect<mlir::arith::ArithDialect>();
75 addIllegalDialect<mlir::scf::SCFDialect>();
76 addIllegalDialect<AffineDialect>();
81 addDynamicallyLegalOp<TOp>([&](
const auto &op) {
return loweredOps[op]; });
83 DenseMap<Operation *, bool> loweredOps;
100template <
typename TOp>
101struct PartialLowerOp :
public ConversionPattern {
102 using PartialLoweringFunc =
103 std::function<LogicalResult(TOp, ConversionPatternRewriter &)>;
106 PartialLowerOp(LowerOpTarget<TOp> &target, MLIRContext *context,
107 LogicalResult &loweringResRef,
const PartialLoweringFunc &fun)
108 : ConversionPattern(TOp::getOperationName(), 1, context), target(target),
109 loweringRes(loweringResRef), fun(fun) {}
110 using ConversionPattern::ConversionPattern;
112 matchAndRewrite(Operation *op, ArrayRef<Value> ,
113 ConversionPatternRewriter &rewriter)
const override {
115 loweringRes = fun(dyn_cast<TOp>(op), rewriter);
116 target.loweredOps[op] =
true;
121 LowerOpTarget<TOp> ⌖
122 LogicalResult &loweringRes;
124 PartialLoweringFunc fun;
130template <
typename TOp>
132 const std::function<LogicalResult(TOp, ConversionPatternRewriter &)>
134 MLIRContext *ctx, TOp op) {
137 auto target = LowerOpTarget<TOp>(*ctx);
138 LogicalResult partialLoweringSuccessfull = success();
139 patterns.add<PartialLowerOp<TOp>>(target, ctx, partialLoweringSuccessfull,
142 applyPartialConversion(op, target, std::move(
patterns)).succeeded() &&
143 partialLoweringSuccessfull.succeeded());
152 markUnknownOpDynamicallyLegal([&](Operation *op) {
153 if (op !=
region.getParentOp())
169 std::function<LogicalResult(Region &, ConversionPatternRewriter &)>;
173 LogicalResult &loweringResRef,
175 : ConversionPattern(
target.region.getParentOp()->getName().getStringRef(),
178 using ConversionPattern::ConversionPattern;
181 ConversionPatternRewriter &rewriter)
const override {
182 rewriter.modifyOpInPlace(
197 MLIRContext *ctx, Region &r) {
199 Operation *op = r.getParentOp();
202 LogicalResult partialLoweringSuccessfull = success();
206 applyPartialConversion(op, target, std::move(
patterns)).succeeded() &&
207 partialLoweringSuccessfull.succeeded());
217 "No block entry control value registerred for this block!");
226 Block *entryBlock = &r.front();
227 auto &entryBlockOps = entryBlock->getOperations();
230 for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
231 entryBlockOps.splice(entryBlockOps.end(), block.getOperations());
234 block.dropAllDefinedValueUses();
235 for (
size_t i = 0; i < block.getNumArguments(); i++) {
236 block.eraseArgument(i);
243 for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) {
244 if (!terminatorLike.hasTrait<OpTrait::IsTerminator>())
247 if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(terminatorLike)) {
248 terminatorLike.erase();
253 terminatorLike.moveBefore(entryBlock, entryBlock->end());
264 if (funcOp.isExternal())
271 if (type.getNumDynamicDims() != 0 || type.getShape().size() != 1)
272 return emitError(loc) <<
"memref's must be both statically sized and "
279 auto predecessors = block->getPredecessors();
280 return std::distance(predecessors.begin(), predecessors.end());
288 ConversionPatternRewriter &rewriter) {
290 auto insertLoc = block->front().getLoc();
291 SmallVector<Backedge> dataEdges;
292 SmallVector<Value> operands;
299 if (block == &
r.front()) {
307 operands.push_back(val);
308 mergeOp = handshake::MergeOp::create(rewriter, insertLoc, operands);
310 for (
unsigned i = 0; i < numPredecessors; i++) {
311 auto edge = edgeBuilder.
get(rewriter.getNoneType());
312 dataEdges.push_back(edge);
313 operands.push_back(Value(edge));
316 handshake::ControlMergeOp::create(rewriter, insertLoc, operands);
327 if (numPredecessors <= 1) {
328 if (numPredecessors == 0) {
332 operands.push_back(val);
336 auto edge = edgeBuilder.
get(val.getType());
337 dataEdges.push_back(edge);
338 operands.push_back(Value(edge));
340 auto merge = handshake::MergeOp::create(rewriter, insertLoc, operands);
348 Backedge indexEdge = edgeBuilder.
get(rewriter.getIndexType());
349 for (
unsigned i = 0; i < numPredecessors; i++) {
350 auto edge = edgeBuilder.
get(val.getType());
351 dataEdges.push_back(edge);
352 operands.push_back(Value(edge));
355 handshake::MuxOp::create(rewriter, insertLoc, Value(indexEdge), operands);
356 return MergeOpInfo{mux, val, dataEdges, indexEdge};
362 ConversionPatternRewriter &rewriter) {
364 for (Block &block :
r) {
365 rewriter.setInsertionPointToStart(&block);
369 for (
auto &arg : block.getArguments()) {
371 if (isa<mlir::MemRefType>(arg.getType()))
374 auto mergeInfo =
insertMerge(&block, arg, edgeBuilder, rewriter);
375 blockMerges[&block].push_back(mergeInfo);
376 mergePairs[arg] = mergeInfo.op->getResult(0);
386 Value srcVal = mergeInfo.
val;
388 Block *block = mergeInfo.
op->getBlock();
393 unsigned index = cast<BlockArgument>(srcVal).getArgNumber();
394 Operation *termOp = predBlock->getTerminator();
395 if (mlir::cf::CondBranchOp br = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
397 if (block == br.getTrueDest())
398 return br.getTrueOperand(index);
399 assert(block == br.getFalseDest());
400 return br.getFalseOperand(index);
402 if (isa<mlir::cf::BranchOp>(termOp))
403 return termOp->getOperand(index);
410 for (Block &block : f) {
411 if (!block.isEntryBlock()) {
412 int x = block.getNumArguments() - 1;
413 for (
int i = x; i >= 0; --i)
414 block.eraseArgument(i);
421template <
typename TOp>
423 auto ops = block->getOps<TOp>();
430 return getFirstOp<ControlMergeOp>(block);
434 for (
auto cbranch : block->getOps<handshake::ConditionalBranchOp>()) {
435 if (cbranch.isControl())
449 for (Block &block : r) {
450 for (
auto &mergeInfo : blockMerges[&block]) {
453 for (
auto *predBlock : block.getPredecessors()) {
455 assert(mgOperand !=
nullptr);
456 if (!mgOperand.getDefiningOp()) {
457 assert(mergePairs.count(mgOperand));
458 mgOperand = mergePairs[mgOperand];
460 mergeInfo.dataEdges[operandIdx].setValue(mgOperand);
466 for (Operation &opp : block)
467 if (!isa<MergeLikeOpInterface>(opp))
468 opp.replaceUsesOfWith(mergeInfo.val, mergeInfo.op->getResult(0));
474 for (Block &block : r) {
477 assert(cntrlMg !=
nullptr);
479 for (
auto &mergeInfo : blockMerges[&block]) {
480 if (mergeInfo.op != cntrlMg) {
484 assert(mergeInfo.indexEdge.has_value());
485 (*mergeInfo.indexEdge).setValue(cntrlMg->getResult(1));
495 return isa<memref::AllocOp, memref::AllocaOp>(op);
520 for (
auto &u : val.getUses())
522 if (isa<MergeLikeOpInterface>(u.getOwner()))
533 for (
int i = 0, e = block->getNumSuccessors(); i < e; ++i) {
535 Block *succ = block->getSuccessor(i);
536 for (
auto &u : val.getUses()) {
537 if (u.getOwner()->getBlock() == succ)
540 uses = (curr > uses) ? curr : uses;
554class FeedForwardNetworkRewriter {
557 ConversionPatternRewriter &rewriter)
558 : hl(hl), rewriter(rewriter), postDomInfo(hl.getRegion().getParentOp()),
559 domInfo(hl.getRegion().getParentOp()),
560 loopInfo(domInfo.getDomTree(&hl.getRegion())) {}
561 LogicalResult apply();
565 ConversionPatternRewriter &rewriter;
566 PostDominanceInfo postDomInfo;
567 DominanceInfo domInfo;
568 CFGLoopInfo loopInfo;
570 using BlockPair = std::pair<Block *, Block *>;
571 using BlockPairs = SmallVector<BlockPair>;
572 LogicalResult findBlockPairs(BlockPairs &blockPairs);
574 BufferOp buildSplitNetwork(Block *splitBlock,
575 handshake::ConditionalBranchOp &ctrlBr);
576 LogicalResult buildMergeNetwork(Block *
mergeBlock, BufferOp buf,
577 handshake::ConditionalBranchOp &ctrlBr);
580 bool requiresOperandFlip(ControlMergeOp &ctrlMerge,
581 handshake::ConditionalBranchOp &ctrlBr);
582 bool formsIrreducibleCF(Block *splitBlock, Block *
mergeBlock);
591 return FeedForwardNetworkRewriter(*
this, rewriter).apply();
595 for (CFGLoop *loop : loopInfo.getTopLevelLoops())
596 if (!loop->getExitBlock())
601bool FeedForwardNetworkRewriter::formsIrreducibleCF(Block *splitBlock,
603 CFGLoop *loop = loopInfo.getLoopFor(
mergeBlock);
604 for (
auto *mergePred :
mergeBlock->getPredecessors()) {
606 if (loop && loop->contains(mergePred))
614 if (llvm::none_of(splitBlock->getSuccessors(), [&](Block *splitSucc) {
615 if (splitSucc == mergeBlock || mergePred == splitBlock)
617 return domInfo.dominates(splitSucc, mergePred);
625 Block *pred = *block->getPredecessors().begin();
626 return pred->getTerminator();
630FeedForwardNetworkRewriter::findBlockPairs(BlockPairs &blockPairs) {
634 Region &r = hl.getRegion();
635 Operation *parentOp = r.getParentOp();
640 "expected loop to only have one exit block.");
643 if (b.getNumSuccessors() < 2)
647 if (loopInfo.getLoopFor(&b))
650 assert(b.getNumSuccessors() == 2);
651 Block *succ0 = b.getSuccessor(0);
652 Block *succ1 = b.getSuccessor(1);
657 Block *
mergeBlock = postDomInfo.findNearestCommonDominator(succ0, succ1);
661 return parentOp->emitError(
"expected only reducible control flow.")
663 <<
"This branch is involved in the irreducible control flow";
666 unsigned nonLoopPreds = 0;
667 CFGLoop *loop = loopInfo.getLoopFor(
mergeBlock);
668 for (
auto *pred :
mergeBlock->getPredecessors()) {
669 if (loop && loop->contains(pred))
673 if (nonLoopPreds > 2)
675 ->emitError(
"expected a merge block to have two predecessors. "
676 "Did you run the merge block insertion pass?")
678 <<
"This branch jumps to the illegal block";
686LogicalResult FeedForwardNetworkRewriter::apply() {
689 if (failed(findBlockPairs(pairs)))
693 handshake::ConditionalBranchOp ctrlBr;
694 BufferOp buffer = buildSplitNetwork(splitBlock, ctrlBr);
695 if (failed(buildMergeNetwork(
mergeBlock, buffer, ctrlBr)))
702BufferOp FeedForwardNetworkRewriter::buildSplitNetwork(
703 Block *splitBlock, handshake::ConditionalBranchOp &ctrlBr) {
704 SmallVector<handshake::ConditionalBranchOp> branches;
705 llvm::copy(splitBlock->getOps<handshake::ConditionalBranchOp>(),
706 std::back_inserter(branches));
708 auto *findRes = llvm::find_if(branches, [](
auto br) {
709 return llvm::isa<NoneType>(br.getDataOperand().getType());
712 assert(findRes &&
"expected one branch for the ctrl signal");
715 Value cond = ctrlBr.getConditionOperand();
716 assert(llvm::all_of(branches, [&](
auto branch) {
717 return branch.getConditionOperand() == cond;
720 Location loc = cond.getLoc();
721 rewriter.setInsertionPointAfterValue(cond);
725 size_t bufferSize = 2;
729 return handshake::BufferOp::create(rewriter, loc, cond, bufferSize,
730 BufferTypeEnum::fifo);
733LogicalResult FeedForwardNetworkRewriter::buildMergeNetwork(
734 Block *
mergeBlock, BufferOp buf, handshake::ConditionalBranchOp &ctrlBr) {
736 auto ctrlMerges =
mergeBlock->getOps<handshake::ControlMergeOp>();
737 assert(std::distance(ctrlMerges.begin(), ctrlMerges.end()) == 1);
739 handshake::ControlMergeOp ctrlMerge = *ctrlMerges.begin();
741 if (ctrlMerge.getNumOperands() != 2)
742 return ctrlMerge.emitError(
"expected cmerges to have two operands");
743 rewriter.setInsertionPointAfter(ctrlMerge);
744 Location loc = ctrlMerge->getLoc();
749 bool requiresFlip = requiresOperandFlip(ctrlMerge, ctrlBr);
750 SmallVector<Value> muxOperands;
752 muxOperands = llvm::to_vector(llvm::reverse(ctrlMerge.getOperands()));
754 muxOperands = llvm::to_vector(ctrlMerge.getOperands());
756 Value newCtrl = handshake::MuxOp::create(rewriter, loc, buf, muxOperands);
758 Value cond = buf.getResult();
763 cond = arith::XOrIOp::create(
764 rewriter, loc, cond.getType(), cond,
765 arith::ConstantOp::create(
766 rewriter, loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 1)));
771 arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), cond);
776 rewriter.replaceOp(ctrlMerge, {newCtrl, condAsIndex});
780bool FeedForwardNetworkRewriter::requiresOperandFlip(
781 ControlMergeOp &ctrlMerge, handshake::ConditionalBranchOp &ctrlBr) {
782 assert(ctrlMerge.getNumOperands() == 2 &&
783 "Loops should already have been handled");
785 Value fstOperand = ctrlMerge.getOperand(0);
787 assert(ctrlBr.getTrueResult().hasOneUse() &&
788 "expected the result of a branch to only have one user");
789 Operation *trueUser = *ctrlBr.getTrueResult().user_begin();
790 if (trueUser == ctrlBr)
792 return ctrlBr.getTrueResult() == fstOperand;
796 Block *trueBlock = trueUser->getBlock();
797 return domInfo.dominates(trueBlock, fstOperand.getDefiningOp()->getBlock());
810class LoopNetworkRewriter {
814 LogicalResult processRegion(Region &r, ConversionPatternRewriter &rewriter);
819 using ExitPair = std::pair<Block *, Block *>;
820 LogicalResult processOuterLoop(Location loc, CFGLoop *loop);
829 BufferOp buildContinueNetwork(Block *loopHeader, Block *loopLatch,
835 void buildExitNetwork(Block *loopHeader,
837 BufferOp loopPrimingRegister,
841 ConversionPatternRewriter *rewriter =
nullptr;
848 return LoopNetworkRewriter(*this).processRegion(
r, rewriter);
852LoopNetworkRewriter::processRegion(Region &r,
853 ConversionPatternRewriter &rewriter) {
857 this->rewriter = &rewriter;
859 Operation *op = r.getParentOp();
861 DominanceInfo domInfo(op);
862 CFGLoopInfo loopInfo(domInfo.getDomTree(&r));
864 for (CFGLoop *loop : loopInfo.getTopLevelLoops()) {
865 if (!loop->getLoopLatch())
866 return emitError(op->getLoc()) <<
"Multiple loop latches detected "
867 "(backedges from within the loop "
868 "to the loop header). Loop task "
869 "pipelining is only supported for "
870 "loops with unified loop latches.";
873 if (failed(processOuterLoop(op->getLoc(), loop)))
882 auto inValueIt = llvm::find_if(mux.getDataOperands(), [&](Value operand) {
883 return block == operand.getParentBlock();
886 inValueIt != mux.getDataOperands().end() &&
887 "Expected mux to have an operand originating from the requested block.");
895 std::vector<Value> sortedOperands;
896 for (
auto in : cmerge.getOperands()) {
897 auto *srcBlock = in.getParentBlock();
902 for (
unsigned i = 0; i < sortedOperands.size(); ++i) {
903 for (
unsigned j = 0; j < sortedOperands.size(); ++j) {
906 assert(sortedOperands[i] != sortedOperands[j] &&
907 "Cannot have an identical operand from two different blocks!");
911 return sortedOperands;
914BufferOp LoopNetworkRewriter::buildContinueNetwork(Block *loopHeader,
921 llvm::SmallVector<MuxOp> muxesToReplace;
922 llvm::copy(loopHeader->getOps<MuxOp>(), std::back_inserter(muxesToReplace));
928 assert(hl.getBlockEntryControl(loopHeader) == cmerge->getResult(0) &&
929 "Expected control merge to be the control component of a loop header");
930 auto loc = cmerge->getLoc();
933 assert(cmerge->getNumOperands() > 1 &&
"This cannot be a loop header");
937 SmallVector<Value> externalCtrls, loopCtrls;
938 for (
auto cval : cmerge->getOperands()) {
939 if (cval.getParentBlock() == loopLatch)
940 loopCtrls.push_back(cval);
942 externalCtrls.push_back(cval);
944 assert(loopCtrls.size() == 1 &&
945 "Expected a single loop control value to match the single loop latch");
946 Value loopCtrl = loopCtrls.front();
949 rewriter->setInsertionPointToStart(loopHeader);
950 auto externalCtrlMerge = rewriter->create<ControlMergeOp>(loc, externalCtrls);
955 auto primingRegister =
956 rewriter->create<BufferOp>(loc, loopPrimingInput, 1, BufferTypeEnum::seq);
958 primingRegister->setAttr(
"initValues", rewriter->getI64ArrayAttr({0}));
962 auto loopCtrlMux = rewriter->create<MuxOp>(
963 loc, primingRegister.getResult(),
964 llvm::SmallVector<Value>{externalCtrlMerge.getResult(), loopCtrl});
968 cmerge->getResult(0).replaceAllUsesWith(loopCtrlMux.getResult());
971 hl.setBlockEntryControl(loopHeader, loopCtrlMux.getResult());
982 DenseMap<MuxOp, std::vector<Value>> externalDataInputs;
983 DenseMap<MuxOp, Value> loopDataInputs;
984 for (
auto muxOp : muxesToReplace) {
985 if (muxOp == loopCtrlMux)
990 assert( 1 + externalDataInputs[muxOp].size() ==
991 muxOp.getDataOperands().size() &&
992 "Expected all mux operands to be partitioned between loop and "
993 "external data inputs");
1001 for (MuxOp mux : muxesToReplace) {
1002 auto externalDataMux = rewriter->create<MuxOp>(
1003 loc, externalCtrlMerge.getIndex(), externalDataInputs[mux]);
1005 rewriter->replaceOp(
1007 ->create<MuxOp>(loc, primingRegister,
1008 llvm::SmallVector<Value>{externalDataMux,
1009 loopDataInputs[mux]})
1015 rewriter->eraseOp(cmerge);
1018 return primingRegister;
1021void LoopNetworkRewriter::buildExitNetwork(
1023 BufferOp loopPrimingRegister,
Backedge &loopPrimingInput) {
1024 auto loc = loopPrimingRegister.getLoc();
1033 SmallVector<Value> parityCorrectedConds;
1034 for (
auto &[condBlock, exitBlock] : exitPairs) {
1038 "Expected a conditional control branch op in the loop condition block");
1039 Operation *trueUser = *condBr.getTrueResult().getUsers().begin();
1040 bool isTrueParity = trueUser->getBlock() == exitBlock;
1042 ((*condBr.getFalseResult().getUsers().begin())->getBlock() ==
1044 "The user of either the true or the false result should be in the "
1047 Value condValue = condBr.getConditionOperand();
1051 rewriter->setInsertionPoint(condBr);
1052 condValue = rewriter->create<arith::XOrIOp>(
1053 loc, condValue.getType(), condValue,
1054 rewriter->create<arith::ConstantOp>(
1055 loc, rewriter->getIntegerAttr(rewriter->getI1Type(), 1)));
1057 parityCorrectedConds.push_back(condValue);
1062 auto exitMerge = rewriter->create<MergeOp>(loc, parityCorrectedConds);
1063 loopPrimingInput.
setValue(exitMerge);
1066LogicalResult LoopNetworkRewriter::processOuterLoop(Location loc,
1071 SmallVector<Block *> exitBlocks;
1072 loop->getExitBlocks(exitBlocks);
1073 for (
auto *exitNode : exitBlocks) {
1074 for (
auto *pred : exitNode->getPredecessors()) {
1076 if (!loop->contains(pred))
1079 ExitPair condPair = {pred, exitNode};
1080 assert(!exitPairs.count(condPair) &&
1081 "identical condition pairs should never be possible");
1082 exitPairs.insert(condPair);
1085 assert(!exitPairs.empty() &&
"No exits from loop?");
1089 if (exitPairs.size() > 1)
1090 return emitError(loc)
1091 <<
"Multiple exits detected within a loop. Loop task pipelining is "
1092 "only supported for loops with unified loop exit blocks.";
1094 Block *header = loop->getHeader();
1099 auto loopPrimingRegisterInput = bebuilder.get(rewriter->getI1Type());
1100 auto loopPrimingRegister = buildContinueNetwork(header, loop->getLoopLatch(),
1101 loopPrimingRegisterInput);
1105 buildExitNetwork(header, exitPairs, loopPrimingRegister,
1106 loopPrimingRegisterInput);
1115 if (
auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
1116 if (condBranchOp.getTrueDest() == succBlock)
1117 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getTrueResult();
1119 assert(condBranchOp.getFalseDest() == succBlock);
1120 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getFalseResult();
1124 return newOp->getResult(0);
1132 for (Block &block :
r) {
1133 for (Operation &op : block) {
1134 for (
auto result : op.getResults())
1136 liveOuts[&block].push_back(result);
1140 for (Block &block :
r) {
1141 Operation *termOp = block.getTerminator();
1142 rewriter.setInsertionPoint(termOp);
1144 for (Value val : liveOuts[&block]) {
1149 for (
int i = 0, e = numBranches; i < e; ++i) {
1150 Operation *newOp =
nullptr;
1152 if (
auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp))
1153 newOp = handshake::ConditionalBranchOp::create(
1154 rewriter, termOp->getLoc(), condBranchOp.getCondition(), val);
1155 else if (isa<mlir::cf::BranchOp>(termOp))
1156 newOp = handshake::BranchOp::create(rewriter, termOp->getLoc(), val);
1158 if (newOp ==
nullptr)
1161 for (
int j = 0, e = block.getNumSuccessors(); j < e; ++j) {
1162 Block *succ = block.getSuccessor(j);
1165 for (
auto &u : val.getUses()) {
1166 if (u.getOwner()->getBlock() == succ) {
1167 u.getOwner()->replaceUsesOfWith(val, res);
1180 ConversionPatternRewriter &rewriter,
bool sourceConstants) {
1186 if (sourceConstants) {
1187 for (
auto constantOp : llvm::make_early_inc_range(
1188 r.template getOps<mlir::arith::ConstantOp>())) {
1189 rewriter.setInsertionPointAfter(constantOp);
1190 auto value = constantOp.getValue();
1191 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1192 constantOp, value.getType(), value,
1193 handshake::SourceOp::create(rewriter, constantOp.getLoc(),
1194 rewriter.getNoneType()));
1197 for (Block &block :
r) {
1199 for (
auto constantOp : llvm::make_early_inc_range(
1200 block.template getOps<mlir::arith::ConstantOp>())) {
1201 rewriter.setInsertionPointAfter(constantOp);
1202 auto value = constantOp.getValue();
1203 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1204 constantOp, value.getType(), value, blockEntryCtrl);
1220 : op(op), ctrlOperand(ctrlOperand) {
1221 assert(op && ctrlOperand);
1222 assert(isa<NoneType>(ctrlOperand.getType()) &&
1223 "Control operand must be a NoneType");
1236 for (Operation &op : *block) {
1237 if (
auto branchOp = dyn_cast<handshake::BranchOp>(op))
1238 if (branchOp.isControl())
1239 return {branchOp, branchOp.getDataOperand()};
1240 if (
auto branchOp = dyn_cast<handshake::ConditionalBranchOp>(op))
1241 if (branchOp.isControl())
1242 return {branchOp, branchOp.getDataOperand()};
1243 if (
auto endOp = dyn_cast<handshake::ReturnOp>(op))
1244 return {endOp, endOp.getOperands().back()};
1246 llvm_unreachable(
"Block terminator must exist");
1251 if (
auto memOp = dyn_cast<memref::LoadOp>(op))
1252 out = memOp.getMemRef();
1253 else if (
auto memOp = dyn_cast<memref::StoreOp>(op))
1254 out = memOp.getMemRef();
1255 else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
1256 MemRefAccess access(op);
1257 out = access.memref;
1261 return op->emitOpError(
"Unknown Op type");
1265 return isa<memref::LoadOp, memref::StoreOp, AffineReadOpInterface,
1266 AffineWriteOpInterface>(op);
1273 std::vector<Operation *> opsToErase;
1276 for (
auto arg :
r.getArguments()) {
1277 auto memrefType = dyn_cast<mlir::MemRefType>(arg.getType());
1283 memRefOps.insert(std::make_pair(arg, std::vector<Operation *>()));
1289 for (Operation &op :
r.getOps()) {
1293 rewriter.setInsertionPoint(&op);
1297 Operation *newOp =
nullptr;
1299 llvm::TypeSwitch<Operation *>(&op)
1300 .Case<memref::LoadOp>([&](
auto loadOp) {
1303 SmallVector<Value, 8> operands(loadOp.getIndices());
1305 newOp = handshake::LoadOp::create(rewriter, op.getLoc(), memref,
1307 op.getResult(0).replaceAllUsesWith(newOp->getResult(0));
1309 .Case<memref::StoreOp>([&](
auto storeOp) {
1312 SmallVector<Value, 8> operands(storeOp.getIndices());
1315 newOp = handshake::StoreOp::create(
1316 rewriter, op.getLoc(), storeOp.getValueToStore(), operands);
1318 .Case<AffineReadOpInterface, AffineWriteOpInterface>([&](
auto) {
1320 MemRefAccess access(&op);
1327 if (
auto loadOp = dyn_cast<AffineReadOpInterface>(op))
1328 map = loadOp.getAffineMap();
1330 map = dyn_cast<AffineWriteOpInterface>(op).getAffineMap();
1338 expandAffineMap(rewriter, op.getLoc(), map, access.indices);
1339 assert(operands &&
"Address operands of affine memref access "
1340 "cannot be reduced.");
1342 if (isa<AffineReadOpInterface>(op)) {
1343 auto loadOp = handshake::LoadOp::create(rewriter, op.getLoc(),
1344 access.memref, *operands);
1346 op.getResult(0).replaceAllUsesWith(loadOp.getDataResult());
1348 newOp = handshake::StoreOp::create(rewriter, op.getLoc(),
1349 op.getOperand(0), *operands);
1352 .Default([&](
auto) {
1353 op.emitOpError(
"Load/store operation cannot be handled.");
1356 memRefOps[memref].push_back(newOp);
1357 opsToErase.push_back(&op);
1361 for (
unsigned i = 0, e = opsToErase.size(); i != e; ++i) {
1362 auto *op = opsToErase[i];
1363 for (
int j = 0, e = op->getNumOperands(); j < e; ++j)
1364 op->eraseOperand(0);
1365 assert(op->getNumOperands() == 0);
1367 rewriter.eraseOp(op);
1376 if (handshake::LoadOp loadOp = dyn_cast<handshake::LoadOp>(op)) {
1379 SmallVector<Value, 8> results(loadOp.getAddressResults());
1384 assert(dyn_cast<handshake::StoreOp>(op));
1385 handshake::StoreOp storeOp = dyn_cast<handshake::StoreOp>(op);
1386 SmallVector<Value, 8> results(storeOp.getResults());
1393 for (Block &block : f) {
1395 if (!ctrl.hasOneUse())
1401 ConversionPatternRewriter &rewriter) {
1402 std::vector<Operation *> opsToDelete;
1405 for (
auto &op : r.getOps())
1406 if (
isAllocOp(&op) && op.getResult(0).use_empty())
1407 opsToDelete.push_back(&op);
1409 llvm::for_each(opsToDelete, [&](
auto allocOp) { rewriter.eraseOp(allocOp); });
1413 ArrayRef<BlockControlTerm> controlTerms) {
1414 for (
auto term : controlTerms) {
1415 auto &[op, ctrl] = term;
1416 auto *srcOp = ctrl.getDefiningOp();
1419 if (!isa<JoinOp>(srcOp)) {
1420 rewriter.setInsertionPointAfter(srcOp);
1421 Operation *newJoin = JoinOp::create(rewriter, srcOp->getLoc(), ctrl);
1422 op->replaceUsesOfWith(ctrl, newJoin->getResult(0));
1427static std::vector<BlockControlTerm>
1429 std::vector<BlockControlTerm> terminators;
1431 for (Operation *op : memOps) {
1433 Block *block = op->getBlock();
1436 if (std::find(terminators.begin(), terminators.end(), term) ==
1438 terminators.push_back(term);
1445 SmallVector<Value, 8> results(op->getOperands());
1446 results.push_back(val);
1447 op->setOperands(results);
1453 for (
auto *op : memOps) {
1454 if (isa<handshake::LoadOp>(op))
1460 Operation *memOp,
int offset,
1461 ArrayRef<int> cntrlInd) {
1464 for (
int i = 0, e = memOps.size(); i < e; ++i) {
1465 auto *op = memOps[i];
1467 auto *srcOp = ctrl.getDefiningOp();
1468 if (!isa<JoinOp>(srcOp)) {
1469 return srcOp->emitOpError(
"Op expected to be a JoinOp");
1477 ConversionPatternRewriter &rewriter, ArrayRef<Operation *> memOps,
1478 Operation *memOp,
int offset, ArrayRef<int> cntrlInd) {
1479 for (
int i = 0, e = memOps.size(); i < e; ++i) {
1480 std::vector<Value> controlOperands;
1481 Operation *currOp = memOps[i];
1482 Block *currBlock = currOp->getBlock();
1486 controlOperands.push_back(blockEntryCtrl);
1489 for (
int j = 0, f = i; j < f; ++j) {
1490 Operation *predOp = memOps[j];
1491 Block *predBlock = predOp->getBlock();
1492 if (currBlock == predBlock)
1494 if (!(isa<handshake::LoadOp>(currOp) && isa<handshake::LoadOp>(predOp)))
1496 controlOperands.push_back(memOp->getResult(offset + cntrlInd[j]));
1500 if (controlOperands.size() == 1)
1505 rewriter.setInsertionPoint(currOp);
1507 JoinOp::create(rewriter, currOp->getLoc(), controlOperands);
1519 for (
auto memory : memRefOps) {
1521 Value memrefOperand = memory.first;
1525 bool isExternalMemory = isa<BlockArgument>(memrefOperand);
1527 mlir::MemRefType memrefType =
1528 cast<mlir::MemRefType>(memrefOperand.getType());
1532 std::vector<Value> operands;
1535 std::vector<BlockControlTerm> controlTerms =
1541 for (
auto valOp : controlTerms)
1542 operands.push_back(valOp.ctrlOperand);
1554 std::vector<int> newInd(memory.second.size(), 0);
1556 for (
int i = 0, e = memory.second.size(); i < e; ++i) {
1557 auto *op = memory.second[i];
1558 if (isa<handshake::StoreOp>(op)) {
1560 operands.insert(operands.end(), results.begin(), results.end());
1567 for (
int i = 0, e = memory.second.size(); i < e; ++i) {
1568 auto *op = memory.second[i];
1569 if (isa<handshake::LoadOp>(op)) {
1571 operands.insert(operands.end(), results.begin(), results.end());
1579 int cntrl_count = lsq ? 0 : memory.second.size();
1581 Block *entryBlock = &
r.front();
1582 rewriter.setInsertionPointToStart(entryBlock);
1585 Operation *newOp =
nullptr;
1586 if (isExternalMemory)
1587 newOp = ExternalMemoryOp::create(rewriter, entryBlock->front().getLoc(),
1588 memrefOperand, operands, ld_count,
1589 cntrl_count - ld_count, mem_count++);
1591 newOp = MemoryOp::create(rewriter, entryBlock->front().getLoc(), operands,
1592 ld_count, cntrl_count, lsq, mem_count++,
1607 bool control =
true;
1630 for (Block &block :
r) {
1634 for (Operation &op : block) {
1635 if (
auto callOp = dyn_cast<mlir::func::CallOp>(op)) {
1636 llvm::SmallVector<Value> operands;
1637 llvm::copy(callOp.getOperands(), std::back_inserter(operands));
1638 operands.push_back(blockEntryControl);
1639 rewriter.setInsertionPoint(callOp);
1640 auto instanceOp = handshake::InstanceOp::create(
1641 rewriter, callOp.getLoc(), callOp.getCallee(),
1642 callOp.getResultTypes(), operands);
1644 for (
auto it : llvm::zip(callOp.getResults(), instanceOp.getResults()))
1645 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
1646 rewriter.eraseOp(callOp);
1659 bool maximizeArgument(BlockArgument arg)
override {
1660 return !isa<mlir::MemRefType>(arg.getType());
1664 bool maximizeOp(Operation *op)
override {
return !
isAllocOp(op); }
1672 ConversionPatternRewriter &rewriter) {
1673 HandshakeLoweringSSAStrategy
strategy;
1677static LogicalResult
lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
1678 bool sourceConstants,
1679 bool disableTaskPipelining) {
1681 SmallVector<NamedAttribute, 4> attributes;
1682 for (
const auto &attr : funcOp->getAttrs()) {
1683 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
1684 attr.getName() == funcOp.getFunctionTypeAttrName())
1686 attributes.push_back(attr);
1690 llvm::SmallVector<mlir::Type, 8> argTypes;
1691 for (
auto &argType : funcOp.getArgumentTypes())
1692 argTypes.push_back(argType);
1695 llvm::SmallVector<mlir::Type, 8> resTypes;
1696 for (
auto resType : funcOp.getResultTypes())
1697 resTypes.push_back(resType);
1703 if (partiallyLowerOp<func::FuncOp>(
1704 [&](func::FuncOp funcOp, PatternRewriter &rewriter) {
1705 auto noneType = rewriter.getNoneType();
1706 resTypes.push_back(noneType);
1707 argTypes.push_back(noneType);
1708 auto func_type = rewriter.getFunctionType(argTypes, resTypes);
1710 funcOp.getName(), func_type,
1712 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1714 if (!newFuncOp.isExternal()) {
1715 newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(),
1717 newFuncOp.resolveArgAndResNames();
1719 rewriter.eraseOp(funcOp);
1730 if (!newFuncOp.isExternal()) {
1731 Block *bodyBlock = newFuncOp.getBodyBlock();
1732 Value entryCtrl = bodyBlock->getArguments().back();
1734 if (failed(lowerRegion<func::ReturnOp, handshake::ReturnOp>(
1735 fol, sourceConstants, disableTaskPipelining, entryCtrl)))
1744struct HandshakeRemoveBlockPass
1745 : circt::impl::HandshakeRemoveBlockBase<HandshakeRemoveBlockPass> {
1749struct CFToHandshakePass
1750 :
public circt::impl::CFToHandshakeBase<CFToHandshakePass> {
1751 CFToHandshakePass(
bool sourceConstants,
bool disableTaskPipelining) {
1752 this->sourceConstants = sourceConstants;
1753 this->disableTaskPipelining = disableTaskPipelining;
1755 void runOnOperation()
override {
1756 ModuleOp m = getOperation();
1758 for (
auto funcOp :
llvm::make_early_inc_range(m.getOps<func::FuncOp>())) {
1759 if (failed(
lowerFuncOp(funcOp, &getContext(), sourceConstants,
1760 disableTaskPipelining))) {
1761 signalPassFailure();
1770std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1772 bool disableTaskPipelining) {
1773 return std::make_unique<CFToHandshakePass>(sourceConstants,
1774 disableTaskPipelining);
1777std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
1779 return std::make_unique<HandshakeRemoveBlockPass>();
static ConditionalBranchOp getControlCondBranch(Block *block)
static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx, bool sourceConstants, bool disableTaskPipelining)
static Operation * getControlMerge(Block *block)
static bool isMemoryOp(Operation *op)
static std::vector< Value > getSortedInputs(ControlMergeOp cmerge, MuxOp mux)
static LogicalResult setJoinControlInputs(ArrayRef< Operation * > memOps, Operation *memOp, int offset, ArrayRef< int > cntrlInd)
static void addJoinOps(ConversionPatternRewriter &rewriter, ArrayRef< BlockControlTerm > controlTerms)
static void addLazyForks(Region &f, ConversionPatternRewriter &rewriter)
static bool isLiveOut(Value val)
static Operation * getFirstOp(Block *block)
Returns the first occurance of an operation of type TOp, else, returns null op.
static unsigned getBlockPredecessorCount(Block *block)
static int getBranchCount(Value val, Block *block)
static Operation * findBranchToBlock(Block *block)
static Value getSuccResult(Operation *termOp, Operation *newOp, Block *succBlock)
static Value getMergeOperand(HandshakeLowering::MergeOpInfo mergeInfo, Block *predBlock)
static LogicalResult isValidMemrefType(Location loc, mlir::MemRefType type)
static bool isAllocOp(Operation *op)
static Value getOperandFromBlock(MuxOp mux, Block *block)
static LogicalResult getOpMemRef(Operation *op, Value &out)
static LogicalResult partiallyLowerOp(const std::function< LogicalResult(TOp, ConversionPatternRewriter &)> &loweringFunc, MLIRContext *ctx, TOp op)
static void addValueToOperands(Operation *op, Value val)
static bool loopsHaveSingleExit(CFGLoopInfo &loopInfo)
static SmallVector< Value, 8 > getResultsToMemory(Operation *op)
static void removeBlockOperands(Region &f)
static void removeUnusedAllocOps(Region &r, ConversionPatternRewriter &rewriter)
static LogicalResult maximizeSSANoMem(Region &r, ConversionPatternRewriter &rewriter)
Converts every value in the region into maximal SSA form, unless the value is a block argument of typ...
static BlockControlTerm getBlockControlTerminator(Block *block)
static std::vector< BlockControlTerm > getControlTerminators(ArrayRef< Operation * > memOps)
static void reconnectMergeOps(Region &r, HandshakeLowering::BlockOps blockMerges, HandshakeLowering::ValueMap &mergePairs)
static void setLoadDataInputs(ArrayRef< Operation * > memOps, Operation *memOp)
assert(baseType &&"element must be base type")
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
static Location getLoc(DefSlot slot)
LowerRegionTarget(MLIRContext &context, Region ®ion)
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Strategy class to control the behavior of SSA maximization.
BlockOps insertMergeOps(ValueMap &mergePairs, BackedgeBuilder &edgeBuilder, ConversionPatternRewriter &rewriter)
MergeOpInfo insertMerge(Block *block, Value val, BackedgeBuilder &edgeBuilder, ConversionPatternRewriter &rewriter)
LogicalResult loopNetworkRewriting(ConversionPatternRewriter &rewriter)
DenseMap< Block *, std::vector< MergeOpInfo > > BlockOps
LogicalResult feedForwardRewriting(ConversionPatternRewriter &rewriter)
LogicalResult replaceCallOps(ConversionPatternRewriter &rewriter)
void setMemOpControlInputs(ConversionPatternRewriter &rewriter, ArrayRef< Operation * > memOps, Operation *memOp, int offset, ArrayRef< int > cntrlInd)
LogicalResult addMergeOps(ConversionPatternRewriter &rewriter)
LogicalResult replaceMemoryOps(ConversionPatternRewriter &rewriter, MemRefToMemoryAccessOp &memRefOps)
DenseMap< Block *, std::vector< Value > > BlockValues
LogicalResult connectConstantsToControl(ConversionPatternRewriter &rewriter, bool sourceConstants)
llvm::MapVector< Value, std::vector< Operation * > > MemRefToMemoryAccessOp
LogicalResult runSSAMaximization(ConversionPatternRewriter &rewriter, Value entryCtrl)
DenseMap< Block *, Value > blockEntryControlMap
DenseMap< Value, Value > ValueMap
Value getBlockEntryControl(Block *block) const
void setBlockEntryControl(Block *block, Value v)
LogicalResult connectToMemory(ConversionPatternRewriter &rewriter, MemRefToMemoryAccessOp memRefOps, bool lsq)
LogicalResult addBranchOps(ConversionPatternRewriter &rewriter)
FuncOp create(Union[StringAttr, str] sym_name, List[Tuple[str, Type]] args, List[Tuple[str, Type]] results, Dict[str, Attribute] attributes={}, loc=None, ip=None)
void insertFork(Value result, bool isLazy, OpBuilder &rewriter)
Adds fork operations to any value with multiple uses in r.
llvm::function_ref< LogicalResult(Region &, ConversionPatternRewriter &)> RegionLoweringFunc
void removeBasicBlocks(Region &r)
Remove basic blocks inside the given region.
LogicalResult partiallyLowerRegion(const RegionLoweringFunc &loweringFunc, MLIRContext *ctx, Region &r)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createCFToHandshakePass(bool sourceConstants=false, bool disableTaskPipelining=false)
std::unique_ptr< mlir::OperationPass< handshake::FuncOp > > createHandshakeRemoveBlockPass()
Holds information about an handshake "basic block terminator" control operation.
friend bool operator==(const BlockControlTerm &lhs, const BlockControlTerm &rhs)
Checks for member-wise equality.
Value ctrlOperand
The operation's control operand (must have type NoneType)
BlockControlTerm(Operation *op, Value ctrlOperand)
Operation * op
The operation.
Allows to partially lower a region by matching on the parent operation to then call the provided part...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value >, ConversionPatternRewriter &rewriter) const override
std::function< LogicalResult(Region &, ConversionPatternRewriter &)> PartialLoweringFunc
PartialLowerRegion(LowerRegionTarget &target, MLIRContext *context, LogicalResult &loweringResRef, const PartialLoweringFunc &fun)
LogicalResult & loweringRes
LowerRegionTarget & target