25#include "mlir/Analysis/TopologicalSortUtils.h"
26#include "mlir/IR/Dominance.h"
27#include "mlir/IR/ImplicitLocOpBuilder.h"
28#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APSInt.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/Support/Debug.h"
34#define DEBUG_TYPE "firrtl-reductions"
38using namespace firrtl;
40using llvm::SmallDenseSet;
54 return tables->getSymbolTable(op);
64 return userMaps.insert({op, SymbolUserMap(*
tables, op)}).first->second;
71 tables = std::make_unique<SymbolTableCollection>();
76 std::unique_ptr<SymbolTableCollection>
tables;
83static std::optional<firrtl::FModuleOp>
86 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
87 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
89 return moduleOp ? std::optional(moduleOp) : std::nullopt;
100 module->walk([&](Operation *op) {
102 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(op))
116 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
117 auto *op = use.getOwner();
118 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
120 if (use.getOperandNumber() != 0)
122 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
139 unsigned numRemoved = 0;
141 SymbolTableCollection symbolTables;
142 for (Operation &rootOp : *
module.getBody()) {
143 if (!isa<firrtl::CircuitOp>(&rootOp))
145 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
146 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
148 if (
auto *op = symbolTable.lookup(sym)) {
149 if (symbolUserMap.useEmpty(op)) {
158 if (numRemoved > 0 || numLost > 0) {
159 llvm::dbgs() <<
"Removed " << numRemoved <<
" NLAs";
161 llvm::dbgs() <<
" (" << numLost <<
" no longer there)";
162 llvm::dbgs() <<
"\n";
171 if (
auto dict = dyn_cast<DictionaryAttr>(anno)) {
172 if (
auto field = dict.getAs<FlatSymbolRefAttr>(
"circt.nonlocal"))
173 nlasToRemove.insert(field.getAttr());
174 for (
auto namedAttr : dict)
175 markNLAsInAnnotation(namedAttr.getValue());
176 }
else if (
auto array = dyn_cast<ArrayAttr>(anno)) {
177 for (
auto attr : array)
178 markNLAsInAnnotation(attr);
186 op->walk([&](Operation *op) {
187 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
188 markNLAsInAnnotation(annos);
203struct FIRRTLModuleExternalizer :
public OpReduction<FModuleOp> {
210 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
212 uint64_t
match(FModuleOp module)
override {
213 if (innerSymUses.hasInnerRef(module))
215 return moduleSizes.getModuleSize(module, symbols);
218 LogicalResult
rewrite(FModuleOp module)
override {
221 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
222 for (
auto attr :
module.getPortTypes()) {
223 auto type = cast<TypeAttr>(attr).getValue();
224 if (
auto refType = type_dyn_cast<RefType>(type))
225 if (
auto layer = refType.getLayer())
226 layers.insert(layer);
228 SmallVector<Attribute, 4> layersArray;
229 layersArray.reserve(layers.size());
230 for (
auto layer : layers)
231 layersArray.push_back(layer);
233 nlaRemover.markNLAsInOperation(module);
234 OpBuilder builder(module);
235 auto extmodule = FExtModuleOp::create(
236 builder, module->getLoc(),
237 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
238 module.getConventionAttr(), module.getPorts(),
239 builder.getArrayAttr(layersArray), StringRef(),
240 module.getAnnotationsAttr());
241 SymbolTable::setSymbolVisibility(extmodule,
242 SymbolTable::getSymbolVisibility(module));
247 std::string
getName()
const override {
return "firrtl-module-externalizer"; }
266static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
268 auto type = type_dyn_cast<FIRRTLType>(value.getType());
273 if (
auto refType = type_dyn_cast<RefType>(type)) {
275 assert(!
flip &&
"input probes are not allowed");
277 auto underlyingType = refType.getType();
279 if (!refType.getForceable()) {
282 auto targetWire = WireOp::create(builder, underlyingType);
283 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
284 builder.create<RefDefineOp>(value, refSend.getResult());
287 auto invalid = tieOffCache.
getInvalid(underlyingType);
288 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
294 WireOp::create(builder, underlyingType,
295 "", NameKindEnum::DroppableName,
296 ArrayRef<Attribute>{},
301 auto targetWire = forceableWire.getResult();
302 auto forceableRef = forceableWire.getDataRef();
304 builder.create<RefDefineOp>(value, forceableRef);
307 auto invalid = tieOffCache.
getInvalid(underlyingType);
308 MatchingConnectOp::create(builder, targetWire, invalid);
313 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
314 for (
auto element :
llvm::enumerate(bundleType.getElements())) {
315 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
316 invalidateOutputs(builder, subfield, tieOffCache,
317 flip ^ element.value().isFlip);
318 if (subfield.use_empty())
319 subfield.getDefiningOp()->erase();
325 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
326 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
327 auto subindex = builder.createOrFold<SubindexOp>(value, i);
328 invalidateOutputs(builder, subindex, tieOffCache,
flip);
329 if (subindex.use_empty())
330 subindex.getDefiningOp()->erase();
340 if (
auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
341 auto invalid = tieOffCache.
getInvalid(baseType);
342 ConnectOp::create(builder, value, invalid);
347 if (
auto propType = type_dyn_cast<PropertyType>(type)) {
348 auto unknown = tieOffCache.
getUnknown(propType);
349 builder.create<PropAssignOp>(value, unknown);
354static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
356 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
359 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
360 for (
auto element :
llvm::enumerate(bundleType.getElements()))
361 connectToLeafs(builder,
362 firrtl::SubfieldOp::create(builder, dest, element.index()),
366 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
367 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
368 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
372 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
375 auto destWidth = type.getBitWidthOrSentinel();
376 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
377 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
378 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
379 if (!isa<firrtl::UIntType>(type)) {
380 if (isa<firrtl::SIntType>(type))
381 value = firrtl::AsSIntPrimOp::create(builder, value);
385 firrtl::ConnectOp::create(builder, dest, value);
389static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
390 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
393 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
394 for (
auto element :
llvm::enumerate(bundleType.getElements()))
397 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
400 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
401 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
402 reduceXor(builder, into,
403 builder.createOrFold<firrtl::SubindexOp>(value, i));
406 if (!isa<firrtl::UIntType>(type)) {
407 if (isa<firrtl::SIntType>(type))
408 value = firrtl::AsUIntPrimOp::create(builder, value);
412 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
418struct InstanceStubber :
public OpReduction<firrtl::InstanceOp> {
421 erasedModules.clear();
429 SmallVector<Operation *> worklist;
430 auto deadInsts = erasedInsts;
431 for (
auto *op : erasedModules)
432 worklist.push_back(op);
433 while (!worklist.empty()) {
434 auto *op = worklist.pop_back_val();
435 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
436 op->walk([&](firrtl::InstanceOp instOp) {
437 auto moduleOp = cast<firrtl::FModuleLike>(
438 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
439 deadInsts.insert(instOp);
441 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
442 [&](Operation *user) { return deadInsts.contains(user); })) {
443 LLVM_DEBUG(llvm::dbgs() <<
"- Removing transitively unused module `"
444 << moduleOp.getModuleName() <<
"`\n");
445 erasedModules.insert(moduleOp);
446 worklist.push_back(moduleOp);
451 for (
auto *op : erasedInsts)
453 for (
auto *op : erasedModules)
455 nlaRemover.remove(op);
458 uint64_t
match(firrtl::InstanceOp instOp)
override {
460 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
464 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
465 LLVM_DEBUG(llvm::dbgs()
466 <<
"Stubbing instance `" << instOp.getName() <<
"`\n");
467 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
469 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
470 auto result = instOp.getResult(i);
471 auto name = builder.getStringAttr(Twine(instOp.getName()) +
"_" +
472 instOp.getPortName(i));
474 firrtl::WireOp::create(builder, result.getType(), name,
475 firrtl::NameKindEnum::DroppableName,
476 instOp.getPortAnnotation(i), StringAttr{})
478 invalidateOutputs(builder, wire, tieOffCache,
479 instOp.getPortDirection(i) == firrtl::Direction::In);
480 result.replaceAllUsesWith(wire);
482 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
483 auto moduleOp = cast<firrtl::FModuleLike>(
484 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
485 nlaRemover.markNLAsInOperation(instOp);
486 erasedInsts.insert(instOp);
488 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
489 [&](Operation *user) { return erasedInsts.contains(user); })) {
490 LLVM_DEBUG(llvm::dbgs() <<
"- Removing now unused module `"
491 << moduleOp.getModuleName() <<
"`\n");
492 erasedModules.insert(moduleOp);
497 std::string
getName()
const override {
return "instance-stubber"; }
502 llvm::DenseSet<Operation *> erasedInsts;
503 llvm::DenseSet<Operation *> erasedModules;
509struct MemoryStubber :
public OpReduction<firrtl::MemOp> {
510 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
511 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
512 LogicalResult
rewrite(firrtl::MemOp memOp)
override {
513 LLVM_DEBUG(llvm::dbgs() <<
"Stubbing memory `" << memOp.getName() <<
"`\n");
514 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
517 SmallVector<Value> outputs;
518 for (
unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
519 auto result = memOp.getResult(i);
520 auto name = builder.getStringAttr(Twine(memOp.getName()) +
"_" +
521 memOp.getPortName(i));
523 firrtl::WireOp::create(builder, result.getType(), name,
524 firrtl::NameKindEnum::DroppableName,
525 memOp.getPortAnnotation(i), StringAttr{})
527 invalidateOutputs(builder, wire, tieOffCache,
true);
528 result.replaceAllUsesWith(wire);
532 switch (memOp.getPortKind(i)) {
533 case firrtl::MemOp::PortKind::Read:
534 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
536 case firrtl::MemOp::PortKind::Write:
537 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
539 case firrtl::MemOp::PortKind::ReadWrite:
540 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
541 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
543 case firrtl::MemOp::PortKind::Debug:
548 if (!isa<firrtl::RefType>(result.getType())) {
551 cast<firrtl::BundleType>(wire.getType()).getNumElements();
552 for (
unsigned i = 0; i != numFields; ++i) {
553 if (i != 2 && i != 3 && i != 5)
554 reduceXor(builder, xorInputs,
555 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
558 reduceXor(builder, xorInputs, input);
563 outputs.push_back(output);
567 for (
auto output : outputs)
568 connectToLeafs(builder, output, xorInputs);
570 nlaRemover.markNLAsInOperation(memOp);
574 std::string
getName()
const override {
return "memory-stubber"; }
581static bool isFlowSensitiveOp(Operation *op) {
582 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
583 SubaccessOp, ObjectSubfieldOp>(op);
589template <
unsigned OpNum>
590struct FIRRTLOperandForwarder :
public Reduction {
591 uint64_t
match(Operation *op)
override {
592 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
594 if (isFlowSensitiveOp(op))
597 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
599 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
600 return resultTy && opTy &&
601 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
602 (resultTy.getBitWidthOrSentinel() == -1) ==
603 (opTy.getBitWidthOrSentinel() == -1) &&
604 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
606 LogicalResult
rewrite(Operation *op)
override {
608 ImplicitLocOpBuilder builder(op->getLoc(), op);
609 auto result = op->getResult(0);
610 auto operand = op->getOperand(OpNum);
611 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
612 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
613 auto resultWidth = resultTy.getBitWidthOrSentinel();
614 auto operandWidth = operandTy.getBitWidthOrSentinel();
616 if (resultWidth < operandWidth)
618 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
619 else if (resultWidth > operandWidth)
620 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
623 LLVM_DEBUG(llvm::dbgs() <<
"Forwarding " << newOp <<
" in " << *op <<
"\n");
624 result.replaceAllUsesWith(newOp);
628 std::string
getName()
const override {
629 return (
"firrtl-operand" + Twine(OpNum) +
"-forwarder").str();
640 anyrefCastDummy.clear();
641 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
642 for (
auto classOp : circuitOp.getOps<ClassOp>()) {
643 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
644 anyrefCastDummy.insert({circuitOp, classOp});
645 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
648 return WalkResult::skip();
652 uint64_t
match(Operation *op)
override {
653 if (op->hasTrait<OpTrait::ConstantLike>()) {
655 if (!matchPattern(op, m_Constant(&attr)))
657 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
658 if (intAttr.getValue().isZero())
660 if (
auto strAttr = dyn_cast<StringAttr>(attr))
663 if (
auto floatAttr = dyn_cast<FloatAttr>(attr))
664 if (floatAttr.getValue().isZero())
667 if (
auto listOp = dyn_cast<ListCreateOp>(op))
668 if (listOp.getElements().empty())
670 if (
auto pathOp = dyn_cast<UnresolvedPathOp>(op))
671 if (pathOp.getTarget().empty())
675 if (
auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
676 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
678 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
679 if (anyrefCastDummyNames[circuitOp].contains(className))
683 if (op->getNumResults() != 1)
685 if (op->hasAttr(
"inner_sym"))
687 if (isFlowSensitiveOp(op))
689 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
690 DoubleType, ListType, PathType, AnyRefType>(
691 op->getResult(0).getType());
694 LogicalResult
rewrite(Operation *op)
override {
695 OpBuilder builder(op);
696 auto type = op->getResult(0).getType();
699 if (isa<UIntType, SIntType>(type)) {
700 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
703 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
704 APSInt(width, isa<UIntType>(type)));
705 op->replaceAllUsesWith(newOp);
711 if (isa<StringType>(type)) {
712 auto attr = builder.getStringAttr(
"");
713 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
714 op->replaceAllUsesWith(newOp);
720 if (isa<FIntegerType>(type)) {
721 auto attr = builder.getIntegerAttr(builder.getIntegerType(64,
true), 0);
722 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
723 op->replaceAllUsesWith(newOp);
729 if (isa<BoolType>(type)) {
730 auto attr = builder.getBoolAttr(
false);
731 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
732 op->replaceAllUsesWith(newOp);
738 if (isa<DoubleType>(type)) {
739 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
740 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
741 op->replaceAllUsesWith(newOp);
747 if (isa<ListType>(type)) {
749 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
750 op->replaceAllUsesWith(newOp);
756 if (isa<PathType>(type)) {
757 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(),
"");
758 op->replaceAllUsesWith(newOp);
764 if (isa<AnyRefType>(type)) {
765 auto circuitOp = op->getParentOfType<CircuitOp>();
766 auto &dummy = anyrefCastDummy[circuitOp];
768 OpBuilder::InsertionGuard guard(builder);
769 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
770 auto &symbolTable = symbols.getNearestSymbolTable(op);
771 dummy = ClassOp::create(builder, op->getLoc(),
"Dummy", {}, {});
772 symbolTable.insert(dummy);
773 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
775 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy,
"dummy");
777 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
778 op->replaceAllUsesWith(anyrefOp);
786 std::string
getName()
const override {
return "firrtl-constantifier"; }
798struct ConnectInvalidator :
public Reduction {
799 uint64_t
match(Operation *op)
override {
800 if (!isa<FConnectLike>(op))
802 if (
auto *srcOp = op->getOperand(1).getDefiningOp())
803 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
804 isa<InvalidValueOp>(srcOp))
806 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
807 return type && type.isPassive();
809 LogicalResult
rewrite(Operation *op)
override {
811 auto rhs = op->getOperand(1);
812 OpBuilder builder(op);
813 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
814 auto *rhsOp = rhs.getDefiningOp();
815 op->setOperand(1, invOp);
820 std::string
getName()
const override {
return "connect-invalidator"; }
827struct AnnotationRemover :
public Reduction {
828 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
829 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
832 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
833 uint64_t matchId = 0;
836 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
837 for (
unsigned i = 0; i < annos.size(); ++i)
838 addMatch(1, matchId++);
841 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations"))
842 for (
auto portAnnoArray : portAnnos)
843 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
844 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
845 addMatch(1, matchId++);
849 ArrayRef<uint64_t> matches)
override {
851 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
854 uint64_t matchId = 0;
855 auto processAnnotations =
856 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
857 SmallVector<Attribute> newAnnotations;
858 for (
auto anno : annotations) {
859 if (!matchesSet.contains(matchId)) {
860 newAnnotations.push_back(anno);
863 nlaRemover.markNLAsInAnnotation(anno);
867 return ArrayAttr::get(op->getContext(), newAnnotations);
871 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations")) {
872 op->setAttr(
"annotations", processAnnotations(annos.getValue()));
876 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations")) {
877 SmallVector<Attribute> newPortAnnos;
878 for (
auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
879 newPortAnnos.push_back(
880 processAnnotations(portAnnoArrayAttr.getValue()));
882 op->setAttr(
"portAnnotations",
883 ArrayAttr::get(op->getContext(), newPortAnnos));
889 std::string
getName()
const override {
return "annotation-remover"; }
896struct SimplifyResets :
public OpReduction<CircuitOp> {
897 uint64_t
match(CircuitOp circuit)
override {
898 uint64_t numResets = 0;
899 AttrTypeWalker walker;
900 walker.addWalk([&](ResetType type) { ++numResets; });
902 circuit.walk([&](Operation *op) {
903 for (
auto result : op->getResults())
904 walker.walk(result.getType());
906 for (
auto ®ion : op->getRegions())
907 for (auto &block : region)
908 for (auto arg : block.getArguments())
909 walker.walk(arg.getType());
911 walker.walk(op->getAttrDictionary());
917 LogicalResult
rewrite(CircuitOp circuit)
override {
918 auto uint1Type = UIntType::get(circuit->getContext(), 1,
false);
919 auto constUint1Type = UIntType::get(circuit->getContext(), 1,
true);
921 AttrTypeReplacer replacer;
922 replacer.addReplacement([&](ResetType type) {
923 return type.isConst() ? constUint1Type : uint1Type;
925 replacer.recursivelyReplaceElementsIn(circuit,
true,
930 circuit.walk([&](Operation *op) {
933 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
934 fullAsyncResetAnnoClass,
935 ignoreFullAsyncResetAnnoClass);
939 if (
auto module = dyn_cast<FModuleLike>(op)) {
942 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
943 fullAsyncResetAnnoClass,
944 ignoreFullAsyncResetAnnoClass);
952 std::string
getName()
const override {
return "firrtl-simplify-resets"; }
958struct RootPortPruner :
public OpReduction<firrtl::FModuleOp> {
959 void matches(firrtl::FModuleOp module,
960 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
961 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
962 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
966 size_t numPorts =
module.getNumPorts();
967 for (
unsigned i = 0; i != numPorts; ++i) {
974 ArrayRef<uint64_t> matches)
override {
976 llvm::BitVector dropPorts(module.getNumPorts());
977 for (
auto portIdx : matches)
978 dropPorts.set(portIdx);
981 for (
auto portIdx : matches) {
983 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
988 module.erasePorts(dropPorts);
992 std::string
getName()
const override {
return "root-port-pruner"; }
998struct RootExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
999 void matches(firrtl::FExtModuleOp module,
1000 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1001 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
1002 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
1007 size_t numPorts =
module.getNumPorts();
1008 for (
unsigned i = 0; i != numPorts; ++i)
1013 ArrayRef<uint64_t> matches)
override {
1014 if (matches.empty())
1018 llvm::BitVector dropPorts(module.getNumPorts());
1019 for (
auto portIdx : matches)
1020 dropPorts.set(portIdx);
1023 module.erasePorts(dropPorts);
1027 std::string
getName()
const override {
return "root-extmodule-port-pruner"; }
1032struct ExtmoduleInstanceRemover :
public OpReduction<firrtl::InstanceOp> {
1037 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1039 uint64_t
match(firrtl::InstanceOp instOp)
override {
1040 return isa<firrtl::FExtModuleOp>(
1041 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1043 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
1045 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1046 symbols.getNearestSymbolTable(instOp)))
1048 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1050 SmallVector<Value> replacementWires;
1052 auto wire = firrtl::WireOp::create(
1054 (Twine(instOp.getName()) +
"_" +
info.getName()).str())
1056 if (
info.isOutput()) {
1058 if (
auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(
info.type)) {
1060 firrtl::ConnectOp::create(builder, wire, inv);
1061 }
else if (
auto propType = dyn_cast<firrtl::PropertyType>(
info.type)) {
1062 auto unknown = tieOffCache.
getUnknown(propType);
1063 builder.create<firrtl::PropAssignOp>(wire, unknown);
1066 replacementWires.push_back(wire);
1068 nlaRemover.markNLAsInOperation(instOp);
1069 instOp.replaceAllUsesWith(std::move(replacementWires));
1073 std::string
getName()
const override {
return "extmodule-instance-remover"; }
1085struct PortPrunerHelpers {
1087 template <
typename ModuleOpType>
1088 static void computeUnusedInstancePorts(ModuleOpType module,
1089 ArrayRef<Operation *> users,
1090 llvm::BitVector &portsToRemove) {
1091 auto ports =
module.getPorts();
1092 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1093 bool portUsed =
false;
1094 for (
auto *user : users) {
1095 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1096 auto result = instOp.getResult(portIdx);
1097 if (!result.use_empty()) {
1104 portsToRemove.set(portIdx);
1110 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1111 const llvm::BitVector &portsToRemove) {
1113 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1114 for (
auto *user : users) {
1115 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1116 instancesToUpdate.push_back(instOp);
1119 for (
auto instOp : instancesToUpdate) {
1120 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1123 size_t newResultIdx = 0;
1124 for (
size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1126 if (portsToRemove[oldResultIdx]) {
1128 assert(instOp.getResult(oldResultIdx).use_empty() &&
1129 "removing port with uses");
1132 instOp.getResult(oldResultIdx)
1133 .replaceAllUsesWith(newInst->getResult(newResultIdx));
1144struct ModulePortPruner :
public OpReduction<firrtl::FModuleOp> {
1149 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1151 void matches(firrtl::FModuleOp module,
1152 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1153 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1154 auto &userMap = symbols.getSymbolUserMap(tableOp);
1155 auto ports =
module.getPorts();
1156 auto users = userMap.getUsers(module);
1160 llvm::BitVector portsToRemove(ports.size());
1164 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1168 portsToRemove.set();
1173 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1174 if (!portsToRemove[portIdx])
1176 if (!module.getArgument(portIdx).use_empty())
1177 portsToRemove.reset(portIdx);
1181 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1182 if (portsToRemove[portIdx])
1183 addMatch(1, portIdx);
1187 ArrayRef<uint64_t> matches)
override {
1188 if (matches.empty())
1192 llvm::BitVector portsToRemove(module.getNumPorts());
1193 for (
auto portIdx : matches)
1194 portsToRemove.set(portIdx);
1197 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1198 auto &userMap = symbols.getSymbolUserMap(tableOp);
1199 auto users = userMap.getUsers(module);
1202 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1207 module.erasePorts(portsToRemove);
1212 std::string
getName()
const override {
return "module-port-pruner"; }
1219struct ExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
1224 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1226 void matches(firrtl::FExtModuleOp module,
1227 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1228 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1229 auto &userMap = symbols.getSymbolUserMap(tableOp);
1230 auto ports =
module.getPorts();
1231 auto users = userMap.getUsers(module);
1234 llvm::BitVector portsToRemove(ports.size());
1236 if (users.empty()) {
1238 portsToRemove.set();
1242 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1247 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1248 if (portsToRemove[portIdx])
1249 addMatch(1, portIdx);
1253 ArrayRef<uint64_t> matches)
override {
1254 if (matches.empty())
1258 llvm::BitVector portsToRemove(module.getNumPorts());
1259 for (
auto portIdx : matches)
1260 portsToRemove.set(portIdx);
1263 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1264 auto &userMap = symbols.getSymbolUserMap(tableOp);
1265 auto users = userMap.getUsers(module);
1268 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1272 module.erasePorts(portsToRemove);
1277 std::string
getName()
const override {
return "extmodule-port-pruner"; }
1284struct ConnectForwarder :
public Reduction {
1286 domInfo = std::make_unique<DominanceInfo>(op);
1289 uint64_t
match(Operation *op)
override {
1290 if (!isa<firrtl::FConnectLike>(op))
1292 auto dest = op->getOperand(0);
1293 auto src = op->getOperand(1);
1294 auto *destOp = dest.getDefiningOp();
1295 auto *srcOp = src.getDefiningOp();
1301 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1307 unsigned numConnects = 0;
1308 for (
auto &use : dest.getUses()) {
1309 auto *op = use.getOwner();
1310 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1311 if (++numConnects > 1)
1318 !domInfo->properlyDominates(srcOp, op,
false))
1325 LogicalResult
rewrite(Operation *op)
override {
1326 auto dst = op->getOperand(0);
1327 auto src = op->getOperand(1);
1328 dst.replaceAllUsesExcept(src, op);
1330 SmallVector<Operation *> worklist(
1331 {dst.getDefiningOp(), src.getDefiningOp()});
1336 std::string
getName()
const override {
return "connect-forwarder"; }
1339 std::unique_ptr<DominanceInfo> domInfo;
1344template <
unsigned OpNum>
1345struct ConnectSourceOperandForwarder :
public Reduction {
1346 uint64_t
match(Operation *op)
override {
1347 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1349 auto dest = op->getOperand(0);
1350 auto *destOp = dest.getDefiningOp();
1353 if (!destOp || !destOp->hasOneUse() ||
1354 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1357 auto *srcOp = op->getOperand(1).getDefiningOp();
1358 if (!srcOp || OpNum >= srcOp->getNumOperands())
1361 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1363 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1365 return resultTy && opTy &&
1366 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1367 ((resultTy.getBitWidthOrSentinel() == -1) ==
1368 (opTy.getBitWidthOrSentinel() == -1)) &&
1369 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1372 LogicalResult
rewrite(Operation *op)
override {
1373 auto *destOp = op->getOperand(0).getDefiningOp();
1374 auto *srcOp = op->getOperand(1).getDefiningOp();
1375 auto forwardedOperand = srcOp->getOperand(OpNum);
1376 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1378 if (
auto wire = dyn_cast<firrtl::WireOp>(destOp))
1379 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1383 auto regName = destOp->getAttrOfType<StringAttr>(
"name");
1386 auto clock = destOp->getOperand(0);
1387 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1388 clock, regName ? regName.str() :
"")
1393 builder.setInsertionPointAfter(op);
1394 if (isa<firrtl::ConnectOp>(op))
1395 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1397 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1407 std::string
getName()
const override {
1408 return (
"connect-source-operand-" + Twine(OpNum) +
"-forwarder").str();
1415struct DetachSubaccesses :
public Reduction {
1416 void beforeReduction(mlir::ModuleOp op)
override { opsToErase.clear(); }
1418 for (
auto *op : opsToErase)
1419 op->dropAllReferences();
1420 for (
auto *op : opsToErase)
1423 uint64_t
match(Operation *op)
override {
1426 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1427 llvm::all_of(op->getUses(), [](
auto &use) {
1428 return use.getOperandNumber() == 0 &&
1429 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1430 firrtl::SubaccessOp>(use.getOwner());
1433 LogicalResult
rewrite(Operation *op)
override {
1435 OpBuilder builder(op);
1436 bool isWire = isa<firrtl::WireOp>(op);
1439 invalidClock = firrtl::InvalidValueOp::create(
1440 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1441 for (Operation *user :
llvm::make_early_inc_range(op->getUsers())) {
1442 builder.setInsertionPoint(user);
1443 auto type = user->getResult(0).getType();
1446 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1449 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1450 user->replaceAllUsesWith(replOp);
1451 opsToErase.insert(user);
1453 opsToErase.insert(op);
1456 std::string
getName()
const override {
return "detach-subaccesses"; }
1457 llvm::DenseSet<Operation *> opsToErase;
1463struct NodeSymbolRemover :
public Reduction {
1468 uint64_t
match(Operation *op)
override {
1470 auto sym = op->getAttrOfType<hw::InnerSymAttr>(
"inner_sym");
1471 if (!sym || sym.empty())
1475 if (innerSymUses.hasInnerRef(op))
1480 LogicalResult
rewrite(Operation *op)
override {
1481 op->removeAttr(
"inner_sym");
1485 std::string
getName()
const override {
return "node-symbol-remover"; }
1494hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1503 LogicalResult walkResult = targetTable.
walkSymbols(
1506 if (parentTable.lookup(name)) {
1514 return failed(walkResult);
1518struct EagerInliner :
public OpReduction<InstanceOp> {
1523 for (
auto circuitOp : op.getOps<CircuitOp>())
1524 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1525 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1528 nlaRemover.remove(op);
1530 innerSymTables.reset();
1533 uint64_t
match(InstanceOp instOp)
override {
1534 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1536 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1539 if (!isa<FModuleOp>(moduleOp))
1543 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1546 auto it = nlaTables.find(circuitOp);
1547 if (it == nlaTables.end() || !it->second)
1549 DenseSet<hw::HierPathOp> nlas;
1550 it->second->getInstanceNLAs(instOp, nlas);
1556 auto parentOp = instOp->getParentOfType<FModuleLike>();
1557 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1563 LogicalResult
rewrite(InstanceOp instOp)
override {
1564 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1565 auto moduleOp = cast<FModuleOp>(
1566 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1568 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1569 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1572 IRRewriter rewriter(instOp);
1573 SmallVector<Value> argWires;
1574 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1575 auto result = instOp.getResult(i);
1576 auto name = rewriter.getStringAttr(Twine(instOp.getName()) +
"_" +
1577 instOp.getPortName(i));
1578 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1579 name, NameKindEnum::DroppableName,
1580 instOp.getPortAnnotation(i), StringAttr{})
1582 result.replaceAllUsesWith(wire);
1583 argWires.push_back(wire);
1587 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1591 nlaRemover.markNLAsInOperation(instOp);
1593 nlaRemover.markNLAsInOperation(moduleOp);
1596 clonedModuleOp.erase();
1600 std::string
getName()
const override {
return "firrtl-eager-inliner"; }
1605 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1606 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1610struct ObjectInliner :
public OpReduction<ObjectOp> {
1612 blocksToSort.clear();
1615 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1618 for (
auto *block : blocksToSort)
1619 mlir::sortTopologically(block);
1620 blocksToSort.clear();
1621 nlaRemover.remove(op);
1622 innerSymTables.reset();
1625 uint64_t
match(ObjectOp objOp)
override {
1626 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1628 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1631 if (!isa<ClassOp>(classOp))
1636 auto parentOp = objOp->getParentOfType<FModuleLike>();
1637 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1641 for (
auto *user : objOp.getResult().getUsers())
1642 if (!isa<ObjectSubfieldOp>(user))
1648 LogicalResult
rewrite(ObjectOp objOp)
override {
1649 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1650 auto classOp = cast<ClassOp>(
1651 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1652 auto clonedClassOp = classOp.clone();
1655 IRRewriter rewriter(objOp);
1656 SmallVector<Value> portWires;
1657 auto classType = objOp.getType();
1660 for (
unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1661 auto element = classType.getElement(i);
1662 auto name = rewriter.getStringAttr(Twine(objOp.getName()) +
"_" +
1663 element.name.getValue());
1664 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1665 NameKindEnum::DroppableName,
1666 rewriter.getArrayAttr({}), StringAttr{})
1668 portWires.push_back(wire);
1672 SmallVector<ObjectSubfieldOp> subfieldOps;
1673 for (
auto *user : objOp.getResult().getUsers()) {
1674 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1675 subfieldOps.push_back(subfieldOp);
1676 auto index = subfieldOp.getIndex();
1677 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1681 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1687 SmallVector<FConnectLike> connectsToErase;
1688 for (
auto portWire : portWires) {
1692 for (
auto *user : portWire.getUsers()) {
1693 if (
auto connect = dyn_cast<FConnectLike>(user)) {
1694 if (
connect.getDest() == portWire) {
1696 connectsToErase.push_back(connect);
1706 portWire.replaceAllUsesWith(value);
1707 for (
auto connect : connectsToErase)
1709 if (portWire.use_empty())
1710 portWire.getDefiningOp()->erase();
1711 connectsToErase.clear();
1715 nlaRemover.markNLAsInOperation(objOp);
1720 blocksToSort.insert(objOp->getBlock());
1723 for (
auto subfieldOp : subfieldOps)
1726 clonedClassOp.erase();
1730 std::string
getName()
const override {
return "firrtl-object-inliner"; }
1733 SetVector<Block *> blocksToSort;
1736 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1741struct ResetDisconnector :
public OpReduction<RegResetOp> {
1742 uint64_t
match(RegResetOp op)
override {
return 1; }
1744 LogicalResult
rewrite(RegResetOp regResetOp)
override {
1745 ImplicitLocOpBuilder builder(regResetOp.getLoc(), regResetOp);
1746 auto regOp = RegOp::create(
1747 builder, regResetOp.getResult().getType(), regResetOp.getClockVal(),
1748 regResetOp.getNameAttr(), regResetOp.getNameKindAttr(),
1749 regResetOp.getAnnotationsAttr(), regResetOp.getInnerSymAttr(),
1750 regResetOp.getForceableAttr());
1752 regResetOp.getResult().replaceAllUsesWith(regOp.getResult());
1753 if (regResetOp.getForceable())
1754 regResetOp.getRef().replaceAllUsesWith(regOp.getRef());
1760 std::string
getName()
const override {
return "reset-disconnector"; }
1775 uint64_t
match(Operation *op)
override {
1777 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1778 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1779 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1780 firrtl::CoverOp>(op);
1782 LogicalResult
rewrite(Operation *op)
override {
1783 TypeSwitch<Operation *, void>(op)
1784 .Case<firrtl::WireOp>([](
auto op) { op.setName(
"wire"); })
1785 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1786 [](
auto op) { op.setName(
"reg"); })
1787 .Case<firrtl::NodeOp>([](
auto op) { op.setName(
"node"); })
1788 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1789 [](
auto op) { op.setName(
"mem"); })
1790 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](
auto op) {
1791 op->setAttr(
"message", StringAttr::get(op.getContext(),
""));
1792 op->setAttr(
"name", StringAttr::get(op.getContext(),
""));
1797 std::string
getName()
const override {
1798 return "module-internal-name-sanitizer";
1803 bool isOneShot()
const override {
return true; }
1823 if (portNameIndex >= 26)
1825 return 'a' + portNameIndex++;
1830 LogicalResult
rewrite(firrtl::CircuitOp circuitOp)
override {
1835 SymbolTable symTable(circuitOp);
1839 auto renameModule = [&](firrtl::FModuleLike mod,
1840 StringAttr newName) -> LogicalResult {
1841 StringAttr oldName = mod.getModuleNameAttr();
1842 if (failed(symTable.rename(mod, newName)))
1844 nlaTable.renameModule(oldName, newName);
1850 auto topModule = iGraph.getTopLevelModule();
1851 auto *ctx = circuitOp.getContext();
1853 topModule.getModuleName())) {
1854 auto newTopName = StringAttr::get(ctx, nameGenerator.
getNextName(ns));
1855 if (failed(renameModule(topModule, newTopName)))
1857 circuitOp.setName(newTopName.getValue());
1860 for (
auto *node : iGraph) {
1861 auto module = node->getModule<firrtl::FModuleLike>();
1863 bool shouldReplacePorts =
false;
1864 SmallVector<Attribute> newPortNames;
1865 if (
auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1870 auto oldPorts = fmodule.getPorts();
1871 shouldReplacePorts = !oldPorts.empty();
1872 for (
unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1873 auto port = oldPorts[i];
1875 .
Case<firrtl::ClockType>(
1876 [&](
auto a) {
return ns.
newName(
"clk"); })
1877 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1878 [&](
auto a) {
return ns.
newName(
"rst"); })
1879 .Case<firrtl::RefType>(
1880 [&](
auto a) {
return ns.
newName(
"ref"); })
1881 .Default([&](
auto a) {
1884 newPortNames.push_back(StringAttr::get(ctx, newName));
1886 fmodule->setAttr(
"portNames",
1887 ArrayAttr::get(fmodule.getContext(), newPortNames));
1890 if (module == iGraph.getTopLevelModule())
1894 module.getModuleName()))
1896 auto newName = StringAttr::get(ctx, nameGenerator.
getNextName(ns));
1897 if (failed(renameModule(module, newName)))
1899 for (
auto *use : node->uses()) {
1900 auto useOp = use->getInstance();
1901 if (
auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1905 instanceOp.setName(newName);
1906 if (shouldReplacePorts)
1907 instanceOp.setPortNamesAttr(ArrayAttr::get(ctx, newPortNames));
1908 }
else if (
auto instanceChoiceOp =
1909 dyn_cast<firrtl::InstanceChoiceOp>(*useOp)) {
1910 if (instanceChoiceOp.getDefaultTargetAttr().getAttr() == newName)
1911 instanceChoiceOp.setName(newName);
1912 if (shouldReplacePorts)
1913 instanceChoiceOp.setPortNamesAttr(
1914 ArrayAttr::get(ctx, newPortNames));
1915 }
else if (
auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1919 auto oldClassType = objectOp.getType();
1920 auto newClassType = firrtl::ClassType::get(
1921 ctx, FlatSymbolRefAttr::get(newName), oldClassType.getElements());
1922 objectOp.getResult().setType(newClassType);
1923 objectOp.setName(newName);
1931 std::string
getName()
const override {
return "module-name-sanitizer"; }
1935 bool isOneShot()
const override {
return true; }
1954struct ModuleSwapper :
public OpReduction<InstanceOp> {
1956 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1957 struct CircuitState {
1958 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1959 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1960 std::unique_ptr<NLATable> nlaTable;
1966 moduleSizes.clear();
1967 circuitStates.clear();
1970 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1971 auto &state = circuitStates[circuitOp];
1972 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1973 buildModuleTypeGroups(circuitOp, state);
1974 return WalkResult::skip();
1977 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1983 PortSignature getModulePortSignature(FModuleLike module) {
1984 PortSignature signature;
1985 signature.reserve(module.getNumPorts());
1986 for (
unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1987 signature.emplace_back(module.getPortType(i),
module.getPortDirection(i));
1992 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1994 for (
auto module : circuitOp.
getBodyBlock()->getOps<FModuleLike>()) {
1995 auto signature = getModulePortSignature(module);
1996 state.moduleTypeGroups[signature].push_back(module);
2000 for (
auto &[signature, modules] : state.moduleTypeGroups) {
2001 if (modules.size() <= 1)
2004 FModuleLike smallestModule =
nullptr;
2005 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
2007 for (
auto module : modules) {
2008 uint64_t size = moduleSizes.getModuleSize(module, symbols);
2009 if (size < smallestSize) {
2010 smallestSize = size;
2011 smallestModule =
module;
2016 for (
auto module : modules) {
2017 if (module != smallestModule) {
2018 state.instanceToCanonicalModule[
module.getModuleNameAttr()] =
2025 uint64_t
match(InstanceOp instOp)
override {
2027 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2029 const auto &state = circuitStates.at(circuitOp);
2032 DenseSet<hw::HierPathOp> nlas;
2033 state.nlaTable->getInstanceNLAs(instOp, nlas);
2038 auto moduleName = instOp.getModuleNameAttr().getAttr();
2039 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
2040 if (!canonicalModule)
2044 auto currentModule = cast<FModuleLike>(
2045 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
2046 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
2047 uint64_t canonicalSize =
2048 moduleSizes.getModuleSize(canonicalModule, symbols);
2049 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
2052 LogicalResult
rewrite(InstanceOp instOp)
override {
2054 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2056 const auto &state = circuitStates.at(circuitOp);
2059 auto canonicalModule = state.instanceToCanonicalModule.at(
2060 instOp.getModuleNameAttr().getAttr());
2061 auto canonicalName = canonicalModule.getModuleNameAttr();
2062 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2065 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2070 std::string
getName()
const override {
return "firrtl-module-swapper"; }
2079 DenseMap<CircuitOp, CircuitState> circuitStates;
2097struct ForceDedup :
public OpReduction<CircuitOp> {
2101 modulesToErase.clear();
2102 moduleSizes.clear();
2105 nlaRemover.remove(op);
2106 for (
auto mod : modulesToErase)
2111 void matches(CircuitOp circuitOp,
2112 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2113 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2115 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2116 if (!anno.
isClass(mustDeduplicateAnnoClass))
2119 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2120 if (!modulesAttr || modulesAttr.size() < 2)
2126 uint64_t totalSize = 0;
2127 ArrayAttr portTypes;
2128 DenseBoolArrayAttr portDirections;
2129 bool allSame =
true;
2130 for (
auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2136 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2141 totalSize += moduleSizes.getModuleSize(mod, symbols);
2143 portTypes = mod.getPortTypesAttr();
2144 portDirections = mod.getPortDirectionsAttr();
2145 }
else if (portTypes != mod.getPortTypesAttr() ||
2146 portDirections != mod.getPortDirectionsAttr()) {
2156 addMatch(totalSize, annoIdx);
2161 ArrayRef<uint64_t> matches)
override {
2162 auto *
context = circuitOp->getContext();
2166 SmallVector<Annotation> newAnnotations;
2168 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2170 if (!llvm::is_contained(matches, annoIdx)) {
2171 newAnnotations.push_back(anno);
2174 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2175 assert(anno.
isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2176 modulesAttr.size() >= 2);
2179 SmallVector<StringAttr> moduleNames;
2180 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2182 auto refStr = moduleRef.getValue();
2183 auto pipePos = refStr.find(
'|');
2184 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2185 auto moduleName = refStr.substr(pipePos + 1);
2186 moduleNames.push_back(StringAttr::get(
context, moduleName));
2191 if (moduleNames.size() < 2)
2196 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2197 nlaRemover.markNLAsInAnnotation(anno.
getAttr());
2199 if (newAnnotations.size() == annotations.size())
2204 newAnnoSet.applyToOperation(circuitOp);
2208 std::string
getName()
const override {
return "firrtl-force-dedup"; }
2214 void replaceModuleReferences(CircuitOp circuitOp,
2215 ArrayRef<StringAttr> moduleNames,
2218 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2219 auto &symbolTable = symbols.getSymbolTable(tableOp);
2220 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2221 auto *
context = circuitOp->getContext();
2225 FModuleLike canonicalModule;
2226 SmallVector<FModuleLike> modulesToReplace;
2227 for (
auto name : moduleNames) {
2228 if (
auto mod = symbolTable.lookup<FModuleLike>(name)) {
2229 if (!canonicalModule)
2230 canonicalModule = mod;
2232 modulesToReplace.push_back(mod);
2235 if (modulesToReplace.empty())
2239 auto canonicalName = canonicalModule.getModuleNameAttr();
2240 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2241 for (
auto moduleName : moduleNames) {
2242 if (moduleName == canonicalName)
2244 auto *symbolOp = symbolTable.lookup(moduleName);
2247 for (
auto *user : symbolUserMap.getUsers(symbolOp)) {
2248 auto instOp = dyn_cast<InstanceOp>(user);
2249 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2251 instOp.setModuleNameAttr(canonicalRef);
2252 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2258 for (
auto oldMod : modulesToReplace) {
2259 SmallVector<hw::HierPathOp> nlaOps(
2260 nlaTable.
lookup(oldMod.getModuleNameAttr()));
2261 for (
auto nlaOp : nlaOps) {
2262 nlaTable.
erase(nlaOp);
2263 StringAttr oldModName = oldMod.getModuleNameAttr();
2264 StringAttr newModName = canonicalName;
2265 SmallVector<Attribute, 4> newPath;
2266 for (
auto nameRef : nlaOp.getNamepath()) {
2267 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2268 if (ref.getModule() == oldModName) {
2269 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2270 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2271 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2272 if (oldInst && newInst) {
2275 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2276 auto newModNames = newInst.getReferencedModuleNamesAttr();
2277 if (!oldModNames.empty() && !newModNames.empty()) {
2278 oldModName = cast<StringAttr>(oldModNames[0]);
2279 newModName = cast<StringAttr>(newModNames[0]);
2283 newPath.push_back(ref);
2284 }
else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2285 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2287 newPath.push_back(nameRef);
2290 nlaOp.setNamepathAttr(ArrayAttr::get(
context, newPath));
2296 for (
auto module : modulesToReplace) {
2297 nlaRemover.markNLAsInOperation(module);
2298 modulesToErase.insert(module);
2304 SetVector<FModuleLike> modulesToErase;
2324struct MustDedupChildren :
public OpReduction<CircuitOp> {
2329 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
2333 void matches(CircuitOp circuitOp,
2334 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2336 uint64_t matchId = 0;
2338 DenseSet<StringRef> modulesAlreadyInMustDedup;
2339 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations))
2340 if (anno.isClass(mustDeduplicateAnnoClass))
2341 if (auto modulesAttr = anno.getMember<ArrayAttr>(
"modules"))
2342 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2344 modulesAlreadyInMustDedup.insert(target->module);
2346 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2347 if (!anno.
isClass(mustDeduplicateAnnoClass))
2350 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2351 if (!modulesAttr || modulesAttr.size() < 2)
2355 processInstanceGroups(
2356 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2360 SmallDenseSet<StringAttr, 4> moduleTargets;
2361 for (
auto instOp : instanceGroup) {
2362 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2363 for (
auto moduleName : moduleNames)
2364 moduleTargets.insert(cast<StringAttr>(moduleName));
2366 if (moduleTargets.size() < 2)
2371 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2372 auto moduleNames = inst.getReferencedModuleNames();
2373 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2374 return modulesAlreadyInMustDedup.contains(moduleName);
2379 addMatch(1, matchId - 1);
2385 ArrayRef<uint64_t> matches)
override {
2386 auto *
context = circuitOp->getContext();
2388 SmallVector<Annotation> newAnnotations;
2389 uint64_t matchId = 0;
2391 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2392 if (!anno.
isClass(mustDeduplicateAnnoClass)) {
2393 newAnnotations.push_back(anno);
2397 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2398 if (!modulesAttr || modulesAttr.size() < 2) {
2399 newAnnotations.push_back(anno);
2403 processInstanceGroups(
2404 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2406 if (!llvm::is_contained(matches, matchId++))
2411 for (
auto instOp : instanceGroup) {
2412 auto moduleNames = instOp.getReferencedModuleNames();
2413 for (
auto moduleName : moduleNames) {
2415 target.circuit = circuitOp.getName();
2416 target.module = moduleName;
2417 moduleTargets.insert(target.toStringAttr(
context));
2422 SmallVector<NamedAttribute> newAnnoAttrs;
2423 newAnnoAttrs.emplace_back(
2424 StringAttr::get(
context,
"class"),
2425 StringAttr::get(
context, mustDeduplicateAnnoClass));
2426 newAnnoAttrs.emplace_back(
2427 StringAttr::get(
context,
"modules"),
2429 SmallVector<Attribute>(moduleTargets.begin(),
2430 moduleTargets.end())));
2432 auto newAnnoDict = DictionaryAttr::get(
context, newAnnoAttrs);
2433 newAnnotations.emplace_back(newAnnoDict);
2437 newAnnotations.push_back(anno);
2442 newAnnoSet.applyToOperation(circuitOp);
2446 std::string
getName()
const override {
return "must-dedup-children"; }
2454 void processInstanceGroups(
2455 CircuitOp circuitOp, ArrayAttr modulesAttr,
2456 llvm::function_ref<
void(ArrayRef<FInstanceLike>)> callback) {
2457 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2460 SmallVector<FModuleLike> modules;
2461 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2463 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2464 modules.push_back(mod);
2467 if (modules.size() < 2)
2474 struct InstanceGroup {
2475 SmallVector<FInstanceLike> instances;
2476 bool nameIsUnique =
true;
2479 for (
auto module : modules) {
2481 module.walk([&](FInstanceLike instOp) {
2482 if (isa<ObjectOp>(instOp.getOperation()))
2484 auto name = instOp.getInstanceNameAttr();
2485 auto &group = instanceGroups[name];
2486 if (nameCounts[name]++ > 1)
2487 group.nameIsUnique =
false;
2488 group.instances.push_back(instOp);
2494 for (
auto &[name, group] : instanceGroups)
2495 if (group.nameIsUnique && group.instances.size() >= 2)
2496 callback(group.instances);
2503struct LayerDisable :
public OpReduction<CircuitOp> {
2504 LayerDisable(MLIRContext *
context) {
2505 pm = std::make_unique<mlir::PassManager>(
2506 context,
"builtin.module", mlir::OpPassManager::Nesting::Explicit);
2507 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2510 void beforeReduction(mlir::ModuleOp op)
override { symbolRefAttrMap.clear(); }
2512 void afterReduction(mlir::ModuleOp op)
override { (void)pm->run(op); };
2514 void matches(CircuitOp circuitOp,
2515 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2516 uint64_t matchId = 0;
2518 SmallVector<FlatSymbolRefAttr> nestedRefs;
2519 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2522 rootRef = layerOp.getSymNameAttr();
2524 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2526 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2527 addMatch(1, matchId++);
2529 for (
auto nestedLayerOp : layerOp.getOps<LayerOp>())
2530 addLayer(rootRef, nestedLayerOp);
2532 if (!nestedRefs.empty())
2533 nestedRefs.pop_back();
2536 for (
auto layerOp : circuitOp.getOps<LayerOp>())
2537 addLayer({}, layerOp);
2541 ArrayRef<uint64_t> matches)
override {
2542 SmallVector<Attribute> disableLayers;
2543 if (
auto existingDisables = circuitOp.getDisableLayersAttr()) {
2544 auto disableRange = existingDisables.getAsRange<Attribute>();
2545 disableLayers.append(disableRange.begin(), disableRange.end());
2547 for (
auto match : matches)
2548 disableLayers.push_back(symbolRefAttrMap.at(match));
2550 circuitOp.setDisableLayersAttr(
2551 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2556 std::string
getName()
const override {
return "firrtl-layer-disable"; }
2558 std::unique_ptr<mlir::PassManager> pm;
2559 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2569 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2571 auto elements = listOp.getElements();
2572 for (
size_t i = 0; i < elements.size(); ++i)
2577 ArrayRef<uint64_t>
matches)
override {
2579 llvm::SmallDenseSet<uint64_t, 4> matchesSet(
matches.begin(),
matches.end());
2582 SmallVector<Value> newElements;
2583 auto elements = listOp.getElements();
2584 for (
size_t i = 0; i < elements.size(); ++i) {
2585 if (!matchesSet.contains(i))
2586 newElements.push_back(elements[i]);
2590 OpBuilder builder(listOp);
2591 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2592 listOp.getType(), newElements);
2593 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2600 return "firrtl-list-create-element-remover";
2606 uint64_t
match(FModuleOp module)
override {
2607 return module.getConvention() != Convention::Internal;
2610 LogicalResult
rewrite(FModuleOp module)
override {
2611 module.setConvention(Convention::Internal);
2615 std::string
getName()
const override {
return "module-convention-remover"; }
2622 uint64_t
match(FExtModuleOp extmodule)
override {
2623 return extmodule.getConvention() != Convention::Internal;
2626 LogicalResult
rewrite(FExtModuleOp extmodule)
override {
2627 extmodule.setConvention(Convention::Internal);
2632 return "extmodule-convention-remover";
2649 patterns.add<SimplifyResets, 35>();
2651 patterns.add<MustDedupChildren, 33>();
2652 patterns.add<AnnotationRemover, 32>();
2654 patterns.add<LayerDisable, 30>(getContext());
2660 firrtl::createLowerCHIRRTLPass(),
true,
true);
2665 patterns.add<FIRRTLModuleExternalizer, 25>();
2666 patterns.add<InstanceStubber, 24>();
2671 firrtl::createLowerFIRRTLTypes(),
true,
true);
2678 firrtl::createRemoveUnusedPorts({
true}));
2679 patterns.add<NodeSymbolRemover, 16>();
2681 patterns.add<ConnectForwarder, 14>();
2682 patterns.add<ConnectInvalidator, 13>();
2684 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2685 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2686 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2688 patterns.add<ResetDisconnector, 8>();
2689 patterns.add<DetachSubaccesses, 7>();
2690 patterns.add<ModulePortPruner, 7>();
2691 patterns.add<ExtmodulePortPruner, 6>();
2693 patterns.add<RootExtmodulePortPruner, 5>();
2694 patterns.add<ExtmoduleInstanceRemover, 4>();
2695 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2696 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2697 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2705 mlir::DialectRegistry ®istry) {
2706 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Helper class to cache tie-off values for different FIRRTL types.
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
@ None
Don't explicitly preserve any named values.
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
void pruneUnusedOps(SmallVectorImpl< Operation * > &worklist, Reduction &reduction)
Starting from an initial worklist of operations, traverse through it and its operands and erase opera...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Reduction that removes the convention attribute from external modules.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
uint64_t match(FExtModuleOp extmodule) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(FExtModuleOp extmodule) override
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Reduction that removes the convention attribute from regular modules.
uint64_t match(FModuleOp module) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(FModuleOp module) override
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Pseudo-reduction that sanitizes the names of operations inside modules.
Pseudo-reduction that sanitizes module and port names.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
A reduction pattern that applies an mlir::Pass.
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
The namespace of a CircuitOp, generally inhabited by modules.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)