CIRCT  20.0.0git
Dedup.cpp
Go to the documentation of this file.
1 //===- Dedup.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/SHA256.h"
16 #include <variant>
17 
18 #define DEBUG_TYPE "arc-dedup"
19 
20 namespace circt {
21 namespace arc {
22 #define GEN_PASS_DEF_DEDUP
23 #include "circt/Dialect/Arc/ArcPasses.h.inc"
24 } // namespace arc
25 } // namespace circt
26 
27 using namespace circt;
28 using namespace arc;
29 using namespace hw;
30 using llvm::SmallMapVector;
31 using llvm::SmallSetVector;
32 
33 namespace {
34 struct StructuralHash {
35  using Hash = std::array<uint8_t, 32>;
36  Hash hash;
37  Hash constInvariant; // a hash that ignores constants
38 };
39 
40 struct StructuralHasher {
41  explicit StructuralHasher(MLIRContext *context) {}
42 
43  StructuralHash hash(DefineOp arc) {
44  reset();
45  update(arc);
46  return StructuralHash{state.final(), stateConstInvariant.final()};
47  }
48 
49 private:
50  void reset() {
51  currentIndex = 0;
52  disableConstInvariant = 0;
53  indices.clear();
54  indicesConstInvariant.clear();
55  state.init();
56  stateConstInvariant.init();
57  }
58 
59  void update(const void *pointer) {
60  auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
61  state.update(ArrayRef<uint8_t>(addr, sizeof pointer));
62  if (disableConstInvariant == 0)
63  stateConstInvariant.update(ArrayRef<uint8_t>(addr, sizeof pointer));
64  }
65 
66  void update(size_t value) {
67  auto *addr = reinterpret_cast<const uint8_t *>(&value);
68  state.update(ArrayRef<uint8_t>(addr, sizeof value));
69  if (disableConstInvariant == 0)
70  stateConstInvariant.update(ArrayRef<uint8_t>(addr, sizeof value));
71  }
72 
73  void update(size_t value, size_t valueConstInvariant) {
74  state.update(ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&value),
75  sizeof value));
76  state.update(ArrayRef<uint8_t>(
77  reinterpret_cast<const uint8_t *>(&valueConstInvariant),
78  sizeof valueConstInvariant));
79  }
80 
81  void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
82 
83  void update(Type type) { update(type.getAsOpaquePointer()); }
84 
85  void update(Attribute attr) { update(attr.getAsOpaquePointer()); }
86 
87  void update(mlir::OperationName name) { update(name.getAsOpaquePointer()); }
88 
89  void update(BlockArgument arg) { update(arg.getType()); }
90 
91  void update(OpResult result) { update(result.getType()); }
92 
93  void update(OpOperand &operand) {
94  // We hash the value's index as it apears in the block.
95  auto it = indices.find(operand.get());
96  auto itCI = indicesConstInvariant.find(operand.get());
97  assert(it != indices.end() && itCI != indicesConstInvariant.end() &&
98  "op should have been previously hashed");
99  update(it->second, itCI->second);
100  }
101 
102  void update(Block &block) {
103  // Assign integer numbers to block arguments and op results. For the const-
104  // invariant hash, assign a zero to block args and constant ops, such that
105  // they hash as the same.
106  for (auto arg : block.getArguments()) {
107  indices.insert({arg, currentIndex++});
108  indicesConstInvariant.insert({arg, 0});
109  }
110  for (auto &op : block) {
111  for (auto result : op.getResults()) {
112  indices.insert({result, currentIndex++});
113  if (op.hasTrait<OpTrait::ConstantLike>())
114  indicesConstInvariant.insert({result, 0});
115  else
116  indicesConstInvariant.insert({result, currentIndexConstInvariant++});
117  }
118  }
119 
120  // Hash the block arguments for the const-invariant hash.
121  ++disableConstInvariant;
122  for (auto arg : block.getArguments())
123  update(arg);
124  --disableConstInvariant;
125 
126  // Hash the operations.
127  for (auto &op : block)
128  update(&op);
129  }
130 
131  void update(Operation *op) {
132  unsigned skipConstInvariant = op->hasTrait<OpTrait::ConstantLike>();
133  disableConstInvariant += skipConstInvariant;
134 
135  update(op->getName());
136 
137  // Hash the attributes. (Excluded in constant invariant hash.)
138  if (!isa<DefineOp>(op)) {
139  for (auto namedAttr : op->getAttrDictionary()) {
140  auto name = namedAttr.getName();
141  auto value = namedAttr.getValue();
142 
143  // Hash the interned pointer.
144  update(name.getAsOpaquePointer());
145  update(value.getAsOpaquePointer());
146  }
147  }
148 
149  // Hash the operands.
150  for (auto &operand : op->getOpOperands())
151  update(operand);
152  // Hash the regions. We need to make sure an empty region doesn't hash
153  // the same as no region, so we include the number of regions.
154  update(op->getNumRegions());
155  for (auto &region : op->getRegions())
156  for (auto &block : region.getBlocks())
157  update(block);
158  // Record any op results.
159  for (auto result : op->getResults())
160  update(result);
161 
162  disableConstInvariant -= skipConstInvariant;
163  }
164 
165  // Every value is assigned a unique id based on their order of appearance.
166  unsigned currentIndex = 0;
167  unsigned currentIndexConstInvariant = 0;
168  DenseMap<Value, unsigned> indices;
169  DenseMap<Value, unsigned> indicesConstInvariant;
170 
171  unsigned disableConstInvariant = 0;
172 
173  // This is the actual running hash calculation. This is a stateful element
174  // that should be reinitialized after each hash is produced.
175  llvm::SHA256 state;
176  llvm::SHA256 stateConstInvariant;
177 };
178 } // namespace
179 
180 namespace {
181 struct StructuralEquivalence {
182  using OpOperandPair = std::pair<OpOperand *, OpOperand *>;
183  explicit StructuralEquivalence(MLIRContext *context) {}
184 
185  void check(DefineOp arcA, DefineOp arcB) {
186  if (!checkImpl(arcA, arcB)) {
187  match = false;
188  matchConstInvariant = false;
189  }
190  }
191 
192  SmallSetVector<OpOperandPair, 1> divergences;
193  bool match;
194  bool matchConstInvariant;
195 
196 private:
197  bool addBlockToWorklist(Block &blockA, Block &blockB) {
198  auto *terminatorA = blockA.getTerminator();
199  auto *terminatorB = blockB.getTerminator();
200  if (!compareOps(terminatorA, terminatorB, OpOperandPair()))
201  return false;
202  if (!addOpToWorklist(terminatorA, terminatorB))
203  return false;
204  // TODO: We should probably bail out if there are any operations in the
205  // block that aren't in the fan-in of the terminator.
206  return true;
207  }
208 
209  bool addOpToWorklist(Operation *opA, Operation *opB,
210  bool *allOperandsHandled = nullptr) {
211  if (opA->getNumOperands() != opB->getNumOperands())
212  return false;
213  for (auto [operandA, operandB] :
214  llvm::zip(opA->getOpOperands(), opB->getOpOperands())) {
215  if (!handled.count({&operandA, &operandB})) {
216  worklist.emplace_back(&operandA, &operandB);
217  if (allOperandsHandled)
218  *allOperandsHandled = false;
219  }
220  }
221  return true;
222  }
223 
224  bool compareOps(Operation *opA, Operation *opB, OpOperandPair values) {
225  if (opA->getName() != opB->getName())
226  return false;
227  if (opA->getAttrDictionary() != opB->getAttrDictionary()) {
228  for (auto [namedAttrA, namedAttrB] :
229  llvm::zip(opA->getAttrDictionary(), opB->getAttrDictionary())) {
230  if (namedAttrA.getName() != namedAttrB.getName())
231  return false;
232  if (namedAttrA.getValue() == namedAttrB.getValue())
233  continue;
234  bool mayDiverge = opA->hasTrait<OpTrait::ConstantLike>();
235  if (!mayDiverge || !values.first || !values.second)
236  return false;
237  divergences.insert(values);
238  match = false;
239  break;
240  }
241  }
242  return true;
243  }
244 
245  bool checkImpl(DefineOp arcA, DefineOp arcB) {
246  worklist.clear();
247  divergences.clear();
248  match = true;
249  matchConstInvariant = true;
250  handled.clear();
251 
252  if (arcA.getFunctionType().getResults() !=
253  arcB.getFunctionType().getResults())
254  return false;
255 
256  if (!addBlockToWorklist(arcA.getBodyBlock(), arcB.getBodyBlock()))
257  return false;
258 
259  while (!worklist.empty()) {
260  OpOperandPair values = worklist.back();
261  if (handled.contains(values)) {
262  worklist.pop_back();
263  continue;
264  }
265 
266  auto valueA = values.first->get();
267  auto valueB = values.second->get();
268  if (valueA.getType() != valueB.getType())
269  return false;
270  auto *opA = valueA.getDefiningOp();
271  auto *opB = valueB.getDefiningOp();
272 
273  // Handle the case where one or both values are block arguments.
274  if (!opA || !opB) {
275  auto argA = dyn_cast<BlockArgument>(valueA);
276  auto argB = dyn_cast<BlockArgument>(valueB);
277  if (argA && argB) {
278  divergences.insert(values);
279  if (argA.getArgNumber() != argB.getArgNumber())
280  match = false;
281  handled.insert(values);
282  worklist.pop_back();
283  continue;
284  }
285  auto isConstA = opA && opA->hasTrait<OpTrait::ConstantLike>();
286  auto isConstB = opB && opB->hasTrait<OpTrait::ConstantLike>();
287  if ((argA && isConstB) || (argB && isConstA)) {
288  // One value is a block argument, one is a constant.
289  divergences.insert(values);
290  match = false;
291  handled.insert(values);
292  worklist.pop_back();
293  continue;
294  }
295  return false;
296  }
297 
298  // Go through all operands push the ones we haven't visited yet onto the
299  // worklist so they get processed before we continue.
300  bool allHandled = true;
301  if (!addOpToWorklist(opA, opB, &allHandled))
302  return false;
303  if (!allHandled)
304  continue;
305  handled.insert(values);
306  worklist.pop_back();
307 
308  // Compare the two operations and check that they are equal.
309  if (!compareOps(opA, opB, values))
310  return false;
311 
312  // Descend into subregions of the operation.
313  if (opA->getNumRegions() != opB->getNumRegions())
314  return false;
315  for (auto [regionA, regionB] :
316  llvm::zip(opA->getRegions(), opB->getRegions())) {
317  if (regionA.getBlocks().size() != regionB.getBlocks().size())
318  return false;
319  for (auto [blockA, blockB] : llvm::zip(regionA, regionB))
320  if (!addBlockToWorklist(blockA, blockB))
321  return false;
322  }
323  }
324 
325  return true;
326  }
327 
328  SmallVector<OpOperandPair, 0> worklist;
329  DenseSet<OpOperandPair> handled;
330 };
331 } // namespace
332 
334  SmallSetVector<mlir::CallOpInterface, 1> &callSites,
335  ArrayRef<std::variant<Operation *, unsigned>> operandMappings) {
337  SmallVector<Value> newOperands;
338  for (auto callOp : callSites) {
339  OpBuilder builder(callOp);
340  newOperands.clear();
341  clonedOps.clear();
342  for (auto mapping : operandMappings) {
343  if (std::holds_alternative<Operation *>(mapping)) {
344  auto *op = std::get<Operation *>(mapping);
345  auto &newOp = clonedOps[op];
346  if (!newOp)
347  newOp = builder.clone(*op);
348  newOperands.push_back(newOp->getResult(0));
349  } else {
350  newOperands.push_back(
351  callOp.getArgOperands()[std::get<unsigned>(mapping)]);
352  }
353  }
354  callOp.getArgOperandsMutable().assign(newOperands);
355  }
356 }
357 
358 static bool isOutlinable(OpOperand &operand) {
359  auto *op = operand.get().getDefiningOp();
360  return !op || op->hasTrait<OpTrait::ConstantLike>();
361 }
362 
363 namespace {
364 struct DedupPass : public arc::impl::DedupBase<DedupPass> {
365  void runOnOperation() override;
366  void replaceArcWith(DefineOp oldArc, DefineOp newArc,
367  SymbolTableCollection &symbolTable);
368 
369  /// A mapping from arc names to arc definitions.
370  DenseMap<StringAttr, DefineOp> arcByName;
371  /// A mapping from arc definitions to call sites.
372  DenseMap<DefineOp, SmallSetVector<mlir::CallOpInterface, 1>> callSites;
373 };
374 
375 struct ArcHash {
376  DefineOp defineOp;
377  StructuralHash hash;
378  unsigned order;
379  ArcHash(DefineOp defineOp, StructuralHash hash, unsigned order)
380  : defineOp(defineOp), hash(hash), order(order) {}
381 };
382 } // namespace
383 
384 void DedupPass::runOnOperation() {
385  arcByName.clear();
386  callSites.clear();
387  SymbolTableCollection symbolTable;
388 
389  // Compute the structural hash for each arc definition.
390  SmallVector<ArcHash> arcHashes;
391  StructuralHasher hasher(&getContext());
392  for (auto defineOp : getOperation().getOps<DefineOp>()) {
393  arcHashes.emplace_back(defineOp, hasher.hash(defineOp), arcHashes.size());
394  arcByName.insert({defineOp.getSymNameAttr(), defineOp});
395  }
396 
397  // Collect the arc call sites.
398  getOperation().walk([&](mlir::CallOpInterface callOp) {
399  if (auto defOp =
400  dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
401  callSites[defOp].insert(callOp);
402  });
403 
404  // Sort the arcs by hash such that arcs with the same hash are next to each
405  // other, and sort arcs with the same hash by order in which they appear in
406  // the input. This allows us to iterate through the list and check
407  // neighbouring arcs for merge opportunities.
408  llvm::stable_sort(arcHashes, [](auto a, auto b) {
409  if (a.hash.hash < b.hash.hash)
410  return true;
411  if (a.hash.hash > b.hash.hash)
412  return false;
413  return a.order < b.order;
414  });
415 
416  // Perform deduplications that do not require modification of the arc call
417  // sites. (No additional ports.)
418  LLVM_DEBUG(llvm::dbgs() << "Check for exact merges (" << arcHashes.size()
419  << " arcs)\n");
420  StructuralEquivalence equiv(&getContext());
421  for (unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
422  ++arcIdx) {
423  auto [defineOp, hash, order] = arcHashes[arcIdx];
424  if (!defineOp)
425  continue;
426  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
427  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
428  if (hash.hash != otherHash.hash)
429  break;
430  if (!otherDefineOp)
431  continue;
432  equiv.check(defineOp, otherDefineOp);
433  if (!equiv.match)
434  continue;
435  LLVM_DEBUG(llvm::dbgs()
436  << "- Merge " << defineOp.getSymNameAttr() << " <- "
437  << otherDefineOp.getSymNameAttr() << "\n");
438  replaceArcWith(otherDefineOp, defineOp, symbolTable);
439  arcHashes[otherIdx].defineOp = {};
440  }
441  }
442 
443  // The initial pass over the arcs has set the `defineOp` to null for every arc
444  // that was already merged. Now sort the list of arcs as follows:
445  // - All merged arcs are moved to the back of the list (`!defineOp`)
446  // - Sort unmerged arcs by const-invariant hash
447  // - Sort arcs with same hash by order in which they appear in the input
448  // This allows us to pop the merged arcs off of the back of the list. Then we
449  // can iterate through the list and check neighbouring arcs for merge
450  // opportunities.
451  llvm::stable_sort(arcHashes, [](auto a, auto b) {
452  if (!a.defineOp && !b.defineOp)
453  return false;
454  if (!a.defineOp)
455  return false;
456  if (!b.defineOp)
457  return true;
458  if (a.hash.constInvariant < b.hash.constInvariant)
459  return true;
460  if (a.hash.constInvariant > b.hash.constInvariant)
461  return false;
462  return a.order < b.order;
463  });
464  while (!arcHashes.empty() && !arcHashes.back().defineOp)
465  arcHashes.pop_back();
466 
467  // Perform deduplication of arcs that differ only in constant values.
468  LLVM_DEBUG(llvm::dbgs() << "Check for constant-agnostic merges ("
469  << arcHashes.size() << " arcs)\n");
470  for (unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
471  ++arcIdx) {
472  auto [defineOp, hash, order] = arcHashes[arcIdx];
473  if (!defineOp)
474  continue;
475 
476  // Perform an initial pass over all other arcs with identical
477  // const-invariant hash. Check for equivalence between the current arc
478  // (`defineOp`) and the other arc (`otherDefineOp`). In case they match
479  // iterate over the list of divergences which holds all non-identical
480  // OpOperand pairs in the two arcs. These can come in different forms:
481  //
482  // - (const, const): Both arcs have the operand set to a constant, but the
483  // constant value differs. We'll want to extract these constants.
484  // - (arg, const): The current arc has a block argument where the other has
485  // a constant. No changes needed; when we replace the uses of the other
486  // arc with the current one further done we can use the existing
487  // argument to pass in that constant.
488  // - (const, arg): The current arc has a constant where the other has a
489  // block argument. We'll want to extract this constant and replace it
490  // with a block argument. This will allow the other arc to be replaced
491  // with the current one.
492  // - (arg, arg): Both arcs have the operand set to a block argument, but
493  // they are different argument numbers. This can happen if for example
494  // one of the arcs uses a single argument in two op operands and the
495  // other arc has two separate arguments for the two op operands. We'll
496  // want to ensure the current arc has two arguments in this case, such
497  // that the two can dedup.
498  //
499  // Whenever an op operand is involved in such a divergence we add it to the
500  // list of operands that must be mapped to a distinct block argument. Later
501  // we'll go through this list and add additional block arguments as
502  // necessary.
503  SmallMapVector<OpOperand *, unsigned, 8> outlineOperands;
504  unsigned nextGroupId = 1;
505  SmallMapVector<Value,
506  SmallMapVector<Value, SmallSetVector<OpOperand *, 1>, 2>, 2>
507  operandMappings;
508  SmallVector<StringAttr> candidateNames;
509 
510  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
511  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
512  if (hash.constInvariant != otherHash.constInvariant)
513  break;
514  if (!otherDefineOp)
515  continue;
516 
517  equiv.check(defineOp, otherDefineOp);
518  if (!equiv.matchConstInvariant)
519  continue;
520  candidateNames.push_back(otherDefineOp.getSymNameAttr());
521 
522  // Iterate over the matching operand pairs ("divergences"), look up the
523  // value pair the operands are set to, and then store the current arc's
524  // operand in the set that corresponds to this value pair. This builds up
525  // `operandMappings` to contain sets of op operands in the current arc
526  // that can be routed out to the same block argument. If a block argument
527  // of the current arc corresponds to multiple different things in the
528  // other arc, this ensures that all unique such combinations get grouped
529  // in distinct sets such that we can create an appropriate number of new
530  // block args.
531  operandMappings.clear();
532  for (auto [operand, otherOperand] : equiv.divergences) {
533  if (!isOutlinable(*operand) || !isOutlinable(*otherOperand))
534  continue;
535  operandMappings[operand->get()][otherOperand->get()].insert(operand);
536  }
537 
538  // Go through the sets of operands that can map to the same block argument
539  // for the combination of current and other arc. Assign all operands in
540  // each set new unique group IDs. If the operands in the set have multiple
541  // IDs, allocate multiple new unique group IDs. This fills the
542  // `outlineOperands` map with operands and their corresponding group ID.
543  // If we find multiple other arcs that we can potentially combine with the
544  // current arc, the operands get distributed into more and more smaller
545  // groups. For example, in arc A we can assign operands X and Y to the
546  // same block argument, so we assign them the same ID; but in arc B we
547  // have to assign X and Y to different block arguments, at which point
548  // that same ID we assigned earlier gets reassigned to two new IDs, one
549  // for each operand.
550  for (auto &[value, mappings] : operandMappings) {
551  for (auto &[otherValue, operands] : mappings) {
552  SmallDenseMap<unsigned, unsigned> remappedGroupIds;
553  for (auto *operand : operands) {
554  auto &id = outlineOperands[operand];
555  auto &remappedId = remappedGroupIds[id];
556  if (remappedId == 0)
557  remappedId = nextGroupId++;
558  id = remappedId;
559  }
560  }
561  }
562  }
563 
564  if (outlineOperands.empty())
565  continue;
566  LLVM_DEBUG({
567  llvm::dbgs() << "- Outlining " << outlineOperands.size()
568  << " operands from " << defineOp.getSymNameAttr() << "\n";
569  for (auto entry : outlineOperands)
570  llvm::dbgs() << " - Operand #" << entry.first->getOperandNumber()
571  << " of " << *entry.first->getOwner() << "\n";
572  for (auto name : candidateNames)
573  llvm::dbgs() << " - Candidate " << name << "\n";
574  });
575 
576  // Sort the operands to be outlined. The order is already deterministic at
577  // this point, but is not really correlated to the existing block argument
578  // order since we gathered these operands by traversing the operations
579  // depth-first. Establish an order that first honors existing argument order
580  // (putting constants at the back), and then considers the order of
581  // operations and op operands.
582  llvm::stable_sort(outlineOperands, [](auto &a, auto &b) {
583  auto argA = dyn_cast<BlockArgument>(a.first->get());
584  auto argB = dyn_cast<BlockArgument>(b.first->get());
585  if (argA && !argB)
586  return true;
587  if (!argA && argB)
588  return false;
589  if (argA && argB) {
590  if (argA.getArgNumber() < argB.getArgNumber())
591  return true;
592  if (argA.getArgNumber() > argB.getArgNumber())
593  return false;
594  }
595  auto *opA = a.first->get().getDefiningOp();
596  auto *opB = b.first->get().getDefiningOp();
597  if (opA == opB)
598  return a.first->getOperandNumber() < b.first->getOperandNumber();
599  if (opA->getBlock() == opB->getBlock())
600  return opA->isBeforeInBlock(opB);
601  return false;
602  });
603 
604  // Build a new set of arc arguments by iterating over the operands that we
605  // have determined must be exposed as arguments above. For each operand
606  // either reuse its existing block argument (if no other operand in the list
607  // has already reused it), or add a new argument for this operand. Also
608  // track how each argument must be connected at call sites (outlined
609  // constant op or reusing an existing operand).
610  unsigned oldArgumentCount = defineOp.getNumArguments();
611  SmallDenseMap<unsigned, Value> newArguments; // by group ID
612  SmallVector<Type> newInputTypes;
613  SmallVector<std::variant<Operation *, unsigned>> newOperands;
614  SmallPtrSet<Operation *, 8> outlinedOps;
615 
616  for (auto [operand, groupId] : outlineOperands) {
617  auto &arg = newArguments[groupId];
618  if (!arg) {
619  auto value = operand->get();
620  arg = defineOp.getBodyBlock().addArgument(value.getType(),
621  value.getLoc());
622  newInputTypes.push_back(arg.getType());
623  if (auto blockArg = dyn_cast<BlockArgument>(value))
624  newOperands.push_back(blockArg.getArgNumber());
625  else {
626  auto *op = value.getDefiningOp();
627  newOperands.push_back(op);
628  outlinedOps.insert(op);
629  }
630  }
631  operand->set(arg);
632  }
633 
634  for (auto arg :
635  defineOp.getBodyBlock().getArguments().slice(0, oldArgumentCount)) {
636  if (!arg.use_empty()) {
637  auto d = defineOp.emitError(
638  "dedup failed to replace all argument uses; arc ")
639  << defineOp.getSymNameAttr() << ", argument "
640  << arg.getArgNumber();
641  for (auto &use : arg.getUses())
642  d.attachNote(use.getOwner()->getLoc())
643  << "used in operand " << use.getOperandNumber() << " here";
644  return signalPassFailure();
645  }
646  }
647 
648  defineOp.getBodyBlock().eraseArguments(0, oldArgumentCount);
649  defineOp.setType(FunctionType::get(
650  &getContext(), newInputTypes, defineOp.getFunctionType().getResults()));
651  addCallSiteOperands(callSites[defineOp], newOperands);
652  for (auto *op : outlinedOps)
653  if (op->use_empty())
654  op->erase();
655 
656  // Perform the actual deduplication with other arcs.
657  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
658  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
659  if (hash.constInvariant != otherHash.constInvariant)
660  break;
661  if (!otherDefineOp)
662  continue;
663 
664  // Check for structural equivalence between the two arcs.
665  equiv.check(defineOp, otherDefineOp);
666  if (!equiv.matchConstInvariant)
667  continue;
668 
669  // Determine how the other arc's operands map to the arc we're trying to
670  // merge into.
671  std::variant<Operation *, unsigned> nullOperand = nullptr;
672  for (auto &operand : newOperands)
673  operand = nullptr;
674 
675  bool mappingFailed = false;
676  for (auto [operand, otherOperand] : equiv.divergences) {
677  auto arg = dyn_cast<BlockArgument>(operand->get());
678  if (!arg || !isOutlinable(*otherOperand)) {
679  mappingFailed = true;
680  break;
681  }
682 
683  // Determine how the other arc's operand maps to the new connection
684  // scheme of the current arc.
685  std::variant<Operation *, unsigned> newOperand;
686  if (auto otherArg = dyn_cast<BlockArgument>(otherOperand->get()))
687  newOperand = otherArg.getArgNumber();
688  else
689  newOperand = otherOperand->get().getDefiningOp();
690 
691  // Ensure that there are no conflicting operand assignment.
692  auto &newOperandSlot = newOperands[arg.getArgNumber()];
693  if (newOperandSlot != nullOperand && newOperandSlot != newOperand) {
694  mappingFailed = true;
695  break;
696  }
697  newOperandSlot = newOperand;
698  }
699  if (mappingFailed) {
700  LLVM_DEBUG(llvm::dbgs() << " - Mapping failed; skipping arc\n");
701  continue;
702  }
703  if (llvm::any_of(newOperands,
704  [&](auto operand) { return operand == nullOperand; })) {
705  LLVM_DEBUG(llvm::dbgs()
706  << " - Not all operands mapped; skipping arc\n");
707  continue;
708  }
709 
710  // Replace all uses of the other arc with the current arc.
711  LLVM_DEBUG(llvm::dbgs()
712  << " - Merged " << defineOp.getSymNameAttr() << " <- "
713  << otherDefineOp.getSymNameAttr() << "\n");
714  addCallSiteOperands(callSites[otherDefineOp], newOperands);
715  replaceArcWith(otherDefineOp, defineOp, symbolTable);
716  arcHashes[otherIdx].defineOp = {};
717  }
718  }
719 }
720 
721 void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc,
722  SymbolTableCollection &symbolTable) {
723  ++dedupPassNumArcsDeduped;
724  auto oldArcOps = oldArc.getOps();
725  dedupPassTotalOps += std::distance(oldArcOps.begin(), oldArcOps.end());
726  auto &oldUses = callSites[oldArc];
727  auto &newUses = callSites[newArc];
728  auto newArcName = SymbolRefAttr::get(newArc.getSymNameAttr());
729  for (auto callOp : oldUses) {
730  callOp.setCalleeFromCallable(newArcName);
731  newUses.insert(callOp);
732  }
733 
734  oldArc.walk([&](mlir::CallOpInterface callOp) {
735  if (auto defOp =
736  dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
737  callSites[defOp].remove(callOp);
738  });
739  callSites.erase(oldArc);
740  arcByName.erase(oldArc.getSymNameAttr());
741  oldArc->erase();
742 }
743 
744 std::unique_ptr<Pass> arc::createDedupPass() {
745  return std::make_unique<DedupPass>();
746 }
static void addCallSiteOperands(SmallSetVector< mlir::CallOpInterface, 1 > &callSites, ArrayRef< std::variant< Operation *, unsigned >> operandMappings)
Definition: Dedup.cpp:333
static bool isOutlinable(OpOperand &operand)
Definition: Dedup.cpp:358
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createDedupPass()
Definition: Dedup.cpp:744
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1