CIRCT  19.0.0git
Dedup.cpp
Go to the documentation of this file.
1 //===- Dedup.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/SHA256.h"
16 
17 #define DEBUG_TYPE "arc-dedup"
18 
19 namespace circt {
20 namespace arc {
21 #define GEN_PASS_DEF_DEDUP
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
23 } // namespace arc
24 } // namespace circt
25 
26 using namespace circt;
27 using namespace arc;
28 using namespace hw;
29 using llvm::SmallMapVector;
30 using llvm::SmallSetVector;
31 
32 namespace {
33 struct StructuralHash {
34  using Hash = std::array<uint8_t, 32>;
35  Hash hash;
36  Hash constInvariant; // a hash that ignores constants
37 };
38 
39 struct StructuralHasher {
40  explicit StructuralHasher(MLIRContext *context) {}
41 
42  StructuralHash hash(DefineOp arc) {
43  reset();
44  update(arc);
45  return StructuralHash{state.final(), stateConstInvariant.final()};
46  }
47 
48 private:
49  void reset() {
50  currentIndex = 0;
51  disableConstInvariant = 0;
52  indices.clear();
53  indicesConstInvariant.clear();
54  state.init();
55  stateConstInvariant.init();
56  }
57 
58  void update(const void *pointer) {
59  auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
60  state.update(ArrayRef<uint8_t>(addr, sizeof pointer));
61  if (disableConstInvariant == 0)
62  stateConstInvariant.update(ArrayRef<uint8_t>(addr, sizeof pointer));
63  }
64 
65  void update(size_t value) {
66  auto *addr = reinterpret_cast<const uint8_t *>(&value);
67  state.update(ArrayRef<uint8_t>(addr, sizeof value));
68  if (disableConstInvariant == 0)
69  stateConstInvariant.update(ArrayRef<uint8_t>(addr, sizeof value));
70  }
71 
72  void update(size_t value, size_t valueConstInvariant) {
73  state.update(ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&value),
74  sizeof value));
75  state.update(ArrayRef<uint8_t>(
76  reinterpret_cast<const uint8_t *>(&valueConstInvariant),
77  sizeof valueConstInvariant));
78  }
79 
80  void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
81 
82  void update(Type type) { update(type.getAsOpaquePointer()); }
83 
84  void update(Attribute attr) { update(attr.getAsOpaquePointer()); }
85 
86  void update(mlir::OperationName name) { update(name.getAsOpaquePointer()); }
87 
88  void update(BlockArgument arg) { update(arg.getType()); }
89 
90  void update(OpResult result) { update(result.getType()); }
91 
92  void update(OpOperand &operand) {
93  // We hash the value's index as it apears in the block.
94  auto it = indices.find(operand.get());
95  auto itCI = indicesConstInvariant.find(operand.get());
96  assert(it != indices.end() && itCI != indicesConstInvariant.end() &&
97  "op should have been previously hashed");
98  update(it->second, itCI->second);
99  }
100 
101  void update(Block &block) {
102  // Assign integer numbers to block arguments and op results. For the const-
103  // invariant hash, assign a zero to block args and constant ops, such that
104  // they hash as the same.
105  for (auto arg : block.getArguments()) {
106  indices.insert({arg, currentIndex++});
107  indicesConstInvariant.insert({arg, 0});
108  }
109  for (auto &op : block) {
110  for (auto result : op.getResults()) {
111  indices.insert({result, currentIndex++});
112  if (op.hasTrait<OpTrait::ConstantLike>())
113  indicesConstInvariant.insert({result, 0});
114  else
115  indicesConstInvariant.insert({result, currentIndexConstInvariant++});
116  }
117  }
118 
119  // Hash the block arguments for the const-invariant hash.
120  ++disableConstInvariant;
121  for (auto arg : block.getArguments())
122  update(arg);
123  --disableConstInvariant;
124 
125  // Hash the operations.
126  for (auto &op : block)
127  update(&op);
128  }
129 
130  void update(Operation *op) {
131  unsigned skipConstInvariant = op->hasTrait<OpTrait::ConstantLike>();
132  disableConstInvariant += skipConstInvariant;
133 
134  update(op->getName());
135 
136  // Hash the attributes. (Excluded in constant invariant hash.)
137  if (!isa<DefineOp>(op)) {
138  for (auto namedAttr : op->getAttrDictionary()) {
139  auto name = namedAttr.getName();
140  auto value = namedAttr.getValue();
141 
142  // Hash the interned pointer.
143  update(name.getAsOpaquePointer());
144  update(value.getAsOpaquePointer());
145  }
146  }
147 
148  // Hash the operands.
149  for (auto &operand : op->getOpOperands())
150  update(operand);
151  // Hash the regions. We need to make sure an empty region doesn't hash
152  // the same as no region, so we include the number of regions.
153  update(op->getNumRegions());
154  for (auto &region : op->getRegions())
155  for (auto &block : region.getBlocks())
156  update(block);
157  // Record any op results.
158  for (auto result : op->getResults())
159  update(result);
160 
161  disableConstInvariant -= skipConstInvariant;
162  }
163 
164  // Every value is assigned a unique id based on their order of appearance.
165  unsigned currentIndex = 0;
166  unsigned currentIndexConstInvariant = 0;
167  DenseMap<Value, unsigned> indices;
168  DenseMap<Value, unsigned> indicesConstInvariant;
169 
170  unsigned disableConstInvariant = 0;
171 
172  // This is the actual running hash calculation. This is a stateful element
173  // that should be reinitialized after each hash is produced.
174  llvm::SHA256 state;
175  llvm::SHA256 stateConstInvariant;
176 };
177 } // namespace
178 
179 namespace {
180 struct StructuralEquivalence {
181  using OpOperandPair = std::pair<OpOperand *, OpOperand *>;
182  explicit StructuralEquivalence(MLIRContext *context) {}
183 
184  void check(DefineOp arcA, DefineOp arcB) {
185  if (!checkImpl(arcA, arcB)) {
186  match = false;
187  matchConstInvariant = false;
188  }
189  }
190 
191  SmallSetVector<OpOperandPair, 1> divergences;
192  bool match;
193  bool matchConstInvariant;
194 
195 private:
196  bool addBlockToWorklist(Block &blockA, Block &blockB) {
197  auto *terminatorA = blockA.getTerminator();
198  auto *terminatorB = blockB.getTerminator();
199  if (!compareOps(terminatorA, terminatorB, OpOperandPair()))
200  return false;
201  if (!addOpToWorklist(terminatorA, terminatorB))
202  return false;
203  // TODO: We should probably bail out if there are any operations in the
204  // block that aren't in the fan-in of the terminator.
205  return true;
206  }
207 
208  bool addOpToWorklist(Operation *opA, Operation *opB,
209  bool *allOperandsHandled = nullptr) {
210  if (opA->getNumOperands() != opB->getNumOperands())
211  return false;
212  for (auto [operandA, operandB] :
213  llvm::zip(opA->getOpOperands(), opB->getOpOperands())) {
214  if (!handled.count({&operandA, &operandB})) {
215  worklist.emplace_back(&operandA, &operandB);
216  if (allOperandsHandled)
217  *allOperandsHandled = false;
218  }
219  }
220  return true;
221  }
222 
223  bool compareOps(Operation *opA, Operation *opB, OpOperandPair values) {
224  if (opA->getName() != opB->getName())
225  return false;
226  if (opA->getAttrDictionary() != opB->getAttrDictionary()) {
227  for (auto [namedAttrA, namedAttrB] :
228  llvm::zip(opA->getAttrDictionary(), opB->getAttrDictionary())) {
229  if (namedAttrA.getName() != namedAttrB.getName())
230  return false;
231  if (namedAttrA.getValue() == namedAttrB.getValue())
232  continue;
233  bool mayDiverge = opA->hasTrait<OpTrait::ConstantLike>();
234  if (!mayDiverge || !values.first || !values.second)
235  return false;
236  divergences.insert(values);
237  match = false;
238  break;
239  }
240  }
241  return true;
242  }
243 
244  bool checkImpl(DefineOp arcA, DefineOp arcB) {
245  worklist.clear();
246  divergences.clear();
247  match = true;
248  matchConstInvariant = true;
249  handled.clear();
250 
251  if (arcA.getFunctionType().getResults() !=
252  arcB.getFunctionType().getResults())
253  return false;
254 
255  if (!addBlockToWorklist(arcA.getBodyBlock(), arcB.getBodyBlock()))
256  return false;
257 
258  while (!worklist.empty()) {
259  OpOperandPair values = worklist.back();
260  if (handled.contains(values)) {
261  worklist.pop_back();
262  continue;
263  }
264 
265  auto valueA = values.first->get();
266  auto valueB = values.second->get();
267  if (valueA.getType() != valueB.getType())
268  return false;
269  auto *opA = valueA.getDefiningOp();
270  auto *opB = valueB.getDefiningOp();
271 
272  // Handle the case where one or both values are block arguments.
273  if (!opA || !opB) {
274  auto argA = dyn_cast<BlockArgument>(valueA);
275  auto argB = dyn_cast<BlockArgument>(valueB);
276  if (argA && argB) {
277  divergences.insert(values);
278  if (argA.getArgNumber() != argB.getArgNumber())
279  match = false;
280  handled.insert(values);
281  worklist.pop_back();
282  continue;
283  }
284  auto isConstA = opA && opA->hasTrait<OpTrait::ConstantLike>();
285  auto isConstB = opB && opB->hasTrait<OpTrait::ConstantLike>();
286  if ((argA && isConstB) || (argB && isConstA)) {
287  // One value is a block argument, one is a constant.
288  divergences.insert(values);
289  match = false;
290  handled.insert(values);
291  worklist.pop_back();
292  continue;
293  }
294  return false;
295  }
296 
297  // Go through all operands push the ones we haven't visited yet onto the
298  // worklist so they get processed before we continue.
299  bool allHandled = true;
300  if (!addOpToWorklist(opA, opB, &allHandled))
301  return false;
302  if (!allHandled)
303  continue;
304  handled.insert(values);
305  worklist.pop_back();
306 
307  // Compare the two operations and check that they are equal.
308  if (!compareOps(opA, opB, values))
309  return false;
310 
311  // Descend into subregions of the operation.
312  if (opA->getNumRegions() != opB->getNumRegions())
313  return false;
314  for (auto [regionA, regionB] :
315  llvm::zip(opA->getRegions(), opB->getRegions())) {
316  if (regionA.getBlocks().size() != regionB.getBlocks().size())
317  return false;
318  for (auto [blockA, blockB] : llvm::zip(regionA, regionB))
319  if (!addBlockToWorklist(blockA, blockB))
320  return false;
321  }
322  }
323 
324  return true;
325  }
326 
327  SmallVector<OpOperandPair, 0> worklist;
328  DenseSet<OpOperandPair> handled;
329 };
330 } // namespace
331 
333  SmallSetVector<mlir::CallOpInterface, 1> &callSites,
334  ArrayRef<std::variant<Operation *, unsigned>> operandMappings) {
336  SmallVector<Value> newOperands;
337  for (auto callOp : callSites) {
338  OpBuilder builder(callOp);
339  newOperands.clear();
340  clonedOps.clear();
341  for (auto mapping : operandMappings) {
342  if (std::holds_alternative<Operation *>(mapping)) {
343  auto *op = std::get<Operation *>(mapping);
344  auto &newOp = clonedOps[op];
345  if (!newOp)
346  newOp = builder.clone(*op);
347  newOperands.push_back(newOp->getResult(0));
348  } else {
349  newOperands.push_back(
350  callOp.getArgOperands()[std::get<unsigned>(mapping)]);
351  }
352  }
353  callOp.getArgOperandsMutable().assign(newOperands);
354  }
355 }
356 
357 static bool isOutlinable(OpOperand &operand) {
358  auto *op = operand.get().getDefiningOp();
359  return !op || op->hasTrait<OpTrait::ConstantLike>();
360 }
361 
362 namespace {
363 struct DedupPass : public arc::impl::DedupBase<DedupPass> {
364  void runOnOperation() override;
365  void replaceArcWith(DefineOp oldArc, DefineOp newArc,
366  SymbolTableCollection &symbolTable);
367 
368  /// A mapping from arc names to arc definitions.
369  DenseMap<StringAttr, DefineOp> arcByName;
370  /// A mapping from arc definitions to call sites.
371  DenseMap<DefineOp, SmallSetVector<mlir::CallOpInterface, 1>> callSites;
372 };
373 
374 struct ArcHash {
375  DefineOp defineOp;
376  StructuralHash hash;
377  unsigned order;
378  ArcHash(DefineOp defineOp, StructuralHash hash, unsigned order)
379  : defineOp(defineOp), hash(hash), order(order) {}
380 };
381 } // namespace
382 
383 void DedupPass::runOnOperation() {
384  arcByName.clear();
385  callSites.clear();
386  SymbolTableCollection symbolTable;
387 
388  // Compute the structural hash for each arc definition.
389  SmallVector<ArcHash> arcHashes;
390  StructuralHasher hasher(&getContext());
391  for (auto defineOp : getOperation().getOps<DefineOp>()) {
392  arcHashes.emplace_back(defineOp, hasher.hash(defineOp), arcHashes.size());
393  arcByName.insert({defineOp.getSymNameAttr(), defineOp});
394  }
395 
396  // Collect the arc call sites.
397  getOperation().walk([&](mlir::CallOpInterface callOp) {
398  if (auto defOp =
399  dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
400  callSites[defOp].insert(callOp);
401  });
402 
403  // Sort the arcs by hash such that arcs with the same hash are next to each
404  // other, and sort arcs with the same hash by order in which they appear in
405  // the input. This allows us to iterate through the list and check
406  // neighbouring arcs for merge opportunities.
407  llvm::stable_sort(arcHashes, [](auto a, auto b) {
408  if (a.hash.hash < b.hash.hash)
409  return true;
410  if (a.hash.hash > b.hash.hash)
411  return false;
412  return a.order < b.order;
413  });
414 
415  // Perform deduplications that do not require modification of the arc call
416  // sites. (No additional ports.)
417  LLVM_DEBUG(llvm::dbgs() << "Check for exact merges (" << arcHashes.size()
418  << " arcs)\n");
419  StructuralEquivalence equiv(&getContext());
420  for (unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
421  ++arcIdx) {
422  auto [defineOp, hash, order] = arcHashes[arcIdx];
423  if (!defineOp)
424  continue;
425  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
426  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
427  if (hash.hash != otherHash.hash)
428  break;
429  if (!otherDefineOp)
430  continue;
431  equiv.check(defineOp, otherDefineOp);
432  if (!equiv.match)
433  continue;
434  LLVM_DEBUG(llvm::dbgs()
435  << "- Merge " << defineOp.getSymNameAttr() << " <- "
436  << otherDefineOp.getSymNameAttr() << "\n");
437  replaceArcWith(otherDefineOp, defineOp, symbolTable);
438  arcHashes[otherIdx].defineOp = {};
439  }
440  }
441 
442  // The initial pass over the arcs has set the `defineOp` to null for every arc
443  // that was already merged. Now sort the list of arcs as follows:
444  // - All merged arcs are moved to the back of the list (`!defineOp`)
445  // - Sort unmerged arcs by const-invariant hash
446  // - Sort arcs with same hash by order in which they appear in the input
447  // This allows us to pop the merged arcs off of the back of the list. Then we
448  // can iterate through the list and check neighbouring arcs for merge
449  // opportunities.
450  llvm::stable_sort(arcHashes, [](auto a, auto b) {
451  if (!a.defineOp && !b.defineOp)
452  return false;
453  if (!a.defineOp)
454  return false;
455  if (!b.defineOp)
456  return true;
457  if (a.hash.constInvariant < b.hash.constInvariant)
458  return true;
459  if (a.hash.constInvariant > b.hash.constInvariant)
460  return false;
461  return a.order < b.order;
462  });
463  while (!arcHashes.empty() && !arcHashes.back().defineOp)
464  arcHashes.pop_back();
465 
466  // Perform deduplication of arcs that differ only in constant values.
467  LLVM_DEBUG(llvm::dbgs() << "Check for constant-agnostic merges ("
468  << arcHashes.size() << " arcs)\n");
469  for (unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
470  ++arcIdx) {
471  auto [defineOp, hash, order] = arcHashes[arcIdx];
472  if (!defineOp)
473  continue;
474 
475  // Perform an initial pass over all other arcs with identical
476  // const-invariant hash. Check for equivalence between the current arc
477  // (`defineOp`) and the other arc (`otherDefineOp`). In case they match
478  // iterate over the list of divergences which holds all non-identical
479  // OpOperand pairs in the two arcs. These can come in different forms:
480  //
481  // - (const, const): Both arcs have the operand set to a constant, but the
482  // constant value differs. We'll want to extract these constants.
483  // - (arg, const): The current arc has a block argument where the other has
484  // a constant. No changes needed; when we replace the uses of the other
485  // arc with the current one further done we can use the existing
486  // argument to pass in that constant.
487  // - (const, arg): The current arc has a constant where the other has a
488  // block argument. We'll want to extract this constant and replace it
489  // with a block argument. This will allow the other arc to be replaced
490  // with the current one.
491  // - (arg, arg): Both arcs have the operand set to a block argument, but
492  // they are different argument numbers. This can happen if for example
493  // one of the arcs uses a single argument in two op operands and the
494  // other arc has two separate arguments for the two op operands. We'll
495  // want to ensure the current arc has two arguments in this case, such
496  // that the two can dedup.
497  //
498  // Whenever an op operand is involved in such a divergence we add it to the
499  // list of operands that must be mapped to a distinct block argument. Later
500  // we'll go through this list and add additional block arguments as
501  // necessary.
502  SmallMapVector<OpOperand *, unsigned, 8> outlineOperands;
503  unsigned nextGroupId = 1;
504  SmallMapVector<Value,
505  SmallMapVector<Value, SmallSetVector<OpOperand *, 1>, 2>, 2>
506  operandMappings;
507  SmallVector<StringAttr> candidateNames;
508 
509  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
510  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
511  if (hash.constInvariant != otherHash.constInvariant)
512  break;
513  if (!otherDefineOp)
514  continue;
515 
516  equiv.check(defineOp, otherDefineOp);
517  if (!equiv.matchConstInvariant)
518  continue;
519  candidateNames.push_back(otherDefineOp.getSymNameAttr());
520 
521  // Iterate over the matching operand pairs ("divergences"), look up the
522  // value pair the operands are set to, and then store the current arc's
523  // operand in the set that corresponds to this value pair. This builds up
524  // `operandMappings` to contain sets of op operands in the current arc
525  // that can be routed out to the same block argument. If a block argument
526  // of the current arc corresponds to multiple different things in the
527  // other arc, this ensures that all unique such combinations get grouped
528  // in distinct sets such that we can create an appropriate number of new
529  // block args.
530  operandMappings.clear();
531  for (auto [operand, otherOperand] : equiv.divergences) {
532  if (!isOutlinable(*operand) || !isOutlinable(*otherOperand))
533  continue;
534  operandMappings[operand->get()][otherOperand->get()].insert(operand);
535  }
536 
537  // Go through the sets of operands that can map to the same block argument
538  // for the combination of current and other arc. Assign all operands in
539  // each set new unique group IDs. If the operands in the set have multiple
540  // IDs, allocate multiple new unique group IDs. This fills the
541  // `outlineOperands` map with operands and their corresponding group ID.
542  // If we find multiple other arcs that we can potentially combine with the
543  // current arc, the operands get distributed into more and more smaller
544  // groups. For example, in arc A we can assign operands X and Y to the
545  // same block argument, so we assign them the same ID; but in arc B we
546  // have to assign X and Y to different block arguments, at which point
547  // that same ID we assigned earlier gets reassigned to two new IDs, one
548  // for each operand.
549  for (auto &[value, mappings] : operandMappings) {
550  for (auto &[otherValue, operands] : mappings) {
551  SmallDenseMap<unsigned, unsigned> remappedGroupIds;
552  for (auto *operand : operands) {
553  auto &id = outlineOperands[operand];
554  auto &remappedId = remappedGroupIds[id];
555  if (remappedId == 0)
556  remappedId = nextGroupId++;
557  id = remappedId;
558  }
559  }
560  }
561  }
562 
563  if (outlineOperands.empty())
564  continue;
565  LLVM_DEBUG({
566  llvm::dbgs() << "- Outlining " << outlineOperands.size()
567  << " operands from " << defineOp.getSymNameAttr() << "\n";
568  for (auto entry : outlineOperands)
569  llvm::dbgs() << " - Operand #" << entry.first->getOperandNumber()
570  << " of " << *entry.first->getOwner() << "\n";
571  for (auto name : candidateNames)
572  llvm::dbgs() << " - Candidate " << name << "\n";
573  });
574 
575  // Sort the operands to be outlined. The order is already deterministic at
576  // this point, but is not really correlated to the existing block argument
577  // order since we gathered these operands by traversing the operations
578  // depth-first. Establish an order that first honors existing argument order
579  // (putting constants at the back), and then considers the order of
580  // operations and op operands.
581  llvm::stable_sort(outlineOperands, [](auto &a, auto &b) {
582  auto argA = dyn_cast<BlockArgument>(a.first->get());
583  auto argB = dyn_cast<BlockArgument>(b.first->get());
584  if (argA && !argB)
585  return true;
586  if (!argA && argB)
587  return false;
588  if (argA && argB) {
589  if (argA.getArgNumber() < argB.getArgNumber())
590  return true;
591  if (argA.getArgNumber() > argB.getArgNumber())
592  return false;
593  }
594  auto *opA = a.first->get().getDefiningOp();
595  auto *opB = b.first->get().getDefiningOp();
596  if (opA == opB)
597  return a.first->getOperandNumber() < b.first->getOperandNumber();
598  if (opA->getBlock() == opB->getBlock())
599  return opA->isBeforeInBlock(opB);
600  return false;
601  });
602 
603  // Build a new set of arc arguments by iterating over the operands that we
604  // have determined must be exposed as arguments above. For each operand
605  // either reuse its existing block argument (if no other operand in the list
606  // has already reused it), or add a new argument for this operand. Also
607  // track how each argument must be connected at call sites (outlined
608  // constant op or reusing an existing operand).
609  unsigned oldArgumentCount = defineOp.getNumArguments();
610  SmallDenseMap<unsigned, Value> newArguments; // by group ID
611  SmallVector<Type> newInputTypes;
612  SmallVector<std::variant<Operation *, unsigned>> newOperands;
613  SmallPtrSet<Operation *, 8> outlinedOps;
614 
615  for (auto [operand, groupId] : outlineOperands) {
616  auto &arg = newArguments[groupId];
617  if (!arg) {
618  auto value = operand->get();
619  arg = defineOp.getBodyBlock().addArgument(value.getType(),
620  value.getLoc());
621  newInputTypes.push_back(arg.getType());
622  if (auto blockArg = dyn_cast<BlockArgument>(value))
623  newOperands.push_back(blockArg.getArgNumber());
624  else {
625  auto *op = value.getDefiningOp();
626  newOperands.push_back(op);
627  outlinedOps.insert(op);
628  }
629  }
630  operand->set(arg);
631  }
632 
633  for (auto arg :
634  defineOp.getBodyBlock().getArguments().slice(0, oldArgumentCount)) {
635  if (!arg.use_empty()) {
636  auto d = defineOp.emitError(
637  "dedup failed to replace all argument uses; arc ")
638  << defineOp.getSymNameAttr() << ", argument "
639  << arg.getArgNumber();
640  for (auto &use : arg.getUses())
641  d.attachNote(use.getOwner()->getLoc())
642  << "used in operand " << use.getOperandNumber() << " here";
643  return signalPassFailure();
644  }
645  }
646 
647  defineOp.getBodyBlock().eraseArguments(0, oldArgumentCount);
648  defineOp.setType(FunctionType::get(
649  &getContext(), newInputTypes, defineOp.getFunctionType().getResults()));
650  addCallSiteOperands(callSites[defineOp], newOperands);
651  for (auto *op : outlinedOps)
652  if (op->use_empty())
653  op->erase();
654 
655  // Perform the actual deduplication with other arcs.
656  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
657  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
658  if (hash.constInvariant != otherHash.constInvariant)
659  break;
660  if (!otherDefineOp)
661  continue;
662 
663  // Check for structural equivalence between the two arcs.
664  equiv.check(defineOp, otherDefineOp);
665  if (!equiv.matchConstInvariant)
666  continue;
667 
668  // Determine how the other arc's operands map to the arc we're trying to
669  // merge into.
670  std::variant<Operation *, unsigned> nullOperand = nullptr;
671  for (auto &operand : newOperands)
672  operand = nullptr;
673 
674  bool mappingFailed = false;
675  for (auto [operand, otherOperand] : equiv.divergences) {
676  auto arg = dyn_cast<BlockArgument>(operand->get());
677  if (!arg || !isOutlinable(*otherOperand)) {
678  mappingFailed = true;
679  break;
680  }
681 
682  // Determine how the other arc's operand maps to the new connection
683  // scheme of the current arc.
684  std::variant<Operation *, unsigned> newOperand;
685  if (auto otherArg = dyn_cast<BlockArgument>(otherOperand->get()))
686  newOperand = otherArg.getArgNumber();
687  else
688  newOperand = otherOperand->get().getDefiningOp();
689 
690  // Ensure that there are no conflicting operand assignment.
691  auto &newOperandSlot = newOperands[arg.getArgNumber()];
692  if (newOperandSlot != nullOperand && newOperandSlot != newOperand) {
693  mappingFailed = true;
694  break;
695  }
696  newOperandSlot = newOperand;
697  }
698  if (mappingFailed) {
699  LLVM_DEBUG(llvm::dbgs() << " - Mapping failed; skipping arc\n");
700  continue;
701  }
702  if (llvm::any_of(newOperands,
703  [&](auto operand) { return operand == nullOperand; })) {
704  LLVM_DEBUG(llvm::dbgs()
705  << " - Not all operands mapped; skipping arc\n");
706  continue;
707  }
708 
709  // Replace all uses of the other arc with the current arc.
710  LLVM_DEBUG(llvm::dbgs()
711  << " - Merged " << defineOp.getSymNameAttr() << " <- "
712  << otherDefineOp.getSymNameAttr() << "\n");
713  addCallSiteOperands(callSites[otherDefineOp], newOperands);
714  replaceArcWith(otherDefineOp, defineOp, symbolTable);
715  arcHashes[otherIdx].defineOp = {};
716  }
717  }
718 }
719 
720 void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc,
721  SymbolTableCollection &symbolTable) {
722  ++dedupPassNumArcsDeduped;
723  auto oldArcOps = oldArc.getOps();
724  dedupPassTotalOps += std::distance(oldArcOps.begin(), oldArcOps.end());
725  auto &oldUses = callSites[oldArc];
726  auto &newUses = callSites[newArc];
727  auto newArcName = SymbolRefAttr::get(newArc.getSymNameAttr());
728  for (auto callOp : oldUses) {
729  callOp.setCalleeFromCallable(newArcName);
730  newUses.insert(callOp);
731  }
732 
733  oldArc.walk([&](mlir::CallOpInterface callOp) {
734  if (auto defOp =
735  dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
736  callSites[defOp].remove(callOp);
737  });
738  callSites.erase(oldArc);
739  arcByName.erase(oldArc.getSymNameAttr());
740  oldArc->erase();
741 }
742 
743 std::unique_ptr<Pass> arc::createDedupPass() {
744  return std::make_unique<DedupPass>();
745 }
static void addCallSiteOperands(SmallSetVector< mlir::CallOpInterface, 1 > &callSites, ArrayRef< std::variant< Operation *, unsigned >> operandMappings)
Definition: Dedup.cpp:332
static bool isOutlinable(OpOperand &operand)
Definition: Dedup.cpp:357
assert(baseType &&"element must be base type")
Builder builder
std::unique_ptr< mlir::Pass > createDedupPass()
Definition: Dedup.cpp:743
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1