CIRCT 20.0.0git
Loading...
Searching...
No Matches
Dedup.cpp
Go to the documentation of this file.
1//===- Dedup.cpp ----------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/BuiltinAttributes.h"
12#include "mlir/Pass/Pass.h"
13#include "llvm/ADT/SetVector.h"
14#include "llvm/Support/Debug.h"
15#include "llvm/Support/SHA256.h"
16#include <variant>
17
18#define DEBUG_TYPE "arc-dedup"
19
20namespace circt {
21namespace arc {
22#define GEN_PASS_DEF_DEDUP
23#include "circt/Dialect/Arc/ArcPasses.h.inc"
24} // namespace arc
25} // namespace circt
26
27using namespace circt;
28using namespace arc;
29using namespace hw;
30using llvm::SmallMapVector;
31using llvm::SmallSetVector;
32
33namespace {
34struct StructuralHash {
35 using Hash = std::array<uint8_t, 32>;
36 Hash hash;
37 Hash constInvariant; // a hash that ignores constants
38};
39
40struct StructuralHasher {
41 explicit StructuralHasher(MLIRContext *context) {}
42
43 StructuralHash hash(DefineOp arc) {
44 reset();
45 update(arc);
46 return StructuralHash{state.final(), stateConstInvariant.final()};
47 }
48
49private:
50 void reset() {
51 currentIndex = 0;
52 disableConstInvariant = 0;
53 indices.clear();
54 indicesConstInvariant.clear();
55 state.init();
56 stateConstInvariant.init();
57 }
58
59 void update(const void *pointer) {
60 auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
61 state.update(ArrayRef<uint8_t>(addr, sizeof pointer));
62 if (disableConstInvariant == 0)
63 stateConstInvariant.update(ArrayRef<uint8_t>(addr, sizeof pointer));
64 }
65
66 void update(size_t value) {
67 auto *addr = reinterpret_cast<const uint8_t *>(&value);
68 state.update(ArrayRef<uint8_t>(addr, sizeof value));
69 if (disableConstInvariant == 0)
70 stateConstInvariant.update(ArrayRef<uint8_t>(addr, sizeof value));
71 }
72
73 void update(size_t value, size_t valueConstInvariant) {
74 state.update(ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&value),
75 sizeof value));
76 state.update(ArrayRef<uint8_t>(
77 reinterpret_cast<const uint8_t *>(&valueConstInvariant),
78 sizeof valueConstInvariant));
79 }
80
81 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
82
83 void update(Type type) { update(type.getAsOpaquePointer()); }
84
85 void update(Attribute attr) { update(attr.getAsOpaquePointer()); }
86
87 void update(mlir::OperationName name) { update(name.getAsOpaquePointer()); }
88
89 void update(BlockArgument arg) { update(arg.getType()); }
90
91 void update(OpResult result) { update(result.getType()); }
92
93 void update(OpOperand &operand) {
94 // We hash the value's index as it apears in the block.
95 auto it = indices.find(operand.get());
96 auto itCI = indicesConstInvariant.find(operand.get());
97 assert(it != indices.end() && itCI != indicesConstInvariant.end() &&
98 "op should have been previously hashed");
99 update(it->second, itCI->second);
100 }
101
102 void update(Block &block) {
103 // Assign integer numbers to block arguments and op results. For the const-
104 // invariant hash, assign a zero to block args and constant ops, such that
105 // they hash as the same.
106 for (auto arg : block.getArguments()) {
107 indices.insert({arg, currentIndex++});
108 indicesConstInvariant.insert({arg, 0});
109 }
110 for (auto &op : block) {
111 for (auto result : op.getResults()) {
112 indices.insert({result, currentIndex++});
113 if (op.hasTrait<OpTrait::ConstantLike>())
114 indicesConstInvariant.insert({result, 0});
115 else
116 indicesConstInvariant.insert({result, currentIndexConstInvariant++});
117 }
118 }
119
120 // Hash the block arguments for the const-invariant hash.
121 ++disableConstInvariant;
122 for (auto arg : block.getArguments())
123 update(arg);
124 --disableConstInvariant;
125
126 // Hash the operations.
127 for (auto &op : block)
128 update(&op);
129 }
130
131 void update(Operation *op) {
132 unsigned skipConstInvariant = op->hasTrait<OpTrait::ConstantLike>();
133 disableConstInvariant += skipConstInvariant;
134
135 update(op->getName());
136
137 // Hash the attributes. (Excluded in constant invariant hash.)
138 if (!isa<DefineOp>(op)) {
139 for (auto namedAttr : op->getAttrDictionary()) {
140 auto name = namedAttr.getName();
141 auto value = namedAttr.getValue();
142
143 // Hash the interned pointer.
144 update(name.getAsOpaquePointer());
145 update(value.getAsOpaquePointer());
146 }
147 }
148
149 // Hash the operands.
150 for (auto &operand : op->getOpOperands())
151 update(operand);
152 // Hash the regions. We need to make sure an empty region doesn't hash
153 // the same as no region, so we include the number of regions.
154 update(op->getNumRegions());
155 for (auto &region : op->getRegions())
156 for (auto &block : region.getBlocks())
157 update(block);
158 // Record any op results.
159 for (auto result : op->getResults())
160 update(result);
161
162 disableConstInvariant -= skipConstInvariant;
163 }
164
165 // Every value is assigned a unique id based on their order of appearance.
166 unsigned currentIndex = 0;
167 unsigned currentIndexConstInvariant = 0;
168 DenseMap<Value, unsigned> indices;
169 DenseMap<Value, unsigned> indicesConstInvariant;
170
171 unsigned disableConstInvariant = 0;
172
173 // This is the actual running hash calculation. This is a stateful element
174 // that should be reinitialized after each hash is produced.
175 llvm::SHA256 state;
176 llvm::SHA256 stateConstInvariant;
177};
178} // namespace
179
180namespace {
181struct StructuralEquivalence {
182 using OpOperandPair = std::pair<OpOperand *, OpOperand *>;
183 explicit StructuralEquivalence(MLIRContext *context) {}
184
185 void check(DefineOp arcA, DefineOp arcB) {
186 if (!checkImpl(arcA, arcB)) {
187 match = false;
188 matchConstInvariant = false;
189 }
190 }
191
192 SmallSetVector<OpOperandPair, 1> divergences;
193 bool match;
194 bool matchConstInvariant;
195
196private:
197 bool addBlockToWorklist(Block &blockA, Block &blockB) {
198 auto *terminatorA = blockA.getTerminator();
199 auto *terminatorB = blockB.getTerminator();
200 if (!compareOps(terminatorA, terminatorB, OpOperandPair()))
201 return false;
202 if (!addOpToWorklist(terminatorA, terminatorB))
203 return false;
204 // TODO: We should probably bail out if there are any operations in the
205 // block that aren't in the fan-in of the terminator.
206 return true;
207 }
208
209 bool addOpToWorklist(Operation *opA, Operation *opB,
210 bool *allOperandsHandled = nullptr) {
211 if (opA->getNumOperands() != opB->getNumOperands())
212 return false;
213 for (auto [operandA, operandB] :
214 llvm::zip(opA->getOpOperands(), opB->getOpOperands())) {
215 if (!handled.count({&operandA, &operandB})) {
216 worklist.emplace_back(&operandA, &operandB);
217 if (allOperandsHandled)
218 *allOperandsHandled = false;
219 }
220 }
221 return true;
222 }
223
224 bool compareOps(Operation *opA, Operation *opB, OpOperandPair values) {
225 if (opA->getName() != opB->getName())
226 return false;
227 if (opA->getAttrDictionary() != opB->getAttrDictionary()) {
228 for (auto [namedAttrA, namedAttrB] :
229 llvm::zip(opA->getAttrDictionary(), opB->getAttrDictionary())) {
230 if (namedAttrA.getName() != namedAttrB.getName())
231 return false;
232 if (namedAttrA.getValue() == namedAttrB.getValue())
233 continue;
234 bool mayDiverge = opA->hasTrait<OpTrait::ConstantLike>();
235 if (!mayDiverge || !values.first || !values.second)
236 return false;
237 divergences.insert(values);
238 match = false;
239 break;
240 }
241 }
242 return true;
243 }
244
245 bool checkImpl(DefineOp arcA, DefineOp arcB) {
246 worklist.clear();
247 divergences.clear();
248 match = true;
249 matchConstInvariant = true;
250 handled.clear();
251
252 if (arcA.getFunctionType().getResults() !=
253 arcB.getFunctionType().getResults())
254 return false;
255
256 if (!addBlockToWorklist(arcA.getBodyBlock(), arcB.getBodyBlock()))
257 return false;
258
259 while (!worklist.empty()) {
260 OpOperandPair values = worklist.back();
261 if (handled.contains(values)) {
262 worklist.pop_back();
263 continue;
264 }
265
266 auto valueA = values.first->get();
267 auto valueB = values.second->get();
268 if (valueA.getType() != valueB.getType())
269 return false;
270 auto *opA = valueA.getDefiningOp();
271 auto *opB = valueB.getDefiningOp();
272
273 // Handle the case where one or both values are block arguments.
274 if (!opA || !opB) {
275 auto argA = dyn_cast<BlockArgument>(valueA);
276 auto argB = dyn_cast<BlockArgument>(valueB);
277 if (argA && argB) {
278 divergences.insert(values);
279 if (argA.getArgNumber() != argB.getArgNumber())
280 match = false;
281 handled.insert(values);
282 worklist.pop_back();
283 continue;
284 }
285 auto isConstA = opA && opA->hasTrait<OpTrait::ConstantLike>();
286 auto isConstB = opB && opB->hasTrait<OpTrait::ConstantLike>();
287 if ((argA && isConstB) || (argB && isConstA)) {
288 // One value is a block argument, one is a constant.
289 divergences.insert(values);
290 match = false;
291 handled.insert(values);
292 worklist.pop_back();
293 continue;
294 }
295 return false;
296 }
297
298 // Go through all operands push the ones we haven't visited yet onto the
299 // worklist so they get processed before we continue.
300 bool allHandled = true;
301 if (!addOpToWorklist(opA, opB, &allHandled))
302 return false;
303 if (!allHandled)
304 continue;
305 handled.insert(values);
306 worklist.pop_back();
307
308 // Compare the two operations and check that they are equal.
309 if (!compareOps(opA, opB, values))
310 return false;
311
312 // Descend into subregions of the operation.
313 if (opA->getNumRegions() != opB->getNumRegions())
314 return false;
315 for (auto [regionA, regionB] :
316 llvm::zip(opA->getRegions(), opB->getRegions())) {
317 if (regionA.getBlocks().size() != regionB.getBlocks().size())
318 return false;
319 for (auto [blockA, blockB] : llvm::zip(regionA, regionB))
320 if (!addBlockToWorklist(blockA, blockB))
321 return false;
322 }
323 }
324
325 return true;
326 }
327
328 SmallVector<OpOperandPair, 0> worklist;
329 DenseSet<OpOperandPair> handled;
330};
331} // namespace
332
334 SmallSetVector<mlir::CallOpInterface, 1> &callSites,
335 ArrayRef<std::variant<Operation *, unsigned>> operandMappings) {
337 SmallVector<Value> newOperands;
338 for (auto callOp : callSites) {
339 OpBuilder builder(callOp);
340 newOperands.clear();
341 clonedOps.clear();
342 for (auto mapping : operandMappings) {
343 if (std::holds_alternative<Operation *>(mapping)) {
344 auto *op = std::get<Operation *>(mapping);
345 auto &newOp = clonedOps[op];
346 if (!newOp)
347 newOp = builder.clone(*op);
348 newOperands.push_back(newOp->getResult(0));
349 } else {
350 newOperands.push_back(
351 callOp.getArgOperands()[std::get<unsigned>(mapping)]);
352 }
353 }
354 callOp.getArgOperandsMutable().assign(newOperands);
355 }
356}
357
358static bool isOutlinable(OpOperand &operand) {
359 auto *op = operand.get().getDefiningOp();
360 return !op || op->hasTrait<OpTrait::ConstantLike>();
361}
362
363namespace {
364struct DedupPass : public arc::impl::DedupBase<DedupPass> {
365 void runOnOperation() override;
366 void replaceArcWith(DefineOp oldArc, DefineOp newArc,
367 SymbolTableCollection &symbolTable);
368
369 /// A mapping from arc names to arc definitions.
370 DenseMap<StringAttr, DefineOp> arcByName;
371 /// A mapping from arc definitions to call sites.
372 DenseMap<DefineOp, SmallSetVector<mlir::CallOpInterface, 1>> callSites;
373};
374
375struct ArcHash {
376 DefineOp defineOp;
377 StructuralHash hash;
378 unsigned order;
379 ArcHash(DefineOp defineOp, StructuralHash hash, unsigned order)
380 : defineOp(defineOp), hash(hash), order(order) {}
381};
382} // namespace
383
384void DedupPass::runOnOperation() {
385 arcByName.clear();
386 callSites.clear();
387 SymbolTableCollection symbolTable;
388
389 // Compute the structural hash for each arc definition.
390 SmallVector<ArcHash> arcHashes;
391 StructuralHasher hasher(&getContext());
392 for (auto defineOp : getOperation().getOps<DefineOp>()) {
393 arcHashes.emplace_back(defineOp, hasher.hash(defineOp), arcHashes.size());
394 arcByName.insert({defineOp.getSymNameAttr(), defineOp});
395 }
396
397 // Collect the arc call sites.
398 getOperation().walk([&](mlir::CallOpInterface callOp) {
399 if (auto defOp = dyn_cast_or_null<DefineOp>(
400 callOp.resolveCallableInTable(&symbolTable)))
401 callSites[defOp].insert(callOp);
402 });
403
404 // Sort the arcs by hash such that arcs with the same hash are next to each
405 // other, and sort arcs with the same hash by order in which they appear in
406 // the input. This allows us to iterate through the list and check
407 // neighbouring arcs for merge opportunities.
408 llvm::stable_sort(arcHashes, [](auto a, auto b) {
409 if (a.hash.hash < b.hash.hash)
410 return true;
411 if (a.hash.hash > b.hash.hash)
412 return false;
413 return a.order < b.order;
414 });
415
416 // Perform deduplications that do not require modification of the arc call
417 // sites. (No additional ports.)
418 LLVM_DEBUG(llvm::dbgs() << "Check for exact merges (" << arcHashes.size()
419 << " arcs)\n");
420 StructuralEquivalence equiv(&getContext());
421 for (unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
422 ++arcIdx) {
423 auto [defineOp, hash, order] = arcHashes[arcIdx];
424 if (!defineOp)
425 continue;
426 for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
427 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
428 if (hash.hash != otherHash.hash)
429 break;
430 if (!otherDefineOp)
431 continue;
432 equiv.check(defineOp, otherDefineOp);
433 if (!equiv.match)
434 continue;
435 LLVM_DEBUG(llvm::dbgs()
436 << "- Merge " << defineOp.getSymNameAttr() << " <- "
437 << otherDefineOp.getSymNameAttr() << "\n");
438 replaceArcWith(otherDefineOp, defineOp, symbolTable);
439 arcHashes[otherIdx].defineOp = {};
440 }
441 }
442
443 // The initial pass over the arcs has set the `defineOp` to null for every arc
444 // that was already merged. Now sort the list of arcs as follows:
445 // - All merged arcs are moved to the back of the list (`!defineOp`)
446 // - Sort unmerged arcs by const-invariant hash
447 // - Sort arcs with same hash by order in which they appear in the input
448 // This allows us to pop the merged arcs off of the back of the list. Then we
449 // can iterate through the list and check neighbouring arcs for merge
450 // opportunities.
451 llvm::stable_sort(arcHashes, [](auto a, auto b) {
452 if (!a.defineOp && !b.defineOp)
453 return false;
454 if (!a.defineOp)
455 return false;
456 if (!b.defineOp)
457 return true;
458 if (a.hash.constInvariant < b.hash.constInvariant)
459 return true;
460 if (a.hash.constInvariant > b.hash.constInvariant)
461 return false;
462 return a.order < b.order;
463 });
464 while (!arcHashes.empty() && !arcHashes.back().defineOp)
465 arcHashes.pop_back();
466
467 // Perform deduplication of arcs that differ only in constant values.
468 LLVM_DEBUG(llvm::dbgs() << "Check for constant-agnostic merges ("
469 << arcHashes.size() << " arcs)\n");
470 for (unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
471 ++arcIdx) {
472 auto [defineOp, hash, order] = arcHashes[arcIdx];
473 if (!defineOp)
474 continue;
475
476 // Perform an initial pass over all other arcs with identical
477 // const-invariant hash. Check for equivalence between the current arc
478 // (`defineOp`) and the other arc (`otherDefineOp`). In case they match
479 // iterate over the list of divergences which holds all non-identical
480 // OpOperand pairs in the two arcs. These can come in different forms:
481 //
482 // - (const, const): Both arcs have the operand set to a constant, but the
483 // constant value differs. We'll want to extract these constants.
484 // - (arg, const): The current arc has a block argument where the other has
485 // a constant. No changes needed; when we replace the uses of the other
486 // arc with the current one further done we can use the existing
487 // argument to pass in that constant.
488 // - (const, arg): The current arc has a constant where the other has a
489 // block argument. We'll want to extract this constant and replace it
490 // with a block argument. This will allow the other arc to be replaced
491 // with the current one.
492 // - (arg, arg): Both arcs have the operand set to a block argument, but
493 // they are different argument numbers. This can happen if for example
494 // one of the arcs uses a single argument in two op operands and the
495 // other arc has two separate arguments for the two op operands. We'll
496 // want to ensure the current arc has two arguments in this case, such
497 // that the two can dedup.
498 //
499 // Whenever an op operand is involved in such a divergence we add it to the
500 // list of operands that must be mapped to a distinct block argument. Later
501 // we'll go through this list and add additional block arguments as
502 // necessary.
503 SmallMapVector<OpOperand *, unsigned, 8> outlineOperands;
504 unsigned nextGroupId = 1;
505 SmallMapVector<Value,
506 SmallMapVector<Value, SmallSetVector<OpOperand *, 1>, 2>, 2>
507 operandMappings;
508 SmallVector<StringAttr> candidateNames;
509
510 for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
511 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
512 if (hash.constInvariant != otherHash.constInvariant)
513 break;
514 if (!otherDefineOp)
515 continue;
516
517 equiv.check(defineOp, otherDefineOp);
518 if (!equiv.matchConstInvariant)
519 continue;
520 candidateNames.push_back(otherDefineOp.getSymNameAttr());
521
522 // Iterate over the matching operand pairs ("divergences"), look up the
523 // value pair the operands are set to, and then store the current arc's
524 // operand in the set that corresponds to this value pair. This builds up
525 // `operandMappings` to contain sets of op operands in the current arc
526 // that can be routed out to the same block argument. If a block argument
527 // of the current arc corresponds to multiple different things in the
528 // other arc, this ensures that all unique such combinations get grouped
529 // in distinct sets such that we can create an appropriate number of new
530 // block args.
531 operandMappings.clear();
532 for (auto [operand, otherOperand] : equiv.divergences) {
533 if (!isOutlinable(*operand) || !isOutlinable(*otherOperand))
534 continue;
535 operandMappings[operand->get()][otherOperand->get()].insert(operand);
536 }
537
538 // Go through the sets of operands that can map to the same block argument
539 // for the combination of current and other arc. Assign all operands in
540 // each set new unique group IDs. If the operands in the set have multiple
541 // IDs, allocate multiple new unique group IDs. This fills the
542 // `outlineOperands` map with operands and their corresponding group ID.
543 // If we find multiple other arcs that we can potentially combine with the
544 // current arc, the operands get distributed into more and more smaller
545 // groups. For example, in arc A we can assign operands X and Y to the
546 // same block argument, so we assign them the same ID; but in arc B we
547 // have to assign X and Y to different block arguments, at which point
548 // that same ID we assigned earlier gets reassigned to two new IDs, one
549 // for each operand.
550 for (auto &[value, mappings] : operandMappings) {
551 for (auto &[otherValue, operands] : mappings) {
552 SmallDenseMap<unsigned, unsigned> remappedGroupIds;
553 for (auto *operand : operands) {
554 auto &id = outlineOperands[operand];
555 auto &remappedId = remappedGroupIds[id];
556 if (remappedId == 0)
557 remappedId = nextGroupId++;
558 id = remappedId;
559 }
560 }
561 }
562 }
563
564 if (outlineOperands.empty())
565 continue;
566 LLVM_DEBUG({
567 llvm::dbgs() << "- Outlining " << outlineOperands.size()
568 << " operands from " << defineOp.getSymNameAttr() << "\n";
569 for (auto entry : outlineOperands)
570 llvm::dbgs() << " - Operand #" << entry.first->getOperandNumber()
571 << " of " << *entry.first->getOwner() << "\n";
572 for (auto name : candidateNames)
573 llvm::dbgs() << " - Candidate " << name << "\n";
574 });
575
576 // Sort the operands to be outlined. The order is already deterministic at
577 // this point, but is not really correlated to the existing block argument
578 // order since we gathered these operands by traversing the operations
579 // depth-first. Establish an order that first honors existing argument order
580 // (putting constants at the back), and then considers the order of
581 // operations and op operands.
582 llvm::stable_sort(outlineOperands, [](auto &a, auto &b) {
583 auto argA = dyn_cast<BlockArgument>(a.first->get());
584 auto argB = dyn_cast<BlockArgument>(b.first->get());
585 if (argA && !argB)
586 return true;
587 if (!argA && argB)
588 return false;
589 if (argA && argB) {
590 if (argA.getArgNumber() < argB.getArgNumber())
591 return true;
592 if (argA.getArgNumber() > argB.getArgNumber())
593 return false;
594 }
595 auto *opA = a.first->get().getDefiningOp();
596 auto *opB = b.first->get().getDefiningOp();
597 if (opA == opB)
598 return a.first->getOperandNumber() < b.first->getOperandNumber();
599 if (opA->getBlock() == opB->getBlock())
600 return opA->isBeforeInBlock(opB);
601 return false;
602 });
603
604 // Build a new set of arc arguments by iterating over the operands that we
605 // have determined must be exposed as arguments above. For each operand
606 // either reuse its existing block argument (if no other operand in the list
607 // has already reused it), or add a new argument for this operand. Also
608 // track how each argument must be connected at call sites (outlined
609 // constant op or reusing an existing operand).
610 unsigned oldArgumentCount = defineOp.getNumArguments();
611 SmallDenseMap<unsigned, Value> newArguments; // by group ID
612 SmallVector<Type> newInputTypes;
613 SmallVector<std::variant<Operation *, unsigned>> newOperands;
614 SmallPtrSet<Operation *, 8> outlinedOps;
615
616 for (auto [operand, groupId] : outlineOperands) {
617 auto &arg = newArguments[groupId];
618 if (!arg) {
619 auto value = operand->get();
620 arg = defineOp.getBodyBlock().addArgument(value.getType(),
621 value.getLoc());
622 newInputTypes.push_back(arg.getType());
623 if (auto blockArg = dyn_cast<BlockArgument>(value))
624 newOperands.push_back(blockArg.getArgNumber());
625 else {
626 auto *op = value.getDefiningOp();
627 newOperands.push_back(op);
628 outlinedOps.insert(op);
629 }
630 }
631 operand->set(arg);
632 }
633
634 for (auto arg :
635 defineOp.getBodyBlock().getArguments().slice(0, oldArgumentCount)) {
636 if (!arg.use_empty()) {
637 auto d = defineOp.emitError(
638 "dedup failed to replace all argument uses; arc ")
639 << defineOp.getSymNameAttr() << ", argument "
640 << arg.getArgNumber();
641 for (auto &use : arg.getUses())
642 d.attachNote(use.getOwner()->getLoc())
643 << "used in operand " << use.getOperandNumber() << " here";
644 return signalPassFailure();
645 }
646 }
647
648 defineOp.getBodyBlock().eraseArguments(0, oldArgumentCount);
649 defineOp.setType(FunctionType::get(
650 &getContext(), newInputTypes, defineOp.getFunctionType().getResults()));
651 addCallSiteOperands(callSites[defineOp], newOperands);
652 for (auto *op : outlinedOps)
653 if (op->use_empty())
654 op->erase();
655
656 // Perform the actual deduplication with other arcs.
657 for (unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
658 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
659 if (hash.constInvariant != otherHash.constInvariant)
660 break;
661 if (!otherDefineOp)
662 continue;
663
664 // Check for structural equivalence between the two arcs.
665 equiv.check(defineOp, otherDefineOp);
666 if (!equiv.matchConstInvariant)
667 continue;
668
669 // Determine how the other arc's operands map to the arc we're trying to
670 // merge into.
671 std::variant<Operation *, unsigned> nullOperand = nullptr;
672 for (auto &operand : newOperands)
673 operand = nullptr;
674
675 bool mappingFailed = false;
676 for (auto [operand, otherOperand] : equiv.divergences) {
677 auto arg = dyn_cast<BlockArgument>(operand->get());
678 if (!arg || !isOutlinable(*otherOperand)) {
679 mappingFailed = true;
680 break;
681 }
682
683 // Determine how the other arc's operand maps to the new connection
684 // scheme of the current arc.
685 std::variant<Operation *, unsigned> newOperand;
686 if (auto otherArg = dyn_cast<BlockArgument>(otherOperand->get()))
687 newOperand = otherArg.getArgNumber();
688 else
689 newOperand = otherOperand->get().getDefiningOp();
690
691 // Ensure that there are no conflicting operand assignment.
692 auto &newOperandSlot = newOperands[arg.getArgNumber()];
693 if (newOperandSlot != nullOperand && newOperandSlot != newOperand) {
694 mappingFailed = true;
695 break;
696 }
697 newOperandSlot = newOperand;
698 }
699 if (mappingFailed) {
700 LLVM_DEBUG(llvm::dbgs() << " - Mapping failed; skipping arc\n");
701 continue;
702 }
703 if (llvm::any_of(newOperands,
704 [&](auto operand) { return operand == nullOperand; })) {
705 LLVM_DEBUG(llvm::dbgs()
706 << " - Not all operands mapped; skipping arc\n");
707 continue;
708 }
709
710 // Replace all uses of the other arc with the current arc.
711 LLVM_DEBUG(llvm::dbgs()
712 << " - Merged " << defineOp.getSymNameAttr() << " <- "
713 << otherDefineOp.getSymNameAttr() << "\n");
714 addCallSiteOperands(callSites[otherDefineOp], newOperands);
715 replaceArcWith(otherDefineOp, defineOp, symbolTable);
716 arcHashes[otherIdx].defineOp = {};
717 }
718 }
719}
720
721void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc,
722 SymbolTableCollection &symbolTable) {
723 ++dedupPassNumArcsDeduped;
724 auto oldArcOps = oldArc.getOps();
725 dedupPassTotalOps += std::distance(oldArcOps.begin(), oldArcOps.end());
726 auto &oldUses = callSites[oldArc];
727 auto &newUses = callSites[newArc];
728 auto newArcName = SymbolRefAttr::get(newArc.getSymNameAttr());
729 for (auto callOp : oldUses) {
730 callOp.setCalleeFromCallable(newArcName);
731 newUses.insert(callOp);
732 }
733
734 oldArc.walk([&](mlir::CallOpInterface callOp) {
735 if (auto defOp = dyn_cast_or_null<DefineOp>(
736 callOp.resolveCallableInTable(&symbolTable)))
737 callSites[defOp].remove(callOp);
738 });
739 callSites.erase(oldArc);
740 arcByName.erase(oldArc.getSymNameAttr());
741 oldArc->erase();
742}
743
744std::unique_ptr<Pass> arc::createDedupPass() {
745 return std::make_unique<DedupPass>();
746}
static void addCallSiteOperands(SmallSetVector< mlir::CallOpInterface, 1 > &callSites, ArrayRef< std::variant< Operation *, unsigned > > operandMappings)
Definition Dedup.cpp:333
static bool isOutlinable(OpOperand &operand)
Definition Dedup.cpp:358
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
std::unique_ptr< mlir::Pass > createDedupPass()
Definition Dedup.cpp:744
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
void update(OpOperand &operand)
Definition Dedup.cpp:179