11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/SHA256.h"
17 #define DEBUG_TYPE "arc-dedup"
21 #define GEN_PASS_DEF_DEDUP
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
26 using namespace circt;
29 using llvm::SmallMapVector;
30 using llvm::SmallSetVector;
33 struct StructuralHash {
34 using Hash = std::array<uint8_t, 32>;
42 StructuralHash hash(DefineOp arc) {
45 return StructuralHash{state.final(), stateConstInvariant.final()};
51 disableConstInvariant = 0;
53 indicesConstInvariant.clear();
55 stateConstInvariant.init();
58 void update(
const void *pointer) {
59 auto *
addr =
reinterpret_cast<const uint8_t *
>(&pointer);
60 state.update(ArrayRef<uint8_t>(addr,
sizeof pointer));
61 if (disableConstInvariant == 0)
62 stateConstInvariant.update(ArrayRef<uint8_t>(addr,
sizeof pointer));
65 void update(
size_t value) {
66 auto *
addr =
reinterpret_cast<const uint8_t *
>(&
value);
67 state.update(ArrayRef<uint8_t>(addr,
sizeof value));
68 if (disableConstInvariant == 0)
69 stateConstInvariant.update(ArrayRef<uint8_t>(addr,
sizeof value));
72 void update(
size_t value,
size_t valueConstInvariant) {
73 state.update(ArrayRef<uint8_t>(
reinterpret_cast<const uint8_t *
>(&
value),
75 state.update(ArrayRef<uint8_t>(
76 reinterpret_cast<const uint8_t *
>(&valueConstInvariant),
77 sizeof valueConstInvariant));
80 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
82 void update(Type type) { update(type.getAsOpaquePointer()); }
84 void update(Attribute attr) { update(attr.getAsOpaquePointer()); }
86 void update(mlir::OperationName name) { update(name.getAsOpaquePointer()); }
88 void update(BlockArgument arg) { update(arg.getType()); }
90 void update(OpResult result) { update(result.getType()); }
92 void update(OpOperand &operand) {
94 auto it = indices.find(operand.get());
95 auto itCI = indicesConstInvariant.find(operand.get());
96 assert(it != indices.end() && itCI != indicesConstInvariant.end() &&
97 "op should have been previously hashed");
98 update(it->second, itCI->second);
101 void update(Block &block) {
105 for (
auto arg : block.getArguments()) {
106 indices.insert({arg, currentIndex++});
107 indicesConstInvariant.insert({arg, 0});
109 for (
auto &op : block) {
110 for (
auto result : op.getResults()) {
111 indices.insert({result, currentIndex++});
112 if (op.hasTrait<OpTrait::ConstantLike>())
113 indicesConstInvariant.insert({result, 0});
115 indicesConstInvariant.insert({result, currentIndexConstInvariant++});
120 ++disableConstInvariant;
121 for (
auto arg : block.getArguments())
123 --disableConstInvariant;
126 for (
auto &op : block)
130 void update(Operation *op) {
131 unsigned skipConstInvariant = op->hasTrait<OpTrait::ConstantLike>();
132 disableConstInvariant += skipConstInvariant;
134 update(op->getName());
137 if (!isa<DefineOp>(op)) {
138 for (
auto namedAttr : op->getAttrDictionary()) {
139 auto name = namedAttr.getName();
140 auto value = namedAttr.getValue();
143 update(name.getAsOpaquePointer());
144 update(
value.getAsOpaquePointer());
149 for (
auto &operand : op->getOpOperands())
153 update(op->getNumRegions());
154 for (
auto ®ion : op->getRegions())
155 for (
auto &block : region.getBlocks())
158 for (
auto result : op->getResults())
161 disableConstInvariant -= skipConstInvariant;
165 unsigned currentIndex = 0;
166 unsigned currentIndexConstInvariant = 0;
167 DenseMap<Value, unsigned> indices;
168 DenseMap<Value, unsigned> indicesConstInvariant;
170 unsigned disableConstInvariant = 0;
175 llvm::SHA256 stateConstInvariant;
180 struct StructuralEquivalence {
181 using OpOperandPair = std::pair<OpOperand *, OpOperand *>;
182 explicit StructuralEquivalence(MLIRContext *context) {}
184 void check(DefineOp arcA, DefineOp arcB) {
185 if (!checkImpl(arcA, arcB)) {
187 matchConstInvariant =
false;
191 SmallSetVector<OpOperandPair, 1> divergences;
193 bool matchConstInvariant;
196 bool addBlockToWorklist(Block &blockA, Block &blockB) {
197 auto *terminatorA = blockA.getTerminator();
198 auto *terminatorB = blockB.getTerminator();
199 if (!compareOps(terminatorA, terminatorB, OpOperandPair()))
201 if (!addOpToWorklist(terminatorA, terminatorB))
208 bool addOpToWorklist(Operation *opA, Operation *opB,
209 bool *allOperandsHandled =
nullptr) {
210 if (opA->getNumOperands() != opB->getNumOperands())
212 for (
auto [operandA, operandB] :
213 llvm::zip(opA->getOpOperands(), opB->getOpOperands())) {
214 if (!handled.count({&operandA, &operandB})) {
215 worklist.emplace_back(&operandA, &operandB);
216 if (allOperandsHandled)
217 *allOperandsHandled =
false;
223 bool compareOps(Operation *opA, Operation *opB, OpOperandPair values) {
224 if (opA->getName() != opB->getName())
226 if (opA->getAttrDictionary() != opB->getAttrDictionary()) {
227 for (
auto [namedAttrA, namedAttrB] :
228 llvm::zip(opA->getAttrDictionary(), opB->getAttrDictionary())) {
229 if (namedAttrA.getName() != namedAttrB.getName())
231 if (namedAttrA.getValue() == namedAttrB.getValue())
233 bool mayDiverge = opA->hasTrait<OpTrait::ConstantLike>();
234 if (!mayDiverge || !values.first || !values.second)
236 divergences.insert(values);
244 bool checkImpl(DefineOp arcA, DefineOp arcB) {
248 matchConstInvariant =
true;
251 if (arcA.getFunctionType().getResults() !=
252 arcB.getFunctionType().getResults())
255 if (!addBlockToWorklist(arcA.getBodyBlock(), arcB.getBodyBlock()))
258 while (!worklist.empty()) {
259 OpOperandPair values = worklist.back();
260 if (handled.contains(values)) {
265 auto valueA = values.first->get();
266 auto valueB = values.second->get();
267 if (valueA.getType() != valueB.getType())
269 auto *opA = valueA.getDefiningOp();
270 auto *opB = valueB.getDefiningOp();
274 auto argA = valueA.dyn_cast<BlockArgument>();
275 auto argB = valueB.dyn_cast<BlockArgument>();
277 divergences.insert(values);
278 if (argA.getArgNumber() != argB.getArgNumber())
280 handled.insert(values);
284 auto isConstA = opA && opA->hasTrait<OpTrait::ConstantLike>();
285 auto isConstB = opB && opB->hasTrait<OpTrait::ConstantLike>();
286 if ((argA && isConstB) || (argB && isConstA)) {
288 divergences.insert(values);
290 handled.insert(values);
299 bool allHandled =
true;
300 if (!addOpToWorklist(opA, opB, &allHandled))
304 handled.insert(values);
308 if (!compareOps(opA, opB, values))
312 if (opA->getNumRegions() != opB->getNumRegions())
314 for (
auto [regionA, regionB] :
315 llvm::zip(opA->getRegions(), opB->getRegions())) {
316 if (regionA.getBlocks().size() != regionB.getBlocks().size())
318 for (
auto [blockA, blockB] : llvm::zip(regionA, regionB))
319 if (!addBlockToWorklist(blockA, blockB))
327 SmallVector<OpOperandPair, 0> worklist;
328 DenseSet<OpOperandPair> handled;
333 MutableArrayRef<mlir::CallOpInterface> callSites,
334 ArrayRef<std::variant<Operation *, unsigned>> operandMappings) {
336 SmallVector<Value> newOperands;
337 for (
auto &callOp : callSites) {
341 for (
auto mapping : operandMappings) {
342 if (std::holds_alternative<Operation *>(mapping)) {
343 auto *op = std::get<Operation *>(mapping);
344 auto &newOp = clonedOps[op];
347 newOperands.push_back(newOp->getResult(0));
349 newOperands.push_back(
350 callOp.getArgOperands()[std::get<unsigned>(mapping)]);
353 callOp.getArgOperandsMutable().assign(newOperands);
358 auto *op = operand.get().getDefiningOp();
359 return !op || op->hasTrait<OpTrait::ConstantLike>();
363 struct DedupPass :
public arc::impl::DedupBase<DedupPass> {
364 void runOnOperation()
override;
365 void replaceArcWith(DefineOp oldArc, DefineOp newArc);
368 DenseMap<StringAttr, DefineOp> arcByName;
370 DenseMap<DefineOp, SmallVector<mlir::CallOpInterface, 1>> callSites;
377 ArcHash(DefineOp defineOp, StructuralHash hash,
unsigned order)
378 : defineOp(defineOp), hash(hash), order(order) {}
382 void DedupPass::runOnOperation() {
385 SymbolTableCollection symbolTable;
388 SmallVector<ArcHash> arcHashes;
390 for (
auto defineOp : getOperation().getOps<DefineOp>()) {
391 arcHashes.emplace_back(defineOp, hasher.hash(defineOp), arcHashes.size());
392 arcByName.insert({defineOp.getSymNameAttr(), defineOp});
396 getOperation().walk([&](mlir::CallOpInterface callOp) {
398 dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
399 callSites[arcByName.lookup(callOp.getCallableForCallee()
400 .get<mlir::SymbolRefAttr>()
401 .getLeafReference())]
409 llvm::stable_sort(arcHashes, [](
auto a,
auto b) {
410 if (a.hash.hash < b.hash.hash)
412 if (a.hash.hash > b.hash.hash)
414 return a.order < b.order;
419 LLVM_DEBUG(
llvm::dbgs() <<
"Check for exact merges (" << arcHashes.size()
421 StructuralEquivalence equiv(&getContext());
422 for (
unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
424 auto [defineOp, hash, order] = arcHashes[arcIdx];
427 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
428 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
429 if (hash.hash != otherHash.hash)
433 equiv.check(defineOp, otherDefineOp);
437 <<
"- Merge " << defineOp.getSymNameAttr() <<
" <- "
438 << otherDefineOp.getSymNameAttr() <<
"\n");
439 replaceArcWith(otherDefineOp, defineOp);
440 arcHashes[otherIdx].defineOp = {};
452 llvm::stable_sort(arcHashes, [](
auto a,
auto b) {
453 if (!a.defineOp && !b.defineOp)
459 if (a.hash.constInvariant < b.hash.constInvariant)
461 if (a.hash.constInvariant > b.hash.constInvariant)
463 return a.order < b.order;
465 while (!arcHashes.empty() && !arcHashes.back().defineOp)
466 arcHashes.pop_back();
469 LLVM_DEBUG(
llvm::dbgs() <<
"Check for constant-agnostic merges ("
470 << arcHashes.size() <<
" arcs)\n");
471 for (
unsigned arcIdx = 0, arcEnd = arcHashes.size(); arcIdx != arcEnd;
473 auto [defineOp, hash, order] = arcHashes[arcIdx];
504 SmallMapVector<OpOperand *, unsigned, 8> outlineOperands;
505 unsigned nextGroupId = 1;
506 SmallMapVector<Value,
507 SmallMapVector<Value, SmallSetVector<OpOperand *, 1>, 2>, 2>
509 SmallVector<StringAttr> candidateNames;
511 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
512 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
513 if (hash.constInvariant != otherHash.constInvariant)
518 equiv.check(defineOp, otherDefineOp);
519 if (!equiv.matchConstInvariant)
521 candidateNames.push_back(otherDefineOp.getSymNameAttr());
532 operandMappings.clear();
533 for (
auto [operand, otherOperand] : equiv.divergences) {
536 operandMappings[operand->get()][otherOperand->get()].insert(operand);
551 for (
auto &[
value, mappings] : operandMappings) {
552 for (
auto &[otherValue, operands] : mappings) {
554 for (
auto *operand : operands) {
555 auto &
id = outlineOperands[operand];
556 auto &remappedId = remappedGroupIds[id];
558 remappedId = nextGroupId++;
565 if (outlineOperands.empty())
568 llvm::dbgs() <<
"- Outlining " << outlineOperands.size()
569 <<
" operands from " << defineOp.getSymNameAttr() <<
"\n";
570 for (
auto entry : outlineOperands)
571 llvm::dbgs() <<
" - Operand #" << entry.first->getOperandNumber()
572 <<
" of " << *entry.first->getOwner() <<
"\n";
573 for (
auto name : candidateNames)
574 llvm::dbgs() <<
" - Candidate " << name <<
"\n";
583 llvm::stable_sort(outlineOperands, [](
auto &a,
auto &b) {
584 auto argA = a.first->get().
template dyn_cast<BlockArgument>();
585 auto argB = b.first->get().
template dyn_cast<BlockArgument>();
591 if (argA.getArgNumber() < argB.getArgNumber())
593 if (argA.getArgNumber() > argB.getArgNumber())
596 auto *opA = a.first->get().getDefiningOp();
597 auto *opB = b.first->get().getDefiningOp();
599 return a.first->getOperandNumber() < b.first->getOperandNumber();
600 if (opA->getBlock() == opB->getBlock())
601 return opA->isBeforeInBlock(opB);
611 unsigned oldArgumentCount = defineOp.getNumArguments();
613 SmallVector<Type> newInputTypes;
614 SmallVector<std::variant<Operation *, unsigned>> newOperands;
615 SmallPtrSet<Operation *, 8> outlinedOps;
617 for (
auto [operand, groupId] : outlineOperands) {
618 auto &arg = newArguments[groupId];
620 auto value = operand->get();
621 arg = defineOp.getBodyBlock().addArgument(
value.getType(),
623 newInputTypes.push_back(arg.getType());
624 if (
auto blockArg =
value.dyn_cast<BlockArgument>())
625 newOperands.push_back(blockArg.getArgNumber());
627 auto *op =
value.getDefiningOp();
628 newOperands.push_back(op);
629 outlinedOps.insert(op);
636 defineOp.getBodyBlock().getArguments().slice(0, oldArgumentCount)) {
637 if (!arg.use_empty()) {
638 auto d = defineOp.emitError(
639 "dedup failed to replace all argument uses; arc ")
640 << defineOp.getSymNameAttr() <<
", argument "
641 << arg.getArgNumber();
642 for (
auto &use : arg.getUses())
643 d.attachNote(use.getOwner()->getLoc())
644 <<
"used in operand " << use.getOperandNumber() <<
" here";
645 return signalPassFailure();
649 defineOp.getBodyBlock().eraseArguments(0, oldArgumentCount);
651 &getContext(), newInputTypes, defineOp.getFunctionType().getResults()));
653 for (
auto *op : outlinedOps)
658 for (
unsigned otherIdx = arcIdx + 1; otherIdx != arcEnd; ++otherIdx) {
659 auto [otherDefineOp, otherHash, otherOrder] = arcHashes[otherIdx];
660 if (hash.constInvariant != otherHash.constInvariant)
666 equiv.check(defineOp, otherDefineOp);
667 if (!equiv.matchConstInvariant)
672 std::variant<Operation *, unsigned> nullOperand =
nullptr;
673 for (
auto &operand : newOperands)
676 bool mappingFailed =
false;
677 for (
auto [operand, otherOperand] : equiv.divergences) {
678 auto arg = operand->get().dyn_cast<BlockArgument>();
680 mappingFailed =
true;
686 std::variant<Operation *, unsigned> newOperand;
687 if (
auto otherArg = otherOperand->get().dyn_cast<BlockArgument>())
688 newOperand = otherArg.getArgNumber();
690 newOperand = otherOperand->get().getDefiningOp();
693 auto &newOperandSlot = newOperands[arg.getArgNumber()];
694 if (newOperandSlot != nullOperand && newOperandSlot != newOperand) {
695 mappingFailed =
true;
698 newOperandSlot = newOperand;
701 LLVM_DEBUG(
llvm::dbgs() <<
" - Mapping failed; skipping arc\n");
704 if (llvm::any_of(newOperands,
705 [&](
auto operand) {
return operand == nullOperand; })) {
707 <<
" - Not all operands mapped; skipping arc\n");
713 <<
" - Merged " << defineOp.getSymNameAttr() <<
" <- "
714 << otherDefineOp.getSymNameAttr() <<
"\n");
716 replaceArcWith(otherDefineOp, defineOp);
717 arcHashes[otherIdx].defineOp = {};
722 void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc) {
723 ++dedupPassNumArcsDeduped;
724 auto oldArcOps = oldArc.getOps();
725 dedupPassTotalOps += std::distance(oldArcOps.begin(), oldArcOps.end());
726 auto &oldUses = callSites[oldArc];
727 auto &newUses = callSites[newArc];
729 for (
auto callOp : oldUses) {
730 callOp.setCalleeFromCallable(newArcName);
731 newUses.push_back(callOp);
733 callSites.erase(oldArc);
734 arcByName.erase(oldArc.getSymNameAttr());
739 return std::make_unique<DedupPass>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createDedupPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()