CIRCT  18.0.0git
Dedup.cpp
Go to the documentation of this file.
1 //===- Dedup.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/SHA256.h"
16 
17 #define DEBUG_TYPE "arc-dedup"
18 
19 namespace circt {
20 namespace arc {
21 #define GEN_PASS_DEF_DEDUP
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
23 } // namespace arc
24 } // namespace circt
25 
26 using namespace circt;
27 using namespace arc;
28 using namespace hw;
29 using llvm::SmallMapVector;
30 using llvm::SmallSetVector;
31 
32 namespace {
33 struct StructuralHash {
34  using Hash = std::array<uint8_t, 32>;
35  Hash hash;
36  Hash constInvariant; // a hash that ignores constants
37 };
38 
39 struct StructuralHasher {
40  explicit StructuralHasher(MLIRContext *context) {}
41 
42  StructuralHash hash(DefineOp arc) {
43  reset();
44  update(arc);
45  return StructuralHash{state.final(), stateConstInvariant.final()};
46  }
47 
48 private:
49  void reset() {
50  currentIndex = 0;
51  disableConstInvariant = 0;
52  indices.clear();
53  indicesConstInvariant.clear();
54  state.init();
55  stateConstInvariant.init();
56  }
57 
58  void update(const void *pointer) {
59  auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
60  state.update(ArrayRef<uint8_t>(addr, sizeof pointer));
61  if (disableConstInvariant == 0)
62  stateConstInvariant.update(ArrayRef<uint8_t>(addr, sizeof pointer));
63  }
64 
65  void update(size_t value) {
66  auto *addr = reinterpret_cast<const uint8_t *>(&value);
67  state.update(ArrayRef<uint8_t>(addr, sizeof value));
68  if (disableConstInvariant == 0)
69  stateConstInvariant.update(ArrayRef<uint8_t>(addr, sizeof value));
70  }
71 
72  void update(size_t value, size_t valueConstInvariant) {
73  state.update(ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&value),
74  sizeof value));
75  state.update(ArrayRef<uint8_t>(
76  reinterpret_cast<const uint8_t *>(&valueConstInvariant),
77  sizeof valueConstInvariant));
78  }
79 
80  void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
81 
82  void update(Type type) { update(type.getAsOpaquePointer()); }
83 
84  void update(Attribute attr) { update(attr.getAsOpaquePointer()); }
85 
86  void update(mlir::OperationName name) { update(name.getAsOpaquePointer()); }
87 
88  void update(BlockArgument arg) { update(arg.getType()); }
89 
90  void update(OpResult result) { update(result.getType()); }
91 
92  void update(OpOperand &operand) {
93  // We hash the value's index as it apears in the block.
94  auto it = indices.find(operand.get());
95  auto itCI = indicesConstInvariant.find(operand.get());
96  assert(it != indices.end() && itCI != indicesConstInvariant.end() &&
97  "op should have been previously hashed");
98  update(it->second, itCI->second);
99  }
100 
101  void update(Block &block) {
102  // Assign integer numbers to block arguments and op results. For the const-
103  // invariant hash, assign a zero to block args and constant ops, such that
104  // they hash as the same.
105  for (auto arg : block.getArguments()) {
106  indices.insert({arg, currentIndex++});
107  indicesConstInvariant.insert({arg, 0});
108  }
109  for (auto &op : block) {
110  for (auto result : op.getResults()) {
111  indices.insert({result, currentIndex++});
112  if (op.hasTrait<OpTrait::ConstantLike>())
113  indicesConstInvariant.insert({result, 0});
114  else
115  indicesConstInvariant.insert({result, currentIndexConstInvariant++});
116  }
117  }
118 
119  // Hash the block arguments for the const-invariant hash.
120  ++disableConstInvariant;
121  for (auto arg : block.getArguments())
122  update(arg);
123  --disableConstInvariant;
124 
125  // Hash the operations.
126  for (auto &op : block)
127  update(&op);
128  }
129 
130  void update(Operation *op) {
131  unsigned skipConstInvariant = op->hasTrait<OpTrait::ConstantLike>();
132  disableConstInvariant += skipConstInvariant;
133 
134  update(op->getName());
135 
136  // Hash the attributes. (Excluded in constant invariant hash.)
137  if (!isa<DefineOp>(op)) {
138  for (auto namedAttr : op->getAttrDictionary()) {
139  auto name = namedAttr.getName();
140  auto value = namedAttr.getValue();
141 
142  // Hash the interned pointer.
143  update(name.getAsOpaquePointer());
144  update(value.getAsOpaquePointer());
145  }
146  }
147 
148  // Hash the operands.
149  for (auto &operand : op->getOpOperands())
150  update(operand);
151  // Hash the regions. We need to make sure an empty region doesn't hash
152  // the same as no region, so we include the number of regions.
153  update(op->getNumRegions());
154  for (auto &region : op->getRegions())
155  for (auto &block : region.getBlocks())
156  update(block);
157  // Record any op results.
158  for (auto result : op->getResults())
159  update(result);
160 
161  disableConstInvariant -= skipConstInvariant;
162  }
163 
164  // Every value is assigned a unique id based on their order of appearance.
165  unsigned currentIndex = 0;
166  unsigned currentIndexConstInvariant = 0;
167  DenseMap<Value, unsigned> indices;
168  DenseMap<Value, unsigned> indicesConstInvariant;
169 
170  unsigned disableConstInvariant = 0;
171 
172  // This is the actual running hash calculation. This is a stateful element
173  // that should be reinitialized after each hash is produced.
174  llvm::SHA256 state;
175  llvm::SHA256 stateConstInvariant;
176 };
177 } // namespace
178 
179 namespace {
180 struct StructuralEquivalence {
181  using OpOperandPair = std::pair<OpOperand *, OpOperand *>;
182  explicit StructuralEquivalence(MLIRContext *context) {}
183 
184  void check(DefineOp arcA, DefineOp arcB) {
185  if (!checkImpl(arcA, arcB)) {
186  match = false;
187  matchConstInvariant = false;
188  }
189  }
190 
191  SmallSetVector<OpOperandPair, 1> divergences;
192  bool match;
193  bool matchConstInvariant;
194 
195 private:
196  bool addBlockToWorklist(Block &blockA, Block &blockB) {
197  auto *terminatorA = blockA.getTerminator();
198  auto *terminatorB = blockB.getTerminator();
199  if (!compareOps(terminatorA, terminatorB, OpOperandPair()))
200  return false;
201  if (!addOpToWorklist(terminatorA, terminatorB))
202  return false;
203  // TODO: We should probably bail out if there are any operations in the
204  // block that aren't in the fan-in of the terminator.
205  return true;
206  }
207 
208  bool addOpToWorklist(Operation *opA, Operation *opB,
209  bool *allOperandsHandled = nullptr) {
210  if (opA->getNumOperands() != opB->getNumOperands())
211  return false;
212  for (auto [operandA, operandB] :
213  llvm::zip(opA->getOpOperands(), opB->getOpOperands())) {
214  if (!handled.count({&operandA, &operandB})) {
215  worklist.emplace_back(&operandA, &operandB);
216  if (allOperandsHandled)
217  *allOperandsHandled = false;
218  }
219  }
220  return true;
221  }
222 
223  bool compareOps(Operation *opA, Operation *opB, OpOperandPair values) {
224  if (opA->getName() != opB->getName())
225  return false;
226  if (opA->getAttrDictionary() != opB->getAttrDictionary()) {
227  for (auto [namedAttrA, namedAttrB] :
228  llvm::zip(opA->getAttrDictionary(), opB->getAttrDictionary())) {
229  if (namedAttrA.getName() != namedAttrB.getName())
230  return false;
231  if (namedAttrA.getValue() == namedAttrB.getValue())
232  continue;
233  bool mayDiverge = opA->hasTrait<OpTrait::ConstantLike>();
234  if (!mayDiverge || !values.first || !values.second)
235  return false;
236  divergences.insert(values);
237  match = false;
238  break;
239  }
240  }
241  return true;
242  }
243 
244  bool checkImpl(DefineOp arcA, DefineOp arcB) {
245  worklist.clear();
246  divergences.clear();
247  match = true;
248  matchConstInvariant = true;
249  handled.clear();
250 
251  if (arcA.getFunctionType().getResults() !=
252  arcB.getFunctionType().getResults())
253  return false;
254 
255  if (!addBlockToWorklist(arcA.getBodyBlock(), arcB.getBodyBlock()))
256  return false;
257 
258  while (!worklist.empty()) {
259  OpOperandPair values = worklist.back();
260  if (handled.contains(values)) {
261  worklist.pop_back();
262  continue;
263  }
264 
265  auto valueA = values.first->get();
266  auto valueB = values.second->get();
267  if (valueA.getType() != valueB.getType())
268  return false;
269  auto *opA = valueA.getDefiningOp();
270  auto *opB = valueB.getDefiningOp();
271 
272  // Handle the case where one or both values are block arguments.
273  if (!opA || !opB) {
274  auto argA = valueA.dyn_cast<BlockArgument>();
275  auto argB = valueB.dyn_cast<BlockArgument>();
276  if (argA && argB) {
277  divergences.insert(values);
278  if (argA.getArgNumber() != argB.getArgNumber())
279  match = false;
280  handled.insert(values);
281  worklist.pop_back();
282  continue;
283  }
284  auto isConstA = opA && opA->hasTrait<OpTrait::ConstantLike>();
285  auto isConstB = opB && opB->hasTrait<OpTrait::ConstantLike>();
286  if ((argA && isConstB) || (argB && isConstA)) {
287  // One value is a block argument, one is a constant.
288  divergences.insert(values);
289  match = false;
290  handled.insert(values);
291  worklist.pop_back();
292  continue;
293  }
294  return false;
295  }
296 
297  // Go through all operands push the ones we haven't visited yet onto the
298  // worklist so they get processed before we continue.
299  bool allHandled = true;
300  if (!addOpToWorklist(opA, opB, &allHandled))
301  return false;
302  if (!allHandled)
303  continue;
304  handled.insert(values);
305  worklist.pop_back();
306 
307  // Compare the two operations and check that they are equal.
308  if (!compareOps(opA, opB, values))
309  return false;
310 
311  // Descend into subregions of the operation.
312  if (opA->getNumRegions() != opB->getNumRegions())
313  return false;
314  for (auto [regionA, regionB] :
315  llvm::zip(opA->getRegions(), opB->getRegions())) {
316  if (regionA.getBlocks().size() != regionB.getBlocks().size())
317  return false;
318  for (auto [blockA, blockB] : llvm::zip(regionA, regionB))
319  if (!addBlockToWorklist(blockA, blockB))
320  return false;
321  }
322  }
323 
324  return true;
325  }
326 
327  SmallVector<OpOperandPair, 0> worklist;
328  DenseSet<OpOperandPair> handled;
329 };
330 } // namespace
331 
333  MutableArrayRef<mlir::CallOpInterface> callSites,
334  ArrayRef<std::variant<Operation *, unsigned>> operandMappings) {
336  SmallVector<Value> newOperands;
337  for (auto &callOp : callSites) {
338  OpBuilder builder(callOp);
339  newOperands.clear();
340  clonedOps.clear();
341  for (auto mapping : operandMappings) {
342  if (std::holds_alternative<Operation *>(mapping)) {
343  auto *op = std::get<Operation *>(mapping);
344  auto &newOp = clonedOps[op];
345  if (!newOp)
346  newOp = builder.clone(*op);
347  newOperands.push_back(newOp->getResult(0));
348  } else {
349  newOperands.push_back(
350  callOp.getArgOperands()[std::get<unsigned>(mapping)]);
351  }
352  }
353  callOp.getArgOperandsMutable().assign(newOperands);
354  }
355 }
356 
357 static bool isOutlinable(OpOperand &operand) {
358  auto *op = operand.get().getDefiningOp();
359  return !op || op->hasTrait<OpTrait::ConstantLike>();
360 }
361 
362 namespace {
363 struct DedupPass : public arc::impl::DedupBase<DedupPass> {
364  void runOnOperation() override;
365  void replaceArcWith(DefineOp oldArc, DefineOp newArc);
366 
367  /// A mapping from arc names to arc definitions.
368  DenseMap<StringAttr, DefineOp> arcByName;
369  /// A mapping from arc definitions to call sites.
370  DenseMap<DefineOp, SmallVector<mlir::CallOpInterface, 1>> callSites;
371 };
372 
373 struct ArcHash {
374  DefineOp defineOp;
375  StructuralHash hash;
376  unsigned order;
377  ArcHash(DefineOp defineOp, StructuralHash hash, unsigned order)
378  : defineOp(defineOp), hash(hash), order(order) {}
379 };
380 } // namespace
381 
382 void DedupPass::runOnOperation() {
383  arcByName.clear();
384  callSites.clear();
385  SymbolTableCollection symbolTable;
386 
387  // Compute the structural hash for each arc definition.
388  SmallVector<ArcHash> arcHashes;
389  StructuralHasher hasher(&getContext());
390  for (auto defineOp : getOperation().getOps<DefineOp>()) {
391  arcHashes.emplace_back(defineOp, hasher.hash(defineOp), arcHashes.size());
392  arcByName.insert({defineOp.getSymNameAttr(), defineOp});
393  }
394 
395  // Collect the arc call sites.
396  getOperation().walk([&](mlir::CallOpInterface callOp) {
397  if (auto defOp =
398  dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
399  callSites[arcByName.lookup(callOp.getCallableForCallee()
400  .get<mlir::SymbolRefAttr>()
401  .getLeafReference())]
402  .push_back(callOp);
403  });
404 
405  // Sort the arcs by hash such that arcs with the same hash are next to each
406  // other, and sort arcs with the same hash by order in which they appear in
407  // the input. This allows us to iterate through the list and check
408  // neighbouring arcs for merge opportunities.
409  llvm::stable_sort(arcHashes, [](auto a, auto b) {
410  if (a.hash.hash < b.hash.hash)
411  return true;
412  if (a.hash.hash > b.hash.hash)
413  return false;
414  return a.order < b.order;
415  });
416 
417  // Perform deduplications that do not require modification of the arc call
418  // sites. (No additional ports.)
419  LLVM_DEBUG(llvm::dbgs() << "Check for exact merges (" << arcHashes.size()
420  << " arcs)\n");
421  StructuralEquivalence equiv(&getContext());
422  for (unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
423  ++arcIdx) {
424  auto [defineOp, hash, order] = arcHashes[arcIdx];
425  if (!defineOp)
426  continue;
427  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
428  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
429  if (hash.hash != otherHash.hash)
430  break;
431  if (!otherDefineOp)
432  continue;
433  equiv.check(defineOp, otherDefineOp);
434  if (!equiv.match)
435  continue;
436  LLVM_DEBUG(llvm::dbgs()
437  << "- Merge " << defineOp.getSymNameAttr() << " <- "
438  << otherDefineOp.getSymNameAttr() << "\n");
439  replaceArcWith(otherDefineOp, defineOp);
440  arcHashes[otherIdx].defineOp = {};
441  }
442  }
443 
444  // The initial pass over the arcs has set the `defineOp` to null for every arc
445  // that was already merged. Now sort the list of arcs as follows:
446  // - All merged arcs are moved to the back of the list (`!defineOp`)
447  // - Sort unmerged arcs by const-invariant hash
448  // - Sort arcs with same hash by order in which they appear in the input
449  // This allows us to pop the merged arcs off of the back of the list. Then we
450  // can iterate through the list and check neighbouring arcs for merge
451  // opportunities.
452  llvm::stable_sort(arcHashes, [](auto a, auto b) {
453  if (!a.defineOp && !b.defineOp)
454  return false;
455  if (!a.defineOp)
456  return false;
457  if (!b.defineOp)
458  return true;
459  if (a.hash.constInvariant < b.hash.constInvariant)
460  return true;
461  if (a.hash.constInvariant > b.hash.constInvariant)
462  return false;
463  return a.order < b.order;
464  });
465  while (!arcHashes.empty() && !arcHashes.back().defineOp)
466  arcHashes.pop_back();
467 
468  // Perform deduplication of arcs that differ only in constant values.
469  LLVM_DEBUG(llvm::dbgs() << "Check for constant-agnostic merges ("
470  << arcHashes.size() << " arcs)\n");
471  for (unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
472  ++arcIdx) {
473  auto [defineOp, hash, order] = arcHashes[arcIdx];
474  if (!defineOp)
475  continue;
476 
477  // Perform an initial pass over all other arcs with identical
478  // const-invariant hash. Check for equivalence between the current arc
479  // (`defineOp`) and the other arc (`otherDefineOp`). In case they match
480  // iterate over the list of divergences which holds all non-identical
481  // OpOperand pairs in the two arcs. These can come in different forms:
482  //
483  // - (const, const): Both arcs have the operand set to a constant, but the
484  // constant value differs. We'll want to extract these constants.
485  // - (arg, const): The current arc has a block argument where the other has
486  // a constant. No changes needed; when we replace the uses of the other
487  // arc with the current one further done we can use the existing
488  // argument to pass in that constant.
489  // - (const, arg): The current arc has a constant where the other has a
490  // block argument. We'll want to extract this constant and replace it
491  // with a block argument. This will allow the other arc to be replaced
492  // with the current one.
493  // - (arg, arg): Both arcs have the operand set to a block argument, but
494  // they are different argument numbers. This can happen if for example
495  // one of the arcs uses a single argument in two op operands and the
496  // other arc has two separate arguments for the two op operands. We'll
497  // want to ensure the current arc has two arguments in this case, such
498  // that the two can dedup.
499  //
500  // Whenever an op operand is involved in such a divergence we add it to the
501  // list of operands that must be mapped to a distinct block argument. Later
502  // we'll go through this list and add additional block arguments as
503  // necessary.
504  SmallMapVector<OpOperand *, unsigned, 8> outlineOperands;
505  unsigned nextGroupId = 1;
506  SmallMapVector<Value,
507  SmallMapVector<Value, SmallSetVector<OpOperand *, 1>, 2>, 2>
508  operandMappings;
509  SmallVector<StringAttr> candidateNames;
510 
511  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
512  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
513  if (hash.constInvariant != otherHash.constInvariant)
514  break;
515  if (!otherDefineOp)
516  continue;
517 
518  equiv.check(defineOp, otherDefineOp);
519  if (!equiv.matchConstInvariant)
520  continue;
521  candidateNames.push_back(otherDefineOp.getSymNameAttr());
522 
523  // Iterate over the matching operand pairs ("divergences"), look up the
524  // value pair the operands are set to, and then store the current arc's
525  // operand in the set that corresponds to this value pair. This builds up
526  // `operandMappings` to contain sets of op operands in the current arc
527  // that can be routed out to the same block argument. If a block argument
528  // of the current arc corresponds to multiple different things in the
529  // other arc, this ensures that all unique such combinations get grouped
530  // in distinct sets such that we can create an appropriate number of new
531  // block args.
532  operandMappings.clear();
533  for (auto [operand, otherOperand] : equiv.divergences) {
534  if (!isOutlinable(*operand) || !isOutlinable(*otherOperand))
535  continue;
536  operandMappings[operand->get()][otherOperand->get()].insert(operand);
537  }
538 
539  // Go through the sets of operands that can map to the same block argument
540  // for the combination of current and other arc. Assign all operands in
541  // each set new unique group IDs. If the operands in the set have multiple
542  // IDs, allocate multiple new unique group IDs. This fills the
543  // `outlineOperands` map with operands and their corresponding group ID.
544  // If we find multiple other arcs that we can potentially combine with the
545  // current arc, the operands get distributed into more and more smaller
546  // groups. For example, in arc A we can assign operands X and Y to the
547  // same block argument, so we assign them the same ID; but in arc B we
548  // have to assign X and Y to different block arguments, at which point
549  // that same ID we assigned earlier gets reassigned to two new IDs, one
550  // for each operand.
551  for (auto &[value, mappings] : operandMappings) {
552  for (auto &[otherValue, operands] : mappings) {
553  SmallDenseMap<unsigned, unsigned> remappedGroupIds;
554  for (auto *operand : operands) {
555  auto &id = outlineOperands[operand];
556  auto &remappedId = remappedGroupIds[id];
557  if (remappedId == 0)
558  remappedId = nextGroupId++;
559  id = remappedId;
560  }
561  }
562  }
563  }
564 
565  if (outlineOperands.empty())
566  continue;
567  LLVM_DEBUG({
568  llvm::dbgs() << "- Outlining " << outlineOperands.size()
569  << " operands from " << defineOp.getSymNameAttr() << "\n";
570  for (auto entry : outlineOperands)
571  llvm::dbgs() << " - Operand #" << entry.first->getOperandNumber()
572  << " of " << *entry.first->getOwner() << "\n";
573  for (auto name : candidateNames)
574  llvm::dbgs() << " - Candidate " << name << "\n";
575  });
576 
577  // Sort the operands to be outlined. The order is already deterministic at
578  // this point, but is not really correlated to the existing block argument
579  // order since we gathered these operands by traversing the operations
580  // depth-first. Establish an order that first honors existing argument order
581  // (putting constants at the back), and then considers the order of
582  // operations and op operands.
583  llvm::stable_sort(outlineOperands, [](auto &a, auto &b) {
584  auto argA = a.first->get().template dyn_cast<BlockArgument>();
585  auto argB = b.first->get().template dyn_cast<BlockArgument>();
586  if (argA && !argB)
587  return true;
588  if (!argA && argB)
589  return false;
590  if (argA && argB) {
591  if (argA.getArgNumber() < argB.getArgNumber())
592  return true;
593  if (argA.getArgNumber() > argB.getArgNumber())
594  return false;
595  }
596  auto *opA = a.first->get().getDefiningOp();
597  auto *opB = b.first->get().getDefiningOp();
598  if (opA == opB)
599  return a.first->getOperandNumber() < b.first->getOperandNumber();
600  if (opA->getBlock() == opB->getBlock())
601  return opA->isBeforeInBlock(opB);
602  return false;
603  });
604 
605  // Build a new set of arc arguments by iterating over the operands that we
606  // have determined must be exposed as arguments above. For each operand
607  // either reuse its existing block argument (if no other operand in the list
608  // has already reused it), or add a new argument for this operand. Also
609  // track how each argument must be connected at call sites (outlined
610  // constant op or reusing an existing operand).
611  unsigned oldArgumentCount = defineOp.getNumArguments();
612  SmallDenseMap<unsigned, Value> newArguments; // by group ID
613  SmallVector<Type> newInputTypes;
614  SmallVector<std::variant<Operation *, unsigned>> newOperands;
615  SmallPtrSet<Operation *, 8> outlinedOps;
616 
617  for (auto [operand, groupId] : outlineOperands) {
618  auto &arg = newArguments[groupId];
619  if (!arg) {
620  auto value = operand->get();
621  arg = defineOp.getBodyBlock().addArgument(value.getType(),
622  value.getLoc());
623  newInputTypes.push_back(arg.getType());
624  if (auto blockArg = value.dyn_cast<BlockArgument>())
625  newOperands.push_back(blockArg.getArgNumber());
626  else {
627  auto *op = value.getDefiningOp();
628  newOperands.push_back(op);
629  outlinedOps.insert(op);
630  }
631  }
632  operand->set(arg);
633  }
634 
635  for (auto arg :
636  defineOp.getBodyBlock().getArguments().slice(0, oldArgumentCount)) {
637  if (!arg.use_empty()) {
638  auto d = defineOp.emitError(
639  "dedup failed to replace all argument uses; arc ")
640  << defineOp.getSymNameAttr() << ", argument "
641  << arg.getArgNumber();
642  for (auto &use : arg.getUses())
643  d.attachNote(use.getOwner()->getLoc())
644  << "used in operand " << use.getOperandNumber() << " here";
645  return signalPassFailure();
646  }
647  }
648 
649  defineOp.getBodyBlock().eraseArguments(0, oldArgumentCount);
650  defineOp.setType(FunctionType::get(
651  &getContext(), newInputTypes, defineOp.getFunctionType().getResults()));
652  addCallSiteOperands(callSites[defineOp], newOperands);
653  for (auto *op : outlinedOps)
654  if (op->use_empty())
655  op->erase();
656 
657  // Perform the actual deduplication with other arcs.
658  for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
659  auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
660  if (hash.constInvariant != otherHash.constInvariant)
661  break;
662  if (!otherDefineOp)
663  continue;
664 
665  // Check for structural equivalence between the two arcs.
666  equiv.check(defineOp, otherDefineOp);
667  if (!equiv.matchConstInvariant)
668  continue;
669 
670  // Determine how the other arc's operands map to the arc we're trying to
671  // merge into.
672  std::variant<Operation *, unsigned> nullOperand = nullptr;
673  for (auto &operand : newOperands)
674  operand = nullptr;
675 
676  bool mappingFailed = false;
677  for (auto [operand, otherOperand] : equiv.divergences) {
678  auto arg = operand->get().dyn_cast<BlockArgument>();
679  if (!arg || !isOutlinable(*otherOperand)) {
680  mappingFailed = true;
681  break;
682  }
683 
684  // Determine how the other arc's operand maps to the new connection
685  // scheme of the current arc.
686  std::variant<Operation *, unsigned> newOperand;
687  if (auto otherArg = otherOperand->get().dyn_cast<BlockArgument>())
688  newOperand = otherArg.getArgNumber();
689  else
690  newOperand = otherOperand->get().getDefiningOp();
691 
692  // Ensure that there are no conflicting operand assignment.
693  auto &newOperandSlot = newOperands[arg.getArgNumber()];
694  if (newOperandSlot != nullOperand && newOperandSlot != newOperand) {
695  mappingFailed = true;
696  break;
697  }
698  newOperandSlot = newOperand;
699  }
700  if (mappingFailed) {
701  LLVM_DEBUG(llvm::dbgs() << " - Mapping failed; skipping arc\n");
702  continue;
703  }
704  if (llvm::any_of(newOperands,
705  [&](auto operand) { return operand == nullOperand; })) {
706  LLVM_DEBUG(llvm::dbgs()
707  << " - Not all operands mapped; skipping arc\n");
708  continue;
709  }
710 
711  // Replace all uses of the other arc with the current arc.
712  LLVM_DEBUG(llvm::dbgs()
713  << " - Merged " << defineOp.getSymNameAttr() << " <- "
714  << otherDefineOp.getSymNameAttr() << "\n");
715  addCallSiteOperands(callSites[otherDefineOp], newOperands);
716  replaceArcWith(otherDefineOp, defineOp);
717  arcHashes[otherIdx].defineOp = {};
718  }
719  }
720 }
721 
722 void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc) {
723  ++dedupPassNumArcsDeduped;
724  auto oldArcOps = oldArc.getOps();
725  dedupPassTotalOps += std::distance(oldArcOps.begin(), oldArcOps.end());
726  auto &oldUses = callSites[oldArc];
727  auto &newUses = callSites[newArc];
728  auto newArcName = SymbolRefAttr::get(newArc.getSymNameAttr());
729  for (auto callOp : oldUses) {
730  callOp.setCalleeFromCallable(newArcName);
731  newUses.push_back(callOp);
732  }
733  callSites.erase(oldArc);
734  arcByName.erase(oldArc.getSymNameAttr());
735  oldArc->erase();
736 }
737 
738 std::unique_ptr<Pass> arc::createDedupPass() {
739  return std::make_unique<DedupPass>();
740 }
static void addCallSiteOperands(MutableArrayRef< mlir::CallOpInterface > callSites, ArrayRef< std::variant< Operation *, unsigned >> operandMappings)
Definition: Dedup.cpp:332
static bool isOutlinable(OpOperand &operand)
Definition: Dedup.cpp:357
lowerAnnotationsNoRefTypePorts FirtoolPreserveValuesMode value
Definition: Firtool.cpp:95
assert(baseType &&"element must be base type")
Builder builder
std::unique_ptr< mlir::Pass > createDedupPass()
Definition: Dedup.cpp:738
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
Definition: hw.py:1
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28