11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/SHA256.h"
18 #define DEBUG_TYPE "arc-dedup"
22 #define GEN_PASS_DEF_DEDUP
23 #include "circt/Dialect/Arc/ArcPasses.h.inc"
27 using namespace circt;
30 using llvm::SmallMapVector;
31 using llvm::SmallSetVector;
34 struct StructuralHash {
35 using Hash = std::array<uint8_t, 32>;
43 StructuralHash hash(DefineOp arc) {
46 return StructuralHash{state.final(), stateConstInvariant.final()};
52 disableConstInvariant = 0;
54 indicesConstInvariant.clear();
56 stateConstInvariant.init();
59 void update(
const void *pointer) {
60 auto *
addr =
reinterpret_cast<const uint8_t *
>(&pointer);
61 state.update(ArrayRef<uint8_t>(addr,
sizeof pointer));
62 if (disableConstInvariant == 0)
63 stateConstInvariant.update(ArrayRef<uint8_t>(addr,
sizeof pointer));
66 void update(
size_t value) {
67 auto *
addr =
reinterpret_cast<const uint8_t *
>(&value);
68 state.update(ArrayRef<uint8_t>(addr,
sizeof value));
69 if (disableConstInvariant == 0)
70 stateConstInvariant.update(ArrayRef<uint8_t>(addr,
sizeof value));
73 void update(
size_t value,
size_t valueConstInvariant) {
74 state.update(ArrayRef<uint8_t>(
reinterpret_cast<const uint8_t *
>(&value),
76 state.update(ArrayRef<uint8_t>(
77 reinterpret_cast<const uint8_t *
>(&valueConstInvariant),
78 sizeof valueConstInvariant));
81 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
83 void update(Type type) { update(type.getAsOpaquePointer()); }
85 void update(Attribute attr) { update(attr.getAsOpaquePointer()); }
87 void update(mlir::OperationName name) { update(name.getAsOpaquePointer()); }
89 void update(BlockArgument arg) { update(arg.getType()); }
91 void update(OpResult result) { update(result.getType()); }
93 void update(OpOperand &operand) {
95 auto it = indices.find(operand.get());
96 auto itCI = indicesConstInvariant.find(operand.get());
97 assert(it != indices.end() && itCI != indicesConstInvariant.end() &&
98 "op should have been previously hashed");
99 update(it->second, itCI->second);
102 void update(Block &block) {
106 for (
auto arg : block.getArguments()) {
107 indices.insert({arg, currentIndex++});
108 indicesConstInvariant.insert({arg, 0});
110 for (
auto &op : block) {
111 for (
auto result : op.getResults()) {
112 indices.insert({result, currentIndex++});
113 if (op.hasTrait<OpTrait::ConstantLike>())
114 indicesConstInvariant.insert({result, 0});
116 indicesConstInvariant.insert({result, currentIndexConstInvariant++});
121 ++disableConstInvariant;
122 for (
auto arg : block.getArguments())
124 --disableConstInvariant;
127 for (
auto &op : block)
131 void update(Operation *op) {
132 unsigned skipConstInvariant = op->hasTrait<OpTrait::ConstantLike>();
133 disableConstInvariant += skipConstInvariant;
135 update(op->getName());
138 if (!isa<DefineOp>(op)) {
139 for (
auto namedAttr : op->getAttrDictionary()) {
140 auto name = namedAttr.getName();
141 auto value = namedAttr.getValue();
144 update(name.getAsOpaquePointer());
145 update(value.getAsOpaquePointer());
150 for (
auto &operand : op->getOpOperands())
154 update(op->getNumRegions());
155 for (
auto ®ion : op->getRegions())
156 for (
auto &block : region.getBlocks())
159 for (
auto result : op->getResults())
162 disableConstInvariant -= skipConstInvariant;
166 unsigned currentIndex = 0;
167 unsigned currentIndexConstInvariant = 0;
168 DenseMap<Value, unsigned> indices;
169 DenseMap<Value, unsigned> indicesConstInvariant;
171 unsigned disableConstInvariant = 0;
176 llvm::SHA256 stateConstInvariant;
181 struct StructuralEquivalence {
182 using OpOperandPair = std::pair<OpOperand *, OpOperand *>;
183 explicit StructuralEquivalence(MLIRContext *context) {}
185 void check(DefineOp arcA, DefineOp arcB) {
186 if (!checkImpl(arcA, arcB)) {
188 matchConstInvariant =
false;
192 SmallSetVector<OpOperandPair, 1> divergences;
194 bool matchConstInvariant;
197 bool addBlockToWorklist(Block &blockA, Block &blockB) {
198 auto *terminatorA = blockA.getTerminator();
199 auto *terminatorB = blockB.getTerminator();
200 if (!compareOps(terminatorA, terminatorB, OpOperandPair()))
202 if (!addOpToWorklist(terminatorA, terminatorB))
209 bool addOpToWorklist(Operation *opA, Operation *opB,
210 bool *allOperandsHandled =
nullptr) {
211 if (opA->getNumOperands() != opB->getNumOperands())
213 for (
auto [operandA, operandB] :
214 llvm::zip(opA->getOpOperands(), opB->getOpOperands())) {
215 if (!handled.count({&operandA, &operandB})) {
216 worklist.emplace_back(&operandA, &operandB);
217 if (allOperandsHandled)
218 *allOperandsHandled =
false;
224 bool compareOps(Operation *opA, Operation *opB, OpOperandPair values) {
225 if (opA->getName() != opB->getName())
227 if (opA->getAttrDictionary() != opB->getAttrDictionary()) {
228 for (
auto [namedAttrA, namedAttrB] :
229 llvm::zip(opA->getAttrDictionary(), opB->getAttrDictionary())) {
230 if (namedAttrA.getName() != namedAttrB.getName())
232 if (namedAttrA.getValue() == namedAttrB.getValue())
234 bool mayDiverge = opA->hasTrait<OpTrait::ConstantLike>();
235 if (!mayDiverge || !values.first || !values.second)
237 divergences.insert(values);
245 bool checkImpl(DefineOp arcA, DefineOp arcB) {
249 matchConstInvariant =
true;
252 if (arcA.getFunctionType().getResults() !=
253 arcB.getFunctionType().getResults())
256 if (!addBlockToWorklist(arcA.getBodyBlock(), arcB.getBodyBlock()))
259 while (!worklist.empty()) {
260 OpOperandPair values = worklist.back();
261 if (handled.contains(values)) {
266 auto valueA = values.first->get();
267 auto valueB = values.second->get();
268 if (valueA.getType() != valueB.getType())
270 auto *opA = valueA.getDefiningOp();
271 auto *opB = valueB.getDefiningOp();
275 auto argA = dyn_cast<BlockArgument>(valueA);
276 auto argB = dyn_cast<BlockArgument>(valueB);
278 divergences.insert(values);
279 if (argA.getArgNumber() != argB.getArgNumber())
281 handled.insert(values);
285 auto isConstA = opA && opA->hasTrait<OpTrait::ConstantLike>();
286 auto isConstB = opB && opB->hasTrait<OpTrait::ConstantLike>();
287 if ((argA && isConstB) || (argB && isConstA)) {
289 divergences.insert(values);
291 handled.insert(values);
300 bool allHandled =
true;
301 if (!addOpToWorklist(opA, opB, &allHandled))
305 handled.insert(values);
309 if (!compareOps(opA, opB, values))
313 if (opA->getNumRegions() != opB->getNumRegions())
315 for (
auto [regionA, regionB] :
316 llvm::zip(opA->getRegions(), opB->getRegions())) {
317 if (regionA.getBlocks().size() != regionB.getBlocks().size())
319 for (
auto [blockA, blockB] : llvm::zip(regionA, regionB))
320 if (!addBlockToWorklist(blockA, blockB))
328 SmallVector<OpOperandPair, 0> worklist;
329 DenseSet<OpOperandPair> handled;
334 SmallSetVector<mlir::CallOpInterface, 1> &callSites,
335 ArrayRef<std::variant<Operation *, unsigned>> operandMappings) {
337 SmallVector<Value> newOperands;
338 for (
auto callOp : callSites) {
339 OpBuilder builder(callOp);
342 for (
auto mapping : operandMappings) {
343 if (std::holds_alternative<Operation *>(mapping)) {
344 auto *op = std::get<Operation *>(mapping);
345 auto &newOp = clonedOps[op];
347 newOp = builder.clone(*op);
348 newOperands.push_back(newOp->getResult(0));
350 newOperands.push_back(
351 callOp.getArgOperands()[std::get<unsigned>(mapping)]);
354 callOp.getArgOperandsMutable().assign(newOperands);
359 auto *op = operand.get().getDefiningOp();
360 return !op || op->hasTrait<OpTrait::ConstantLike>();
364 struct DedupPass :
public arc::impl::DedupBase<DedupPass> {
365 void runOnOperation()
override;
366 void replaceArcWith(DefineOp oldArc, DefineOp newArc,
367 SymbolTableCollection &symbolTable);
370 DenseMap<StringAttr, DefineOp> arcByName;
372 DenseMap<DefineOp, SmallSetVector<mlir::CallOpInterface, 1>> callSites;
379 ArcHash(DefineOp defineOp, StructuralHash hash,
unsigned order)
380 : defineOp(defineOp), hash(hash), order(order) {}
384 void DedupPass::runOnOperation() {
387 SymbolTableCollection symbolTable;
390 SmallVector<ArcHash> arcHashes;
392 for (
auto defineOp : getOperation().getOps<DefineOp>()) {
393 arcHashes.emplace_back(defineOp, hasher.hash(defineOp), arcHashes.size());
394 arcByName.insert({defineOp.getSymNameAttr(), defineOp});
398 getOperation().walk([&](mlir::CallOpInterface callOp) {
399 if (
auto defOp = dyn_cast_or_null<DefineOp>(
400 callOp.resolveCallableInTable(&symbolTable)))
401 callSites[defOp].insert(callOp);
408 llvm::stable_sort(arcHashes, [](
auto a,
auto b) {
409 if (a.hash.hash < b.hash.hash)
411 if (a.hash.hash > b.hash.hash)
413 return a.order < b.order;
418 LLVM_DEBUG(llvm::dbgs() <<
"Check for exact merges (" << arcHashes.size()
420 StructuralEquivalence equiv(&getContext());
421 for (
unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
423 auto [defineOp, hash, order] = arcHashes[arcIdx];
426 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
427 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
428 if (hash.hash != otherHash.hash)
432 equiv.check(defineOp, otherDefineOp);
435 LLVM_DEBUG(llvm::dbgs()
436 <<
"- Merge " << defineOp.getSymNameAttr() <<
" <- "
437 << otherDefineOp.getSymNameAttr() <<
"\n");
438 replaceArcWith(otherDefineOp, defineOp, symbolTable);
439 arcHashes[otherIdx].defineOp = {};
451 llvm::stable_sort(arcHashes, [](
auto a,
auto b) {
452 if (!a.defineOp && !b.defineOp)
458 if (a.hash.constInvariant < b.hash.constInvariant)
460 if (a.hash.constInvariant > b.hash.constInvariant)
462 return a.order < b.order;
464 while (!arcHashes.empty() && !arcHashes.back().defineOp)
465 arcHashes.pop_back();
468 LLVM_DEBUG(llvm::dbgs() <<
"Check for constant-agnostic merges ("
469 << arcHashes.size() <<
" arcs)\n");
470 for (
unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
472 auto [defineOp, hash, order] = arcHashes[arcIdx];
503 SmallMapVector<OpOperand *, unsigned, 8> outlineOperands;
504 unsigned nextGroupId = 1;
505 SmallMapVector<Value,
506 SmallMapVector<Value, SmallSetVector<OpOperand *, 1>, 2>, 2>
508 SmallVector<StringAttr> candidateNames;
510 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
511 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
512 if (hash.constInvariant != otherHash.constInvariant)
517 equiv.check(defineOp, otherDefineOp);
518 if (!equiv.matchConstInvariant)
520 candidateNames.push_back(otherDefineOp.getSymNameAttr());
531 operandMappings.clear();
532 for (
auto [operand, otherOperand] : equiv.divergences) {
535 operandMappings[operand->get()][otherOperand->get()].insert(operand);
550 for (
auto &[value, mappings] : operandMappings) {
551 for (
auto &[otherValue, operands] : mappings) {
553 for (
auto *operand : operands) {
554 auto &
id = outlineOperands[operand];
555 auto &remappedId = remappedGroupIds[id];
557 remappedId = nextGroupId++;
564 if (outlineOperands.empty())
567 llvm::dbgs() <<
"- Outlining " << outlineOperands.size()
568 <<
" operands from " << defineOp.getSymNameAttr() <<
"\n";
569 for (
auto entry : outlineOperands)
570 llvm::dbgs() <<
" - Operand #" << entry.first->getOperandNumber()
571 <<
" of " << *entry.first->getOwner() <<
"\n";
572 for (
auto name : candidateNames)
573 llvm::dbgs() <<
" - Candidate " << name <<
"\n";
582 llvm::stable_sort(outlineOperands, [](
auto &a,
auto &b) {
583 auto argA = dyn_cast<BlockArgument>(a.first->get());
584 auto argB = dyn_cast<BlockArgument>(b.first->get());
590 if (argA.getArgNumber() < argB.getArgNumber())
592 if (argA.getArgNumber() > argB.getArgNumber())
595 auto *opA = a.first->get().getDefiningOp();
596 auto *opB = b.first->get().getDefiningOp();
598 return a.first->getOperandNumber() < b.first->getOperandNumber();
599 if (opA->getBlock() == opB->getBlock())
600 return opA->isBeforeInBlock(opB);
610 unsigned oldArgumentCount = defineOp.getNumArguments();
612 SmallVector<Type> newInputTypes;
613 SmallVector<std::variant<Operation *, unsigned>> newOperands;
614 SmallPtrSet<Operation *, 8> outlinedOps;
616 for (
auto [operand, groupId] : outlineOperands) {
617 auto &arg = newArguments[groupId];
619 auto value = operand->get();
620 arg = defineOp.getBodyBlock().addArgument(value.getType(),
622 newInputTypes.push_back(arg.getType());
623 if (
auto blockArg = dyn_cast<BlockArgument>(value))
624 newOperands.push_back(blockArg.getArgNumber());
626 auto *op = value.getDefiningOp();
627 newOperands.push_back(op);
628 outlinedOps.insert(op);
635 defineOp.getBodyBlock().getArguments().slice(0, oldArgumentCount)) {
636 if (!arg.use_empty()) {
637 auto d = defineOp.emitError(
638 "dedup failed to replace all argument uses; arc ")
639 << defineOp.getSymNameAttr() <<
", argument "
640 << arg.getArgNumber();
641 for (
auto &use : arg.getUses())
642 d.attachNote(use.getOwner()->getLoc())
643 <<
"used in operand " << use.getOperandNumber() <<
" here";
644 return signalPassFailure();
648 defineOp.getBodyBlock().eraseArguments(0, oldArgumentCount);
650 &getContext(), newInputTypes, defineOp.getFunctionType().getResults()));
652 for (
auto *op : outlinedOps)
657 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
658 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
659 if (hash.constInvariant != otherHash.constInvariant)
665 equiv.check(defineOp, otherDefineOp);
666 if (!equiv.matchConstInvariant)
671 std::variant<Operation *, unsigned> nullOperand =
nullptr;
672 for (
auto &operand : newOperands)
675 bool mappingFailed =
false;
676 for (
auto [operand, otherOperand] : equiv.divergences) {
677 auto arg = dyn_cast<BlockArgument>(operand->get());
679 mappingFailed =
true;
685 std::variant<Operation *, unsigned> newOperand;
686 if (
auto otherArg = dyn_cast<BlockArgument>(otherOperand->get()))
687 newOperand = otherArg.getArgNumber();
689 newOperand = otherOperand->get().getDefiningOp();
692 auto &newOperandSlot = newOperands[arg.getArgNumber()];
693 if (newOperandSlot != nullOperand && newOperandSlot != newOperand) {
694 mappingFailed =
true;
697 newOperandSlot = newOperand;
700 LLVM_DEBUG(llvm::dbgs() <<
" - Mapping failed; skipping arc\n");
703 if (llvm::any_of(newOperands,
704 [&](
auto operand) {
return operand == nullOperand; })) {
705 LLVM_DEBUG(llvm::dbgs()
706 <<
" - Not all operands mapped; skipping arc\n");
711 LLVM_DEBUG(llvm::dbgs()
712 <<
" - Merged " << defineOp.getSymNameAttr() <<
" <- "
713 << otherDefineOp.getSymNameAttr() <<
"\n");
715 replaceArcWith(otherDefineOp, defineOp, symbolTable);
716 arcHashes[otherIdx].defineOp = {};
721 void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc,
722 SymbolTableCollection &symbolTable) {
723 ++dedupPassNumArcsDeduped;
724 auto oldArcOps = oldArc.getOps();
725 dedupPassTotalOps += std::distance(oldArcOps.begin(), oldArcOps.end());
726 auto &oldUses = callSites[oldArc];
727 auto &newUses = callSites[newArc];
729 for (
auto callOp : oldUses) {
730 callOp.setCalleeFromCallable(newArcName);
731 newUses.insert(callOp);
734 oldArc.walk([&](mlir::CallOpInterface callOp) {
735 if (
auto defOp = dyn_cast_or_null<DefineOp>(
736 callOp.resolveCallableInTable(&symbolTable)))
737 callSites[defOp].remove(callOp);
739 callSites.erase(oldArc);
740 arcByName.erase(oldArc.getSymNameAttr());
745 return std::make_unique<DedupPass>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createDedupPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.