24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/Matchers.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/Support/Debug.h"
33#define DEBUG_TYPE "firrtl-reductions"
37using namespace firrtl;
39using llvm::SmallDenseSet;
40using llvm::SmallSetVector;
53 return tables->getSymbolTable(op);
63 return userMaps.insert({op, SymbolUserMap(*
tables, op)}).first->second;
70 tables = std::make_unique<SymbolTableCollection>();
75 std::unique_ptr<SymbolTableCollection>
tables;
82static std::optional<firrtl::FModuleOp>
85 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
86 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
88 return moduleOp ? std::optional(moduleOp) : std::nullopt;
99 module->walk([&](Operation *op) {
101 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(op))
115 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
116 auto *op = use.getOwner();
117 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
119 if (use.getOperandNumber() != 0)
121 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
138 unsigned numRemoved = 0;
140 SymbolTableCollection symbolTables;
141 for (Operation &rootOp : *
module.getBody()) {
142 if (!isa<firrtl::CircuitOp>(&rootOp))
144 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
145 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
147 if (
auto *op = symbolTable.lookup(sym)) {
148 if (symbolUserMap.useEmpty(op)) {
157 if (numRemoved > 0 || numLost > 0) {
158 llvm::dbgs() <<
"Removed " << numRemoved <<
" NLAs";
160 llvm::dbgs() <<
" (" << numLost <<
" no longer there)";
161 llvm::dbgs() <<
"\n";
170 if (
auto dict = dyn_cast<DictionaryAttr>(anno)) {
171 if (
auto field = dict.getAs<FlatSymbolRefAttr>(
"circt.nonlocal"))
172 nlasToRemove.insert(field.getAttr());
173 for (
auto namedAttr : dict)
174 markNLAsInAnnotation(namedAttr.getValue());
175 }
else if (
auto array = dyn_cast<ArrayAttr>(anno)) {
176 for (
auto attr : array)
177 markNLAsInAnnotation(attr);
185 op->walk([&](Operation *op) {
186 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
187 markNLAsInAnnotation(annos);
202struct FIRRTLModuleExternalizer :
public OpReduction<FModuleOp> {
209 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
211 uint64_t
match(FModuleOp module)
override {
212 if (innerSymUses.hasInnerRef(module))
214 return moduleSizes.getModuleSize(module, symbols);
217 LogicalResult
rewrite(FModuleOp module)
override {
220 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
221 for (
auto attr :
module.getPortTypes()) {
222 auto type = cast<TypeAttr>(attr).getValue();
223 if (
auto refType = type_dyn_cast<RefType>(type))
224 if (
auto layer = refType.getLayer())
225 layers.insert(layer);
227 SmallVector<Attribute, 4> layersArray;
228 layersArray.reserve(layers.size());
229 for (
auto layer : layers)
230 layersArray.push_back(layer);
232 nlaRemover.markNLAsInOperation(module);
233 OpBuilder builder(module);
234 auto extmodule = FExtModuleOp::create(
235 builder, module->getLoc(),
236 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
237 module.getConventionAttr(), module.getPorts(),
238 builder.getArrayAttr(layersArray), StringRef(),
239 module.getAnnotationsAttr());
240 SymbolTable::setSymbolVisibility(extmodule,
241 SymbolTable::getSymbolVisibility(module));
246 std::string
getName()
const override {
return "firrtl-module-externalizer"; }
265static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
267 auto type = type_dyn_cast<FIRRTLType>(value.getType());
272 if (
auto refType = type_dyn_cast<RefType>(type)) {
274 assert(!
flip &&
"input probes are not allowed");
276 auto underlyingType = refType.getType();
278 if (!refType.getForceable()) {
281 auto targetWire = WireOp::create(builder, underlyingType);
282 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
283 builder.create<RefDefineOp>(value, refSend.getResult());
286 auto invalid = tieOffCache.
getInvalid(underlyingType);
287 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
293 WireOp::create(builder, underlyingType,
294 "", NameKindEnum::DroppableName,
295 ArrayRef<Attribute>{},
300 auto targetWire = forceableWire.getResult();
301 auto forceableRef = forceableWire.getDataRef();
303 builder.create<RefDefineOp>(value, forceableRef);
306 auto invalid = tieOffCache.
getInvalid(underlyingType);
307 MatchingConnectOp::create(builder, targetWire, invalid);
312 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
313 for (
auto element :
llvm::enumerate(bundleType.getElements())) {
314 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
315 invalidateOutputs(builder, subfield, tieOffCache,
316 flip ^ element.value().isFlip);
317 if (subfield.use_empty())
318 subfield.getDefiningOp()->erase();
324 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
325 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
326 auto subindex = builder.createOrFold<SubindexOp>(value, i);
327 invalidateOutputs(builder, subindex, tieOffCache,
flip);
328 if (subindex.use_empty())
329 subindex.getDefiningOp()->erase();
339 if (
auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
340 auto invalid = tieOffCache.
getInvalid(baseType);
341 ConnectOp::create(builder, value, invalid);
346 if (
auto propType = type_dyn_cast<PropertyType>(type)) {
347 auto unknown = tieOffCache.
getUnknown(propType);
348 builder.create<PropAssignOp>(value, unknown);
353static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
355 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
358 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
359 for (
auto element :
llvm::enumerate(bundleType.getElements()))
360 connectToLeafs(builder,
361 firrtl::SubfieldOp::create(builder, dest, element.index()),
365 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
366 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
367 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
371 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
374 auto destWidth = type.getBitWidthOrSentinel();
375 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
376 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
377 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
378 if (!isa<firrtl::UIntType>(type)) {
379 if (isa<firrtl::SIntType>(type))
380 value = firrtl::AsSIntPrimOp::create(builder, value);
384 firrtl::ConnectOp::create(builder, dest, value);
388static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
389 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
392 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
393 for (
auto element :
llvm::enumerate(bundleType.getElements()))
396 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
399 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
400 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
401 reduceXor(builder, into,
402 builder.createOrFold<firrtl::SubindexOp>(value, i));
405 if (!isa<firrtl::UIntType>(type)) {
406 if (isa<firrtl::SIntType>(type))
407 value = firrtl::AsUIntPrimOp::create(builder, value);
411 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
417struct InstanceStubber :
public OpReduction<firrtl::InstanceOp> {
420 erasedModules.clear();
428 SmallVector<Operation *> worklist;
429 auto deadInsts = erasedInsts;
430 for (
auto *op : erasedModules)
431 worklist.push_back(op);
432 while (!worklist.empty()) {
433 auto *op = worklist.pop_back_val();
434 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
435 op->walk([&](firrtl::InstanceOp instOp) {
436 auto moduleOp = cast<firrtl::FModuleLike>(
437 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
438 deadInsts.insert(instOp);
440 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
441 [&](Operation *user) { return deadInsts.contains(user); })) {
442 LLVM_DEBUG(llvm::dbgs() <<
"- Removing transitively unused module `"
443 << moduleOp.getModuleName() <<
"`\n");
444 erasedModules.insert(moduleOp);
445 worklist.push_back(moduleOp);
450 for (
auto *op : erasedInsts)
452 for (
auto *op : erasedModules)
454 nlaRemover.remove(op);
457 uint64_t
match(firrtl::InstanceOp instOp)
override {
459 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
463 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
464 LLVM_DEBUG(llvm::dbgs()
465 <<
"Stubbing instance `" << instOp.getName() <<
"`\n");
466 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
468 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
469 auto result = instOp.getResult(i);
470 auto name = builder.getStringAttr(Twine(instOp.getName()) +
"_" +
471 instOp.getPortName(i));
473 firrtl::WireOp::create(builder, result.getType(), name,
474 firrtl::NameKindEnum::DroppableName,
475 instOp.getPortAnnotation(i), StringAttr{})
477 invalidateOutputs(builder, wire, tieOffCache,
478 instOp.getPortDirection(i) == firrtl::Direction::In);
479 result.replaceAllUsesWith(wire);
481 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
482 auto moduleOp = cast<firrtl::FModuleLike>(
483 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
484 nlaRemover.markNLAsInOperation(instOp);
485 erasedInsts.insert(instOp);
487 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
488 [&](Operation *user) { return erasedInsts.contains(user); })) {
489 LLVM_DEBUG(llvm::dbgs() <<
"- Removing now unused module `"
490 << moduleOp.getModuleName() <<
"`\n");
491 erasedModules.insert(moduleOp);
496 std::string
getName()
const override {
return "instance-stubber"; }
501 llvm::DenseSet<Operation *> erasedInsts;
502 llvm::DenseSet<Operation *> erasedModules;
508struct MemoryStubber :
public OpReduction<firrtl::MemOp> {
509 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
510 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
511 LogicalResult
rewrite(firrtl::MemOp memOp)
override {
512 LLVM_DEBUG(llvm::dbgs() <<
"Stubbing memory `" << memOp.getName() <<
"`\n");
513 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
516 SmallVector<Value> outputs;
517 for (
unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
518 auto result = memOp.getResult(i);
519 auto name = builder.getStringAttr(Twine(memOp.getName()) +
"_" +
520 memOp.getPortName(i));
522 firrtl::WireOp::create(builder, result.getType(), name,
523 firrtl::NameKindEnum::DroppableName,
524 memOp.getPortAnnotation(i), StringAttr{})
526 invalidateOutputs(builder, wire, tieOffCache,
true);
527 result.replaceAllUsesWith(wire);
531 switch (memOp.getPortKind(i)) {
532 case firrtl::MemOp::PortKind::Read:
533 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
535 case firrtl::MemOp::PortKind::Write:
536 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
538 case firrtl::MemOp::PortKind::ReadWrite:
539 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
540 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
542 case firrtl::MemOp::PortKind::Debug:
547 if (!isa<firrtl::RefType>(result.getType())) {
550 cast<firrtl::BundleType>(wire.getType()).getNumElements();
551 for (
unsigned i = 0; i != numFields; ++i) {
552 if (i != 2 && i != 3 && i != 5)
553 reduceXor(builder, xorInputs,
554 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
557 reduceXor(builder, xorInputs, input);
562 outputs.push_back(output);
566 for (
auto output : outputs)
567 connectToLeafs(builder, output, xorInputs);
569 nlaRemover.markNLAsInOperation(memOp);
573 std::string
getName()
const override {
return "memory-stubber"; }
580static bool isFlowSensitiveOp(Operation *op) {
581 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
582 SubaccessOp, ObjectSubfieldOp>(op);
588template <
unsigned OpNum>
589struct FIRRTLOperandForwarder :
public Reduction {
590 uint64_t
match(Operation *op)
override {
591 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
593 if (isFlowSensitiveOp(op))
596 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
598 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
599 return resultTy && opTy &&
600 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
601 (resultTy.getBitWidthOrSentinel() == -1) ==
602 (opTy.getBitWidthOrSentinel() == -1) &&
603 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
605 LogicalResult
rewrite(Operation *op)
override {
607 ImplicitLocOpBuilder builder(op->getLoc(), op);
608 auto result = op->getResult(0);
609 auto operand = op->getOperand(OpNum);
610 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
611 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
612 auto resultWidth = resultTy.getBitWidthOrSentinel();
613 auto operandWidth = operandTy.getBitWidthOrSentinel();
615 if (resultWidth < operandWidth)
617 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
618 else if (resultWidth > operandWidth)
619 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
622 LLVM_DEBUG(llvm::dbgs() <<
"Forwarding " << newOp <<
" in " << *op <<
"\n");
623 result.replaceAllUsesWith(newOp);
627 std::string
getName()
const override {
628 return (
"firrtl-operand" + Twine(OpNum) +
"-forwarder").str();
639 anyrefCastDummy.clear();
640 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
641 for (
auto classOp : circuitOp.getOps<ClassOp>()) {
642 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
643 anyrefCastDummy.insert({circuitOp, classOp});
644 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
647 return WalkResult::skip();
651 uint64_t
match(Operation *op)
override {
652 if (op->hasTrait<OpTrait::ConstantLike>()) {
654 if (!matchPattern(op, m_Constant(&attr)))
656 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
657 if (intAttr.getValue().isZero())
659 if (
auto strAttr = dyn_cast<StringAttr>(attr))
662 if (
auto floatAttr = dyn_cast<FloatAttr>(attr))
663 if (floatAttr.getValue().isZero())
666 if (
auto listOp = dyn_cast<ListCreateOp>(op))
667 if (listOp.getElements().empty())
669 if (
auto pathOp = dyn_cast<UnresolvedPathOp>(op))
670 if (pathOp.getTarget().empty())
674 if (
auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
675 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
677 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
678 if (anyrefCastDummyNames[circuitOp].contains(className))
682 if (op->getNumResults() != 1)
684 if (op->hasAttr(
"inner_sym"))
686 if (isFlowSensitiveOp(op))
688 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
689 DoubleType, ListType, PathType, AnyRefType>(
690 op->getResult(0).getType());
693 LogicalResult
rewrite(Operation *op)
override {
694 OpBuilder builder(op);
695 auto type = op->getResult(0).getType();
698 if (isa<UIntType, SIntType>(type)) {
699 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
702 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
703 APSInt(width, isa<UIntType>(type)));
704 op->replaceAllUsesWith(newOp);
710 if (isa<StringType>(type)) {
711 auto attr = builder.getStringAttr(
"");
712 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
713 op->replaceAllUsesWith(newOp);
719 if (isa<FIntegerType>(type)) {
720 auto attr = builder.getIntegerAttr(builder.getIntegerType(64,
true), 0);
721 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
722 op->replaceAllUsesWith(newOp);
728 if (isa<BoolType>(type)) {
729 auto attr = builder.getBoolAttr(
false);
730 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
731 op->replaceAllUsesWith(newOp);
737 if (isa<DoubleType>(type)) {
738 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
739 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
740 op->replaceAllUsesWith(newOp);
746 if (isa<ListType>(type)) {
748 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
749 op->replaceAllUsesWith(newOp);
755 if (isa<PathType>(type)) {
756 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(),
"");
757 op->replaceAllUsesWith(newOp);
763 if (isa<AnyRefType>(type)) {
764 auto circuitOp = op->getParentOfType<CircuitOp>();
765 auto &dummy = anyrefCastDummy[circuitOp];
767 OpBuilder::InsertionGuard guard(builder);
768 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
769 auto &symbolTable = symbols.getNearestSymbolTable(op);
770 dummy = ClassOp::create(builder, op->getLoc(),
"Dummy", {}, {});
771 symbolTable.insert(dummy);
772 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
774 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy,
"dummy");
776 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
777 op->replaceAllUsesWith(anyrefOp);
785 std::string
getName()
const override {
return "firrtl-constantifier"; }
797struct ConnectInvalidator :
public Reduction {
798 uint64_t
match(Operation *op)
override {
799 if (!isa<FConnectLike>(op))
801 if (
auto *srcOp = op->getOperand(1).getDefiningOp())
802 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
803 isa<InvalidValueOp>(srcOp))
805 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
806 return type && type.isPassive();
808 LogicalResult
rewrite(Operation *op)
override {
810 auto rhs = op->getOperand(1);
811 OpBuilder builder(op);
812 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
813 auto *rhsOp = rhs.getDefiningOp();
814 op->setOperand(1, invOp);
819 std::string
getName()
const override {
return "connect-invalidator"; }
826struct AnnotationRemover :
public Reduction {
827 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
828 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
831 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
832 uint64_t matchId = 0;
835 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
836 for (
unsigned i = 0; i < annos.size(); ++i)
837 addMatch(1, matchId++);
840 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations"))
841 for (
auto portAnnoArray : portAnnos)
842 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
843 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
844 addMatch(1, matchId++);
848 ArrayRef<uint64_t> matches)
override {
850 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
853 uint64_t matchId = 0;
854 auto processAnnotations =
855 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
856 SmallVector<Attribute> newAnnotations;
857 for (
auto anno : annotations) {
858 if (!matchesSet.contains(matchId)) {
859 newAnnotations.push_back(anno);
862 nlaRemover.markNLAsInAnnotation(anno);
866 return ArrayAttr::get(op->getContext(), newAnnotations);
870 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations")) {
871 op->setAttr(
"annotations", processAnnotations(annos.getValue()));
875 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations")) {
876 SmallVector<Attribute> newPortAnnos;
877 for (
auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
878 newPortAnnos.push_back(
879 processAnnotations(portAnnoArrayAttr.getValue()));
881 op->setAttr(
"portAnnotations",
882 ArrayAttr::get(op->getContext(), newPortAnnos));
888 std::string
getName()
const override {
return "annotation-remover"; }
895struct SimplifyResets :
public OpReduction<CircuitOp> {
896 uint64_t
match(CircuitOp circuit)
override {
897 uint64_t numResets = 0;
898 AttrTypeWalker walker;
899 walker.addWalk([&](ResetType type) { ++numResets; });
901 circuit.walk([&](Operation *op) {
902 for (
auto result : op->getResults())
903 walker.walk(result.getType());
905 for (
auto ®ion : op->getRegions())
906 for (auto &block : region)
907 for (auto arg : block.getArguments())
908 walker.walk(arg.getType());
910 walker.walk(op->getAttrDictionary());
916 LogicalResult
rewrite(CircuitOp circuit)
override {
917 auto uint1Type = UIntType::get(circuit->getContext(), 1,
false);
918 auto constUint1Type = UIntType::get(circuit->getContext(), 1,
true);
920 AttrTypeReplacer replacer;
921 replacer.addReplacement([&](ResetType type) {
922 return type.isConst() ? constUint1Type : uint1Type;
924 replacer.recursivelyReplaceElementsIn(circuit,
true,
929 circuit.walk([&](Operation *op) {
932 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
933 fullAsyncResetAnnoClass,
934 ignoreFullAsyncResetAnnoClass);
938 if (
auto module = dyn_cast<FModuleLike>(op)) {
941 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
942 fullAsyncResetAnnoClass,
943 ignoreFullAsyncResetAnnoClass);
951 std::string
getName()
const override {
return "firrtl-simplify-resets"; }
957struct RootPortPruner :
public OpReduction<firrtl::FModuleOp> {
958 uint64_t
match(firrtl::FModuleOp module)
override {
959 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
962 return circuit.getNameAttr() ==
module.getNameAttr();
964 LogicalResult
rewrite(firrtl::FModuleOp module)
override {
966 size_t numPorts =
module.getNumPorts();
967 llvm::BitVector dropPorts(numPorts);
968 for (
unsigned i = 0; i != numPorts; ++i) {
972 llvm::make_early_inc_range(module.getArgument(i).getUsers()))
976 module.erasePorts(dropPorts);
979 std::string
getName()
const override {
return "root-port-pruner"; }
985struct RootExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
986 uint64_t
match(firrtl::FExtModuleOp module)
override {
987 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
988 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
991 return module.getNumPorts();
994 LogicalResult
rewrite(firrtl::FExtModuleOp module)
override {
996 size_t numPorts =
module.getNumPorts();
1000 llvm::BitVector dropPorts(numPorts);
1001 dropPorts.set(0, numPorts);
1002 module.erasePorts(dropPorts);
1006 std::string
getName()
const override {
return "root-extmodule-port-pruner"; }
1011struct ExtmoduleInstanceRemover :
public OpReduction<firrtl::InstanceOp> {
1016 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1018 uint64_t
match(firrtl::InstanceOp instOp)
override {
1019 return isa<firrtl::FExtModuleOp>(
1020 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1022 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
1024 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1025 symbols.getNearestSymbolTable(instOp)))
1027 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1029 SmallVector<Value> replacementWires;
1031 auto wire = firrtl::WireOp::create(
1033 (Twine(instOp.getName()) +
"_" +
info.getName()).str())
1035 if (
info.isOutput()) {
1037 if (
auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(
info.type)) {
1039 firrtl::ConnectOp::create(builder, wire, inv);
1040 }
else if (
auto propType = dyn_cast<firrtl::PropertyType>(
info.type)) {
1041 auto unknown = tieOffCache.
getUnknown(propType);
1042 builder.create<firrtl::PropAssignOp>(wire, unknown);
1045 replacementWires.push_back(wire);
1047 nlaRemover.markNLAsInOperation(instOp);
1048 instOp.replaceAllUsesWith(std::move(replacementWires));
1052 std::string
getName()
const override {
return "extmodule-instance-remover"; }
1064struct PortPrunerHelpers {
1066 template <
typename ModuleOpType>
1067 static void computeUnusedInstancePorts(ModuleOpType module,
1068 ArrayRef<Operation *> users,
1069 llvm::BitVector &portsToRemove) {
1070 auto ports =
module.getPorts();
1071 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1072 bool portUsed =
false;
1073 for (
auto *user : users) {
1074 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1075 auto result = instOp.getResult(portIdx);
1076 if (!result.use_empty()) {
1083 portsToRemove.set(portIdx);
1089 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1090 const llvm::BitVector &portsToRemove) {
1092 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1093 for (
auto *user : users) {
1094 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1095 instancesToUpdate.push_back(instOp);
1098 for (
auto instOp : instancesToUpdate) {
1099 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1102 size_t newResultIdx = 0;
1103 for (
size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1105 if (portsToRemove[oldResultIdx]) {
1107 assert(instOp.getResult(oldResultIdx).use_empty() &&
1108 "removing port with uses");
1111 instOp.getResult(oldResultIdx)
1112 .replaceAllUsesWith(newInst.getResult(newResultIdx));
1123struct ModulePortPruner :
public OpReduction<firrtl::FModuleOp> {
1127 portsToRemoveMap.clear();
1129 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1131 uint64_t
match(firrtl::FModuleOp module)
override {
1132 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1133 auto &userMap = symbols.getSymbolUserMap(tableOp);
1134 auto ports =
module.getPorts();
1135 auto users = userMap.getUsers(module);
1138 llvm::BitVector portsToRemove(ports.size());
1142 if (users.empty()) {
1143 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1144 auto arg =
module.getArgument(portIdx);
1145 if (arg.use_empty())
1146 portsToRemove.set(portIdx);
1151 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1155 auto count = portsToRemove.count();
1157 portsToRemoveMap[module] = std::move(portsToRemove);
1162 LogicalResult
rewrite(firrtl::FModuleOp module)
override {
1164 auto it = portsToRemoveMap.find(module);
1165 if (it == portsToRemoveMap.end())
1168 const auto &portsToRemove = it->second;
1171 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1172 auto &userMap = symbols.getSymbolUserMap(tableOp);
1173 auto users = userMap.getUsers(module);
1176 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1180 for (
size_t portIdx = 0; portIdx <
module.getNumPorts(); ++portIdx)
1181 if (portsToRemove[portIdx])
1184 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
1188 module.erasePorts(portsToRemove);
1193 std::string
getName()
const override {
return "module-port-pruner"; }
1197 DenseMap<firrtl::FModuleOp, llvm::BitVector> portsToRemoveMap;
1201struct ExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
1205 portsToRemoveMap.clear();
1207 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1209 uint64_t
match(firrtl::FExtModuleOp module)
override {
1210 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1211 auto &userMap = symbols.getSymbolUserMap(tableOp);
1212 auto ports =
module.getPorts();
1213 auto users = userMap.getUsers(module);
1216 llvm::BitVector portsToRemove(ports.size());
1218 if (users.empty()) {
1220 portsToRemove.set();
1224 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1228 auto count = portsToRemove.count();
1230 portsToRemoveMap[module] = std::move(portsToRemove);
1235 LogicalResult
rewrite(firrtl::FExtModuleOp module)
override {
1237 auto it = portsToRemoveMap.find(module);
1238 if (it == portsToRemoveMap.end())
1241 const auto &portsToRemove = it->second;
1244 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1245 auto &userMap = symbols.getSymbolUserMap(tableOp);
1246 auto users = userMap.getUsers(module);
1249 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1253 module.erasePorts(portsToRemove);
1258 std::string
getName()
const override {
return "extmodule-port-pruner"; }
1262 DenseMap<firrtl::FExtModuleOp, llvm::BitVector> portsToRemoveMap;
1266struct ConnectForwarder :
public Reduction {
1268 domInfo = std::make_unique<DominanceInfo>(op);
1271 uint64_t
match(Operation *op)
override {
1272 if (!isa<firrtl::FConnectLike>(op))
1274 auto dest = op->getOperand(0);
1275 auto src = op->getOperand(1);
1276 auto *destOp = dest.getDefiningOp();
1277 auto *srcOp = src.getDefiningOp();
1283 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1289 unsigned numConnects = 0;
1290 for (
auto &use : dest.getUses()) {
1291 auto *op = use.getOwner();
1292 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1293 if (++numConnects > 1)
1300 !domInfo->properlyDominates(srcOp, op,
false))
1307 LogicalResult
rewrite(Operation *op)
override {
1308 auto dst = op->getOperand(0);
1309 auto src = op->getOperand(1);
1310 dst.replaceAllUsesWith(src);
1312 if (
auto *dstOp = dst.getDefiningOp())
1314 if (
auto *srcOp = src.getDefiningOp())
1319 std::string
getName()
const override {
return "connect-forwarder"; }
1322 std::unique_ptr<DominanceInfo> domInfo;
1327template <
unsigned OpNum>
1328struct ConnectSourceOperandForwarder :
public Reduction {
1329 uint64_t
match(Operation *op)
override {
1330 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1332 auto dest = op->getOperand(0);
1333 auto *destOp = dest.getDefiningOp();
1336 if (!destOp || !destOp->hasOneUse() ||
1337 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1340 auto *srcOp = op->getOperand(1).getDefiningOp();
1341 if (!srcOp || OpNum >= srcOp->getNumOperands())
1344 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1346 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1348 return resultTy && opTy &&
1349 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1350 ((resultTy.getBitWidthOrSentinel() == -1) ==
1351 (opTy.getBitWidthOrSentinel() == -1)) &&
1352 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1355 LogicalResult
rewrite(Operation *op)
override {
1356 auto *destOp = op->getOperand(0).getDefiningOp();
1357 auto *srcOp = op->getOperand(1).getDefiningOp();
1358 auto forwardedOperand = srcOp->getOperand(OpNum);
1359 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1361 if (
auto wire = dyn_cast<firrtl::WireOp>(destOp))
1362 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1366 auto regName = destOp->getAttrOfType<StringAttr>(
"name");
1369 auto clock = destOp->getOperand(0);
1370 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1371 clock, regName ? regName.str() :
"")
1376 builder.setInsertionPointAfter(op);
1377 if (isa<firrtl::ConnectOp>(op))
1378 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1380 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1390 std::string
getName()
const override {
1391 return (
"connect-source-operand-" + Twine(OpNum) +
"-forwarder").str();
1398struct DetachSubaccesses :
public Reduction {
1399 void beforeReduction(mlir::ModuleOp op)
override { opsToErase.clear(); }
1401 for (
auto *op : opsToErase)
1402 op->dropAllReferences();
1403 for (
auto *op : opsToErase)
1406 uint64_t
match(Operation *op)
override {
1409 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1410 llvm::all_of(op->getUses(), [](
auto &use) {
1411 return use.getOperandNumber() == 0 &&
1412 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1413 firrtl::SubaccessOp>(use.getOwner());
1416 LogicalResult
rewrite(Operation *op)
override {
1418 OpBuilder builder(op);
1419 bool isWire = isa<firrtl::WireOp>(op);
1422 invalidClock = firrtl::InvalidValueOp::create(
1423 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1424 for (Operation *user :
llvm::make_early_inc_range(op->getUsers())) {
1425 builder.setInsertionPoint(user);
1426 auto type = user->getResult(0).getType();
1429 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1432 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1433 user->replaceAllUsesWith(replOp);
1434 opsToErase.insert(user);
1436 opsToErase.insert(op);
1439 std::string
getName()
const override {
return "detach-subaccesses"; }
1440 llvm::DenseSet<Operation *> opsToErase;
1446struct NodeSymbolRemover :
public Reduction {
1451 uint64_t
match(Operation *op)
override {
1453 auto sym = op->getAttrOfType<hw::InnerSymAttr>(
"inner_sym");
1454 if (!sym || sym.empty())
1458 if (innerSymUses.hasInnerRef(op))
1463 LogicalResult
rewrite(Operation *op)
override {
1464 op->removeAttr(
"inner_sym");
1468 std::string
getName()
const override {
return "node-symbol-remover"; }
1477hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1486 LogicalResult walkResult = targetTable.
walkSymbols(
1489 if (parentTable.lookup(name)) {
1497 return failed(walkResult);
1501struct EagerInliner :
public OpReduction<InstanceOp> {
1506 for (
auto circuitOp : op.getOps<CircuitOp>())
1507 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1508 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1511 nlaRemover.remove(op);
1513 innerSymTables.reset();
1516 uint64_t
match(InstanceOp instOp)
override {
1517 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1519 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1522 if (!isa<FModuleOp>(moduleOp))
1526 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1529 auto it = nlaTables.find(circuitOp);
1530 if (it == nlaTables.end() || !it->second)
1532 DenseSet<hw::HierPathOp> nlas;
1533 it->second->getInstanceNLAs(instOp, nlas);
1539 auto parentOp = instOp->getParentOfType<FModuleLike>();
1540 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1546 LogicalResult
rewrite(InstanceOp instOp)
override {
1547 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1548 auto moduleOp = cast<FModuleOp>(
1549 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1551 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1552 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1555 IRRewriter rewriter(instOp);
1556 SmallVector<Value> argWires;
1557 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1558 auto result = instOp.getResult(i);
1559 auto name = rewriter.getStringAttr(Twine(instOp.getName()) +
"_" +
1560 instOp.getPortName(i));
1561 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1562 name, NameKindEnum::DroppableName,
1563 instOp.getPortAnnotation(i), StringAttr{})
1565 result.replaceAllUsesWith(wire);
1566 argWires.push_back(wire);
1570 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1574 nlaRemover.markNLAsInOperation(instOp);
1576 nlaRemover.markNLAsInOperation(moduleOp);
1579 clonedModuleOp.erase();
1583 std::string
getName()
const override {
return "firrtl-eager-inliner"; }
1588 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1589 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1593struct ObjectInliner :
public OpReduction<ObjectOp> {
1595 blocksToSort.clear();
1598 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1601 for (
auto *block : blocksToSort)
1602 mlir::sortTopologically(block);
1603 blocksToSort.clear();
1604 nlaRemover.remove(op);
1605 innerSymTables.reset();
1608 uint64_t
match(ObjectOp objOp)
override {
1609 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1611 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1614 if (!isa<ClassOp>(classOp))
1619 auto parentOp = objOp->getParentOfType<FModuleLike>();
1620 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1624 for (
auto *user : objOp.getResult().getUsers())
1625 if (!isa<ObjectSubfieldOp>(user))
1631 LogicalResult
rewrite(ObjectOp objOp)
override {
1632 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1633 auto classOp = cast<ClassOp>(
1634 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1635 auto clonedClassOp = classOp.clone();
1638 IRRewriter rewriter(objOp);
1639 SmallVector<Value> portWires;
1640 auto classType = objOp.getType();
1643 for (
unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1644 auto element = classType.getElement(i);
1645 auto name = rewriter.getStringAttr(Twine(objOp.getName()) +
"_" +
1646 element.name.getValue());
1647 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1648 NameKindEnum::DroppableName,
1649 rewriter.getArrayAttr({}), StringAttr{})
1651 portWires.push_back(wire);
1655 SmallVector<ObjectSubfieldOp> subfieldOps;
1656 for (
auto *user : objOp.getResult().getUsers()) {
1657 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1658 subfieldOps.push_back(subfieldOp);
1659 auto index = subfieldOp.getIndex();
1660 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1664 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1670 SmallVector<FConnectLike> connectsToErase;
1671 for (
auto portWire : portWires) {
1675 for (
auto *user : portWire.getUsers()) {
1676 if (
auto connect = dyn_cast<FConnectLike>(user)) {
1677 if (
connect.getDest() == portWire) {
1679 connectsToErase.push_back(connect);
1689 portWire.replaceAllUsesWith(value);
1690 for (
auto connect : connectsToErase)
1692 if (portWire.use_empty())
1693 portWire.getDefiningOp()->erase();
1694 connectsToErase.clear();
1698 nlaRemover.markNLAsInOperation(objOp);
1703 blocksToSort.insert(objOp->getBlock());
1706 for (
auto subfieldOp : subfieldOps)
1709 clonedClassOp.erase();
1713 std::string
getName()
const override {
return "firrtl-object-inliner"; }
1716 SetVector<Block *> blocksToSort;
1719 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1733struct ModuleInternalNameSanitizer :
public Reduction {
1734 uint64_t
match(Operation *op)
override {
1736 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1737 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1738 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1739 firrtl::CoverOp>(op);
1741 LogicalResult
rewrite(Operation *op)
override {
1742 TypeSwitch<Operation *, void>(op)
1743 .Case<firrtl::WireOp>([](
auto op) { op.setName(
"wire"); })
1744 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1745 [](
auto op) { op.setName(
"reg"); })
1746 .Case<firrtl::NodeOp>([](
auto op) { op.setName(
"node"); })
1747 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1748 [](
auto op) { op.setName(
"mem"); })
1749 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](
auto op) {
1750 op->setAttr(
"message", StringAttr::get(op.getContext(),
""));
1751 op->setAttr(
"name", StringAttr::get(op.getContext(),
""));
1756 std::string
getName()
const override {
1757 return "module-internal-name-sanitizer";
1762 bool isOneShot()
const override {
return true; }
1776struct ModuleNameSanitizer :
OpReduction<firrtl::CircuitOp> {
1778 const char *names[48] = {
1779 "Foo",
"Bar",
"Baz",
"Qux",
"Quux",
"Quuux",
"Quuuux",
1780 "Quz",
"Corge",
"Grault",
"Bazola",
"Ztesch",
"Thud",
"Grunt",
1781 "Bletch",
"Fum",
"Fred",
"Jim",
"Sheila",
"Barney",
"Flarp",
1782 "Zxc",
"Spqr",
"Wombat",
"Shme",
"Bongo",
"Spam",
"Eggs",
1783 "Snork",
"Zot",
"Blarg",
"Wibble",
"Toto",
"Titi",
"Tata",
1784 "Tutu",
"Pippo",
"Pluto",
"Paperino",
"Aap",
"Noot",
"Mies",
1785 "Oogle",
"Foogle",
"Boogle",
"Zork",
"Gork",
"Bork"};
1787 size_t nameIndex = 0;
1790 if (nameIndex >= 48)
1792 return names[nameIndex++];
1795 size_t portNameIndex = 0;
1797 char getPortName() {
1798 if (portNameIndex >= 26)
1800 return 'a' + portNameIndex++;
1805 LogicalResult
rewrite(firrtl::CircuitOp circuitOp)
override {
1809 auto *circuitName =
getName();
1810 iGraph.getTopLevelModule().setName(circuitName);
1811 circuitOp.setName(circuitName);
1813 for (
auto *node : iGraph) {
1814 auto module = node->getModule<firrtl::FModuleLike>();
1816 bool shouldReplacePorts =
false;
1817 SmallVector<Attribute> newNames;
1818 if (
auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1823 auto oldPorts = fmodule.getPorts();
1824 shouldReplacePorts = !oldPorts.empty();
1825 for (
unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1826 auto port = oldPorts[i];
1828 .
Case<firrtl::ClockType>(
1829 [&](
auto a) {
return ns.
newName(
"clk"); })
1830 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1831 [&](
auto a) {
return ns.
newName(
"rst"); })
1832 .Case<firrtl::RefType>(
1833 [&](
auto a) {
return ns.
newName(
"ref"); })
1834 .Default([&](
auto a) {
1835 return ns.
newName(Twine(getPortName()));
1837 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1839 fmodule->setAttr(
"portNames",
1840 ArrayAttr::get(fmodule.getContext(), newNames));
1843 if (module == iGraph.getTopLevelModule())
1845 auto newName = StringAttr::get(circuitOp.getContext(),
getName());
1846 module.setName(newName);
1847 for (
auto *use : node->uses()) {
1848 auto useOp = use->getInstance();
1849 if (
auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1850 instanceOp.setModuleName(newName);
1851 instanceOp.setName(newName);
1852 if (shouldReplacePorts)
1853 instanceOp.setPortNamesAttr(
1854 ArrayAttr::get(circuitOp.getContext(), newNames));
1855 }
else if (
auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1858 auto oldClassType = objectOp.getType();
1859 auto newClassType = firrtl::ClassType::get(
1860 circuitOp.getContext(), FlatSymbolRefAttr::get(newName),
1861 oldClassType.getElements());
1862 objectOp.getResult().setType(newClassType);
1863 objectOp.setName(newName);
1871 std::string
getName()
const override {
return "module-name-sanitizer"; }
1875 bool isOneShot()
const override {
return true; }
1894struct ModuleSwapper :
public OpReduction<InstanceOp> {
1896 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1897 struct CircuitState {
1898 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1899 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1900 std::unique_ptr<NLATable> nlaTable;
1906 moduleSizes.clear();
1907 circuitStates.clear();
1910 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1911 auto &state = circuitStates[circuitOp];
1912 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1913 buildModuleTypeGroups(circuitOp, state);
1914 return WalkResult::skip();
1917 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1923 PortSignature getModulePortSignature(FModuleLike module) {
1924 PortSignature signature;
1925 signature.reserve(module.getNumPorts());
1926 for (
unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1927 signature.emplace_back(module.getPortType(i),
module.getPortDirection(i));
1932 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1934 for (
auto module : circuitOp.
getBodyBlock()->getOps<FModuleLike>()) {
1935 auto signature = getModulePortSignature(module);
1936 state.moduleTypeGroups[signature].push_back(module);
1940 for (
auto &[signature, modules] : state.moduleTypeGroups) {
1941 if (modules.size() <= 1)
1944 FModuleLike smallestModule =
nullptr;
1945 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1947 for (
auto module : modules) {
1948 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1949 if (size < smallestSize) {
1950 smallestSize = size;
1951 smallestModule =
module;
1956 for (
auto module : modules) {
1957 if (module != smallestModule) {
1958 state.instanceToCanonicalModule[
module.getModuleNameAttr()] =
1965 uint64_t
match(InstanceOp instOp)
override {
1967 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1969 const auto &state = circuitStates.at(circuitOp);
1972 DenseSet<hw::HierPathOp> nlas;
1973 state.nlaTable->getInstanceNLAs(instOp, nlas);
1978 auto moduleName = instOp.getModuleNameAttr().getAttr();
1979 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1980 if (!canonicalModule)
1984 auto currentModule = cast<FModuleLike>(
1985 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1986 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1987 uint64_t canonicalSize =
1988 moduleSizes.getModuleSize(canonicalModule, symbols);
1989 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1992 LogicalResult
rewrite(InstanceOp instOp)
override {
1994 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1996 const auto &state = circuitStates.at(circuitOp);
1999 auto canonicalModule = state.instanceToCanonicalModule.at(
2000 instOp.getModuleNameAttr().getAttr());
2001 auto canonicalName = canonicalModule.getModuleNameAttr();
2002 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2005 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2010 std::string
getName()
const override {
return "firrtl-module-swapper"; }
2019 DenseMap<CircuitOp, CircuitState> circuitStates;
2037struct ForceDedup :
public OpReduction<CircuitOp> {
2041 modulesToErase.clear();
2042 moduleSizes.clear();
2045 nlaRemover.remove(op);
2046 for (
auto mod : modulesToErase)
2051 void matches(CircuitOp circuitOp,
2052 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2053 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2055 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2056 if (!anno.
isClass(mustDeduplicateAnnoClass))
2059 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2060 if (!modulesAttr || modulesAttr.size() < 2)
2066 uint64_t totalSize = 0;
2067 ArrayAttr portTypes;
2068 DenseBoolArrayAttr portDirections;
2069 bool allSame =
true;
2070 for (
auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2076 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2081 totalSize += moduleSizes.getModuleSize(mod, symbols);
2083 portTypes = mod.getPortTypesAttr();
2084 portDirections = mod.getPortDirectionsAttr();
2085 }
else if (portTypes != mod.getPortTypesAttr() ||
2086 portDirections != mod.getPortDirectionsAttr()) {
2096 addMatch(totalSize, annoIdx);
2101 ArrayRef<uint64_t> matches)
override {
2102 auto *
context = circuitOp->getContext();
2106 SmallVector<Annotation> newAnnotations;
2108 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2110 if (!llvm::is_contained(matches, annoIdx)) {
2111 newAnnotations.push_back(anno);
2114 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2115 assert(anno.
isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2116 modulesAttr.size() >= 2);
2119 SmallVector<StringAttr> moduleNames;
2120 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2122 auto refStr = moduleRef.getValue();
2123 auto pipePos = refStr.find(
'|');
2124 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2125 auto moduleName = refStr.substr(pipePos + 1);
2126 moduleNames.push_back(StringAttr::get(
context, moduleName));
2131 if (moduleNames.size() < 2)
2136 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2137 nlaRemover.markNLAsInAnnotation(anno.
getAttr());
2139 if (newAnnotations.size() == annotations.size())
2144 newAnnoSet.applyToOperation(circuitOp);
2148 std::string
getName()
const override {
return "firrtl-force-dedup"; }
2154 void replaceModuleReferences(CircuitOp circuitOp,
2155 ArrayRef<StringAttr> moduleNames,
2158 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2159 auto &symbolTable = symbols.getSymbolTable(tableOp);
2160 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2161 auto *
context = circuitOp->getContext();
2165 FModuleLike canonicalModule;
2166 SmallVector<FModuleLike> modulesToReplace;
2167 for (
auto name : moduleNames) {
2168 if (
auto mod = symbolTable.lookup<FModuleLike>(name)) {
2169 if (!canonicalModule)
2170 canonicalModule = mod;
2172 modulesToReplace.push_back(mod);
2175 if (modulesToReplace.empty())
2179 auto canonicalName = canonicalModule.getModuleNameAttr();
2180 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2181 for (
auto moduleName : moduleNames) {
2182 if (moduleName == canonicalName)
2184 auto *symbolOp = symbolTable.lookup(moduleName);
2187 for (
auto *user : symbolUserMap.getUsers(symbolOp)) {
2188 auto instOp = dyn_cast<InstanceOp>(user);
2189 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2191 instOp.setModuleNameAttr(canonicalRef);
2192 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2198 for (
auto oldMod : modulesToReplace) {
2199 SmallVector<hw::HierPathOp> nlaOps(
2200 nlaTable.
lookup(oldMod.getModuleNameAttr()));
2201 for (
auto nlaOp : nlaOps) {
2202 nlaTable.
erase(nlaOp);
2203 StringAttr oldModName = oldMod.getModuleNameAttr();
2204 StringAttr newModName = canonicalName;
2205 SmallVector<Attribute, 4> newPath;
2206 for (
auto nameRef : nlaOp.getNamepath()) {
2207 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2208 if (ref.getModule() == oldModName) {
2209 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2210 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2211 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2212 if (oldInst && newInst) {
2215 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2216 auto newModNames = newInst.getReferencedModuleNamesAttr();
2217 if (!oldModNames.empty() && !newModNames.empty()) {
2218 oldModName = cast<StringAttr>(oldModNames[0]);
2219 newModName = cast<StringAttr>(newModNames[0]);
2223 newPath.push_back(ref);
2224 }
else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2225 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2227 newPath.push_back(nameRef);
2230 nlaOp.setNamepathAttr(ArrayAttr::get(
context, newPath));
2236 for (
auto module : modulesToReplace) {
2237 nlaRemover.markNLAsInOperation(module);
2238 modulesToErase.insert(module);
2244 SetVector<FModuleLike> modulesToErase;
2264struct MustDedupChildren :
public OpReduction<CircuitOp> {
2269 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
2273 void matches(CircuitOp circuitOp,
2274 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2276 uint64_t matchId = 0;
2278 DenseSet<StringRef> modulesAlreadyInMustDedup;
2279 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations))
2280 if (anno.isClass(mustDeduplicateAnnoClass))
2281 if (auto modulesAttr = anno.getMember<ArrayAttr>(
"modules"))
2282 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2284 modulesAlreadyInMustDedup.insert(target->module);
2286 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2287 if (!anno.
isClass(mustDeduplicateAnnoClass))
2290 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2291 if (!modulesAttr || modulesAttr.size() < 2)
2295 processInstanceGroups(
2296 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2300 SmallDenseSet<StringAttr, 4> moduleTargets;
2301 for (
auto instOp : instanceGroup) {
2302 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2303 for (
auto moduleName : moduleNames)
2304 moduleTargets.insert(cast<StringAttr>(moduleName));
2306 if (moduleTargets.size() < 2)
2311 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2312 auto moduleNames = inst.getReferencedModuleNames();
2313 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2314 return modulesAlreadyInMustDedup.contains(moduleName);
2319 addMatch(1, matchId - 1);
2325 ArrayRef<uint64_t> matches)
override {
2326 auto *
context = circuitOp->getContext();
2328 SmallVector<Annotation> newAnnotations;
2329 uint64_t matchId = 0;
2331 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2332 if (!anno.
isClass(mustDeduplicateAnnoClass)) {
2333 newAnnotations.push_back(anno);
2337 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2338 if (!modulesAttr || modulesAttr.size() < 2) {
2339 newAnnotations.push_back(anno);
2343 processInstanceGroups(
2344 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2346 if (!llvm::is_contained(matches, matchId++))
2350 SmallSetVector<StringAttr, 4> moduleTargets;
2351 for (
auto instOp : instanceGroup) {
2352 auto moduleNames = instOp.getReferencedModuleNames();
2353 for (
auto moduleName : moduleNames) {
2355 target.circuit = circuitOp.getName();
2356 target.module = moduleName;
2357 moduleTargets.insert(target.toStringAttr(
context));
2362 SmallVector<NamedAttribute> newAnnoAttrs;
2363 newAnnoAttrs.emplace_back(
2364 StringAttr::get(
context,
"class"),
2365 StringAttr::get(
context, mustDeduplicateAnnoClass));
2366 newAnnoAttrs.emplace_back(
2367 StringAttr::get(
context,
"modules"),
2369 SmallVector<Attribute>(moduleTargets.begin(),
2370 moduleTargets.end())));
2372 auto newAnnoDict = DictionaryAttr::get(
context, newAnnoAttrs);
2373 newAnnotations.emplace_back(newAnnoDict);
2377 newAnnotations.push_back(anno);
2382 newAnnoSet.applyToOperation(circuitOp);
2386 std::string
getName()
const override {
return "must-dedup-children"; }
2394 void processInstanceGroups(
2395 CircuitOp circuitOp, ArrayAttr modulesAttr,
2396 llvm::function_ref<
void(ArrayRef<FInstanceLike>)> callback) {
2397 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2400 SmallVector<FModuleLike> modules;
2401 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2403 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2404 modules.push_back(mod);
2407 if (modules.size() < 2)
2414 struct InstanceGroup {
2415 SmallVector<FInstanceLike> instances;
2416 bool nameIsUnique =
true;
2418 MapVector<StringAttr, InstanceGroup> instanceGroups;
2419 for (
auto module : modules) {
2421 module.walk([&](FInstanceLike instOp) {
2422 if (isa<ObjectOp>(instOp.getOperation()))
2424 auto name = instOp.getInstanceNameAttr();
2425 auto &group = instanceGroups[name];
2426 if (nameCounts[name]++ > 1)
2427 group.nameIsUnique =
false;
2428 group.instances.push_back(instOp);
2434 for (
auto &[name, group] : instanceGroups)
2435 if (group.nameIsUnique && group.instances.size() >= 2)
2436 callback(group.instances);
2443struct LayerDisable :
public OpReduction<CircuitOp> {
2444 LayerDisable(MLIRContext *
context) {
2445 pm = std::make_unique<mlir::PassManager>(
2446 context,
"builtin.module", mlir::OpPassManager::Nesting::Explicit);
2447 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2450 void beforeReduction(mlir::ModuleOp op)
override { symbolRefAttrMap.clear(); }
2452 void afterReduction(mlir::ModuleOp op)
override { (void)pm->run(op); };
2454 void matches(CircuitOp circuitOp,
2455 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2456 uint64_t matchId = 0;
2458 SmallVector<FlatSymbolRefAttr> nestedRefs;
2459 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2462 rootRef = layerOp.getSymNameAttr();
2464 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2466 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2467 addMatch(1, matchId++);
2469 for (
auto nestedLayerOp : layerOp.getOps<LayerOp>())
2470 addLayer(rootRef, nestedLayerOp);
2472 if (!nestedRefs.empty())
2473 nestedRefs.pop_back();
2476 for (
auto layerOp : circuitOp.getOps<LayerOp>())
2477 addLayer({}, layerOp);
2481 ArrayRef<uint64_t> matches)
override {
2482 SmallVector<Attribute> disableLayers;
2483 if (
auto existingDisables = circuitOp.getDisableLayersAttr()) {
2484 auto disableRange = existingDisables.getAsRange<Attribute>();
2485 disableLayers.append(disableRange.begin(), disableRange.end());
2487 for (
auto match : matches)
2488 disableLayers.push_back(symbolRefAttrMap.at(match));
2490 circuitOp.setDisableLayersAttr(
2491 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2496 std::string
getName()
const override {
return "firrtl-layer-disable"; }
2498 std::unique_ptr<mlir::PassManager> pm;
2499 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2509 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2511 auto elements = listOp.getElements();
2512 for (
size_t i = 0; i < elements.size(); ++i)
2517 ArrayRef<uint64_t>
matches)
override {
2519 llvm::SmallDenseSet<uint64_t, 4> matchesSet(
matches.begin(),
matches.end());
2522 SmallVector<Value> newElements;
2523 auto elements = listOp.getElements();
2524 for (
size_t i = 0; i < elements.size(); ++i) {
2525 if (!matchesSet.contains(i))
2526 newElements.push_back(elements[i]);
2530 OpBuilder builder(listOp);
2531 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2532 listOp.getType(), newElements);
2533 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2540 return "firrtl-list-create-element-remover";
2555 patterns.add<SimplifyResets, 35>();
2557 patterns.add<MustDedupChildren, 33>();
2558 patterns.add<AnnotationRemover, 32>();
2560 patterns.add<LayerDisable, 30>(getContext());
2566 firrtl::createLowerCHIRRTLPass(),
true,
true);
2571 patterns.add<FIRRTLModuleExternalizer, 25>();
2572 patterns.add<InstanceStubber, 24>();
2577 firrtl::createLowerFIRRTLTypes(),
true,
true);
2584 firrtl::createRemoveUnusedPorts({
true}));
2585 patterns.add<NodeSymbolRemover, 15>();
2586 patterns.add<ConnectForwarder, 14>();
2587 patterns.add<ConnectInvalidator, 13>();
2589 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2590 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2591 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2593 patterns.add<DetachSubaccesses, 7>();
2594 patterns.add<ModulePortPruner, 7>();
2595 patterns.add<ExtmodulePortPruner, 6>();
2597 patterns.add<RootExtmodulePortPruner, 5>();
2598 patterns.add<ExtmoduleInstanceRemover, 4>();
2599 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2600 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2601 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2602 patterns.add<ModuleInternalNameSanitizer, 0>();
2603 patterns.add<ModuleNameSanitizer, 0>();
2607 mlir::DialectRegistry ®istry) {
2608 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Helper class to cache tie-off values for different FIRRTL types.
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
@ None
Don't explicitly preserve any named values.
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
A reduction pattern that applies an mlir::Pass.
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)