11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/SHA256.h"
17 #define DEBUG_TYPE "arc-dedup"
21 #define GEN_PASS_DEF_DEDUP
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
26 using namespace circt;
29 using llvm::SmallMapVector;
30 using llvm::SmallSetVector;
33 struct StructuralHash {
34 using Hash = std::array<uint8_t, 32>;
42 StructuralHash hash(DefineOp arc) {
45 return StructuralHash{state.final(), stateConstInvariant.final()};
51 disableConstInvariant = 0;
53 indicesConstInvariant.clear();
55 stateConstInvariant.init();
58 void update(
const void *pointer) {
59 auto *
addr =
reinterpret_cast<const uint8_t *
>(&pointer);
60 state.update(ArrayRef<uint8_t>(addr,
sizeof pointer));
61 if (disableConstInvariant == 0)
62 stateConstInvariant.update(ArrayRef<uint8_t>(addr,
sizeof pointer));
65 void update(
size_t value) {
66 auto *
addr =
reinterpret_cast<const uint8_t *
>(&value);
67 state.update(ArrayRef<uint8_t>(addr,
sizeof value));
68 if (disableConstInvariant == 0)
69 stateConstInvariant.update(ArrayRef<uint8_t>(addr,
sizeof value));
72 void update(
size_t value,
size_t valueConstInvariant) {
73 state.update(ArrayRef<uint8_t>(
reinterpret_cast<const uint8_t *
>(&value),
75 state.update(ArrayRef<uint8_t>(
76 reinterpret_cast<const uint8_t *
>(&valueConstInvariant),
77 sizeof valueConstInvariant));
80 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
82 void update(Type type) { update(type.getAsOpaquePointer()); }
84 void update(Attribute attr) { update(attr.getAsOpaquePointer()); }
86 void update(mlir::OperationName name) { update(name.getAsOpaquePointer()); }
88 void update(BlockArgument arg) { update(arg.getType()); }
90 void update(OpResult result) { update(result.getType()); }
92 void update(OpOperand &operand) {
94 auto it = indices.find(operand.get());
95 auto itCI = indicesConstInvariant.find(operand.get());
96 assert(it != indices.end() && itCI != indicesConstInvariant.end() &&
97 "op should have been previously hashed");
98 update(it->second, itCI->second);
101 void update(Block &block) {
105 for (
auto arg : block.getArguments()) {
106 indices.insert({arg, currentIndex++});
107 indicesConstInvariant.insert({arg, 0});
109 for (
auto &op : block) {
110 for (
auto result : op.getResults()) {
111 indices.insert({result, currentIndex++});
112 if (op.hasTrait<OpTrait::ConstantLike>())
113 indicesConstInvariant.insert({result, 0});
115 indicesConstInvariant.insert({result, currentIndexConstInvariant++});
120 ++disableConstInvariant;
121 for (
auto arg : block.getArguments())
123 --disableConstInvariant;
126 for (
auto &op : block)
130 void update(Operation *op) {
131 unsigned skipConstInvariant = op->hasTrait<OpTrait::ConstantLike>();
132 disableConstInvariant += skipConstInvariant;
134 update(op->getName());
137 if (!isa<DefineOp>(op)) {
138 for (
auto namedAttr : op->getAttrDictionary()) {
139 auto name = namedAttr.getName();
140 auto value = namedAttr.getValue();
143 update(name.getAsOpaquePointer());
144 update(value.getAsOpaquePointer());
149 for (
auto &operand : op->getOpOperands())
153 update(op->getNumRegions());
154 for (
auto ®ion : op->getRegions())
155 for (
auto &block : region.getBlocks())
158 for (
auto result : op->getResults())
161 disableConstInvariant -= skipConstInvariant;
165 unsigned currentIndex = 0;
166 unsigned currentIndexConstInvariant = 0;
167 DenseMap<Value, unsigned> indices;
168 DenseMap<Value, unsigned> indicesConstInvariant;
170 unsigned disableConstInvariant = 0;
175 llvm::SHA256 stateConstInvariant;
180 struct StructuralEquivalence {
181 using OpOperandPair = std::pair<OpOperand *, OpOperand *>;
182 explicit StructuralEquivalence(MLIRContext *context) {}
184 void check(DefineOp arcA, DefineOp arcB) {
185 if (!checkImpl(arcA, arcB)) {
187 matchConstInvariant =
false;
191 SmallSetVector<OpOperandPair, 1> divergences;
193 bool matchConstInvariant;
196 bool addBlockToWorklist(Block &blockA, Block &blockB) {
197 auto *terminatorA = blockA.getTerminator();
198 auto *terminatorB = blockB.getTerminator();
199 if (!compareOps(terminatorA, terminatorB, OpOperandPair()))
201 if (!addOpToWorklist(terminatorA, terminatorB))
208 bool addOpToWorklist(Operation *opA, Operation *opB,
209 bool *allOperandsHandled =
nullptr) {
210 if (opA->getNumOperands() != opB->getNumOperands())
212 for (
auto [operandA, operandB] :
213 llvm::zip(opA->getOpOperands(), opB->getOpOperands())) {
214 if (!handled.count({&operandA, &operandB})) {
215 worklist.emplace_back(&operandA, &operandB);
216 if (allOperandsHandled)
217 *allOperandsHandled =
false;
223 bool compareOps(Operation *opA, Operation *opB, OpOperandPair values) {
224 if (opA->getName() != opB->getName())
226 if (opA->getAttrDictionary() != opB->getAttrDictionary()) {
227 for (
auto [namedAttrA, namedAttrB] :
228 llvm::zip(opA->getAttrDictionary(), opB->getAttrDictionary())) {
229 if (namedAttrA.getName() != namedAttrB.getName())
231 if (namedAttrA.getValue() == namedAttrB.getValue())
233 bool mayDiverge = opA->hasTrait<OpTrait::ConstantLike>();
234 if (!mayDiverge || !values.first || !values.second)
236 divergences.insert(values);
244 bool checkImpl(DefineOp arcA, DefineOp arcB) {
248 matchConstInvariant =
true;
251 if (arcA.getFunctionType().getResults() !=
252 arcB.getFunctionType().getResults())
255 if (!addBlockToWorklist(arcA.getBodyBlock(), arcB.getBodyBlock()))
258 while (!worklist.empty()) {
259 OpOperandPair values = worklist.back();
260 if (handled.contains(values)) {
265 auto valueA = values.first->get();
266 auto valueB = values.second->get();
267 if (valueA.getType() != valueB.getType())
269 auto *opA = valueA.getDefiningOp();
270 auto *opB = valueB.getDefiningOp();
274 auto argA = dyn_cast<BlockArgument>(valueA);
275 auto argB = dyn_cast<BlockArgument>(valueB);
277 divergences.insert(values);
278 if (argA.getArgNumber() != argB.getArgNumber())
280 handled.insert(values);
284 auto isConstA = opA && opA->hasTrait<OpTrait::ConstantLike>();
285 auto isConstB = opB && opB->hasTrait<OpTrait::ConstantLike>();
286 if ((argA && isConstB) || (argB && isConstA)) {
288 divergences.insert(values);
290 handled.insert(values);
299 bool allHandled =
true;
300 if (!addOpToWorklist(opA, opB, &allHandled))
304 handled.insert(values);
308 if (!compareOps(opA, opB, values))
312 if (opA->getNumRegions() != opB->getNumRegions())
314 for (
auto [regionA, regionB] :
315 llvm::zip(opA->getRegions(), opB->getRegions())) {
316 if (regionA.getBlocks().size() != regionB.getBlocks().size())
318 for (
auto [blockA, blockB] : llvm::zip(regionA, regionB))
319 if (!addBlockToWorklist(blockA, blockB))
327 SmallVector<OpOperandPair, 0> worklist;
328 DenseSet<OpOperandPair> handled;
333 SmallSetVector<mlir::CallOpInterface, 1> &callSites,
334 ArrayRef<std::variant<Operation *, unsigned>> operandMappings) {
336 SmallVector<Value> newOperands;
337 for (
auto callOp : callSites) {
341 for (
auto mapping : operandMappings) {
342 if (std::holds_alternative<Operation *>(mapping)) {
343 auto *op = std::get<Operation *>(mapping);
344 auto &newOp = clonedOps[op];
347 newOperands.push_back(newOp->getResult(0));
349 newOperands.push_back(
350 callOp.getArgOperands()[std::get<unsigned>(mapping)]);
353 callOp.getArgOperandsMutable().assign(newOperands);
358 auto *op = operand.get().getDefiningOp();
359 return !op || op->hasTrait<OpTrait::ConstantLike>();
363 struct DedupPass :
public arc::impl::DedupBase<DedupPass> {
364 void runOnOperation()
override;
365 void replaceArcWith(DefineOp oldArc, DefineOp newArc,
366 SymbolTableCollection &symbolTable);
369 DenseMap<StringAttr, DefineOp> arcByName;
371 DenseMap<DefineOp, SmallSetVector<mlir::CallOpInterface, 1>> callSites;
378 ArcHash(DefineOp defineOp, StructuralHash hash,
unsigned order)
379 : defineOp(defineOp), hash(hash), order(order) {}
383 void DedupPass::runOnOperation() {
386 SymbolTableCollection symbolTable;
389 SmallVector<ArcHash> arcHashes;
391 for (
auto defineOp : getOperation().getOps<DefineOp>()) {
392 arcHashes.emplace_back(defineOp, hasher.hash(defineOp), arcHashes.size());
393 arcByName.insert({defineOp.getSymNameAttr(), defineOp});
397 getOperation().walk([&](mlir::CallOpInterface callOp) {
399 dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
400 callSites[defOp].insert(callOp);
407 llvm::stable_sort(arcHashes, [](
auto a,
auto b) {
408 if (a.hash.hash < b.hash.hash)
410 if (a.hash.hash > b.hash.hash)
412 return a.order < b.order;
417 LLVM_DEBUG(llvm::dbgs() <<
"Check for exact merges (" << arcHashes.size()
419 StructuralEquivalence equiv(&getContext());
420 for (
unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
422 auto [defineOp, hash, order] = arcHashes[arcIdx];
425 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
426 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
427 if (hash.hash != otherHash.hash)
431 equiv.check(defineOp, otherDefineOp);
434 LLVM_DEBUG(llvm::dbgs()
435 <<
"- Merge " << defineOp.getSymNameAttr() <<
" <- "
436 << otherDefineOp.getSymNameAttr() <<
"\n");
437 replaceArcWith(otherDefineOp, defineOp, symbolTable);
438 arcHashes[otherIdx].defineOp = {};
450 llvm::stable_sort(arcHashes, [](
auto a,
auto b) {
451 if (!a.defineOp && !b.defineOp)
457 if (a.hash.constInvariant < b.hash.constInvariant)
459 if (a.hash.constInvariant > b.hash.constInvariant)
461 return a.order < b.order;
463 while (!arcHashes.empty() && !arcHashes.back().defineOp)
464 arcHashes.pop_back();
467 LLVM_DEBUG(llvm::dbgs() <<
"Check for constant-agnostic merges ("
468 << arcHashes.size() <<
" arcs)\n");
469 for (
unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
471 auto [defineOp, hash, order] = arcHashes[arcIdx];
502 SmallMapVector<OpOperand *, unsigned, 8> outlineOperands;
503 unsigned nextGroupId = 1;
504 SmallMapVector<Value,
505 SmallMapVector<Value, SmallSetVector<OpOperand *, 1>, 2>, 2>
507 SmallVector<StringAttr> candidateNames;
509 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
510 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
511 if (hash.constInvariant != otherHash.constInvariant)
516 equiv.check(defineOp, otherDefineOp);
517 if (!equiv.matchConstInvariant)
519 candidateNames.push_back(otherDefineOp.getSymNameAttr());
530 operandMappings.clear();
531 for (
auto [operand, otherOperand] : equiv.divergences) {
534 operandMappings[operand->get()][otherOperand->get()].insert(operand);
549 for (
auto &[value, mappings] : operandMappings) {
550 for (
auto &[otherValue, operands] : mappings) {
552 for (
auto *operand : operands) {
553 auto &
id = outlineOperands[operand];
554 auto &remappedId = remappedGroupIds[id];
556 remappedId = nextGroupId++;
563 if (outlineOperands.empty())
566 llvm::dbgs() <<
"- Outlining " << outlineOperands.size()
567 <<
" operands from " << defineOp.getSymNameAttr() <<
"\n";
568 for (
auto entry : outlineOperands)
569 llvm::dbgs() <<
" - Operand #" << entry.first->getOperandNumber()
570 <<
" of " << *entry.first->getOwner() <<
"\n";
571 for (
auto name : candidateNames)
572 llvm::dbgs() <<
" - Candidate " << name <<
"\n";
581 llvm::stable_sort(outlineOperands, [](
auto &a,
auto &b) {
582 auto argA = dyn_cast<BlockArgument>(a.first->get());
583 auto argB = dyn_cast<BlockArgument>(b.first->get());
589 if (argA.getArgNumber() < argB.getArgNumber())
591 if (argA.getArgNumber() > argB.getArgNumber())
594 auto *opA = a.first->get().getDefiningOp();
595 auto *opB = b.first->get().getDefiningOp();
597 return a.first->getOperandNumber() < b.first->getOperandNumber();
598 if (opA->getBlock() == opB->getBlock())
599 return opA->isBeforeInBlock(opB);
609 unsigned oldArgumentCount = defineOp.getNumArguments();
611 SmallVector<Type> newInputTypes;
612 SmallVector<std::variant<Operation *, unsigned>> newOperands;
613 SmallPtrSet<Operation *, 8> outlinedOps;
615 for (
auto [operand, groupId] : outlineOperands) {
616 auto &arg = newArguments[groupId];
618 auto value = operand->get();
619 arg = defineOp.getBodyBlock().addArgument(value.getType(),
621 newInputTypes.push_back(arg.getType());
622 if (
auto blockArg = dyn_cast<BlockArgument>(value))
623 newOperands.push_back(blockArg.getArgNumber());
625 auto *op = value.getDefiningOp();
626 newOperands.push_back(op);
627 outlinedOps.insert(op);
634 defineOp.getBodyBlock().getArguments().slice(0, oldArgumentCount)) {
635 if (!arg.use_empty()) {
636 auto d = defineOp.emitError(
637 "dedup failed to replace all argument uses; arc ")
638 << defineOp.getSymNameAttr() <<
", argument "
639 << arg.getArgNumber();
640 for (
auto &use : arg.getUses())
641 d.attachNote(use.getOwner()->getLoc())
642 <<
"used in operand " << use.getOperandNumber() <<
" here";
643 return signalPassFailure();
647 defineOp.getBodyBlock().eraseArguments(0, oldArgumentCount);
649 &getContext(), newInputTypes, defineOp.getFunctionType().getResults()));
651 for (
auto *op : outlinedOps)
656 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
657 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
658 if (hash.constInvariant != otherHash.constInvariant)
664 equiv.check(defineOp, otherDefineOp);
665 if (!equiv.matchConstInvariant)
670 std::variant<Operation *, unsigned> nullOperand =
nullptr;
671 for (
auto &operand : newOperands)
674 bool mappingFailed =
false;
675 for (
auto [operand, otherOperand] : equiv.divergences) {
676 auto arg = dyn_cast<BlockArgument>(operand->get());
678 mappingFailed =
true;
684 std::variant<Operation *, unsigned> newOperand;
685 if (
auto otherArg = dyn_cast<BlockArgument>(otherOperand->get()))
686 newOperand = otherArg.getArgNumber();
688 newOperand = otherOperand->get().getDefiningOp();
691 auto &newOperandSlot = newOperands[arg.getArgNumber()];
692 if (newOperandSlot != nullOperand && newOperandSlot != newOperand) {
693 mappingFailed =
true;
696 newOperandSlot = newOperand;
699 LLVM_DEBUG(llvm::dbgs() <<
" - Mapping failed; skipping arc\n");
702 if (llvm::any_of(newOperands,
703 [&](
auto operand) {
return operand == nullOperand; })) {
704 LLVM_DEBUG(llvm::dbgs()
705 <<
" - Not all operands mapped; skipping arc\n");
710 LLVM_DEBUG(llvm::dbgs()
711 <<
" - Merged " << defineOp.getSymNameAttr() <<
" <- "
712 << otherDefineOp.getSymNameAttr() <<
"\n");
714 replaceArcWith(otherDefineOp, defineOp, symbolTable);
715 arcHashes[otherIdx].defineOp = {};
720 void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc,
721 SymbolTableCollection &symbolTable) {
722 ++dedupPassNumArcsDeduped;
723 auto oldArcOps = oldArc.getOps();
724 dedupPassTotalOps += std::distance(oldArcOps.begin(), oldArcOps.end());
725 auto &oldUses = callSites[oldArc];
726 auto &newUses = callSites[newArc];
728 for (
auto callOp : oldUses) {
729 callOp.setCalleeFromCallable(newArcName);
730 newUses.insert(callOp);
733 oldArc.walk([&](mlir::CallOpInterface callOp) {
735 dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
736 callSites[defOp].remove(callOp);
738 callSites.erase(oldArc);
739 arcByName.erase(oldArc.getSymNameAttr());
744 return std::make_unique<DedupPass>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createDedupPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.