25#include "mlir/Analysis/TopologicalSortUtils.h"
26#include "mlir/IR/Dominance.h"
27#include "mlir/IR/ImplicitLocOpBuilder.h"
28#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APSInt.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/Support/Debug.h"
34#define DEBUG_TYPE "firrtl-reductions"
38using namespace firrtl;
40using llvm::SmallDenseSet;
54 return tables->getSymbolTable(op);
64 return userMaps.insert({op, SymbolUserMap(*
tables, op)}).first->second;
71 tables = std::make_unique<SymbolTableCollection>();
76 std::unique_ptr<SymbolTableCollection>
tables;
83static std::optional<firrtl::FModuleOp>
86 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
87 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
89 return moduleOp ? std::optional(moduleOp) : std::nullopt;
100 module->walk([&](Operation *op) {
102 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(op))
116 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
117 auto *op = use.getOwner();
118 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
120 if (use.getOperandNumber() != 0)
122 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
139 unsigned numRemoved = 0;
141 SymbolTableCollection symbolTables;
142 for (Operation &rootOp : *
module.getBody()) {
143 if (!isa<firrtl::CircuitOp>(&rootOp))
145 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
146 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
148 if (
auto *op = symbolTable.lookup(sym)) {
149 if (symbolUserMap.useEmpty(op)) {
158 if (numRemoved > 0 || numLost > 0) {
159 llvm::dbgs() <<
"Removed " << numRemoved <<
" NLAs";
161 llvm::dbgs() <<
" (" << numLost <<
" no longer there)";
162 llvm::dbgs() <<
"\n";
171 if (
auto dict = dyn_cast<DictionaryAttr>(anno)) {
172 if (
auto field = dict.getAs<FlatSymbolRefAttr>(
"circt.nonlocal"))
173 nlasToRemove.insert(field.getAttr());
174 for (
auto namedAttr : dict)
175 markNLAsInAnnotation(namedAttr.getValue());
176 }
else if (
auto array = dyn_cast<ArrayAttr>(anno)) {
177 for (
auto attr : array)
178 markNLAsInAnnotation(attr);
186 op->walk([&](Operation *op) {
187 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
188 markNLAsInAnnotation(annos);
203struct FIRRTLModuleExternalizer :
public OpReduction<FModuleOp> {
210 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
212 uint64_t
match(FModuleOp module)
override {
213 if (innerSymUses.hasInnerRef(module))
215 return moduleSizes.getModuleSize(module, symbols);
218 LogicalResult
rewrite(FModuleOp module)
override {
221 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
222 for (
auto attr :
module.getPortTypes()) {
223 auto type = cast<TypeAttr>(attr).getValue();
224 if (
auto refType = type_dyn_cast<RefType>(type))
225 if (
auto layer = refType.getLayer())
226 layers.insert(layer);
228 SmallVector<Attribute, 4> layersArray;
229 layersArray.reserve(layers.size());
230 for (
auto layer : layers)
231 layersArray.push_back(layer);
233 nlaRemover.markNLAsInOperation(module);
234 OpBuilder builder(module);
235 auto extmodule = FExtModuleOp::create(
236 builder, module->getLoc(),
237 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
238 module.getConventionAttr(), module.getPorts(),
239 builder.getArrayAttr(layersArray), StringRef(),
240 module.getAnnotationsAttr());
241 SymbolTable::setSymbolVisibility(extmodule,
242 SymbolTable::getSymbolVisibility(module));
247 std::string
getName()
const override {
return "firrtl-module-externalizer"; }
266static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
268 auto type = type_dyn_cast<FIRRTLType>(value.getType());
273 if (
auto refType = type_dyn_cast<RefType>(type)) {
275 assert(!
flip &&
"input probes are not allowed");
277 auto underlyingType = refType.getType();
279 if (!refType.getForceable()) {
282 auto targetWire = WireOp::create(builder, underlyingType);
283 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
284 builder.create<RefDefineOp>(value, refSend.getResult());
287 auto invalid = tieOffCache.
getInvalid(underlyingType);
288 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
294 WireOp::create(builder, underlyingType,
295 "", NameKindEnum::DroppableName,
296 ArrayRef<Attribute>{},
301 auto targetWire = forceableWire.getResult();
302 auto forceableRef = forceableWire.getDataRef();
304 builder.create<RefDefineOp>(value, forceableRef);
307 auto invalid = tieOffCache.
getInvalid(underlyingType);
308 MatchingConnectOp::create(builder, targetWire, invalid);
313 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
314 for (
auto element :
llvm::enumerate(bundleType.getElements())) {
315 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
316 invalidateOutputs(builder, subfield, tieOffCache,
317 flip ^ element.value().isFlip);
318 if (subfield.use_empty())
319 subfield.getDefiningOp()->erase();
325 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
326 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
327 auto subindex = builder.createOrFold<SubindexOp>(value, i);
328 invalidateOutputs(builder, subindex, tieOffCache,
flip);
329 if (subindex.use_empty())
330 subindex.getDefiningOp()->erase();
340 if (
auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
341 auto invalid = tieOffCache.
getInvalid(baseType);
342 ConnectOp::create(builder, value, invalid);
347 if (
auto propType = type_dyn_cast<PropertyType>(type)) {
348 auto unknown = tieOffCache.
getUnknown(propType);
349 builder.create<PropAssignOp>(value, unknown);
354static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
356 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
359 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
360 for (
auto element :
llvm::enumerate(bundleType.getElements()))
361 connectToLeafs(builder,
362 firrtl::SubfieldOp::create(builder, dest, element.index()),
366 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
367 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
368 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
372 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
375 auto destWidth = type.getBitWidthOrSentinel();
376 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
377 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
378 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
379 if (!isa<firrtl::UIntType>(type)) {
380 if (isa<firrtl::SIntType>(type))
381 value = firrtl::AsSIntPrimOp::create(builder, value);
385 firrtl::ConnectOp::create(builder, dest, value);
389static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
390 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
393 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
394 for (
auto element :
llvm::enumerate(bundleType.getElements()))
397 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
400 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
401 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
402 reduceXor(builder, into,
403 builder.createOrFold<firrtl::SubindexOp>(value, i));
406 if (!isa<firrtl::UIntType>(type)) {
407 if (isa<firrtl::SIntType>(type))
408 value = firrtl::AsUIntPrimOp::create(builder, value);
412 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
418struct InstanceStubber :
public OpReduction<firrtl::InstanceOp> {
421 erasedModules.clear();
429 SmallVector<Operation *> worklist;
430 auto deadInsts = erasedInsts;
431 for (
auto *op : erasedModules)
432 worklist.push_back(op);
433 while (!worklist.empty()) {
434 auto *op = worklist.pop_back_val();
435 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
436 op->walk([&](firrtl::InstanceOp instOp) {
437 auto moduleOp = cast<firrtl::FModuleLike>(
438 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
439 deadInsts.insert(instOp);
441 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
442 [&](Operation *user) { return deadInsts.contains(user); })) {
443 LLVM_DEBUG(llvm::dbgs() <<
"- Removing transitively unused module `"
444 << moduleOp.getModuleName() <<
"`\n");
445 erasedModules.insert(moduleOp);
446 worklist.push_back(moduleOp);
451 for (
auto *op : erasedInsts)
453 for (
auto *op : erasedModules)
455 nlaRemover.remove(op);
458 uint64_t
match(firrtl::InstanceOp instOp)
override {
460 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
464 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
465 LLVM_DEBUG(llvm::dbgs()
466 <<
"Stubbing instance `" << instOp.getName() <<
"`\n");
467 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
469 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
470 auto result = instOp.getResult(i);
471 auto name = builder.getStringAttr(Twine(instOp.getName()) +
"_" +
472 instOp.getPortName(i));
474 firrtl::WireOp::create(builder, result.getType(), name,
475 firrtl::NameKindEnum::DroppableName,
476 instOp.getPortAnnotation(i), StringAttr{})
478 invalidateOutputs(builder, wire, tieOffCache,
479 instOp.getPortDirection(i) == firrtl::Direction::In);
480 result.replaceAllUsesWith(wire);
482 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
483 auto moduleOp = cast<firrtl::FModuleLike>(
484 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
485 nlaRemover.markNLAsInOperation(instOp);
486 erasedInsts.insert(instOp);
488 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
489 [&](Operation *user) { return erasedInsts.contains(user); })) {
490 LLVM_DEBUG(llvm::dbgs() <<
"- Removing now unused module `"
491 << moduleOp.getModuleName() <<
"`\n");
492 erasedModules.insert(moduleOp);
497 std::string
getName()
const override {
return "instance-stubber"; }
502 llvm::DenseSet<Operation *> erasedInsts;
503 llvm::DenseSet<Operation *> erasedModules;
509struct MemoryStubber :
public OpReduction<firrtl::MemOp> {
510 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
511 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
512 LogicalResult
rewrite(firrtl::MemOp memOp)
override {
513 LLVM_DEBUG(llvm::dbgs() <<
"Stubbing memory `" << memOp.getName() <<
"`\n");
514 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
517 SmallVector<Value> outputs;
518 for (
unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
519 auto result = memOp.getResult(i);
520 auto name = builder.getStringAttr(Twine(memOp.getName()) +
"_" +
521 memOp.getPortName(i));
523 firrtl::WireOp::create(builder, result.getType(), name,
524 firrtl::NameKindEnum::DroppableName,
525 memOp.getPortAnnotation(i), StringAttr{})
527 invalidateOutputs(builder, wire, tieOffCache,
true);
528 result.replaceAllUsesWith(wire);
532 switch (memOp.getPortKind(i)) {
533 case firrtl::MemOp::PortKind::Read:
534 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
536 case firrtl::MemOp::PortKind::Write:
537 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
539 case firrtl::MemOp::PortKind::ReadWrite:
540 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
541 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
543 case firrtl::MemOp::PortKind::Debug:
548 if (!isa<firrtl::RefType>(result.getType())) {
551 cast<firrtl::BundleType>(wire.getType()).getNumElements();
552 for (
unsigned i = 0; i != numFields; ++i) {
553 if (i != 2 && i != 3 && i != 5)
554 reduceXor(builder, xorInputs,
555 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
558 reduceXor(builder, xorInputs, input);
563 outputs.push_back(output);
567 for (
auto output : outputs)
568 connectToLeafs(builder, output, xorInputs);
570 nlaRemover.markNLAsInOperation(memOp);
574 std::string
getName()
const override {
return "memory-stubber"; }
581static bool isFlowSensitiveOp(Operation *op) {
582 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
583 SubaccessOp, ObjectSubfieldOp>(op);
589template <
unsigned OpNum>
590struct FIRRTLOperandForwarder :
public Reduction {
591 uint64_t
match(Operation *op)
override {
592 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
594 if (isFlowSensitiveOp(op))
597 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
599 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
600 return resultTy && opTy &&
601 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
602 (resultTy.getBitWidthOrSentinel() == -1) ==
603 (opTy.getBitWidthOrSentinel() == -1) &&
604 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
606 LogicalResult
rewrite(Operation *op)
override {
608 ImplicitLocOpBuilder builder(op->getLoc(), op);
609 auto result = op->getResult(0);
610 auto operand = op->getOperand(OpNum);
611 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
612 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
613 auto resultWidth = resultTy.getBitWidthOrSentinel();
614 auto operandWidth = operandTy.getBitWidthOrSentinel();
616 if (resultWidth < operandWidth)
618 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
619 else if (resultWidth > operandWidth)
620 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
623 LLVM_DEBUG(llvm::dbgs() <<
"Forwarding " << newOp <<
" in " << *op <<
"\n");
624 result.replaceAllUsesWith(newOp);
628 std::string
getName()
const override {
629 return (
"firrtl-operand" + Twine(OpNum) +
"-forwarder").str();
640 anyrefCastDummy.clear();
641 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
642 for (
auto classOp : circuitOp.getOps<ClassOp>()) {
643 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
644 anyrefCastDummy.insert({circuitOp, classOp});
645 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
648 return WalkResult::skip();
652 uint64_t
match(Operation *op)
override {
653 if (op->hasTrait<OpTrait::ConstantLike>()) {
655 if (!matchPattern(op, m_Constant(&attr)))
657 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
658 if (intAttr.getValue().isZero())
660 if (
auto strAttr = dyn_cast<StringAttr>(attr))
663 if (
auto floatAttr = dyn_cast<FloatAttr>(attr))
664 if (floatAttr.getValue().isZero())
667 if (
auto listOp = dyn_cast<ListCreateOp>(op))
668 if (listOp.getElements().empty())
670 if (
auto pathOp = dyn_cast<UnresolvedPathOp>(op))
671 if (pathOp.getTarget().empty())
675 if (
auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
676 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
678 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
679 if (anyrefCastDummyNames[circuitOp].contains(className))
683 if (op->getNumResults() != 1)
685 if (op->hasAttr(
"inner_sym"))
687 if (isFlowSensitiveOp(op))
689 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
690 DoubleType, ListType, PathType, AnyRefType>(
691 op->getResult(0).getType());
694 LogicalResult
rewrite(Operation *op)
override {
695 OpBuilder builder(op);
696 auto type = op->getResult(0).getType();
699 if (isa<UIntType, SIntType>(type)) {
700 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
703 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
704 APSInt(width, isa<UIntType>(type)));
705 op->replaceAllUsesWith(newOp);
711 if (isa<StringType>(type)) {
712 auto attr = builder.getStringAttr(
"");
713 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
714 op->replaceAllUsesWith(newOp);
720 if (isa<FIntegerType>(type)) {
721 auto attr = builder.getIntegerAttr(builder.getIntegerType(64,
true), 0);
722 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
723 op->replaceAllUsesWith(newOp);
729 if (isa<BoolType>(type)) {
730 auto attr = builder.getBoolAttr(
false);
731 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
732 op->replaceAllUsesWith(newOp);
738 if (isa<DoubleType>(type)) {
739 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
740 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
741 op->replaceAllUsesWith(newOp);
747 if (isa<ListType>(type)) {
749 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
750 op->replaceAllUsesWith(newOp);
756 if (isa<PathType>(type)) {
757 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(),
"");
758 op->replaceAllUsesWith(newOp);
764 if (isa<AnyRefType>(type)) {
765 auto circuitOp = op->getParentOfType<CircuitOp>();
766 auto &dummy = anyrefCastDummy[circuitOp];
768 OpBuilder::InsertionGuard guard(builder);
769 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
770 auto &symbolTable = symbols.getNearestSymbolTable(op);
771 dummy = ClassOp::create(builder, op->getLoc(),
"Dummy", {}, {});
772 symbolTable.insert(dummy);
773 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
775 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy,
"dummy");
777 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
778 op->replaceAllUsesWith(anyrefOp);
786 std::string
getName()
const override {
return "firrtl-constantifier"; }
798struct ConnectInvalidator :
public Reduction {
799 uint64_t
match(Operation *op)
override {
800 if (!isa<FConnectLike>(op))
802 if (
auto *srcOp = op->getOperand(1).getDefiningOp())
803 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
804 isa<InvalidValueOp>(srcOp))
806 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
807 return type && type.isPassive();
809 LogicalResult
rewrite(Operation *op)
override {
811 auto rhs = op->getOperand(1);
812 OpBuilder builder(op);
813 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
814 auto *rhsOp = rhs.getDefiningOp();
815 op->setOperand(1, invOp);
820 std::string
getName()
const override {
return "connect-invalidator"; }
827struct AnnotationRemover :
public Reduction {
828 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
829 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
832 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
833 uint64_t matchId = 0;
836 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
837 for (
unsigned i = 0; i < annos.size(); ++i)
838 addMatch(1, matchId++);
841 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations"))
842 for (
auto portAnnoArray : portAnnos)
843 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
844 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
845 addMatch(1, matchId++);
849 ArrayRef<uint64_t> matches)
override {
851 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
854 uint64_t matchId = 0;
855 auto processAnnotations =
856 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
857 SmallVector<Attribute> newAnnotations;
858 for (
auto anno : annotations) {
859 if (!matchesSet.contains(matchId)) {
860 newAnnotations.push_back(anno);
863 nlaRemover.markNLAsInAnnotation(anno);
867 return ArrayAttr::get(op->getContext(), newAnnotations);
871 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations")) {
872 op->setAttr(
"annotations", processAnnotations(annos.getValue()));
876 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations")) {
877 SmallVector<Attribute> newPortAnnos;
878 for (
auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
879 newPortAnnos.push_back(
880 processAnnotations(portAnnoArrayAttr.getValue()));
882 op->setAttr(
"portAnnotations",
883 ArrayAttr::get(op->getContext(), newPortAnnos));
889 std::string
getName()
const override {
return "annotation-remover"; }
896struct SimplifyResets :
public OpReduction<CircuitOp> {
897 uint64_t
match(CircuitOp circuit)
override {
898 uint64_t numResets = 0;
899 AttrTypeWalker walker;
900 walker.addWalk([&](ResetType type) { ++numResets; });
902 circuit.walk([&](Operation *op) {
903 for (
auto result : op->getResults())
904 walker.walk(result.getType());
906 for (
auto ®ion : op->getRegions())
907 for (auto &block : region)
908 for (auto arg : block.getArguments())
909 walker.walk(arg.getType());
911 walker.walk(op->getAttrDictionary());
917 LogicalResult
rewrite(CircuitOp circuit)
override {
918 auto uint1Type = UIntType::get(circuit->getContext(), 1,
false);
919 auto constUint1Type = UIntType::get(circuit->getContext(), 1,
true);
921 AttrTypeReplacer replacer;
922 replacer.addReplacement([&](ResetType type) {
923 return type.isConst() ? constUint1Type : uint1Type;
925 replacer.recursivelyReplaceElementsIn(circuit,
true,
930 circuit.walk([&](Operation *op) {
933 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
934 fullAsyncResetAnnoClass,
935 ignoreFullAsyncResetAnnoClass);
939 if (
auto module = dyn_cast<FModuleLike>(op)) {
942 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
943 fullAsyncResetAnnoClass,
944 ignoreFullAsyncResetAnnoClass);
952 std::string
getName()
const override {
return "firrtl-simplify-resets"; }
958struct RootPortPruner :
public OpReduction<firrtl::FModuleOp> {
959 void matches(firrtl::FModuleOp module,
960 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
961 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
962 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
966 size_t numPorts =
module.getNumPorts();
967 for (
unsigned i = 0; i != numPorts; ++i) {
974 ArrayRef<uint64_t> matches)
override {
976 llvm::BitVector dropPorts(module.getNumPorts());
977 for (
auto portIdx : matches)
978 dropPorts.set(portIdx);
981 for (
auto portIdx : matches) {
983 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
988 module.erasePorts(dropPorts);
992 std::string
getName()
const override {
return "root-port-pruner"; }
998struct RootExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
999 void matches(firrtl::FExtModuleOp module,
1000 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1001 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
1002 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
1007 size_t numPorts =
module.getNumPorts();
1008 for (
unsigned i = 0; i != numPorts; ++i)
1013 ArrayRef<uint64_t> matches)
override {
1014 if (matches.empty())
1018 llvm::BitVector dropPorts(module.getNumPorts());
1019 for (
auto portIdx : matches)
1020 dropPorts.set(portIdx);
1023 module.erasePorts(dropPorts);
1027 std::string
getName()
const override {
return "root-extmodule-port-pruner"; }
1032struct ExtmoduleInstanceRemover :
public OpReduction<firrtl::InstanceOp> {
1037 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1039 uint64_t
match(firrtl::InstanceOp instOp)
override {
1040 return isa<firrtl::FExtModuleOp>(
1041 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1043 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
1045 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1046 symbols.getNearestSymbolTable(instOp)))
1048 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1050 SmallVector<Value> replacementWires;
1052 auto wire = firrtl::WireOp::create(
1054 (Twine(instOp.getName()) +
"_" +
info.getName()).str())
1056 if (
info.isOutput()) {
1058 if (
auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(
info.type)) {
1060 firrtl::ConnectOp::create(builder, wire, inv);
1061 }
else if (
auto propType = dyn_cast<firrtl::PropertyType>(
info.type)) {
1062 auto unknown = tieOffCache.
getUnknown(propType);
1063 builder.create<firrtl::PropAssignOp>(wire, unknown);
1066 replacementWires.push_back(wire);
1068 nlaRemover.markNLAsInOperation(instOp);
1069 instOp.replaceAllUsesWith(std::move(replacementWires));
1073 std::string
getName()
const override {
return "extmodule-instance-remover"; }
1085struct PortPrunerHelpers {
1087 template <
typename ModuleOpType>
1088 static void computeUnusedInstancePorts(ModuleOpType module,
1089 ArrayRef<Operation *> users,
1090 llvm::BitVector &portsToRemove) {
1091 auto ports =
module.getPorts();
1092 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1093 bool portUsed =
false;
1094 for (
auto *user : users) {
1095 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1096 auto result = instOp.getResult(portIdx);
1097 if (!result.use_empty()) {
1104 portsToRemove.set(portIdx);
1110 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1111 const llvm::BitVector &portsToRemove) {
1113 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1114 for (
auto *user : users) {
1115 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1116 instancesToUpdate.push_back(instOp);
1119 for (
auto instOp : instancesToUpdate) {
1120 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1123 size_t newResultIdx = 0;
1124 for (
size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1126 if (portsToRemove[oldResultIdx]) {
1128 assert(instOp.getResult(oldResultIdx).use_empty() &&
1129 "removing port with uses");
1132 instOp.getResult(oldResultIdx)
1133 .replaceAllUsesWith(newInst->getResult(newResultIdx));
1144struct ModulePortPruner :
public OpReduction<firrtl::FModuleOp> {
1149 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1151 void matches(firrtl::FModuleOp module,
1152 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1153 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1154 auto &userMap = symbols.getSymbolUserMap(tableOp);
1155 auto ports =
module.getPorts();
1156 auto users = userMap.getUsers(module);
1160 llvm::BitVector portsToRemove(ports.size());
1164 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1168 portsToRemove.set();
1173 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1174 if (!portsToRemove[portIdx])
1176 if (!module.getArgument(portIdx).use_empty())
1177 portsToRemove.reset(portIdx);
1181 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1182 if (portsToRemove[portIdx])
1183 addMatch(1, portIdx);
1187 ArrayRef<uint64_t> matches)
override {
1188 if (matches.empty())
1192 llvm::BitVector portsToRemove(module.getNumPorts());
1193 for (
auto portIdx : matches)
1194 portsToRemove.set(portIdx);
1197 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1198 auto &userMap = symbols.getSymbolUserMap(tableOp);
1199 auto users = userMap.getUsers(module);
1202 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1207 module.erasePorts(portsToRemove);
1212 std::string
getName()
const override {
return "module-port-pruner"; }
1219struct ExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
1224 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1226 void matches(firrtl::FExtModuleOp module,
1227 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1228 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1229 auto &userMap = symbols.getSymbolUserMap(tableOp);
1230 auto ports =
module.getPorts();
1231 auto users = userMap.getUsers(module);
1234 llvm::BitVector portsToRemove(ports.size());
1236 if (users.empty()) {
1238 portsToRemove.set();
1242 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1247 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1248 if (portsToRemove[portIdx])
1249 addMatch(1, portIdx);
1253 ArrayRef<uint64_t> matches)
override {
1254 if (matches.empty())
1258 llvm::BitVector portsToRemove(module.getNumPorts());
1259 for (
auto portIdx : matches)
1260 portsToRemove.set(portIdx);
1263 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1264 auto &userMap = symbols.getSymbolUserMap(tableOp);
1265 auto users = userMap.getUsers(module);
1268 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1272 module.erasePorts(portsToRemove);
1277 std::string
getName()
const override {
return "extmodule-port-pruner"; }
1284struct ConnectForwarder :
public Reduction {
1286 domInfo = std::make_unique<DominanceInfo>(op);
1289 uint64_t
match(Operation *op)
override {
1290 if (!isa<firrtl::FConnectLike>(op))
1292 auto dest = op->getOperand(0);
1293 auto src = op->getOperand(1);
1294 auto *destOp = dest.getDefiningOp();
1295 auto *srcOp = src.getDefiningOp();
1301 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1307 unsigned numConnects = 0;
1308 for (
auto &use : dest.getUses()) {
1309 auto *op = use.getOwner();
1310 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1311 if (++numConnects > 1)
1318 !domInfo->properlyDominates(srcOp, op,
false))
1325 LogicalResult
rewrite(Operation *op)
override {
1326 auto dst = op->getOperand(0);
1327 auto src = op->getOperand(1);
1328 dst.replaceAllUsesExcept(src, op);
1330 SmallVector<Operation *> worklist(
1331 {dst.getDefiningOp(), src.getDefiningOp()});
1336 std::string
getName()
const override {
return "connect-forwarder"; }
1339 std::unique_ptr<DominanceInfo> domInfo;
1344template <
unsigned OpNum>
1345struct ConnectSourceOperandForwarder :
public Reduction {
1346 uint64_t
match(Operation *op)
override {
1347 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1349 auto dest = op->getOperand(0);
1350 auto *destOp = dest.getDefiningOp();
1353 if (!destOp || !destOp->hasOneUse() ||
1354 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1357 auto *srcOp = op->getOperand(1).getDefiningOp();
1358 if (!srcOp || OpNum >= srcOp->getNumOperands())
1361 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1363 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1365 return resultTy && opTy &&
1366 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1367 ((resultTy.getBitWidthOrSentinel() == -1) ==
1368 (opTy.getBitWidthOrSentinel() == -1)) &&
1369 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1372 LogicalResult
rewrite(Operation *op)
override {
1373 auto *destOp = op->getOperand(0).getDefiningOp();
1374 auto *srcOp = op->getOperand(1).getDefiningOp();
1375 auto forwardedOperand = srcOp->getOperand(OpNum);
1376 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1378 if (
auto wire = dyn_cast<firrtl::WireOp>(destOp))
1379 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1383 auto regName = destOp->getAttrOfType<StringAttr>(
"name");
1386 auto clock = destOp->getOperand(0);
1387 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1388 clock, regName ? regName.str() :
"")
1393 builder.setInsertionPointAfter(op);
1394 if (isa<firrtl::ConnectOp>(op))
1395 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1397 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1407 std::string
getName()
const override {
1408 return (
"connect-source-operand-" + Twine(OpNum) +
"-forwarder").str();
1415struct DetachSubaccesses :
public Reduction {
1416 void beforeReduction(mlir::ModuleOp op)
override { opsToErase.clear(); }
1418 for (
auto *op : opsToErase)
1419 op->dropAllReferences();
1420 for (
auto *op : opsToErase)
1423 uint64_t
match(Operation *op)
override {
1426 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1427 llvm::all_of(op->getUses(), [](
auto &use) {
1428 return use.getOperandNumber() == 0 &&
1429 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1430 firrtl::SubaccessOp>(use.getOwner());
1433 LogicalResult
rewrite(Operation *op)
override {
1435 OpBuilder builder(op);
1436 bool isWire = isa<firrtl::WireOp>(op);
1439 invalidClock = firrtl::InvalidValueOp::create(
1440 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1441 for (Operation *user :
llvm::make_early_inc_range(op->getUsers())) {
1442 builder.setInsertionPoint(user);
1443 auto type = user->getResult(0).getType();
1446 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1449 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1450 user->replaceAllUsesWith(replOp);
1451 opsToErase.insert(user);
1453 opsToErase.insert(op);
1456 std::string
getName()
const override {
return "detach-subaccesses"; }
1457 llvm::DenseSet<Operation *> opsToErase;
1463struct NodeSymbolRemover :
public Reduction {
1468 uint64_t
match(Operation *op)
override {
1470 auto sym = op->getAttrOfType<hw::InnerSymAttr>(
"inner_sym");
1471 if (!sym || sym.empty())
1475 if (innerSymUses.hasInnerRef(op))
1480 LogicalResult
rewrite(Operation *op)
override {
1481 op->removeAttr(
"inner_sym");
1485 std::string
getName()
const override {
return "node-symbol-remover"; }
1494hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1503 LogicalResult walkResult = targetTable.
walkSymbols(
1506 if (parentTable.lookup(name)) {
1514 return failed(walkResult);
1518struct EagerInliner :
public OpReduction<InstanceOp> {
1523 for (
auto circuitOp : op.getOps<CircuitOp>())
1524 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1525 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1528 nlaRemover.remove(op);
1530 innerSymTables.reset();
1533 uint64_t
match(InstanceOp instOp)
override {
1534 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1536 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1539 if (!isa<FModuleOp>(moduleOp))
1543 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1546 auto it = nlaTables.find(circuitOp);
1547 if (it == nlaTables.end() || !it->second)
1549 DenseSet<hw::HierPathOp> nlas;
1550 it->second->getInstanceNLAs(instOp, nlas);
1556 auto parentOp = instOp->getParentOfType<FModuleLike>();
1557 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1563 LogicalResult
rewrite(InstanceOp instOp)
override {
1564 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1565 auto moduleOp = cast<FModuleOp>(
1566 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1568 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1569 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1572 IRRewriter rewriter(instOp);
1573 SmallVector<Value> argWires;
1574 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1575 auto result = instOp.getResult(i);
1576 auto name = rewriter.getStringAttr(Twine(instOp.getName()) +
"_" +
1577 instOp.getPortName(i));
1578 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1579 name, NameKindEnum::DroppableName,
1580 instOp.getPortAnnotation(i), StringAttr{})
1582 result.replaceAllUsesWith(wire);
1583 argWires.push_back(wire);
1587 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1591 nlaRemover.markNLAsInOperation(instOp);
1593 nlaRemover.markNLAsInOperation(moduleOp);
1596 clonedModuleOp.erase();
1600 std::string
getName()
const override {
return "firrtl-eager-inliner"; }
1605 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1606 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1610struct ObjectInliner :
public OpReduction<ObjectOp> {
1612 blocksToSort.clear();
1615 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1618 for (
auto *block : blocksToSort)
1619 mlir::sortTopologically(block);
1620 blocksToSort.clear();
1621 nlaRemover.remove(op);
1622 innerSymTables.reset();
1625 uint64_t
match(ObjectOp objOp)
override {
1626 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1628 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1631 if (!isa<ClassOp>(classOp))
1636 auto parentOp = objOp->getParentOfType<FModuleLike>();
1637 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1641 for (
auto *user : objOp.getResult().getUsers())
1642 if (!isa<ObjectSubfieldOp>(user))
1648 LogicalResult
rewrite(ObjectOp objOp)
override {
1649 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1650 auto classOp = cast<ClassOp>(
1651 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1652 auto clonedClassOp = classOp.clone();
1655 IRRewriter rewriter(objOp);
1656 SmallVector<Value> portWires;
1657 auto classType = objOp.getType();
1660 for (
unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1661 auto element = classType.getElement(i);
1662 auto name = rewriter.getStringAttr(Twine(objOp.getName()) +
"_" +
1663 element.name.getValue());
1664 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1665 NameKindEnum::DroppableName,
1666 rewriter.getArrayAttr({}), StringAttr{})
1668 portWires.push_back(wire);
1672 SmallVector<ObjectSubfieldOp> subfieldOps;
1673 for (
auto *user : objOp.getResult().getUsers()) {
1674 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1675 subfieldOps.push_back(subfieldOp);
1676 auto index = subfieldOp.getIndex();
1677 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1681 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1687 SmallVector<FConnectLike> connectsToErase;
1688 for (
auto portWire : portWires) {
1692 for (
auto *user : portWire.getUsers()) {
1693 if (
auto connect = dyn_cast<FConnectLike>(user)) {
1694 if (
connect.getDest() == portWire) {
1696 connectsToErase.push_back(connect);
1706 portWire.replaceAllUsesWith(value);
1707 for (
auto connect : connectsToErase)
1709 if (portWire.use_empty())
1710 portWire.getDefiningOp()->erase();
1711 connectsToErase.clear();
1715 nlaRemover.markNLAsInOperation(objOp);
1720 blocksToSort.insert(objOp->getBlock());
1723 for (
auto subfieldOp : subfieldOps)
1726 clonedClassOp.erase();
1730 std::string
getName()
const override {
return "firrtl-object-inliner"; }
1733 SetVector<Block *> blocksToSort;
1736 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1741struct ResetDisconnector :
public OpReduction<RegResetOp> {
1742 uint64_t
match(RegResetOp op)
override {
return 1; }
1744 LogicalResult
rewrite(RegResetOp regResetOp)
override {
1745 ImplicitLocOpBuilder builder(regResetOp.getLoc(), regResetOp);
1746 auto regOp = RegOp::create(
1747 builder, regResetOp.getResult().getType(), regResetOp.getClockVal(),
1748 regResetOp.getNameAttr(), regResetOp.getNameKindAttr(),
1749 regResetOp.getAnnotationsAttr(), regResetOp.getInnerSymAttr(),
1750 regResetOp.getForceableAttr());
1752 regResetOp.getResult().replaceAllUsesWith(regOp.getResult());
1753 if (regResetOp.getForceable())
1754 regResetOp.getRef().replaceAllUsesWith(regOp.getRef());
1760 std::string
getName()
const override {
return "reset-disconnector"; }
1775 uint64_t
match(Operation *op)
override {
1777 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1778 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1779 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1780 firrtl::CoverOp>(op);
1782 LogicalResult
rewrite(Operation *op)
override {
1783 TypeSwitch<Operation *, void>(op)
1784 .Case<firrtl::WireOp>([](
auto op) { op.setName(
"wire"); })
1785 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1786 [](
auto op) { op.setName(
"reg"); })
1787 .Case<firrtl::NodeOp>([](
auto op) { op.setName(
"node"); })
1788 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1789 [](
auto op) { op.setName(
"mem"); })
1790 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](
auto op) {
1791 op->setAttr(
"message", StringAttr::get(op.getContext(),
""));
1792 op->setAttr(
"name", StringAttr::get(op.getContext(),
""));
1797 std::string
getName()
const override {
1798 return "module-internal-name-sanitizer";
1803 bool isOneShot()
const override {
return true; }
1823 if (portNameIndex >= 26)
1825 return 'a' + portNameIndex++;
1830 LogicalResult
rewrite(firrtl::CircuitOp circuitOp)
override {
1835 SymbolTable symTable(circuitOp);
1839 auto renameModule = [&](firrtl::FModuleLike mod,
1840 StringAttr newName) -> LogicalResult {
1841 StringAttr oldName = mod.getModuleNameAttr();
1842 if (failed(symTable.rename(mod, newName)))
1844 nlaTable.renameModule(oldName, newName);
1850 auto topModule = iGraph.getTopLevelModule();
1851 auto *ctx = circuitOp.getContext();
1853 topModule.getModuleName())) {
1854 auto newTopName = StringAttr::get(ctx, nameGenerator.
getNextName(ns));
1855 if (failed(renameModule(topModule, newTopName)))
1857 circuitOp.setName(newTopName.getValue());
1860 for (
auto *node : iGraph) {
1861 auto module = node->getModule<firrtl::FModuleLike>();
1863 bool shouldReplacePorts =
false;
1864 SmallVector<Attribute> newPortNames;
1865 if (
auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1870 auto oldPorts = fmodule.getPorts();
1871 shouldReplacePorts = !oldPorts.empty();
1872 for (
unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1873 auto port = oldPorts[i];
1875 .
Case<firrtl::ClockType>(
1876 [&](
auto a) {
return ns.
newName(
"clk"); })
1877 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1878 [&](
auto a) {
return ns.
newName(
"rst"); })
1879 .Case<firrtl::RefType>(
1880 [&](
auto a) {
return ns.
newName(
"ref"); })
1881 .Default([&](
auto a) {
1884 newPortNames.push_back(StringAttr::get(ctx, newName));
1886 fmodule->setAttr(
"portNames",
1887 ArrayAttr::get(fmodule.getContext(), newPortNames));
1890 if (module == iGraph.getTopLevelModule())
1894 module.getModuleName()))
1896 auto newName = StringAttr::get(ctx, nameGenerator.
getNextName(ns));
1897 if (failed(renameModule(module, newName)))
1899 for (
auto *use : node->uses()) {
1900 auto useOp = use->getInstance();
1901 if (
auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1905 instanceOp.setName(newName);
1906 if (shouldReplacePorts)
1907 instanceOp.setPortNamesAttr(ArrayAttr::get(ctx, newPortNames));
1908 }
else if (
auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1912 auto oldClassType = objectOp.getType();
1913 auto newClassType = firrtl::ClassType::get(
1914 ctx, FlatSymbolRefAttr::get(newName), oldClassType.getElements());
1915 objectOp.getResult().setType(newClassType);
1916 objectOp.setName(newName);
1924 std::string
getName()
const override {
return "module-name-sanitizer"; }
1928 bool isOneShot()
const override {
return true; }
1947struct ModuleSwapper :
public OpReduction<InstanceOp> {
1949 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1950 struct CircuitState {
1951 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1952 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1953 std::unique_ptr<NLATable> nlaTable;
1959 moduleSizes.clear();
1960 circuitStates.clear();
1963 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1964 auto &state = circuitStates[circuitOp];
1965 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1966 buildModuleTypeGroups(circuitOp, state);
1967 return WalkResult::skip();
1970 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1976 PortSignature getModulePortSignature(FModuleLike module) {
1977 PortSignature signature;
1978 signature.reserve(module.getNumPorts());
1979 for (
unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1980 signature.emplace_back(module.getPortType(i),
module.getPortDirection(i));
1985 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1987 for (
auto module : circuitOp.
getBodyBlock()->getOps<FModuleLike>()) {
1988 auto signature = getModulePortSignature(module);
1989 state.moduleTypeGroups[signature].push_back(module);
1993 for (
auto &[signature, modules] : state.moduleTypeGroups) {
1994 if (modules.size() <= 1)
1997 FModuleLike smallestModule =
nullptr;
1998 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
2000 for (
auto module : modules) {
2001 uint64_t size = moduleSizes.getModuleSize(module, symbols);
2002 if (size < smallestSize) {
2003 smallestSize = size;
2004 smallestModule =
module;
2009 for (
auto module : modules) {
2010 if (module != smallestModule) {
2011 state.instanceToCanonicalModule[
module.getModuleNameAttr()] =
2018 uint64_t
match(InstanceOp instOp)
override {
2020 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2022 const auto &state = circuitStates.at(circuitOp);
2025 DenseSet<hw::HierPathOp> nlas;
2026 state.nlaTable->getInstanceNLAs(instOp, nlas);
2031 auto moduleName = instOp.getModuleNameAttr().getAttr();
2032 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
2033 if (!canonicalModule)
2037 auto currentModule = cast<FModuleLike>(
2038 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
2039 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
2040 uint64_t canonicalSize =
2041 moduleSizes.getModuleSize(canonicalModule, symbols);
2042 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
2045 LogicalResult
rewrite(InstanceOp instOp)
override {
2047 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2049 const auto &state = circuitStates.at(circuitOp);
2052 auto canonicalModule = state.instanceToCanonicalModule.at(
2053 instOp.getModuleNameAttr().getAttr());
2054 auto canonicalName = canonicalModule.getModuleNameAttr();
2055 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2058 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2063 std::string
getName()
const override {
return "firrtl-module-swapper"; }
2072 DenseMap<CircuitOp, CircuitState> circuitStates;
2090struct ForceDedup :
public OpReduction<CircuitOp> {
2094 modulesToErase.clear();
2095 moduleSizes.clear();
2098 nlaRemover.remove(op);
2099 for (
auto mod : modulesToErase)
2104 void matches(CircuitOp circuitOp,
2105 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2106 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2108 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2109 if (!anno.
isClass(mustDeduplicateAnnoClass))
2112 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2113 if (!modulesAttr || modulesAttr.size() < 2)
2119 uint64_t totalSize = 0;
2120 ArrayAttr portTypes;
2121 DenseBoolArrayAttr portDirections;
2122 bool allSame =
true;
2123 for (
auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2129 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2134 totalSize += moduleSizes.getModuleSize(mod, symbols);
2136 portTypes = mod.getPortTypesAttr();
2137 portDirections = mod.getPortDirectionsAttr();
2138 }
else if (portTypes != mod.getPortTypesAttr() ||
2139 portDirections != mod.getPortDirectionsAttr()) {
2149 addMatch(totalSize, annoIdx);
2154 ArrayRef<uint64_t> matches)
override {
2155 auto *
context = circuitOp->getContext();
2159 SmallVector<Annotation> newAnnotations;
2161 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2163 if (!llvm::is_contained(matches, annoIdx)) {
2164 newAnnotations.push_back(anno);
2167 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2168 assert(anno.
isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2169 modulesAttr.size() >= 2);
2172 SmallVector<StringAttr> moduleNames;
2173 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2175 auto refStr = moduleRef.getValue();
2176 auto pipePos = refStr.find(
'|');
2177 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2178 auto moduleName = refStr.substr(pipePos + 1);
2179 moduleNames.push_back(StringAttr::get(
context, moduleName));
2184 if (moduleNames.size() < 2)
2189 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2190 nlaRemover.markNLAsInAnnotation(anno.
getAttr());
2192 if (newAnnotations.size() == annotations.size())
2197 newAnnoSet.applyToOperation(circuitOp);
2201 std::string
getName()
const override {
return "firrtl-force-dedup"; }
2207 void replaceModuleReferences(CircuitOp circuitOp,
2208 ArrayRef<StringAttr> moduleNames,
2211 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2212 auto &symbolTable = symbols.getSymbolTable(tableOp);
2213 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2214 auto *
context = circuitOp->getContext();
2218 FModuleLike canonicalModule;
2219 SmallVector<FModuleLike> modulesToReplace;
2220 for (
auto name : moduleNames) {
2221 if (
auto mod = symbolTable.lookup<FModuleLike>(name)) {
2222 if (!canonicalModule)
2223 canonicalModule = mod;
2225 modulesToReplace.push_back(mod);
2228 if (modulesToReplace.empty())
2232 auto canonicalName = canonicalModule.getModuleNameAttr();
2233 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2234 for (
auto moduleName : moduleNames) {
2235 if (moduleName == canonicalName)
2237 auto *symbolOp = symbolTable.lookup(moduleName);
2240 for (
auto *user : symbolUserMap.getUsers(symbolOp)) {
2241 auto instOp = dyn_cast<InstanceOp>(user);
2242 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2244 instOp.setModuleNameAttr(canonicalRef);
2245 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2251 for (
auto oldMod : modulesToReplace) {
2252 SmallVector<hw::HierPathOp> nlaOps(
2253 nlaTable.
lookup(oldMod.getModuleNameAttr()));
2254 for (
auto nlaOp : nlaOps) {
2255 nlaTable.
erase(nlaOp);
2256 StringAttr oldModName = oldMod.getModuleNameAttr();
2257 StringAttr newModName = canonicalName;
2258 SmallVector<Attribute, 4> newPath;
2259 for (
auto nameRef : nlaOp.getNamepath()) {
2260 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2261 if (ref.getModule() == oldModName) {
2262 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2263 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2264 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2265 if (oldInst && newInst) {
2268 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2269 auto newModNames = newInst.getReferencedModuleNamesAttr();
2270 if (!oldModNames.empty() && !newModNames.empty()) {
2271 oldModName = cast<StringAttr>(oldModNames[0]);
2272 newModName = cast<StringAttr>(newModNames[0]);
2276 newPath.push_back(ref);
2277 }
else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2278 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2280 newPath.push_back(nameRef);
2283 nlaOp.setNamepathAttr(ArrayAttr::get(
context, newPath));
2289 for (
auto module : modulesToReplace) {
2290 nlaRemover.markNLAsInOperation(module);
2291 modulesToErase.insert(module);
2297 SetVector<FModuleLike> modulesToErase;
2317struct MustDedupChildren :
public OpReduction<CircuitOp> {
2322 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
2326 void matches(CircuitOp circuitOp,
2327 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2329 uint64_t matchId = 0;
2331 DenseSet<StringRef> modulesAlreadyInMustDedup;
2332 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations))
2333 if (anno.isClass(mustDeduplicateAnnoClass))
2334 if (auto modulesAttr = anno.getMember<ArrayAttr>(
"modules"))
2335 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2337 modulesAlreadyInMustDedup.insert(target->module);
2339 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2340 if (!anno.
isClass(mustDeduplicateAnnoClass))
2343 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2344 if (!modulesAttr || modulesAttr.size() < 2)
2348 processInstanceGroups(
2349 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2353 SmallDenseSet<StringAttr, 4> moduleTargets;
2354 for (
auto instOp : instanceGroup) {
2355 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2356 for (
auto moduleName : moduleNames)
2357 moduleTargets.insert(cast<StringAttr>(moduleName));
2359 if (moduleTargets.size() < 2)
2364 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2365 auto moduleNames = inst.getReferencedModuleNames();
2366 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2367 return modulesAlreadyInMustDedup.contains(moduleName);
2372 addMatch(1, matchId - 1);
2378 ArrayRef<uint64_t> matches)
override {
2379 auto *
context = circuitOp->getContext();
2381 SmallVector<Annotation> newAnnotations;
2382 uint64_t matchId = 0;
2384 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2385 if (!anno.
isClass(mustDeduplicateAnnoClass)) {
2386 newAnnotations.push_back(anno);
2390 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2391 if (!modulesAttr || modulesAttr.size() < 2) {
2392 newAnnotations.push_back(anno);
2396 processInstanceGroups(
2397 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2399 if (!llvm::is_contained(matches, matchId++))
2404 for (
auto instOp : instanceGroup) {
2405 auto moduleNames = instOp.getReferencedModuleNames();
2406 for (
auto moduleName : moduleNames) {
2408 target.circuit = circuitOp.getName();
2409 target.module = moduleName;
2410 moduleTargets.insert(target.toStringAttr(
context));
2415 SmallVector<NamedAttribute> newAnnoAttrs;
2416 newAnnoAttrs.emplace_back(
2417 StringAttr::get(
context,
"class"),
2418 StringAttr::get(
context, mustDeduplicateAnnoClass));
2419 newAnnoAttrs.emplace_back(
2420 StringAttr::get(
context,
"modules"),
2422 SmallVector<Attribute>(moduleTargets.begin(),
2423 moduleTargets.end())));
2425 auto newAnnoDict = DictionaryAttr::get(
context, newAnnoAttrs);
2426 newAnnotations.emplace_back(newAnnoDict);
2430 newAnnotations.push_back(anno);
2435 newAnnoSet.applyToOperation(circuitOp);
2439 std::string
getName()
const override {
return "must-dedup-children"; }
2447 void processInstanceGroups(
2448 CircuitOp circuitOp, ArrayAttr modulesAttr,
2449 llvm::function_ref<
void(ArrayRef<FInstanceLike>)> callback) {
2450 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2453 SmallVector<FModuleLike> modules;
2454 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2456 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2457 modules.push_back(mod);
2460 if (modules.size() < 2)
2467 struct InstanceGroup {
2468 SmallVector<FInstanceLike> instances;
2469 bool nameIsUnique =
true;
2472 for (
auto module : modules) {
2474 module.walk([&](FInstanceLike instOp) {
2475 if (isa<ObjectOp>(instOp.getOperation()))
2477 auto name = instOp.getInstanceNameAttr();
2478 auto &group = instanceGroups[name];
2479 if (nameCounts[name]++ > 1)
2480 group.nameIsUnique =
false;
2481 group.instances.push_back(instOp);
2487 for (
auto &[name, group] : instanceGroups)
2488 if (group.nameIsUnique && group.instances.size() >= 2)
2489 callback(group.instances);
2496struct LayerDisable :
public OpReduction<CircuitOp> {
2497 LayerDisable(MLIRContext *
context) {
2498 pm = std::make_unique<mlir::PassManager>(
2499 context,
"builtin.module", mlir::OpPassManager::Nesting::Explicit);
2500 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2503 void beforeReduction(mlir::ModuleOp op)
override { symbolRefAttrMap.clear(); }
2505 void afterReduction(mlir::ModuleOp op)
override { (void)pm->run(op); };
2507 void matches(CircuitOp circuitOp,
2508 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2509 uint64_t matchId = 0;
2511 SmallVector<FlatSymbolRefAttr> nestedRefs;
2512 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2515 rootRef = layerOp.getSymNameAttr();
2517 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2519 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2520 addMatch(1, matchId++);
2522 for (
auto nestedLayerOp : layerOp.getOps<LayerOp>())
2523 addLayer(rootRef, nestedLayerOp);
2525 if (!nestedRefs.empty())
2526 nestedRefs.pop_back();
2529 for (
auto layerOp : circuitOp.getOps<LayerOp>())
2530 addLayer({}, layerOp);
2534 ArrayRef<uint64_t> matches)
override {
2535 SmallVector<Attribute> disableLayers;
2536 if (
auto existingDisables = circuitOp.getDisableLayersAttr()) {
2537 auto disableRange = existingDisables.getAsRange<Attribute>();
2538 disableLayers.append(disableRange.begin(), disableRange.end());
2540 for (
auto match : matches)
2541 disableLayers.push_back(symbolRefAttrMap.at(match));
2543 circuitOp.setDisableLayersAttr(
2544 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2549 std::string
getName()
const override {
return "firrtl-layer-disable"; }
2551 std::unique_ptr<mlir::PassManager> pm;
2552 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2562 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2564 auto elements = listOp.getElements();
2565 for (
size_t i = 0; i < elements.size(); ++i)
2570 ArrayRef<uint64_t>
matches)
override {
2572 llvm::SmallDenseSet<uint64_t, 4> matchesSet(
matches.begin(),
matches.end());
2575 SmallVector<Value> newElements;
2576 auto elements = listOp.getElements();
2577 for (
size_t i = 0; i < elements.size(); ++i) {
2578 if (!matchesSet.contains(i))
2579 newElements.push_back(elements[i]);
2583 OpBuilder builder(listOp);
2584 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2585 listOp.getType(), newElements);
2586 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2593 return "firrtl-list-create-element-remover";
2599 uint64_t
match(FModuleOp module)
override {
2600 return module.getConvention() != Convention::Internal;
2603 LogicalResult
rewrite(FModuleOp module)
override {
2604 module.setConvention(Convention::Internal);
2608 std::string
getName()
const override {
return "module-convention-remover"; }
2615 uint64_t
match(FExtModuleOp extmodule)
override {
2616 return extmodule.getConvention() != Convention::Internal;
2619 LogicalResult
rewrite(FExtModuleOp extmodule)
override {
2620 extmodule.setConvention(Convention::Internal);
2625 return "extmodule-convention-remover";
2642 patterns.add<SimplifyResets, 35>();
2644 patterns.add<MustDedupChildren, 33>();
2645 patterns.add<AnnotationRemover, 32>();
2647 patterns.add<LayerDisable, 30>(getContext());
2653 firrtl::createLowerCHIRRTLPass(),
true,
true);
2658 patterns.add<FIRRTLModuleExternalizer, 25>();
2659 patterns.add<InstanceStubber, 24>();
2664 firrtl::createLowerFIRRTLTypes(),
true,
true);
2671 firrtl::createRemoveUnusedPorts({
true}));
2672 patterns.add<NodeSymbolRemover, 16>();
2674 patterns.add<ConnectForwarder, 14>();
2675 patterns.add<ConnectInvalidator, 13>();
2677 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2678 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2679 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2681 patterns.add<ResetDisconnector, 8>();
2682 patterns.add<DetachSubaccesses, 7>();
2683 patterns.add<ModulePortPruner, 7>();
2684 patterns.add<ExtmodulePortPruner, 6>();
2686 patterns.add<RootExtmodulePortPruner, 5>();
2687 patterns.add<ExtmoduleInstanceRemover, 4>();
2688 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2689 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2690 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2698 mlir::DialectRegistry ®istry) {
2699 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Helper class to cache tie-off values for different FIRRTL types.
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
@ None
Don't explicitly preserve any named values.
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
void pruneUnusedOps(SmallVectorImpl< Operation * > &worklist, Reduction &reduction)
Starting from an initial worklist of operations, traverse through it and its operands and erase opera...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Reduction that removes the convention attribute from external modules.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
uint64_t match(FExtModuleOp extmodule) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(FExtModuleOp extmodule) override
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Reduction that removes the convention attribute from regular modules.
uint64_t match(FModuleOp module) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(FModuleOp module) override
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Pseudo-reduction that sanitizes the names of operations inside modules.
Pseudo-reduction that sanitizes module and port names.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
A reduction pattern that applies an mlir::Pass.
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
The namespace of a CircuitOp, generally inhabited by modules.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)