24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/Matchers.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/Support/Debug.h"
33#define DEBUG_TYPE "firrtl-reductions"
37using namespace firrtl;
39using llvm::SmallDenseSet;
40using llvm::SmallSetVector;
53 return tables->getSymbolTable(op);
63 return userMaps.insert({op, SymbolUserMap(*
tables, op)}).first->second;
70 tables = std::make_unique<SymbolTableCollection>();
75 std::unique_ptr<SymbolTableCollection>
tables;
82static std::optional<firrtl::FModuleOp>
85 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
86 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
88 return moduleOp ? std::optional(moduleOp) : std::nullopt;
99 module->walk([&](Operation *op) {
101 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(op))
115 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
116 auto *op = use.getOwner();
117 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
119 if (use.getOperandNumber() != 0)
121 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
138 unsigned numRemoved = 0;
140 SymbolTableCollection symbolTables;
141 for (Operation &rootOp : *
module.getBody()) {
142 if (!isa<firrtl::CircuitOp>(&rootOp))
144 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
145 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
147 if (
auto *op = symbolTable.lookup(sym)) {
148 if (symbolUserMap.useEmpty(op)) {
157 if (numRemoved > 0 || numLost > 0) {
158 llvm::dbgs() <<
"Removed " << numRemoved <<
" NLAs";
160 llvm::dbgs() <<
" (" << numLost <<
" no longer there)";
161 llvm::dbgs() <<
"\n";
170 if (
auto dict = dyn_cast<DictionaryAttr>(anno)) {
171 if (
auto field = dict.getAs<FlatSymbolRefAttr>(
"circt.nonlocal"))
172 nlasToRemove.insert(field.getAttr());
173 for (
auto namedAttr : dict)
174 markNLAsInAnnotation(namedAttr.getValue());
175 }
else if (
auto array = dyn_cast<ArrayAttr>(anno)) {
176 for (
auto attr : array)
177 markNLAsInAnnotation(attr);
185 op->walk([&](Operation *op) {
186 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
187 markNLAsInAnnotation(annos);
202struct FIRRTLModuleExternalizer :
public OpReduction<FModuleOp> {
209 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
211 uint64_t
match(FModuleOp module)
override {
212 if (innerSymUses.hasInnerRef(module))
214 return moduleSizes.getModuleSize(module, symbols);
217 LogicalResult
rewrite(FModuleOp module)
override {
220 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
221 for (
auto attr :
module.getPortTypes()) {
222 auto type = cast<TypeAttr>(attr).getValue();
223 if (
auto refType = type_dyn_cast<RefType>(type))
224 if (
auto layer = refType.getLayer())
225 layers.insert(layer);
227 SmallVector<Attribute, 4> layersArray;
228 layersArray.reserve(layers.size());
229 for (
auto layer : layers)
230 layersArray.push_back(layer);
232 nlaRemover.markNLAsInOperation(module);
233 OpBuilder builder(module);
234 auto extmodule = FExtModuleOp::create(
235 builder, module->getLoc(),
236 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
237 module.getConventionAttr(), module.getPorts(),
238 builder.getArrayAttr(layersArray), StringRef(),
239 module.getAnnotationsAttr());
240 SymbolTable::setSymbolVisibility(extmodule,
241 SymbolTable::getSymbolVisibility(module));
246 std::string
getName()
const override {
return "firrtl-module-externalizer"; }
265static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
267 auto type = type_dyn_cast<FIRRTLType>(value.getType());
272 if (
auto refType = type_dyn_cast<RefType>(type)) {
274 assert(!
flip &&
"input probes are not allowed");
276 auto underlyingType = refType.getType();
278 if (!refType.getForceable()) {
281 auto targetWire = WireOp::create(builder, underlyingType);
282 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
283 builder.create<RefDefineOp>(value, refSend.getResult());
286 auto invalid = tieOffCache.
getInvalid(underlyingType);
287 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
293 WireOp::create(builder, underlyingType,
294 "", NameKindEnum::DroppableName,
295 ArrayRef<Attribute>{},
300 auto targetWire = forceableWire.getResult();
301 auto forceableRef = forceableWire.getDataRef();
303 builder.create<RefDefineOp>(value, forceableRef);
306 auto invalid = tieOffCache.
getInvalid(underlyingType);
307 MatchingConnectOp::create(builder, targetWire, invalid);
312 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
313 for (
auto element :
llvm::enumerate(bundleType.getElements())) {
314 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
315 invalidateOutputs(builder, subfield, tieOffCache,
316 flip ^ element.value().isFlip);
317 if (subfield.use_empty())
318 subfield.getDefiningOp()->erase();
324 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
325 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
326 auto subindex = builder.createOrFold<SubindexOp>(value, i);
327 invalidateOutputs(builder, subindex, tieOffCache,
flip);
328 if (subindex.use_empty())
329 subindex.getDefiningOp()->erase();
339 if (
auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
340 auto invalid = tieOffCache.
getInvalid(baseType);
341 ConnectOp::create(builder, value, invalid);
346 if (
auto propType = type_dyn_cast<PropertyType>(type)) {
347 auto unknown = tieOffCache.
getUnknown(propType);
348 builder.create<PropAssignOp>(value, unknown);
353static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
355 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
358 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
359 for (
auto element :
llvm::enumerate(bundleType.getElements()))
360 connectToLeafs(builder,
361 firrtl::SubfieldOp::create(builder, dest, element.index()),
365 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
366 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
367 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
371 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
374 auto destWidth = type.getBitWidthOrSentinel();
375 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
376 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
377 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
378 if (!isa<firrtl::UIntType>(type)) {
379 if (isa<firrtl::SIntType>(type))
380 value = firrtl::AsSIntPrimOp::create(builder, value);
384 firrtl::ConnectOp::create(builder, dest, value);
388static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
389 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
392 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
393 for (
auto element :
llvm::enumerate(bundleType.getElements()))
396 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
399 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
400 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
401 reduceXor(builder, into,
402 builder.createOrFold<firrtl::SubindexOp>(value, i));
405 if (!isa<firrtl::UIntType>(type)) {
406 if (isa<firrtl::SIntType>(type))
407 value = firrtl::AsUIntPrimOp::create(builder, value);
411 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
417struct InstanceStubber :
public OpReduction<firrtl::InstanceOp> {
420 erasedModules.clear();
428 SmallVector<Operation *> worklist;
429 auto deadInsts = erasedInsts;
430 for (
auto *op : erasedModules)
431 worklist.push_back(op);
432 while (!worklist.empty()) {
433 auto *op = worklist.pop_back_val();
434 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
435 op->walk([&](firrtl::InstanceOp instOp) {
436 auto moduleOp = cast<firrtl::FModuleLike>(
437 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
438 deadInsts.insert(instOp);
440 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
441 [&](Operation *user) { return deadInsts.contains(user); })) {
442 LLVM_DEBUG(llvm::dbgs() <<
"- Removing transitively unused module `"
443 << moduleOp.getModuleName() <<
"`\n");
444 erasedModules.insert(moduleOp);
445 worklist.push_back(moduleOp);
450 for (
auto *op : erasedInsts)
452 for (
auto *op : erasedModules)
454 nlaRemover.remove(op);
457 uint64_t
match(firrtl::InstanceOp instOp)
override {
459 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
463 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
464 LLVM_DEBUG(llvm::dbgs()
465 <<
"Stubbing instance `" << instOp.getName() <<
"`\n");
466 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
468 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
469 auto result = instOp.getResult(i);
470 auto name = builder.getStringAttr(Twine(instOp.getName()) +
"_" +
471 instOp.getPortName(i));
473 firrtl::WireOp::create(builder, result.getType(), name,
474 firrtl::NameKindEnum::DroppableName,
475 instOp.getPortAnnotation(i), StringAttr{})
477 invalidateOutputs(builder, wire, tieOffCache,
478 instOp.getPortDirection(i) == firrtl::Direction::In);
479 result.replaceAllUsesWith(wire);
481 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
482 auto moduleOp = cast<firrtl::FModuleLike>(
483 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
484 nlaRemover.markNLAsInOperation(instOp);
485 erasedInsts.insert(instOp);
487 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
488 [&](Operation *user) { return erasedInsts.contains(user); })) {
489 LLVM_DEBUG(llvm::dbgs() <<
"- Removing now unused module `"
490 << moduleOp.getModuleName() <<
"`\n");
491 erasedModules.insert(moduleOp);
496 std::string
getName()
const override {
return "instance-stubber"; }
501 llvm::DenseSet<Operation *> erasedInsts;
502 llvm::DenseSet<Operation *> erasedModules;
508struct MemoryStubber :
public OpReduction<firrtl::MemOp> {
509 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
510 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
511 LogicalResult
rewrite(firrtl::MemOp memOp)
override {
512 LLVM_DEBUG(llvm::dbgs() <<
"Stubbing memory `" << memOp.getName() <<
"`\n");
513 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
516 SmallVector<Value> outputs;
517 for (
unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
518 auto result = memOp.getResult(i);
519 auto name = builder.getStringAttr(Twine(memOp.getName()) +
"_" +
520 memOp.getPortName(i));
522 firrtl::WireOp::create(builder, result.getType(), name,
523 firrtl::NameKindEnum::DroppableName,
524 memOp.getPortAnnotation(i), StringAttr{})
526 invalidateOutputs(builder, wire, tieOffCache,
true);
527 result.replaceAllUsesWith(wire);
531 switch (memOp.getPortKind(i)) {
532 case firrtl::MemOp::PortKind::Read:
533 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
535 case firrtl::MemOp::PortKind::Write:
536 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
538 case firrtl::MemOp::PortKind::ReadWrite:
539 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
540 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
542 case firrtl::MemOp::PortKind::Debug:
547 if (!isa<firrtl::RefType>(result.getType())) {
550 cast<firrtl::BundleType>(wire.getType()).getNumElements();
551 for (
unsigned i = 0; i != numFields; ++i) {
552 if (i != 2 && i != 3 && i != 5)
553 reduceXor(builder, xorInputs,
554 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
557 reduceXor(builder, xorInputs, input);
562 outputs.push_back(output);
566 for (
auto output : outputs)
567 connectToLeafs(builder, output, xorInputs);
569 nlaRemover.markNLAsInOperation(memOp);
573 std::string
getName()
const override {
return "memory-stubber"; }
580static bool isFlowSensitiveOp(Operation *op) {
581 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
582 SubaccessOp, ObjectSubfieldOp>(op);
588template <
unsigned OpNum>
589struct FIRRTLOperandForwarder :
public Reduction {
590 uint64_t
match(Operation *op)
override {
591 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
593 if (isFlowSensitiveOp(op))
596 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
598 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
599 return resultTy && opTy &&
600 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
601 (resultTy.getBitWidthOrSentinel() == -1) ==
602 (opTy.getBitWidthOrSentinel() == -1) &&
603 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
605 LogicalResult
rewrite(Operation *op)
override {
607 ImplicitLocOpBuilder builder(op->getLoc(), op);
608 auto result = op->getResult(0);
609 auto operand = op->getOperand(OpNum);
610 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
611 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
612 auto resultWidth = resultTy.getBitWidthOrSentinel();
613 auto operandWidth = operandTy.getBitWidthOrSentinel();
615 if (resultWidth < operandWidth)
617 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
618 else if (resultWidth > operandWidth)
619 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
622 LLVM_DEBUG(llvm::dbgs() <<
"Forwarding " << newOp <<
" in " << *op <<
"\n");
623 result.replaceAllUsesWith(newOp);
627 std::string
getName()
const override {
628 return (
"firrtl-operand" + Twine(OpNum) +
"-forwarder").str();
639 anyrefCastDummy.clear();
640 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
641 for (
auto classOp : circuitOp.getOps<ClassOp>()) {
642 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
643 anyrefCastDummy.insert({circuitOp, classOp});
644 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
647 return WalkResult::skip();
651 uint64_t
match(Operation *op)
override {
652 if (op->hasTrait<OpTrait::ConstantLike>()) {
654 if (!matchPattern(op, m_Constant(&attr)))
656 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
657 if (intAttr.getValue().isZero())
659 if (
auto strAttr = dyn_cast<StringAttr>(attr))
662 if (
auto floatAttr = dyn_cast<FloatAttr>(attr))
663 if (floatAttr.getValue().isZero())
666 if (
auto listOp = dyn_cast<ListCreateOp>(op))
667 if (listOp.getElements().empty())
669 if (
auto pathOp = dyn_cast<UnresolvedPathOp>(op))
670 if (pathOp.getTarget().empty())
674 if (
auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
675 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
677 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
678 if (anyrefCastDummyNames[circuitOp].contains(className))
682 if (op->getNumResults() != 1)
684 if (op->hasAttr(
"inner_sym"))
686 if (isFlowSensitiveOp(op))
688 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
689 DoubleType, ListType, PathType, AnyRefType>(
690 op->getResult(0).getType());
693 LogicalResult
rewrite(Operation *op)
override {
694 OpBuilder builder(op);
695 auto type = op->getResult(0).getType();
698 if (isa<UIntType, SIntType>(type)) {
699 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
702 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
703 APSInt(width, isa<UIntType>(type)));
704 op->replaceAllUsesWith(newOp);
710 if (isa<StringType>(type)) {
711 auto attr = builder.getStringAttr(
"");
712 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
713 op->replaceAllUsesWith(newOp);
719 if (isa<FIntegerType>(type)) {
720 auto attr = builder.getIntegerAttr(builder.getIntegerType(64,
true), 0);
721 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
722 op->replaceAllUsesWith(newOp);
728 if (isa<BoolType>(type)) {
729 auto attr = builder.getBoolAttr(
false);
730 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
731 op->replaceAllUsesWith(newOp);
737 if (isa<DoubleType>(type)) {
738 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
739 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
740 op->replaceAllUsesWith(newOp);
746 if (isa<ListType>(type)) {
748 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
749 op->replaceAllUsesWith(newOp);
755 if (isa<PathType>(type)) {
756 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(),
"");
757 op->replaceAllUsesWith(newOp);
763 if (isa<AnyRefType>(type)) {
764 auto circuitOp = op->getParentOfType<CircuitOp>();
765 auto &dummy = anyrefCastDummy[circuitOp];
767 OpBuilder::InsertionGuard guard(builder);
768 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
769 auto &symbolTable = symbols.getNearestSymbolTable(op);
770 dummy = ClassOp::create(builder, op->getLoc(),
"Dummy", {}, {});
771 symbolTable.insert(dummy);
772 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
774 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy,
"dummy");
776 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
777 op->replaceAllUsesWith(anyrefOp);
785 std::string
getName()
const override {
return "firrtl-constantifier"; }
797struct ConnectInvalidator :
public Reduction {
798 uint64_t
match(Operation *op)
override {
799 if (!isa<FConnectLike>(op))
801 if (
auto *srcOp = op->getOperand(1).getDefiningOp())
802 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
803 isa<InvalidValueOp>(srcOp))
805 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
806 return type && type.isPassive();
808 LogicalResult
rewrite(Operation *op)
override {
810 auto rhs = op->getOperand(1);
811 OpBuilder builder(op);
812 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
813 auto *rhsOp = rhs.getDefiningOp();
814 op->setOperand(1, invOp);
819 std::string
getName()
const override {
return "connect-invalidator"; }
826struct AnnotationRemover :
public Reduction {
827 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
828 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
831 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
832 uint64_t matchId = 0;
835 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
836 for (
unsigned i = 0; i < annos.size(); ++i)
837 addMatch(1, matchId++);
840 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations"))
841 for (
auto portAnnoArray : portAnnos)
842 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
843 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
844 addMatch(1, matchId++);
848 ArrayRef<uint64_t> matches)
override {
850 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
853 uint64_t matchId = 0;
854 auto processAnnotations =
855 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
856 SmallVector<Attribute> newAnnotations;
857 for (
auto anno : annotations) {
858 if (!matchesSet.contains(matchId)) {
859 newAnnotations.push_back(anno);
862 nlaRemover.markNLAsInAnnotation(anno);
866 return ArrayAttr::get(op->getContext(), newAnnotations);
870 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations")) {
871 op->setAttr(
"annotations", processAnnotations(annos.getValue()));
875 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations")) {
876 SmallVector<Attribute> newPortAnnos;
877 for (
auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
878 newPortAnnos.push_back(
879 processAnnotations(portAnnoArrayAttr.getValue()));
881 op->setAttr(
"portAnnotations",
882 ArrayAttr::get(op->getContext(), newPortAnnos));
888 std::string
getName()
const override {
return "annotation-remover"; }
895struct SimplifyResets :
public OpReduction<CircuitOp> {
896 uint64_t
match(CircuitOp circuit)
override {
897 uint64_t numResets = 0;
898 AttrTypeWalker walker;
899 walker.addWalk([&](ResetType type) { ++numResets; });
901 circuit.walk([&](Operation *op) {
902 for (
auto result : op->getResults())
903 walker.walk(result.getType());
905 for (
auto ®ion : op->getRegions())
906 for (auto &block : region)
907 for (auto arg : block.getArguments())
908 walker.walk(arg.getType());
910 walker.walk(op->getAttrDictionary());
916 LogicalResult
rewrite(CircuitOp circuit)
override {
917 auto uint1Type = UIntType::get(circuit->getContext(), 1,
false);
918 auto constUint1Type = UIntType::get(circuit->getContext(), 1,
true);
920 AttrTypeReplacer replacer;
921 replacer.addReplacement([&](ResetType type) {
922 return type.isConst() ? constUint1Type : uint1Type;
924 replacer.recursivelyReplaceElementsIn(circuit,
true,
929 circuit.walk([&](Operation *op) {
932 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
933 fullAsyncResetAnnoClass,
934 ignoreFullAsyncResetAnnoClass);
938 if (
auto module = dyn_cast<FModuleLike>(op)) {
941 return anno.
isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
942 fullAsyncResetAnnoClass,
943 ignoreFullAsyncResetAnnoClass);
951 std::string
getName()
const override {
return "firrtl-simplify-resets"; }
957struct RootPortPruner :
public OpReduction<firrtl::FModuleOp> {
958 void matches(firrtl::FModuleOp module,
959 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
960 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
961 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
965 size_t numPorts =
module.getNumPorts();
966 for (
unsigned i = 0; i != numPorts; ++i) {
973 ArrayRef<uint64_t> matches)
override {
975 llvm::BitVector dropPorts(module.getNumPorts());
976 for (
auto portIdx : matches)
977 dropPorts.set(portIdx);
980 for (
auto portIdx : matches) {
982 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
987 module.erasePorts(dropPorts);
991 std::string
getName()
const override {
return "root-port-pruner"; }
997struct RootExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
998 void matches(firrtl::FExtModuleOp module,
999 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1000 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
1001 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
1006 size_t numPorts =
module.getNumPorts();
1007 for (
unsigned i = 0; i != numPorts; ++i)
1012 ArrayRef<uint64_t> matches)
override {
1013 if (matches.empty())
1017 llvm::BitVector dropPorts(module.getNumPorts());
1018 for (
auto portIdx : matches)
1019 dropPorts.set(portIdx);
1022 module.erasePorts(dropPorts);
1026 std::string
getName()
const override {
return "root-extmodule-port-pruner"; }
1031struct ExtmoduleInstanceRemover :
public OpReduction<firrtl::InstanceOp> {
1036 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1038 uint64_t
match(firrtl::InstanceOp instOp)
override {
1039 return isa<firrtl::FExtModuleOp>(
1040 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1042 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
1044 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1045 symbols.getNearestSymbolTable(instOp)))
1047 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1049 SmallVector<Value> replacementWires;
1051 auto wire = firrtl::WireOp::create(
1053 (Twine(instOp.getName()) +
"_" +
info.getName()).str())
1055 if (
info.isOutput()) {
1057 if (
auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(
info.type)) {
1059 firrtl::ConnectOp::create(builder, wire, inv);
1060 }
else if (
auto propType = dyn_cast<firrtl::PropertyType>(
info.type)) {
1061 auto unknown = tieOffCache.
getUnknown(propType);
1062 builder.create<firrtl::PropAssignOp>(wire, unknown);
1065 replacementWires.push_back(wire);
1067 nlaRemover.markNLAsInOperation(instOp);
1068 instOp.replaceAllUsesWith(std::move(replacementWires));
1072 std::string
getName()
const override {
return "extmodule-instance-remover"; }
1084struct PortPrunerHelpers {
1086 template <
typename ModuleOpType>
1087 static void computeUnusedInstancePorts(ModuleOpType module,
1088 ArrayRef<Operation *> users,
1089 llvm::BitVector &portsToRemove) {
1090 auto ports =
module.getPorts();
1091 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1092 bool portUsed =
false;
1093 for (
auto *user : users) {
1094 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1095 auto result = instOp.getResult(portIdx);
1096 if (!result.use_empty()) {
1103 portsToRemove.set(portIdx);
1109 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1110 const llvm::BitVector &portsToRemove) {
1112 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1113 for (
auto *user : users) {
1114 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1115 instancesToUpdate.push_back(instOp);
1118 for (
auto instOp : instancesToUpdate) {
1119 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1122 size_t newResultIdx = 0;
1123 for (
size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1125 if (portsToRemove[oldResultIdx]) {
1127 assert(instOp.getResult(oldResultIdx).use_empty() &&
1128 "removing port with uses");
1131 instOp.getResult(oldResultIdx)
1132 .replaceAllUsesWith(newInst.getResult(newResultIdx));
1143struct ModulePortPruner :
public OpReduction<firrtl::FModuleOp> {
1148 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1150 void matches(firrtl::FModuleOp module,
1151 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1152 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1153 auto &userMap = symbols.getSymbolUserMap(tableOp);
1154 auto ports =
module.getPorts();
1155 auto users = userMap.getUsers(module);
1158 llvm::BitVector portsToRemove(ports.size());
1162 if (users.empty()) {
1163 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1164 auto arg =
module.getArgument(portIdx);
1165 if (arg.use_empty())
1166 portsToRemove.set(portIdx);
1171 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1176 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1177 if (portsToRemove[portIdx])
1178 addMatch(1, portIdx);
1182 ArrayRef<uint64_t> matches)
override {
1183 if (matches.empty())
1187 llvm::BitVector portsToRemove(module.getNumPorts());
1188 for (
auto portIdx : matches)
1189 portsToRemove.set(portIdx);
1192 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1193 auto &userMap = symbols.getSymbolUserMap(tableOp);
1194 auto users = userMap.getUsers(module);
1197 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1201 for (
auto portIdx : matches) {
1204 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
1209 module.erasePorts(portsToRemove);
1214 std::string
getName()
const override {
return "module-port-pruner"; }
1221struct ExtmodulePortPruner :
public OpReduction<firrtl::FExtModuleOp> {
1226 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1228 void matches(firrtl::FExtModuleOp module,
1229 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1230 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1231 auto &userMap = symbols.getSymbolUserMap(tableOp);
1232 auto ports =
module.getPorts();
1233 auto users = userMap.getUsers(module);
1236 llvm::BitVector portsToRemove(ports.size());
1238 if (users.empty()) {
1240 portsToRemove.set();
1244 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1249 for (
size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1250 if (portsToRemove[portIdx])
1251 addMatch(1, portIdx);
1255 ArrayRef<uint64_t> matches)
override {
1256 if (matches.empty())
1260 llvm::BitVector portsToRemove(module.getNumPorts());
1261 for (
auto portIdx : matches)
1262 portsToRemove.set(portIdx);
1265 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1266 auto &userMap = symbols.getSymbolUserMap(tableOp);
1267 auto users = userMap.getUsers(module);
1270 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1274 module.erasePorts(portsToRemove);
1279 std::string
getName()
const override {
return "extmodule-port-pruner"; }
1286struct ConnectForwarder :
public Reduction {
1288 domInfo = std::make_unique<DominanceInfo>(op);
1291 uint64_t
match(Operation *op)
override {
1292 if (!isa<firrtl::FConnectLike>(op))
1294 auto dest = op->getOperand(0);
1295 auto src = op->getOperand(1);
1296 auto *destOp = dest.getDefiningOp();
1297 auto *srcOp = src.getDefiningOp();
1303 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1309 unsigned numConnects = 0;
1310 for (
auto &use : dest.getUses()) {
1311 auto *op = use.getOwner();
1312 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1313 if (++numConnects > 1)
1320 !domInfo->properlyDominates(srcOp, op,
false))
1327 LogicalResult
rewrite(Operation *op)
override {
1328 auto dst = op->getOperand(0);
1329 auto src = op->getOperand(1);
1330 dst.replaceAllUsesWith(src);
1332 if (
auto *dstOp = dst.getDefiningOp())
1334 if (
auto *srcOp = src.getDefiningOp())
1339 std::string
getName()
const override {
return "connect-forwarder"; }
1342 std::unique_ptr<DominanceInfo> domInfo;
1347template <
unsigned OpNum>
1348struct ConnectSourceOperandForwarder :
public Reduction {
1349 uint64_t
match(Operation *op)
override {
1350 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1352 auto dest = op->getOperand(0);
1353 auto *destOp = dest.getDefiningOp();
1356 if (!destOp || !destOp->hasOneUse() ||
1357 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1360 auto *srcOp = op->getOperand(1).getDefiningOp();
1361 if (!srcOp || OpNum >= srcOp->getNumOperands())
1364 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1366 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1368 return resultTy && opTy &&
1369 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1370 ((resultTy.getBitWidthOrSentinel() == -1) ==
1371 (opTy.getBitWidthOrSentinel() == -1)) &&
1372 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1375 LogicalResult
rewrite(Operation *op)
override {
1376 auto *destOp = op->getOperand(0).getDefiningOp();
1377 auto *srcOp = op->getOperand(1).getDefiningOp();
1378 auto forwardedOperand = srcOp->getOperand(OpNum);
1379 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1381 if (
auto wire = dyn_cast<firrtl::WireOp>(destOp))
1382 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1386 auto regName = destOp->getAttrOfType<StringAttr>(
"name");
1389 auto clock = destOp->getOperand(0);
1390 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1391 clock, regName ? regName.str() :
"")
1396 builder.setInsertionPointAfter(op);
1397 if (isa<firrtl::ConnectOp>(op))
1398 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1400 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1410 std::string
getName()
const override {
1411 return (
"connect-source-operand-" + Twine(OpNum) +
"-forwarder").str();
1418struct DetachSubaccesses :
public Reduction {
1419 void beforeReduction(mlir::ModuleOp op)
override { opsToErase.clear(); }
1421 for (
auto *op : opsToErase)
1422 op->dropAllReferences();
1423 for (
auto *op : opsToErase)
1426 uint64_t
match(Operation *op)
override {
1429 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1430 llvm::all_of(op->getUses(), [](
auto &use) {
1431 return use.getOperandNumber() == 0 &&
1432 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1433 firrtl::SubaccessOp>(use.getOwner());
1436 LogicalResult
rewrite(Operation *op)
override {
1438 OpBuilder builder(op);
1439 bool isWire = isa<firrtl::WireOp>(op);
1442 invalidClock = firrtl::InvalidValueOp::create(
1443 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1444 for (Operation *user :
llvm::make_early_inc_range(op->getUsers())) {
1445 builder.setInsertionPoint(user);
1446 auto type = user->getResult(0).getType();
1449 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1452 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1453 user->replaceAllUsesWith(replOp);
1454 opsToErase.insert(user);
1456 opsToErase.insert(op);
1459 std::string
getName()
const override {
return "detach-subaccesses"; }
1460 llvm::DenseSet<Operation *> opsToErase;
1466struct NodeSymbolRemover :
public Reduction {
1471 uint64_t
match(Operation *op)
override {
1473 auto sym = op->getAttrOfType<hw::InnerSymAttr>(
"inner_sym");
1474 if (!sym || sym.empty())
1478 if (innerSymUses.hasInnerRef(op))
1483 LogicalResult
rewrite(Operation *op)
override {
1484 op->removeAttr(
"inner_sym");
1488 std::string
getName()
const override {
return "node-symbol-remover"; }
1497hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1506 LogicalResult walkResult = targetTable.
walkSymbols(
1509 if (parentTable.lookup(name)) {
1517 return failed(walkResult);
1521struct EagerInliner :
public OpReduction<InstanceOp> {
1526 for (
auto circuitOp : op.getOps<CircuitOp>())
1527 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1528 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1531 nlaRemover.remove(op);
1533 innerSymTables.reset();
1536 uint64_t
match(InstanceOp instOp)
override {
1537 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1539 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1542 if (!isa<FModuleOp>(moduleOp))
1546 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1549 auto it = nlaTables.find(circuitOp);
1550 if (it == nlaTables.end() || !it->second)
1552 DenseSet<hw::HierPathOp> nlas;
1553 it->second->getInstanceNLAs(instOp, nlas);
1559 auto parentOp = instOp->getParentOfType<FModuleLike>();
1560 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1566 LogicalResult
rewrite(InstanceOp instOp)
override {
1567 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1568 auto moduleOp = cast<FModuleOp>(
1569 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1571 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1572 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1575 IRRewriter rewriter(instOp);
1576 SmallVector<Value> argWires;
1577 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1578 auto result = instOp.getResult(i);
1579 auto name = rewriter.getStringAttr(Twine(instOp.getName()) +
"_" +
1580 instOp.getPortName(i));
1581 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1582 name, NameKindEnum::DroppableName,
1583 instOp.getPortAnnotation(i), StringAttr{})
1585 result.replaceAllUsesWith(wire);
1586 argWires.push_back(wire);
1590 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1594 nlaRemover.markNLAsInOperation(instOp);
1596 nlaRemover.markNLAsInOperation(moduleOp);
1599 clonedModuleOp.erase();
1603 std::string
getName()
const override {
return "firrtl-eager-inliner"; }
1608 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1609 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1613struct ObjectInliner :
public OpReduction<ObjectOp> {
1615 blocksToSort.clear();
1618 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1621 for (
auto *block : blocksToSort)
1622 mlir::sortTopologically(block);
1623 blocksToSort.clear();
1624 nlaRemover.remove(op);
1625 innerSymTables.reset();
1628 uint64_t
match(ObjectOp objOp)
override {
1629 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1631 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1634 if (!isa<ClassOp>(classOp))
1639 auto parentOp = objOp->getParentOfType<FModuleLike>();
1640 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1644 for (
auto *user : objOp.getResult().getUsers())
1645 if (!isa<ObjectSubfieldOp>(user))
1651 LogicalResult
rewrite(ObjectOp objOp)
override {
1652 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1653 auto classOp = cast<ClassOp>(
1654 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1655 auto clonedClassOp = classOp.clone();
1658 IRRewriter rewriter(objOp);
1659 SmallVector<Value> portWires;
1660 auto classType = objOp.getType();
1663 for (
unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1664 auto element = classType.getElement(i);
1665 auto name = rewriter.getStringAttr(Twine(objOp.getName()) +
"_" +
1666 element.name.getValue());
1667 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1668 NameKindEnum::DroppableName,
1669 rewriter.getArrayAttr({}), StringAttr{})
1671 portWires.push_back(wire);
1675 SmallVector<ObjectSubfieldOp> subfieldOps;
1676 for (
auto *user : objOp.getResult().getUsers()) {
1677 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1678 subfieldOps.push_back(subfieldOp);
1679 auto index = subfieldOp.getIndex();
1680 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1684 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1690 SmallVector<FConnectLike> connectsToErase;
1691 for (
auto portWire : portWires) {
1695 for (
auto *user : portWire.getUsers()) {
1696 if (
auto connect = dyn_cast<FConnectLike>(user)) {
1697 if (
connect.getDest() == portWire) {
1699 connectsToErase.push_back(connect);
1709 portWire.replaceAllUsesWith(value);
1710 for (
auto connect : connectsToErase)
1712 if (portWire.use_empty())
1713 portWire.getDefiningOp()->erase();
1714 connectsToErase.clear();
1718 nlaRemover.markNLAsInOperation(objOp);
1723 blocksToSort.insert(objOp->getBlock());
1726 for (
auto subfieldOp : subfieldOps)
1729 clonedClassOp.erase();
1733 std::string
getName()
const override {
return "firrtl-object-inliner"; }
1736 SetVector<Block *> blocksToSort;
1739 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1754 uint64_t
match(Operation *op)
override {
1756 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1757 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1758 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1759 firrtl::CoverOp>(op);
1761 LogicalResult
rewrite(Operation *op)
override {
1762 TypeSwitch<Operation *, void>(op)
1763 .Case<firrtl::WireOp>([](
auto op) { op.setName(
"wire"); })
1764 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1765 [](
auto op) { op.setName(
"reg"); })
1766 .Case<firrtl::NodeOp>([](
auto op) { op.setName(
"node"); })
1767 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1768 [](
auto op) { op.setName(
"mem"); })
1769 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](
auto op) {
1770 op->setAttr(
"message", StringAttr::get(op.getContext(),
""));
1771 op->setAttr(
"name", StringAttr::get(op.getContext(),
""));
1776 std::string
getName()
const override {
1777 return "module-internal-name-sanitizer";
1782 bool isOneShot()
const override {
return true; }
1802 if (portNameIndex >= 26)
1804 return 'a' + portNameIndex++;
1809 LogicalResult
rewrite(firrtl::CircuitOp circuitOp)
override {
1814 iGraph.getTopLevelModule().setName(circuitName);
1815 circuitOp.setName(circuitName);
1817 for (
auto *node : iGraph) {
1818 auto module = node->getModule<firrtl::FModuleLike>();
1820 bool shouldReplacePorts =
false;
1821 SmallVector<Attribute> newNames;
1822 if (
auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1827 auto oldPorts = fmodule.getPorts();
1828 shouldReplacePorts = !oldPorts.empty();
1829 for (
unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1830 auto port = oldPorts[i];
1832 .
Case<firrtl::ClockType>(
1833 [&](
auto a) {
return ns.
newName(
"clk"); })
1834 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1835 [&](
auto a) {
return ns.
newName(
"rst"); })
1836 .Case<firrtl::RefType>(
1837 [&](
auto a) {
return ns.
newName(
"ref"); })
1838 .Default([&](
auto a) {
1841 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1843 fmodule->setAttr(
"portNames",
1844 ArrayAttr::get(fmodule.getContext(), newNames));
1847 if (module == iGraph.getTopLevelModule())
1850 StringAttr::get(circuitOp.getContext(), nameGenerator.
getNextName());
1851 module.setName(newName);
1852 for (
auto *use : node->uses()) {
1853 auto useOp = use->getInstance();
1854 if (
auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1855 instanceOp.setModuleName(newName);
1856 instanceOp.setName(newName);
1857 if (shouldReplacePorts)
1858 instanceOp.setPortNamesAttr(
1859 ArrayAttr::get(circuitOp.getContext(), newNames));
1860 }
else if (
auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1863 auto oldClassType = objectOp.getType();
1864 auto newClassType = firrtl::ClassType::get(
1865 circuitOp.getContext(), FlatSymbolRefAttr::get(newName),
1866 oldClassType.getElements());
1867 objectOp.getResult().setType(newClassType);
1868 objectOp.setName(newName);
1876 std::string
getName()
const override {
return "module-name-sanitizer"; }
1880 bool isOneShot()
const override {
return true; }
1899struct ModuleSwapper :
public OpReduction<InstanceOp> {
1901 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1902 struct CircuitState {
1903 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1904 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1905 std::unique_ptr<NLATable> nlaTable;
1911 moduleSizes.clear();
1912 circuitStates.clear();
1915 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1916 auto &state = circuitStates[circuitOp];
1917 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1918 buildModuleTypeGroups(circuitOp, state);
1919 return WalkResult::skip();
1922 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1928 PortSignature getModulePortSignature(FModuleLike module) {
1929 PortSignature signature;
1930 signature.reserve(module.getNumPorts());
1931 for (
unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1932 signature.emplace_back(module.getPortType(i),
module.getPortDirection(i));
1937 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1939 for (
auto module : circuitOp.
getBodyBlock()->getOps<FModuleLike>()) {
1940 auto signature = getModulePortSignature(module);
1941 state.moduleTypeGroups[signature].push_back(module);
1945 for (
auto &[signature, modules] : state.moduleTypeGroups) {
1946 if (modules.size() <= 1)
1949 FModuleLike smallestModule =
nullptr;
1950 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1952 for (
auto module : modules) {
1953 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1954 if (size < smallestSize) {
1955 smallestSize = size;
1956 smallestModule =
module;
1961 for (
auto module : modules) {
1962 if (module != smallestModule) {
1963 state.instanceToCanonicalModule[
module.getModuleNameAttr()] =
1970 uint64_t
match(InstanceOp instOp)
override {
1972 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1974 const auto &state = circuitStates.at(circuitOp);
1977 DenseSet<hw::HierPathOp> nlas;
1978 state.nlaTable->getInstanceNLAs(instOp, nlas);
1983 auto moduleName = instOp.getModuleNameAttr().getAttr();
1984 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1985 if (!canonicalModule)
1989 auto currentModule = cast<FModuleLike>(
1990 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1991 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1992 uint64_t canonicalSize =
1993 moduleSizes.getModuleSize(canonicalModule, symbols);
1994 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1997 LogicalResult
rewrite(InstanceOp instOp)
override {
1999 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2001 const auto &state = circuitStates.at(circuitOp);
2004 auto canonicalModule = state.instanceToCanonicalModule.at(
2005 instOp.getModuleNameAttr().getAttr());
2006 auto canonicalName = canonicalModule.getModuleNameAttr();
2007 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2010 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2015 std::string
getName()
const override {
return "firrtl-module-swapper"; }
2024 DenseMap<CircuitOp, CircuitState> circuitStates;
2042struct ForceDedup :
public OpReduction<CircuitOp> {
2046 modulesToErase.clear();
2047 moduleSizes.clear();
2050 nlaRemover.remove(op);
2051 for (
auto mod : modulesToErase)
2056 void matches(CircuitOp circuitOp,
2057 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2058 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2060 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2061 if (!anno.
isClass(mustDeduplicateAnnoClass))
2064 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2065 if (!modulesAttr || modulesAttr.size() < 2)
2071 uint64_t totalSize = 0;
2072 ArrayAttr portTypes;
2073 DenseBoolArrayAttr portDirections;
2074 bool allSame =
true;
2075 for (
auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2081 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2086 totalSize += moduleSizes.getModuleSize(mod, symbols);
2088 portTypes = mod.getPortTypesAttr();
2089 portDirections = mod.getPortDirectionsAttr();
2090 }
else if (portTypes != mod.getPortTypesAttr() ||
2091 portDirections != mod.getPortDirectionsAttr()) {
2101 addMatch(totalSize, annoIdx);
2106 ArrayRef<uint64_t> matches)
override {
2107 auto *
context = circuitOp->getContext();
2111 SmallVector<Annotation> newAnnotations;
2113 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2115 if (!llvm::is_contained(matches, annoIdx)) {
2116 newAnnotations.push_back(anno);
2119 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2120 assert(anno.
isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2121 modulesAttr.size() >= 2);
2124 SmallVector<StringAttr> moduleNames;
2125 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2127 auto refStr = moduleRef.getValue();
2128 auto pipePos = refStr.find(
'|');
2129 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2130 auto moduleName = refStr.substr(pipePos + 1);
2131 moduleNames.push_back(StringAttr::get(
context, moduleName));
2136 if (moduleNames.size() < 2)
2141 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2142 nlaRemover.markNLAsInAnnotation(anno.
getAttr());
2144 if (newAnnotations.size() == annotations.size())
2149 newAnnoSet.applyToOperation(circuitOp);
2153 std::string
getName()
const override {
return "firrtl-force-dedup"; }
2159 void replaceModuleReferences(CircuitOp circuitOp,
2160 ArrayRef<StringAttr> moduleNames,
2163 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2164 auto &symbolTable = symbols.getSymbolTable(tableOp);
2165 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2166 auto *
context = circuitOp->getContext();
2170 FModuleLike canonicalModule;
2171 SmallVector<FModuleLike> modulesToReplace;
2172 for (
auto name : moduleNames) {
2173 if (
auto mod = symbolTable.lookup<FModuleLike>(name)) {
2174 if (!canonicalModule)
2175 canonicalModule = mod;
2177 modulesToReplace.push_back(mod);
2180 if (modulesToReplace.empty())
2184 auto canonicalName = canonicalModule.getModuleNameAttr();
2185 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2186 for (
auto moduleName : moduleNames) {
2187 if (moduleName == canonicalName)
2189 auto *symbolOp = symbolTable.lookup(moduleName);
2192 for (
auto *user : symbolUserMap.getUsers(symbolOp)) {
2193 auto instOp = dyn_cast<InstanceOp>(user);
2194 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2196 instOp.setModuleNameAttr(canonicalRef);
2197 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2203 for (
auto oldMod : modulesToReplace) {
2204 SmallVector<hw::HierPathOp> nlaOps(
2205 nlaTable.
lookup(oldMod.getModuleNameAttr()));
2206 for (
auto nlaOp : nlaOps) {
2207 nlaTable.
erase(nlaOp);
2208 StringAttr oldModName = oldMod.getModuleNameAttr();
2209 StringAttr newModName = canonicalName;
2210 SmallVector<Attribute, 4> newPath;
2211 for (
auto nameRef : nlaOp.getNamepath()) {
2212 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2213 if (ref.getModule() == oldModName) {
2214 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2215 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2216 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2217 if (oldInst && newInst) {
2220 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2221 auto newModNames = newInst.getReferencedModuleNamesAttr();
2222 if (!oldModNames.empty() && !newModNames.empty()) {
2223 oldModName = cast<StringAttr>(oldModNames[0]);
2224 newModName = cast<StringAttr>(newModNames[0]);
2228 newPath.push_back(ref);
2229 }
else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2230 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2232 newPath.push_back(nameRef);
2235 nlaOp.setNamepathAttr(ArrayAttr::get(
context, newPath));
2241 for (
auto module : modulesToReplace) {
2242 nlaRemover.markNLAsInOperation(module);
2243 modulesToErase.insert(module);
2249 SetVector<FModuleLike> modulesToErase;
2269struct MustDedupChildren :
public OpReduction<CircuitOp> {
2274 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
2278 void matches(CircuitOp circuitOp,
2279 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2281 uint64_t matchId = 0;
2283 DenseSet<StringRef> modulesAlreadyInMustDedup;
2284 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations))
2285 if (anno.isClass(mustDeduplicateAnnoClass))
2286 if (auto modulesAttr = anno.getMember<ArrayAttr>(
"modules"))
2287 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2289 modulesAlreadyInMustDedup.insert(target->module);
2291 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2292 if (!anno.
isClass(mustDeduplicateAnnoClass))
2295 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2296 if (!modulesAttr || modulesAttr.size() < 2)
2300 processInstanceGroups(
2301 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2305 SmallDenseSet<StringAttr, 4> moduleTargets;
2306 for (
auto instOp : instanceGroup) {
2307 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2308 for (
auto moduleName : moduleNames)
2309 moduleTargets.insert(cast<StringAttr>(moduleName));
2311 if (moduleTargets.size() < 2)
2316 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2317 auto moduleNames = inst.getReferencedModuleNames();
2318 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2319 return modulesAlreadyInMustDedup.contains(moduleName);
2324 addMatch(1, matchId - 1);
2330 ArrayRef<uint64_t> matches)
override {
2331 auto *
context = circuitOp->getContext();
2333 SmallVector<Annotation> newAnnotations;
2334 uint64_t matchId = 0;
2336 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
2337 if (!anno.
isClass(mustDeduplicateAnnoClass)) {
2338 newAnnotations.push_back(anno);
2342 auto modulesAttr = anno.
getMember<ArrayAttr>(
"modules");
2343 if (!modulesAttr || modulesAttr.size() < 2) {
2344 newAnnotations.push_back(anno);
2348 processInstanceGroups(
2349 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2351 if (!llvm::is_contained(matches, matchId++))
2355 SmallSetVector<StringAttr, 4> moduleTargets;
2356 for (
auto instOp : instanceGroup) {
2357 auto moduleNames = instOp.getReferencedModuleNames();
2358 for (
auto moduleName : moduleNames) {
2360 target.circuit = circuitOp.getName();
2361 target.module = moduleName;
2362 moduleTargets.insert(target.toStringAttr(
context));
2367 SmallVector<NamedAttribute> newAnnoAttrs;
2368 newAnnoAttrs.emplace_back(
2369 StringAttr::get(
context,
"class"),
2370 StringAttr::get(
context, mustDeduplicateAnnoClass));
2371 newAnnoAttrs.emplace_back(
2372 StringAttr::get(
context,
"modules"),
2374 SmallVector<Attribute>(moduleTargets.begin(),
2375 moduleTargets.end())));
2377 auto newAnnoDict = DictionaryAttr::get(
context, newAnnoAttrs);
2378 newAnnotations.emplace_back(newAnnoDict);
2382 newAnnotations.push_back(anno);
2387 newAnnoSet.applyToOperation(circuitOp);
2391 std::string
getName()
const override {
return "must-dedup-children"; }
2399 void processInstanceGroups(
2400 CircuitOp circuitOp, ArrayAttr modulesAttr,
2401 llvm::function_ref<
void(ArrayRef<FInstanceLike>)> callback) {
2402 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2405 SmallVector<FModuleLike> modules;
2406 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2408 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2409 modules.push_back(mod);
2412 if (modules.size() < 2)
2419 struct InstanceGroup {
2420 SmallVector<FInstanceLike> instances;
2421 bool nameIsUnique =
true;
2423 MapVector<StringAttr, InstanceGroup> instanceGroups;
2424 for (
auto module : modules) {
2426 module.walk([&](FInstanceLike instOp) {
2427 if (isa<ObjectOp>(instOp.getOperation()))
2429 auto name = instOp.getInstanceNameAttr();
2430 auto &group = instanceGroups[name];
2431 if (nameCounts[name]++ > 1)
2432 group.nameIsUnique =
false;
2433 group.instances.push_back(instOp);
2439 for (
auto &[name, group] : instanceGroups)
2440 if (group.nameIsUnique && group.instances.size() >= 2)
2441 callback(group.instances);
2448struct LayerDisable :
public OpReduction<CircuitOp> {
2449 LayerDisable(MLIRContext *
context) {
2450 pm = std::make_unique<mlir::PassManager>(
2451 context,
"builtin.module", mlir::OpPassManager::Nesting::Explicit);
2452 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2455 void beforeReduction(mlir::ModuleOp op)
override { symbolRefAttrMap.clear(); }
2457 void afterReduction(mlir::ModuleOp op)
override { (void)pm->run(op); };
2459 void matches(CircuitOp circuitOp,
2460 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2461 uint64_t matchId = 0;
2463 SmallVector<FlatSymbolRefAttr> nestedRefs;
2464 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2467 rootRef = layerOp.getSymNameAttr();
2469 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2471 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2472 addMatch(1, matchId++);
2474 for (
auto nestedLayerOp : layerOp.getOps<LayerOp>())
2475 addLayer(rootRef, nestedLayerOp);
2477 if (!nestedRefs.empty())
2478 nestedRefs.pop_back();
2481 for (
auto layerOp : circuitOp.getOps<LayerOp>())
2482 addLayer({}, layerOp);
2486 ArrayRef<uint64_t> matches)
override {
2487 SmallVector<Attribute> disableLayers;
2488 if (
auto existingDisables = circuitOp.getDisableLayersAttr()) {
2489 auto disableRange = existingDisables.getAsRange<Attribute>();
2490 disableLayers.append(disableRange.begin(), disableRange.end());
2492 for (
auto match : matches)
2493 disableLayers.push_back(symbolRefAttrMap.at(match));
2495 circuitOp.setDisableLayersAttr(
2496 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2501 std::string
getName()
const override {
return "firrtl-layer-disable"; }
2503 std::unique_ptr<mlir::PassManager> pm;
2504 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2514 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
2516 auto elements = listOp.getElements();
2517 for (
size_t i = 0; i < elements.size(); ++i)
2522 ArrayRef<uint64_t>
matches)
override {
2524 llvm::SmallDenseSet<uint64_t, 4> matchesSet(
matches.begin(),
matches.end());
2527 SmallVector<Value> newElements;
2528 auto elements = listOp.getElements();
2529 for (
size_t i = 0; i < elements.size(); ++i) {
2530 if (!matchesSet.contains(i))
2531 newElements.push_back(elements[i]);
2535 OpBuilder builder(listOp);
2536 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2537 listOp.getType(), newElements);
2538 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2545 return "firrtl-list-create-element-remover";
2560 patterns.add<SimplifyResets, 35>();
2562 patterns.add<MustDedupChildren, 33>();
2563 patterns.add<AnnotationRemover, 32>();
2565 patterns.add<LayerDisable, 30>(getContext());
2571 firrtl::createLowerCHIRRTLPass(),
true,
true);
2576 patterns.add<FIRRTLModuleExternalizer, 25>();
2577 patterns.add<InstanceStubber, 24>();
2582 firrtl::createLowerFIRRTLTypes(),
true,
true);
2589 firrtl::createRemoveUnusedPorts({
true}));
2590 patterns.add<NodeSymbolRemover, 15>();
2591 patterns.add<ConnectForwarder, 14>();
2592 patterns.add<ConnectInvalidator, 13>();
2594 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2595 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2596 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2598 patterns.add<DetachSubaccesses, 7>();
2599 patterns.add<ModulePortPruner, 7>();
2600 patterns.add<ExtmodulePortPruner, 6>();
2602 patterns.add<RootExtmodulePortPruner, 5>();
2603 patterns.add<ExtmoduleInstanceRemover, 4>();
2604 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2605 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2606 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2612 mlir::DialectRegistry ®istry) {
2613 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Helper class to cache tie-off values for different FIRRTL types.
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
@ None
Don't explicitly preserve any named values.
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Pseudo-reduction that sanitizes the names of operations inside modules.
Pseudo-reduction that sanitizes module and port names.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
A reduction pattern that applies an mlir::Pass.
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)