24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/Matchers.h"
27#include "llvm/ADT/APSInt.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/SmallSet.h"
30#include "llvm/Support/Debug.h"
32#define DEBUG_TYPE "firrtl-reductions"
36using namespace firrtl;
38using llvm::SmallSetVector;
51 return tables->getSymbolTable(op);
61 return userMaps.insert({op, SymbolUserMap(*
tables, op)}).first->second;
68 tables = std::make_unique<SymbolTableCollection>();
73 std::unique_ptr<SymbolTableCollection>
tables;
80static std::optional<firrtl::FModuleOp>
83 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
84 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
86 return moduleOp ? std::optional(moduleOp) : std::nullopt;
97 module->walk([&](Operation *op) {
99 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(op))
113 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
114 auto *op = use.getOwner();
115 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
117 if (use.getOperandNumber() != 0)
119 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
136 unsigned numRemoved = 0;
138 SymbolTableCollection symbolTables;
139 for (Operation &rootOp : *
module.getBody()) {
140 if (!isa<firrtl::CircuitOp>(&rootOp))
142 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
143 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
145 if (
auto *op = symbolTable.lookup(sym)) {
146 if (symbolUserMap.useEmpty(op)) {
155 if (numRemoved > 0 || numLost > 0) {
156 llvm::dbgs() <<
"Removed " << numRemoved <<
" NLAs";
158 llvm::dbgs() <<
" (" << numLost <<
" no longer there)";
159 llvm::dbgs() <<
"\n";
168 if (
auto dict = dyn_cast<DictionaryAttr>(anno)) {
169 if (
auto field = dict.getAs<FlatSymbolRefAttr>(
"circt.nonlocal"))
170 nlasToRemove.insert(field.getAttr());
171 for (
auto namedAttr : dict)
172 markNLAsInAnnotation(namedAttr.getValue());
173 }
else if (
auto array = dyn_cast<ArrayAttr>(anno)) {
174 for (
auto attr : array)
175 markNLAsInAnnotation(attr);
183 op->walk([&](Operation *op) {
184 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
185 markNLAsInAnnotation(annos);
200struct FIRRTLModuleExternalizer :
public OpReduction<FModuleOp> {
207 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
209 uint64_t
match(FModuleOp module)
override {
210 if (innerSymUses.hasInnerRef(module))
212 return moduleSizes.getModuleSize(module, symbols);
215 LogicalResult
rewrite(FModuleOp module)
override {
218 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
219 for (
auto attr :
module.getPortTypes()) {
220 auto type = cast<TypeAttr>(attr).getValue();
221 if (
auto refType = type_dyn_cast<RefType>(type))
222 if (
auto layer = refType.getLayer())
223 layers.insert(layer);
225 SmallVector<Attribute, 4> layersArray;
226 layersArray.reserve(layers.size());
227 for (
auto layer : layers)
228 layersArray.push_back(layer);
230 nlaRemover.markNLAsInOperation(module);
231 OpBuilder builder(module);
232 auto extmodule = FExtModuleOp::create(
233 builder, module->getLoc(),
234 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
235 module.getConventionAttr(), module.getPorts(),
236 builder.getArrayAttr(layersArray), StringRef(),
237 module.getAnnotationsAttr());
238 SymbolTable::setSymbolVisibility(extmodule,
239 SymbolTable::getSymbolVisibility(module));
244 std::string
getName()
const override {
return "firrtl-module-externalizer"; }
256static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
259 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
264 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
265 for (
auto element :
llvm::enumerate(bundleType.getElements())) {
267 builder.createOrFold<firrtl::SubfieldOp>(value, element.index());
268 invalidateOutputs(builder, subfield, invalidCache,
269 flip ^ element.value().isFlip);
270 if (subfield.use_empty())
271 subfield.getDefiningOp()->erase();
277 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
278 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
279 auto subindex = builder.createOrFold<firrtl::SubindexOp>(value, i);
280 invalidateOutputs(builder, subindex, invalidCache,
flip);
281 if (subindex.use_empty())
282 subindex.getDefiningOp()->erase();
290 Value invalid = invalidCache.lookup(type);
292 invalid = firrtl::InvalidValueOp::create(builder, type);
293 invalidCache.insert({type, invalid});
295 firrtl::ConnectOp::create(builder, value, invalid);
299static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
301 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
304 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
305 for (
auto element :
llvm::enumerate(bundleType.getElements()))
306 connectToLeafs(builder,
307 firrtl::SubfieldOp::create(builder, dest, element.index()),
311 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
312 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
313 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
317 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
320 auto destWidth = type.getBitWidthOrSentinel();
321 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
322 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
323 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
324 if (!isa<firrtl::UIntType>(type)) {
325 if (isa<firrtl::SIntType>(type))
326 value = firrtl::AsSIntPrimOp::create(builder, value);
330 firrtl::ConnectOp::create(builder, dest, value);
334static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
335 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
338 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
339 for (
auto element :
llvm::enumerate(bundleType.getElements()))
342 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
345 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
346 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
347 reduceXor(builder, into,
348 builder.createOrFold<firrtl::SubindexOp>(value, i));
351 if (!isa<firrtl::UIntType>(type)) {
352 if (isa<firrtl::SIntType>(type))
353 value = firrtl::AsUIntPrimOp::create(builder, value);
357 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
363struct InstanceStubber :
public OpReduction<firrtl::InstanceOp> {
366 erasedModules.clear();
374 SmallVector<Operation *> worklist;
375 auto deadInsts = erasedInsts;
376 for (
auto *op : erasedModules)
377 worklist.push_back(op);
378 while (!worklist.empty()) {
379 auto *op = worklist.pop_back_val();
380 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
381 op->walk([&](firrtl::InstanceOp instOp) {
382 auto moduleOp = cast<firrtl::FModuleLike>(
383 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
384 deadInsts.insert(instOp);
386 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
387 [&](Operation *user) { return deadInsts.contains(user); })) {
388 LLVM_DEBUG(llvm::dbgs() <<
"- Removing transitively unused module `"
389 << moduleOp.getModuleName() <<
"`\n");
390 erasedModules.insert(moduleOp);
391 worklist.push_back(moduleOp);
396 for (
auto *op : erasedInsts)
398 for (
auto *op : erasedModules)
400 nlaRemover.remove(op);
403 uint64_t
match(firrtl::InstanceOp instOp)
override {
405 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
409 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
410 LLVM_DEBUG(llvm::dbgs()
411 <<
"Stubbing instance `" << instOp.getName() <<
"`\n");
412 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
414 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
415 auto result = instOp.getResult(i);
416 auto name = builder.getStringAttr(Twine(instOp.getName()) +
"_" +
417 instOp.getPortNameStr(i));
419 firrtl::WireOp::create(builder, result.getType(), name,
420 firrtl::NameKindEnum::DroppableName,
421 instOp.getPortAnnotation(i), StringAttr{})
423 invalidateOutputs(builder, wire, invalidCache,
424 instOp.getPortDirection(i) == firrtl::Direction::In);
425 result.replaceAllUsesWith(wire);
427 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
428 auto moduleOp = cast<firrtl::FModuleLike>(
429 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
430 nlaRemover.markNLAsInOperation(instOp);
431 erasedInsts.insert(instOp);
433 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
434 [&](Operation *user) { return erasedInsts.contains(user); })) {
435 LLVM_DEBUG(llvm::dbgs() <<
"- Removing now unused module `"
436 << moduleOp.getModuleName() <<
"`\n");
437 erasedModules.insert(moduleOp);
442 std::string
getName()
const override {
return "instance-stubber"; }
447 llvm::DenseSet<Operation *> erasedInsts;
448 llvm::DenseSet<Operation *> erasedModules;
454struct MemoryStubber :
public OpReduction<firrtl::MemOp> {
455 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
456 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
457 LogicalResult
rewrite(firrtl::MemOp memOp)
override {
458 LLVM_DEBUG(llvm::dbgs() <<
"Stubbing memory `" << memOp.getName() <<
"`\n");
459 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
462 SmallVector<Value> outputs;
463 for (
unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
464 auto result = memOp.getResult(i);
465 auto name = builder.getStringAttr(Twine(memOp.getName()) +
"_" +
466 memOp.getPortNameStr(i));
468 firrtl::WireOp::create(builder, result.getType(), name,
469 firrtl::NameKindEnum::DroppableName,
470 memOp.getPortAnnotation(i), StringAttr{})
472 invalidateOutputs(builder, wire, invalidCache,
true);
473 result.replaceAllUsesWith(wire);
477 switch (memOp.getPortKind(i)) {
478 case firrtl::MemOp::PortKind::Read:
479 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
481 case firrtl::MemOp::PortKind::Write:
482 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
484 case firrtl::MemOp::PortKind::ReadWrite:
485 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
486 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
488 case firrtl::MemOp::PortKind::Debug:
493 if (!isa<firrtl::RefType>(result.getType())) {
496 cast<firrtl::BundleType>(wire.getType()).getNumElements();
497 for (
unsigned i = 0; i != numFields; ++i) {
498 if (i != 2 && i != 3 && i != 5)
499 reduceXor(builder, xorInputs,
500 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
503 reduceXor(builder, xorInputs, input);
508 outputs.push_back(output);
512 for (
auto output : outputs)
513 connectToLeafs(builder, output, xorInputs);
515 nlaRemover.markNLAsInOperation(memOp);
519 std::string
getName()
const override {
return "memory-stubber"; }
526static bool isFlowSensitiveOp(Operation *op) {
527 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
528 SubaccessOp, ObjectSubfieldOp>(op);
534template <
unsigned OpNum>
535struct FIRRTLOperandForwarder :
public Reduction {
536 uint64_t
match(Operation *op)
override {
537 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
539 if (isFlowSensitiveOp(op))
542 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
544 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
545 return resultTy && opTy &&
546 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
547 (resultTy.getBitWidthOrSentinel() == -1) ==
548 (opTy.getBitWidthOrSentinel() == -1) &&
549 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
551 LogicalResult
rewrite(Operation *op)
override {
553 ImplicitLocOpBuilder builder(op->getLoc(), op);
554 auto result = op->getResult(0);
555 auto operand = op->getOperand(OpNum);
556 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
557 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
558 auto resultWidth = resultTy.getBitWidthOrSentinel();
559 auto operandWidth = operandTy.getBitWidthOrSentinel();
561 if (resultWidth < operandWidth)
563 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
564 else if (resultWidth > operandWidth)
565 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
568 LLVM_DEBUG(llvm::dbgs() <<
"Forwarding " << newOp <<
" in " << *op <<
"\n");
569 result.replaceAllUsesWith(newOp);
573 std::string
getName()
const override {
574 return (
"firrtl-operand" + Twine(OpNum) +
"-forwarder").str();
585 anyrefCastDummy.clear();
586 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
587 for (
auto classOp : circuitOp.getOps<ClassOp>()) {
588 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
589 anyrefCastDummy.insert({circuitOp, classOp});
590 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
593 return WalkResult::skip();
597 uint64_t
match(Operation *op)
override {
598 if (op->hasTrait<OpTrait::ConstantLike>()) {
600 if (!matchPattern(op, m_Constant(&attr)))
602 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
603 if (intAttr.getValue().isZero())
605 if (
auto strAttr = dyn_cast<StringAttr>(attr))
608 if (
auto floatAttr = dyn_cast<FloatAttr>(attr))
609 if (floatAttr.getValue().isZero())
612 if (
auto listOp = dyn_cast<ListCreateOp>(op))
613 if (listOp.getElements().empty())
615 if (
auto pathOp = dyn_cast<UnresolvedPathOp>(op))
616 if (pathOp.getTarget().empty())
620 if (
auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
621 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
623 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
624 if (anyrefCastDummyNames[circuitOp].contains(className))
628 if (op->getNumResults() != 1)
630 if (op->hasAttr(
"inner_sym"))
632 if (isFlowSensitiveOp(op))
634 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
635 DoubleType, ListType, PathType, AnyRefType>(
636 op->getResult(0).getType());
639 LogicalResult
rewrite(Operation *op)
override {
640 OpBuilder builder(op);
641 auto type = op->getResult(0).getType();
644 if (isa<UIntType, SIntType>(type)) {
645 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
648 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
649 APSInt(width, isa<UIntType>(type)));
650 op->replaceAllUsesWith(newOp);
656 if (isa<StringType>(type)) {
657 auto attr = builder.getStringAttr(
"");
658 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
659 op->replaceAllUsesWith(newOp);
665 if (isa<FIntegerType>(type)) {
666 auto attr = builder.getIntegerAttr(builder.getI64Type(), 0);
667 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
668 op->replaceAllUsesWith(newOp);
674 if (isa<BoolType>(type)) {
675 auto attr = builder.getBoolAttr(
false);
676 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
677 op->replaceAllUsesWith(newOp);
683 if (isa<DoubleType>(type)) {
684 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
685 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
686 op->replaceAllUsesWith(newOp);
692 if (isa<ListType>(type)) {
694 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
695 op->replaceAllUsesWith(newOp);
701 if (isa<PathType>(type)) {
702 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(),
"");
703 op->replaceAllUsesWith(newOp);
709 if (isa<AnyRefType>(type)) {
710 auto circuitOp = op->getParentOfType<CircuitOp>();
711 auto &dummy = anyrefCastDummy[circuitOp];
713 OpBuilder::InsertionGuard guard(builder);
714 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
715 auto &symbolTable = symbols.getNearestSymbolTable(op);
716 dummy = ClassOp::create(builder, op->getLoc(),
"Dummy", {}, {});
717 symbolTable.insert(dummy);
718 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
720 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy,
"dummy");
722 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
723 op->replaceAllUsesWith(anyrefOp);
731 std::string
getName()
const override {
return "firrtl-constantifier"; }
743struct ConnectInvalidator :
public Reduction {
744 uint64_t
match(Operation *op)
override {
745 if (!isa<FConnectLike>(op))
747 if (
auto *srcOp = op->getOperand(1).getDefiningOp())
748 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
749 isa<InvalidValueOp>(srcOp))
751 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
752 return type && type.isPassive();
754 LogicalResult
rewrite(Operation *op)
override {
756 auto rhs = op->getOperand(1);
757 OpBuilder builder(op);
758 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
759 auto *rhsOp = rhs.getDefiningOp();
760 op->setOperand(1, invOp);
765 std::string
getName()
const override {
return "connect-invalidator"; }
772struct AnnotationRemover :
public Reduction {
773 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
774 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
777 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
778 uint64_t matchId = 0;
781 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
782 for (
unsigned i = 0; i < annos.size(); ++i)
783 addMatch(1, matchId++);
786 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations"))
787 for (
auto portAnnoArray : portAnnos)
788 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
789 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
790 addMatch(1, matchId++);
794 ArrayRef<uint64_t> matches)
override {
796 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
799 uint64_t matchId = 0;
800 auto processAnnotations =
801 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
802 SmallVector<Attribute> newAnnotations;
803 for (
auto anno : annotations) {
804 if (!matchesSet.contains(matchId)) {
805 newAnnotations.push_back(anno);
808 nlaRemover.markNLAsInAnnotation(anno);
812 return ArrayAttr::get(op->getContext(), newAnnotations);
816 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations")) {
817 op->setAttr(
"annotations", processAnnotations(annos.getValue()));
821 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations")) {
822 SmallVector<Attribute> newPortAnnos;
823 for (
auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
824 newPortAnnos.push_back(
825 processAnnotations(portAnnoArrayAttr.getValue()));
827 op->setAttr(
"portAnnotations",
828 ArrayAttr::get(op->getContext(), newPortAnnos));
834 std::string
getName()
const override {
return "annotation-remover"; }
840struct RootPortPruner :
public OpReduction<firrtl::FModuleOp> {
841 uint64_t
match(firrtl::FModuleOp module)
override {
842 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
845 return circuit.getNameAttr() ==
module.getNameAttr();
847 LogicalResult
rewrite(firrtl::FModuleOp module)
override {
849 size_t numPorts =
module.getNumPorts();
850 llvm::BitVector dropPorts(numPorts);
851 for (
unsigned i = 0; i != numPorts; ++i) {
855 llvm::make_early_inc_range(module.getArgument(i).getUsers()))
859 module.erasePorts(dropPorts);
862 std::string
getName()
const override {
return "root-port-pruner"; }
867struct ExtmoduleInstanceRemover :
public OpReduction<firrtl::InstanceOp> {
872 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
874 uint64_t
match(firrtl::InstanceOp instOp)
override {
875 return isa<firrtl::FExtModuleOp>(
876 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
878 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
880 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
881 symbols.getNearestSymbolTable(instOp)))
883 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
884 SmallVector<Value> replacementWires;
886 auto wire = firrtl::WireOp::create(
888 (Twine(instOp.getName()) +
"_" +
info.getName()).str())
890 if (
info.isOutput()) {
891 auto inv = firrtl::InvalidValueOp::create(builder,
info.type);
892 firrtl::ConnectOp::create(builder, wire, inv);
894 replacementWires.push_back(wire);
896 nlaRemover.markNLAsInOperation(instOp);
897 instOp.replaceAllUsesWith(std::move(replacementWires));
901 std::string
getName()
const override {
return "extmodule-instance-remover"; }
909struct ConnectForwarder :
public Reduction {
910 uint64_t
match(Operation *op)
override {
911 if (!isa<firrtl::FConnectLike>(op))
913 auto dest = op->getOperand(0);
914 auto src = op->getOperand(1);
915 auto *destOp = dest.getDefiningOp();
916 auto *srcOp = src.getDefiningOp();
922 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
928 unsigned numConnects = 0;
929 for (
auto &use : dest.getUses()) {
930 auto *op = use.getOwner();
931 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
932 if (++numConnects > 1)
936 if (srcOp && !srcOp->isBeforeInBlock(op))
943 LogicalResult
rewrite(Operation *op)
override {
944 auto dst = op->getOperand(0);
945 auto src = op->getOperand(1);
946 dst.replaceAllUsesWith(src);
948 if (
auto *dstOp = dst.getDefiningOp())
950 if (
auto *srcOp = src.getDefiningOp())
955 std::string
getName()
const override {
return "connect-forwarder"; }
960template <
unsigned OpNum>
961struct ConnectSourceOperandForwarder :
public Reduction {
962 uint64_t
match(Operation *op)
override {
963 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
965 auto dest = op->getOperand(0);
966 auto *destOp = dest.getDefiningOp();
969 if (!destOp || !destOp->hasOneUse() ||
970 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
973 auto *srcOp = op->getOperand(1).getDefiningOp();
974 if (!srcOp || OpNum >= srcOp->getNumOperands())
977 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
979 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
981 return resultTy && opTy &&
982 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
983 ((resultTy.getBitWidthOrSentinel() == -1) ==
984 (opTy.getBitWidthOrSentinel() == -1)) &&
985 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
988 LogicalResult
rewrite(Operation *op)
override {
989 auto *destOp = op->getOperand(0).getDefiningOp();
990 auto *srcOp = op->getOperand(1).getDefiningOp();
991 auto forwardedOperand = srcOp->getOperand(OpNum);
992 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
994 if (
auto wire = dyn_cast<firrtl::WireOp>(destOp))
995 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
999 auto regName = destOp->getAttrOfType<StringAttr>(
"name");
1002 auto clock = destOp->getOperand(0);
1003 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1004 clock, regName ? regName.str() :
"")
1009 builder.setInsertionPointAfter(op);
1010 if (isa<firrtl::ConnectOp>(op))
1011 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1013 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1023 std::string
getName()
const override {
1024 return (
"connect-source-operand-" + Twine(OpNum) +
"-forwarder").str();
1031struct DetachSubaccesses :
public Reduction {
1032 void beforeReduction(mlir::ModuleOp op)
override { opsToErase.clear(); }
1034 for (
auto *op : opsToErase)
1035 op->dropAllReferences();
1036 for (
auto *op : opsToErase)
1039 uint64_t
match(Operation *op)
override {
1042 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1043 llvm::all_of(op->getUses(), [](
auto &use) {
1044 return use.getOperandNumber() == 0 &&
1045 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1046 firrtl::SubaccessOp>(use.getOwner());
1049 LogicalResult
rewrite(Operation *op)
override {
1051 OpBuilder builder(op);
1052 bool isWire = isa<firrtl::WireOp>(op);
1055 invalidClock = firrtl::InvalidValueOp::create(
1056 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1057 for (Operation *user :
llvm::make_early_inc_range(op->getUsers())) {
1058 builder.setInsertionPoint(user);
1059 auto type = user->getResult(0).getType();
1062 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1065 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1066 user->replaceAllUsesWith(replOp);
1067 opsToErase.insert(user);
1069 opsToErase.insert(op);
1072 std::string
getName()
const override {
return "detach-subaccesses"; }
1073 llvm::DenseSet<Operation *> opsToErase;
1079struct NodeSymbolRemover :
public Reduction {
1084 uint64_t
match(Operation *op)
override {
1086 auto sym = op->getAttrOfType<hw::InnerSymAttr>(
"inner_sym");
1087 if (!sym || sym.empty())
1091 if (innerSymUses.hasInnerRef(op))
1096 LogicalResult
rewrite(Operation *op)
override {
1097 op->removeAttr(
"inner_sym");
1101 std::string
getName()
const override {
return "node-symbol-remover"; }
1110hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1119 LogicalResult walkResult = targetTable.
walkSymbols(
1122 if (parentTable.lookup(name)) {
1130 return failed(walkResult);
1134struct EagerInliner :
public OpReduction<InstanceOp> {
1139 for (
auto circuitOp : op.getOps<CircuitOp>())
1140 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1141 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1144 nlaRemover.remove(op);
1146 innerSymTables.reset();
1149 uint64_t
match(InstanceOp instOp)
override {
1150 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1152 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1155 if (!isa<FModuleOp>(moduleOp))
1159 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1162 auto it = nlaTables.find(circuitOp);
1163 if (it == nlaTables.end() || !it->second)
1165 DenseSet<hw::HierPathOp> nlas;
1166 it->second->getInstanceNLAs(instOp, nlas);
1172 auto parentOp = instOp->getParentOfType<FModuleLike>();
1173 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1179 LogicalResult
rewrite(InstanceOp instOp)
override {
1180 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1181 auto moduleOp = cast<FModuleOp>(
1182 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1184 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1185 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1188 IRRewriter rewriter(instOp);
1189 SmallVector<Value> argWires;
1190 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1191 auto result = instOp.getResult(i);
1192 auto name = rewriter.getStringAttr(Twine(instOp.getName()) +
"_" +
1193 instOp.getPortNameStr(i));
1194 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1195 name, NameKindEnum::DroppableName,
1196 instOp.getPortAnnotation(i), StringAttr{})
1198 result.replaceAllUsesWith(wire);
1199 argWires.push_back(wire);
1203 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1207 nlaRemover.markNLAsInOperation(instOp);
1209 nlaRemover.markNLAsInOperation(moduleOp);
1212 clonedModuleOp.erase();
1216 std::string
getName()
const override {
return "firrtl-eager-inliner"; }
1221 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1222 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1226struct ObjectInliner :
public OpReduction<ObjectOp> {
1228 blocksToSort.clear();
1231 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1234 for (
auto *block : blocksToSort)
1235 mlir::sortTopologically(block);
1236 blocksToSort.clear();
1237 nlaRemover.remove(op);
1238 innerSymTables.reset();
1241 uint64_t
match(ObjectOp objOp)
override {
1242 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1244 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1247 if (!isa<ClassOp>(classOp))
1252 auto parentOp = objOp->getParentOfType<FModuleLike>();
1253 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1257 for (
auto *user : objOp.getResult().getUsers())
1258 if (!isa<ObjectSubfieldOp>(user))
1264 LogicalResult
rewrite(ObjectOp objOp)
override {
1265 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1266 auto classOp = cast<ClassOp>(
1267 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1268 auto clonedClassOp = classOp.clone();
1271 IRRewriter rewriter(objOp);
1272 SmallVector<Value> portWires;
1273 auto classType = objOp.getType();
1276 for (
unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1277 auto element = classType.getElement(i);
1278 auto name = rewriter.getStringAttr(Twine(objOp.getName()) +
"_" +
1279 element.name.getValue());
1280 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1281 NameKindEnum::DroppableName,
1282 rewriter.getArrayAttr({}), StringAttr{})
1284 portWires.push_back(wire);
1288 SmallVector<ObjectSubfieldOp> subfieldOps;
1289 for (
auto *user : objOp.getResult().getUsers()) {
1290 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1291 subfieldOps.push_back(subfieldOp);
1292 auto index = subfieldOp.getIndex();
1293 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1297 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1303 SmallVector<FConnectLike> connectsToErase;
1304 for (
auto portWire : portWires) {
1308 for (
auto *user : portWire.getUsers()) {
1309 if (
auto connect = dyn_cast<FConnectLike>(user)) {
1310 if (
connect.getDest() == portWire) {
1312 connectsToErase.push_back(connect);
1322 portWire.replaceAllUsesWith(value);
1323 for (
auto connect : connectsToErase)
1325 if (portWire.use_empty())
1326 portWire.getDefiningOp()->erase();
1327 connectsToErase.clear();
1331 nlaRemover.markNLAsInOperation(objOp);
1336 blocksToSort.insert(objOp->getBlock());
1339 for (
auto subfieldOp : subfieldOps)
1342 clonedClassOp.erase();
1346 std::string
getName()
const override {
return "firrtl-object-inliner"; }
1349 SetVector<Block *> blocksToSort;
1352 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1366struct ModuleInternalNameSanitizer :
public Reduction {
1367 uint64_t
match(Operation *op)
override {
1369 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1370 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1371 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1372 firrtl::CoverOp>(op);
1374 LogicalResult
rewrite(Operation *op)
override {
1375 TypeSwitch<Operation *, void>(op)
1376 .Case<firrtl::WireOp>([](
auto op) { op.setName(
"wire"); })
1377 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1378 [](
auto op) { op.setName(
"reg"); })
1379 .Case<firrtl::NodeOp>([](
auto op) { op.setName(
"node"); })
1380 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1381 [](
auto op) { op.setName(
"mem"); })
1382 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](
auto op) {
1383 op->setAttr(
"message", StringAttr::get(op.getContext(),
""));
1384 op->setAttr(
"name", StringAttr::get(op.getContext(),
""));
1389 std::string
getName()
const override {
1390 return "module-internal-name-sanitizer";
1395 bool isOneShot()
const override {
return true; }
1409struct ModuleNameSanitizer :
OpReduction<firrtl::CircuitOp> {
1411 const char *names[48] = {
1412 "Foo",
"Bar",
"Baz",
"Qux",
"Quux",
"Quuux",
"Quuuux",
1413 "Quz",
"Corge",
"Grault",
"Bazola",
"Ztesch",
"Thud",
"Grunt",
1414 "Bletch",
"Fum",
"Fred",
"Jim",
"Sheila",
"Barney",
"Flarp",
1415 "Zxc",
"Spqr",
"Wombat",
"Shme",
"Bongo",
"Spam",
"Eggs",
1416 "Snork",
"Zot",
"Blarg",
"Wibble",
"Toto",
"Titi",
"Tata",
1417 "Tutu",
"Pippo",
"Pluto",
"Paperino",
"Aap",
"Noot",
"Mies",
1418 "Oogle",
"Foogle",
"Boogle",
"Zork",
"Gork",
"Bork"};
1420 size_t nameIndex = 0;
1423 if (nameIndex >= 48)
1425 return names[nameIndex++];
1428 size_t portNameIndex = 0;
1430 char getPortName() {
1431 if (portNameIndex >= 26)
1433 return 'a' + portNameIndex++;
1438 LogicalResult
rewrite(firrtl::CircuitOp circuitOp)
override {
1442 auto *circuitName =
getName();
1443 iGraph.getTopLevelModule().setName(circuitName);
1444 circuitOp.setName(circuitName);
1446 for (
auto *node : iGraph) {
1447 auto module = node->getModule<firrtl::FModuleLike>();
1449 bool shouldReplacePorts =
false;
1450 SmallVector<Attribute> newNames;
1451 if (
auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1456 auto oldPorts = fmodule.getPorts();
1457 shouldReplacePorts = !oldPorts.empty();
1458 for (
unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1459 auto port = oldPorts[i];
1461 .
Case<firrtl::ClockType>(
1462 [&](
auto a) {
return ns.
newName(
"clk"); })
1463 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1464 [&](
auto a) {
return ns.
newName(
"rst"); })
1465 .Case<firrtl::RefType>(
1466 [&](
auto a) {
return ns.
newName(
"ref"); })
1467 .Default([&](
auto a) {
1468 return ns.
newName(Twine(getPortName()));
1470 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1472 fmodule->setAttr(
"portNames",
1473 ArrayAttr::get(fmodule.getContext(), newNames));
1476 if (module == iGraph.getTopLevelModule())
1478 auto newName = StringAttr::get(circuitOp.getContext(),
getName());
1479 module.setName(newName);
1480 for (
auto *use : node->uses()) {
1481 auto instanceOp = dyn_cast<firrtl::InstanceOp>(*use->getInstance());
1482 instanceOp.setModuleName(newName);
1483 instanceOp.setName(newName);
1484 if (shouldReplacePorts)
1485 instanceOp.setPortNamesAttr(
1486 ArrayAttr::get(circuitOp.getContext(), newNames));
1495 std::string
getName()
const override {
return "module-name-sanitizer"; }
1499 bool isOneShot()
const override {
return true; }
1518struct ModuleSwapper :
public OpReduction<InstanceOp> {
1520 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1521 struct CircuitState {
1522 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1523 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1524 std::unique_ptr<NLATable> nlaTable;
1530 moduleSizes.clear();
1531 circuitStates.clear();
1534 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1535 auto &state = circuitStates[circuitOp];
1536 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1537 buildModuleTypeGroups(circuitOp, state);
1538 return WalkResult::skip();
1541 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1547 PortSignature getModulePortSignature(FModuleLike module) {
1548 PortSignature signature;
1549 signature.reserve(module.getNumPorts());
1550 for (
unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1551 signature.emplace_back(module.getPortType(i),
module.getPortDirection(i));
1556 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1558 for (
auto module : circuitOp.
getBodyBlock()->getOps<FModuleLike>()) {
1559 auto signature = getModulePortSignature(module);
1560 state.moduleTypeGroups[signature].push_back(module);
1564 for (
auto &[signature, modules] : state.moduleTypeGroups) {
1565 if (modules.size() <= 1)
1568 FModuleLike smallestModule =
nullptr;
1569 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1571 for (
auto module : modules) {
1572 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1573 if (size < smallestSize) {
1574 smallestSize = size;
1575 smallestModule =
module;
1580 for (
auto module : modules) {
1581 if (module != smallestModule) {
1582 state.instanceToCanonicalModule[
module.getModuleNameAttr()] =
1589 uint64_t
match(InstanceOp instOp)
override {
1591 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1593 const auto &state = circuitStates.at(circuitOp);
1596 DenseSet<hw::HierPathOp> nlas;
1597 state.nlaTable->getInstanceNLAs(instOp, nlas);
1602 auto moduleName = instOp.getModuleNameAttr().getAttr();
1603 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1604 if (!canonicalModule)
1608 auto currentModule = cast<FModuleLike>(
1609 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1610 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1611 uint64_t canonicalSize =
1612 moduleSizes.getModuleSize(canonicalModule, symbols);
1613 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1616 LogicalResult
rewrite(InstanceOp instOp)
override {
1618 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1620 const auto &state = circuitStates.at(circuitOp);
1623 auto canonicalModule = state.instanceToCanonicalModule.at(
1624 instOp.getModuleNameAttr().getAttr());
1625 auto canonicalName = canonicalModule.getModuleNameAttr();
1626 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
1629 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1634 std::string
getName()
const override {
return "firrtl-module-swapper"; }
1643 DenseMap<CircuitOp, CircuitState> circuitStates;
1661struct ForceDedup :
public OpReduction<CircuitOp> {
1666 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1669 void matches(CircuitOp circuitOp,
1670 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1672 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1676 auto modulesAttr = anno.getMember<ArrayAttr>(
"modules");
1682 uint64_t benefit = modulesAttr.size();
1683 addMatch(benefit, annoIdx);
1688 ArrayRef<uint64_t> matches)
override {
1689 auto *context = circuitOp->getContext();
1693 SmallVector<Annotation> newAnnotations;
1695 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1697 if (!llvm::is_contained(matches, annoIdx)) {
1698 newAnnotations.push_back(anno);
1701 auto modulesAttr = anno.getMember<ArrayAttr>(
"modules");
1703 modulesAttr.size() >= 2);
1706 SmallVector<StringAttr> moduleNames;
1707 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
1709 auto refStr = moduleRef.getValue();
1710 auto pipePos = refStr.find(
'|');
1711 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
1712 auto moduleName = refStr.substr(pipePos + 1);
1713 moduleNames.push_back(StringAttr::get(context, moduleName));
1718 if (moduleNames.size() < 2)
1723 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
1724 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1726 if (newAnnotations.size() == annotations.size())
1731 newAnnoSet.applyToOperation(circuitOp);
1735 std::string
getName()
const override {
return "firrtl-force-dedup"; }
1741 void replaceModuleReferences(CircuitOp circuitOp,
1742 ArrayRef<StringAttr> moduleNames,
1745 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
1746 auto &symbolTable = symbols.getSymbolTable(tableOp);
1747 auto *context = circuitOp->getContext();
1751 FModuleLike canonicalModule;
1752 SmallVector<FModuleLike> modulesToReplace;
1753 for (
auto name : moduleNames) {
1754 if (
auto mod = symbolTable.lookup<FModuleLike>(name)) {
1755 if (!canonicalModule)
1756 canonicalModule = mod;
1758 modulesToReplace.push_back(mod);
1761 if (modulesToReplace.empty())
1765 auto canonicalName = canonicalModule.getModuleNameAttr();
1766 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
1767 circuitOp.walk([&](InstanceOp instOp) {
1768 auto moduleName = instOp.getModuleNameAttr().getAttr();
1769 if (llvm::is_contained(moduleNames, moduleName) &&
1770 moduleName != canonicalName) {
1771 instOp.setModuleNameAttr(canonicalRef);
1772 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1778 for (
auto oldMod : modulesToReplace) {
1779 SmallVector<hw::HierPathOp> nlaOps(
1780 nlaTable.
lookup(oldMod.getModuleNameAttr()));
1781 for (
auto nlaOp : nlaOps) {
1782 nlaTable.
erase(nlaOp);
1783 StringAttr oldModName = oldMod.getModuleNameAttr();
1784 StringAttr newModName = canonicalName;
1785 SmallVector<Attribute, 4> newPath;
1786 for (
auto nameRef : nlaOp.getNamepath()) {
1787 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
1788 if (ref.getModule() == oldModName) {
1789 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
1790 ref = hw::InnerRefAttr::get(newModName, ref.getName());
1791 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
1792 if (oldInst && newInst) {
1793 oldModName = oldInst.getReferencedModuleNameAttr();
1794 newModName = newInst.getReferencedModuleNameAttr();
1797 newPath.push_back(ref);
1798 }
else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
1799 newPath.push_back(FlatSymbolRefAttr::get(newModName));
1801 newPath.push_back(nameRef);
1804 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
1810 for (
auto module : modulesToReplace) {
1811 nlaRemover.markNLAsInOperation(module);
1836struct MustDedupChildren :
public OpReduction<CircuitOp> {
1841 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1845 void matches(CircuitOp circuitOp,
1846 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1848 uint64_t matchId = 0;
1850 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1854 auto modulesAttr = anno.getMember<ArrayAttr>(
"modules");
1855 if (!modulesAttr || modulesAttr.size() < 2)
1859 processInstanceGroups(
1860 circuitOp, modulesAttr,
1861 [&](ArrayRef<FInstanceLike>) { addMatch(1, matchId++); });
1866 ArrayRef<uint64_t> matches)
override {
1867 auto *context = circuitOp->getContext();
1869 SmallVector<Annotation> newAnnotations;
1870 uint64_t matchId = 0;
1872 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1874 newAnnotations.push_back(anno);
1878 auto modulesAttr = anno.getMember<ArrayAttr>(
"modules");
1879 if (!modulesAttr || modulesAttr.size() < 2) {
1880 newAnnotations.push_back(anno);
1885 bool anyMatchSelected =
false;
1886 processInstanceGroups(
1887 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
1889 if (!llvm::is_contained(matches, matchId++))
1891 anyMatchSelected =
true;
1894 SmallSetVector<StringAttr, 4> moduleTargets;
1895 for (
auto instOp : instanceGroup) {
1897 target.circuit = circuitOp.getName();
1898 target.module = instOp.getReferencedModuleName();
1899 moduleTargets.insert(target.toStringAttr(context));
1901 if (moduleTargets.size() < 2)
1905 SmallVector<NamedAttribute> newAnnoAttrs;
1906 newAnnoAttrs.emplace_back(
1907 StringAttr::get(context,
"class"),
1909 newAnnoAttrs.emplace_back(
1910 StringAttr::get(context,
"modules"),
1911 ArrayAttr::get(context,
1912 SmallVector<Attribute>(moduleTargets.begin(),
1913 moduleTargets.end())));
1915 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
1916 newAnnotations.emplace_back(newAnnoDict);
1922 if (anyMatchSelected)
1923 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1925 newAnnotations.push_back(anno);
1930 newAnnoSet.applyToOperation(circuitOp);
1934 std::string
getName()
const override {
return "must-dedup-children"; }
1942 void processInstanceGroups(
1943 CircuitOp circuitOp, ArrayAttr modulesAttr,
1944 llvm::function_ref<
void(ArrayRef<FInstanceLike>)> callback) {
1945 auto &symbolTable = symbols.getSymbolTable(circuitOp);
1948 SmallVector<FModuleLike> modules;
1949 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>())
1951 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
1952 modules.push_back(mod);
1955 if (modules.size() < 2)
1962 struct InstanceGroup {
1963 SmallVector<FInstanceLike> instances;
1964 bool nameIsUnique =
true;
1966 MapVector<StringAttr, InstanceGroup> instanceGroups;
1967 for (
auto module : modules) {
1969 module.walk([&](FInstanceLike instOp) {
1970 auto name = instOp.getInstanceNameAttr();
1971 auto &group = instanceGroups[name];
1972 if (nameCounts[name]++ > 1)
1973 group.nameIsUnique =
false;
1974 group.instances.push_back(instOp);
1980 for (
auto &[name, group] : instanceGroups)
1981 if (group.nameIsUnique && group.instances.size() >= 2)
1982 callback(group.instances);
2002 patterns.add<AnnotationRemover, 33>();
2005 patterns.add<MustDedupChildren, 30>();
2011 firrtl::createLowerCHIRRTLPass(),
true,
true);
2016 patterns.add<FIRRTLModuleExternalizer, 25>();
2017 patterns.add<InstanceStubber, 24>();
2022 firrtl::createLowerFIRRTLTypes(),
true,
true);
2029 firrtl::createRemoveUnusedPorts({
true}));
2030 patterns.add<NodeSymbolRemover, 15>();
2031 patterns.add<ConnectForwarder, 14>();
2032 patterns.add<ConnectInvalidator, 13>();
2034 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2035 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2036 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2037 patterns.add<DetachSubaccesses, 7>();
2039 patterns.add<ExtmoduleInstanceRemover, 4>();
2040 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2041 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2042 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2043 patterns.add<ModuleInternalNameSanitizer, 0>();
2044 patterns.add<ModuleNameSanitizer, 0>();
2048 mlir::DialectRegistry ®istry) {
2049 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
assert(baseType &&"element must be base type")
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
@ None
Don't explicitly preserve any named values.
constexpr const char * mustDedupAnnoClass
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
A reduction pattern that applies an mlir::Pass.
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)