13 #include "../PassDetail.h"
17 #include "mlir/Analysis/CFGLoopInfo.h"
18 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
19 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
20 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
23 #include "mlir/Dialect/Affine/Utils.h"
24 #include "mlir/Dialect/Arith/IR/Arith.h"
25 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/Dialect/SCF/IR/SCF.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/Diagnostics.h"
32 #include "mlir/IR/Dominance.h"
33 #include "mlir/IR/OpImplementation.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/Types.h"
36 #include "mlir/IR/Value.h"
37 #include "mlir/Pass/Pass.h"
38 #include "mlir/Support/LLVM.h"
39 #include "mlir/Transforms/DialectConversion.h"
40 #include "mlir/Transforms/Passes.h"
41 #include "llvm/ADT/SmallSet.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/Support/raw_ostream.h"
51 using namespace circt;
60 template <
typename TOp>
61 class LowerOpTarget :
public ConversionTarget {
63 explicit LowerOpTarget(MLIRContext &context) : ConversionTarget(context) {
65 addLegalDialect<HandshakeDialect>();
66 addLegalDialect<mlir::func::FuncDialect>();
67 addLegalDialect<mlir::arith::ArithDialect>();
68 addIllegalDialect<mlir::scf::SCFDialect>();
69 addIllegalDialect<AffineDialect>();
74 addDynamicallyLegalOp<TOp>([&](
const auto &op) {
return loweredOps[op]; });
76 DenseMap<Operation *, bool> loweredOps;
91 template <
typename TOp>
92 struct PartialLowerOp :
public ConversionPattern {
93 using PartialLoweringFunc =
94 std::function<LogicalResult(TOp, ConversionPatternRewriter &)>;
97 PartialLowerOp(LowerOpTarget<TOp> &target, MLIRContext *context,
98 LogicalResult &loweringResRef,
const PartialLoweringFunc &fun)
99 : ConversionPattern(TOp::getOperationName(), 1, context), target(target),
100 loweringRes(loweringResRef), fun(fun) {}
101 using ConversionPattern::ConversionPattern;
103 matchAndRewrite(Operation *op, ArrayRef<Value> ,
104 ConversionPatternRewriter &rewriter)
const override {
106 rewriter.updateRootInPlace(
107 op, [&] { loweringRes = fun(dyn_cast<TOp>(op), rewriter); });
108 target.loweredOps[op] =
true;
113 LowerOpTarget<TOp> ⌖
114 LogicalResult &loweringRes;
116 PartialLoweringFunc fun;
122 template <
typename TOp>
124 const std::function<LogicalResult(TOp, ConversionPatternRewriter &)>
126 MLIRContext *ctx, TOp op) {
129 auto target = LowerOpTarget<TOp>(*ctx);
130 LogicalResult partialLoweringSuccessfull = success();
131 patterns.add<PartialLowerOp<TOp>>(target, ctx, partialLoweringSuccessfull,
134 applyPartialConversion(op, target, std::move(
patterns)).succeeded() &&
135 partialLoweringSuccessfull.succeeded());
141 : ConversionTarget(context), region(region) {
144 markUnknownOpDynamicallyLegal([&](Operation *op) {
145 if (op != region.getParentOp())
150 bool opLowered =
false;
161 std::function<LogicalResult(Region &, ConversionPatternRewriter &)>;
165 LogicalResult &loweringResRef,
167 : ConversionPattern(target.region.getParentOp()->
getName().getStringRef(),
169 target(target), loweringRes(loweringResRef), fun(fun) {}
170 using ConversionPattern::ConversionPattern;
173 ConversionPatternRewriter &rewriter)
const override {
174 rewriter.updateRootInPlace(
175 op, [&] { loweringRes = fun(target.region, rewriter); });
177 target.opLowered =
true;
189 MLIRContext *ctx, Region &r) {
191 Operation *op = r.getParentOp();
194 LogicalResult partialLoweringSuccessfull = success();
198 applyPartialConversion(op, target, std::move(
patterns)).succeeded() &&
199 partialLoweringSuccessfull.succeeded());
202 #define returnOnError(logicalResult) \
203 if (failed(logicalResult)) \
210 Value HandshakeLowering::getBlockEntryControl(Block *block)
const {
211 auto it = blockEntryControlMap.find(block);
212 assert(it != blockEntryControlMap.end() &&
213 "No block entry control value registerred for this block!");
217 void HandshakeLowering::setBlockEntryControl(Block *block, Value v) {
218 blockEntryControlMap[block] = v;
222 Block *entryBlock = &r.front();
223 auto &entryBlockOps = entryBlock->getOperations();
226 for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
227 entryBlockOps.splice(entryBlockOps.end(), block.getOperations());
230 block.dropAllDefinedValueUses();
231 for (
size_t i = 0; i < block.getNumArguments(); i++) {
232 block.eraseArgument(i);
239 for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) {
240 if (!terminatorLike.hasTrait<OpTrait::IsTerminator>())
243 if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(terminatorLike)) {
244 terminatorLike.erase();
249 terminatorLike.moveBefore(entryBlock, entryBlock->end());
254 if (funcOp.isExternal())
261 if (type.getNumDynamicDims() != 0 || type.getShape().size() != 1)
262 return emitError(loc) <<
"memref's must be both statically sized and "
269 auto predecessors = block->getPredecessors();
270 return std::distance(predecessors.begin(), predecessors.end());
276 HandshakeLowering::insertMerge(Block *block, Value val,
278 ConversionPatternRewriter &rewriter) {
280 auto insertLoc = block->front().getLoc();
281 SmallVector<Backedge> dataEdges;
282 SmallVector<Value> operands;
286 if (val == getBlockEntryControl(block)) {
289 if (block == &r.front()) {
297 operands.push_back(val);
298 mergeOp = rewriter.create<handshake::MergeOp>(insertLoc, operands);
300 for (
unsigned i = 0; i < numPredecessors; i++) {
301 auto edge = edgeBuilder.
get(rewriter.getNoneType());
302 dataEdges.push_back(edge);
303 operands.push_back(Value(edge));
305 mergeOp = rewriter.create<handshake::ControlMergeOp>(insertLoc, operands);
307 setBlockEntryControl(block, mergeOp->getResult(0));
316 if (numPredecessors <= 1) {
317 if (numPredecessors == 0) {
321 operands.push_back(val);
325 auto edge = edgeBuilder.
get(val.getType());
326 dataEdges.push_back(edge);
327 operands.push_back(Value(edge));
329 auto merge = rewriter.create<handshake::MergeOp>(insertLoc, operands);
337 Backedge indexEdge = edgeBuilder.
get(rewriter.getIndexType());
338 for (
unsigned i = 0; i < numPredecessors; i++) {
339 auto edge = edgeBuilder.
get(val.getType());
340 dataEdges.push_back(edge);
341 operands.push_back(Value(edge));
344 rewriter.create<handshake::MuxOp>(insertLoc, Value(indexEdge), operands);
345 return MergeOpInfo{mux, val, dataEdges, indexEdge};
351 ConversionPatternRewriter &rewriter) {
353 for (Block &block : r) {
354 rewriter.setInsertionPointToStart(&block);
358 for (
auto &arg : block.getArguments()) {
360 if (arg.getType().isa<mlir::MemRefType>())
363 auto mergeInfo = insertMerge(&block, arg, edgeBuilder, rewriter);
364 blockMerges[&block].push_back(mergeInfo);
365 mergePairs[arg] = mergeInfo.op->getResult(0);
375 Value srcVal = mergeInfo.
val;
377 Block *block = mergeInfo.
op->getBlock();
382 unsigned index = srcVal.cast<BlockArgument>().getArgNumber();
383 Operation *termOp = predBlock->getTerminator();
384 if (mlir::cf::CondBranchOp br = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
386 if (block == br.getTrueDest())
387 return br.getTrueOperand(index);
388 assert(block == br.getFalseDest());
389 return br.getFalseOperand(index);
391 if (isa<mlir::cf::BranchOp>(termOp))
392 return termOp->getOperand(index);
399 for (Block &block : f) {
400 if (!block.isEntryBlock()) {
401 int x = block.getNumArguments() - 1;
402 for (
int i = x; i >= 0; --i)
403 block.eraseArgument(i);
410 template <
typename TOp>
412 auto ops = block->getOps<TOp>();
419 return getFirstOp<ControlMergeOp>(block);
423 for (
auto cbranch : block->getOps<handshake::ConditionalBranchOp>()) {
424 if (cbranch.isControl())
438 for (Block &block : r) {
439 for (
auto &mergeInfo : blockMerges[&block]) {
442 for (
auto *predBlock : block.getPredecessors()) {
444 assert(mgOperand !=
nullptr);
445 if (!mgOperand.getDefiningOp()) {
446 assert(mergePairs.count(mgOperand));
447 mgOperand = mergePairs[mgOperand];
449 mergeInfo.dataEdges[operandIdx].setValue(mgOperand);
455 for (Operation &opp : block)
456 if (!isa<MergeLikeOpInterface>(opp))
457 opp.replaceUsesOfWith(mergeInfo.val, mergeInfo.op->getResult(0));
463 for (Block &block : r) {
466 assert(cntrlMg !=
nullptr);
468 for (
auto &mergeInfo : blockMerges[&block]) {
469 if (mergeInfo.op != cntrlMg) {
473 assert(mergeInfo.indexEdge.has_value());
474 (*mergeInfo.indexEdge).setValue(cntrlMg->getResult(1));
484 return isa<memref::AllocOp, memref::AllocaOp>(op);
488 HandshakeLowering::addMergeOps(ConversionPatternRewriter &rewriter) {
499 BlockOps mergeOps = insertMergeOps(mergePairs, edgeBuilder, rewriter);
509 for (
auto &u : val.getUses())
511 if (isa<MergeLikeOpInterface>(u.getOwner()))
522 for (
int i = 0, e = block->getNumSuccessors(); i < e; ++i) {
524 Block *succ = block->getSuccessor(i);
525 for (
auto &u : val.getUses()) {
526 if (u.getOwner()->getBlock() == succ)
529 uses = (curr > uses) ? curr : uses;
543 class FeedForwardNetworkRewriter {
546 ConversionPatternRewriter &rewriter)
547 : hl(hl), rewriter(rewriter), postDomInfo(hl.getRegion().getParentOp()),
548 domInfo(hl.getRegion().getParentOp()),
549 loopInfo(domInfo.getDomTree(&hl.getRegion())) {}
550 LogicalResult apply();
554 ConversionPatternRewriter &rewriter;
555 PostDominanceInfo postDomInfo;
556 DominanceInfo domInfo;
557 CFGLoopInfo loopInfo;
559 using BlockPair = std::pair<Block *, Block *>;
560 using BlockPairs = SmallVector<BlockPair>;
561 LogicalResult findBlockPairs(BlockPairs &blockPairs);
563 BufferOp buildSplitNetwork(Block *splitBlock,
564 handshake::ConditionalBranchOp &ctrlBr);
565 LogicalResult buildMergeNetwork(Block *
mergeBlock, BufferOp buf,
566 handshake::ConditionalBranchOp &ctrlBr);
569 bool requiresOperandFlip(ControlMergeOp &ctrlMerge,
570 handshake::ConditionalBranchOp &ctrlBr);
571 bool formsIrreducibleCF(Block *splitBlock, Block *
mergeBlock);
576 HandshakeLowering::feedForwardRewriting(ConversionPatternRewriter &rewriter) {
578 if (this->getRegion().hasOneBlock())
580 return FeedForwardNetworkRewriter(*
this, rewriter).apply();
584 for (CFGLoop *loop : loopInfo.getTopLevelLoops())
585 if (!loop->getExitBlock())
590 bool FeedForwardNetworkRewriter::formsIrreducibleCF(Block *splitBlock,
592 CFGLoop *loop = loopInfo.getLoopFor(
mergeBlock);
593 for (
auto *mergePred :
mergeBlock->getPredecessors()) {
595 if (loop && loop->contains(mergePred))
603 if (llvm::none_of(splitBlock->getSuccessors(), [&](Block *splitSucc) {
604 if (splitSucc == mergeBlock || mergePred == splitBlock)
606 return domInfo.dominates(splitSucc, mergePred);
614 Block *pred = *block->getPredecessors().begin();
615 return pred->getTerminator();
619 FeedForwardNetworkRewriter::findBlockPairs(BlockPairs &blockPairs) {
623 Region &r = hl.getRegion();
624 Operation *parentOp = r.getParentOp();
629 "expected loop to only have one exit block.");
632 if (b.getNumSuccessors() < 2)
636 if (loopInfo.getLoopFor(&b))
639 assert(b.getNumSuccessors() == 2);
640 Block *succ0 = b.getSuccessor(0);
641 Block *succ1 = b.getSuccessor(1);
646 Block *
mergeBlock = postDomInfo.findNearestCommonDominator(succ0, succ1);
650 return parentOp->emitError(
"expected only reducible control flow.")
652 <<
"This branch is involved in the irreducible control flow";
655 unsigned nonLoopPreds = 0;
656 CFGLoop *loop = loopInfo.getLoopFor(
mergeBlock);
657 for (
auto *pred :
mergeBlock->getPredecessors()) {
658 if (loop && loop->contains(pred))
662 if (nonLoopPreds > 2)
664 ->emitError(
"expected a merge block to have two predecessors. "
665 "Did you run the merge block insertion pass?")
667 <<
"This branch jumps to the illegal block";
675 LogicalResult FeedForwardNetworkRewriter::apply() {
678 if (failed(findBlockPairs(pairs)))
682 handshake::ConditionalBranchOp ctrlBr;
683 BufferOp buffer = buildSplitNetwork(splitBlock, ctrlBr);
684 if (failed(buildMergeNetwork(
mergeBlock, buffer, ctrlBr)))
691 BufferOp FeedForwardNetworkRewriter::buildSplitNetwork(
692 Block *splitBlock, handshake::ConditionalBranchOp &ctrlBr) {
693 SmallVector<handshake::ConditionalBranchOp> branches;
694 llvm::copy(splitBlock->getOps<handshake::ConditionalBranchOp>(),
695 std::back_inserter(branches));
697 auto *findRes = llvm::find_if(branches, [](
auto br) {
698 return br.getDataOperand().getType().
template isa<NoneType>();
701 assert(findRes &&
"expected one branch for the ctrl signal");
704 Value cond = ctrlBr.getConditionOperand();
705 assert(llvm::all_of(branches, [&](
auto branch) {
706 return branch.getConditionOperand() == cond;
709 Location loc = cond.getLoc();
710 rewriter.setInsertionPointAfterValue(cond);
714 size_t bufferSize = 2;
718 return rewriter.create<handshake::BufferOp>(loc, cond, bufferSize,
719 BufferTypeEnum::fifo);
722 LogicalResult FeedForwardNetworkRewriter::buildMergeNetwork(
723 Block *
mergeBlock, BufferOp buf, handshake::ConditionalBranchOp &ctrlBr) {
725 auto ctrlMerges =
mergeBlock->getOps<handshake::ControlMergeOp>();
726 assert(std::distance(ctrlMerges.begin(), ctrlMerges.end()) == 1);
728 handshake::ControlMergeOp ctrlMerge = *ctrlMerges.begin();
730 if (ctrlMerge.getNumOperands() != 2)
731 return ctrlMerge.emitError(
"expected cmerges to have two operands");
732 rewriter.setInsertionPointAfter(ctrlMerge);
733 Location loc = ctrlMerge->getLoc();
738 bool requiresFlip = requiresOperandFlip(ctrlMerge, ctrlBr);
739 SmallVector<Value> muxOperands;
741 muxOperands = llvm::to_vector(llvm::reverse(ctrlMerge.getOperands()));
743 muxOperands = llvm::to_vector(ctrlMerge.getOperands());
745 Value newCtrl = rewriter.create<handshake::MuxOp>(loc, buf, muxOperands);
747 Value cond = buf.getResult();
752 cond = rewriter.create<arith::XOrIOp>(
753 loc, cond.getType(), cond,
754 rewriter.create<arith::ConstantOp>(
755 loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 1)));
760 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), cond);
765 rewriter.replaceOp(ctrlMerge, {newCtrl, condAsIndex});
769 bool FeedForwardNetworkRewriter::requiresOperandFlip(
770 ControlMergeOp &ctrlMerge, handshake::ConditionalBranchOp &ctrlBr) {
771 assert(ctrlMerge.getNumOperands() == 2 &&
772 "Loops should already have been handled");
774 Value fstOperand = ctrlMerge.getOperand(0);
776 assert(ctrlBr.getTrueResult().hasOneUse() &&
777 "expected the result of a branch to only have one user");
778 Operation *trueUser = *ctrlBr.getTrueResult().user_begin();
779 if (trueUser == ctrlBr)
781 return ctrlBr.getTrueResult() == fstOperand;
785 Block *trueBlock = trueUser->getBlock();
786 return domInfo.dominates(trueBlock, fstOperand.getDefiningOp()->getBlock());
799 class LoopNetworkRewriter {
803 LogicalResult processRegion(Region &r, ConversionPatternRewriter &rewriter);
808 using ExitPair = std::pair<Block *, Block *>;
809 LogicalResult processOuterLoop(Location loc, CFGLoop *loop);
818 BufferOp buildContinueNetwork(Block *loopHeader, Block *loopLatch,
824 void buildExitNetwork(Block *loopHeader,
826 BufferOp loopPrimingRegister,
830 ConversionPatternRewriter *rewriter =
nullptr;
836 HandshakeLowering::loopNetworkRewriting(ConversionPatternRewriter &rewriter) {
837 return LoopNetworkRewriter(*this).processRegion(r, rewriter);
841 LoopNetworkRewriter::processRegion(Region &r,
842 ConversionPatternRewriter &rewriter) {
846 this->rewriter = &rewriter;
848 Operation *op = r.getParentOp();
850 DominanceInfo domInfo(op);
851 CFGLoopInfo loopInfo(domInfo.getDomTree(&r));
853 for (CFGLoop *loop : loopInfo.getTopLevelLoops()) {
854 if (!loop->getLoopLatch())
855 return emitError(op->getLoc()) <<
"Multiple loop latches detected "
856 "(backedges from within the loop "
857 "to the loop header). Loop task "
858 "pipelining is only supported for "
859 "loops with unified loop latches.";
862 if (failed(processOuterLoop(op->getLoc(), loop)))
871 auto inValueIt = llvm::find_if(mux.getDataOperands(), [&](Value operand) {
872 return block == operand.getParentBlock();
875 inValueIt != mux.getOperands().end() &&
876 "Expected mux to have an operand originating from the requested block.");
884 std::vector<Value> sortedOperands;
885 for (
auto in : cmerge.getOperands()) {
886 auto *srcBlock = in.getParentBlock();
891 for (
unsigned i = 0; i < sortedOperands.size(); ++i) {
892 for (
unsigned j = 0; j < sortedOperands.size(); ++j) {
895 assert(sortedOperands[i] != sortedOperands[j] &&
896 "Cannot have an identical operand from two different blocks!");
900 return sortedOperands;
903 BufferOp LoopNetworkRewriter::buildContinueNetwork(Block *loopHeader,
910 llvm::SmallVector<MuxOp> muxesToReplace;
911 llvm::copy(loopHeader->getOps<MuxOp>(), std::back_inserter(muxesToReplace));
917 assert(hl.getBlockEntryControl(loopHeader) == cmerge->getResult(0) &&
918 "Expected control merge to be the control component of a loop header");
919 auto loc = cmerge->getLoc();
922 assert(cmerge->getNumOperands() > 1 &&
"This cannot be a loop header");
926 SmallVector<Value> externalCtrls, loopCtrls;
927 for (
auto cval : cmerge->getOperands()) {
928 if (cval.getParentBlock() == loopLatch)
929 loopCtrls.push_back(cval);
931 externalCtrls.push_back(cval);
933 assert(loopCtrls.size() == 1 &&
934 "Expected a single loop control value to match the single loop latch");
935 Value loopCtrl = loopCtrls.front();
938 rewriter->setInsertionPointToStart(loopHeader);
939 auto externalCtrlMerge = rewriter->create<ControlMergeOp>(loc, externalCtrls);
944 auto primingRegister =
945 rewriter->create<BufferOp>(loc, loopPrimingInput, 1, BufferTypeEnum::seq);
947 primingRegister->setAttr(
"initValues", rewriter->getI64ArrayAttr({0}));
951 auto loopCtrlMux = rewriter->create<MuxOp>(
952 loc, primingRegister.getResult(),
953 llvm::SmallVector<Value>{externalCtrlMerge.getResult(), loopCtrl});
957 cmerge->getResult(0).replaceAllUsesWith(loopCtrlMux.getResult());
960 hl.setBlockEntryControl(loopHeader, loopCtrlMux.getResult());
971 DenseMap<MuxOp, std::vector<Value>> externalDataInputs;
972 DenseMap<MuxOp, Value> loopDataInputs;
973 for (
auto muxOp : muxesToReplace) {
974 if (muxOp == loopCtrlMux)
979 assert( 1 + externalDataInputs[muxOp].size() ==
980 muxOp.getDataOperands().size() &&
981 "Expected all mux operands to be partitioned between loop and "
982 "external data inputs");
990 for (MuxOp mux : muxesToReplace) {
991 auto externalDataMux = rewriter->create<MuxOp>(
992 loc, externalCtrlMerge.getIndex(), externalDataInputs[mux]);
996 ->create<MuxOp>(loc, primingRegister,
997 llvm::SmallVector<Value>{externalDataMux,
998 loopDataInputs[mux]})
1004 rewriter->eraseOp(cmerge);
1007 return primingRegister;
1010 void LoopNetworkRewriter::buildExitNetwork(
1012 BufferOp loopPrimingRegister,
Backedge &loopPrimingInput) {
1013 auto loc = loopPrimingRegister.getLoc();
1022 SmallVector<Value> parityCorrectedConds;
1023 for (
auto &[condBlock, exitBlock] : exitPairs) {
1027 "Expected a conditional control branch op in the loop condition block");
1028 Operation *trueUser = *condBr.getTrueResult().getUsers().begin();
1029 bool isTrueParity = trueUser->getBlock() == exitBlock;
1031 ((*condBr.getFalseResult().getUsers().begin())->getBlock() ==
1033 "The user of either the true or the false result should be in the "
1036 Value condValue = condBr.getConditionOperand();
1040 rewriter->setInsertionPoint(condBr);
1041 condValue = rewriter->create<arith::XOrIOp>(
1042 loc, condValue.getType(), condValue,
1043 rewriter->create<arith::ConstantOp>(
1044 loc, rewriter->getIntegerAttr(rewriter->getI1Type(), 1)));
1046 parityCorrectedConds.push_back(condValue);
1051 auto exitMerge = rewriter->create<MergeOp>(loc, parityCorrectedConds);
1052 loopPrimingInput.
setValue(exitMerge);
1055 LogicalResult LoopNetworkRewriter::processOuterLoop(Location loc,
1060 SmallVector<Block *> exitBlocks;
1061 loop->getExitBlocks(exitBlocks);
1062 for (
auto *exitNode : exitBlocks) {
1063 for (
auto *pred : exitNode->getPredecessors()) {
1065 if (!loop->contains(pred))
1068 ExitPair condPair = {pred, exitNode};
1069 assert(!exitPairs.count(condPair) &&
1070 "identical condition pairs should never be possible");
1071 exitPairs.insert(condPair);
1074 assert(!exitPairs.empty() &&
"No exits from loop?");
1078 if (exitPairs.size() > 1)
1079 return emitError(loc)
1080 <<
"Multiple exits detected within a loop. Loop task pipelining is "
1081 "only supported for loops with unified loop exit blocks.";
1083 Block *header = loop->getHeader();
1088 auto loopPrimingRegisterInput = bebuilder.get(rewriter->getI1Type());
1089 auto loopPrimingRegister = buildContinueNetwork(header, loop->getLoopLatch(),
1090 loopPrimingRegisterInput);
1094 buildExitNetwork(header, exitPairs, loopPrimingRegister,
1095 loopPrimingRegisterInput);
1104 if (
auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
1105 if (condBranchOp.getTrueDest() == succBlock)
1106 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getTrueResult();
1108 assert(condBranchOp.getFalseDest() == succBlock);
1109 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getFalseResult();
1113 return newOp->getResult(0);
1117 HandshakeLowering::addBranchOps(ConversionPatternRewriter &rewriter) {
1121 for (Block &block : r) {
1122 for (Operation &op : block) {
1123 for (
auto result : op.getResults())
1125 liveOuts[&block].push_back(result);
1129 for (Block &block : r) {
1130 Operation *termOp = block.getTerminator();
1131 rewriter.setInsertionPoint(termOp);
1133 for (Value val : liveOuts[&block]) {
1138 for (
int i = 0, e = numBranches; i < e; ++i) {
1139 Operation *newOp =
nullptr;
1141 if (
auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp))
1142 newOp = rewriter.create<handshake::ConditionalBranchOp>(
1143 termOp->getLoc(), condBranchOp.getCondition(), val);
1144 else if (isa<mlir::cf::BranchOp>(termOp))
1145 newOp = rewriter.create<handshake::BranchOp>(termOp->getLoc(), val);
1147 if (newOp ==
nullptr)
1150 for (
int j = 0, e = block.getNumSuccessors(); j < e; ++j) {
1151 Block *succ = block.getSuccessor(j);
1154 for (
auto &u : val.getUses()) {
1155 if (u.getOwner()->getBlock() == succ) {
1156 u.getOwner()->replaceUsesOfWith(val, res);
1168 LogicalResult HandshakeLowering::connectConstantsToControl(
1169 ConversionPatternRewriter &rewriter,
bool sourceConstants) {
1175 if (sourceConstants) {
1176 for (
auto constantOp : llvm::make_early_inc_range(
1177 r.template getOps<mlir::arith::ConstantOp>())) {
1178 rewriter.setInsertionPointAfter(constantOp);
1179 auto value = constantOp.getValue();
1180 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1182 rewriter.create<handshake::SourceOp>(constantOp.getLoc(),
1183 rewriter.getNoneType()));
1186 for (Block &block : r) {
1187 Value blockEntryCtrl = getBlockEntryControl(&block);
1188 for (
auto constantOp : llvm::make_early_inc_range(
1189 block.template getOps<mlir::arith::ConstantOp>())) {
1190 rewriter.setInsertionPointAfter(constantOp);
1191 auto value = constantOp.getValue();
1192 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1193 constantOp,
value.getType(),
value, blockEntryCtrl);
1209 : op(op), ctrlOperand(ctrlOperand) {
1210 assert(op && ctrlOperand);
1211 assert(ctrlOperand.getType().isa<NoneType>() &&
1212 "Control operand must be a NoneType");
1225 for (Operation &op : *block) {
1226 if (
auto branchOp = dyn_cast<handshake::BranchOp>(op))
1227 if (branchOp.isControl())
1228 return {branchOp, branchOp.getDataOperand()};
1229 if (
auto branchOp = dyn_cast<handshake::ConditionalBranchOp>(op))
1230 if (branchOp.isControl())
1231 return {branchOp, branchOp.getDataOperand()};
1232 if (
auto endOp = dyn_cast<handshake::ReturnOp>(op))
1233 return {endOp, endOp.getOperands().back()};
1235 llvm_unreachable(
"Block terminator must exist");
1240 if (
auto memOp = dyn_cast<memref::LoadOp>(op))
1241 out = memOp.getMemRef();
1242 else if (
auto memOp = dyn_cast<memref::StoreOp>(op))
1243 out = memOp.getMemRef();
1244 else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
1245 MemRefAccess access(op);
1246 out = access.memref;
1250 return op->emitOpError(
"Unknown Op type");
1254 return isa<memref::LoadOp, memref::StoreOp, AffineReadOpInterface,
1255 AffineWriteOpInterface>(op);
1259 HandshakeLowering::replaceMemoryOps(ConversionPatternRewriter &rewriter,
1262 std::vector<Operation *> opsToErase;
1265 for (
auto arg : r.getArguments()) {
1266 auto memrefType = dyn_cast<mlir::MemRefType>(arg.getType());
1272 memRefOps.insert(std::make_pair(arg, std::vector<Operation *>()));
1278 for (Operation &op : r.getOps()) {
1282 rewriter.setInsertionPoint(&op);
1286 Operation *newOp =
nullptr;
1288 llvm::TypeSwitch<Operation *>(&op)
1289 .Case<memref::LoadOp>([&](
auto loadOp) {
1292 SmallVector<Value, 8> operands(loadOp.getIndices());
1295 rewriter.create<handshake::LoadOp>(op.getLoc(), memref, operands);
1296 op.getResult(0).replaceAllUsesWith(newOp->getResult(0));
1298 .Case<memref::StoreOp>([&](
auto storeOp) {
1301 SmallVector<Value, 8> operands(storeOp.getIndices());
1304 newOp = rewriter.create<handshake::StoreOp>(
1305 op.getLoc(), storeOp.getValueToStore(), operands);
1307 .Case<AffineReadOpInterface, AffineWriteOpInterface>([&](
auto) {
1309 MemRefAccess access(&op);
1316 if (
auto loadOp = dyn_cast<AffineReadOpInterface>(op))
1317 map = loadOp.getAffineMap();
1319 map = dyn_cast<AffineWriteOpInterface>(op).getAffineMap();
1327 expandAffineMap(rewriter, op.getLoc(), map, access.indices);
1328 assert(operands &&
"Address operands of affine memref access "
1329 "cannot be reduced.");
1331 if (isa<AffineReadOpInterface>(op)) {
1332 auto loadOp = rewriter.create<handshake::LoadOp>(
1333 op.getLoc(), access.memref, *operands);
1335 op.getResult(0).replaceAllUsesWith(loadOp.getDataResult());
1337 newOp = rewriter.create<handshake::StoreOp>(
1338 op.getLoc(), op.getOperand(0), *operands);
1341 .Default([&](
auto) {
1342 op.emitOpError(
"Load/store operation cannot be handled.");
1345 memRefOps[memref].push_back(newOp);
1346 opsToErase.push_back(&op);
1350 for (
unsigned i = 0, e = opsToErase.size(); i != e; ++i) {
1351 auto *op = opsToErase[i];
1352 for (
int j = 0, e = op->getNumOperands(); j < e; ++j)
1353 op->eraseOperand(0);
1354 assert(op->getNumOperands() == 0);
1356 rewriter.eraseOp(op);
1365 if (handshake::LoadOp loadOp = dyn_cast<handshake::LoadOp>(op)) {
1368 SmallVector<Value, 8> results(loadOp.getAddressResults());
1373 assert(dyn_cast<handshake::StoreOp>(op));
1374 handshake::StoreOp storeOp = dyn_cast<handshake::StoreOp>(op);
1375 SmallVector<Value, 8> results(storeOp.getResults());
1382 for (Block &block : f) {
1384 if (!ctrl.hasOneUse())
1390 ConversionPatternRewriter &rewriter) {
1391 std::vector<Operation *> opsToDelete;
1394 for (
auto &op : r.getOps())
1395 if (
isAllocOp(&op) && op.getResult(0).use_empty())
1396 opsToDelete.push_back(&op);
1398 llvm::for_each(opsToDelete, [&](
auto allocOp) { rewriter.eraseOp(allocOp); });
1402 ArrayRef<BlockControlTerm> controlTerms) {
1403 for (
auto term : controlTerms) {
1404 auto &[op, ctrl] = term;
1405 auto *srcOp = ctrl.getDefiningOp();
1408 if (!isa<JoinOp>(srcOp)) {
1409 rewriter.setInsertionPointAfter(srcOp);
1410 Operation *newJoin = rewriter.create<JoinOp>(srcOp->getLoc(), ctrl);
1411 op->replaceUsesOfWith(ctrl, newJoin->getResult(0));
1416 static std::vector<BlockControlTerm>
1418 std::vector<BlockControlTerm> terminators;
1420 for (Operation *op : memOps) {
1422 Block *block = op->getBlock();
1425 if (std::find(terminators.begin(), terminators.end(), term) ==
1427 terminators.push_back(term);
1434 SmallVector<Value, 8> results(op->getOperands());
1435 results.push_back(val);
1436 op->setOperands(results);
1442 for (
auto *op : memOps) {
1443 if (isa<handshake::LoadOp>(op))
1449 Operation *memOp,
int offset,
1450 ArrayRef<int> cntrlInd) {
1453 for (
int i = 0, e = memOps.size(); i < e; ++i) {
1454 auto *op = memOps[i];
1456 auto *srcOp = ctrl.getDefiningOp();
1457 if (!isa<JoinOp>(srcOp)) {
1458 return srcOp->emitOpError(
"Op expected to be a JoinOp");
1465 void HandshakeLowering::setMemOpControlInputs(
1466 ConversionPatternRewriter &rewriter, ArrayRef<Operation *> memOps,
1467 Operation *memOp,
int offset, ArrayRef<int> cntrlInd) {
1468 for (
int i = 0, e = memOps.size(); i < e; ++i) {
1469 std::vector<Value> controlOperands;
1470 Operation *currOp = memOps[i];
1471 Block *currBlock = currOp->getBlock();
1474 Value blockEntryCtrl = getBlockEntryControl(currBlock);
1475 controlOperands.push_back(blockEntryCtrl);
1478 for (
int j = 0, f = i; j < f; ++j) {
1479 Operation *predOp = memOps[j];
1480 Block *predBlock = predOp->getBlock();
1481 if (currBlock == predBlock)
1483 if (!(isa<handshake::LoadOp>(currOp) && isa<handshake::LoadOp>(predOp)))
1485 controlOperands.push_back(memOp->getResult(offset + cntrlInd[j]));
1489 if (controlOperands.size() == 1)
1494 rewriter.setInsertionPoint(currOp);
1496 rewriter.create<JoinOp>(currOp->getLoc(), controlOperands);
1503 HandshakeLowering::connectToMemory(ConversionPatternRewriter &rewriter,
1508 for (
auto memory : memRefOps) {
1510 Value memrefOperand = memory.first;
1514 bool isExternalMemory = memrefOperand.isa<BlockArgument>();
1516 mlir::MemRefType memrefType =
1517 memrefOperand.getType().cast<mlir::MemRefType>();
1521 std::vector<Value> operands;
1524 std::vector<BlockControlTerm> controlTerms =
1530 for (
auto valOp : controlTerms)
1531 operands.push_back(valOp.ctrlOperand);
1543 std::vector<int> newInd(memory.second.size(), 0);
1545 for (
int i = 0, e = memory.second.size(); i < e; ++i) {
1546 auto *op = memory.second[i];
1547 if (isa<handshake::StoreOp>(op)) {
1549 operands.insert(operands.end(), results.begin(), results.end());
1556 for (
int i = 0, e = memory.second.size(); i < e; ++i) {
1557 auto *op = memory.second[i];
1558 if (isa<handshake::LoadOp>(op)) {
1560 operands.insert(operands.end(), results.begin(), results.end());
1568 int cntrl_count = lsq ? 0 : memory.second.size();
1570 Block *entryBlock = &r.front();
1571 rewriter.setInsertionPointToStart(entryBlock);
1574 Operation *newOp =
nullptr;
1575 if (isExternalMemory)
1576 newOp = rewriter.create<ExternalMemoryOp>(
1577 entryBlock->front().getLoc(), memrefOperand, operands, ld_count,
1578 cntrl_count - ld_count, mem_count++);
1580 newOp = rewriter.create<MemoryOp>(entryBlock->front().getLoc(), operands,
1581 ld_count, cntrl_count, lsq, mem_count++,
1596 bool control =
true;
1606 setMemOpControlInputs(rewriter, memory.second, newOp, ld_count, newInd);
1618 HandshakeLowering::replaceCallOps(ConversionPatternRewriter &rewriter) {
1619 for (Block &block : r) {
1622 Value blockEntryControl = getBlockEntryControl(&block);
1623 for (Operation &op : block) {
1624 if (
auto callOp = dyn_cast<mlir::func::CallOp>(op)) {
1625 llvm::SmallVector<Value> operands;
1626 llvm::copy(callOp.getOperands(), std::back_inserter(operands));
1627 operands.push_back(blockEntryControl);
1628 rewriter.setInsertionPoint(callOp);
1629 auto instanceOp = rewriter.create<handshake::InstanceOp>(
1630 callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
1633 for (
auto it : llvm::zip(callOp.getResults(), instanceOp.getResults()))
1634 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
1635 rewriter.eraseOp(callOp);
1648 bool maximizeArgument(BlockArgument arg)
override {
1649 return !arg.getType().isa<mlir::MemRefType>();
1653 bool maximizeOp(Operation *op)
override {
return !
isAllocOp(op); }
1661 ConversionPatternRewriter &rewriter) {
1662 HandshakeLoweringSSAStrategy strategy;
1666 static LogicalResult
lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
1667 bool sourceConstants,
1668 bool disableTaskPipelining) {
1670 SmallVector<NamedAttribute, 4> attributes;
1671 for (
const auto &attr : funcOp->getAttrs()) {
1672 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
1673 attr.getName() == funcOp.getFunctionTypeAttrName())
1675 attributes.push_back(attr);
1679 llvm::SmallVector<mlir::Type, 8> argTypes;
1680 for (
auto &argType : funcOp.getArgumentTypes())
1681 argTypes.push_back(argType);
1684 llvm::SmallVector<mlir::Type, 8> resTypes;
1685 for (
auto resType : funcOp.getResultTypes())
1686 resTypes.push_back(resType);
1688 handshake::FuncOp newFuncOp;
1693 [&](func::FuncOp funcOp, PatternRewriter &rewriter) {
1694 auto noneType = rewriter.getNoneType();
1695 resTypes.push_back(noneType);
1696 argTypes.push_back(noneType);
1697 auto func_type = rewriter.getFunctionType(argTypes, resTypes);
1698 newFuncOp = rewriter.create<handshake::FuncOp>(
1699 funcOp.getLoc(), funcOp.getName(), func_type, attributes);
1700 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1702 if (!newFuncOp.isExternal()) {
1703 newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(),
1705 newFuncOp.resolveArgAndResNames();
1707 rewriter.eraseOp(funcOp);
1716 if (!newFuncOp.isExternal()) {
1717 Block *bodyBlock = newFuncOp.getBodyBlock();
1718 Value entryCtrl = bodyBlock->getArguments().back();
1720 if (failed(lowerRegion<func::ReturnOp, handshake::ReturnOp>(
1721 fol, sourceConstants, disableTaskPipelining, entryCtrl)))
1730 struct HandshakeRemoveBlockPass
1731 : HandshakeRemoveBlockBase<HandshakeRemoveBlockPass> {
1735 struct CFToHandshakePass :
public CFToHandshakeBase<CFToHandshakePass> {
1736 CFToHandshakePass(
bool sourceConstants,
bool disableTaskPipelining) {
1737 this->sourceConstants = sourceConstants;
1738 this->disableTaskPipelining = disableTaskPipelining;
1740 void runOnOperation()
override {
1741 ModuleOp m = getOperation();
1743 for (
auto funcOp : llvm::make_early_inc_range(m.getOps<func::FuncOp>())) {
1744 if (failed(
lowerFuncOp(funcOp, &getContext(), sourceConstants,
1745 disableTaskPipelining))) {
1746 signalPassFailure();
1755 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1757 bool disableTaskPipelining) {
1758 return std::make_unique<CFToHandshakePass>(sourceConstants,
1759 disableTaskPipelining);
1762 std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
1764 return std::make_unique<HandshakeRemoveBlockPass>();
static ConditionalBranchOp getControlCondBranch(Block *block)
#define returnOnError(logicalResult)
static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx, bool sourceConstants, bool disableTaskPipelining)
static bool isMemoryOp(Operation *op)
static LogicalResult setJoinControlInputs(ArrayRef< Operation * > memOps, Operation *memOp, int offset, ArrayRef< int > cntrlInd)
static void addJoinOps(ConversionPatternRewriter &rewriter, ArrayRef< BlockControlTerm > controlTerms)
static void addLazyForks(Region &f, ConversionPatternRewriter &rewriter)
static bool isLiveOut(Value val)
static Operation * getControlMerge(Block *block)
static SmallVector< Value, 8 > getResultsToMemory(Operation *op)
static unsigned getBlockPredecessorCount(Block *block)
static int getBranchCount(Value val, Block *block)
void removeBasicBlocks(handshake::FuncOp funcOp)
static Operation * getFirstOp(Block *block)
Returns the first occurance of an operation of type TOp, else, returns null op.
static Value getSuccResult(Operation *termOp, Operation *newOp, Block *succBlock)
static Value getMergeOperand(HandshakeLowering::MergeOpInfo mergeInfo, Block *predBlock)
static LogicalResult isValidMemrefType(Location loc, mlir::MemRefType type)
static bool isAllocOp(Operation *op)
static Value getOperandFromBlock(MuxOp mux, Block *block)
static LogicalResult getOpMemRef(Operation *op, Value &out)
static LogicalResult partiallyLowerOp(const std::function< LogicalResult(TOp, ConversionPatternRewriter &)> &loweringFunc, MLIRContext *ctx, TOp op)
static void addValueToOperands(Operation *op, Value val)
static bool loopsHaveSingleExit(CFGLoopInfo &loopInfo)
static std::vector< Value > getSortedInputs(ControlMergeOp cmerge, MuxOp mux)
static void removeBlockOperands(Region &f)
static void removeUnusedAllocOps(Region &r, ConversionPatternRewriter &rewriter)
static LogicalResult maximizeSSANoMem(Region &r, ConversionPatternRewriter &rewriter)
Converts every value in the region into maximal SSA form, unless the value is a block argument of typ...
static Operation * findBranchToBlock(Block *block)
static std::vector< BlockControlTerm > getControlTerminators(ArrayRef< Operation * > memOps)
static BlockControlTerm getBlockControlTerminator(Block *block)
static void reconnectMergeOps(Region &r, HandshakeLowering::BlockOps blockMerges, HandshakeLowering::ValueMap &mergePairs)
static void setLoadDataInputs(ArrayRef< Operation * > memOps, Operation *memOp)
assert(baseType &&"element must be base type")
static void insertFork(Value result, OpBuilder &rewriter)
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
LowerRegionTarget(MLIRContext &context, Region ®ion)
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Strategy class to control the behavior of SSA maximization.
DenseMap< Block *, std::vector< MergeOpInfo > > BlockOps
DenseMap< Block *, std::vector< Value > > BlockValues
llvm::MapVector< Value, std::vector< Operation * > > MemRefToMemoryAccessOp
DenseMap< Value, Value > ValueMap
llvm::function_ref< LogicalResult(Region &, ConversionPatternRewriter &)> RegionLoweringFunc
LogicalResult partiallyLowerRegion(const RegionLoweringFunc &loweringFunc, MLIRContext *ctx, Region &r)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createCFToHandshakePass(bool sourceConstants=false, bool disableTaskPipelining=false)
std::unique_ptr< mlir::OperationPass< handshake::FuncOp > > createHandshakeRemoveBlockPass()
Holds information about an handshake "basic block terminator" control operation.
friend bool operator==(const BlockControlTerm &lhs, const BlockControlTerm &rhs)
Checks for member-wise equality.
Value ctrlOperand
The operation's control operand (must have type NoneType)
BlockControlTerm(Operation *op, Value ctrlOperand)
Operation * op
The operation.
Allows to partially lower a region by matching on the parent operation to then call the provided part...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value >, ConversionPatternRewriter &rewriter) const override
std::function< LogicalResult(Region &, ConversionPatternRewriter &)> PartialLoweringFunc
PartialLowerRegion(LowerRegionTarget &target, MLIRContext *context, LogicalResult &loweringResRef, const PartialLoweringFunc &fun)
LogicalResult & loweringRes
LowerRegionTarget & target