24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/Matchers.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/Support/Debug.h"
33#define DEBUG_TYPE "firrtl-reductions"
37using namespace firrtl;
39using llvm::SmallDenseSet;
40using llvm::SmallSetVector;
53 return tables->getSymbolTable(op);
63 return userMaps.insert({op, SymbolUserMap(*
tables, op)}).first->second;
70 tables = std::make_unique<SymbolTableCollection>();
75 std::unique_ptr<SymbolTableCollection>
tables;
82static std::optional<firrtl::FModuleOp>
85 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
86 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
88 return moduleOp ? std::optional(moduleOp) : std::nullopt;
99 module->walk([&](Operation *op) {
101 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(op))
115 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
116 auto *op = use.getOwner();
117 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
119 if (use.getOperandNumber() != 0)
121 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
138 unsigned numRemoved = 0;
140 SymbolTableCollection symbolTables;
141 for (Operation &rootOp : *
module.getBody()) {
142 if (!isa<firrtl::CircuitOp>(&rootOp))
144 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
145 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
147 if (
auto *op = symbolTable.lookup(sym)) {
148 if (symbolUserMap.useEmpty(op)) {
157 if (numRemoved > 0 || numLost > 0) {
158 llvm::dbgs() <<
"Removed " << numRemoved <<
" NLAs";
160 llvm::dbgs() <<
" (" << numLost <<
" no longer there)";
161 llvm::dbgs() <<
"\n";
170 if (
auto dict = dyn_cast<DictionaryAttr>(anno)) {
171 if (
auto field = dict.getAs<FlatSymbolRefAttr>(
"circt.nonlocal"))
172 nlasToRemove.insert(field.getAttr());
173 for (
auto namedAttr : dict)
174 markNLAsInAnnotation(namedAttr.getValue());
175 }
else if (
auto array = dyn_cast<ArrayAttr>(anno)) {
176 for (
auto attr : array)
177 markNLAsInAnnotation(attr);
185 op->walk([&](Operation *op) {
186 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
187 markNLAsInAnnotation(annos);
202struct FIRRTLModuleExternalizer :
public OpReduction<FModuleOp> {
209 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
211 uint64_t
match(FModuleOp module)
override {
212 if (innerSymUses.hasInnerRef(module))
214 return moduleSizes.getModuleSize(module, symbols);
217 LogicalResult
rewrite(FModuleOp module)
override {
220 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
221 for (
auto attr :
module.getPortTypes()) {
222 auto type = cast<TypeAttr>(attr).getValue();
223 if (
auto refType = type_dyn_cast<RefType>(type))
224 if (
auto layer = refType.getLayer())
225 layers.insert(layer);
227 SmallVector<Attribute, 4> layersArray;
228 layersArray.reserve(layers.size());
229 for (
auto layer : layers)
230 layersArray.push_back(layer);
232 nlaRemover.markNLAsInOperation(module);
233 OpBuilder builder(module);
234 auto extmodule = FExtModuleOp::create(
235 builder, module->getLoc(),
236 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
237 module.getConventionAttr(), module.getPorts(),
238 builder.getArrayAttr(layersArray), StringRef(),
239 module.getAnnotationsAttr());
240 SymbolTable::setSymbolVisibility(extmodule,
241 SymbolTable::getSymbolVisibility(module));
246 std::string
getName()
const override {
return "firrtl-module-externalizer"; }
265static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
267 auto type = type_dyn_cast<FIRRTLType>(value.getType());
272 if (
auto refType = type_dyn_cast<RefType>(type)) {
274 assert(!
flip &&
"input probes are not allowed");
276 auto underlyingType = refType.getType();
278 if (!refType.getForceable()) {
281 auto targetWire = WireOp::create(builder, underlyingType);
282 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
283 builder.create<RefDefineOp>(value, refSend.getResult());
286 auto invalid = tieOffCache.
getInvalid(underlyingType);
287 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
293 WireOp::create(builder, underlyingType,
294 "", NameKindEnum::DroppableName,
295 ArrayRef<Attribute>{},
300 auto targetWire = forceableWire.getResult();
301 auto forceableRef = forceableWire.getDataRef();
303 builder.create<RefDefineOp>(value, forceableRef);
306 auto invalid = tieOffCache.
getInvalid(underlyingType);
307 MatchingConnectOp::create(builder, targetWire, invalid);
312 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
313 for (
auto element :
llvm::enumerate(bundleType.getElements())) {
314 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
315 invalidateOutputs(builder, subfield, tieOffCache,
316 flip ^ element.value().isFlip);
317 if (subfield.use_empty())
318 subfield.getDefiningOp()->erase();
324 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
325 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
326 auto subindex = builder.createOrFold<SubindexOp>(value, i);
327 invalidateOutputs(builder, subindex, tieOffCache,
flip);
328 if (subindex.use_empty())
329 subindex.getDefiningOp()->erase();
339 if (
auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
340 auto invalid = tieOffCache.
getInvalid(baseType);
341 ConnectOp::create(builder, value, invalid);
346 if (
auto propType = type_dyn_cast<PropertyType>(type)) {
347 auto unknown = tieOffCache.
getUnknown(propType);
348 builder.create<PropAssignOp>(value, unknown);
353static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
355 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
358 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
359 for (
auto element :
llvm::enumerate(bundleType.getElements()))
360 connectToLeafs(builder,
361 firrtl::SubfieldOp::create(builder, dest, element.index()),
365 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
366 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
367 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
371 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
374 auto destWidth = type.getBitWidthOrSentinel();
375 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
376 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
377 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
378 if (!isa<firrtl::UIntType>(type)) {
379 if (isa<firrtl::SIntType>(type))
380 value = firrtl::AsSIntPrimOp::create(builder, value);
384 firrtl::ConnectOp::create(builder, dest, value);
388static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
389 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
392 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
393 for (
auto element :
llvm::enumerate(bundleType.getElements()))
396 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
399 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
400 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
401 reduceXor(builder, into,
402 builder.createOrFold<firrtl::SubindexOp>(value, i));
405 if (!isa<firrtl::UIntType>(type)) {
406 if (isa<firrtl::SIntType>(type))
407 value = firrtl::AsUIntPrimOp::create(builder, value);
411 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
417struct InstanceStubber :
public OpReduction<firrtl::InstanceOp> {
420 erasedModules.clear();
428 SmallVector<Operation *> worklist;
429 auto deadInsts = erasedInsts;
430 for (
auto *op : erasedModules)
431 worklist.push_back(op);
432 while (!worklist.empty()) {
433 auto *op = worklist.pop_back_val();
434 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
435 op->walk([&](firrtl::InstanceOp instOp) {
436 auto moduleOp = cast<firrtl::FModuleLike>(
437 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
438 deadInsts.insert(instOp);
440 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
441 [&](Operation *user) { return deadInsts.contains(user); })) {
442 LLVM_DEBUG(llvm::dbgs() <<
"- Removing transitively unused module `"
443 << moduleOp.getModuleName() <<
"`\n");
444 erasedModules.insert(moduleOp);
445 worklist.push_back(moduleOp);
450 for (
auto *op : erasedInsts)
452 for (
auto *op : erasedModules)
454 nlaRemover.remove(op);
457 uint64_t
match(firrtl::InstanceOp instOp)
override {
459 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
463 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
464 LLVM_DEBUG(llvm::dbgs()
465 <<
"Stubbing instance `" << instOp.getName() <<
"`\n");
466 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
468 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
469 auto result = instOp.getResult(i);
470 auto name = builder.getStringAttr(Twine(instOp.getName()) +
"_" +
471 instOp.getPortName(i));
473 firrtl::WireOp::create(builder, result.getType(), name,
474 firrtl::NameKindEnum::DroppableName,
475 instOp.getPortAnnotation(i), StringAttr{})
477 invalidateOutputs(builder, wire, tieOffCache,
478 instOp.getPortDirection(i) == firrtl::Direction::In);
479 result.replaceAllUsesWith(wire);
481 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
482 auto moduleOp = cast<firrtl::FModuleLike>(
483 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
484 nlaRemover.markNLAsInOperation(instOp);
485 erasedInsts.insert(instOp);
487 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
488 [&](Operation *user) { return erasedInsts.contains(user); })) {
489 LLVM_DEBUG(llvm::dbgs() <<
"- Removing now unused module `"
490 << moduleOp.getModuleName() <<
"`\n");
491 erasedModules.insert(moduleOp);
496 std::string
getName()
const override {
return "instance-stubber"; }
501 llvm::DenseSet<Operation *> erasedInsts;
502 llvm::DenseSet<Operation *> erasedModules;
508struct MemoryStubber :
public OpReduction<firrtl::MemOp> {
509 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
510 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
511 LogicalResult
rewrite(firrtl::MemOp memOp)
override {
512 LLVM_DEBUG(llvm::dbgs() <<
"Stubbing memory `" << memOp.getName() <<
"`\n");
513 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
516 SmallVector<Value> outputs;
517 for (
unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
518 auto result = memOp.getResult(i);
519 auto name = builder.getStringAttr(Twine(memOp.getName()) +
"_" +
520 memOp.getPortName(i));
522 firrtl::WireOp::create(builder, result.getType(), name,
523 firrtl::NameKindEnum::DroppableName,
524 memOp.getPortAnnotation(i), StringAttr{})
526 invalidateOutputs(builder, wire, tieOffCache,
true);
527 result.replaceAllUsesWith(wire);
531 switch (memOp.getPortKind(i)) {
532 case firrtl::MemOp::PortKind::Read:
533 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
535 case firrtl::MemOp::PortKind::Write:
536 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
538 case firrtl::MemOp::PortKind::ReadWrite:
539 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
540 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
542 case firrtl::MemOp::PortKind::Debug:
547 if (!isa<firrtl::RefType>(result.getType())) {
550 cast<firrtl::BundleType>(wire.getType()).getNumElements();
551 for (
unsigned i = 0; i != numFields; ++i) {
552 if (i != 2 && i != 3 && i != 5)
553 reduceXor(builder, xorInputs,
554 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
557 reduceXor(builder, xorInputs, input);
562 outputs.push_back(output);
566 for (
auto output : outputs)
567 connectToLeafs(builder, output, xorInputs);
569 nlaRemover.markNLAsInOperation(memOp);
573 std::string
getName()
const override {
return "memory-stubber"; }
580static bool isFlowSensitiveOp(Operation *op) {
581 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
582 SubaccessOp, ObjectSubfieldOp>(op);
588template <
unsigned OpNum>
589struct FIRRTLOperandForwarder :
public Reduction {
590 uint64_t
match(Operation *op)
override {
591 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
593 if (isFlowSensitiveOp(op))
596 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
598 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
599 return resultTy && opTy &&
600 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
601 (resultTy.getBitWidthOrSentinel() == -1) ==
602 (opTy.getBitWidthOrSentinel() == -1) &&
603 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
605 LogicalResult
rewrite(Operation *op)
override {
607 ImplicitLocOpBuilder builder(op->getLoc(), op);
608 auto result = op->getResult(0);
609 auto operand = op->getOperand(OpNum);
610 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
611 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
612 auto resultWidth = resultTy.getBitWidthOrSentinel();
613 auto operandWidth = operandTy.getBitWidthOrSentinel();
615 if (resultWidth < operandWidth)
617 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
618 else if (resultWidth > operandWidth)
619 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
622 LLVM_DEBUG(llvm::dbgs() <<
"Forwarding " << newOp <<
" in " << *op <<
"\n");
623 result.replaceAllUsesWith(newOp);
627 std::string
getName()
const override {
628 return (
"firrtl-operand" + Twine(OpNum) +
"-forwarder").str();
639 anyrefCastDummy.clear();
640 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
641 for (
auto classOp : circuitOp.getOps<ClassOp>()) {
642 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
643 anyrefCastDummy.insert({circuitOp, classOp});
644 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
647 return WalkResult::skip();
651 uint64_t
match(Operation *op)
override {
652 if (op->hasTrait<OpTrait::ConstantLike>()) {
654 if (!matchPattern(op, m_Constant(&attr)))
656 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
657 if (intAttr.getValue().isZero())
659 if (
auto strAttr = dyn_cast<StringAttr>(attr))
662 if (
auto floatAttr = dyn_cast<FloatAttr>(attr))
663 if (floatAttr.getValue().isZero())
666 if (
auto listOp = dyn_cast<ListCreateOp>(op))
667 if (listOp.getElements().empty())
669 if (
auto pathOp = dyn_cast<UnresolvedPathOp>(op))
670 if (pathOp.getTarget().empty())
674 if (
auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
675 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
677 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
678 if (anyrefCastDummyNames[circuitOp].contains(className))
682 if (op->getNumResults() != 1)
684 if (op->hasAttr(
"inner_sym"))
686 if (isFlowSensitiveOp(op))
688 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
689 DoubleType, ListType, PathType, AnyRefType>(
690 op->getResult(0).getType());
693 LogicalResult
rewrite(Operation *op)
override {
694 OpBuilder builder(op);
695 auto type = op->getResult(0).getType();
698 if (isa<UIntType, SIntType>(type)) {
699 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
702 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
703 APSInt(width, isa<UIntType>(type)));
704 op->replaceAllUsesWith(newOp);
710 if (isa<StringType>(type)) {
711 auto attr = builder.getStringAttr(
"");
712 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
713 op->replaceAllUsesWith(newOp);
719 if (isa<FIntegerType>(type)) {
720 auto attr = builder.getIntegerAttr(builder.getIntegerType(64,
true), 0);
721 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
722 op->replaceAllUsesWith(newOp);
728 if (isa<BoolType>(type)) {
729 auto attr = builder.getBoolAttr(
false);
730 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
731 op->replaceAllUsesWith(newOp);
737 if (isa<DoubleType>(type)) {
738 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
739 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
740 op->replaceAllUsesWith(newOp);
746 if (isa<ListType>(type)) {
748 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
749 op->replaceAllUsesWith(newOp);
755 if (isa<PathType>(type)) {
756 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(),
"");
757 op->replaceAllUsesWith(newOp);
763 if (isa<AnyRefType>(type)) {
764 auto circuitOp = op->getParentOfType<CircuitOp>();
765 auto &dummy = anyrefCastDummy[circuitOp];
767 OpBuilder::InsertionGuard guard(builder);
768 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
769 auto &symbolTable = symbols.getNearestSymbolTable(op);
770 dummy = ClassOp::create(builder, op->getLoc(),
"Dummy", {}, {});
771 symbolTable.insert(dummy);
772 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
774 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy,
"dummy");
776 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
777 op->replaceAllUsesWith(anyrefOp);
785 std::string
getName()
const override {
return "firrtl-constantifier"; }
797struct ConnectInvalidator :
public Reduction {
798 uint64_t
match(Operation *op)
override {
799 if (!isa<FConnectLike>(op))
801 if (
auto *srcOp = op->getOperand(1).getDefiningOp())
802 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
803 isa<InvalidValueOp>(srcOp))
805 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
806 return type && type.isPassive();
808 LogicalResult
rewrite(Operation *op)
override {
810 auto rhs = op->getOperand(1);
811 OpBuilder builder(op);
812 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
813 auto *rhsOp = rhs.getDefiningOp();
814 op->setOperand(1, invOp);
819 std::string
getName()
const override {
return "connect-invalidator"; }
826struct AnnotationRemover :
public Reduction {
827 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
828 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
831 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
832 uint64_t matchId = 0;
835 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
836 for (
unsigned i = 0; i < annos.size(); ++i)
837 addMatch(1, matchId++);
840 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations"))
841 for (
auto portAnnoArray : portAnnos)
842 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
843 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
844 addMatch(1, matchId++);
848 ArrayRef<uint64_t> matches)
override {
850 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
853 uint64_t matchId = 0;
854 auto processAnnotations =
855 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
856 SmallVector<Attribute> newAnnotations;
857 for (
auto anno : annotations) {
858 if (!matchesSet.contains(matchId)) {
859 newAnnotations.push_back(anno);
862 nlaRemover.markNLAsInAnnotation(anno);
866 return ArrayAttr::get(op->getContext(), newAnnotations);
870 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations")) {
871 op->setAttr(
"annotations", processAnnotations(annos.getValue()));
875 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations")) {
876 SmallVector<Attribute> newPortAnnos;
877 for (
auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
878 newPortAnnos.push_back(
879 processAnnotations(portAnnoArrayAttr.getValue()));
881 op->setAttr(
"portAnnotations",
882 ArrayAttr::get(op->getContext(), newPortAnnos));
888 std::string
getName()
const override {
return "annotation-remover"; }
895struct SimplifyResets :
public OpReduction<CircuitOp> {
896 uint64_t
match(CircuitOp circuit)
override {
897 uint64_t numResets = 0;
898 AttrTypeWalker walker;
899 walker.addWalk([&](ResetType type) { ++numResets; });
901 circuit.walk([&](Operation *op) {
902 for (
auto result : op->getResults())
903 walker.walk(result.getType());
905 for (
auto ®ion : op->getRegions())
906 for (auto &block : region)
907 for (auto arg : block.getArguments())
908 walker.walk(arg.getType());
910 walker.walk(op->getAttrDictionary());
916 LogicalResult
rewrite(CircuitOp circuit)
override {
917 auto uint1Type = UIntType::get(circuit->getContext(), 1,
false);
918 auto constUint1Type = UIntType::get(circuit->getContext(), 1,
true);
920 AttrTypeReplacer replacer;
921 replacer.addReplacement([&](ResetType type) {
922 return type.isConst() ? constUint1Type : uint1Type;
924 replacer.recursivelyReplaceElementsIn(circuit,
true,
929 circuit.walk([&](Operation *op) {
932 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
933 fullAsyncResetAnnoClass,
934 ignoreFullAsyncResetAnnoClass);
938 if (
auto module = dyn_cast<FModuleLike>(op)) {
941 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
942 fullAsyncResetAnnoClass,
943 ignoreFullAsyncResetAnnoClass);
951 std::string
getName()
const override {
return "firrtl-simplify-resets"; }
957struct RootPortPruner :
public OpReduction<firrtl::FModuleOp> {
958 void matches(firrtl::FModuleOp module,
959 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
960 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
961 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
965 size_t numPorts =
module.getNumPorts();
966 for (
unsigned i = 0; i != numPorts; ++i) {
973 ArrayRef<uint64_t> matches)
override {
975 llvm::BitVector dropPorts(module.getNumPorts());
976 for (
auto portIdx : matches)
977 dropPorts.set(portIdx);
980 for (
auto portIdx : matches) {
982 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
987 module.erasePorts(dropPorts);
991 std::string
getName()
const override {
return "root-port-pruner"; }
997struct RootExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
998 void matches(firrtl::FExtModuleOp module,
999 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1000 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
1001 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
1006 size_t numPorts =
module.getNumPorts();
1007 for (
unsigned i = 0; i != numPorts; ++i)
1012 ArrayRef<uint64_t> matches)
override {
1013 if (matches.empty())
1017 llvm::BitVector dropPorts(module.getNumPorts());
1018 for (
auto portIdx : matches)
1019 dropPorts.set(portIdx);
1022 module.erasePorts(dropPorts);
1026 std::string
getName()
const override {
return "root-extmodule-port-pruner"; }
1031struct ExtmoduleInstanceRemover :
public OpReduction<firrtl::InstanceOp> {
1036 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1038 uint64_t
match(firrtl::InstanceOp instOp)
override {
1039 return isa<firrtl::FExtModuleOp>(
1040 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1042 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
1044 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1045 symbols.getNearestSymbolTable(instOp)))
1047 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1049 SmallVector<Value> replacementWires;
1051 auto wire = firrtl::WireOp::create(
1053 (Twine(instOp.getName()) +
"_" +
info.getName()).str())
1055 if (
info.isOutput()) {
1057 if (
auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(
info.type)) {
1059 firrtl::ConnectOp::create(builder, wire, inv);
1060 }
else if (
auto propType = dyn_cast<firrtl::PropertyType>(
info.type)) {
1061 auto unknown = tieOffCache.
getUnknown(propType);
1062 builder.create<firrtl::PropAssignOp>(wire, unknown);
1065 replacementWires.push_back(wire);
1067 nlaRemover.markNLAsInOperation(instOp);
1068 instOp.replaceAllUsesWith(std::move(replacementWires));
1072 std::string
getName()
const override {
return "extmodule-instance-remover"; }
1084struct PortPrunerHelpers {
1086 template <
typename ModuleOpType>
1087 static void computeUnusedInstancePorts(ModuleOpType module,
1088 ArrayRef<Operation *> users,
1089 llvm::BitVector &portsToRemove) {
1090 auto ports =
module.getPorts();
1091 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1092 bool portUsed =
false;
1093 for (
auto *user : users) {
1094 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1095 auto result = instOp.getResult(portIdx);
1096 if (!result.use_empty()) {
1103 portsToRemove.set(portIdx);
1109 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1110 const llvm::BitVector &portsToRemove) {
1112 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1113 for (
auto *user : users) {
1114 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1115 instancesToUpdate.push_back(instOp);
1118 for (
auto instOp : instancesToUpdate) {
1119 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1122 size_t newResultIdx = 0;
1123 for (
size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1125 if (portsToRemove[oldResultIdx]) {
1127 assert(instOp.getResult(oldResultIdx).use_empty() &&
1128 "removing port with uses");
1131 instOp.getResult(oldResultIdx)
1132 .replaceAllUsesWith(newInst->getResult(newResultIdx));
1143struct ModulePortPruner :
public OpReduction<firrtl::FModuleOp> {
1148 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1150 void matches(firrtl::FModuleOp module,
1151 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1152 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1153 auto &userMap = symbols.getSymbolUserMap(tableOp);
1154 auto ports =
module.getPorts();
1155 auto users = userMap.getUsers(module);
1159 llvm::BitVector portsToRemove(ports.size());
1163 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1167 portsToRemove.set();
1172 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1173 if (!portsToRemove[portIdx])
1175 if (!module.getArgument(portIdx).use_empty())
1176 portsToRemove.reset(portIdx);
1180 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1181 if (portsToRemove[portIdx])
1182 addMatch(1, portIdx);
1186 ArrayRef<uint64_t> matches)
override {
1187 if (matches.empty())
1191 llvm::BitVector portsToRemove(module.getNumPorts());
1192 for (
auto portIdx : matches)
1193 portsToRemove.set(portIdx);
1196 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1197 auto &userMap = symbols.getSymbolUserMap(tableOp);
1198 auto users = userMap.getUsers(module);
1201 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1206 module.erasePorts(portsToRemove);
1211 std::string
getName()
const override {
return "module-port-pruner"; }
1218struct ExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
1223 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1225 void matches(firrtl::FExtModuleOp module,
1226 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1227 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1228 auto &userMap = symbols.getSymbolUserMap(tableOp);
1229 auto ports =
module.getPorts();
1230 auto users = userMap.getUsers(module);
1233 llvm::BitVector portsToRemove(ports.size());
1235 if (users.empty()) {
1237 portsToRemove.set();
1241 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1246 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1247 if (portsToRemove[portIdx])
1248 addMatch(1, portIdx);
1252 ArrayRef<uint64_t> matches)
override {
1253 if (matches.empty())
1257 llvm::BitVector portsToRemove(module.getNumPorts());
1258 for (
auto portIdx : matches)
1259 portsToRemove.set(portIdx);
1262 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1263 auto &userMap = symbols.getSymbolUserMap(tableOp);
1264 auto users = userMap.getUsers(module);
1267 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1271 module.erasePorts(portsToRemove);
1276 std::string
getName()
const override {
return "extmodule-port-pruner"; }
1283struct ConnectForwarder :
public Reduction {
1285 domInfo = std::make_unique<DominanceInfo>(op);
1288 uint64_t
match(Operation *op)
override {
1289 if (!isa<firrtl::FConnectLike>(op))
1291 auto dest = op->getOperand(0);
1292 auto src = op->getOperand(1);
1293 auto *destOp = dest.getDefiningOp();
1294 auto *srcOp = src.getDefiningOp();
1300 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1306 unsigned numConnects = 0;
1307 for (
auto &use : dest.getUses()) {
1308 auto *op = use.getOwner();
1309 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1310 if (++numConnects > 1)
1317 !domInfo->properlyDominates(srcOp, op,
false))
1324 LogicalResult
rewrite(Operation *op)
override {
1325 auto dst = op->getOperand(0);
1326 auto src = op->getOperand(1);
1327 dst.replaceAllUsesExcept(src, op);
1329 SmallVector<Operation *> worklist(
1330 {dst.getDefiningOp(), src.getDefiningOp()});
1335 std::string
getName()
const override {
return "connect-forwarder"; }
1338 std::unique_ptr<DominanceInfo> domInfo;
1343template <
unsigned OpNum>
1344struct ConnectSourceOperandForwarder :
public Reduction {
1345 uint64_t
match(Operation *op)
override {
1346 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1348 auto dest = op->getOperand(0);
1349 auto *destOp = dest.getDefiningOp();
1352 if (!destOp || !destOp->hasOneUse() ||
1353 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1356 auto *srcOp = op->getOperand(1).getDefiningOp();
1357 if (!srcOp || OpNum >= srcOp->getNumOperands())
1360 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1362 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1364 return resultTy && opTy &&
1365 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1366 ((resultTy.getBitWidthOrSentinel() == -1) ==
1367 (opTy.getBitWidthOrSentinel() == -1)) &&
1368 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1371 LogicalResult
rewrite(Operation *op)
override {
1372 auto *destOp = op->getOperand(0).getDefiningOp();
1373 auto *srcOp = op->getOperand(1).getDefiningOp();
1374 auto forwardedOperand = srcOp->getOperand(OpNum);
1375 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1377 if (
auto wire = dyn_cast<firrtl::WireOp>(destOp))
1378 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1382 auto regName = destOp->getAttrOfType<StringAttr>(
"name");
1385 auto clock = destOp->getOperand(0);
1386 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1387 clock, regName ? regName.str() :
"")
1392 builder.setInsertionPointAfter(op);
1393 if (isa<firrtl::ConnectOp>(op))
1394 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1396 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1406 std::string
getName()
const override {
1407 return (
"connect-source-operand-" + Twine(OpNum) +
"-forwarder").str();
1414struct DetachSubaccesses :
public Reduction {
1415 void beforeReduction(mlir::ModuleOp op)
override { opsToErase.clear(); }
1417 for (
auto *op : opsToErase)
1418 op->dropAllReferences();
1419 for (
auto *op : opsToErase)
1422 uint64_t
match(Operation *op)
override {
1425 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1426 llvm::all_of(op->getUses(), [](
auto &use) {
1427 return use.getOperandNumber() == 0 &&
1428 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1429 firrtl::SubaccessOp>(use.getOwner());
1432 LogicalResult
rewrite(Operation *op)
override {
1434 OpBuilder builder(op);
1435 bool isWire = isa<firrtl::WireOp>(op);
1438 invalidClock = firrtl::InvalidValueOp::create(
1439 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1440 for (Operation *user :
llvm::make_early_inc_range(op->getUsers())) {
1441 builder.setInsertionPoint(user);
1442 auto type = user->getResult(0).getType();
1445 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1448 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1449 user->replaceAllUsesWith(replOp);
1450 opsToErase.insert(user);
1452 opsToErase.insert(op);
1455 std::string
getName()
const override {
return "detach-subaccesses"; }
1456 llvm::DenseSet<Operation *> opsToErase;
1462struct NodeSymbolRemover :
public Reduction {
1467 uint64_t
match(Operation *op)
override {
1469 auto sym = op->getAttrOfType<hw::InnerSymAttr>(
"inner_sym");
1470 if (!sym || sym.empty())
1474 if (innerSymUses.hasInnerRef(op))
1479 LogicalResult
rewrite(Operation *op)
override {
1480 op->removeAttr(
"inner_sym");
1484 std::string
getName()
const override {
return "node-symbol-remover"; }
1493hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1502 LogicalResult walkResult = targetTable.
walkSymbols(
1505 if (parentTable.lookup(name)) {
1513 return failed(walkResult);
1517struct EagerInliner :
public OpReduction<InstanceOp> {
1522 for (
auto circuitOp : op.getOps<CircuitOp>())
1523 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1524 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1527 nlaRemover.remove(op);
1529 innerSymTables.reset();
1532 uint64_t
match(InstanceOp instOp)
override {
1533 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1535 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1538 if (!isa<FModuleOp>(moduleOp))
1542 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1545 auto it = nlaTables.find(circuitOp);
1546 if (it == nlaTables.end() || !it->second)
1548 DenseSet<hw::HierPathOp> nlas;
1549 it->second->getInstanceNLAs(instOp, nlas);
1555 auto parentOp = instOp->getParentOfType<FModuleLike>();
1556 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1562 LogicalResult
rewrite(InstanceOp instOp)
override {
1563 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1564 auto moduleOp = cast<FModuleOp>(
1565 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1567 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1568 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1571 IRRewriter rewriter(instOp);
1572 SmallVector<Value> argWires;
1573 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1574 auto result = instOp.getResult(i);
1575 auto name = rewriter.getStringAttr(Twine(instOp.getName()) +
"_" +
1576 instOp.getPortName(i));
1577 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1578 name, NameKindEnum::DroppableName,
1579 instOp.getPortAnnotation(i), StringAttr{})
1581 result.replaceAllUsesWith(wire);
1582 argWires.push_back(wire);
1586 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1590 nlaRemover.markNLAsInOperation(instOp);
1592 nlaRemover.markNLAsInOperation(moduleOp);
1595 clonedModuleOp.erase();
1599 std::string
getName()
const override {
return "firrtl-eager-inliner"; }
1604 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1605 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1609struct ObjectInliner :
public OpReduction<ObjectOp> {
1611 blocksToSort.clear();
1614 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1617 for (
auto *block : blocksToSort)
1618 mlir::sortTopologically(block);
1619 blocksToSort.clear();
1620 nlaRemover.remove(op);
1621 innerSymTables.reset();
1624 uint64_t
match(ObjectOp objOp)
override {
1625 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1627 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1630 if (!isa<ClassOp>(classOp))
1635 auto parentOp = objOp->getParentOfType<FModuleLike>();
1636 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1640 for (
auto *user : objOp.getResult().getUsers())
1641 if (!isa<ObjectSubfieldOp>(user))
1647 LogicalResult
rewrite(ObjectOp objOp)
override {
1648 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1649 auto classOp = cast<ClassOp>(
1650 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1651 auto clonedClassOp = classOp.clone();
1654 IRRewriter rewriter(objOp);
1655 SmallVector<Value> portWires;
1656 auto classType = objOp.getType();
1659 for (
unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1660 auto element = classType.getElement(i);
1661 auto name = rewriter.getStringAttr(Twine(objOp.getName()) +
"_" +
1662 element.name.getValue());
1663 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1664 NameKindEnum::DroppableName,
1665 rewriter.getArrayAttr({}), StringAttr{})
1667 portWires.push_back(wire);
1671 SmallVector<ObjectSubfieldOp> subfieldOps;
1672 for (
auto *user : objOp.getResult().getUsers()) {
1673 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1674 subfieldOps.push_back(subfieldOp);
1675 auto index = subfieldOp.getIndex();
1676 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1680 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1686 SmallVector<FConnectLike> connectsToErase;
1687 for (
auto portWire : portWires) {
1691 for (
auto *user : portWire.getUsers()) {
1692 if (
auto connect = dyn_cast<FConnectLike>(user)) {
1693 if (
connect.getDest() == portWire) {
1695 connectsToErase.push_back(connect);
1705 portWire.replaceAllUsesWith(value);
1706 for (
auto connect : connectsToErase)
1708 if (portWire.use_empty())
1709 portWire.getDefiningOp()->erase();
1710 connectsToErase.clear();
1714 nlaRemover.markNLAsInOperation(objOp);
1719 blocksToSort.insert(objOp->getBlock());
1722 for (
auto subfieldOp : subfieldOps)
1725 clonedClassOp.erase();
1729 std::string
getName()
const override {
return "firrtl-object-inliner"; }
1732 SetVector<Block *> blocksToSort;
1735 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1740struct ResetDisconnector :
public OpReduction<RegResetOp> {
1741 uint64_t
match(RegResetOp op)
override {
return 1; }
1743 LogicalResult
rewrite(RegResetOp regResetOp)
override {
1744 ImplicitLocOpBuilder builder(regResetOp.getLoc(), regResetOp);
1745 auto regOp = RegOp::create(
1746 builder, regResetOp.getResult().getType(), regResetOp.getClockVal(),
1747 regResetOp.getNameAttr(), regResetOp.getNameKindAttr(),
1748 regResetOp.getAnnotationsAttr(), regResetOp.getInnerSymAttr(),
1749 regResetOp.getForceableAttr());
1751 regResetOp.getResult().replaceAllUsesWith(regOp.getResult());
1752 if (regResetOp.getForceable())
1753 regResetOp.getRef().replaceAllUsesWith(regOp.getRef());
1759 std::string
getName()
const override {
return "reset-disconnector"; }
1774 uint64_t
match(Operation *op)
override {
1776 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1777 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1778 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1779 firrtl::CoverOp>(op);
1781 LogicalResult
rewrite(Operation *op)
override {
1782 TypeSwitch<Operation *, void>(op)
1783 .Case<firrtl::WireOp>([](
auto op) { op.setName(
"wire"); })
1784 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1785 [](
auto op) { op.setName(
"reg"); })
1786 .Case<firrtl::NodeOp>([](
auto op) { op.setName(
"node"); })
1787 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1788 [](
auto op) { op.setName(
"mem"); })
1789 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](
auto op) {
1790 op->setAttr(
"message", StringAttr::get(op.getContext(),
""));
1791 op->setAttr(
"name", StringAttr::get(op.getContext(),
""));
1796 std::string
getName()
const override {
1797 return "module-internal-name-sanitizer";
1802 bool isOneShot()
const override {
return true; }
1822 if (portNameIndex >= 26)
1824 return 'a' + portNameIndex++;
1829 LogicalResult
rewrite(firrtl::CircuitOp circuitOp)
override {
1834 iGraph.getTopLevelModule().setName(circuitName);
1835 circuitOp.setName(circuitName);
1837 for (
auto *node : iGraph) {
1838 auto module = node->getModule<firrtl::FModuleLike>();
1840 bool shouldReplacePorts =
false;
1841 SmallVector<Attribute> newNames;
1842 if (
auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1847 auto oldPorts = fmodule.getPorts();
1848 shouldReplacePorts = !oldPorts.empty();
1849 for (
unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1850 auto port = oldPorts[i];
1852 .
Case<firrtl::ClockType>(
1853 [&](
auto a) {
return ns.
newName(
"clk"); })
1854 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1855 [&](
auto a) {
return ns.
newName(
"rst"); })
1856 .Case<firrtl::RefType>(
1857 [&](
auto a) {
return ns.
newName(
"ref"); })
1858 .Default([&](
auto a) {
1861 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1863 fmodule->setAttr(
"portNames",
1864 ArrayAttr::get(fmodule.getContext(), newNames));
1867 if (module == iGraph.getTopLevelModule())
1870 StringAttr::get(circuitOp.getContext(), nameGenerator.
getNextName());
1871 module.setName(newName);
1872 for (
auto *use : node->uses()) {
1873 auto useOp = use->getInstance();
1874 if (
auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1875 instanceOp.setModuleName(newName);
1876 instanceOp.setName(newName);
1877 if (shouldReplacePorts)
1878 instanceOp.setPortNamesAttr(
1879 ArrayAttr::get(circuitOp.getContext(), newNames));
1880 }
else if (
auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1883 auto oldClassType = objectOp.getType();
1884 auto newClassType = firrtl::ClassType::get(
1885 circuitOp.getContext(), FlatSymbolRefAttr::get(newName),
1886 oldClassType.getElements());
1887 objectOp.getResult().setType(newClassType);
1888 objectOp.setName(newName);
1896 std::string
getName()
const override {
return "module-name-sanitizer"; }
1900 bool isOneShot()
const override {
return true; }
1919struct ModuleSwapper :
public OpReduction<InstanceOp> {
1921 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1922 struct CircuitState {
1923 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1924 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1925 std::unique_ptr<NLATable> nlaTable;
1931 moduleSizes.clear();
1932 circuitStates.clear();
1935 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1936 auto &state = circuitStates[circuitOp];
1937 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1938 buildModuleTypeGroups(circuitOp, state);
1939 return WalkResult::skip();
1942 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1948 PortSignature getModulePortSignature(FModuleLike module) {
1949 PortSignature signature;
1950 signature.reserve(module.getNumPorts());
1951 for (
unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1952 signature.emplace_back(module.getPortType(i),
module.getPortDirection(i));
1957 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1959 for (
auto module : circuitOp.
getBodyBlock()->getOps<FModuleLike>()) {
1960 auto signature = getModulePortSignature(module);
1961 state.moduleTypeGroups[signature].push_back(module);
1965 for (
auto &[signature, modules] : state.moduleTypeGroups) {
1966 if (modules.size() <= 1)
1969 FModuleLike smallestModule =
nullptr;
1970 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1972 for (
auto module : modules) {
1973 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1974 if (size < smallestSize) {
1975 smallestSize = size;
1976 smallestModule =
module;
1981 for (
auto module : modules) {
1982 if (module != smallestModule) {
1983 state.instanceToCanonicalModule[
module.getModuleNameAttr()] =
1990 uint64_t
match(InstanceOp instOp)
override {
1992 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1994 const auto &state = circuitStates.at(circuitOp);
1997 DenseSet<hw::HierPathOp> nlas;
1998 state.nlaTable->getInstanceNLAs(instOp, nlas);
2003 auto moduleName = instOp.getModuleNameAttr().getAttr();
2004 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
2005 if (!canonicalModule)
2009 auto currentModule = cast<FModuleLike>(
2010 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
2011 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
2012 uint64_t canonicalSize =
2013 moduleSizes.getModuleSize(canonicalModule, symbols);
2014 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
2017 LogicalResult
rewrite(InstanceOp instOp)
override {
2019 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2021 const auto &state = circuitStates.at(circuitOp);
2024 auto canonicalModule = state.instanceToCanonicalModule.at(
2025 instOp.getModuleNameAttr().getAttr());
2026 auto canonicalName = canonicalModule.getModuleNameAttr();
2027 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2030 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2035 std::string
getName()
const override {
return "firrtl-module-swapper"; }
2044 DenseMap<CircuitOp, CircuitState> circuitStates;
2062struct ForceDedup :
public OpReduction<CircuitOp> {
2066 modulesToErase.clear();
2067 moduleSizes.clear();
2070 nlaRemover.remove(op);
2071 for (
auto mod : modulesToErase)
2076 void matches(CircuitOp circuitOp,
2077 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2078 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2080 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2081 if (!anno.
isClass(mustDeduplicateAnnoClass))
2084 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2085 if (!modulesAttr || modulesAttr.size() < 2)
2091 uint64_t totalSize = 0;
2092 ArrayAttr portTypes;
2093 DenseBoolArrayAttr portDirections;
2094 bool allSame =
true;
2095 for (
auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2101 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2106 totalSize += moduleSizes.getModuleSize(mod, symbols);
2108 portTypes = mod.getPortTypesAttr();
2109 portDirections = mod.getPortDirectionsAttr();
2110 }
else if (portTypes != mod.getPortTypesAttr() ||
2111 portDirections != mod.getPortDirectionsAttr()) {
2121 addMatch(totalSize, annoIdx);
2126 ArrayRef<uint64_t> matches)
override {
2127 auto *
context = circuitOp->getContext();
2131 SmallVector<Annotation> newAnnotations;
2133 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2135 if (!llvm::is_contained(matches, annoIdx)) {
2136 newAnnotations.push_back(anno);
2139 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2140 assert(anno.
isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2141 modulesAttr.size() >= 2);
2144 SmallVector<StringAttr> moduleNames;
2145 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2147 auto refStr = moduleRef.getValue();
2148 auto pipePos = refStr.find(
'|');
2149 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2150 auto moduleName = refStr.substr(pipePos + 1);
2151 moduleNames.push_back(StringAttr::get(
context, moduleName));
2156 if (moduleNames.size() < 2)
2161 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2162 nlaRemover.markNLAsInAnnotation(anno.
getAttr());
2164 if (newAnnotations.size() == annotations.size())
2169 newAnnoSet.applyToOperation(circuitOp);
2173 std::string
getName()
const override {
return "firrtl-force-dedup"; }
2179 void replaceModuleReferences(CircuitOp circuitOp,
2180 ArrayRef<StringAttr> moduleNames,
2183 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2184 auto &symbolTable = symbols.getSymbolTable(tableOp);
2185 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2186 auto *
context = circuitOp->getContext();
2190 FModuleLike canonicalModule;
2191 SmallVector<FModuleLike> modulesToReplace;
2192 for (
auto name : moduleNames) {
2193 if (
auto mod = symbolTable.lookup<FModuleLike>(name)) {
2194 if (!canonicalModule)
2195 canonicalModule = mod;
2197 modulesToReplace.push_back(mod);
2200 if (modulesToReplace.empty())
2204 auto canonicalName = canonicalModule.getModuleNameAttr();
2205 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2206 for (
auto moduleName : moduleNames) {
2207 if (moduleName == canonicalName)
2209 auto *symbolOp = symbolTable.lookup(moduleName);
2212 for (
auto *user : symbolUserMap.getUsers(symbolOp)) {
2213 auto instOp = dyn_cast<InstanceOp>(user);
2214 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2216 instOp.setModuleNameAttr(canonicalRef);
2217 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2223 for (
auto oldMod : modulesToReplace) {
2224 SmallVector<hw::HierPathOp> nlaOps(
2225 nlaTable.
lookup(oldMod.getModuleNameAttr()));
2226 for (
auto nlaOp : nlaOps) {
2227 nlaTable.
erase(nlaOp);
2228 StringAttr oldModName = oldMod.getModuleNameAttr();
2229 StringAttr newModName = canonicalName;
2230 SmallVector<Attribute, 4> newPath;
2231 for (
auto nameRef : nlaOp.getNamepath()) {
2232 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2233 if (ref.getModule() == oldModName) {
2234 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2235 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2236 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2237 if (oldInst && newInst) {
2240 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2241 auto newModNames = newInst.getReferencedModuleNamesAttr();
2242 if (!oldModNames.empty() && !newModNames.empty()) {
2243 oldModName = cast<StringAttr>(oldModNames[0]);
2244 newModName = cast<StringAttr>(newModNames[0]);
2248 newPath.push_back(ref);
2249 }
else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2250 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2252 newPath.push_back(nameRef);
2255 nlaOp.setNamepathAttr(ArrayAttr::get(
context, newPath));
2261 for (
auto module : modulesToReplace) {
2262 nlaRemover.markNLAsInOperation(module);
2263 modulesToErase.insert(module);
2269 SetVector<FModuleLike> modulesToErase;
2289struct MustDedupChildren :
public OpReduction<CircuitOp> {
2294 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
2298 void matches(CircuitOp circuitOp,
2299 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2301 uint64_t matchId = 0;
2303 DenseSet<StringRef> modulesAlreadyInMustDedup;
2304 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations))
2305 if (anno.isClass(mustDeduplicateAnnoClass))
2306 if (auto modulesAttr = anno.getMember<ArrayAttr>(
"modules"))
2307 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2309 modulesAlreadyInMustDedup.insert(target->module);
2311 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2312 if (!anno.
isClass(mustDeduplicateAnnoClass))
2315 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2316 if (!modulesAttr || modulesAttr.size() < 2)
2320 processInstanceGroups(
2321 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2325 SmallDenseSet<StringAttr, 4> moduleTargets;
2326 for (
auto instOp : instanceGroup) {
2327 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2328 for (
auto moduleName : moduleNames)
2329 moduleTargets.insert(cast<StringAttr>(moduleName));
2331 if (moduleTargets.size() < 2)
2336 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2337 auto moduleNames = inst.getReferencedModuleNames();
2338 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2339 return modulesAlreadyInMustDedup.contains(moduleName);
2344 addMatch(1, matchId - 1);
2350 ArrayRef<uint64_t> matches)
override {
2351 auto *
context = circuitOp->getContext();
2353 SmallVector<Annotation> newAnnotations;
2354 uint64_t matchId = 0;
2356 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2357 if (!anno.
isClass(mustDeduplicateAnnoClass)) {
2358 newAnnotations.push_back(anno);
2362 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2363 if (!modulesAttr || modulesAttr.size() < 2) {
2364 newAnnotations.push_back(anno);
2368 processInstanceGroups(
2369 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2371 if (!llvm::is_contained(matches, matchId++))
2375 SmallSetVector<StringAttr, 4> moduleTargets;
2376 for (
auto instOp : instanceGroup) {
2377 auto moduleNames = instOp.getReferencedModuleNames();
2378 for (
auto moduleName : moduleNames) {
2380 target.circuit = circuitOp.getName();
2381 target.module = moduleName;
2382 moduleTargets.insert(target.toStringAttr(
context));
2387 SmallVector<NamedAttribute> newAnnoAttrs;
2388 newAnnoAttrs.emplace_back(
2389 StringAttr::get(
context,
"class"),
2390 StringAttr::get(
context, mustDeduplicateAnnoClass));
2391 newAnnoAttrs.emplace_back(
2392 StringAttr::get(
context,
"modules"),
2394 SmallVector<Attribute>(moduleTargets.begin(),
2395 moduleTargets.end())));
2397 auto newAnnoDict = DictionaryAttr::get(
context, newAnnoAttrs);
2398 newAnnotations.emplace_back(newAnnoDict);
2402 newAnnotations.push_back(anno);
2407 newAnnoSet.applyToOperation(circuitOp);
2411 std::string
getName()
const override {
return "must-dedup-children"; }
2419 void processInstanceGroups(
2420 CircuitOp circuitOp, ArrayAttr modulesAttr,
2421 llvm::function_ref<
void(ArrayRef<FInstanceLike>)> callback) {
2422 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2425 SmallVector<FModuleLike> modules;
2426 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2428 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2429 modules.push_back(mod);
2432 if (modules.size() < 2)
2439 struct InstanceGroup {
2440 SmallVector<FInstanceLike> instances;
2441 bool nameIsUnique =
true;
2443 MapVector<StringAttr, InstanceGroup> instanceGroups;
2444 for (
auto module : modules) {
2446 module.walk([&](FInstanceLike instOp) {
2447 if (isa<ObjectOp>(instOp.getOperation()))
2449 auto name = instOp.getInstanceNameAttr();
2450 auto &group = instanceGroups[name];
2451 if (nameCounts[name]++ > 1)
2452 group.nameIsUnique =
false;
2453 group.instances.push_back(instOp);
2459 for (
auto &[name, group] : instanceGroups)
2460 if (group.nameIsUnique && group.instances.size() >= 2)
2461 callback(group.instances);
2468struct LayerDisable :
public OpReduction<CircuitOp> {
2469 LayerDisable(MLIRContext *
context) {
2470 pm = std::make_unique<mlir::PassManager>(
2471 context,
"builtin.module", mlir::OpPassManager::Nesting::Explicit);
2472 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2475 void beforeReduction(mlir::ModuleOp op)
override { symbolRefAttrMap.clear(); }
2477 void afterReduction(mlir::ModuleOp op)
override { (void)pm->run(op); };
2479 void matches(CircuitOp circuitOp,
2480 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2481 uint64_t matchId = 0;
2483 SmallVector<FlatSymbolRefAttr> nestedRefs;
2484 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2487 rootRef = layerOp.getSymNameAttr();
2489 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2491 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2492 addMatch(1, matchId++);
2494 for (
auto nestedLayerOp : layerOp.getOps<LayerOp>())
2495 addLayer(rootRef, nestedLayerOp);
2497 if (!nestedRefs.empty())
2498 nestedRefs.pop_back();
2501 for (
auto layerOp : circuitOp.getOps<LayerOp>())
2502 addLayer({}, layerOp);
2506 ArrayRef<uint64_t> matches)
override {
2507 SmallVector<Attribute> disableLayers;
2508 if (
auto existingDisables = circuitOp.getDisableLayersAttr()) {
2509 auto disableRange = existingDisables.getAsRange<Attribute>();
2510 disableLayers.append(disableRange.begin(), disableRange.end());
2512 for (
auto match : matches)
2513 disableLayers.push_back(symbolRefAttrMap.at(match));
2515 circuitOp.setDisableLayersAttr(
2516 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2521 std::string
getName()
const override {
return "firrtl-layer-disable"; }
2523 std::unique_ptr<mlir::PassManager> pm;
2524 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2534 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2536 auto elements = listOp.getElements();
2537 for (
size_t i = 0; i < elements.size(); ++i)
2542 ArrayRef<uint64_t>
matches)
override {
2544 llvm::SmallDenseSet<uint64_t, 4> matchesSet(
matches.begin(),
matches.end());
2547 SmallVector<Value> newElements;
2548 auto elements = listOp.getElements();
2549 for (
size_t i = 0; i < elements.size(); ++i) {
2550 if (!matchesSet.contains(i))
2551 newElements.push_back(elements[i]);
2555 OpBuilder builder(listOp);
2556 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2557 listOp.getType(), newElements);
2558 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2565 return "firrtl-list-create-element-remover";
2571 uint64_t
match(FModuleOp module)
override {
2572 return module.getConvention() != Convention::Internal;
2575 LogicalResult
rewrite(FModuleOp module)
override {
2576 module.setConvention(Convention::Internal);
2580 std::string
getName()
const override {
return "module-convention-remover"; }
2587 uint64_t
match(FExtModuleOp extmodule)
override {
2588 return extmodule.getConvention() != Convention::Internal;
2591 LogicalResult
rewrite(FExtModuleOp extmodule)
override {
2592 extmodule.setConvention(Convention::Internal);
2597 return "extmodule-convention-remover";
2614 patterns.add<SimplifyResets, 35>();
2616 patterns.add<MustDedupChildren, 33>();
2617 patterns.add<AnnotationRemover, 32>();
2619 patterns.add<LayerDisable, 30>(getContext());
2625 firrtl::createLowerCHIRRTLPass(),
true,
true);
2630 patterns.add<FIRRTLModuleExternalizer, 25>();
2631 patterns.add<InstanceStubber, 24>();
2636 firrtl::createLowerFIRRTLTypes(),
true,
true);
2643 firrtl::createRemoveUnusedPorts({
true}));
2644 patterns.add<NodeSymbolRemover, 16>();
2646 patterns.add<ConnectForwarder, 14>();
2647 patterns.add<ConnectInvalidator, 13>();
2649 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2650 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2651 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2653 patterns.add<ResetDisconnector, 8>();
2654 patterns.add<DetachSubaccesses, 7>();
2655 patterns.add<ModulePortPruner, 7>();
2656 patterns.add<ExtmodulePortPruner, 6>();
2658 patterns.add<RootExtmodulePortPruner, 5>();
2659 patterns.add<ExtmoduleInstanceRemover, 4>();
2660 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2661 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2662 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2670 mlir::DialectRegistry ®istry) {
2671 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Helper class to cache tie-off values for different FIRRTL types.
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
@ None
Don't explicitly preserve any named values.
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
void pruneUnusedOps(SmallVectorImpl< Operation * > &worklist, Reduction &reduction)
Starting from an initial worklist of operations, traverse through it and its operands and erase opera...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Reduction that removes the convention attribute from external modules.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
uint64_t match(FExtModuleOp extmodule) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(FExtModuleOp extmodule) override
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Reduction that removes the convention attribute from regular modules.
uint64_t match(FModuleOp module) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(FModuleOp module) override
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Pseudo-reduction that sanitizes the names of operations inside modules.
Pseudo-reduction that sanitizes module and port names.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
A reduction pattern that applies an mlir::Pass.
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)