CIRCT  20.0.0git
CFToHandshake.cpp
Go to the documentation of this file.
1 //===- CFToHandshake.cpp - Convert standard MLIR into dataflow IR ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This is the main Standard to Handshake Conversion Pass Implementation.
9 //
10 //===----------------------------------------------------------------------===//
11 
18 #include "mlir/Analysis/CFGLoopInfo.h"
19 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
20 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
21 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
24 #include "mlir/Dialect/Affine/Utils.h"
25 #include "mlir/Dialect/Arith/IR/Arith.h"
26 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"
29 #include "mlir/Dialect/SCF/IR/SCF.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Diagnostics.h"
33 #include "mlir/IR/Dominance.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/Types.h"
37 #include "mlir/IR/Value.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Support/LLVM.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 #include "mlir/Transforms/Passes.h"
42 #include "llvm/ADT/SmallSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/Support/raw_ostream.h"
45 
46 #include <list>
47 #include <map>
48 
49 namespace circt {
50 #define GEN_PASS_DEF_CFTOHANDSHAKE
51 #define GEN_PASS_DEF_HANDSHAKEREMOVEBLOCK
52 #include "circt/Conversion/Passes.h.inc"
53 } // namespace circt
54 
55 using namespace mlir;
56 using namespace mlir::func;
57 using namespace mlir::affine;
58 using namespace circt;
59 using namespace circt::handshake;
60 using namespace std;
61 
62 // ============================================================================
63 // Partial lowering infrastructure
64 // ============================================================================
65 
66 namespace {
67 template <typename TOp>
68 class LowerOpTarget : public ConversionTarget {
69 public:
70  explicit LowerOpTarget(MLIRContext &context) : ConversionTarget(context) {
71  loweredOps.clear();
72  addLegalDialect<HandshakeDialect>();
73  addLegalDialect<mlir::func::FuncDialect>();
74  addLegalDialect<mlir::arith::ArithDialect>();
75  addIllegalDialect<mlir::scf::SCFDialect>();
76  addIllegalDialect<AffineDialect>();
77 
78  /// The root operation to be replaced is marked dynamically legal
79  /// based on the lowering status of the given operation, see
80  /// PartialLowerOp.
81  addDynamicallyLegalOp<TOp>([&](const auto &op) { return loweredOps[op]; });
82  }
83  DenseMap<Operation *, bool> loweredOps;
84 };
85 
86 /// Default function for partial lowering of handshake::FuncOp. Lowering is
87 /// achieved by a provided partial lowering function.
88 ///
89 /// A partial lowering function may only replace a subset of the operations
90 /// within the funcOp currently being lowered. However, the dialect conversion
91 /// scheme requires the matched root operation to be replaced/updated/erased. It
92 /// is the partial update function's responsibility to ensure this. The parital
93 /// update function may only mutate the IR through the provided
94 /// ConversionPatternRewriter, like any other ConversionPattern.
95 /// Next, the function operation is expected to go
96 /// from illegal to legalized, after matchAndRewrite returned true. To work
97 /// around this, LowerFuncOpTarget::loweredFuncs is used to communicate between
98 /// the target and the conversion, to indicate that the partial lowering was
99 /// completed.
100 template <typename TOp>
101 struct PartialLowerOp : public ConversionPattern {
102  using PartialLoweringFunc =
103  std::function<LogicalResult(TOp, ConversionPatternRewriter &)>;
104 
105 public:
106  PartialLowerOp(LowerOpTarget<TOp> &target, MLIRContext *context,
107  LogicalResult &loweringResRef, const PartialLoweringFunc &fun)
108  : ConversionPattern(TOp::getOperationName(), 1, context), target(target),
109  loweringRes(loweringResRef), fun(fun) {}
110  using ConversionPattern::ConversionPattern;
111  LogicalResult
112  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
113  ConversionPatternRewriter &rewriter) const override {
114  assert(isa<TOp>(op));
115  loweringRes = fun(dyn_cast<TOp>(op), rewriter);
116  target.loweredOps[op] = true;
117  return loweringRes;
118  };
119 
120 private:
121  LowerOpTarget<TOp> &target;
122  LogicalResult &loweringRes;
123  // NOTE: this is basically the rewrite function
124  PartialLoweringFunc fun;
125 };
126 } // namespace
127 
128 // Convenience function for running lowerToHandshake with a partial
129 // handshake::FuncOp lowering function.
130 template <typename TOp>
131 static LogicalResult partiallyLowerOp(
132  const std::function<LogicalResult(TOp, ConversionPatternRewriter &)>
133  &loweringFunc,
134  MLIRContext *ctx, TOp op) {
135 
136  RewritePatternSet patterns(ctx);
137  auto target = LowerOpTarget<TOp>(*ctx);
138  LogicalResult partialLoweringSuccessfull = success();
139  patterns.add<PartialLowerOp<TOp>>(target, ctx, partialLoweringSuccessfull,
140  loweringFunc);
141  return success(
142  applyPartialConversion(op, target, std::move(patterns)).succeeded() &&
143  partialLoweringSuccessfull.succeeded());
144 }
145 
146 class LowerRegionTarget : public ConversionTarget {
147 public:
148  explicit LowerRegionTarget(MLIRContext &context, Region &region)
149  : ConversionTarget(context), region(region) {
150  // The root operation is marked dynamically legal to ensure
151  // the pattern on its region is only applied once.
152  markUnknownOpDynamicallyLegal([&](Operation *op) {
153  if (op != region.getParentOp())
154  return true;
155  return opLowered;
156  });
157  }
158  bool opLowered = false;
159  Region &region;
160 };
161 
162 /// Allows to partially lower a region by matching on the parent operation to
163 /// then call the provided partial lowering function with the region and the
164 /// rewriter.
165 ///
166 /// The interplay with the target is similar to PartialLowerOp
167 struct PartialLowerRegion : public ConversionPattern {
169  std::function<LogicalResult(Region &, ConversionPatternRewriter &)>;
170 
171 public:
172  PartialLowerRegion(LowerRegionTarget &target, MLIRContext *context,
173  LogicalResult &loweringResRef,
174  const PartialLoweringFunc &fun)
175  : ConversionPattern(target.region.getParentOp()->getName().getStringRef(),
176  1, context),
177  target(target), loweringRes(loweringResRef), fun(fun) {}
178  using ConversionPattern::ConversionPattern;
179  LogicalResult
180  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
181  ConversionPatternRewriter &rewriter) const override {
182  rewriter.modifyOpInPlace(
183  op, [&] { loweringRes = fun(target.region, rewriter); });
184 
185  target.opLowered = true;
186  return loweringRes;
187  };
188 
189 private:
191  LogicalResult &loweringRes;
193 };
194 
195 LogicalResult
197  MLIRContext *ctx, Region &r) {
198 
199  Operation *op = r.getParentOp();
200  RewritePatternSet patterns(ctx);
201  auto target = LowerRegionTarget(*ctx, r);
202  LogicalResult partialLoweringSuccessfull = success();
203  patterns.add<PartialLowerRegion>(target, ctx, partialLoweringSuccessfull,
204  loweringFunc);
205  return success(
206  applyPartialConversion(op, target, std::move(patterns)).succeeded() &&
207  partialLoweringSuccessfull.succeeded());
208 }
209 
210 // ============================================================================
211 // Start of lowering passes
212 // ============================================================================
213 
214 Value HandshakeLowering::getBlockEntryControl(Block *block) const {
215  auto it = blockEntryControlMap.find(block);
216  assert(it != blockEntryControlMap.end() &&
217  "No block entry control value registerred for this block!");
218  return it->second;
219 }
220 
221 void HandshakeLowering::setBlockEntryControl(Block *block, Value v) {
222  blockEntryControlMap[block] = v;
223 }
224 
226  Block *entryBlock = &r.front();
227  auto &entryBlockOps = entryBlock->getOperations();
228 
229  // Move all operations to entry block and erase other blocks.
230  for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
231  entryBlockOps.splice(entryBlockOps.end(), block.getOperations());
232 
233  block.clear();
234  block.dropAllDefinedValueUses();
235  for (size_t i = 0; i < block.getNumArguments(); i++) {
236  block.eraseArgument(i);
237  }
238  block.erase();
239  }
240 
241  // Remove any control flow operations, and move the non-control flow
242  // terminator op to the end of the entry block.
243  for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) {
244  if (!terminatorLike.hasTrait<OpTrait::IsTerminator>())
245  continue;
246 
247  if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(terminatorLike)) {
248  terminatorLike.erase();
249  continue;
250  }
251 
252  // Else, assume that this is a return-like terminator op.
253  terminatorLike.moveBefore(entryBlock, entryBlock->end());
254  }
255 }
256 
257 LogicalResult
258 HandshakeLowering::runSSAMaximization(ConversionPatternRewriter &rewriter,
259  Value entryCtrl) {
260  return maximizeSSA(entryCtrl, rewriter);
261 }
262 
263 void removeBasicBlocks(handshake::FuncOp funcOp) {
264  if (funcOp.isExternal())
265  return; // nothing to do, external funcOp.
266 
267  removeBasicBlocks(funcOp.getBody());
268 }
269 
270 static LogicalResult isValidMemrefType(Location loc, mlir::MemRefType type) {
271  if (type.getNumDynamicDims() != 0 || type.getShape().size() != 1)
272  return emitError(loc) << "memref's must be both statically sized and "
273  "unidimensional.";
274  return success();
275 }
276 
277 static unsigned getBlockPredecessorCount(Block *block) {
278  // Returns number of block predecessors
279  auto predecessors = block->getPredecessors();
280  return std::distance(predecessors.begin(), predecessors.end());
281 }
282 
283 // Insert appropriate type of Merge CMerge for control-only path,
284 // Merge for single-successor blocks, Mux otherwise
286 HandshakeLowering::insertMerge(Block *block, Value val,
287  BackedgeBuilder &edgeBuilder,
288  ConversionPatternRewriter &rewriter) {
289  unsigned numPredecessors = getBlockPredecessorCount(block);
290  auto insertLoc = block->front().getLoc();
291  SmallVector<Backedge> dataEdges;
292  SmallVector<Value> operands;
293 
294  // Every block (except the entry block) needs to feed it's entry control into
295  // a control merge
296  if (val == getBlockEntryControl(block)) {
297 
298  Operation *mergeOp;
299  if (block == &r.front()) {
300  // For consistency within the entry block, replace the latter's entry
301  // control with the output of a MergeOp that takes the control-only
302  // network's start point as input. This makes it so that only the
303  // MergeOp's output is used as a control within the entry block, instead
304  // of a combination of the MergeOp's output and the function/block control
305  // argument. Taking this step out should have no impact on functionality
306  // but would make the resulting IR less "regular"
307  operands.push_back(val);
308  mergeOp = rewriter.create<handshake::MergeOp>(insertLoc, operands);
309  } else {
310  for (unsigned i = 0; i < numPredecessors; i++) {
311  auto edge = edgeBuilder.get(rewriter.getNoneType());
312  dataEdges.push_back(edge);
313  operands.push_back(Value(edge));
314  }
315  mergeOp = rewriter.create<handshake::ControlMergeOp>(insertLoc, operands);
316  }
317  setBlockEntryControl(block, mergeOp->getResult(0));
318  return MergeOpInfo{mergeOp, val, dataEdges};
319  }
320 
321  // Every live-in value to a block is passed through a merge-like operation,
322  // even when it's not required for circuit correctness (useless merge-like
323  // operations are removed down the line during handshake canonicalization)
324 
325  // Insert "dummy" MergeOp's for blocks with less than two predecessors
326  if (numPredecessors <= 1) {
327  if (numPredecessors == 0) {
328  // All of the entry block's block arguments get passed through a dummy
329  // MergeOp. There is no need for a backedge here as the unique operand can
330  // be resolved immediately
331  operands.push_back(val);
332  } else {
333  // The value incoming from the single block predecessor will be resolved
334  // later during merge reconnection
335  auto edge = edgeBuilder.get(val.getType());
336  dataEdges.push_back(edge);
337  operands.push_back(Value(edge));
338  }
339  auto merge = rewriter.create<handshake::MergeOp>(insertLoc, operands);
340  return MergeOpInfo{merge, val, dataEdges};
341  }
342 
343  // Create a backedge for the index operand, and another one for each data
344  // operand. The index operand will eventually resolve to the current block's
345  // control merge index output, while data operands will resolve to their
346  // respective values from each block predecessor
347  Backedge indexEdge = edgeBuilder.get(rewriter.getIndexType());
348  for (unsigned i = 0; i < numPredecessors; i++) {
349  auto edge = edgeBuilder.get(val.getType());
350  dataEdges.push_back(edge);
351  operands.push_back(Value(edge));
352  }
353  auto mux =
354  rewriter.create<handshake::MuxOp>(insertLoc, Value(indexEdge), operands);
355  return MergeOpInfo{mux, val, dataEdges, indexEdge};
356 }
357 
359 HandshakeLowering::insertMergeOps(HandshakeLowering::ValueMap &mergePairs,
360  BackedgeBuilder &edgeBuilder,
361  ConversionPatternRewriter &rewriter) {
362  HandshakeLowering::BlockOps blockMerges;
363  for (Block &block : r) {
364  rewriter.setInsertionPointToStart(&block);
365 
366  // All of the block's live-ins are passed explictly through block arguments
367  // thanks to prior SSA maximization
368  for (auto &arg : block.getArguments()) {
369  // No merges on memref block arguments; these are handled separately
370  if (isa<mlir::MemRefType>(arg.getType()))
371  continue;
372 
373  auto mergeInfo = insertMerge(&block, arg, edgeBuilder, rewriter);
374  blockMerges[&block].push_back(mergeInfo);
375  mergePairs[arg] = mergeInfo.op->getResult(0);
376  }
377  }
378  return blockMerges;
379 }
380 
381 // Get value from predBlock which will be set as operand of op (merge)
383  Block *predBlock) {
384  // The input value to the merge operations
385  Value srcVal = mergeInfo.val;
386  // The block the merge operation belongs to
387  Block *block = mergeInfo.op->getBlock();
388 
389  // The block terminator is either a cf-level branch or cf-level conditional
390  // branch. In either case, identify the value passed to the block using its
391  // index in the list of block arguments
392  unsigned index = cast<BlockArgument>(srcVal).getArgNumber();
393  Operation *termOp = predBlock->getTerminator();
394  if (mlir::cf::CondBranchOp br = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
395  // Block should be one of the two destinations of the conditional branch
396  if (block == br.getTrueDest())
397  return br.getTrueOperand(index);
398  assert(block == br.getFalseDest());
399  return br.getFalseOperand(index);
400  }
401  if (isa<mlir::cf::BranchOp>(termOp))
402  return termOp->getOperand(index);
403  return nullptr;
404 }
405 
406 static void removeBlockOperands(Region &f) {
407  // Remove all block arguments, they are no longer used
408  // eraseArguments also removes corresponding branch operands
409  for (Block &block : f) {
410  if (!block.isEntryBlock()) {
411  int x = block.getNumArguments() - 1;
412  for (int i = x; i >= 0; --i)
413  block.eraseArgument(i);
414  }
415  }
416 }
417 
418 /// Returns the first occurance of an operation of type TOp, else, returns
419 /// null op.
420 template <typename TOp>
421 static Operation *getFirstOp(Block *block) {
422  auto ops = block->getOps<TOp>();
423  if (ops.empty())
424  return nullptr;
425  return *ops.begin();
426 }
427 
428 static Operation *getControlMerge(Block *block) {
429  return getFirstOp<ControlMergeOp>(block);
430 }
431 
432 static ConditionalBranchOp getControlCondBranch(Block *block) {
433  for (auto cbranch : block->getOps<handshake::ConditionalBranchOp>()) {
434  if (cbranch.isControl())
435  return cbranch;
436  }
437  return nullptr;
438 }
439 
440 static void reconnectMergeOps(Region &r,
441  HandshakeLowering::BlockOps blockMerges,
442  HandshakeLowering::ValueMap &mergePairs) {
443  // At this point all merge-like operations have backedges as operands.
444  // We here replace all backedge values with appropriate value from
445  // predecessor block. The predecessor can either be a merge, the original
446  // defining value, or a branch operand.
447 
448  for (Block &block : r) {
449  for (auto &mergeInfo : blockMerges[&block]) {
450  int operandIdx = 0;
451  // Set appropriate operand from each predecessor block
452  for (auto *predBlock : block.getPredecessors()) {
453  Value mgOperand = getMergeOperand(mergeInfo, predBlock);
454  assert(mgOperand != nullptr);
455  if (!mgOperand.getDefiningOp()) {
456  assert(mergePairs.count(mgOperand));
457  mgOperand = mergePairs[mgOperand];
458  }
459  mergeInfo.dataEdges[operandIdx].setValue(mgOperand);
460  operandIdx++;
461  }
462 
463  // Reconnect all operands originating from livein defining value through
464  // corresponding merge of that block
465  for (Operation &opp : block)
466  if (!isa<MergeLikeOpInterface>(opp))
467  opp.replaceUsesOfWith(mergeInfo.val, mergeInfo.op->getResult(0));
468  }
469  }
470 
471  // Connect select operand of muxes to control merge's index result in all
472  // blocks with more than one predecessor
473  for (Block &block : r) {
474  if (getBlockPredecessorCount(&block) > 1) {
475  Operation *cntrlMg = getControlMerge(&block);
476  assert(cntrlMg != nullptr);
477 
478  for (auto &mergeInfo : blockMerges[&block]) {
479  if (mergeInfo.op != cntrlMg) {
480  // If the block has multiple predecessors, merge-like operation that
481  // are not the block's control merge must have an index operand (at
482  // this point, an index backedge)
483  assert(mergeInfo.indexEdge.has_value());
484  (*mergeInfo.indexEdge).setValue(cntrlMg->getResult(1));
485  }
486  }
487  }
488  }
489 
491 }
492 
493 static bool isAllocOp(Operation *op) {
494  return isa<memref::AllocOp, memref::AllocaOp>(op);
495 }
496 
497 LogicalResult
498 HandshakeLowering::addMergeOps(ConversionPatternRewriter &rewriter) {
499 
500  // Stores mapping from each value that pass through a merge operation to the
501  // first result of that merge operation
502  ValueMap mergePairs;
503 
504  // Create backedge builder to manage operands of merge operations between
505  // insertion and reconnection
506  BackedgeBuilder edgeBuilder{rewriter, r.front().front().getLoc()};
507 
508  // Insert merge operations (with backedges instead of actual operands)
509  BlockOps mergeOps = insertMergeOps(mergePairs, edgeBuilder, rewriter);
510 
511  // Reconnect merge operations with values incoming from predecessor blocks
512  // and resolve all backedges that were created during merge insertion
513  reconnectMergeOps(r, mergeOps, mergePairs);
514  return success();
515 }
516 
517 static bool isLiveOut(Value val) {
518  // Identifies liveout values after adding Merges
519  for (auto &u : val.getUses())
520  // Result is liveout if used by some Merge block
521  if (isa<MergeLikeOpInterface>(u.getOwner()))
522  return true;
523  return false;
524 }
525 
526 // A value can have multiple branches in a single successor block
527 // (for instance, there can be an SSA phi and a merge that we insert)
528 // This function determines the number of branches to insert based on the
529 // value uses in successor blocks
530 static int getBranchCount(Value val, Block *block) {
531  int uses = 0;
532  for (int i = 0, e = block->getNumSuccessors(); i < e; ++i) {
533  int curr = 0;
534  Block *succ = block->getSuccessor(i);
535  for (auto &u : val.getUses()) {
536  if (u.getOwner()->getBlock() == succ)
537  curr++;
538  }
539  uses = (curr > uses) ? curr : uses;
540  }
541  return uses;
542 }
543 
544 namespace {
545 
546 /// This class inserts a reorder prevention mechanism for blocks with multiple
547 /// successors. Such a mechanism is required to guarantee correct execution in a
548 /// multi-threaded usage of the circuits.
549 ///
550 /// The order of the results matches the order of the traversals of the
551 /// divergence point. A FIFO buffer stores the condition of the conditional
552 /// branch. The buffer feeds a mux that guarantees the correct out-order.
553 class FeedForwardNetworkRewriter {
554 public:
555  FeedForwardNetworkRewriter(HandshakeLowering &hl,
556  ConversionPatternRewriter &rewriter)
557  : hl(hl), rewriter(rewriter), postDomInfo(hl.getRegion().getParentOp()),
558  domInfo(hl.getRegion().getParentOp()),
559  loopInfo(domInfo.getDomTree(&hl.getRegion())) {}
560  LogicalResult apply();
561 
562 private:
563  HandshakeLowering &hl;
564  ConversionPatternRewriter &rewriter;
565  PostDominanceInfo postDomInfo;
566  DominanceInfo domInfo;
567  CFGLoopInfo loopInfo;
568 
569  using BlockPair = std::pair<Block *, Block *>;
570  using BlockPairs = SmallVector<BlockPair>;
571  LogicalResult findBlockPairs(BlockPairs &blockPairs);
572 
573  BufferOp buildSplitNetwork(Block *splitBlock,
574  handshake::ConditionalBranchOp &ctrlBr);
575  LogicalResult buildMergeNetwork(Block *mergeBlock, BufferOp buf,
576  handshake::ConditionalBranchOp &ctrlBr);
577 
578  // Determines if the cmerge inpus match the cond_br output order.
579  bool requiresOperandFlip(ControlMergeOp &ctrlMerge,
580  handshake::ConditionalBranchOp &ctrlBr);
581  bool formsIrreducibleCF(Block *splitBlock, Block *mergeBlock);
582 };
583 } // namespace
584 
585 LogicalResult
586 HandshakeLowering::feedForwardRewriting(ConversionPatternRewriter &rewriter) {
587  // Nothing to do on a single block region.
588  if (this->getRegion().hasOneBlock())
589  return success();
590  return FeedForwardNetworkRewriter(*this, rewriter).apply();
591 }
592 
593 [[maybe_unused]] static bool loopsHaveSingleExit(CFGLoopInfo &loopInfo) {
594  for (CFGLoop *loop : loopInfo.getTopLevelLoops())
595  if (!loop->getExitBlock())
596  return false;
597  return true;
598 }
599 
600 bool FeedForwardNetworkRewriter::formsIrreducibleCF(Block *splitBlock,
601  Block *mergeBlock) {
602  CFGLoop *loop = loopInfo.getLoopFor(mergeBlock);
603  for (auto *mergePred : mergeBlock->getPredecessors()) {
604  // Skip loop predecessors
605  if (loop && loop->contains(mergePred))
606  continue;
607 
608  // A DAG-CFG is irreducible, iff a merge block has a predecessor that can be
609  // reached from both successors of a split node, e.g., neither is a
610  // dominator.
611  // => Their control flow can merge in other places, which makes this
612  // irreducible.
613  if (llvm::none_of(splitBlock->getSuccessors(), [&](Block *splitSucc) {
614  if (splitSucc == mergeBlock || mergePred == splitBlock)
615  return true;
616  return domInfo.dominates(splitSucc, mergePred);
617  }))
618  return true;
619  }
620  return false;
621 }
622 
623 static Operation *findBranchToBlock(Block *block) {
624  Block *pred = *block->getPredecessors().begin();
625  return pred->getTerminator();
626 }
627 
628 LogicalResult
629 FeedForwardNetworkRewriter::findBlockPairs(BlockPairs &blockPairs) {
630  // assumes that merge block insertion happended beforehand
631  // Thus, for each split block, there exists one merge block which is the post
632  // dominator of the child nodes.
633  Region &r = hl.getRegion();
634  Operation *parentOp = r.getParentOp();
635 
636  // Assumes that each loop has only one exit block. Such an error should
637  // already be reported by the loop rewriting.
638  assert(loopsHaveSingleExit(loopInfo) &&
639  "expected loop to only have one exit block.");
640 
641  for (Block &b : r) {
642  if (b.getNumSuccessors() < 2)
643  continue;
644 
645  // Loop headers cannot be merge blocks.
646  if (loopInfo.getLoopFor(&b))
647  continue;
648 
649  assert(b.getNumSuccessors() == 2);
650  Block *succ0 = b.getSuccessor(0);
651  Block *succ1 = b.getSuccessor(1);
652 
653  if (succ0 == succ1)
654  continue;
655 
656  Block *mergeBlock = postDomInfo.findNearestCommonDominator(succ0, succ1);
657 
658  // Precondition checks
659  if (formsIrreducibleCF(&b, mergeBlock)) {
660  return parentOp->emitError("expected only reducible control flow.")
661  .attachNote(findBranchToBlock(mergeBlock)->getLoc())
662  << "This branch is involved in the irreducible control flow";
663  }
664 
665  unsigned nonLoopPreds = 0;
666  CFGLoop *loop = loopInfo.getLoopFor(mergeBlock);
667  for (auto *pred : mergeBlock->getPredecessors()) {
668  if (loop && loop->contains(pred))
669  continue;
670  nonLoopPreds++;
671  }
672  if (nonLoopPreds > 2)
673  return parentOp
674  ->emitError("expected a merge block to have two predecessors. "
675  "Did you run the merge block insertion pass?")
676  .attachNote(findBranchToBlock(mergeBlock)->getLoc())
677  << "This branch jumps to the illegal block";
678 
679  blockPairs.emplace_back(&b, mergeBlock);
680  }
681 
682  return success();
683 }
684 
685 LogicalResult FeedForwardNetworkRewriter::apply() {
686  BlockPairs pairs;
687 
688  if (failed(findBlockPairs(pairs)))
689  return failure();
690 
691  for (auto [splitBlock, mergeBlock] : pairs) {
692  handshake::ConditionalBranchOp ctrlBr;
693  BufferOp buffer = buildSplitNetwork(splitBlock, ctrlBr);
694  if (failed(buildMergeNetwork(mergeBlock, buffer, ctrlBr)))
695  return failure();
696  }
697 
698  return success();
699 }
700 
701 BufferOp FeedForwardNetworkRewriter::buildSplitNetwork(
702  Block *splitBlock, handshake::ConditionalBranchOp &ctrlBr) {
703  SmallVector<handshake::ConditionalBranchOp> branches;
704  llvm::copy(splitBlock->getOps<handshake::ConditionalBranchOp>(),
705  std::back_inserter(branches));
706 
707  auto *findRes = llvm::find_if(branches, [](auto br) {
708  return llvm::isa<NoneType>(br.getDataOperand().getType());
709  });
710 
711  assert(findRes && "expected one branch for the ctrl signal");
712  ctrlBr = *findRes;
713 
714  Value cond = ctrlBr.getConditionOperand();
715  assert(llvm::all_of(branches, [&](auto branch) {
716  return branch.getConditionOperand() == cond;
717  }));
718 
719  Location loc = cond.getLoc();
720  rewriter.setInsertionPointAfterValue(cond);
721 
722  // The buffer size defines the number of threads that can be concurently
723  // traversing the sub-CFG starting at the splitBlock.
724  size_t bufferSize = 2;
725  // TODO how to size these?
726  // Longest path in a CFG-DAG would be O(#blocks)
727 
728  return rewriter.create<handshake::BufferOp>(loc, cond, bufferSize,
729  BufferTypeEnum::fifo);
730 }
731 
732 LogicalResult FeedForwardNetworkRewriter::buildMergeNetwork(
733  Block *mergeBlock, BufferOp buf, handshake::ConditionalBranchOp &ctrlBr) {
734  // Replace control merge with mux
735  auto ctrlMerges = mergeBlock->getOps<handshake::ControlMergeOp>();
736  assert(std::distance(ctrlMerges.begin(), ctrlMerges.end()) == 1);
737 
738  handshake::ControlMergeOp ctrlMerge = *ctrlMerges.begin();
739  // This input might contain irreducible loops that we cannot handle.
740  if (ctrlMerge.getNumOperands() != 2)
741  return ctrlMerge.emitError("expected cmerges to have two operands");
742  rewriter.setInsertionPointAfter(ctrlMerge);
743  Location loc = ctrlMerge->getLoc();
744 
745  // The newly inserted mux has to select the results from the correct operand.
746  // As there is no guarantee on the order of cmerge inputs, the correct order
747  // has to be determined first.
748  bool requiresFlip = requiresOperandFlip(ctrlMerge, ctrlBr);
749  SmallVector<Value> muxOperands;
750  if (requiresFlip)
751  muxOperands = llvm::to_vector(llvm::reverse(ctrlMerge.getOperands()));
752  else
753  muxOperands = llvm::to_vector(ctrlMerge.getOperands());
754 
755  Value newCtrl = rewriter.create<handshake::MuxOp>(loc, buf, muxOperands);
756 
757  Value cond = buf.getResult();
758  if (requiresFlip) {
759  // As the mux operand order is the flipped cmerge input order, the index
760  // which replaces the output of the cmerge has to be flipped/negated as
761  // well.
762  cond = rewriter.create<arith::XOrIOp>(
763  loc, cond.getType(), cond,
764  rewriter.create<arith::ConstantOp>(
765  loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 1)));
766  }
767 
768  // Require a cast to index to stick to the type of the mux input.
769  Value condAsIndex =
770  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), cond);
771 
772  hl.setBlockEntryControl(mergeBlock, newCtrl);
773 
774  // Replace with new ctrl value from mux and the index
775  rewriter.replaceOp(ctrlMerge, {newCtrl, condAsIndex});
776  return success();
777 }
778 
779 bool FeedForwardNetworkRewriter::requiresOperandFlip(
780  ControlMergeOp &ctrlMerge, handshake::ConditionalBranchOp &ctrlBr) {
781  assert(ctrlMerge.getNumOperands() == 2 &&
782  "Loops should already have been handled");
783 
784  Value fstOperand = ctrlMerge.getOperand(0);
785 
786  assert(ctrlBr.getTrueResult().hasOneUse() &&
787  "expected the result of a branch to only have one user");
788  Operation *trueUser = *ctrlBr.getTrueResult().user_begin();
789  if (trueUser == ctrlBr)
790  // The cmerge directly consumes the cond_br output.
791  return ctrlBr.getTrueResult() == fstOperand;
792 
793  // The cmerge is consumed in an intermediate block. Find out if this block is
794  // a predecessor of the "true" successor of the cmerge.
795  Block *trueBlock = trueUser->getBlock();
796  return domInfo.dominates(trueBlock, fstOperand.getDefiningOp()->getBlock());
797 }
798 
799 namespace {
800 // This function creates the loop 'continue' and 'exit' network around backedges
801 // in the CFG.
802 // We don't have a standard dialect based LoopInfo utility in MLIR
803 // (which could probably deduce most of the information that we need for this
804 // transformation), so we roll our own loop-detection analysis. This is
805 // simplified by the need to only detect outermost loops. Inner loops are
806 // not included in the loop network (since we only care about restricting
807 // different function invocations from activating a loop, not prevent loop
808 // pipelining within a single function invocation).
809 class LoopNetworkRewriter {
810 public:
811  LoopNetworkRewriter(HandshakeLowering &hl) : hl(hl) {}
812 
813  LogicalResult processRegion(Region &r, ConversionPatternRewriter &rewriter);
814 
815 private:
816  // An exit pair is a pair of <in loop block, outside loop block> that
817  // indicates where control leaves a loop.
818  using ExitPair = std::pair<Block *, Block *>;
819  LogicalResult processOuterLoop(Location loc, CFGLoop *loop);
820 
821  // Builds the loop continue network in between the loop header and its loop
822  // latch. The loop continuation network will replace the existing control
823  // merge in the loop header with a mux + loop priming register.
824  // The 'loopPrimingInput' is a backedge that will later be assigned by
825  // 'buildExitNetwork'. The value is to be used as the input to the loop
826  // priming buffer.
827  // Returns a reference to the loop priming register.
828  BufferOp buildContinueNetwork(Block *loopHeader, Block *loopLatch,
829  Backedge &loopPrimingInput);
830 
831  // Builds the loop exit network. This detects the conditional operands used in
832  // each of the exit blocks, matches their parity with the convention used to
833  // prime the loop register, and assigns it to the loop priming register input.
834  void buildExitNetwork(Block *loopHeader,
835  const llvm::SmallSet<ExitPair, 2> &exitPairs,
836  BufferOp loopPrimingRegister,
837  Backedge &loopPrimingInput);
838 
839 private:
840  ConversionPatternRewriter *rewriter = nullptr;
841  HandshakeLowering &hl;
842 };
843 } // namespace
844 
845 LogicalResult
846 HandshakeLowering::loopNetworkRewriting(ConversionPatternRewriter &rewriter) {
847  return LoopNetworkRewriter(*this).processRegion(r, rewriter);
848 }
849 
850 LogicalResult
851 LoopNetworkRewriter::processRegion(Region &r,
852  ConversionPatternRewriter &rewriter) {
853  // Nothing to do on a single block region.
854  if (r.hasOneBlock())
855  return success();
856  this->rewriter = &rewriter;
857 
858  Operation *op = r.getParentOp();
859 
860  DominanceInfo domInfo(op);
861  CFGLoopInfo loopInfo(domInfo.getDomTree(&r));
862 
863  for (CFGLoop *loop : loopInfo.getTopLevelLoops()) {
864  if (!loop->getLoopLatch())
865  return emitError(op->getLoc()) << "Multiple loop latches detected "
866  "(backedges from within the loop "
867  "to the loop header). Loop task "
868  "pipelining is only supported for "
869  "loops with unified loop latches.";
870 
871  // This is the start of an outer loop - go process!
872  if (failed(processOuterLoop(op->getLoc(), loop)))
873  return failure();
874  }
875 
876  return success();
877 }
878 
879 // Returns the operand of the 'mux' operation which originated from 'block'.
880 static Value getOperandFromBlock(MuxOp mux, Block *block) {
881  auto inValueIt = llvm::find_if(mux.getDataOperands(), [&](Value operand) {
882  return block == operand.getParentBlock();
883  });
884  assert(
885  inValueIt != mux.getDataOperands().end() &&
886  "Expected mux to have an operand originating from the requested block.");
887  return *inValueIt;
888 }
889 
890 // Returns a list of operands from 'mux' which corresponds to the inputs of the
891 // 'cmerge' operation. The results are sorted such that the i'th cmerge operand
892 // and the i'th sorted operand originate from the same block.
893 static std::vector<Value> getSortedInputs(ControlMergeOp cmerge, MuxOp mux) {
894  std::vector<Value> sortedOperands;
895  for (auto in : cmerge.getOperands()) {
896  auto *srcBlock = in.getParentBlock();
897  sortedOperands.push_back(getOperandFromBlock(mux, srcBlock));
898  }
899 
900  // Sanity check: ensure that operands are unique
901  for (unsigned i = 0; i < sortedOperands.size(); ++i) {
902  for (unsigned j = 0; j < sortedOperands.size(); ++j) {
903  if (i == j)
904  continue;
905  assert(sortedOperands[i] != sortedOperands[j] &&
906  "Cannot have an identical operand from two different blocks!");
907  }
908  }
909 
910  return sortedOperands;
911 }
912 
913 BufferOp LoopNetworkRewriter::buildContinueNetwork(Block *loopHeader,
914  Block *loopLatch,
915  Backedge &loopPrimingInput) {
916  // Gather the muxes to replace before modifying block internals; it's been
917  // found that if this is not done, we have determinism issues wrt. generating
918  // the same order of replaced muxes on repeated runs of an identical
919  // conversion.
920  llvm::SmallVector<MuxOp> muxesToReplace;
921  llvm::copy(loopHeader->getOps<MuxOp>(), std::back_inserter(muxesToReplace));
922 
923  // Fetch the control merge of the block; it is assumed that, at this point of
924  // lowering, no other form of control can be used for the loop header block
925  // than a control merge.
926  auto *cmerge = getControlMerge(loopHeader);
927  assert(hl.getBlockEntryControl(loopHeader) == cmerge->getResult(0) &&
928  "Expected control merge to be the control component of a loop header");
929  auto loc = cmerge->getLoc();
930 
931  // sanity check: cmerge should have >1 input to actually be a loop
932  assert(cmerge->getNumOperands() > 1 && "This cannot be a loop header");
933 
934  // Partition the control merge inputs into those originating from backedges,
935  // and those originating elsewhere.
936  SmallVector<Value> externalCtrls, loopCtrls;
937  for (auto cval : cmerge->getOperands()) {
938  if (cval.getParentBlock() == loopLatch)
939  loopCtrls.push_back(cval);
940  else
941  externalCtrls.push_back(cval);
942  }
943  assert(loopCtrls.size() == 1 &&
944  "Expected a single loop control value to match the single loop latch");
945  Value loopCtrl = loopCtrls.front();
946 
947  // Merge all of the controls in each partition
948  rewriter->setInsertionPointToStart(loopHeader);
949  auto externalCtrlMerge = rewriter->create<ControlMergeOp>(loc, externalCtrls);
950 
951  // Create loop mux and the loop priming register. The loop mux will on select
952  // "0" select external control, and internal control at "1". This convention
953  // must be followed by the loop exit network.
954  auto primingRegister =
955  rewriter->create<BufferOp>(loc, loopPrimingInput, 1, BufferTypeEnum::seq);
956  // Initialize the priming register to path 0.
957  primingRegister->setAttr("initValues", rewriter->getI64ArrayAttr({0}));
958 
959  // The loop control mux will deterministically select between control entering
960  // the loop from any external block or the single loop backedge.
961  auto loopCtrlMux = rewriter->create<MuxOp>(
962  loc, primingRegister.getResult(),
963  llvm::SmallVector<Value>{externalCtrlMerge.getResult(), loopCtrl});
964 
965  // Replace the existing control merge 'result' output with the loop control
966  // mux.
967  cmerge->getResult(0).replaceAllUsesWith(loopCtrlMux.getResult());
968 
969  // Register the new block entry control value
970  hl.setBlockEntryControl(loopHeader, loopCtrlMux.getResult());
971 
972  // Next, we need to consider how to replace the control merge 'index' output,
973  // used to drive input selection to the block.
974 
975  // Inputs to the loop header will be sourced from muxes with inputs from both
976  // the loop latch as well as external blocks. Iterate through these and sort
977  // based on the input ordering to the external/internal control merge.
978  // We do this by maintaining a mapping between the external and loop data
979  // inputs for each data mux in the design. The key of these maps is the
980  // original mux (that is to be replaced).
981  DenseMap<MuxOp, std::vector<Value>> externalDataInputs;
982  DenseMap<MuxOp, Value> loopDataInputs;
983  for (auto muxOp : muxesToReplace) {
984  if (muxOp == loopCtrlMux)
985  continue;
986 
987  externalDataInputs[muxOp] = getSortedInputs(externalCtrlMerge, muxOp);
988  loopDataInputs[muxOp] = getOperandFromBlock(muxOp, loopLatch);
989  assert(/*loop latch input*/ 1 + externalDataInputs[muxOp].size() ==
990  muxOp.getDataOperands().size() &&
991  "Expected all mux operands to be partitioned between loop and "
992  "external data inputs");
993  }
994 
995  // With this, we now replace each of the data input muxes in the loop header.
996  // We instantiate a single mux driven by the external control merge.
997  // This, as well as the corresponding data input coming from within the single
998  // loop latch, will then be selected between by a 3rd mux, based on the
999  // priming register.
1000  for (MuxOp mux : muxesToReplace) {
1001  auto externalDataMux = rewriter->create<MuxOp>(
1002  loc, externalCtrlMerge.getIndex(), externalDataInputs[mux]);
1003 
1004  rewriter->replaceOp(
1005  mux, rewriter
1006  ->create<MuxOp>(loc, primingRegister,
1007  llvm::SmallVector<Value>{externalDataMux,
1008  loopDataInputs[mux]})
1009  .getResult());
1010  }
1011 
1012  // Now all values defined by the original cmerge should have been replaced,
1013  // and it may be erased.
1014  rewriter->eraseOp(cmerge);
1015 
1016  // Return the priming register to be referenced by the exit network builder.
1017  return primingRegister;
1018 }
1019 
1020 void LoopNetworkRewriter::buildExitNetwork(
1021  Block *loopHeader, const llvm::SmallSet<ExitPair, 2> &exitPairs,
1022  BufferOp loopPrimingRegister, Backedge &loopPrimingInput) {
1023  auto loc = loopPrimingRegister.getLoc();
1024 
1025  // Iterate over the exit pairs to gather up the condition signals that need to
1026  // be connected to the exit network. In doing so, we parity-correct these
1027  // condition values based on the convention used in buildContinueNetwork - The
1028  // loop mux will on select "0" select external control, and internal control
1029  // at "1". This convention which must be followed by the loop exit network.
1030  // External control must be selected when exiting the loop (to reprime the
1031  // register).
1032  SmallVector<Value> parityCorrectedConds;
1033  for (auto &[condBlock, exitBlock] : exitPairs) {
1034  auto condBr = getControlCondBranch(condBlock);
1035  assert(
1036  condBr &&
1037  "Expected a conditional control branch op in the loop condition block");
1038  Operation *trueUser = *condBr.getTrueResult().getUsers().begin();
1039  bool isTrueParity = trueUser->getBlock() == exitBlock;
1040  assert(isTrueParity ^
1041  ((*condBr.getFalseResult().getUsers().begin())->getBlock() ==
1042  exitBlock) &&
1043  "The user of either the true or the false result should be in the "
1044  "exit block");
1045 
1046  Value condValue = condBr.getConditionOperand();
1047  if (isTrueParity) {
1048  // This goes against the convention, and we have to invert the condition
1049  // value before connecting it to the exit network.
1050  rewriter->setInsertionPoint(condBr);
1051  condValue = rewriter->create<arith::XOrIOp>(
1052  loc, condValue.getType(), condValue,
1053  rewriter->create<arith::ConstantOp>(
1054  loc, rewriter->getIntegerAttr(rewriter->getI1Type(), 1)));
1055  }
1056  parityCorrectedConds.push_back(condValue);
1057  }
1058 
1059  // Merge all of the parity-corrected exit conditions and assign them
1060  // to the loop priming input.
1061  auto exitMerge = rewriter->create<MergeOp>(loc, parityCorrectedConds);
1062  loopPrimingInput.setValue(exitMerge);
1063 }
1064 
1065 LogicalResult LoopNetworkRewriter::processOuterLoop(Location loc,
1066  CFGLoop *loop) {
1067  // We determine the exit pairs of the loop; this is the in-loop nodes
1068  // which branch off to the exit nodes.
1069  llvm::SmallSet<ExitPair, 2> exitPairs;
1070  SmallVector<Block *> exitBlocks;
1071  loop->getExitBlocks(exitBlocks);
1072  for (auto *exitNode : exitBlocks) {
1073  for (auto *pred : exitNode->getPredecessors()) {
1074  // is the predecessor inside the loop?
1075  if (!loop->contains(pred))
1076  continue;
1077 
1078  ExitPair condPair = {pred, exitNode};
1079  assert(!exitPairs.count(condPair) &&
1080  "identical condition pairs should never be possible");
1081  exitPairs.insert(condPair);
1082  }
1083  }
1084  assert(!exitPairs.empty() && "No exits from loop?");
1085 
1086  // The first precondition to our loop transformation is that only a single
1087  // exit pair exists in the loop.
1088  if (exitPairs.size() > 1)
1089  return emitError(loc)
1090  << "Multiple exits detected within a loop. Loop task pipelining is "
1091  "only supported for loops with unified loop exit blocks.";
1092 
1093  Block *header = loop->getHeader();
1094  BackedgeBuilder bebuilder(*rewriter, header->front().getLoc());
1095 
1096  // Build the loop continue network. Loop continuation is triggered solely by
1097  // backedges to the header.
1098  auto loopPrimingRegisterInput = bebuilder.get(rewriter->getI1Type());
1099  auto loopPrimingRegister = buildContinueNetwork(header, loop->getLoopLatch(),
1100  loopPrimingRegisterInput);
1101 
1102  // Build the loop exit network. Loop exiting is driven solely by exit pairs
1103  // from the loop.
1104  buildExitNetwork(header, exitPairs, loopPrimingRegister,
1105  loopPrimingRegisterInput);
1106 
1107  return success();
1108 }
1109 
1110 // Return the appropriate branch result based on successor block which uses it
1111 static Value getSuccResult(Operation *termOp, Operation *newOp,
1112  Block *succBlock) {
1113  // For conditional block, check if result goes to true or to false successor
1114  if (auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp)) {
1115  if (condBranchOp.getTrueDest() == succBlock)
1116  return dyn_cast<handshake::ConditionalBranchOp>(newOp).getTrueResult();
1117  else {
1118  assert(condBranchOp.getFalseDest() == succBlock);
1119  return dyn_cast<handshake::ConditionalBranchOp>(newOp).getFalseResult();
1120  }
1121  }
1122  // If the block is unconditional, newOp has only one result
1123  return newOp->getResult(0);
1124 }
1125 
1126 LogicalResult
1127 HandshakeLowering::addBranchOps(ConversionPatternRewriter &rewriter) {
1128 
1129  BlockValues liveOuts;
1130 
1131  for (Block &block : r) {
1132  for (Operation &op : block) {
1133  for (auto result : op.getResults())
1134  if (isLiveOut(result))
1135  liveOuts[&block].push_back(result);
1136  }
1137  }
1138 
1139  for (Block &block : r) {
1140  Operation *termOp = block.getTerminator();
1141  rewriter.setInsertionPoint(termOp);
1142 
1143  for (Value val : liveOuts[&block]) {
1144  // Count the number of branches which the liveout needs
1145  int numBranches = getBranchCount(val, &block);
1146 
1147  // Instantiate branches and connect to Merges
1148  for (int i = 0, e = numBranches; i < e; ++i) {
1149  Operation *newOp = nullptr;
1150 
1151  if (auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(termOp))
1152  newOp = rewriter.create<handshake::ConditionalBranchOp>(
1153  termOp->getLoc(), condBranchOp.getCondition(), val);
1154  else if (isa<mlir::cf::BranchOp>(termOp))
1155  newOp = rewriter.create<handshake::BranchOp>(termOp->getLoc(), val);
1156 
1157  if (newOp == nullptr)
1158  continue;
1159 
1160  for (int j = 0, e = block.getNumSuccessors(); j < e; ++j) {
1161  Block *succ = block.getSuccessor(j);
1162  Value res = getSuccResult(termOp, newOp, succ);
1163 
1164  for (auto &u : val.getUses()) {
1165  if (u.getOwner()->getBlock() == succ) {
1166  u.getOwner()->replaceUsesOfWith(val, res);
1167  break;
1168  }
1169  }
1170  }
1171  }
1172  }
1173  }
1174 
1175  return success();
1176 }
1177 
1178 LogicalResult HandshakeLowering::connectConstantsToControl(
1179  ConversionPatternRewriter &rewriter, bool sourceConstants) {
1180  // Create new constants which have a control-only input to trigger them.
1181  // These are conneted to the control network or optionally to a Source
1182  // operation (always triggering). Control-network connected constants may
1183  // help debugability, but result in a slightly larger circuit.
1184 
1185  if (sourceConstants) {
1186  for (auto constantOp : llvm::make_early_inc_range(
1187  r.template getOps<mlir::arith::ConstantOp>())) {
1188  rewriter.setInsertionPointAfter(constantOp);
1189  auto value = constantOp.getValue();
1190  rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1191  constantOp, value.getType(), value,
1192  rewriter.create<handshake::SourceOp>(constantOp.getLoc(),
1193  rewriter.getNoneType()));
1194  }
1195  } else {
1196  for (Block &block : r) {
1197  Value blockEntryCtrl = getBlockEntryControl(&block);
1198  for (auto constantOp : llvm::make_early_inc_range(
1199  block.template getOps<mlir::arith::ConstantOp>())) {
1200  rewriter.setInsertionPointAfter(constantOp);
1201  auto value = constantOp.getValue();
1202  rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
1203  constantOp, value.getType(), value, blockEntryCtrl);
1204  }
1205  }
1206  }
1207  return success();
1208 }
1209 
1210 /// Holds information about an handshake "basic block terminator" control
1211 /// operation
1213  /// The operation
1214  Operation *op;
1215  /// The operation's control operand (must have type NoneType)
1217 
1218  BlockControlTerm(Operation *op, Value ctrlOperand)
1219  : op(op), ctrlOperand(ctrlOperand) {
1220  assert(op && ctrlOperand);
1221  assert(isa<NoneType>(ctrlOperand.getType()) &&
1222  "Control operand must be a NoneType");
1223  }
1224 
1225  /// Checks for member-wise equality
1226  friend bool operator==(const BlockControlTerm &lhs,
1227  const BlockControlTerm &rhs) {
1228  return lhs.op == rhs.op && lhs.ctrlOperand == rhs.ctrlOperand;
1229  }
1230 };
1231 
1233  // Identify the control terminator operation and its control operand in the
1234  // given block. One such operation must exist in the block
1235  for (Operation &op : *block) {
1236  if (auto branchOp = dyn_cast<handshake::BranchOp>(op))
1237  if (branchOp.isControl())
1238  return {branchOp, branchOp.getDataOperand()};
1239  if (auto branchOp = dyn_cast<handshake::ConditionalBranchOp>(op))
1240  if (branchOp.isControl())
1241  return {branchOp, branchOp.getDataOperand()};
1242  if (auto endOp = dyn_cast<handshake::ReturnOp>(op))
1243  return {endOp, endOp.getOperands().back()};
1244  }
1245  llvm_unreachable("Block terminator must exist");
1246 }
1247 
1248 static LogicalResult getOpMemRef(Operation *op, Value &out) {
1249  out = Value();
1250  if (auto memOp = dyn_cast<memref::LoadOp>(op))
1251  out = memOp.getMemRef();
1252  else if (auto memOp = dyn_cast<memref::StoreOp>(op))
1253  out = memOp.getMemRef();
1254  else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
1255  MemRefAccess access(op);
1256  out = access.memref;
1257  }
1258  if (out != Value())
1259  return success();
1260  return op->emitOpError("Unknown Op type");
1261 }
1262 
1263 static bool isMemoryOp(Operation *op) {
1264  return isa<memref::LoadOp, memref::StoreOp, AffineReadOpInterface,
1265  AffineWriteOpInterface>(op);
1266 }
1267 
1268 LogicalResult
1269 HandshakeLowering::replaceMemoryOps(ConversionPatternRewriter &rewriter,
1270  MemRefToMemoryAccessOp &memRefOps) {
1271 
1272  std::vector<Operation *> opsToErase;
1273 
1274  // Enrich the memRefOps context with BlockArguments, in case they aren't used.
1275  for (auto arg : r.getArguments()) {
1276  auto memrefType = dyn_cast<mlir::MemRefType>(arg.getType());
1277  if (!memrefType)
1278  continue;
1279  // Ensure that this is a valid memref-typed value.
1280  if (failed(isValidMemrefType(arg.getLoc(), memrefType)))
1281  return failure();
1282  memRefOps.insert(std::make_pair(arg, std::vector<Operation *>()));
1283  }
1284 
1285  // Replace load and store ops with the corresponding handshake ops
1286  // Need to traverse ops in blocks to store them in memRefOps in program
1287  // order
1288  for (Operation &op : r.getOps()) {
1289  if (!isMemoryOp(&op))
1290  continue;
1291 
1292  rewriter.setInsertionPoint(&op);
1293  Value memref;
1294  if (getOpMemRef(&op, memref).failed())
1295  return failure();
1296  Operation *newOp = nullptr;
1297 
1298  llvm::TypeSwitch<Operation *>(&op)
1299  .Case<memref::LoadOp>([&](auto loadOp) {
1300  // Get operands which correspond to address indices
1301  // This will add all operands except alloc
1302  SmallVector<Value, 8> operands(loadOp.getIndices());
1303 
1304  newOp =
1305  rewriter.create<handshake::LoadOp>(op.getLoc(), memref, operands);
1306  op.getResult(0).replaceAllUsesWith(newOp->getResult(0));
1307  })
1308  .Case<memref::StoreOp>([&](auto storeOp) {
1309  // Get operands which correspond to address indices
1310  // This will add all operands except alloc and data
1311  SmallVector<Value, 8> operands(storeOp.getIndices());
1312 
1313  // Create new op where operands are store data and address indices
1314  newOp = rewriter.create<handshake::StoreOp>(
1315  op.getLoc(), storeOp.getValueToStore(), operands);
1316  })
1317  .Case<AffineReadOpInterface, AffineWriteOpInterface>([&](auto) {
1318  // Get essential memref access inforamtion.
1319  MemRefAccess access(&op);
1320  // The address of an affine load/store operation can be a result
1321  // of an affine map, which is a linear combination of constants
1322  // and parameters. Therefore, we should extract the affine map of
1323  // each address and expand it into proper expressions that
1324  // calculate the result.
1325  AffineMap map;
1326  if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
1327  map = loadOp.getAffineMap();
1328  else
1329  map = dyn_cast<AffineWriteOpInterface>(op).getAffineMap();
1330 
1331  // The returned object from expandAffineMap is an optional list of
1332  // the expansion results from the given affine map, which are the
1333  // actual address indices that can be used as operands for
1334  // handshake LoadOp/StoreOp. The following processing requires it
1335  // to be a valid result.
1336  auto operands =
1337  expandAffineMap(rewriter, op.getLoc(), map, access.indices);
1338  assert(operands && "Address operands of affine memref access "
1339  "cannot be reduced.");
1340 
1341  if (isa<AffineReadOpInterface>(op)) {
1342  auto loadOp = rewriter.create<handshake::LoadOp>(
1343  op.getLoc(), access.memref, *operands);
1344  newOp = loadOp;
1345  op.getResult(0).replaceAllUsesWith(loadOp.getDataResult());
1346  } else {
1347  newOp = rewriter.create<handshake::StoreOp>(
1348  op.getLoc(), op.getOperand(0), *operands);
1349  }
1350  })
1351  .Default([&](auto) {
1352  op.emitOpError("Load/store operation cannot be handled.");
1353  });
1354 
1355  memRefOps[memref].push_back(newOp);
1356  opsToErase.push_back(&op);
1357  }
1358 
1359  // Erase old memory ops
1360  for (unsigned i = 0, e = opsToErase.size(); i != e; ++i) {
1361  auto *op = opsToErase[i];
1362  for (int j = 0, e = op->getNumOperands(); j < e; ++j)
1363  op->eraseOperand(0);
1364  assert(op->getNumOperands() == 0);
1365 
1366  rewriter.eraseOp(op);
1367  }
1368 
1369  return success();
1370 }
1371 
1372 static SmallVector<Value, 8> getResultsToMemory(Operation *op) {
1373  // Get load/store results which are given as inputs to MemoryOp
1374 
1375  if (handshake::LoadOp loadOp = dyn_cast<handshake::LoadOp>(op)) {
1376  // For load, get all address outputs/indices
1377  // (load also has one data output which goes to successor operation)
1378  SmallVector<Value, 8> results(loadOp.getAddressResults());
1379  return results;
1380 
1381  } else {
1382  // For store, all outputs (data and address indices) go to memory
1383  assert(dyn_cast<handshake::StoreOp>(op));
1384  handshake::StoreOp storeOp = dyn_cast<handshake::StoreOp>(op);
1385  SmallVector<Value, 8> results(storeOp.getResults());
1386  return results;
1387  }
1388 }
1389 
1390 static void addLazyForks(Region &f, ConversionPatternRewriter &rewriter) {
1391 
1392  for (Block &block : f) {
1393  Value ctrl = getBlockControlTerminator(&block).ctrlOperand;
1394  if (!ctrl.hasOneUse())
1395  insertFork(ctrl, true, rewriter);
1396  }
1397 }
1398 
1399 static void removeUnusedAllocOps(Region &r,
1400  ConversionPatternRewriter &rewriter) {
1401  std::vector<Operation *> opsToDelete;
1402 
1403  // Remove alloc operations whose result have no use
1404  for (auto &op : r.getOps())
1405  if (isAllocOp(&op) && op.getResult(0).use_empty())
1406  opsToDelete.push_back(&op);
1407 
1408  llvm::for_each(opsToDelete, [&](auto allocOp) { rewriter.eraseOp(allocOp); });
1409 }
1410 
1411 static void addJoinOps(ConversionPatternRewriter &rewriter,
1412  ArrayRef<BlockControlTerm> controlTerms) {
1413  for (auto term : controlTerms) {
1414  auto &[op, ctrl] = term;
1415  auto *srcOp = ctrl.getDefiningOp();
1416 
1417  // Insert only single join per block
1418  if (!isa<JoinOp>(srcOp)) {
1419  rewriter.setInsertionPointAfter(srcOp);
1420  Operation *newJoin = rewriter.create<JoinOp>(srcOp->getLoc(), ctrl);
1421  op->replaceUsesOfWith(ctrl, newJoin->getResult(0));
1422  }
1423  }
1424 }
1425 
1426 static std::vector<BlockControlTerm>
1427 getControlTerminators(ArrayRef<Operation *> memOps) {
1428  std::vector<BlockControlTerm> terminators;
1429 
1430  for (Operation *op : memOps) {
1431  // Get block from which the mem op originates
1432  Block *block = op->getBlock();
1433  // Identify the control terminator in the block
1434  auto term = getBlockControlTerminator(block);
1435  if (std::find(terminators.begin(), terminators.end(), term) ==
1436  terminators.end())
1437  terminators.push_back(term);
1438  }
1439  return terminators;
1440 }
1441 
1442 static void addValueToOperands(Operation *op, Value val) {
1443 
1444  SmallVector<Value, 8> results(op->getOperands());
1445  results.push_back(val);
1446  op->setOperands(results);
1447 }
1448 
1449 static void setLoadDataInputs(ArrayRef<Operation *> memOps, Operation *memOp) {
1450  // Set memory outputs as load input data
1451  int ld_count = 0;
1452  for (auto *op : memOps) {
1453  if (isa<handshake::LoadOp>(op))
1454  addValueToOperands(op, memOp->getResult(ld_count++));
1455  }
1456 }
1457 
1458 static LogicalResult setJoinControlInputs(ArrayRef<Operation *> memOps,
1459  Operation *memOp, int offset,
1460  ArrayRef<int> cntrlInd) {
1461  // Connect all memory ops to the join of that block (ensures that all mem
1462  // ops terminate before a new block starts)
1463  for (int i = 0, e = memOps.size(); i < e; ++i) {
1464  auto *op = memOps[i];
1465  Value ctrl = getBlockControlTerminator(op->getBlock()).ctrlOperand;
1466  auto *srcOp = ctrl.getDefiningOp();
1467  if (!isa<JoinOp>(srcOp)) {
1468  return srcOp->emitOpError("Op expected to be a JoinOp");
1469  }
1470  addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i]));
1471  }
1472  return success();
1473 }
1474 
1475 void HandshakeLowering::setMemOpControlInputs(
1476  ConversionPatternRewriter &rewriter, ArrayRef<Operation *> memOps,
1477  Operation *memOp, int offset, ArrayRef<int> cntrlInd) {
1478  for (int i = 0, e = memOps.size(); i < e; ++i) {
1479  std::vector<Value> controlOperands;
1480  Operation *currOp = memOps[i];
1481  Block *currBlock = currOp->getBlock();
1482 
1483  // Set load/store control inputs from the block input control value
1484  Value blockEntryCtrl = getBlockEntryControl(currBlock);
1485  controlOperands.push_back(blockEntryCtrl);
1486 
1487  // Set load/store control inputs from predecessors in block
1488  for (int j = 0, f = i; j < f; ++j) {
1489  Operation *predOp = memOps[j];
1490  Block *predBlock = predOp->getBlock();
1491  if (currBlock == predBlock)
1492  // Any dependency but RARs
1493  if (!(isa<handshake::LoadOp>(currOp) && isa<handshake::LoadOp>(predOp)))
1494  // cntrlInd maps memOps index to correct control output index
1495  controlOperands.push_back(memOp->getResult(offset + cntrlInd[j]));
1496  }
1497 
1498  // If there is only one control input, add directly to memory op
1499  if (controlOperands.size() == 1)
1500  addValueToOperands(currOp, controlOperands[0]);
1501 
1502  // If multiple, join them and connect join output to memory op
1503  else {
1504  rewriter.setInsertionPoint(currOp);
1505  Operation *joinOp =
1506  rewriter.create<JoinOp>(currOp->getLoc(), controlOperands);
1507  addValueToOperands(currOp, joinOp->getResult(0));
1508  }
1509  }
1510 }
1511 
1512 LogicalResult
1513 HandshakeLowering::connectToMemory(ConversionPatternRewriter &rewriter,
1514  MemRefToMemoryAccessOp memRefOps, bool lsq) {
1515  // Add MemoryOps which represent the memory interface
1516  // Connect memory operations and control appropriately
1517  int mem_count = 0;
1518  for (auto memory : memRefOps) {
1519  // First operand corresponds to memref (alloca or function argument)
1520  Value memrefOperand = memory.first;
1521 
1522  // A memory is external if the memref that defines it is provided as a
1523  // function (block) argument.
1524  bool isExternalMemory = isa<BlockArgument>(memrefOperand);
1525 
1526  mlir::MemRefType memrefType =
1527  cast<mlir::MemRefType>(memrefOperand.getType());
1528  if (failed(isValidMemrefType(memrefOperand.getLoc(), memrefType)))
1529  return failure();
1530 
1531  std::vector<Value> operands;
1532 
1533  // Get control terminators whose control operand need to connect to memory
1534  std::vector<BlockControlTerm> controlTerms =
1535  getControlTerminators(memory.second);
1536 
1537  // In case of LSQ interface, set control values as inputs (used to
1538  // trigger allocation to LSQ)
1539  if (lsq)
1540  for (auto valOp : controlTerms)
1541  operands.push_back(valOp.ctrlOperand);
1542 
1543  // Add load indices and store data+indices to memory operands
1544  // Count number of loads so that we can generate appropriate number of
1545  // memory outputs (data to load ops)
1546 
1547  // memory.second is in program order
1548  // Enforce MemoryOp port ordering as follows:
1549  // Operands: all stores then all loads (stdata1, staddr1, stdata2,...,
1550  // ldaddr1, ldaddr2,....) Outputs: all load outputs, ordered the same as
1551  // load addresses (lddata1, lddata2, ...), followed by all none outputs,
1552  // ordered as operands (stnone1, stnone2,...ldnone1, ldnone2,...)
1553  std::vector<int> newInd(memory.second.size(), 0);
1554  int ind = 0;
1555  for (int i = 0, e = memory.second.size(); i < e; ++i) {
1556  auto *op = memory.second[i];
1557  if (isa<handshake::StoreOp>(op)) {
1558  SmallVector<Value, 8> results = getResultsToMemory(op);
1559  operands.insert(operands.end(), results.begin(), results.end());
1560  newInd[i] = ind++;
1561  }
1562  }
1563 
1564  int ld_count = 0;
1565 
1566  for (int i = 0, e = memory.second.size(); i < e; ++i) {
1567  auto *op = memory.second[i];
1568  if (isa<handshake::LoadOp>(op)) {
1569  SmallVector<Value, 8> results = getResultsToMemory(op);
1570  operands.insert(operands.end(), results.begin(), results.end());
1571 
1572  ld_count++;
1573  newInd[i] = ind++;
1574  }
1575  }
1576 
1577  // control-only outputs for each access (indicate access completion)
1578  int cntrl_count = lsq ? 0 : memory.second.size();
1579 
1580  Block *entryBlock = &r.front();
1581  rewriter.setInsertionPointToStart(entryBlock);
1582 
1583  // Place memory op next to the alloc op
1584  Operation *newOp = nullptr;
1585  if (isExternalMemory)
1586  newOp = rewriter.create<ExternalMemoryOp>(
1587  entryBlock->front().getLoc(), memrefOperand, operands, ld_count,
1588  cntrl_count - ld_count, mem_count++);
1589  else
1590  newOp = rewriter.create<MemoryOp>(entryBlock->front().getLoc(), operands,
1591  ld_count, cntrl_count, lsq, mem_count++,
1592  memrefOperand);
1593 
1594  setLoadDataInputs(memory.second, newOp);
1595 
1596  if (!lsq) {
1597  // Create Joins which join done signals from memory with the
1598  // control-only network
1599  addJoinOps(rewriter, controlTerms);
1600 
1601  // Connect all load/store done signals to the join of their block
1602  // Ensure that the block terminates only after all its accesses have
1603  // completed
1604  // True is default. When no sync needed, set to false (for now,
1605  // user-determined)
1606  bool control = true;
1607 
1608  if (control &&
1609  setJoinControlInputs(memory.second, newOp, ld_count, newInd).failed())
1610  return failure();
1611 
1612  // Set control-only inputs to each memory op
1613  // Ensure that op starts only after prior blocks have completed
1614  // Ensure that op starts only after predecessor ops (with RAW, WAR, or
1615  // WAW) have completed
1616  setMemOpControlInputs(rewriter, memory.second, newOp, ld_count, newInd);
1617  }
1618  }
1619 
1620  if (lsq)
1621  addLazyForks(r, rewriter);
1622 
1623  removeUnusedAllocOps(r, rewriter);
1624  return success();
1625 }
1626 
1627 LogicalResult
1628 HandshakeLowering::replaceCallOps(ConversionPatternRewriter &rewriter) {
1629  for (Block &block : r) {
1630  /// An instance is activated whenever control arrives at the basic block
1631  /// of the source callOp.
1632  Value blockEntryControl = getBlockEntryControl(&block);
1633  for (Operation &op : block) {
1634  if (auto callOp = dyn_cast<mlir::func::CallOp>(op)) {
1635  llvm::SmallVector<Value> operands;
1636  llvm::copy(callOp.getOperands(), std::back_inserter(operands));
1637  operands.push_back(blockEntryControl);
1638  rewriter.setInsertionPoint(callOp);
1639  auto instanceOp = rewriter.create<handshake::InstanceOp>(
1640  callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
1641  operands);
1642  // Replace all results of the source callOp.
1643  for (auto it : llvm::zip(callOp.getResults(), instanceOp.getResults()))
1644  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
1645  rewriter.eraseOp(callOp);
1646  }
1647  }
1648  }
1649  return success();
1650 }
1651 
1652 namespace {
1653 /// Strategy class for SSA maximization during cf-to-handshake conversion.
1654 /// Block arguments of type MemRefType and allocation operations are not
1655 /// considered for SSA maximization.
1656 class HandshakeLoweringSSAStrategy : public SSAMaximizationStrategy {
1657  /// Filters out block arguments of type MemRefType
1658  bool maximizeArgument(BlockArgument arg) override {
1659  return !isa<mlir::MemRefType>(arg.getType());
1660  }
1661 
1662  /// Filters out allocation operations
1663  bool maximizeOp(Operation *op) override { return !isAllocOp(op); }
1664 };
1665 } // namespace
1666 
1667 /// Converts every value in the region into maximal SSA form, unless the value
1668 /// is a block argument of type MemRefType or the result of an allocation
1669 /// operation.
1670 static LogicalResult maximizeSSANoMem(Region &r,
1671  ConversionPatternRewriter &rewriter) {
1672  HandshakeLoweringSSAStrategy strategy;
1673  return maximizeSSA(r, strategy, rewriter);
1674 }
1675 
1676 static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
1677  bool sourceConstants,
1678  bool disableTaskPipelining) {
1679  // Only retain those attributes that are not constructed by build.
1680  SmallVector<NamedAttribute, 4> attributes;
1681  for (const auto &attr : funcOp->getAttrs()) {
1682  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
1683  attr.getName() == funcOp.getFunctionTypeAttrName())
1684  continue;
1685  attributes.push_back(attr);
1686  }
1687 
1688  // Get function arguments
1689  llvm::SmallVector<mlir::Type, 8> argTypes;
1690  for (auto &argType : funcOp.getArgumentTypes())
1691  argTypes.push_back(argType);
1692 
1693  // Get function results
1694  llvm::SmallVector<mlir::Type, 8> resTypes;
1695  for (auto resType : funcOp.getResultTypes())
1696  resTypes.push_back(resType);
1697 
1698  handshake::FuncOp newFuncOp;
1699 
1700  // Add control input/output to function arguments/results and create a
1701  // handshake::FuncOp of appropriate type
1702  if (partiallyLowerOp<func::FuncOp>(
1703  [&](func::FuncOp funcOp, PatternRewriter &rewriter) {
1704  auto noneType = rewriter.getNoneType();
1705  resTypes.push_back(noneType);
1706  argTypes.push_back(noneType);
1707  auto func_type = rewriter.getFunctionType(argTypes, resTypes);
1708  newFuncOp = rewriter.create<handshake::FuncOp>(
1709  funcOp.getLoc(), funcOp.getName(), func_type, attributes);
1710  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1711  newFuncOp.end());
1712  if (!newFuncOp.isExternal()) {
1713  newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(),
1714  funcOp.getLoc());
1715  newFuncOp.resolveArgAndResNames();
1716  }
1717  rewriter.eraseOp(funcOp);
1718  return success();
1719  },
1720  ctx, funcOp)
1721  .failed())
1722  return failure();
1723 
1724  // Apply SSA maximization
1725  if (partiallyLowerRegion(maximizeSSANoMem, ctx, newFuncOp.getBody()).failed())
1726  return failure();
1727 
1728  if (!newFuncOp.isExternal()) {
1729  Block *bodyBlock = newFuncOp.getBodyBlock();
1730  Value entryCtrl = bodyBlock->getArguments().back();
1731  HandshakeLowering fol(newFuncOp.getBody());
1732  if (failed(lowerRegion<func::ReturnOp, handshake::ReturnOp>(
1733  fol, sourceConstants, disableTaskPipelining, entryCtrl)))
1734  return failure();
1735  }
1736 
1737  return success();
1738 }
1739 
1740 namespace {
1741 
1742 struct HandshakeRemoveBlockPass
1743  : circt::impl::HandshakeRemoveBlockBase<HandshakeRemoveBlockPass> {
1744  void runOnOperation() override { removeBasicBlocks(getOperation()); }
1745 };
1746 
1747 struct CFToHandshakePass
1748  : public circt::impl::CFToHandshakeBase<CFToHandshakePass> {
1749  CFToHandshakePass(bool sourceConstants, bool disableTaskPipelining) {
1750  this->sourceConstants = sourceConstants;
1751  this->disableTaskPipelining = disableTaskPipelining;
1752  }
1753  void runOnOperation() override {
1754  ModuleOp m = getOperation();
1755 
1756  for (auto funcOp : llvm::make_early_inc_range(m.getOps<func::FuncOp>())) {
1757  if (failed(lowerFuncOp(funcOp, &getContext(), sourceConstants,
1758  disableTaskPipelining))) {
1759  signalPassFailure();
1760  return;
1761  }
1762  }
1763  }
1764 };
1765 
1766 } // namespace
1767 
1768 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1769 circt::createCFToHandshakePass(bool sourceConstants,
1770  bool disableTaskPipelining) {
1771  return std::make_unique<CFToHandshakePass>(sourceConstants,
1772  disableTaskPipelining);
1773 }
1774 
1775 std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
1777  return std::make_unique<HandshakeRemoveBlockPass>();
1778 }
static ConditionalBranchOp getControlCondBranch(Block *block)
static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx, bool sourceConstants, bool disableTaskPipelining)
static bool isMemoryOp(Operation *op)
static LogicalResult setJoinControlInputs(ArrayRef< Operation * > memOps, Operation *memOp, int offset, ArrayRef< int > cntrlInd)
static void addJoinOps(ConversionPatternRewriter &rewriter, ArrayRef< BlockControlTerm > controlTerms)
static void addLazyForks(Region &f, ConversionPatternRewriter &rewriter)
static bool isLiveOut(Value val)
static Operation * getControlMerge(Block *block)
static SmallVector< Value, 8 > getResultsToMemory(Operation *op)
static unsigned getBlockPredecessorCount(Block *block)
static int getBranchCount(Value val, Block *block)
void removeBasicBlocks(handshake::FuncOp funcOp)
static Operation * getFirstOp(Block *block)
Returns the first occurance of an operation of type TOp, else, returns null op.
static Value getSuccResult(Operation *termOp, Operation *newOp, Block *succBlock)
static Value getMergeOperand(HandshakeLowering::MergeOpInfo mergeInfo, Block *predBlock)
static LogicalResult isValidMemrefType(Location loc, mlir::MemRefType type)
static bool isAllocOp(Operation *op)
static Value getOperandFromBlock(MuxOp mux, Block *block)
static LogicalResult getOpMemRef(Operation *op, Value &out)
static LogicalResult partiallyLowerOp(const std::function< LogicalResult(TOp, ConversionPatternRewriter &)> &loweringFunc, MLIRContext *ctx, TOp op)
static void addValueToOperands(Operation *op, Value val)
static bool loopsHaveSingleExit(CFGLoopInfo &loopInfo)
static std::vector< Value > getSortedInputs(ControlMergeOp cmerge, MuxOp mux)
static void removeBlockOperands(Region &f)
static void removeUnusedAllocOps(Region &r, ConversionPatternRewriter &rewriter)
static LogicalResult maximizeSSANoMem(Region &r, ConversionPatternRewriter &rewriter)
Converts every value in the region into maximal SSA form, unless the value is a block argument of typ...
static Operation * findBranchToBlock(Block *block)
static std::vector< BlockControlTerm > getControlTerminators(ArrayRef< Operation * > memOps)
static BlockControlTerm getBlockControlTerminator(Block *block)
static void reconnectMergeOps(Region &r, HandshakeLowering::BlockOps blockMerges, HandshakeLowering::ValueMap &mergePairs)
static void setLoadDataInputs(ArrayRef< Operation * > memOps, Operation *memOp)
assert(baseType &&"element must be base type")
static void insertFork(Value result, OpBuilder &rewriter)
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
Definition: ExpandWhens.cpp:35
Strategy strategy
LowerRegionTarget(MLIRContext &context, Region &region)
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Strategy class to control the behavior of SSA maximization.
Definition: Passes.h:65
DenseMap< Block *, std::vector< MergeOpInfo > > BlockOps
Definition: CFToHandshake.h:65
DenseMap< Block *, std::vector< Value > > BlockValues
Definition: CFToHandshake.h:64
llvm::MapVector< Value, std::vector< Operation * > > MemRefToMemoryAccessOp
Definition: CFToHandshake.h:68
DenseMap< Value, Value > ValueMap
Definition: CFToHandshake.h:66
llvm::function_ref< LogicalResult(Region &, ConversionPatternRewriter &)> RegionLoweringFunc
Definition: CFToHandshake.h:43
LogicalResult partiallyLowerRegion(const RegionLoweringFunc &loweringFunc, MLIRContext *ctx, Region &r)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
Definition: MaximizeSSA.cpp:85
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createCFToHandshakePass(bool sourceConstants=false, bool disableTaskPipelining=false)
std::unique_ptr< mlir::OperationPass< handshake::FuncOp > > createHandshakeRemoveBlockPass()
Holds information about an handshake "basic block terminator" control operation.
friend bool operator==(const BlockControlTerm &lhs, const BlockControlTerm &rhs)
Checks for member-wise equality.
Value ctrlOperand
The operation's control operand (must have type NoneType)
BlockControlTerm(Operation *op, Value ctrlOperand)
Operation * op
The operation.
Allows to partially lower a region by matching on the parent operation to then call the provided part...
PartialLoweringFunc fun
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value >, ConversionPatternRewriter &rewriter) const override
std::function< LogicalResult(Region &, ConversionPatternRewriter &)> PartialLoweringFunc
PartialLowerRegion(LowerRegionTarget &target, MLIRContext *context, LogicalResult &loweringResRef, const PartialLoweringFunc &fun)
LogicalResult & loweringRes
LowerRegionTarget & target