13 #include "../PassDetail.h"
18 #include "mlir/Analysis/CFGLoopInfo.h"
19 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
20 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
21 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
24 #include "mlir/Dialect/Affine/Utils.h"
25 #include "mlir/Dialect/Arith/IR/Arith.h"
26 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"
29 #include "mlir/Dialect/SCF/IR/SCF.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Diagnostics.h"
33 #include "mlir/IR/Dominance.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/Types.h"
37 #include "mlir/IR/Value.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Support/LLVM.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 #include "mlir/Transforms/Passes.h"
42 #include "llvm/ADT/SmallSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/Support/raw_ostream.h"
52 using namespace circt;
61 template <
typename TOp>
62 class LowerOpTarget :
public ConversionTarget {
64 explicit LowerOpTarget(MLIRContext &context) : ConversionTarget(context) {
66 addLegalDialect<HandshakeDialect>();
67 addLegalDialect<mlir::func::FuncDialect>();
68 addLegalDialect<mlir::arith::ArithDialect>();
69 addIllegalDialect<mlir::scf::SCFDialect>();
70 addIllegalDialect<AffineDialect>();
75 addDynamicallyLegalOp<TOp>([&](
const auto &op) {
return loweredOps[op]; });
77 DenseMap<Operation *, bool> loweredOps;
94 template <
typename TOp>
95 struct PartialLowerOp :
public ConversionPattern {
96 using PartialLoweringFunc =
97 std::function<LogicalResult(TOp, ConversionPatternRewriter &)>;
100 PartialLowerOp(LowerOpTarget<TOp> &target, MLIRContext *context,
101 LogicalResult &loweringResRef,
const PartialLoweringFunc &fun)
102 : ConversionPattern(TOp::getOperationName(), 1, context), target(target),
103 loweringRes(loweringResRef), fun(fun) {}
104 using ConversionPattern::ConversionPattern;
106 matchAndRewrite(Operation *op, ArrayRef<Value> ,
107 ConversionPatternRewriter &rewriter)
const override {
109 loweringRes = fun(dyn_cast<TOp>(op), rewriter);
110 target.loweredOps[op] =
true;
115 LowerOpTarget<TOp> ⌖
116 LogicalResult &loweringRes;
118 PartialLoweringFunc fun;
124 template <
typename TOp>
126 const std::function<LogicalResult(TOp, ConversionPatternRewriter &)>
128 MLIRContext *ctx, TOp op) {
131 auto target = LowerOpTarget<TOp>(*ctx);
132 LogicalResult partialLoweringSuccessfull = success();
133 patterns.add<PartialLowerOp<TOp>>(target, ctx, partialLoweringSuccessfull,
136 applyPartialConversion(op, target, std::move(
patterns)).succeeded() &&
137 partialLoweringSuccessfull.succeeded());
143 : ConversionTarget(context), region(region) {
146 markUnknownOpDynamicallyLegal([&](Operation *op) {
147 if (op != region.getParentOp())
152 bool opLowered =
false;
163 std::function<LogicalResult(Region &, ConversionPatternRewriter &)>;
167 LogicalResult &loweringResRef,
169 : ConversionPattern(target.region.getParentOp()->
getName().getStringRef(),
171 target(target), loweringRes(loweringResRef), fun(fun) {}
172 using ConversionPattern::ConversionPattern;
175 ConversionPatternRewriter &rewriter)
const override {
176 rewriter.modifyOpInPlace(
177 op, [&] { loweringRes = fun(target.region, rewriter); });
179 target.opLowered =
true;
191 MLIRContext *ctx, Region &r) {
193 Operation *op = r.getParentOp();
196 LogicalResult partialLoweringSuccessfull = success();
200 applyPartialConversion(op, target, std::move(
patterns)).succeeded() &&
201 partialLoweringSuccessfull.succeeded());
208 Value HandshakeLowering::getBlockEntryControl(Block *block)
const {
209 auto it = blockEntryControlMap.find(block);
210 assert(it != blockEntryControlMap.end() &&
211 "No block entry control value registerred for this block!");
215 void HandshakeLowering::setBlockEntryControl(Block *block, Value v) {
216 blockEntryControlMap[block] = v;
220 Block *entryBlock = &r.front();
221 auto &entryBlockOps = entryBlock->getOperations();
224 for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
225 entryBlockOps.splice(entryBlockOps.end(), block.getOperations());
228 block.dropAllDefinedValueUses();
229 for (
size_t i = 0; i < block.getNumArguments(); i++) {
230 block.eraseArgument(i);
237 for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) {
238 if (!terminatorLike.hasTrait<OpTrait::IsTerminator>())
241 if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(terminatorLike)) {
242 terminatorLike.erase();
247 terminatorLike.moveBefore(entryBlock, entryBlock->end());
252 HandshakeLowering::runSSAMaximization(ConversionPatternRewriter &rewriter,
258 if (funcOp.isExternal())
265 if (type.getNumDynamicDims() != 0 || type.getShape().size() != 1)
266 return emitError(loc) <<
"memref's must be both statically sized and "
273 auto predecessors = block->getPredecessors();
274 return std::distance(predecessors.begin(), predecessors.end());
280 HandshakeLowering::insertMerge(Block *block, Value val,
282 ConversionPatternRewriter &rewriter) {
284 auto insertLoc = block->front().getLoc();
285 SmallVector<Backedge> dataEdges;
286 SmallVector<Value> operands;
290 if (val == getBlockEntryControl(block)) {
293 if (block == &r.front()) {
301 operands.push_back(val);
302 mergeOp = rewriter.create<handshake::MergeOp>(insertLoc, operands);
304 for (
unsigned i = 0; i < numPredecessors; i++) {
305 auto edge = edgeBuilder.
get(rewriter.getNoneType());
306 dataEdges.push_back(edge);
307 operands.push_back(Value(edge));
309 mergeOp = rewriter.create<handshake::ControlMergeOp>(insertLoc, operands);
311 setBlockEntryControl(block, mergeOp->getResult(0));
320 if (numPredecessors <= 1) {
321 if (numPredecessors == 0) {
325 operands.push_back(val);
329 auto edge = edgeBuilder.
get(val.getType());
330 dataEdges.push_back(edge);
331 operands.push_back(Value(edge));
333 auto merge = rewriter.create<handshake::MergeOp>(insertLoc, operands);
341 Backedge indexEdge = edgeBuilder.
get(rewriter.getIndexType());
342 for (
unsigned i = 0; i < numPredecessors; i++) {
343 auto edge = edgeBuilder.
get(val.getType());
344 dataEdges.push_back(edge);
345 operands.push_back(Value(edge));
348 rewriter.create<handshake::MuxOp>(insertLoc, Value(indexEdge), operands);
349 return MergeOpInfo{mux, val, dataEdges, indexEdge};
355 ConversionPatternRewriter &rewriter) {
357 for (Block &block : r) {
358 rewriter.setInsertionPointToStart(&block);
362 for (
auto &arg : block.getArguments()) {
364 if (arg.getType().isa<mlir::MemRefType>())
367 auto mergeInfo = insertMerge(&block, arg, edgeBuilder, rewriter);
368 blockMerges[&block].push_back(mergeInfo);
369 mergePairs[arg] = mergeInfo.op->getResult(0);
379 Value srcVal = mergeInfo.
val;
381 Block *block = mergeInfo.
op->getBlock();
386 unsigned index = srcVal.cast<BlockArgument>().getArgNumber();
387 Operation *termOp = predBlock->getTerminator();
388 if (mlir::cf::CondBranchOp br = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
390 if (block == br.getTrueDest())
391 return br.getTrueOperand(index);
392 assert(block == br.getFalseDest());
393 return br.getFalseOperand(index);
395 if (isa<mlir::cf::BranchOp>(termOp))
396 return termOp->getOperand(index);
403 for (Block &block : f) {
404 if (!block.isEntryBlock()) {
405 int x = block.getNumArguments() - 1;
406 for (
int i = x; i >= 0; --i)
407 block.eraseArgument(i);
414 template <
typename TOp>
416 auto ops = block->getOps<TOp>();
423 return getFirstOp<ControlMergeOp>(block);
427 for (
auto cbranch : block->getOps<handshake::ConditionalBranchOp>()) {
428 if (cbranch.isControl())
442 for (Block &block : r) {
443 for (
auto &mergeInfo : blockMerges[&block]) {
446 for (
auto *predBlock : block.getPredecessors()) {
448 assert(mgOperand !=
nullptr);
449 if (!mgOperand.getDefiningOp()) {
450 assert(mergePairs.count(mgOperand));
451 mgOperand = mergePairs[mgOperand];
453 mergeInfo.dataEdges[operandIdx].setValue(mgOperand);
459 for (Operation &opp : block)
460 if (!isa<MergeLikeOpInterface>(opp))
461 opp.replaceUsesOfWith(mergeInfo.val, mergeInfo.op->getResult(0));
467 for (Block &block : r) {
470 assert(cntrlMg !=
nullptr);
472 for (
auto &mergeInfo : blockMerges[&block]) {
473 if (mergeInfo.op != cntrlMg) {
477 assert(mergeInfo.indexEdge.has_value());
478 (*mergeInfo.indexEdge).setValue(cntrlMg->getResult(1));
488 return isa<memref::AllocOp, memref::AllocaOp>(op);
492 HandshakeLowering::addMergeOps(ConversionPatternRewriter &rewriter) {
503 BlockOps mergeOps = insertMergeOps(mergePairs, edgeBuilder, rewriter);
513 for (
auto &u : val.getUses())
515 if (isa<MergeLikeOpInterface>(u.getOwner()))
526 for (
int i = 0, e = block->getNumSuccessors(); i < e; ++i) {
528 Block *succ = block->getSuccessor(i);
529 for (
auto &u : val.getUses()) {
530 if (u.getOwner()->getBlock() == succ)
533 uses = (curr > uses) ? curr : uses;
547 class FeedForwardNetworkRewriter {
550 ConversionPatternRewriter &rewriter)
551 : hl(hl), rewriter(rewriter), postDomInfo(hl.getRegion().getParentOp()),
552 domInfo(hl.getRegion().getParentOp()),
553 loopInfo(domInfo.getDomTree(&hl.getRegion())) {}
554 LogicalResult apply();
558 ConversionPatternRewriter &rewriter;
559 PostDominanceInfo postDomInfo;
560 DominanceInfo domInfo;
561 CFGLoopInfo loopInfo;
563 using BlockPair = std::pair<Block *, Block *>;
564 using BlockPairs = SmallVector<BlockPair>;
565 LogicalResult findBlockPairs(BlockPairs &blockPairs);
567 BufferOp buildSplitNetwork(Block *splitBlock,
568 handshake::ConditionalBranchOp &ctrlBr);
569 LogicalResult buildMergeNetwork(Block *
mergeBlock, BufferOp buf,
570 handshake::ConditionalBranchOp &ctrlBr);
573 bool requiresOperandFlip(ControlMergeOp &ctrlMerge,
574 handshake::ConditionalBranchOp &ctrlBr);
575 bool formsIrreducibleCF(Block *splitBlock, Block *
mergeBlock);
580 HandshakeLowering::feedForwardRewriting(ConversionPatternRewriter &rewriter) {
582 if (this->getRegion().hasOneBlock())
584 return FeedForwardNetworkRewriter(*
this, rewriter).apply();
588 for (CFGLoop *loop : loopInfo.getTopLevelLoops())
589 if (!loop->getExitBlock())
594 bool FeedForwardNetworkRewriter::formsIrreducibleCF(Block *splitBlock,
596 CFGLoop *loop = loopInfo.getLoopFor(
mergeBlock);
597 for (
auto *mergePred :
mergeBlock->getPredecessors()) {
599 if (loop && loop->contains(mergePred))
607 if (llvm::none_of(splitBlock->getSuccessors(), [&](Block *splitSucc) {
608 if (splitSucc == mergeBlock || mergePred == splitBlock)
610 return domInfo.dominates(splitSucc, mergePred);
618 Block *pred = *block->getPredecessors().begin();
619 return pred->getTerminator();
623 FeedForwardNetworkRewriter::findBlockPairs(BlockPairs &blockPairs) {
627 Region &r = hl.getRegion();
628 Operation *parentOp = r.getParentOp();
633 "expected loop to only have one exit block.");
636 if (b.getNumSuccessors() < 2)
640 if (loopInfo.getLoopFor(&b))
643 assert(b.getNumSuccessors() == 2);
644 Block *succ0 = b.getSuccessor(0);
645 Block *succ1 = b.getSuccessor(1);
650 Block *
mergeBlock = postDomInfo.findNearestCommonDominator(succ0, succ1);
654 return parentOp->emitError(
"expected only reducible control flow.")
656 <<
"This branch is involved in the irreducible control flow";
659 unsigned nonLoopPreds = 0;
660 CFGLoop *loop = loopInfo.getLoopFor(
mergeBlock);
661 for (
auto *pred :
mergeBlock->getPredecessors()) {
662 if (loop && loop->contains(pred))
666 if (nonLoopPreds > 2)
668 ->emitError(
"expected a merge block to have two predecessors. "
669 "Did you run the merge block insertion pass?")
671 <<
"This branch jumps to the illegal block";
679 LogicalResult FeedForwardNetworkRewriter::apply() {
682 if (failed(findBlockPairs(pairs)))
686 handshake::ConditionalBranchOp ctrlBr;
687 BufferOp buffer = buildSplitNetwork(splitBlock, ctrlBr);
688 if (failed(buildMergeNetwork(
mergeBlock, buffer, ctrlBr)))
695 BufferOp FeedForwardNetworkRewriter::buildSplitNetwork(
696 Block *splitBlock, handshake::ConditionalBranchOp &ctrlBr) {
697 SmallVector<handshake::ConditionalBranchOp> branches;
698 llvm::copy(splitBlock->getOps<handshake::ConditionalBranchOp>(),
699 std::back_inserter(branches));
701 auto *findRes = llvm::find_if(branches, [](
auto br) {
702 return br.getDataOperand().getType().
template isa<NoneType>();
705 assert(findRes &&
"expected one branch for the ctrl signal");
708 Value cond = ctrlBr.getConditionOperand();
709 assert(llvm::all_of(branches, [&](
auto branch) {
710 return branch.getConditionOperand() == cond;
713 Location loc = cond.getLoc();
714 rewriter.setInsertionPointAfterValue(cond);
718 size_t bufferSize = 2;
722 return rewriter.create<handshake::BufferOp>(loc, cond, bufferSize,
723 BufferTypeEnum::fifo);
726 LogicalResult FeedForwardNetworkRewriter::buildMergeNetwork(
727 Block *
mergeBlock, BufferOp buf, handshake::ConditionalBranchOp &ctrlBr) {
729 auto ctrlMerges =
mergeBlock->getOps<handshake::ControlMergeOp>();
730 assert(std::distance(ctrlMerges.begin(), ctrlMerges.end()) == 1);
732 handshake::ControlMergeOp ctrlMerge = *ctrlMerges.begin();
734 if (ctrlMerge.getNumOperands() != 2)
735 return ctrlMerge.emitError(
"expected cmerges to have two operands");
736 rewriter.setInsertionPointAfter(ctrlMerge);
737 Location loc = ctrlMerge->getLoc();
742 bool requiresFlip = requiresOperandFlip(ctrlMerge, ctrlBr);
743 SmallVector<Value> muxOperands;
745 muxOperands = llvm::to_vector(llvm::reverse(ctrlMerge.getOperands()));
747 muxOperands = llvm::to_vector(ctrlMerge.getOperands());
749 Value newCtrl = rewriter.create<handshake::MuxOp>(loc, buf, muxOperands);
751 Value cond = buf.getResult();
756 cond = rewriter.create<arith::XOrIOp>(
757 loc, cond.getType(), cond,
758 rewriter.create<arith::ConstantOp>(
759 loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 1)));
764 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), cond);
769 rewriter.replaceOp(ctrlMerge, {newCtrl, condAsIndex});
773 bool FeedForwardNetworkRewriter::requiresOperandFlip(
774 ControlMergeOp &ctrlMerge, handshake::ConditionalBranchOp &ctrlBr) {
775 assert(ctrlMerge.getNumOperands() == 2 &&
776 "Loops should already have been handled");
778 Value fstOperand = ctrlMerge.getOperand(0);
780 assert(ctrlBr.getTrueResult().hasOneUse() &&
781 "expected the result of a branch to only have one user");
782 Operation *trueUser = *ctrlBr.getTrueResult().user_begin();
783 if (trueUser == ctrlBr)
785 return ctrlBr.getTrueResult() == fstOperand;
789 Block *trueBlock = trueUser->getBlock();
790 return domInfo.dominates(trueBlock, fstOperand.getDefiningOp()->getBlock());
803 class LoopNetworkRewriter {
807 LogicalResult processRegion(Region &r, ConversionPatternRewriter &rewriter);
812 using ExitPair = std::pair<Block *, Block *>;
813 LogicalResult processOuterLoop(Location loc, CFGLoop *loop);
822 BufferOp buildContinueNetwork(Block *loopHeader, Block *loopLatch,
828 void buildExitNetwork(Block *loopHeader,
830 BufferOp loopPrimingRegister,
834 ConversionPatternRewriter *rewriter =
nullptr;
840 HandshakeLowering::loopNetworkRewriting(ConversionPatternRewriter &rewriter) {
841 return LoopNetworkRewriter(*this).processRegion(r, rewriter);
845 LoopNetworkRewriter::processRegion(Region &r,
846 ConversionPatternRewriter &rewriter) {
850 this->rewriter = &rewriter;
852 Operation *op = r.getParentOp();
854 DominanceInfo domInfo(op);
855 CFGLoopInfo loopInfo(domInfo.getDomTree(&r));
857 for (CFGLoop *loop : loopInfo.getTopLevelLoops()) {
858 if (!loop->getLoopLatch())
859 return emitError(op->getLoc()) <<
"Multiple loop latches detected "
860 "(backedges from within the loop "
861 "to the loop header). Loop task "
862 "pipelining is only supported for "
863 "loops with unified loop latches.";
866 if (failed(processOuterLoop(op->getLoc(), loop)))
875 auto inValueIt = llvm::find_if(mux.getDataOperands(), [&](Value operand) {
876 return block == operand.getParentBlock();
879 inValueIt != mux.getOperands().end() &&
880 "Expected mux to have an operand originating from the requested block.");
888 std::vector<Value> sortedOperands;
889 for (
auto in : cmerge.getOperands()) {
890 auto *srcBlock = in.getParentBlock();
895 for (
unsigned i = 0; i < sortedOperands.size(); ++i) {
896 for (
unsigned j = 0; j < sortedOperands.size(); ++j) {
899 assert(sortedOperands[i] != sortedOperands[j] &&
900 "Cannot have an identical operand from two different blocks!");
904 return sortedOperands;
907 BufferOp LoopNetworkRewriter::buildContinueNetwork(Block *loopHeader,
914 llvm::SmallVector<MuxOp> muxesToReplace;
915 llvm::copy(loopHeader->getOps<MuxOp>(), std::back_inserter(muxesToReplace));
921 assert(hl.getBlockEntryControl(loopHeader) == cmerge->getResult(0) &&
922 "Expected control merge to be the control component of a loop header");
923 auto loc = cmerge->getLoc();
926 assert(cmerge->getNumOperands() > 1 &&
"This cannot be a loop header");
930 SmallVector<Value> externalCtrls, loopCtrls;
931 for (
auto cval : cmerge->getOperands()) {
932 if (cval.getParentBlock() == loopLatch)
933 loopCtrls.push_back(cval);
935 externalCtrls.push_back(cval);
937 assert(loopCtrls.size() == 1 &&
938 "Expected a single loop control value to match the single loop latch");
939 Value loopCtrl = loopCtrls.front();
942 rewriter->setInsertionPointToStart(loopHeader);
943 auto externalCtrlMerge = rewriter->create<ControlMergeOp>(loc, externalCtrls);
948 auto primingRegister =
949 rewriter->create<BufferOp>(loc, loopPrimingInput, 1, BufferTypeEnum::seq);
951 primingRegister->setAttr(
"initValues", rewriter->getI64ArrayAttr({0}));
955 auto loopCtrlMux = rewriter->create<MuxOp>(
956 loc, primingRegister.getResult(),
957 llvm::SmallVector<Value>{externalCtrlMerge.getResult(), loopCtrl});
961 cmerge->getResult(0).replaceAllUsesWith(loopCtrlMux.getResult());
964 hl.setBlockEntryControl(loopHeader, loopCtrlMux.getResult());
975 DenseMap<MuxOp, std::vector<Value>> externalDataInputs;
976 DenseMap<MuxOp, Value> loopDataInputs;
977 for (
auto muxOp : muxesToReplace) {
978 if (muxOp == loopCtrlMux)
983 assert( 1 + externalDataInputs[muxOp].size() ==
984 muxOp.getDataOperands().size() &&
985 "Expected all mux operands to be partitioned between loop and "
986 "external data inputs");
994 for (MuxOp mux : muxesToReplace) {
995 auto externalDataMux = rewriter->create<MuxOp>(
996 loc, externalCtrlMerge.getIndex(), externalDataInputs[mux]);
1000 ->create<MuxOp>(loc, primingRegister,
1001 llvm::SmallVector<Value>{externalDataMux,
1002 loopDataInputs[mux]})
1008 rewriter->eraseOp(cmerge);
1011 return primingRegister;
1014 void LoopNetworkRewriter::buildExitNetwork(
1016 BufferOp loopPrimingRegister,
Backedge &loopPrimingInput) {
1017 auto loc = loopPrimingRegister.getLoc();
1026 SmallVector<Value> parityCorrectedConds;
1027 for (
auto &[condBlock, exitBlock] : exitPairs) {
1031 "Expected a conditional control branch op in the loop condition block");
1032 Operation *trueUser = *condBr.getTrueResult().getUsers().begin();
1033 bool isTrueParity = trueUser->getBlock() == exitBlock;
1035 ((*condBr.getFalseResult().getUsers().begin())->getBlock() ==
1037 "The user of either the true or the false result should be in the "
1040 Value condValue = condBr.getConditionOperand();
1044 rewriter->setInsertionPoint(condBr);
1045 condValue = rewriter->create<arith::XOrIOp>(
1046 loc, condValue.getType(), condValue,
1047 rewriter->create<arith::ConstantOp>(
1048 loc, rewriter->getIntegerAttr(rewriter->getI1Type(), 1)));
1050 parityCorrectedConds.push_back(condValue);
1055 auto exitMerge = rewriter->create<MergeOp>(loc, parityCorrectedConds);
1056 loopPrimingInput.
setValue(exitMerge);
1059 LogicalResult LoopNetworkRewriter::processOuterLoop(Location loc,
1064 SmallVector<Block *> exitBlocks;
1065 loop->getExitBlocks(exitBlocks);
1066 for (
auto *exitNode : exitBlocks) {
1067 for (
auto *pred : exitNode->getPredecessors()) {
1069 if (!loop->contains(pred))
1072 ExitPair condPair = {pred, exitNode};
1073 assert(!exitPairs.count(condPair) &&
1074 "identical condition pairs should never be possible");
1075 exitPairs.insert(condPair);
1078 assert(!exitPairs.empty() &&
"No exits from loop?");
1082 if (exitPairs.size() > 1)
1083 return emitError(loc)
1084 <<
"Multiple exits detected within a loop. Loop task pipelining is "
1085 "only supported for loops with unified loop exit blocks.";
1087 Block *header = loop->getHeader();
1092 auto loopPrimingRegisterInput = bebuilder.get(rewriter->getI1Type());
1093 auto loopPrimingRegister = buildContinueNetwork(header, loop->getLoopLatch(),
1094 loopPrimingRegisterInput);
1098 buildExitNetwork(header, exitPairs, loopPrimingRegister,
1099 loopPrimingRegisterInput);
1108 if (
auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
1109 if (condBranchOp.getTrueDest() == succBlock)
1110 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getTrueResult();
1112 assert(condBranchOp.getFalseDest() == succBlock);
1113 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getFalseResult();
1117 return newOp->getResult(0);
1121 HandshakeLowering::addBranchOps(ConversionPatternRewriter &rewriter) {
1125 for (Block &block : r) {
1126 for (Operation &op : block) {
1127 for (
auto result : op.getResults())
1129 liveOuts[&block].push_back(result);
1133 for (Block &block : r) {
1134 Operation *termOp = block.getTerminator();
1135 rewriter.setInsertionPoint(termOp);
1137 for (Value val : liveOuts[&block]) {
1142 for (
int i = 0, e = numBranches; i < e; ++i) {
1143 Operation *newOp =
nullptr;
1145 if (
auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp))
1146 newOp = rewriter.create<handshake::ConditionalBranchOp>(
1147 termOp->getLoc(), condBranchOp.getCondition(), val);
1148 else if (isa<mlir::cf::BranchOp>(termOp))
1149 newOp = rewriter.create<handshake::BranchOp>(termOp->getLoc(), val);
1151 if (newOp ==
nullptr)
1154 for (
int j = 0, e = block.getNumSuccessors(); j < e; ++j) {
1155 Block *succ = block.getSuccessor(j);
1158 for (
auto &u : val.getUses()) {
1159 if (u.getOwner()->getBlock() == succ) {
1160 u.getOwner()->replaceUsesOfWith(val, res);
1172 LogicalResult HandshakeLowering::connectConstantsToControl(
1173 ConversionPatternRewriter &rewriter,
bool sourceConstants) {
1179 if (sourceConstants) {
1180 for (
auto constantOp : llvm::make_early_inc_range(
1181 r.template getOps<mlir::arith::ConstantOp>())) {
1182 rewriter.setInsertionPointAfter(constantOp);
1183 auto value = constantOp.getValue();
1184 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1185 constantOp, value.getType(), value,
1186 rewriter.create<handshake::SourceOp>(constantOp.getLoc(),
1187 rewriter.getNoneType()));
1190 for (Block &block : r) {
1191 Value blockEntryCtrl = getBlockEntryControl(&block);
1192 for (
auto constantOp : llvm::make_early_inc_range(
1193 block.template getOps<mlir::arith::ConstantOp>())) {
1194 rewriter.setInsertionPointAfter(constantOp);
1195 auto value = constantOp.getValue();
1196 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1197 constantOp, value.getType(), value, blockEntryCtrl);
1213 : op(op), ctrlOperand(ctrlOperand) {
1214 assert(op && ctrlOperand);
1215 assert(ctrlOperand.getType().isa<NoneType>() &&
1216 "Control operand must be a NoneType");
1229 for (Operation &op : *block) {
1230 if (
auto branchOp = dyn_cast<handshake::BranchOp>(op))
1231 if (branchOp.isControl())
1232 return {branchOp, branchOp.getDataOperand()};
1233 if (
auto branchOp = dyn_cast<handshake::ConditionalBranchOp>(op))
1234 if (branchOp.isControl())
1235 return {branchOp, branchOp.getDataOperand()};
1236 if (
auto endOp = dyn_cast<handshake::ReturnOp>(op))
1237 return {endOp, endOp.getOperands().back()};
1239 llvm_unreachable(
"Block terminator must exist");
1244 if (
auto memOp = dyn_cast<memref::LoadOp>(op))
1245 out = memOp.getMemRef();
1246 else if (
auto memOp = dyn_cast<memref::StoreOp>(op))
1247 out = memOp.getMemRef();
1248 else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
1249 MemRefAccess access(op);
1250 out = access.memref;
1254 return op->emitOpError(
"Unknown Op type");
1258 return isa<memref::LoadOp, memref::StoreOp, AffineReadOpInterface,
1259 AffineWriteOpInterface>(op);
1263 HandshakeLowering::replaceMemoryOps(ConversionPatternRewriter &rewriter,
1266 std::vector<Operation *> opsToErase;
1269 for (
auto arg : r.getArguments()) {
1270 auto memrefType = dyn_cast<mlir::MemRefType>(arg.getType());
1276 memRefOps.insert(std::make_pair(arg, std::vector<Operation *>()));
1282 for (Operation &op : r.getOps()) {
1286 rewriter.setInsertionPoint(&op);
1290 Operation *newOp =
nullptr;
1292 llvm::TypeSwitch<Operation *>(&op)
1293 .Case<memref::LoadOp>([&](
auto loadOp) {
1296 SmallVector<Value, 8> operands(loadOp.getIndices());
1299 rewriter.create<handshake::LoadOp>(op.getLoc(), memref, operands);
1300 op.getResult(0).replaceAllUsesWith(newOp->getResult(0));
1302 .Case<memref::StoreOp>([&](
auto storeOp) {
1305 SmallVector<Value, 8> operands(storeOp.getIndices());
1308 newOp = rewriter.create<handshake::StoreOp>(
1309 op.getLoc(), storeOp.getValueToStore(), operands);
1311 .Case<AffineReadOpInterface, AffineWriteOpInterface>([&](
auto) {
1313 MemRefAccess access(&op);
1320 if (
auto loadOp = dyn_cast<AffineReadOpInterface>(op))
1321 map = loadOp.getAffineMap();
1323 map = dyn_cast<AffineWriteOpInterface>(op).getAffineMap();
1331 expandAffineMap(rewriter, op.getLoc(), map, access.indices);
1332 assert(operands &&
"Address operands of affine memref access "
1333 "cannot be reduced.");
1335 if (isa<AffineReadOpInterface>(op)) {
1336 auto loadOp = rewriter.create<handshake::LoadOp>(
1337 op.getLoc(), access.memref, *operands);
1339 op.getResult(0).replaceAllUsesWith(loadOp.getDataResult());
1341 newOp = rewriter.create<handshake::StoreOp>(
1342 op.getLoc(), op.getOperand(0), *operands);
1345 .Default([&](
auto) {
1346 op.emitOpError(
"Load/store operation cannot be handled.");
1349 memRefOps[memref].push_back(newOp);
1350 opsToErase.push_back(&op);
1354 for (
unsigned i = 0, e = opsToErase.size(); i != e; ++i) {
1355 auto *op = opsToErase[i];
1356 for (
int j = 0, e = op->getNumOperands(); j < e; ++j)
1357 op->eraseOperand(0);
1358 assert(op->getNumOperands() == 0);
1360 rewriter.eraseOp(op);
1369 if (handshake::LoadOp loadOp = dyn_cast<handshake::LoadOp>(op)) {
1372 SmallVector<Value, 8> results(loadOp.getAddressResults());
1377 assert(dyn_cast<handshake::StoreOp>(op));
1378 handshake::StoreOp storeOp = dyn_cast<handshake::StoreOp>(op);
1379 SmallVector<Value, 8> results(storeOp.getResults());
1386 for (Block &block : f) {
1388 if (!ctrl.hasOneUse())
1394 ConversionPatternRewriter &rewriter) {
1395 std::vector<Operation *> opsToDelete;
1398 for (
auto &op : r.getOps())
1399 if (
isAllocOp(&op) && op.getResult(0).use_empty())
1400 opsToDelete.push_back(&op);
1402 llvm::for_each(opsToDelete, [&](
auto allocOp) { rewriter.eraseOp(allocOp); });
1406 ArrayRef<BlockControlTerm> controlTerms) {
1407 for (
auto term : controlTerms) {
1408 auto &[op, ctrl] = term;
1409 auto *srcOp = ctrl.getDefiningOp();
1412 if (!isa<JoinOp>(srcOp)) {
1413 rewriter.setInsertionPointAfter(srcOp);
1414 Operation *newJoin = rewriter.create<JoinOp>(srcOp->getLoc(), ctrl);
1415 op->replaceUsesOfWith(ctrl, newJoin->getResult(0));
1420 static std::vector<BlockControlTerm>
1422 std::vector<BlockControlTerm> terminators;
1424 for (Operation *op : memOps) {
1426 Block *block = op->getBlock();
1429 if (std::find(terminators.begin(), terminators.end(), term) ==
1431 terminators.push_back(term);
1438 SmallVector<Value, 8> results(op->getOperands());
1439 results.push_back(val);
1440 op->setOperands(results);
1446 for (
auto *op : memOps) {
1447 if (isa<handshake::LoadOp>(op))
1453 Operation *memOp,
int offset,
1454 ArrayRef<int> cntrlInd) {
1457 for (
int i = 0, e = memOps.size(); i < e; ++i) {
1458 auto *op = memOps[i];
1460 auto *srcOp = ctrl.getDefiningOp();
1461 if (!isa<JoinOp>(srcOp)) {
1462 return srcOp->emitOpError(
"Op expected to be a JoinOp");
1469 void HandshakeLowering::setMemOpControlInputs(
1470 ConversionPatternRewriter &rewriter, ArrayRef<Operation *> memOps,
1471 Operation *memOp,
int offset, ArrayRef<int> cntrlInd) {
1472 for (
int i = 0, e = memOps.size(); i < e; ++i) {
1473 std::vector<Value> controlOperands;
1474 Operation *currOp = memOps[i];
1475 Block *currBlock = currOp->getBlock();
1478 Value blockEntryCtrl = getBlockEntryControl(currBlock);
1479 controlOperands.push_back(blockEntryCtrl);
1482 for (
int j = 0, f = i; j < f; ++j) {
1483 Operation *predOp = memOps[j];
1484 Block *predBlock = predOp->getBlock();
1485 if (currBlock == predBlock)
1487 if (!(isa<handshake::LoadOp>(currOp) && isa<handshake::LoadOp>(predOp)))
1489 controlOperands.push_back(memOp->getResult(offset + cntrlInd[j]));
1493 if (controlOperands.size() == 1)
1498 rewriter.setInsertionPoint(currOp);
1500 rewriter.create<JoinOp>(currOp->getLoc(), controlOperands);
1507 HandshakeLowering::connectToMemory(ConversionPatternRewriter &rewriter,
1512 for (
auto memory : memRefOps) {
1514 Value memrefOperand = memory.first;
1518 bool isExternalMemory = memrefOperand.isa<BlockArgument>();
1520 mlir::MemRefType memrefType =
1521 memrefOperand.getType().cast<mlir::MemRefType>();
1525 std::vector<Value> operands;
1528 std::vector<BlockControlTerm> controlTerms =
1534 for (
auto valOp : controlTerms)
1535 operands.push_back(valOp.ctrlOperand);
1547 std::vector<int> newInd(memory.second.size(), 0);
1549 for (
int i = 0, e = memory.second.size(); i < e; ++i) {
1550 auto *op = memory.second[i];
1551 if (isa<handshake::StoreOp>(op)) {
1553 operands.insert(operands.end(), results.begin(), results.end());
1560 for (
int i = 0, e = memory.second.size(); i < e; ++i) {
1561 auto *op = memory.second[i];
1562 if (isa<handshake::LoadOp>(op)) {
1564 operands.insert(operands.end(), results.begin(), results.end());
1572 int cntrl_count = lsq ? 0 : memory.second.size();
1574 Block *entryBlock = &r.front();
1575 rewriter.setInsertionPointToStart(entryBlock);
1578 Operation *newOp =
nullptr;
1579 if (isExternalMemory)
1580 newOp = rewriter.create<ExternalMemoryOp>(
1581 entryBlock->front().getLoc(), memrefOperand, operands, ld_count,
1582 cntrl_count - ld_count, mem_count++);
1584 newOp = rewriter.create<MemoryOp>(entryBlock->front().getLoc(), operands,
1585 ld_count, cntrl_count, lsq, mem_count++,
1600 bool control =
true;
1610 setMemOpControlInputs(rewriter, memory.second, newOp, ld_count, newInd);
1622 HandshakeLowering::replaceCallOps(ConversionPatternRewriter &rewriter) {
1623 for (Block &block : r) {
1626 Value blockEntryControl = getBlockEntryControl(&block);
1627 for (Operation &op : block) {
1628 if (
auto callOp = dyn_cast<mlir::func::CallOp>(op)) {
1629 llvm::SmallVector<Value> operands;
1630 llvm::copy(callOp.getOperands(), std::back_inserter(operands));
1631 operands.push_back(blockEntryControl);
1632 rewriter.setInsertionPoint(callOp);
1633 auto instanceOp = rewriter.create<handshake::InstanceOp>(
1634 callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
1637 for (
auto it : llvm::zip(callOp.getResults(), instanceOp.getResults()))
1638 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
1639 rewriter.eraseOp(callOp);
1652 bool maximizeArgument(BlockArgument arg)
override {
1653 return !arg.getType().isa<mlir::MemRefType>();
1657 bool maximizeOp(Operation *op)
override {
return !
isAllocOp(op); }
1665 ConversionPatternRewriter &rewriter) {
1666 HandshakeLoweringSSAStrategy strategy;
1670 static LogicalResult
lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
1671 bool sourceConstants,
1672 bool disableTaskPipelining) {
1674 SmallVector<NamedAttribute, 4> attributes;
1675 for (
const auto &attr : funcOp->getAttrs()) {
1676 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
1677 attr.getName() == funcOp.getFunctionTypeAttrName())
1679 attributes.push_back(attr);
1683 llvm::SmallVector<mlir::Type, 8> argTypes;
1684 for (
auto &argType : funcOp.getArgumentTypes())
1685 argTypes.push_back(argType);
1688 llvm::SmallVector<mlir::Type, 8> resTypes;
1689 for (
auto resType : funcOp.getResultTypes())
1690 resTypes.push_back(resType);
1692 handshake::FuncOp newFuncOp;
1696 if (partiallyLowerOp<func::FuncOp>(
1697 [&](func::FuncOp funcOp, PatternRewriter &rewriter) {
1698 auto noneType = rewriter.getNoneType();
1699 resTypes.push_back(noneType);
1700 argTypes.push_back(noneType);
1701 auto func_type = rewriter.getFunctionType(argTypes, resTypes);
1702 newFuncOp = rewriter.create<handshake::FuncOp>(
1703 funcOp.getLoc(), funcOp.getName(), func_type, attributes);
1704 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1706 if (!newFuncOp.isExternal()) {
1707 newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(),
1709 newFuncOp.resolveArgAndResNames();
1711 rewriter.eraseOp(funcOp);
1722 if (!newFuncOp.isExternal()) {
1723 Block *bodyBlock = newFuncOp.getBodyBlock();
1724 Value entryCtrl = bodyBlock->getArguments().back();
1726 if (failed(lowerRegion<func::ReturnOp, handshake::ReturnOp>(
1727 fol, sourceConstants, disableTaskPipelining, entryCtrl)))
1736 struct HandshakeRemoveBlockPass
1737 : HandshakeRemoveBlockBase<HandshakeRemoveBlockPass> {
1741 struct CFToHandshakePass :
public CFToHandshakeBase<CFToHandshakePass> {
1742 CFToHandshakePass(
bool sourceConstants,
bool disableTaskPipelining) {
1743 this->sourceConstants = sourceConstants;
1744 this->disableTaskPipelining = disableTaskPipelining;
1746 void runOnOperation()
override {
1747 ModuleOp m = getOperation();
1749 for (
auto funcOp : llvm::make_early_inc_range(m.getOps<func::FuncOp>())) {
1750 if (failed(
lowerFuncOp(funcOp, &getContext(), sourceConstants,
1751 disableTaskPipelining))) {
1752 signalPassFailure();
1761 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1763 bool disableTaskPipelining) {
1764 return std::make_unique<CFToHandshakePass>(sourceConstants,
1765 disableTaskPipelining);
1768 std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
1770 return std::make_unique<HandshakeRemoveBlockPass>();
static ConditionalBranchOp getControlCondBranch(Block *block)
static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx, bool sourceConstants, bool disableTaskPipelining)
static bool isMemoryOp(Operation *op)
static LogicalResult setJoinControlInputs(ArrayRef< Operation * > memOps, Operation *memOp, int offset, ArrayRef< int > cntrlInd)
static void addJoinOps(ConversionPatternRewriter &rewriter, ArrayRef< BlockControlTerm > controlTerms)
static void addLazyForks(Region &f, ConversionPatternRewriter &rewriter)
static bool isLiveOut(Value val)
static Operation * getControlMerge(Block *block)
static SmallVector< Value, 8 > getResultsToMemory(Operation *op)
static unsigned getBlockPredecessorCount(Block *block)
static int getBranchCount(Value val, Block *block)
void removeBasicBlocks(handshake::FuncOp funcOp)
static Operation * getFirstOp(Block *block)
Returns the first occurance of an operation of type TOp, else, returns null op.
static Value getSuccResult(Operation *termOp, Operation *newOp, Block *succBlock)
static Value getMergeOperand(HandshakeLowering::MergeOpInfo mergeInfo, Block *predBlock)
static LogicalResult isValidMemrefType(Location loc, mlir::MemRefType type)
static bool isAllocOp(Operation *op)
static Value getOperandFromBlock(MuxOp mux, Block *block)
static LogicalResult getOpMemRef(Operation *op, Value &out)
static LogicalResult partiallyLowerOp(const std::function< LogicalResult(TOp, ConversionPatternRewriter &)> &loweringFunc, MLIRContext *ctx, TOp op)
static void addValueToOperands(Operation *op, Value val)
static bool loopsHaveSingleExit(CFGLoopInfo &loopInfo)
static std::vector< Value > getSortedInputs(ControlMergeOp cmerge, MuxOp mux)
static void removeBlockOperands(Region &f)
static void removeUnusedAllocOps(Region &r, ConversionPatternRewriter &rewriter)
static LogicalResult maximizeSSANoMem(Region &r, ConversionPatternRewriter &rewriter)
Converts every value in the region into maximal SSA form, unless the value is a block argument of typ...
static Operation * findBranchToBlock(Block *block)
static std::vector< BlockControlTerm > getControlTerminators(ArrayRef< Operation * > memOps)
static BlockControlTerm getBlockControlTerminator(Block *block)
static void reconnectMergeOps(Region &r, HandshakeLowering::BlockOps blockMerges, HandshakeLowering::ValueMap &mergePairs)
static void setLoadDataInputs(ArrayRef< Operation * > memOps, Operation *memOp)
assert(baseType &&"element must be base type")
static void insertFork(Value result, OpBuilder &rewriter)
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
LowerRegionTarget(MLIRContext &context, Region ®ion)
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Strategy class to control the behavior of SSA maximization.
DenseMap< Block *, std::vector< MergeOpInfo > > BlockOps
DenseMap< Block *, std::vector< Value > > BlockValues
llvm::MapVector< Value, std::vector< Operation * > > MemRefToMemoryAccessOp
DenseMap< Value, Value > ValueMap
llvm::function_ref< LogicalResult(Region &, ConversionPatternRewriter &)> RegionLoweringFunc
LogicalResult partiallyLowerRegion(const RegionLoweringFunc &loweringFunc, MLIRContext *ctx, Region &r)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createCFToHandshakePass(bool sourceConstants=false, bool disableTaskPipelining=false)
std::unique_ptr< mlir::OperationPass< handshake::FuncOp > > createHandshakeRemoveBlockPass()
Holds information about an handshake "basic block terminator" control operation.
friend bool operator==(const BlockControlTerm &lhs, const BlockControlTerm &rhs)
Checks for member-wise equality.
Value ctrlOperand
The operation's control operand (must have type NoneType)
BlockControlTerm(Operation *op, Value ctrlOperand)
Operation * op
The operation.
Allows to partially lower a region by matching on the parent operation to then call the provided part...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value >, ConversionPatternRewriter &rewriter) const override
std::function< LogicalResult(Region &, ConversionPatternRewriter &)> PartialLoweringFunc
PartialLowerRegion(LowerRegionTarget &target, MLIRContext *context, LogicalResult &loweringResRef, const PartialLoweringFunc &fun)
LogicalResult & loweringRes
LowerRegionTarget & target