24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/Matchers.h"
27#include "llvm/ADT/APSInt.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/SmallSet.h"
30#include "llvm/Support/Debug.h"
32#define DEBUG_TYPE "firrtl-reductions"
36using namespace firrtl;
38using llvm::SmallDenseSet;
39using llvm::SmallSetVector;
52 return tables->getSymbolTable(op);
62 return userMaps.insert({op, SymbolUserMap(*
tables, op)}).first->second;
69 tables = std::make_unique<SymbolTableCollection>();
74 std::unique_ptr<SymbolTableCollection>
tables;
81static std::optional<firrtl::FModuleOp>
84 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
85 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
87 return moduleOp ? std::optional(moduleOp) : std::nullopt;
98 module->walk([&](Operation *op) {
100 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(op))
114 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
115 auto *op = use.getOwner();
116 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
118 if (use.getOperandNumber() != 0)
120 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
137 unsigned numRemoved = 0;
139 SymbolTableCollection symbolTables;
140 for (Operation &rootOp : *
module.getBody()) {
141 if (!isa<firrtl::CircuitOp>(&rootOp))
143 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
144 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
146 if (
auto *op = symbolTable.lookup(sym)) {
147 if (symbolUserMap.useEmpty(op)) {
156 if (numRemoved > 0 || numLost > 0) {
157 llvm::dbgs() <<
"Removed " << numRemoved <<
" NLAs";
159 llvm::dbgs() <<
" (" << numLost <<
" no longer there)";
160 llvm::dbgs() <<
"\n";
169 if (
auto dict = dyn_cast<DictionaryAttr>(anno)) {
170 if (
auto field = dict.getAs<FlatSymbolRefAttr>(
"circt.nonlocal"))
171 nlasToRemove.insert(field.getAttr());
172 for (
auto namedAttr : dict)
173 markNLAsInAnnotation(namedAttr.getValue());
174 }
else if (
auto array = dyn_cast<ArrayAttr>(anno)) {
175 for (
auto attr : array)
176 markNLAsInAnnotation(attr);
184 op->walk([&](Operation *op) {
185 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
186 markNLAsInAnnotation(annos);
201struct FIRRTLModuleExternalizer :
public OpReduction<FModuleOp> {
208 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
210 uint64_t
match(FModuleOp module)
override {
211 if (innerSymUses.hasInnerRef(module))
213 return moduleSizes.getModuleSize(module, symbols);
216 LogicalResult
rewrite(FModuleOp module)
override {
219 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
220 for (
auto attr :
module.getPortTypes()) {
221 auto type = cast<TypeAttr>(attr).getValue();
222 if (
auto refType = type_dyn_cast<RefType>(type))
223 if (
auto layer = refType.getLayer())
224 layers.insert(layer);
226 SmallVector<Attribute, 4> layersArray;
227 layersArray.reserve(layers.size());
228 for (
auto layer : layers)
229 layersArray.push_back(layer);
231 nlaRemover.markNLAsInOperation(module);
232 OpBuilder builder(module);
233 auto extmodule = FExtModuleOp::create(
234 builder, module->getLoc(),
235 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
236 module.getConventionAttr(), module.getPorts(),
237 builder.getArrayAttr(layersArray), StringRef(),
238 module.getAnnotationsAttr());
239 SymbolTable::setSymbolVisibility(extmodule,
240 SymbolTable::getSymbolVisibility(module));
245 std::string
getName()
const override {
return "firrtl-module-externalizer"; }
257static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
260 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
265 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
266 for (
auto element :
llvm::enumerate(bundleType.getElements())) {
268 builder.createOrFold<firrtl::SubfieldOp>(value, element.index());
269 invalidateOutputs(builder, subfield, invalidCache,
270 flip ^ element.value().isFlip);
271 if (subfield.use_empty())
272 subfield.getDefiningOp()->erase();
278 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
279 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
280 auto subindex = builder.createOrFold<firrtl::SubindexOp>(value, i);
281 invalidateOutputs(builder, subindex, invalidCache,
flip);
282 if (subindex.use_empty())
283 subindex.getDefiningOp()->erase();
291 Value invalid = invalidCache.lookup(type);
293 invalid = firrtl::InvalidValueOp::create(builder, type);
294 invalidCache.insert({type, invalid});
296 firrtl::ConnectOp::create(builder, value, invalid);
300static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
302 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
305 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
306 for (
auto element :
llvm::enumerate(bundleType.getElements()))
307 connectToLeafs(builder,
308 firrtl::SubfieldOp::create(builder, dest, element.index()),
312 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
313 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
314 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
318 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
321 auto destWidth = type.getBitWidthOrSentinel();
322 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
323 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
324 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
325 if (!isa<firrtl::UIntType>(type)) {
326 if (isa<firrtl::SIntType>(type))
327 value = firrtl::AsSIntPrimOp::create(builder, value);
331 firrtl::ConnectOp::create(builder, dest, value);
335static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
336 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
339 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
340 for (
auto element :
llvm::enumerate(bundleType.getElements()))
343 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
346 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
347 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
348 reduceXor(builder, into,
349 builder.createOrFold<firrtl::SubindexOp>(value, i));
352 if (!isa<firrtl::UIntType>(type)) {
353 if (isa<firrtl::SIntType>(type))
354 value = firrtl::AsUIntPrimOp::create(builder, value);
358 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
364struct InstanceStubber :
public OpReduction<firrtl::InstanceOp> {
367 erasedModules.clear();
375 SmallVector<Operation *> worklist;
376 auto deadInsts = erasedInsts;
377 for (
auto *op : erasedModules)
378 worklist.push_back(op);
379 while (!worklist.empty()) {
380 auto *op = worklist.pop_back_val();
381 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
382 op->walk([&](firrtl::InstanceOp instOp) {
383 auto moduleOp = cast<firrtl::FModuleLike>(
384 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
385 deadInsts.insert(instOp);
387 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
388 [&](Operation *user) { return deadInsts.contains(user); })) {
389 LLVM_DEBUG(llvm::dbgs() <<
"- Removing transitively unused module `"
390 << moduleOp.getModuleName() <<
"`\n");
391 erasedModules.insert(moduleOp);
392 worklist.push_back(moduleOp);
397 for (
auto *op : erasedInsts)
399 for (
auto *op : erasedModules)
401 nlaRemover.remove(op);
404 uint64_t
match(firrtl::InstanceOp instOp)
override {
406 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
410 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
411 LLVM_DEBUG(llvm::dbgs()
412 <<
"Stubbing instance `" << instOp.getName() <<
"`\n");
413 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
415 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
416 auto result = instOp.getResult(i);
417 auto name = builder.getStringAttr(Twine(instOp.getName()) +
"_" +
418 instOp.getPortName(i));
420 firrtl::WireOp::create(builder, result.getType(), name,
421 firrtl::NameKindEnum::DroppableName,
422 instOp.getPortAnnotation(i), StringAttr{})
424 invalidateOutputs(builder, wire, invalidCache,
425 instOp.getPortDirection(i) == firrtl::Direction::In);
426 result.replaceAllUsesWith(wire);
428 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
429 auto moduleOp = cast<firrtl::FModuleLike>(
430 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
431 nlaRemover.markNLAsInOperation(instOp);
432 erasedInsts.insert(instOp);
434 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
435 [&](Operation *user) { return erasedInsts.contains(user); })) {
436 LLVM_DEBUG(llvm::dbgs() <<
"- Removing now unused module `"
437 << moduleOp.getModuleName() <<
"`\n");
438 erasedModules.insert(moduleOp);
443 std::string
getName()
const override {
return "instance-stubber"; }
448 llvm::DenseSet<Operation *> erasedInsts;
449 llvm::DenseSet<Operation *> erasedModules;
455struct MemoryStubber :
public OpReduction<firrtl::MemOp> {
456 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
457 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
458 LogicalResult
rewrite(firrtl::MemOp memOp)
override {
459 LLVM_DEBUG(llvm::dbgs() <<
"Stubbing memory `" << memOp.getName() <<
"`\n");
460 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
463 SmallVector<Value> outputs;
464 for (
unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
465 auto result = memOp.getResult(i);
466 auto name = builder.getStringAttr(Twine(memOp.getName()) +
"_" +
467 memOp.getPortName(i));
469 firrtl::WireOp::create(builder, result.getType(), name,
470 firrtl::NameKindEnum::DroppableName,
471 memOp.getPortAnnotation(i), StringAttr{})
473 invalidateOutputs(builder, wire, invalidCache,
true);
474 result.replaceAllUsesWith(wire);
478 switch (memOp.getPortKind(i)) {
479 case firrtl::MemOp::PortKind::Read:
480 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
482 case firrtl::MemOp::PortKind::Write:
483 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
485 case firrtl::MemOp::PortKind::ReadWrite:
486 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
487 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
489 case firrtl::MemOp::PortKind::Debug:
494 if (!isa<firrtl::RefType>(result.getType())) {
497 cast<firrtl::BundleType>(wire.getType()).getNumElements();
498 for (
unsigned i = 0; i != numFields; ++i) {
499 if (i != 2 && i != 3 && i != 5)
500 reduceXor(builder, xorInputs,
501 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
504 reduceXor(builder, xorInputs, input);
509 outputs.push_back(output);
513 for (
auto output : outputs)
514 connectToLeafs(builder, output, xorInputs);
516 nlaRemover.markNLAsInOperation(memOp);
520 std::string
getName()
const override {
return "memory-stubber"; }
527static bool isFlowSensitiveOp(Operation *op) {
528 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
529 SubaccessOp, ObjectSubfieldOp>(op);
535template <
unsigned OpNum>
536struct FIRRTLOperandForwarder :
public Reduction {
537 uint64_t
match(Operation *op)
override {
538 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
540 if (isFlowSensitiveOp(op))
543 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
545 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
546 return resultTy && opTy &&
547 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
548 (resultTy.getBitWidthOrSentinel() == -1) ==
549 (opTy.getBitWidthOrSentinel() == -1) &&
550 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
552 LogicalResult
rewrite(Operation *op)
override {
554 ImplicitLocOpBuilder builder(op->getLoc(), op);
555 auto result = op->getResult(0);
556 auto operand = op->getOperand(OpNum);
557 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
558 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
559 auto resultWidth = resultTy.getBitWidthOrSentinel();
560 auto operandWidth = operandTy.getBitWidthOrSentinel();
562 if (resultWidth < operandWidth)
564 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
565 else if (resultWidth > operandWidth)
566 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
569 LLVM_DEBUG(llvm::dbgs() <<
"Forwarding " << newOp <<
" in " << *op <<
"\n");
570 result.replaceAllUsesWith(newOp);
574 std::string
getName()
const override {
575 return (
"firrtl-operand" + Twine(OpNum) +
"-forwarder").str();
586 anyrefCastDummy.clear();
587 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
588 for (
auto classOp : circuitOp.getOps<ClassOp>()) {
589 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
590 anyrefCastDummy.insert({circuitOp, classOp});
591 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
594 return WalkResult::skip();
598 uint64_t
match(Operation *op)
override {
599 if (op->hasTrait<OpTrait::ConstantLike>()) {
601 if (!matchPattern(op, m_Constant(&attr)))
603 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
604 if (intAttr.getValue().isZero())
606 if (
auto strAttr = dyn_cast<StringAttr>(attr))
609 if (
auto floatAttr = dyn_cast<FloatAttr>(attr))
610 if (floatAttr.getValue().isZero())
613 if (
auto listOp = dyn_cast<ListCreateOp>(op))
614 if (listOp.getElements().empty())
616 if (
auto pathOp = dyn_cast<UnresolvedPathOp>(op))
617 if (pathOp.getTarget().empty())
621 if (
auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
622 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
624 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
625 if (anyrefCastDummyNames[circuitOp].contains(className))
629 if (op->getNumResults() != 1)
631 if (op->hasAttr(
"inner_sym"))
633 if (isFlowSensitiveOp(op))
635 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
636 DoubleType, ListType, PathType, AnyRefType>(
637 op->getResult(0).getType());
640 LogicalResult
rewrite(Operation *op)
override {
641 OpBuilder builder(op);
642 auto type = op->getResult(0).getType();
645 if (isa<UIntType, SIntType>(type)) {
646 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
649 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
650 APSInt(width, isa<UIntType>(type)));
651 op->replaceAllUsesWith(newOp);
657 if (isa<StringType>(type)) {
658 auto attr = builder.getStringAttr(
"");
659 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
660 op->replaceAllUsesWith(newOp);
666 if (isa<FIntegerType>(type)) {
667 auto attr = builder.getIntegerAttr(builder.getI64Type(), 0);
668 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
669 op->replaceAllUsesWith(newOp);
675 if (isa<BoolType>(type)) {
676 auto attr = builder.getBoolAttr(
false);
677 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
678 op->replaceAllUsesWith(newOp);
684 if (isa<DoubleType>(type)) {
685 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
686 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
687 op->replaceAllUsesWith(newOp);
693 if (isa<ListType>(type)) {
695 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
696 op->replaceAllUsesWith(newOp);
702 if (isa<PathType>(type)) {
703 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(),
"");
704 op->replaceAllUsesWith(newOp);
710 if (isa<AnyRefType>(type)) {
711 auto circuitOp = op->getParentOfType<CircuitOp>();
712 auto &dummy = anyrefCastDummy[circuitOp];
714 OpBuilder::InsertionGuard guard(builder);
715 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
716 auto &symbolTable = symbols.getNearestSymbolTable(op);
717 dummy = ClassOp::create(builder, op->getLoc(),
"Dummy", {}, {});
718 symbolTable.insert(dummy);
719 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
721 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy,
"dummy");
723 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
724 op->replaceAllUsesWith(anyrefOp);
732 std::string
getName()
const override {
return "firrtl-constantifier"; }
744struct ConnectInvalidator :
public Reduction {
745 uint64_t
match(Operation *op)
override {
746 if (!isa<FConnectLike>(op))
748 if (
auto *srcOp = op->getOperand(1).getDefiningOp())
749 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
750 isa<InvalidValueOp>(srcOp))
752 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
753 return type && type.isPassive();
755 LogicalResult
rewrite(Operation *op)
override {
757 auto rhs = op->getOperand(1);
758 OpBuilder builder(op);
759 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
760 auto *rhsOp = rhs.getDefiningOp();
761 op->setOperand(1, invOp);
766 std::string
getName()
const override {
return "connect-invalidator"; }
773struct AnnotationRemover :
public Reduction {
774 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
775 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
778 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
779 uint64_t matchId = 0;
782 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
783 for (
unsigned i = 0; i < annos.size(); ++i)
784 addMatch(1, matchId++);
787 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations"))
788 for (
auto portAnnoArray : portAnnos)
789 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
790 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
791 addMatch(1, matchId++);
795 ArrayRef<uint64_t> matches)
override {
797 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
800 uint64_t matchId = 0;
801 auto processAnnotations =
802 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
803 SmallVector<Attribute> newAnnotations;
804 for (
auto anno : annotations) {
805 if (!matchesSet.contains(matchId)) {
806 newAnnotations.push_back(anno);
809 nlaRemover.markNLAsInAnnotation(anno);
813 return ArrayAttr::get(op->getContext(), newAnnotations);
817 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations")) {
818 op->setAttr(
"annotations", processAnnotations(annos.getValue()));
822 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations")) {
823 SmallVector<Attribute> newPortAnnos;
824 for (
auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
825 newPortAnnos.push_back(
826 processAnnotations(portAnnoArrayAttr.getValue()));
828 op->setAttr(
"portAnnotations",
829 ArrayAttr::get(op->getContext(), newPortAnnos));
835 std::string
getName()
const override {
return "annotation-remover"; }
842struct SimplifyResets :
public OpReduction<CircuitOp> {
843 uint64_t
match(CircuitOp circuit)
override {
844 uint64_t numResets = 0;
845 AttrTypeWalker walker;
846 walker.addWalk([&](ResetType type) { ++numResets; });
848 circuit.walk([&](Operation *op) {
849 for (
auto result : op->getResults())
850 walker.walk(result.getType());
852 for (
auto ®ion : op->getRegions())
853 for (auto &block : region)
854 for (auto arg : block.getArguments())
855 walker.walk(arg.getType());
857 walker.walk(op->getAttrDictionary());
863 LogicalResult
rewrite(CircuitOp circuit)
override {
864 auto uint1Type = UIntType::get(circuit->getContext(), 1,
false);
865 auto constUint1Type = UIntType::get(circuit->getContext(), 1,
true);
867 AttrTypeReplacer replacer;
868 replacer.addReplacement([&](ResetType type) {
869 return type.isConst() ? constUint1Type : uint1Type;
871 replacer.recursivelyReplaceElementsIn(circuit,
true,
876 circuit.walk([&](Operation *op) {
885 if (
auto module = dyn_cast<FModuleLike>(op)) {
898 std::string
getName()
const override {
return "firrtl-simplify-resets"; }
904struct RootPortPruner :
public OpReduction<firrtl::FModuleOp> {
905 uint64_t
match(firrtl::FModuleOp module)
override {
906 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
909 return circuit.getNameAttr() ==
module.getNameAttr();
911 LogicalResult
rewrite(firrtl::FModuleOp module)
override {
913 size_t numPorts =
module.getNumPorts();
914 llvm::BitVector dropPorts(numPorts);
915 for (
unsigned i = 0; i != numPorts; ++i) {
919 llvm::make_early_inc_range(module.getArgument(i).getUsers()))
923 module.erasePorts(dropPorts);
926 std::string
getName()
const override {
return "root-port-pruner"; }
931struct ExtmoduleInstanceRemover :
public OpReduction<firrtl::InstanceOp> {
936 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
938 uint64_t
match(firrtl::InstanceOp instOp)
override {
939 return isa<firrtl::FExtModuleOp>(
940 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
942 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
944 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
945 symbols.getNearestSymbolTable(instOp)))
947 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
948 SmallVector<Value> replacementWires;
950 auto wire = firrtl::WireOp::create(
952 (Twine(instOp.getName()) +
"_" +
info.getName()).str())
954 if (
info.isOutput()) {
955 auto inv = firrtl::InvalidValueOp::create(builder,
info.type);
956 firrtl::ConnectOp::create(builder, wire, inv);
958 replacementWires.push_back(wire);
960 nlaRemover.markNLAsInOperation(instOp);
961 instOp.replaceAllUsesWith(std::move(replacementWires));
965 std::string
getName()
const override {
return "extmodule-instance-remover"; }
973struct ConnectForwarder :
public Reduction {
974 uint64_t
match(Operation *op)
override {
975 if (!isa<firrtl::FConnectLike>(op))
977 auto dest = op->getOperand(0);
978 auto src = op->getOperand(1);
979 auto *destOp = dest.getDefiningOp();
980 auto *srcOp = src.getDefiningOp();
986 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
992 unsigned numConnects = 0;
993 for (
auto &use : dest.getUses()) {
994 auto *op = use.getOwner();
995 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
996 if (++numConnects > 1)
1000 if (srcOp && !srcOp->isBeforeInBlock(op))
1007 LogicalResult
rewrite(Operation *op)
override {
1008 auto dst = op->getOperand(0);
1009 auto src = op->getOperand(1);
1010 dst.replaceAllUsesWith(src);
1012 if (
auto *dstOp = dst.getDefiningOp())
1014 if (
auto *srcOp = src.getDefiningOp())
1019 std::string
getName()
const override {
return "connect-forwarder"; }
1024template <
unsigned OpNum>
1025struct ConnectSourceOperandForwarder :
public Reduction {
1026 uint64_t
match(Operation *op)
override {
1027 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1029 auto dest = op->getOperand(0);
1030 auto *destOp = dest.getDefiningOp();
1033 if (!destOp || !destOp->hasOneUse() ||
1034 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1037 auto *srcOp = op->getOperand(1).getDefiningOp();
1038 if (!srcOp || OpNum >= srcOp->getNumOperands())
1041 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1043 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1045 return resultTy && opTy &&
1046 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1047 ((resultTy.getBitWidthOrSentinel() == -1) ==
1048 (opTy.getBitWidthOrSentinel() == -1)) &&
1049 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1052 LogicalResult
rewrite(Operation *op)
override {
1053 auto *destOp = op->getOperand(0).getDefiningOp();
1054 auto *srcOp = op->getOperand(1).getDefiningOp();
1055 auto forwardedOperand = srcOp->getOperand(OpNum);
1056 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1058 if (
auto wire = dyn_cast<firrtl::WireOp>(destOp))
1059 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1063 auto regName = destOp->getAttrOfType<StringAttr>(
"name");
1066 auto clock = destOp->getOperand(0);
1067 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1068 clock, regName ? regName.str() :
"")
1073 builder.setInsertionPointAfter(op);
1074 if (isa<firrtl::ConnectOp>(op))
1075 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1077 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1087 std::string
getName()
const override {
1088 return (
"connect-source-operand-" + Twine(OpNum) +
"-forwarder").str();
1095struct DetachSubaccesses :
public Reduction {
1096 void beforeReduction(mlir::ModuleOp op)
override { opsToErase.clear(); }
1098 for (
auto *op : opsToErase)
1099 op->dropAllReferences();
1100 for (
auto *op : opsToErase)
1103 uint64_t
match(Operation *op)
override {
1106 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1107 llvm::all_of(op->getUses(), [](
auto &use) {
1108 return use.getOperandNumber() == 0 &&
1109 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1110 firrtl::SubaccessOp>(use.getOwner());
1113 LogicalResult
rewrite(Operation *op)
override {
1115 OpBuilder builder(op);
1116 bool isWire = isa<firrtl::WireOp>(op);
1119 invalidClock = firrtl::InvalidValueOp::create(
1120 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1121 for (Operation *user :
llvm::make_early_inc_range(op->getUsers())) {
1122 builder.setInsertionPoint(user);
1123 auto type = user->getResult(0).getType();
1126 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1129 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1130 user->replaceAllUsesWith(replOp);
1131 opsToErase.insert(user);
1133 opsToErase.insert(op);
1136 std::string
getName()
const override {
return "detach-subaccesses"; }
1137 llvm::DenseSet<Operation *> opsToErase;
1143struct NodeSymbolRemover :
public Reduction {
1148 uint64_t
match(Operation *op)
override {
1150 auto sym = op->getAttrOfType<hw::InnerSymAttr>(
"inner_sym");
1151 if (!sym || sym.empty())
1155 if (innerSymUses.hasInnerRef(op))
1160 LogicalResult
rewrite(Operation *op)
override {
1161 op->removeAttr(
"inner_sym");
1165 std::string
getName()
const override {
return "node-symbol-remover"; }
1174hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1183 LogicalResult walkResult = targetTable.
walkSymbols(
1186 if (parentTable.lookup(name)) {
1194 return failed(walkResult);
1198struct EagerInliner :
public OpReduction<InstanceOp> {
1203 for (
auto circuitOp : op.getOps<CircuitOp>())
1204 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1205 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1208 nlaRemover.remove(op);
1210 innerSymTables.reset();
1213 uint64_t
match(InstanceOp instOp)
override {
1214 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1216 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1219 if (!isa<FModuleOp>(moduleOp))
1223 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1226 auto it = nlaTables.find(circuitOp);
1227 if (it == nlaTables.end() || !it->second)
1229 DenseSet<hw::HierPathOp> nlas;
1230 it->second->getInstanceNLAs(instOp, nlas);
1236 auto parentOp = instOp->getParentOfType<FModuleLike>();
1237 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1243 LogicalResult
rewrite(InstanceOp instOp)
override {
1244 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1245 auto moduleOp = cast<FModuleOp>(
1246 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1248 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1249 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1252 IRRewriter rewriter(instOp);
1253 SmallVector<Value> argWires;
1254 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1255 auto result = instOp.getResult(i);
1256 auto name = rewriter.getStringAttr(Twine(instOp.getName()) +
"_" +
1257 instOp.getPortName(i));
1258 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1259 name, NameKindEnum::DroppableName,
1260 instOp.getPortAnnotation(i), StringAttr{})
1262 result.replaceAllUsesWith(wire);
1263 argWires.push_back(wire);
1267 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1271 nlaRemover.markNLAsInOperation(instOp);
1273 nlaRemover.markNLAsInOperation(moduleOp);
1276 clonedModuleOp.erase();
1280 std::string
getName()
const override {
return "firrtl-eager-inliner"; }
1285 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1286 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1290struct ObjectInliner :
public OpReduction<ObjectOp> {
1292 blocksToSort.clear();
1295 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1298 for (
auto *block : blocksToSort)
1299 mlir::sortTopologically(block);
1300 blocksToSort.clear();
1301 nlaRemover.remove(op);
1302 innerSymTables.reset();
1305 uint64_t
match(ObjectOp objOp)
override {
1306 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1308 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1311 if (!isa<ClassOp>(classOp))
1316 auto parentOp = objOp->getParentOfType<FModuleLike>();
1317 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1321 for (
auto *user : objOp.getResult().getUsers())
1322 if (!isa<ObjectSubfieldOp>(user))
1328 LogicalResult
rewrite(ObjectOp objOp)
override {
1329 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1330 auto classOp = cast<ClassOp>(
1331 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1332 auto clonedClassOp = classOp.clone();
1335 IRRewriter rewriter(objOp);
1336 SmallVector<Value> portWires;
1337 auto classType = objOp.getType();
1340 for (
unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1341 auto element = classType.getElement(i);
1342 auto name = rewriter.getStringAttr(Twine(objOp.getName()) +
"_" +
1343 element.name.getValue());
1344 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1345 NameKindEnum::DroppableName,
1346 rewriter.getArrayAttr({}), StringAttr{})
1348 portWires.push_back(wire);
1352 SmallVector<ObjectSubfieldOp> subfieldOps;
1353 for (
auto *user : objOp.getResult().getUsers()) {
1354 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1355 subfieldOps.push_back(subfieldOp);
1356 auto index = subfieldOp.getIndex();
1357 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1361 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1367 SmallVector<FConnectLike> connectsToErase;
1368 for (
auto portWire : portWires) {
1372 for (
auto *user : portWire.getUsers()) {
1373 if (
auto connect = dyn_cast<FConnectLike>(user)) {
1374 if (
connect.getDest() == portWire) {
1376 connectsToErase.push_back(connect);
1386 portWire.replaceAllUsesWith(value);
1387 for (
auto connect : connectsToErase)
1389 if (portWire.use_empty())
1390 portWire.getDefiningOp()->erase();
1391 connectsToErase.clear();
1395 nlaRemover.markNLAsInOperation(objOp);
1400 blocksToSort.insert(objOp->getBlock());
1403 for (
auto subfieldOp : subfieldOps)
1406 clonedClassOp.erase();
1410 std::string
getName()
const override {
return "firrtl-object-inliner"; }
1413 SetVector<Block *> blocksToSort;
1416 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1430struct ModuleInternalNameSanitizer :
public Reduction {
1431 uint64_t
match(Operation *op)
override {
1433 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1434 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1435 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1436 firrtl::CoverOp>(op);
1438 LogicalResult
rewrite(Operation *op)
override {
1439 TypeSwitch<Operation *, void>(op)
1440 .Case<firrtl::WireOp>([](
auto op) { op.setName(
"wire"); })
1441 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1442 [](
auto op) { op.setName(
"reg"); })
1443 .Case<firrtl::NodeOp>([](
auto op) { op.setName(
"node"); })
1444 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1445 [](
auto op) { op.setName(
"mem"); })
1446 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](
auto op) {
1447 op->setAttr(
"message", StringAttr::get(op.getContext(),
""));
1448 op->setAttr(
"name", StringAttr::get(op.getContext(),
""));
1453 std::string
getName()
const override {
1454 return "module-internal-name-sanitizer";
1459 bool isOneShot()
const override {
return true; }
1473struct ModuleNameSanitizer :
OpReduction<firrtl::CircuitOp> {
1475 const char *names[48] = {
1476 "Foo",
"Bar",
"Baz",
"Qux",
"Quux",
"Quuux",
"Quuuux",
1477 "Quz",
"Corge",
"Grault",
"Bazola",
"Ztesch",
"Thud",
"Grunt",
1478 "Bletch",
"Fum",
"Fred",
"Jim",
"Sheila",
"Barney",
"Flarp",
1479 "Zxc",
"Spqr",
"Wombat",
"Shme",
"Bongo",
"Spam",
"Eggs",
1480 "Snork",
"Zot",
"Blarg",
"Wibble",
"Toto",
"Titi",
"Tata",
1481 "Tutu",
"Pippo",
"Pluto",
"Paperino",
"Aap",
"Noot",
"Mies",
1482 "Oogle",
"Foogle",
"Boogle",
"Zork",
"Gork",
"Bork"};
1484 size_t nameIndex = 0;
1487 if (nameIndex >= 48)
1489 return names[nameIndex++];
1492 size_t portNameIndex = 0;
1494 char getPortName() {
1495 if (portNameIndex >= 26)
1497 return 'a' + portNameIndex++;
1502 LogicalResult
rewrite(firrtl::CircuitOp circuitOp)
override {
1506 auto *circuitName =
getName();
1507 iGraph.getTopLevelModule().setName(circuitName);
1508 circuitOp.setName(circuitName);
1510 for (
auto *node : iGraph) {
1511 auto module = node->getModule<firrtl::FModuleLike>();
1513 bool shouldReplacePorts =
false;
1514 SmallVector<Attribute> newNames;
1515 if (
auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1520 auto oldPorts = fmodule.getPorts();
1521 shouldReplacePorts = !oldPorts.empty();
1522 for (
unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1523 auto port = oldPorts[i];
1525 .
Case<firrtl::ClockType>(
1526 [&](
auto a) {
return ns.
newName(
"clk"); })
1527 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1528 [&](
auto a) {
return ns.
newName(
"rst"); })
1529 .Case<firrtl::RefType>(
1530 [&](
auto a) {
return ns.
newName(
"ref"); })
1531 .Default([&](
auto a) {
1532 return ns.
newName(Twine(getPortName()));
1534 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1536 fmodule->setAttr(
"portNames",
1537 ArrayAttr::get(fmodule.getContext(), newNames));
1540 if (module == iGraph.getTopLevelModule())
1542 auto newName = StringAttr::get(circuitOp.getContext(),
getName());
1543 module.setName(newName);
1544 for (
auto *use : node->uses()) {
1545 auto instanceOp = dyn_cast<firrtl::InstanceOp>(*use->getInstance());
1546 instanceOp.setModuleName(newName);
1547 instanceOp.setName(newName);
1548 if (shouldReplacePorts)
1549 instanceOp.setPortNamesAttr(
1550 ArrayAttr::get(circuitOp.getContext(), newNames));
1559 std::string
getName()
const override {
return "module-name-sanitizer"; }
1563 bool isOneShot()
const override {
return true; }
1582struct ModuleSwapper :
public OpReduction<InstanceOp> {
1584 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1585 struct CircuitState {
1586 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1587 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1588 std::unique_ptr<NLATable> nlaTable;
1594 moduleSizes.clear();
1595 circuitStates.clear();
1598 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1599 auto &state = circuitStates[circuitOp];
1600 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1601 buildModuleTypeGroups(circuitOp, state);
1602 return WalkResult::skip();
1605 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1611 PortSignature getModulePortSignature(FModuleLike module) {
1612 PortSignature signature;
1613 signature.reserve(module.getNumPorts());
1614 for (
unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1615 signature.emplace_back(module.getPortType(i),
module.getPortDirection(i));
1620 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1622 for (
auto module : circuitOp.
getBodyBlock()->getOps<FModuleLike>()) {
1623 auto signature = getModulePortSignature(module);
1624 state.moduleTypeGroups[signature].push_back(module);
1628 for (
auto &[signature, modules] : state.moduleTypeGroups) {
1629 if (modules.size() <= 1)
1632 FModuleLike smallestModule =
nullptr;
1633 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1635 for (
auto module : modules) {
1636 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1637 if (size < smallestSize) {
1638 smallestSize = size;
1639 smallestModule =
module;
1644 for (
auto module : modules) {
1645 if (module != smallestModule) {
1646 state.instanceToCanonicalModule[
module.getModuleNameAttr()] =
1653 uint64_t
match(InstanceOp instOp)
override {
1655 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1657 const auto &state = circuitStates.at(circuitOp);
1660 DenseSet<hw::HierPathOp> nlas;
1661 state.nlaTable->getInstanceNLAs(instOp, nlas);
1666 auto moduleName = instOp.getModuleNameAttr().getAttr();
1667 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1668 if (!canonicalModule)
1672 auto currentModule = cast<FModuleLike>(
1673 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1674 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1675 uint64_t canonicalSize =
1676 moduleSizes.getModuleSize(canonicalModule, symbols);
1677 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1680 LogicalResult
rewrite(InstanceOp instOp)
override {
1682 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1684 const auto &state = circuitStates.at(circuitOp);
1687 auto canonicalModule = state.instanceToCanonicalModule.at(
1688 instOp.getModuleNameAttr().getAttr());
1689 auto canonicalName = canonicalModule.getModuleNameAttr();
1690 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
1693 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1698 std::string
getName()
const override {
return "firrtl-module-swapper"; }
1707 DenseMap<CircuitOp, CircuitState> circuitStates;
1725struct ForceDedup :
public OpReduction<CircuitOp> {
1729 modulesToErase.clear();
1730 moduleSizes.clear();
1733 nlaRemover.remove(op);
1734 for (
auto mod : modulesToErase)
1739 void matches(CircuitOp circuitOp,
1740 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1741 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
1743 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1747 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
1748 if (!modulesAttr || modulesAttr.size() < 2)
1754 uint64_t totalSize = 0;
1755 ArrayAttr portTypes;
1756 DenseBoolArrayAttr portDirections;
1757 bool allSame =
true;
1758 for (
auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
1764 auto mod = symbolTable.lookup<FModuleLike>(target->module);
1769 totalSize += moduleSizes.getModuleSize(mod, symbols);
1771 portTypes = mod.getPortTypesAttr();
1772 portDirections = mod.getPortDirectionsAttr();
1773 }
else if (portTypes != mod.getPortTypesAttr() ||
1774 portDirections != mod.getPortDirectionsAttr()) {
1784 addMatch(totalSize, annoIdx);
1789 ArrayRef<uint64_t> matches)
override {
1790 auto *context = circuitOp->getContext();
1794 SmallVector<Annotation> newAnnotations;
1796 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1798 if (!llvm::is_contained(matches, annoIdx)) {
1799 newAnnotations.push_back(anno);
1802 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
1804 modulesAttr.size() >= 2);
1807 SmallVector<StringAttr> moduleNames;
1808 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
1810 auto refStr = moduleRef.getValue();
1811 auto pipePos = refStr.find(
'|');
1812 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
1813 auto moduleName = refStr.substr(pipePos + 1);
1814 moduleNames.push_back(StringAttr::get(context, moduleName));
1819 if (moduleNames.size() < 2)
1824 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
1825 nlaRemover.markNLAsInAnnotation(anno.
getAttr());
1827 if (newAnnotations.size() == annotations.size())
1832 newAnnoSet.applyToOperation(circuitOp);
1836 std::string
getName()
const override {
return "firrtl-force-dedup"; }
1842 void replaceModuleReferences(CircuitOp circuitOp,
1843 ArrayRef<StringAttr> moduleNames,
1846 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
1847 auto &symbolTable = symbols.getSymbolTable(tableOp);
1848 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
1849 auto *context = circuitOp->getContext();
1853 FModuleLike canonicalModule;
1854 SmallVector<FModuleLike> modulesToReplace;
1855 for (
auto name : moduleNames) {
1856 if (
auto mod = symbolTable.lookup<FModuleLike>(name)) {
1857 if (!canonicalModule)
1858 canonicalModule = mod;
1860 modulesToReplace.push_back(mod);
1863 if (modulesToReplace.empty())
1867 auto canonicalName = canonicalModule.getModuleNameAttr();
1868 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
1869 for (
auto moduleName : moduleNames) {
1870 if (moduleName == canonicalName)
1872 auto *symbolOp = symbolTable.lookup(moduleName);
1875 for (
auto *user : symbolUserMap.getUsers(symbolOp)) {
1876 auto instOp = dyn_cast<InstanceOp>(user);
1877 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
1879 instOp.setModuleNameAttr(canonicalRef);
1880 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1886 for (
auto oldMod : modulesToReplace) {
1887 SmallVector<hw::HierPathOp> nlaOps(
1888 nlaTable.
lookup(oldMod.getModuleNameAttr()));
1889 for (
auto nlaOp : nlaOps) {
1890 nlaTable.
erase(nlaOp);
1891 StringAttr oldModName = oldMod.getModuleNameAttr();
1892 StringAttr newModName = canonicalName;
1893 SmallVector<Attribute, 4> newPath;
1894 for (
auto nameRef : nlaOp.getNamepath()) {
1895 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
1896 if (ref.getModule() == oldModName) {
1897 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
1898 ref = hw::InnerRefAttr::get(newModName, ref.getName());
1899 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
1900 if (oldInst && newInst) {
1901 oldModName = oldInst.getReferencedModuleNameAttr();
1902 newModName = newInst.getReferencedModuleNameAttr();
1905 newPath.push_back(ref);
1906 }
else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
1907 newPath.push_back(FlatSymbolRefAttr::get(newModName));
1909 newPath.push_back(nameRef);
1912 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
1918 for (
auto module : modulesToReplace) {
1919 nlaRemover.markNLAsInOperation(module);
1920 modulesToErase.insert(module);
1926 SetVector<FModuleLike> modulesToErase;
1946struct MustDedupChildren :
public OpReduction<CircuitOp> {
1951 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1955 void matches(CircuitOp circuitOp,
1956 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1958 uint64_t matchId = 0;
1960 DenseSet<StringRef> modulesAlreadyInMustDedup;
1961 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations))
1963 if (auto modulesAttr = anno.getMember<ArrayAttr>(
"modules"))
1964 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
1966 modulesAlreadyInMustDedup.insert(target->module);
1968 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1972 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
1973 if (!modulesAttr || modulesAttr.size() < 2)
1977 processInstanceGroups(
1978 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
1982 SmallDenseSet<StringAttr, 4> moduleTargets;
1983 for (
auto instOp : instanceGroup)
1984 moduleTargets.insert(instOp.getReferencedModuleNameAttr());
1985 if (moduleTargets.size() < 2)
1990 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
1991 return modulesAlreadyInMustDedup.contains(
1992 inst.getReferencedModuleName());
1996 addMatch(1, matchId - 1);
2002 ArrayRef<uint64_t> matches)
override {
2003 auto *context = circuitOp->getContext();
2005 SmallVector<Annotation> newAnnotations;
2006 uint64_t matchId = 0;
2008 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2010 newAnnotations.push_back(anno);
2014 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2015 if (!modulesAttr || modulesAttr.size() < 2) {
2016 newAnnotations.push_back(anno);
2020 processInstanceGroups(
2021 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2023 if (!llvm::is_contained(matches, matchId++))
2027 SmallSetVector<StringAttr, 4> moduleTargets;
2028 for (
auto instOp : instanceGroup) {
2030 target.circuit = circuitOp.getName();
2031 target.module = instOp.getReferencedModuleName();
2032 moduleTargets.insert(target.toStringAttr(context));
2036 SmallVector<NamedAttribute> newAnnoAttrs;
2037 newAnnoAttrs.emplace_back(
2038 StringAttr::get(context,
"class"),
2040 newAnnoAttrs.emplace_back(
2041 StringAttr::get(context,
"modules"),
2042 ArrayAttr::get(context,
2043 SmallVector<Attribute>(moduleTargets.begin(),
2044 moduleTargets.end())));
2046 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
2047 newAnnotations.emplace_back(newAnnoDict);
2051 newAnnotations.push_back(anno);
2056 newAnnoSet.applyToOperation(circuitOp);
2060 std::string
getName()
const override {
return "must-dedup-children"; }
2068 void processInstanceGroups(
2069 CircuitOp circuitOp, ArrayAttr modulesAttr,
2070 llvm::function_ref<
void(ArrayRef<FInstanceLike>)> callback) {
2071 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2074 SmallVector<FModuleLike> modules;
2075 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2077 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2078 modules.push_back(mod);
2081 if (modules.size() < 2)
2088 struct InstanceGroup {
2089 SmallVector<FInstanceLike> instances;
2090 bool nameIsUnique =
true;
2092 MapVector<StringAttr, InstanceGroup> instanceGroups;
2093 for (
auto module : modules) {
2095 module.walk([&](FInstanceLike instOp) {
2096 if (isa<ObjectOp>(instOp.getOperation()))
2098 auto name = instOp.getInstanceNameAttr();
2099 auto &group = instanceGroups[name];
2100 if (nameCounts[name]++ > 1)
2101 group.nameIsUnique =
false;
2102 group.instances.push_back(instOp);
2108 for (
auto &[name, group] : instanceGroups)
2109 if (group.nameIsUnique && group.instances.size() >= 2)
2110 callback(group.instances);
2130 patterns.add<SimplifyResets, 34>();
2132 patterns.add<MustDedupChildren, 32>();
2133 patterns.add<AnnotationRemover, 31>();
2140 firrtl::createLowerCHIRRTLPass(),
true,
true);
2145 patterns.add<FIRRTLModuleExternalizer, 25>();
2146 patterns.add<InstanceStubber, 24>();
2151 firrtl::createLowerFIRRTLTypes(),
true,
true);
2158 firrtl::createRemoveUnusedPorts({
true}));
2159 patterns.add<NodeSymbolRemover, 15>();
2160 patterns.add<ConnectForwarder, 14>();
2161 patterns.add<ConnectInvalidator, 13>();
2163 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2164 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2165 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2166 patterns.add<DetachSubaccesses, 7>();
2168 patterns.add<ExtmoduleInstanceRemover, 4>();
2169 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2170 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2171 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2172 patterns.add<ModuleInternalNameSanitizer, 0>();
2173 patterns.add<ModuleNameSanitizer, 0>();
2177 mlir::DialectRegistry ®istry) {
2178 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
assert(baseType &&"element must be base type")
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
@ None
Don't explicitly preserve any named values.
constexpr const char * excludeFromFullResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
constexpr const char * fullResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
constexpr const char * fullAsyncResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
constexpr const char * mustDedupAnnoClass
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
constexpr const char * ignoreFullAsyncResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
A reduction pattern that applies an mlir::Pass.
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)