Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CFToHandshake.cpp
Go to the documentation of this file.
1//===- CFToHandshake.cpp - Convert standard MLIR into dataflow IR ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//===----------------------------------------------------------------------===//
7//
8// This is the main Standard to Handshake Conversion Pass Implementation.
9//
10//===----------------------------------------------------------------------===//
11
18#include "mlir/Analysis/CFGLoopInfo.h"
19#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
20#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
21#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
22#include "mlir/Dialect/Affine/IR/AffineOps.h"
23#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
24#include "mlir/Dialect/Affine/Utils.h"
25#include "mlir/Dialect/Arith/IR/Arith.h"
26#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
27#include "mlir/Dialect/Func/IR/FuncOps.h"
28#include "mlir/Dialect/MemRef/IR/MemRef.h"
29#include "mlir/Dialect/SCF/IR/SCF.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinOps.h"
32#include "mlir/IR/Diagnostics.h"
33#include "mlir/IR/Dominance.h"
34#include "mlir/IR/OpImplementation.h"
35#include "mlir/IR/PatternMatch.h"
36#include "mlir/IR/Types.h"
37#include "mlir/IR/Value.h"
38#include "mlir/Pass/Pass.h"
39#include "mlir/Support/LLVM.h"
40#include "mlir/Transforms/DialectConversion.h"
41#include "mlir/Transforms/Passes.h"
42#include "llvm/ADT/SmallSet.h"
43#include "llvm/ADT/TypeSwitch.h"
44#include "llvm/Support/raw_ostream.h"
45
46#include <list>
47#include <map>
48
49namespace circt {
50#define GEN_PASS_DEF_CFTOHANDSHAKE
51#define GEN_PASS_DEF_HANDSHAKEREMOVEBLOCK
52#include "circt/Conversion/Passes.h.inc"
53} // namespace circt
54
55using namespace mlir;
56using namespace mlir::func;
57using namespace mlir::affine;
58using namespace circt;
59using namespace circt::handshake;
60using namespace std;
61
62// ============================================================================
63// Partial lowering infrastructure
64// ============================================================================
65
66namespace {
67template <typename TOp>
68class LowerOpTarget : public ConversionTarget {
69public:
70 explicit LowerOpTarget(MLIRContext &context) : ConversionTarget(context) {
71 loweredOps.clear();
72 addLegalDialect<HandshakeDialect>();
73 addLegalDialect<mlir::func::FuncDialect>();
74 addLegalDialect<mlir::arith::ArithDialect>();
75 addIllegalDialect<mlir::scf::SCFDialect>();
76 addIllegalDialect<AffineDialect>();
77
78 /// The root operation to be replaced is marked dynamically legal
79 /// based on the lowering status of the given operation, see
80 /// PartialLowerOp.
81 addDynamicallyLegalOp<TOp>([&](const auto &op) { return loweredOps[op]; });
82 }
83 DenseMap<Operation *, bool> loweredOps;
84};
85
86/// Default function for partial lowering of handshake::FuncOp. Lowering is
87/// achieved by a provided partial lowering function.
88///
89/// A partial lowering function may only replace a subset of the operations
90/// within the funcOp currently being lowered. However, the dialect conversion
91/// scheme requires the matched root operation to be replaced/updated/erased. It
92/// is the partial update function's responsibility to ensure this. The parital
93/// update function may only mutate the IR through the provided
94/// ConversionPatternRewriter, like any other ConversionPattern.
95/// Next, the function operation is expected to go
96/// from illegal to legalized, after matchAndRewrite returned true. To work
97/// around this, LowerFuncOpTarget::loweredFuncs is used to communicate between
98/// the target and the conversion, to indicate that the partial lowering was
99/// completed.
100template <typename TOp>
101struct PartialLowerOp : public ConversionPattern {
102 using PartialLoweringFunc =
103 std::function<LogicalResult(TOp, ConversionPatternRewriter &)>;
104
105public:
106 PartialLowerOp(LowerOpTarget<TOp> &target, MLIRContext *context,
107 LogicalResult &loweringResRef, const PartialLoweringFunc &fun)
108 : ConversionPattern(TOp::getOperationName(), 1, context), target(target),
109 loweringRes(loweringResRef), fun(fun) {}
110 using ConversionPattern::ConversionPattern;
111 LogicalResult
112 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
113 ConversionPatternRewriter &rewriter) const override {
114 assert(isa<TOp>(op));
115 loweringRes = fun(dyn_cast<TOp>(op), rewriter);
116 target.loweredOps[op] = true;
117 return loweringRes;
118 };
119
120private:
121 LowerOpTarget<TOp> &target;
122 LogicalResult &loweringRes;
123 // NOTE: this is basically the rewrite function
124 PartialLoweringFunc fun;
125};
126} // namespace
127
128// Convenience function for running lowerToHandshake with a partial
129// handshake::FuncOp lowering function.
130template <typename TOp>
131static LogicalResult partiallyLowerOp(
132 const std::function<LogicalResult(TOp, ConversionPatternRewriter &)>
133 &loweringFunc,
134 MLIRContext *ctx, TOp op) {
135
136 RewritePatternSet patterns(ctx);
137 auto target = LowerOpTarget<TOp>(*ctx);
138 LogicalResult partialLoweringSuccessfull = success();
139 patterns.add<PartialLowerOp<TOp>>(target, ctx, partialLoweringSuccessfull,
140 loweringFunc);
141 return success(
142 applyPartialConversion(op, target, std::move(patterns)).succeeded() &&
143 partialLoweringSuccessfull.succeeded());
144}
145
146class LowerRegionTarget : public ConversionTarget {
147public:
148 explicit LowerRegionTarget(MLIRContext &context, Region &region)
149 : ConversionTarget(context), region(region) {
150 // The root operation is marked dynamically legal to ensure
151 // the pattern on its region is only applied once.
152 markUnknownOpDynamicallyLegal([&](Operation *op) {
153 if (op != region.getParentOp())
154 return true;
155 return opLowered;
156 });
157 }
158 bool opLowered = false;
159 Region &region;
160};
161
162/// Allows to partially lower a region by matching on the parent operation to
163/// then call the provided partial lowering function with the region and the
164/// rewriter.
165///
166/// The interplay with the target is similar to PartialLowerOp
167struct PartialLowerRegion : public ConversionPattern {
169 std::function<LogicalResult(Region &, ConversionPatternRewriter &)>;
170
171public:
173 LogicalResult &loweringResRef,
175 : ConversionPattern(target.region.getParentOp()->getName().getStringRef(),
176 1, context),
177 target(target), loweringRes(loweringResRef), fun(fun) {}
178 using ConversionPattern::ConversionPattern;
179 LogicalResult
180 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
181 ConversionPatternRewriter &rewriter) const override {
182 rewriter.modifyOpInPlace(
183 op, [&] { loweringRes = fun(target.region, rewriter); });
184
185 target.opLowered = true;
186 return loweringRes;
187 };
188
189private:
191 LogicalResult &loweringRes;
193};
194
195LogicalResult
196handshake::partiallyLowerRegion(const RegionLoweringFunc &loweringFunc,
197 MLIRContext *ctx, Region &r) {
198
199 Operation *op = r.getParentOp();
200 RewritePatternSet patterns(ctx);
201 auto target = LowerRegionTarget(*ctx, r);
202 LogicalResult partialLoweringSuccessfull = success();
203 patterns.add<PartialLowerRegion>(target, ctx, partialLoweringSuccessfull,
204 loweringFunc);
205 return success(
206 applyPartialConversion(op, target, std::move(patterns)).succeeded() &&
207 partialLoweringSuccessfull.succeeded());
208}
209
210// ============================================================================
211// Start of lowering passes
212// ============================================================================
213
215 auto it = blockEntryControlMap.find(block);
216 assert(it != blockEntryControlMap.end() &&
217 "No block entry control value registerred for this block!");
218 return it->second;
219}
220
221void HandshakeLowering::setBlockEntryControl(Block *block, Value v) {
222 blockEntryControlMap[block] = v;
223}
224
226 Block *entryBlock = &r.front();
227 auto &entryBlockOps = entryBlock->getOperations();
228
229 // Move all operations to entry block and erase other blocks.
230 for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
231 entryBlockOps.splice(entryBlockOps.end(), block.getOperations());
232
233 block.clear();
234 block.dropAllDefinedValueUses();
235 for (size_t i = 0; i < block.getNumArguments(); i++) {
236 block.eraseArgument(i);
237 }
238 block.erase();
239 }
240
241 // Remove any control flow operations, and move the non-control flow
242 // terminator op to the end of the entry block.
243 for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) {
244 if (!terminatorLike.hasTrait<OpTrait::IsTerminator>())
245 continue;
246
247 if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(terminatorLike)) {
248 terminatorLike.erase();
249 continue;
250 }
251
252 // Else, assume that this is a return-like terminator op.
253 terminatorLike.moveBefore(entryBlock, entryBlock->end());
254 }
255}
256
257LogicalResult
258HandshakeLowering::runSSAMaximization(ConversionPatternRewriter &rewriter,
259 Value entryCtrl) {
260 return maximizeSSA(entryCtrl, rewriter);
261}
262
264 if (funcOp.isExternal())
265 return; // nothing to do, external funcOp.
266
267 removeBasicBlocks(funcOp.getBody());
268}
269
270static LogicalResult isValidMemrefType(Location loc, mlir::MemRefType type) {
271 if (type.getNumDynamicDims() != 0 || type.getShape().size() != 1)
272 return emitError(loc) << "memref's must be both statically sized and "
273 "unidimensional.";
274 return success();
275}
276
277static unsigned getBlockPredecessorCount(Block *block) {
278 // Returns number of block predecessors
279 auto predecessors = block->getPredecessors();
280 return std::distance(predecessors.begin(), predecessors.end());
281}
282
283// Insert appropriate type of Merge CMerge for control-only path,
284// Merge for single-successor blocks, Mux otherwise
286HandshakeLowering::insertMerge(Block *block, Value val,
287 BackedgeBuilder &edgeBuilder,
288 ConversionPatternRewriter &rewriter) {
289 unsigned numPredecessors = getBlockPredecessorCount(block);
290 auto insertLoc = block->front().getLoc();
291 SmallVector<Backedge> dataEdges;
292 SmallVector<Value> operands;
293
294 // Every block (except the entry block) needs to feed it's entry control into
295 // a control merge
296 if (val == getBlockEntryControl(block)) {
297
298 Operation *mergeOp;
299 if (block == &r.front()) {
300 // For consistency within the entry block, replace the latter's entry
301 // control with the output of a MergeOp that takes the control-only
302 // network's start point as input. This makes it so that only the
303 // MergeOp's output is used as a control within the entry block, instead
304 // of a combination of the MergeOp's output and the function/block control
305 // argument. Taking this step out should have no impact on functionality
306 // but would make the resulting IR less "regular"
307 operands.push_back(val);
308 mergeOp = handshake::MergeOp::create(rewriter, insertLoc, operands);
309 } else {
310 for (unsigned i = 0; i < numPredecessors; i++) {
311 auto edge = edgeBuilder.get(rewriter.getNoneType());
312 dataEdges.push_back(edge);
313 operands.push_back(Value(edge));
314 }
315 mergeOp =
316 handshake::ControlMergeOp::create(rewriter, insertLoc, operands);
317 }
318 setBlockEntryControl(block, mergeOp->getResult(0));
319 return MergeOpInfo{mergeOp, val, dataEdges};
320 }
321
322 // Every live-in value to a block is passed through a merge-like operation,
323 // even when it's not required for circuit correctness (useless merge-like
324 // operations are removed down the line during handshake canonicalization)
325
326 // Insert "dummy" MergeOp's for blocks with less than two predecessors
327 if (numPredecessors <= 1) {
328 if (numPredecessors == 0) {
329 // All of the entry block's block arguments get passed through a dummy
330 // MergeOp. There is no need for a backedge here as the unique operand can
331 // be resolved immediately
332 operands.push_back(val);
333 } else {
334 // The value incoming from the single block predecessor will be resolved
335 // later during merge reconnection
336 auto edge = edgeBuilder.get(val.getType());
337 dataEdges.push_back(edge);
338 operands.push_back(Value(edge));
339 }
340 auto merge = handshake::MergeOp::create(rewriter, insertLoc, operands);
341 return MergeOpInfo{merge, val, dataEdges};
342 }
343
344 // Create a backedge for the index operand, and another one for each data
345 // operand. The index operand will eventually resolve to the current block's
346 // control merge index output, while data operands will resolve to their
347 // respective values from each block predecessor
348 Backedge indexEdge = edgeBuilder.get(rewriter.getIndexType());
349 for (unsigned i = 0; i < numPredecessors; i++) {
350 auto edge = edgeBuilder.get(val.getType());
351 dataEdges.push_back(edge);
352 operands.push_back(Value(edge));
353 }
354 auto mux =
355 handshake::MuxOp::create(rewriter, insertLoc, Value(indexEdge), operands);
356 return MergeOpInfo{mux, val, dataEdges, indexEdge};
357}
358
361 BackedgeBuilder &edgeBuilder,
362 ConversionPatternRewriter &rewriter) {
363 HandshakeLowering::BlockOps blockMerges;
364 for (Block &block : r) {
365 rewriter.setInsertionPointToStart(&block);
366
367 // All of the block's live-ins are passed explictly through block arguments
368 // thanks to prior SSA maximization
369 for (auto &arg : block.getArguments()) {
370 // No merges on memref block arguments; these are handled separately
371 if (isa<mlir::MemRefType>(arg.getType()))
372 continue;
373
374 auto mergeInfo = insertMerge(&block, arg, edgeBuilder, rewriter);
375 blockMerges[&block].push_back(mergeInfo);
376 mergePairs[arg] = mergeInfo.op->getResult(0);
377 }
378 }
379 return blockMerges;
380}
381
382// Get value from predBlock which will be set as operand of op (merge)
384 Block *predBlock) {
385 // The input value to the merge operations
386 Value srcVal = mergeInfo.val;
387 // The block the merge operation belongs to
388 Block *block = mergeInfo.op->getBlock();
389
390 // The block terminator is either a cf-level branch or cf-level conditional
391 // branch. In either case, identify the value passed to the block using its
392 // index in the list of block arguments
393 unsigned index = cast<BlockArgument>(srcVal).getArgNumber();
394 Operation *termOp = predBlock->getTerminator();
395 if (mlir::cf::CondBranchOp br = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
396 // Block should be one of the two destinations of the conditional branch
397 if (block == br.getTrueDest())
398 return br.getTrueOperand(index);
399 assert(block == br.getFalseDest());
400 return br.getFalseOperand(index);
401 }
402 if (isa<mlir::cf::BranchOp>(termOp))
403 return termOp->getOperand(index);
404 return nullptr;
405}
406
407static void removeBlockOperands(Region &f) {
408 // Remove all block arguments, they are no longer used
409 // eraseArguments also removes corresponding branch operands
410 for (Block &block : f) {
411 if (!block.isEntryBlock()) {
412 int x = block.getNumArguments() - 1;
413 for (int i = x; i >= 0; --i)
414 block.eraseArgument(i);
415 }
416 }
417}
418
419/// Returns the first occurance of an operation of type TOp, else, returns
420/// null op.
421template <typename TOp>
422static Operation *getFirstOp(Block *block) {
423 auto ops = block->getOps<TOp>();
424 if (ops.empty())
425 return nullptr;
426 return *ops.begin();
427}
428
429static Operation *getControlMerge(Block *block) {
430 return getFirstOp<ControlMergeOp>(block);
431}
432
433static ConditionalBranchOp getControlCondBranch(Block *block) {
434 for (auto cbranch : block->getOps<handshake::ConditionalBranchOp>()) {
435 if (cbranch.isControl())
436 return cbranch;
437 }
438 return nullptr;
439}
440
441static void reconnectMergeOps(Region &r,
442 HandshakeLowering::BlockOps blockMerges,
443 HandshakeLowering::ValueMap &mergePairs) {
444 // At this point all merge-like operations have backedges as operands.
445 // We here replace all backedge values with appropriate value from
446 // predecessor block. The predecessor can either be a merge, the original
447 // defining value, or a branch operand.
448
449 for (Block &block : r) {
450 for (auto &mergeInfo : blockMerges[&block]) {
451 int operandIdx = 0;
452 // Set appropriate operand from each predecessor block
453 for (auto *predBlock : block.getPredecessors()) {
454 Value mgOperand = getMergeOperand(mergeInfo, predBlock);
455 assert(mgOperand != nullptr);
456 if (!mgOperand.getDefiningOp()) {
457 assert(mergePairs.count(mgOperand));
458 mgOperand = mergePairs[mgOperand];
459 }
460 mergeInfo.dataEdges[operandIdx].setValue(mgOperand);
461 operandIdx++;
462 }
463
464 // Reconnect all operands originating from livein defining value through
465 // corresponding merge of that block
466 for (Operation &opp : block)
467 if (!isa<MergeLikeOpInterface>(opp))
468 opp.replaceUsesOfWith(mergeInfo.val, mergeInfo.op->getResult(0));
469 }
470 }
471
472 // Connect select operand of muxes to control merge's index result in all
473 // blocks with more than one predecessor
474 for (Block &block : r) {
475 if (getBlockPredecessorCount(&block) > 1) {
476 Operation *cntrlMg = getControlMerge(&block);
477 assert(cntrlMg != nullptr);
478
479 for (auto &mergeInfo : blockMerges[&block]) {
480 if (mergeInfo.op != cntrlMg) {
481 // If the block has multiple predecessors, merge-like operation that
482 // are not the block's control merge must have an index operand (at
483 // this point, an index backedge)
484 assert(mergeInfo.indexEdge.has_value());
485 (*mergeInfo.indexEdge).setValue(cntrlMg->getResult(1));
486 }
487 }
488 }
489 }
490
492}
493
494static bool isAllocOp(Operation *op) {
495 return isa<memref::AllocOp, memref::AllocaOp>(op);
496}
497
498LogicalResult
499HandshakeLowering::addMergeOps(ConversionPatternRewriter &rewriter) {
500
501 // Stores mapping from each value that pass through a merge operation to the
502 // first result of that merge operation
503 ValueMap mergePairs;
504
505 // Create backedge builder to manage operands of merge operations between
506 // insertion and reconnection
507 BackedgeBuilder edgeBuilder{rewriter, r.front().front().getLoc()};
508
509 // Insert merge operations (with backedges instead of actual operands)
510 BlockOps mergeOps = insertMergeOps(mergePairs, edgeBuilder, rewriter);
511
512 // Reconnect merge operations with values incoming from predecessor blocks
513 // and resolve all backedges that were created during merge insertion
514 reconnectMergeOps(r, mergeOps, mergePairs);
515 return success();
516}
517
518static bool isLiveOut(Value val) {
519 // Identifies liveout values after adding Merges
520 for (auto &u : val.getUses())
521 // Result is liveout if used by some Merge block
522 if (isa<MergeLikeOpInterface>(u.getOwner()))
523 return true;
524 return false;
525}
526
527// A value can have multiple branches in a single successor block
528// (for instance, there can be an SSA phi and a merge that we insert)
529// This function determines the number of branches to insert based on the
530// value uses in successor blocks
531static int getBranchCount(Value val, Block *block) {
532 int uses = 0;
533 for (int i = 0, e = block->getNumSuccessors(); i < e; ++i) {
534 int curr = 0;
535 Block *succ = block->getSuccessor(i);
536 for (auto &u : val.getUses()) {
537 if (u.getOwner()->getBlock() == succ)
538 curr++;
539 }
540 uses = (curr > uses) ? curr : uses;
541 }
542 return uses;
543}
544
545namespace {
546
547/// This class inserts a reorder prevention mechanism for blocks with multiple
548/// successors. Such a mechanism is required to guarantee correct execution in a
549/// multi-threaded usage of the circuits.
550///
551/// The order of the results matches the order of the traversals of the
552/// divergence point. A FIFO buffer stores the condition of the conditional
553/// branch. The buffer feeds a mux that guarantees the correct out-order.
554class FeedForwardNetworkRewriter {
555public:
556 FeedForwardNetworkRewriter(HandshakeLowering &hl,
557 ConversionPatternRewriter &rewriter)
558 : hl(hl), rewriter(rewriter), postDomInfo(hl.getRegion().getParentOp()),
559 domInfo(hl.getRegion().getParentOp()),
560 loopInfo(domInfo.getDomTree(&hl.getRegion())) {}
561 LogicalResult apply();
562
563private:
565 ConversionPatternRewriter &rewriter;
566 PostDominanceInfo postDomInfo;
567 DominanceInfo domInfo;
568 CFGLoopInfo loopInfo;
569
570 using BlockPair = std::pair<Block *, Block *>;
571 using BlockPairs = SmallVector<BlockPair>;
572 LogicalResult findBlockPairs(BlockPairs &blockPairs);
573
574 BufferOp buildSplitNetwork(Block *splitBlock,
575 handshake::ConditionalBranchOp &ctrlBr);
576 LogicalResult buildMergeNetwork(Block *mergeBlock, BufferOp buf,
577 handshake::ConditionalBranchOp &ctrlBr);
578
579 // Determines if the cmerge inpus match the cond_br output order.
580 bool requiresOperandFlip(ControlMergeOp &ctrlMerge,
581 handshake::ConditionalBranchOp &ctrlBr);
582 bool formsIrreducibleCF(Block *splitBlock, Block *mergeBlock);
583};
584} // namespace
585
586LogicalResult
587HandshakeLowering::feedForwardRewriting(ConversionPatternRewriter &rewriter) {
588 // Nothing to do on a single block region.
589 if (this->getRegion().hasOneBlock())
590 return success();
591 return FeedForwardNetworkRewriter(*this, rewriter).apply();
592}
593
594[[maybe_unused]] static bool loopsHaveSingleExit(CFGLoopInfo &loopInfo) {
595 for (CFGLoop *loop : loopInfo.getTopLevelLoops())
596 if (!loop->getExitBlock())
597 return false;
598 return true;
599}
600
601bool FeedForwardNetworkRewriter::formsIrreducibleCF(Block *splitBlock,
602 Block *mergeBlock) {
603 CFGLoop *loop = loopInfo.getLoopFor(mergeBlock);
604 for (auto *mergePred : mergeBlock->getPredecessors()) {
605 // Skip loop predecessors
606 if (loop && loop->contains(mergePred))
607 continue;
608
609 // A DAG-CFG is irreducible, iff a merge block has a predecessor that can be
610 // reached from both successors of a split node, e.g., neither is a
611 // dominator.
612 // => Their control flow can merge in other places, which makes this
613 // irreducible.
614 if (llvm::none_of(splitBlock->getSuccessors(), [&](Block *splitSucc) {
615 if (splitSucc == mergeBlock || mergePred == splitBlock)
616 return true;
617 return domInfo.dominates(splitSucc, mergePred);
618 }))
619 return true;
620 }
621 return false;
622}
623
624static Operation *findBranchToBlock(Block *block) {
625 Block *pred = *block->getPredecessors().begin();
626 return pred->getTerminator();
627}
628
629LogicalResult
630FeedForwardNetworkRewriter::findBlockPairs(BlockPairs &blockPairs) {
631 // assumes that merge block insertion happended beforehand
632 // Thus, for each split block, there exists one merge block which is the post
633 // dominator of the child nodes.
634 Region &r = hl.getRegion();
635 Operation *parentOp = r.getParentOp();
636
637 // Assumes that each loop has only one exit block. Such an error should
638 // already be reported by the loop rewriting.
639 assert(loopsHaveSingleExit(loopInfo) &&
640 "expected loop to only have one exit block.");
641
642 for (Block &b : r) {
643 if (b.getNumSuccessors() < 2)
644 continue;
645
646 // Loop headers cannot be merge blocks.
647 if (loopInfo.getLoopFor(&b))
648 continue;
649
650 assert(b.getNumSuccessors() == 2);
651 Block *succ0 = b.getSuccessor(0);
652 Block *succ1 = b.getSuccessor(1);
653
654 if (succ0 == succ1)
655 continue;
656
657 Block *mergeBlock = postDomInfo.findNearestCommonDominator(succ0, succ1);
658
659 // Precondition checks
660 if (formsIrreducibleCF(&b, mergeBlock)) {
661 return parentOp->emitError("expected only reducible control flow.")
662 .attachNote(findBranchToBlock(mergeBlock)->getLoc())
663 << "This branch is involved in the irreducible control flow";
664 }
665
666 unsigned nonLoopPreds = 0;
667 CFGLoop *loop = loopInfo.getLoopFor(mergeBlock);
668 for (auto *pred : mergeBlock->getPredecessors()) {
669 if (loop && loop->contains(pred))
670 continue;
671 nonLoopPreds++;
672 }
673 if (nonLoopPreds > 2)
674 return parentOp
675 ->emitError("expected a merge block to have two predecessors. "
676 "Did you run the merge block insertion pass?")
677 .attachNote(findBranchToBlock(mergeBlock)->getLoc())
678 << "This branch jumps to the illegal block";
679
680 blockPairs.emplace_back(&b, mergeBlock);
681 }
682
683 return success();
684}
685
686LogicalResult FeedForwardNetworkRewriter::apply() {
687 BlockPairs pairs;
688
689 if (failed(findBlockPairs(pairs)))
690 return failure();
691
692 for (auto [splitBlock, mergeBlock] : pairs) {
693 handshake::ConditionalBranchOp ctrlBr;
694 BufferOp buffer = buildSplitNetwork(splitBlock, ctrlBr);
695 if (failed(buildMergeNetwork(mergeBlock, buffer, ctrlBr)))
696 return failure();
697 }
698
699 return success();
700}
701
702BufferOp FeedForwardNetworkRewriter::buildSplitNetwork(
703 Block *splitBlock, handshake::ConditionalBranchOp &ctrlBr) {
704 SmallVector<handshake::ConditionalBranchOp> branches;
705 llvm::copy(splitBlock->getOps<handshake::ConditionalBranchOp>(),
706 std::back_inserter(branches));
707
708 auto *findRes = llvm::find_if(branches, [](auto br) {
709 return llvm::isa<NoneType>(br.getDataOperand().getType());
710 });
711
712 assert(findRes && "expected one branch for the ctrl signal");
713 ctrlBr = *findRes;
714
715 Value cond = ctrlBr.getConditionOperand();
716 assert(llvm::all_of(branches, [&](auto branch) {
717 return branch.getConditionOperand() == cond;
718 }));
719
720 Location loc = cond.getLoc();
721 rewriter.setInsertionPointAfterValue(cond);
722
723 // The buffer size defines the number of threads that can be concurently
724 // traversing the sub-CFG starting at the splitBlock.
725 size_t bufferSize = 2;
726 // TODO how to size these?
727 // Longest path in a CFG-DAG would be O(#blocks)
728
729 return handshake::BufferOp::create(rewriter, loc, cond, bufferSize,
730 BufferTypeEnum::fifo);
731}
732
733LogicalResult FeedForwardNetworkRewriter::buildMergeNetwork(
734 Block *mergeBlock, BufferOp buf, handshake::ConditionalBranchOp &ctrlBr) {
735 // Replace control merge with mux
736 auto ctrlMerges = mergeBlock->getOps<handshake::ControlMergeOp>();
737 assert(std::distance(ctrlMerges.begin(), ctrlMerges.end()) == 1);
738
739 handshake::ControlMergeOp ctrlMerge = *ctrlMerges.begin();
740 // This input might contain irreducible loops that we cannot handle.
741 if (ctrlMerge.getNumOperands() != 2)
742 return ctrlMerge.emitError("expected cmerges to have two operands");
743 rewriter.setInsertionPointAfter(ctrlMerge);
744 Location loc = ctrlMerge->getLoc();
745
746 // The newly inserted mux has to select the results from the correct operand.
747 // As there is no guarantee on the order of cmerge inputs, the correct order
748 // has to be determined first.
749 bool requiresFlip = requiresOperandFlip(ctrlMerge, ctrlBr);
750 SmallVector<Value> muxOperands;
751 if (requiresFlip)
752 muxOperands = llvm::to_vector(llvm::reverse(ctrlMerge.getOperands()));
753 else
754 muxOperands = llvm::to_vector(ctrlMerge.getOperands());
755
756 Value newCtrl = handshake::MuxOp::create(rewriter, loc, buf, muxOperands);
757
758 Value cond = buf.getResult();
759 if (requiresFlip) {
760 // As the mux operand order is the flipped cmerge input order, the index
761 // which replaces the output of the cmerge has to be flipped/negated as
762 // well.
763 cond = arith::XOrIOp::create(
764 rewriter, loc, cond.getType(), cond,
765 arith::ConstantOp::create(
766 rewriter, loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 1)));
767 }
768
769 // Require a cast to index to stick to the type of the mux input.
770 Value condAsIndex =
771 arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), cond);
772
773 hl.setBlockEntryControl(mergeBlock, newCtrl);
774
775 // Replace with new ctrl value from mux and the index
776 rewriter.replaceOp(ctrlMerge, {newCtrl, condAsIndex});
777 return success();
778}
779
780bool FeedForwardNetworkRewriter::requiresOperandFlip(
781 ControlMergeOp &ctrlMerge, handshake::ConditionalBranchOp &ctrlBr) {
782 assert(ctrlMerge.getNumOperands() == 2 &&
783 "Loops should already have been handled");
784
785 Value fstOperand = ctrlMerge.getOperand(0);
786
787 assert(ctrlBr.getTrueResult().hasOneUse() &&
788 "expected the result of a branch to only have one user");
789 Operation *trueUser = *ctrlBr.getTrueResult().user_begin();
790 if (trueUser == ctrlBr)
791 // The cmerge directly consumes the cond_br output.
792 return ctrlBr.getTrueResult() == fstOperand;
793
794 // The cmerge is consumed in an intermediate block. Find out if this block is
795 // a predecessor of the "true" successor of the cmerge.
796 Block *trueBlock = trueUser->getBlock();
797 return domInfo.dominates(trueBlock, fstOperand.getDefiningOp()->getBlock());
798}
799
800namespace {
801// This function creates the loop 'continue' and 'exit' network around backedges
802// in the CFG.
803// We don't have a standard dialect based LoopInfo utility in MLIR
804// (which could probably deduce most of the information that we need for this
805// transformation), so we roll our own loop-detection analysis. This is
806// simplified by the need to only detect outermost loops. Inner loops are
807// not included in the loop network (since we only care about restricting
808// different function invocations from activating a loop, not prevent loop
809// pipelining within a single function invocation).
810class LoopNetworkRewriter {
811public:
812 LoopNetworkRewriter(HandshakeLowering &hl) : hl(hl) {}
813
814 LogicalResult processRegion(Region &r, ConversionPatternRewriter &rewriter);
815
816private:
817 // An exit pair is a pair of <in loop block, outside loop block> that
818 // indicates where control leaves a loop.
819 using ExitPair = std::pair<Block *, Block *>;
820 LogicalResult processOuterLoop(Location loc, CFGLoop *loop);
821
822 // Builds the loop continue network in between the loop header and its loop
823 // latch. The loop continuation network will replace the existing control
824 // merge in the loop header with a mux + loop priming register.
825 // The 'loopPrimingInput' is a backedge that will later be assigned by
826 // 'buildExitNetwork'. The value is to be used as the input to the loop
827 // priming buffer.
828 // Returns a reference to the loop priming register.
829 BufferOp buildContinueNetwork(Block *loopHeader, Block *loopLatch,
830 Backedge &loopPrimingInput);
831
832 // Builds the loop exit network. This detects the conditional operands used in
833 // each of the exit blocks, matches their parity with the convention used to
834 // prime the loop register, and assigns it to the loop priming register input.
835 void buildExitNetwork(Block *loopHeader,
836 const llvm::SmallSet<ExitPair, 2> &exitPairs,
837 BufferOp loopPrimingRegister,
838 Backedge &loopPrimingInput);
839
840private:
841 ConversionPatternRewriter *rewriter = nullptr;
843};
844} // namespace
845
846LogicalResult
847HandshakeLowering::loopNetworkRewriting(ConversionPatternRewriter &rewriter) {
848 return LoopNetworkRewriter(*this).processRegion(r, rewriter);
849}
850
851LogicalResult
852LoopNetworkRewriter::processRegion(Region &r,
853 ConversionPatternRewriter &rewriter) {
854 // Nothing to do on a single block region.
855 if (r.hasOneBlock())
856 return success();
857 this->rewriter = &rewriter;
858
859 Operation *op = r.getParentOp();
860
861 DominanceInfo domInfo(op);
862 CFGLoopInfo loopInfo(domInfo.getDomTree(&r));
863
864 for (CFGLoop *loop : loopInfo.getTopLevelLoops()) {
865 if (!loop->getLoopLatch())
866 return emitError(op->getLoc()) << "Multiple loop latches detected "
867 "(backedges from within the loop "
868 "to the loop header). Loop task "
869 "pipelining is only supported for "
870 "loops with unified loop latches.";
871
872 // This is the start of an outer loop - go process!
873 if (failed(processOuterLoop(op->getLoc(), loop)))
874 return failure();
875 }
876
877 return success();
878}
879
880// Returns the operand of the 'mux' operation which originated from 'block'.
881static Value getOperandFromBlock(MuxOp mux, Block *block) {
882 auto inValueIt = llvm::find_if(mux.getDataOperands(), [&](Value operand) {
883 return block == operand.getParentBlock();
884 });
885 assert(
886 inValueIt != mux.getDataOperands().end() &&
887 "Expected mux to have an operand originating from the requested block.");
888 return *inValueIt;
889}
890
891// Returns a list of operands from 'mux' which corresponds to the inputs of the
892// 'cmerge' operation. The results are sorted such that the i'th cmerge operand
893// and the i'th sorted operand originate from the same block.
894static std::vector<Value> getSortedInputs(ControlMergeOp cmerge, MuxOp mux) {
895 std::vector<Value> sortedOperands;
896 for (auto in : cmerge.getOperands()) {
897 auto *srcBlock = in.getParentBlock();
898 sortedOperands.push_back(getOperandFromBlock(mux, srcBlock));
899 }
900
901 // Sanity check: ensure that operands are unique
902 for (unsigned i = 0; i < sortedOperands.size(); ++i) {
903 for (unsigned j = 0; j < sortedOperands.size(); ++j) {
904 if (i == j)
905 continue;
906 assert(sortedOperands[i] != sortedOperands[j] &&
907 "Cannot have an identical operand from two different blocks!");
908 }
909 }
910
911 return sortedOperands;
912}
913
914BufferOp LoopNetworkRewriter::buildContinueNetwork(Block *loopHeader,
915 Block *loopLatch,
916 Backedge &loopPrimingInput) {
917 // Gather the muxes to replace before modifying block internals; it's been
918 // found that if this is not done, we have determinism issues wrt. generating
919 // the same order of replaced muxes on repeated runs of an identical
920 // conversion.
921 llvm::SmallVector<MuxOp> muxesToReplace;
922 llvm::copy(loopHeader->getOps<MuxOp>(), std::back_inserter(muxesToReplace));
923
924 // Fetch the control merge of the block; it is assumed that, at this point of
925 // lowering, no other form of control can be used for the loop header block
926 // than a control merge.
927 auto *cmerge = getControlMerge(loopHeader);
928 assert(hl.getBlockEntryControl(loopHeader) == cmerge->getResult(0) &&
929 "Expected control merge to be the control component of a loop header");
930 auto loc = cmerge->getLoc();
931
932 // sanity check: cmerge should have >1 input to actually be a loop
933 assert(cmerge->getNumOperands() > 1 && "This cannot be a loop header");
934
935 // Partition the control merge inputs into those originating from backedges,
936 // and those originating elsewhere.
937 SmallVector<Value> externalCtrls, loopCtrls;
938 for (auto cval : cmerge->getOperands()) {
939 if (cval.getParentBlock() == loopLatch)
940 loopCtrls.push_back(cval);
941 else
942 externalCtrls.push_back(cval);
943 }
944 assert(loopCtrls.size() == 1 &&
945 "Expected a single loop control value to match the single loop latch");
946 Value loopCtrl = loopCtrls.front();
947
948 // Merge all of the controls in each partition
949 rewriter->setInsertionPointToStart(loopHeader);
950 auto externalCtrlMerge = rewriter->create<ControlMergeOp>(loc, externalCtrls);
951
952 // Create loop mux and the loop priming register. The loop mux will on select
953 // "0" select external control, and internal control at "1". This convention
954 // must be followed by the loop exit network.
955 auto primingRegister =
956 rewriter->create<BufferOp>(loc, loopPrimingInput, 1, BufferTypeEnum::seq);
957 // Initialize the priming register to path 0.
958 primingRegister->setAttr("initValues", rewriter->getI64ArrayAttr({0}));
959
960 // The loop control mux will deterministically select between control entering
961 // the loop from any external block or the single loop backedge.
962 auto loopCtrlMux = rewriter->create<MuxOp>(
963 loc, primingRegister.getResult(),
964 llvm::SmallVector<Value>{externalCtrlMerge.getResult(), loopCtrl});
965
966 // Replace the existing control merge 'result' output with the loop control
967 // mux.
968 cmerge->getResult(0).replaceAllUsesWith(loopCtrlMux.getResult());
969
970 // Register the new block entry control value
971 hl.setBlockEntryControl(loopHeader, loopCtrlMux.getResult());
972
973 // Next, we need to consider how to replace the control merge 'index' output,
974 // used to drive input selection to the block.
975
976 // Inputs to the loop header will be sourced from muxes with inputs from both
977 // the loop latch as well as external blocks. Iterate through these and sort
978 // based on the input ordering to the external/internal control merge.
979 // We do this by maintaining a mapping between the external and loop data
980 // inputs for each data mux in the design. The key of these maps is the
981 // original mux (that is to be replaced).
982 DenseMap<MuxOp, std::vector<Value>> externalDataInputs;
983 DenseMap<MuxOp, Value> loopDataInputs;
984 for (auto muxOp : muxesToReplace) {
985 if (muxOp == loopCtrlMux)
986 continue;
987
988 externalDataInputs[muxOp] = getSortedInputs(externalCtrlMerge, muxOp);
989 loopDataInputs[muxOp] = getOperandFromBlock(muxOp, loopLatch);
990 assert(/*loop latch input*/ 1 + externalDataInputs[muxOp].size() ==
991 muxOp.getDataOperands().size() &&
992 "Expected all mux operands to be partitioned between loop and "
993 "external data inputs");
994 }
995
996 // With this, we now replace each of the data input muxes in the loop header.
997 // We instantiate a single mux driven by the external control merge.
998 // This, as well as the corresponding data input coming from within the single
999 // loop latch, will then be selected between by a 3rd mux, based on the
1000 // priming register.
1001 for (MuxOp mux : muxesToReplace) {
1002 auto externalDataMux = rewriter->create<MuxOp>(
1003 loc, externalCtrlMerge.getIndex(), externalDataInputs[mux]);
1004
1005 rewriter->replaceOp(
1006 mux, rewriter
1007 ->create<MuxOp>(loc, primingRegister,
1008 llvm::SmallVector<Value>{externalDataMux,
1009 loopDataInputs[mux]})
1010 .getResult());
1011 }
1012
1013 // Now all values defined by the original cmerge should have been replaced,
1014 // and it may be erased.
1015 rewriter->eraseOp(cmerge);
1016
1017 // Return the priming register to be referenced by the exit network builder.
1018 return primingRegister;
1019}
1020
1021void LoopNetworkRewriter::buildExitNetwork(
1022 Block *loopHeader, const llvm::SmallSet<ExitPair, 2> &exitPairs,
1023 BufferOp loopPrimingRegister, Backedge &loopPrimingInput) {
1024 auto loc = loopPrimingRegister.getLoc();
1025
1026 // Iterate over the exit pairs to gather up the condition signals that need to
1027 // be connected to the exit network. In doing so, we parity-correct these
1028 // condition values based on the convention used in buildContinueNetwork - The
1029 // loop mux will on select "0" select external control, and internal control
1030 // at "1". This convention which must be followed by the loop exit network.
1031 // External control must be selected when exiting the loop (to reprime the
1032 // register).
1033 SmallVector<Value> parityCorrectedConds;
1034 for (auto &[condBlock, exitBlock] : exitPairs) {
1035 auto condBr = getControlCondBranch(condBlock);
1036 assert(
1037 condBr &&
1038 "Expected a conditional control branch op in the loop condition block");
1039 Operation *trueUser = *condBr.getTrueResult().getUsers().begin();
1040 bool isTrueParity = trueUser->getBlock() == exitBlock;
1041 assert(isTrueParity ^
1042 ((*condBr.getFalseResult().getUsers().begin())->getBlock() ==
1043 exitBlock) &&
1044 "The user of either the true or the false result should be in the "
1045 "exit block");
1046
1047 Value condValue = condBr.getConditionOperand();
1048 if (isTrueParity) {
1049 // This goes against the convention, and we have to invert the condition
1050 // value before connecting it to the exit network.
1051 rewriter->setInsertionPoint(condBr);
1052 condValue = rewriter->create<arith::XOrIOp>(
1053 loc, condValue.getType(), condValue,
1054 rewriter->create<arith::ConstantOp>(
1055 loc, rewriter->getIntegerAttr(rewriter->getI1Type(), 1)));
1056 }
1057 parityCorrectedConds.push_back(condValue);
1058 }
1059
1060 // Merge all of the parity-corrected exit conditions and assign them
1061 // to the loop priming input.
1062 auto exitMerge = rewriter->create<MergeOp>(loc, parityCorrectedConds);
1063 loopPrimingInput.setValue(exitMerge);
1064}
1065
1066LogicalResult LoopNetworkRewriter::processOuterLoop(Location loc,
1067 CFGLoop *loop) {
1068 // We determine the exit pairs of the loop; this is the in-loop nodes
1069 // which branch off to the exit nodes.
1071 SmallVector<Block *> exitBlocks;
1072 loop->getExitBlocks(exitBlocks);
1073 for (auto *exitNode : exitBlocks) {
1074 for (auto *pred : exitNode->getPredecessors()) {
1075 // is the predecessor inside the loop?
1076 if (!loop->contains(pred))
1077 continue;
1078
1079 ExitPair condPair = {pred, exitNode};
1080 assert(!exitPairs.count(condPair) &&
1081 "identical condition pairs should never be possible");
1082 exitPairs.insert(condPair);
1083 }
1084 }
1085 assert(!exitPairs.empty() && "No exits from loop?");
1086
1087 // The first precondition to our loop transformation is that only a single
1088 // exit pair exists in the loop.
1089 if (exitPairs.size() > 1)
1090 return emitError(loc)
1091 << "Multiple exits detected within a loop. Loop task pipelining is "
1092 "only supported for loops with unified loop exit blocks.";
1093
1094 Block *header = loop->getHeader();
1095 BackedgeBuilder bebuilder(*rewriter, header->front().getLoc());
1096
1097 // Build the loop continue network. Loop continuation is triggered solely by
1098 // backedges to the header.
1099 auto loopPrimingRegisterInput = bebuilder.get(rewriter->getI1Type());
1100 auto loopPrimingRegister = buildContinueNetwork(header, loop->getLoopLatch(),
1101 loopPrimingRegisterInput);
1102
1103 // Build the loop exit network. Loop exiting is driven solely by exit pairs
1104 // from the loop.
1105 buildExitNetwork(header, exitPairs, loopPrimingRegister,
1106 loopPrimingRegisterInput);
1107
1108 return success();
1109}
1110
1111// Return the appropriate branch result based on successor block which uses it
1112static Value getSuccResult(Operation *termOp, Operation *newOp,
1113 Block *succBlock) {
1114 // For conditional block, check if result goes to true or to false successor
1115 if (auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
1116 if (condBranchOp.getTrueDest() == succBlock)
1117 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getTrueResult();
1118 else {
1119 assert(condBranchOp.getFalseDest() == succBlock);
1120 return dyn_cast<handshake::ConditionalBranchOp>(newOp).getFalseResult();
1121 }
1122 }
1123 // If the block is unconditional, newOp has only one result
1124 return newOp->getResult(0);
1125}
1126
1127LogicalResult
1128HandshakeLowering::addBranchOps(ConversionPatternRewriter &rewriter) {
1129
1130 BlockValues liveOuts;
1131
1132 for (Block &block : r) {
1133 for (Operation &op : block) {
1134 for (auto result : op.getResults())
1135 if (isLiveOut(result))
1136 liveOuts[&block].push_back(result);
1137 }
1138 }
1139
1140 for (Block &block : r) {
1141 Operation *termOp = block.getTerminator();
1142 rewriter.setInsertionPoint(termOp);
1143
1144 for (Value val : liveOuts[&block]) {
1145 // Count the number of branches which the liveout needs
1146 int numBranches = getBranchCount(val, &block);
1147
1148 // Instantiate branches and connect to Merges
1149 for (int i = 0, e = numBranches; i < e; ++i) {
1150 Operation *newOp = nullptr;
1151
1152 if (auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp))
1153 newOp = handshake::ConditionalBranchOp::create(
1154 rewriter, termOp->getLoc(), condBranchOp.getCondition(), val);
1155 else if (isa<mlir::cf::BranchOp>(termOp))
1156 newOp = handshake::BranchOp::create(rewriter, termOp->getLoc(), val);
1157
1158 if (newOp == nullptr)
1159 continue;
1160
1161 for (int j = 0, e = block.getNumSuccessors(); j < e; ++j) {
1162 Block *succ = block.getSuccessor(j);
1163 Value res = getSuccResult(termOp, newOp, succ);
1164
1165 for (auto &u : val.getUses()) {
1166 if (u.getOwner()->getBlock() == succ) {
1167 u.getOwner()->replaceUsesOfWith(val, res);
1168 break;
1169 }
1170 }
1171 }
1172 }
1173 }
1174 }
1175
1176 return success();
1177}
1178
1180 ConversionPatternRewriter &rewriter, bool sourceConstants) {
1181 // Create new constants which have a control-only input to trigger them.
1182 // These are conneted to the control network or optionally to a Source
1183 // operation (always triggering). Control-network connected constants may
1184 // help debugability, but result in a slightly larger circuit.
1185
1186 if (sourceConstants) {
1187 for (auto constantOp : llvm::make_early_inc_range(
1188 r.template getOps<mlir::arith::ConstantOp>())) {
1189 rewriter.setInsertionPointAfter(constantOp);
1190 auto value = constantOp.getValue();
1191 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1192 constantOp, value.getType(), value,
1193 handshake::SourceOp::create(rewriter, constantOp.getLoc(),
1194 rewriter.getNoneType()));
1195 }
1196 } else {
1197 for (Block &block : r) {
1198 Value blockEntryCtrl = getBlockEntryControl(&block);
1199 for (auto constantOp : llvm::make_early_inc_range(
1200 block.template getOps<mlir::arith::ConstantOp>())) {
1201 rewriter.setInsertionPointAfter(constantOp);
1202 auto value = constantOp.getValue();
1203 rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1204 constantOp, value.getType(), value, blockEntryCtrl);
1205 }
1206 }
1207 }
1208 return success();
1209}
1210
1211/// Holds information about an handshake "basic block terminator" control
1212/// operation
1214 /// The operation
1215 Operation *op;
1216 /// The operation's control operand (must have type NoneType)
1218
1219 BlockControlTerm(Operation *op, Value ctrlOperand)
1220 : op(op), ctrlOperand(ctrlOperand) {
1221 assert(op && ctrlOperand);
1222 assert(isa<NoneType>(ctrlOperand.getType()) &&
1223 "Control operand must be a NoneType");
1224 }
1225
1226 /// Checks for member-wise equality
1227 friend bool operator==(const BlockControlTerm &lhs,
1228 const BlockControlTerm &rhs) {
1229 return lhs.op == rhs.op && lhs.ctrlOperand == rhs.ctrlOperand;
1230 }
1231};
1232
1234 // Identify the control terminator operation and its control operand in the
1235 // given block. One such operation must exist in the block
1236 for (Operation &op : *block) {
1237 if (auto branchOp = dyn_cast<handshake::BranchOp>(op))
1238 if (branchOp.isControl())
1239 return {branchOp, branchOp.getDataOperand()};
1240 if (auto branchOp = dyn_cast<handshake::ConditionalBranchOp>(op))
1241 if (branchOp.isControl())
1242 return {branchOp, branchOp.getDataOperand()};
1243 if (auto endOp = dyn_cast<handshake::ReturnOp>(op))
1244 return {endOp, endOp.getOperands().back()};
1245 }
1246 llvm_unreachable("Block terminator must exist");
1247}
1248
1249static LogicalResult getOpMemRef(Operation *op, Value &out) {
1250 out = Value();
1251 if (auto memOp = dyn_cast<memref::LoadOp>(op))
1252 out = memOp.getMemRef();
1253 else if (auto memOp = dyn_cast<memref::StoreOp>(op))
1254 out = memOp.getMemRef();
1255 else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
1256 MemRefAccess access(op);
1257 out = access.memref;
1258 }
1259 if (out != Value())
1260 return success();
1261 return op->emitOpError("Unknown Op type");
1262}
1263
1264static bool isMemoryOp(Operation *op) {
1265 return isa<memref::LoadOp, memref::StoreOp, AffineReadOpInterface,
1266 AffineWriteOpInterface>(op);
1267}
1268
1269LogicalResult
1270HandshakeLowering::replaceMemoryOps(ConversionPatternRewriter &rewriter,
1271 MemRefToMemoryAccessOp &memRefOps) {
1272
1273 std::vector<Operation *> opsToErase;
1274
1275 // Enrich the memRefOps context with BlockArguments, in case they aren't used.
1276 for (auto arg : r.getArguments()) {
1277 auto memrefType = dyn_cast<mlir::MemRefType>(arg.getType());
1278 if (!memrefType)
1279 continue;
1280 // Ensure that this is a valid memref-typed value.
1281 if (failed(isValidMemrefType(arg.getLoc(), memrefType)))
1282 return failure();
1283 memRefOps.insert(std::make_pair(arg, std::vector<Operation *>()));
1284 }
1285
1286 // Replace load and store ops with the corresponding handshake ops
1287 // Need to traverse ops in blocks to store them in memRefOps in program
1288 // order
1289 for (Operation &op : r.getOps()) {
1290 if (!isMemoryOp(&op))
1291 continue;
1292
1293 rewriter.setInsertionPoint(&op);
1294 Value memref;
1295 if (getOpMemRef(&op, memref).failed())
1296 return failure();
1297 Operation *newOp = nullptr;
1298
1299 llvm::TypeSwitch<Operation *>(&op)
1300 .Case<memref::LoadOp>([&](auto loadOp) {
1301 // Get operands which correspond to address indices
1302 // This will add all operands except alloc
1303 SmallVector<Value, 8> operands(loadOp.getIndices());
1304
1305 newOp = handshake::LoadOp::create(rewriter, op.getLoc(), memref,
1306 operands);
1307 op.getResult(0).replaceAllUsesWith(newOp->getResult(0));
1308 })
1309 .Case<memref::StoreOp>([&](auto storeOp) {
1310 // Get operands which correspond to address indices
1311 // This will add all operands except alloc and data
1312 SmallVector<Value, 8> operands(storeOp.getIndices());
1313
1314 // Create new op where operands are store data and address indices
1315 newOp = handshake::StoreOp::create(
1316 rewriter, op.getLoc(), storeOp.getValueToStore(), operands);
1317 })
1318 .Case<AffineReadOpInterface, AffineWriteOpInterface>([&](auto) {
1319 // Get essential memref access inforamtion.
1320 MemRefAccess access(&op);
1321 // The address of an affine load/store operation can be a result
1322 // of an affine map, which is a linear combination of constants
1323 // and parameters. Therefore, we should extract the affine map of
1324 // each address and expand it into proper expressions that
1325 // calculate the result.
1326 AffineMap map;
1327 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
1328 map = loadOp.getAffineMap();
1329 else
1330 map = dyn_cast<AffineWriteOpInterface>(op).getAffineMap();
1331
1332 // The returned object from expandAffineMap is an optional list of
1333 // the expansion results from the given affine map, which are the
1334 // actual address indices that can be used as operands for
1335 // handshake LoadOp/StoreOp. The following processing requires it
1336 // to be a valid result.
1337 auto operands =
1338 expandAffineMap(rewriter, op.getLoc(), map, access.indices);
1339 assert(operands && "Address operands of affine memref access "
1340 "cannot be reduced.");
1341
1342 if (isa<AffineReadOpInterface>(op)) {
1343 auto loadOp = handshake::LoadOp::create(rewriter, op.getLoc(),
1344 access.memref, *operands);
1345 newOp = loadOp;
1346 op.getResult(0).replaceAllUsesWith(loadOp.getDataResult());
1347 } else {
1348 newOp = handshake::StoreOp::create(rewriter, op.getLoc(),
1349 op.getOperand(0), *operands);
1350 }
1351 })
1352 .Default([&](auto) {
1353 op.emitOpError("Load/store operation cannot be handled.");
1354 });
1355
1356 memRefOps[memref].push_back(newOp);
1357 opsToErase.push_back(&op);
1358 }
1359
1360 // Erase old memory ops
1361 for (unsigned i = 0, e = opsToErase.size(); i != e; ++i) {
1362 auto *op = opsToErase[i];
1363 for (int j = 0, e = op->getNumOperands(); j < e; ++j)
1364 op->eraseOperand(0);
1365 assert(op->getNumOperands() == 0);
1366
1367 rewriter.eraseOp(op);
1368 }
1369
1370 return success();
1371}
1372
1373static SmallVector<Value, 8> getResultsToMemory(Operation *op) {
1374 // Get load/store results which are given as inputs to MemoryOp
1375
1376 if (handshake::LoadOp loadOp = dyn_cast<handshake::LoadOp>(op)) {
1377 // For load, get all address outputs/indices
1378 // (load also has one data output which goes to successor operation)
1379 SmallVector<Value, 8> results(loadOp.getAddressResults());
1380 return results;
1381
1382 } else {
1383 // For store, all outputs (data and address indices) go to memory
1384 assert(dyn_cast<handshake::StoreOp>(op));
1385 handshake::StoreOp storeOp = dyn_cast<handshake::StoreOp>(op);
1386 SmallVector<Value, 8> results(storeOp.getResults());
1387 return results;
1388 }
1389}
1390
1391static void addLazyForks(Region &f, ConversionPatternRewriter &rewriter) {
1392
1393 for (Block &block : f) {
1394 Value ctrl = getBlockControlTerminator(&block).ctrlOperand;
1395 if (!ctrl.hasOneUse())
1396 insertFork(ctrl, true, rewriter);
1397 }
1398}
1399
1400static void removeUnusedAllocOps(Region &r,
1401 ConversionPatternRewriter &rewriter) {
1402 std::vector<Operation *> opsToDelete;
1403
1404 // Remove alloc operations whose result have no use
1405 for (auto &op : r.getOps())
1406 if (isAllocOp(&op) && op.getResult(0).use_empty())
1407 opsToDelete.push_back(&op);
1408
1409 llvm::for_each(opsToDelete, [&](auto allocOp) { rewriter.eraseOp(allocOp); });
1410}
1411
1412static void addJoinOps(ConversionPatternRewriter &rewriter,
1413 ArrayRef<BlockControlTerm> controlTerms) {
1414 for (auto term : controlTerms) {
1415 auto &[op, ctrl] = term;
1416 auto *srcOp = ctrl.getDefiningOp();
1417
1418 // Insert only single join per block
1419 if (!isa<JoinOp>(srcOp)) {
1420 rewriter.setInsertionPointAfter(srcOp);
1421 Operation *newJoin = JoinOp::create(rewriter, srcOp->getLoc(), ctrl);
1422 op->replaceUsesOfWith(ctrl, newJoin->getResult(0));
1423 }
1424 }
1425}
1426
1427static std::vector<BlockControlTerm>
1428getControlTerminators(ArrayRef<Operation *> memOps) {
1429 std::vector<BlockControlTerm> terminators;
1430
1431 for (Operation *op : memOps) {
1432 // Get block from which the mem op originates
1433 Block *block = op->getBlock();
1434 // Identify the control terminator in the block
1435 auto term = getBlockControlTerminator(block);
1436 if (std::find(terminators.begin(), terminators.end(), term) ==
1437 terminators.end())
1438 terminators.push_back(term);
1439 }
1440 return terminators;
1441}
1442
1443static void addValueToOperands(Operation *op, Value val) {
1444
1445 SmallVector<Value, 8> results(op->getOperands());
1446 results.push_back(val);
1447 op->setOperands(results);
1448}
1449
1450static void setLoadDataInputs(ArrayRef<Operation *> memOps, Operation *memOp) {
1451 // Set memory outputs as load input data
1452 int ld_count = 0;
1453 for (auto *op : memOps) {
1454 if (isa<handshake::LoadOp>(op))
1455 addValueToOperands(op, memOp->getResult(ld_count++));
1456 }
1457}
1458
1459static LogicalResult setJoinControlInputs(ArrayRef<Operation *> memOps,
1460 Operation *memOp, int offset,
1461 ArrayRef<int> cntrlInd) {
1462 // Connect all memory ops to the join of that block (ensures that all mem
1463 // ops terminate before a new block starts)
1464 for (int i = 0, e = memOps.size(); i < e; ++i) {
1465 auto *op = memOps[i];
1466 Value ctrl = getBlockControlTerminator(op->getBlock()).ctrlOperand;
1467 auto *srcOp = ctrl.getDefiningOp();
1468 if (!isa<JoinOp>(srcOp)) {
1469 return srcOp->emitOpError("Op expected to be a JoinOp");
1470 }
1471 addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i]));
1472 }
1473 return success();
1474}
1475
1477 ConversionPatternRewriter &rewriter, ArrayRef<Operation *> memOps,
1478 Operation *memOp, int offset, ArrayRef<int> cntrlInd) {
1479 for (int i = 0, e = memOps.size(); i < e; ++i) {
1480 std::vector<Value> controlOperands;
1481 Operation *currOp = memOps[i];
1482 Block *currBlock = currOp->getBlock();
1483
1484 // Set load/store control inputs from the block input control value
1485 Value blockEntryCtrl = getBlockEntryControl(currBlock);
1486 controlOperands.push_back(blockEntryCtrl);
1487
1488 // Set load/store control inputs from predecessors in block
1489 for (int j = 0, f = i; j < f; ++j) {
1490 Operation *predOp = memOps[j];
1491 Block *predBlock = predOp->getBlock();
1492 if (currBlock == predBlock)
1493 // Any dependency but RARs
1494 if (!(isa<handshake::LoadOp>(currOp) && isa<handshake::LoadOp>(predOp)))
1495 // cntrlInd maps memOps index to correct control output index
1496 controlOperands.push_back(memOp->getResult(offset + cntrlInd[j]));
1497 }
1498
1499 // If there is only one control input, add directly to memory op
1500 if (controlOperands.size() == 1)
1501 addValueToOperands(currOp, controlOperands[0]);
1502
1503 // If multiple, join them and connect join output to memory op
1504 else {
1505 rewriter.setInsertionPoint(currOp);
1506 Operation *joinOp =
1507 JoinOp::create(rewriter, currOp->getLoc(), controlOperands);
1508 addValueToOperands(currOp, joinOp->getResult(0));
1509 }
1510 }
1511}
1512
1513LogicalResult
1514HandshakeLowering::connectToMemory(ConversionPatternRewriter &rewriter,
1515 MemRefToMemoryAccessOp memRefOps, bool lsq) {
1516 // Add MemoryOps which represent the memory interface
1517 // Connect memory operations and control appropriately
1518 int mem_count = 0;
1519 for (auto memory : memRefOps) {
1520 // First operand corresponds to memref (alloca or function argument)
1521 Value memrefOperand = memory.first;
1522
1523 // A memory is external if the memref that defines it is provided as a
1524 // function (block) argument.
1525 bool isExternalMemory = isa<BlockArgument>(memrefOperand);
1526
1527 mlir::MemRefType memrefType =
1528 cast<mlir::MemRefType>(memrefOperand.getType());
1529 if (failed(isValidMemrefType(memrefOperand.getLoc(), memrefType)))
1530 return failure();
1531
1532 std::vector<Value> operands;
1533
1534 // Get control terminators whose control operand need to connect to memory
1535 std::vector<BlockControlTerm> controlTerms =
1536 getControlTerminators(memory.second);
1537
1538 // In case of LSQ interface, set control values as inputs (used to
1539 // trigger allocation to LSQ)
1540 if (lsq)
1541 for (auto valOp : controlTerms)
1542 operands.push_back(valOp.ctrlOperand);
1543
1544 // Add load indices and store data+indices to memory operands
1545 // Count number of loads so that we can generate appropriate number of
1546 // memory outputs (data to load ops)
1547
1548 // memory.second is in program order
1549 // Enforce MemoryOp port ordering as follows:
1550 // Operands: all stores then all loads (stdata1, staddr1, stdata2,...,
1551 // ldaddr1, ldaddr2,....) Outputs: all load outputs, ordered the same as
1552 // load addresses (lddata1, lddata2, ...), followed by all none outputs,
1553 // ordered as operands (stnone1, stnone2,...ldnone1, ldnone2,...)
1554 std::vector<int> newInd(memory.second.size(), 0);
1555 int ind = 0;
1556 for (int i = 0, e = memory.second.size(); i < e; ++i) {
1557 auto *op = memory.second[i];
1558 if (isa<handshake::StoreOp>(op)) {
1559 SmallVector<Value, 8> results = getResultsToMemory(op);
1560 operands.insert(operands.end(), results.begin(), results.end());
1561 newInd[i] = ind++;
1562 }
1563 }
1564
1565 int ld_count = 0;
1566
1567 for (int i = 0, e = memory.second.size(); i < e; ++i) {
1568 auto *op = memory.second[i];
1569 if (isa<handshake::LoadOp>(op)) {
1570 SmallVector<Value, 8> results = getResultsToMemory(op);
1571 operands.insert(operands.end(), results.begin(), results.end());
1572
1573 ld_count++;
1574 newInd[i] = ind++;
1575 }
1576 }
1577
1578 // control-only outputs for each access (indicate access completion)
1579 int cntrl_count = lsq ? 0 : memory.second.size();
1580
1581 Block *entryBlock = &r.front();
1582 rewriter.setInsertionPointToStart(entryBlock);
1583
1584 // Place memory op next to the alloc op
1585 Operation *newOp = nullptr;
1586 if (isExternalMemory)
1587 newOp = ExternalMemoryOp::create(rewriter, entryBlock->front().getLoc(),
1588 memrefOperand, operands, ld_count,
1589 cntrl_count - ld_count, mem_count++);
1590 else
1591 newOp = MemoryOp::create(rewriter, entryBlock->front().getLoc(), operands,
1592 ld_count, cntrl_count, lsq, mem_count++,
1593 memrefOperand);
1594
1595 setLoadDataInputs(memory.second, newOp);
1596
1597 if (!lsq) {
1598 // Create Joins which join done signals from memory with the
1599 // control-only network
1600 addJoinOps(rewriter, controlTerms);
1601
1602 // Connect all load/store done signals to the join of their block
1603 // Ensure that the block terminates only after all its accesses have
1604 // completed
1605 // True is default. When no sync needed, set to false (for now,
1606 // user-determined)
1607 bool control = true;
1608
1609 if (control &&
1610 setJoinControlInputs(memory.second, newOp, ld_count, newInd).failed())
1611 return failure();
1612
1613 // Set control-only inputs to each memory op
1614 // Ensure that op starts only after prior blocks have completed
1615 // Ensure that op starts only after predecessor ops (with RAW, WAR, or
1616 // WAW) have completed
1617 setMemOpControlInputs(rewriter, memory.second, newOp, ld_count, newInd);
1618 }
1619 }
1620
1621 if (lsq)
1622 addLazyForks(r, rewriter);
1623
1624 removeUnusedAllocOps(r, rewriter);
1625 return success();
1626}
1627
1628LogicalResult
1629HandshakeLowering::replaceCallOps(ConversionPatternRewriter &rewriter) {
1630 for (Block &block : r) {
1631 /// An instance is activated whenever control arrives at the basic block
1632 /// of the source callOp.
1633 Value blockEntryControl = getBlockEntryControl(&block);
1634 for (Operation &op : block) {
1635 if (auto callOp = dyn_cast<mlir::func::CallOp>(op)) {
1636 llvm::SmallVector<Value> operands;
1637 llvm::copy(callOp.getOperands(), std::back_inserter(operands));
1638 operands.push_back(blockEntryControl);
1639 rewriter.setInsertionPoint(callOp);
1640 auto instanceOp = handshake::InstanceOp::create(
1641 rewriter, callOp.getLoc(), callOp.getCallee(),
1642 callOp.getResultTypes(), operands);
1643 // Replace all results of the source callOp.
1644 for (auto it : llvm::zip(callOp.getResults(), instanceOp.getResults()))
1645 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
1646 rewriter.eraseOp(callOp);
1647 }
1648 }
1649 }
1650 return success();
1651}
1652
1653namespace {
1654/// Strategy class for SSA maximization during cf-to-handshake conversion.
1655/// Block arguments of type MemRefType and allocation operations are not
1656/// considered for SSA maximization.
1657class HandshakeLoweringSSAStrategy : public SSAMaximizationStrategy {
1658 /// Filters out block arguments of type MemRefType
1659 bool maximizeArgument(BlockArgument arg) override {
1660 return !isa<mlir::MemRefType>(arg.getType());
1661 }
1662
1663 /// Filters out allocation operations
1664 bool maximizeOp(Operation *op) override { return !isAllocOp(op); }
1665};
1666} // namespace
1667
1668/// Converts every value in the region into maximal SSA form, unless the value
1669/// is a block argument of type MemRefType or the result of an allocation
1670/// operation.
1671static LogicalResult maximizeSSANoMem(Region &r,
1672 ConversionPatternRewriter &rewriter) {
1673 HandshakeLoweringSSAStrategy strategy;
1674 return maximizeSSA(r, strategy, rewriter);
1675}
1676
1677static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
1678 bool sourceConstants,
1679 bool disableTaskPipelining) {
1680 // Only retain those attributes that are not constructed by build.
1681 SmallVector<NamedAttribute, 4> attributes;
1682 for (const auto &attr : funcOp->getAttrs()) {
1683 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
1684 attr.getName() == funcOp.getFunctionTypeAttrName())
1685 continue;
1686 attributes.push_back(attr);
1687 }
1688
1689 // Get function arguments
1690 llvm::SmallVector<mlir::Type, 8> argTypes;
1691 for (auto &argType : funcOp.getArgumentTypes())
1692 argTypes.push_back(argType);
1693
1694 // Get function results
1695 llvm::SmallVector<mlir::Type, 8> resTypes;
1696 for (auto resType : funcOp.getResultTypes())
1697 resTypes.push_back(resType);
1698
1699 handshake::FuncOp newFuncOp;
1700
1701 // Add control input/output to function arguments/results and create a
1702 // handshake::FuncOp of appropriate type
1703 if (partiallyLowerOp<func::FuncOp>(
1704 [&](func::FuncOp funcOp, PatternRewriter &rewriter) {
1705 auto noneType = rewriter.getNoneType();
1706 resTypes.push_back(noneType);
1707 argTypes.push_back(noneType);
1708 auto func_type = rewriter.getFunctionType(argTypes, resTypes);
1709 newFuncOp = handshake::FuncOp::create(rewriter, funcOp.getLoc(),
1710 funcOp.getName(), func_type,
1711 attributes);
1712 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1713 newFuncOp.end());
1714 if (!newFuncOp.isExternal()) {
1715 newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(),
1716 funcOp.getLoc());
1717 newFuncOp.resolveArgAndResNames();
1718 }
1719 rewriter.eraseOp(funcOp);
1720 return success();
1721 },
1722 ctx, funcOp)
1723 .failed())
1724 return failure();
1725
1726 // Apply SSA maximization
1727 if (partiallyLowerRegion(maximizeSSANoMem, ctx, newFuncOp.getBody()).failed())
1728 return failure();
1729
1730 if (!newFuncOp.isExternal()) {
1731 Block *bodyBlock = newFuncOp.getBodyBlock();
1732 Value entryCtrl = bodyBlock->getArguments().back();
1733 HandshakeLowering fol(newFuncOp.getBody());
1734 if (failed(lowerRegion<func::ReturnOp, handshake::ReturnOp>(
1735 fol, sourceConstants, disableTaskPipelining, entryCtrl)))
1736 return failure();
1737 }
1738
1739 return success();
1740}
1741
1742namespace {
1743
1744struct HandshakeRemoveBlockPass
1745 : circt::impl::HandshakeRemoveBlockBase<HandshakeRemoveBlockPass> {
1746 void runOnOperation() override { removeBasicBlocks(getOperation()); }
1747};
1748
1749struct CFToHandshakePass
1750 : public circt::impl::CFToHandshakeBase<CFToHandshakePass> {
1751 CFToHandshakePass(bool sourceConstants, bool disableTaskPipelining) {
1752 this->sourceConstants = sourceConstants;
1753 this->disableTaskPipelining = disableTaskPipelining;
1754 }
1755 void runOnOperation() override {
1756 ModuleOp m = getOperation();
1757
1758 for (auto funcOp : llvm::make_early_inc_range(m.getOps<func::FuncOp>())) {
1759 if (failed(lowerFuncOp(funcOp, &getContext(), sourceConstants,
1760 disableTaskPipelining))) {
1761 signalPassFailure();
1762 return;
1763 }
1764 }
1765 }
1766};
1767
1768} // namespace
1769
1770std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1772 bool disableTaskPipelining) {
1773 return std::make_unique<CFToHandshakePass>(sourceConstants,
1774 disableTaskPipelining);
1775}
1776
1777std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
1779 return std::make_unique<HandshakeRemoveBlockPass>();
1780}
static ConditionalBranchOp getControlCondBranch(Block *block)
static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx, bool sourceConstants, bool disableTaskPipelining)
static Operation * getControlMerge(Block *block)
static bool isMemoryOp(Operation *op)
static std::vector< Value > getSortedInputs(ControlMergeOp cmerge, MuxOp mux)
static LogicalResult setJoinControlInputs(ArrayRef< Operation * > memOps, Operation *memOp, int offset, ArrayRef< int > cntrlInd)
static void addJoinOps(ConversionPatternRewriter &rewriter, ArrayRef< BlockControlTerm > controlTerms)
static void addLazyForks(Region &f, ConversionPatternRewriter &rewriter)
static bool isLiveOut(Value val)
static Operation * getFirstOp(Block *block)
Returns the first occurance of an operation of type TOp, else, returns null op.
static unsigned getBlockPredecessorCount(Block *block)
static int getBranchCount(Value val, Block *block)
static Operation * findBranchToBlock(Block *block)
static Value getSuccResult(Operation *termOp, Operation *newOp, Block *succBlock)
static Value getMergeOperand(HandshakeLowering::MergeOpInfo mergeInfo, Block *predBlock)
static LogicalResult isValidMemrefType(Location loc, mlir::MemRefType type)
static bool isAllocOp(Operation *op)
static Value getOperandFromBlock(MuxOp mux, Block *block)
static LogicalResult getOpMemRef(Operation *op, Value &out)
static LogicalResult partiallyLowerOp(const std::function< LogicalResult(TOp, ConversionPatternRewriter &)> &loweringFunc, MLIRContext *ctx, TOp op)
static void addValueToOperands(Operation *op, Value val)
static bool loopsHaveSingleExit(CFGLoopInfo &loopInfo)
static SmallVector< Value, 8 > getResultsToMemory(Operation *op)
static void removeBlockOperands(Region &f)
static void removeUnusedAllocOps(Region &r, ConversionPatternRewriter &rewriter)
static LogicalResult maximizeSSANoMem(Region &r, ConversionPatternRewriter &rewriter)
Converts every value in the region into maximal SSA form, unless the value is a block argument of typ...
static BlockControlTerm getBlockControlTerminator(Block *block)
static std::vector< BlockControlTerm > getControlTerminators(ArrayRef< Operation * > memOps)
static void reconnectMergeOps(Region &r, HandshakeLowering::BlockOps blockMerges, HandshakeLowering::ValueMap &mergePairs)
static void setLoadDataInputs(ArrayRef< Operation * > memOps, Operation *memOp)
assert(baseType &&"element must be base type")
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
Strategy strategy
LowerRegionTarget(MLIRContext &context, Region &region)
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Strategy class to control the behavior of SSA maximization.
Definition Passes.h:77
BlockOps insertMergeOps(ValueMap &mergePairs, BackedgeBuilder &edgeBuilder, ConversionPatternRewriter &rewriter)
MergeOpInfo insertMerge(Block *block, Value val, BackedgeBuilder &edgeBuilder, ConversionPatternRewriter &rewriter)
LogicalResult loopNetworkRewriting(ConversionPatternRewriter &rewriter)
DenseMap< Block *, std::vector< MergeOpInfo > > BlockOps
LogicalResult feedForwardRewriting(ConversionPatternRewriter &rewriter)
LogicalResult replaceCallOps(ConversionPatternRewriter &rewriter)
void setMemOpControlInputs(ConversionPatternRewriter &rewriter, ArrayRef< Operation * > memOps, Operation *memOp, int offset, ArrayRef< int > cntrlInd)
LogicalResult addMergeOps(ConversionPatternRewriter &rewriter)
LogicalResult replaceMemoryOps(ConversionPatternRewriter &rewriter, MemRefToMemoryAccessOp &memRefOps)
DenseMap< Block *, std::vector< Value > > BlockValues
LogicalResult connectConstantsToControl(ConversionPatternRewriter &rewriter, bool sourceConstants)
llvm::MapVector< Value, std::vector< Operation * > > MemRefToMemoryAccessOp
LogicalResult runSSAMaximization(ConversionPatternRewriter &rewriter, Value entryCtrl)
DenseMap< Block *, Value > blockEntryControlMap
DenseMap< Value, Value > ValueMap
Value getBlockEntryControl(Block *block) const
void setBlockEntryControl(Block *block, Value v)
LogicalResult connectToMemory(ConversionPatternRewriter &rewriter, MemRefToMemoryAccessOp memRefOps, bool lsq)
LogicalResult addBranchOps(ConversionPatternRewriter &rewriter)
FuncOp create(Union[StringAttr, str] sym_name, List[Tuple[str, Type]] args, List[Tuple[str, Type]] results, Dict[str, Attribute] attributes={}, loc=None, ip=None)
Definition handshake.py:36
void insertFork(Value result, bool isLazy, OpBuilder &rewriter)
Adds fork operations to any value with multiple uses in r.
llvm::function_ref< LogicalResult(Region &, ConversionPatternRewriter &)> RegionLoweringFunc
void removeBasicBlocks(Region &r)
Remove basic blocks inside the given region.
LogicalResult partiallyLowerRegion(const RegionLoweringFunc &loweringFunc, MLIRContext *ctx, Region &r)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createCFToHandshakePass(bool sourceConstants=false, bool disableTaskPipelining=false)
std::unique_ptr< mlir::OperationPass< handshake::FuncOp > > createHandshakeRemoveBlockPass()
Holds information about an handshake "basic block terminator" control operation.
friend bool operator==(const BlockControlTerm &lhs, const BlockControlTerm &rhs)
Checks for member-wise equality.
Value ctrlOperand
The operation's control operand (must have type NoneType)
BlockControlTerm(Operation *op, Value ctrlOperand)
Operation * op
The operation.
Allows to partially lower a region by matching on the parent operation to then call the provided part...
PartialLoweringFunc fun
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value >, ConversionPatternRewriter &rewriter) const override
std::function< LogicalResult(Region &, ConversionPatternRewriter &)> PartialLoweringFunc
PartialLowerRegion(LowerRegionTarget &target, MLIRContext *context, LogicalResult &loweringResRef, const PartialLoweringFunc &fun)
LogicalResult & loweringRes
LowerRegionTarget & target