24#include "mlir/IR/ImplicitLocOpBuilder.h"
25#include "llvm/ADT/APSInt.h"
26#include "llvm/ADT/SmallSet.h"
27#include "llvm/Support/Debug.h"
29#define DEBUG_TYPE "firrtl-reductions"
33using namespace firrtl;
35using llvm::SmallSetVector;
48 return tables->getSymbolTable(op);
58 return userMaps.insert({op, SymbolUserMap(*
tables, op)}).first->second;
65 tables = std::make_unique<SymbolTableCollection>();
70 std::unique_ptr<SymbolTableCollection>
tables;
77static std::optional<firrtl::FModuleOp>
80 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
81 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
83 return moduleOp ? std::optional(moduleOp) : std::nullopt;
94 module->walk([&](Operation *op) {
96 if (
auto instOp = dyn_cast<firrtl::InstanceOp>(op))
110 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
111 auto *op = use.getOwner();
112 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
114 if (use.getOperandNumber() != 0)
116 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
133 unsigned numRemoved = 0;
135 SymbolTableCollection symbolTables;
136 for (Operation &rootOp : *
module.getBody()) {
137 if (!isa<firrtl::CircuitOp>(&rootOp))
139 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
140 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
142 if (
auto *op = symbolTable.lookup(sym)) {
143 if (symbolUserMap.useEmpty(op)) {
152 if (numRemoved > 0 || numLost > 0) {
153 llvm::dbgs() <<
"Removed " << numRemoved <<
" NLAs";
155 llvm::dbgs() <<
" (" << numLost <<
" no longer there)";
156 llvm::dbgs() <<
"\n";
165 if (
auto dict = dyn_cast<DictionaryAttr>(anno)) {
166 if (
auto field = dict.getAs<FlatSymbolRefAttr>(
"circt.nonlocal"))
167 nlasToRemove.insert(field.getAttr());
168 for (
auto namedAttr : dict)
169 markNLAsInAnnotation(namedAttr.getValue());
170 }
else if (
auto array = dyn_cast<ArrayAttr>(anno)) {
171 for (
auto attr : array)
172 markNLAsInAnnotation(attr);
180 op->walk([&](Operation *op) {
181 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
182 markNLAsInAnnotation(annos);
197struct FIRRTLModuleExternalizer :
public OpReduction<FModuleOp> {
204 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
206 uint64_t
match(FModuleOp module)
override {
207 if (innerSymUses.hasInnerRef(module))
209 return moduleSizes.getModuleSize(module, symbols);
212 LogicalResult
rewrite(FModuleOp module)
override {
215 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
216 for (
auto attr :
module.getPortTypes()) {
217 auto type = cast<TypeAttr>(attr).getValue();
218 if (
auto refType = type_dyn_cast<RefType>(type))
219 if (
auto layer = refType.getLayer())
220 layers.insert(layer);
222 SmallVector<Attribute, 4> layersArray;
223 layersArray.reserve(layers.size());
224 for (
auto layer : layers)
225 layersArray.push_back(layer);
227 nlaRemover.markNLAsInOperation(module);
228 OpBuilder builder(module);
229 auto extmodule = FExtModuleOp::create(
230 builder, module->getLoc(),
231 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
232 module.getConventionAttr(), module.getPorts(),
233 builder.getArrayAttr(layersArray), StringRef(),
234 module.getAnnotationsAttr());
235 SymbolTable::setSymbolVisibility(extmodule,
236 SymbolTable::getSymbolVisibility(module));
241 std::string
getName()
const override {
return "firrtl-module-externalizer"; }
253static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
256 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
261 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
262 for (
auto element :
llvm::enumerate(bundleType.getElements())) {
264 builder.createOrFold<firrtl::SubfieldOp>(value, element.index());
265 invalidateOutputs(builder, subfield, invalidCache,
266 flip ^ element.value().isFlip);
267 if (subfield.use_empty())
268 subfield.getDefiningOp()->erase();
274 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
275 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
276 auto subindex = builder.createOrFold<firrtl::SubindexOp>(value, i);
277 invalidateOutputs(builder, subindex, invalidCache,
flip);
278 if (subindex.use_empty())
279 subindex.getDefiningOp()->erase();
287 Value invalid = invalidCache.lookup(type);
289 invalid = firrtl::InvalidValueOp::create(builder, type);
290 invalidCache.insert({type, invalid});
292 firrtl::ConnectOp::create(builder, value, invalid);
296static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
298 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
301 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
302 for (
auto element :
llvm::enumerate(bundleType.getElements()))
303 connectToLeafs(builder,
304 firrtl::SubfieldOp::create(builder, dest, element.index()),
308 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
309 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
310 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
314 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
317 auto destWidth = type.getBitWidthOrSentinel();
318 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
319 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
320 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
321 if (!isa<firrtl::UIntType>(type)) {
322 if (isa<firrtl::SIntType>(type))
323 value = firrtl::AsSIntPrimOp::create(builder, value);
327 firrtl::ConnectOp::create(builder, dest, value);
331static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
332 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
335 if (
auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
336 for (
auto element :
llvm::enumerate(bundleType.getElements()))
339 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
342 if (
auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
343 for (
unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
344 reduceXor(builder, into,
345 builder.createOrFold<firrtl::SubindexOp>(value, i));
348 if (!isa<firrtl::UIntType>(type)) {
349 if (isa<firrtl::SIntType>(type))
350 value = firrtl::AsUIntPrimOp::create(builder, value);
354 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
360struct InstanceStubber :
public OpReduction<firrtl::InstanceOp> {
363 erasedModules.clear();
371 SmallVector<Operation *> worklist;
372 auto deadInsts = erasedInsts;
373 for (
auto *op : erasedModules)
374 worklist.push_back(op);
375 while (!worklist.empty()) {
376 auto *op = worklist.pop_back_val();
377 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
378 op->walk([&](firrtl::InstanceOp instOp) {
379 auto moduleOp = cast<firrtl::FModuleLike>(
380 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
381 deadInsts.insert(instOp);
383 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
384 [&](Operation *user) { return deadInsts.contains(user); })) {
385 LLVM_DEBUG(llvm::dbgs() <<
"- Removing transitively unused module `"
386 << moduleOp.getModuleName() <<
"`\n");
387 erasedModules.insert(moduleOp);
388 worklist.push_back(moduleOp);
393 for (
auto *op : erasedInsts)
395 for (
auto *op : erasedModules)
397 nlaRemover.remove(op);
400 uint64_t
match(firrtl::InstanceOp instOp)
override {
402 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
406 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
407 LLVM_DEBUG(llvm::dbgs()
408 <<
"Stubbing instance `" << instOp.getName() <<
"`\n");
409 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
411 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
412 auto result = instOp.getResult(i);
413 auto name = builder.getStringAttr(Twine(instOp.getName()) +
"_" +
414 instOp.getPortNameStr(i));
416 firrtl::WireOp::create(builder, result.getType(), name,
417 firrtl::NameKindEnum::DroppableName,
418 instOp.getPortAnnotation(i), StringAttr{})
420 invalidateOutputs(builder, wire, invalidCache,
421 instOp.getPortDirection(i) == firrtl::Direction::In);
422 result.replaceAllUsesWith(wire);
424 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
425 auto moduleOp = cast<firrtl::FModuleLike>(
426 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
427 nlaRemover.markNLAsInOperation(instOp);
428 erasedInsts.insert(instOp);
430 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
431 [&](Operation *user) { return erasedInsts.contains(user); })) {
432 LLVM_DEBUG(llvm::dbgs() <<
"- Removing now unused module `"
433 << moduleOp.getModuleName() <<
"`\n");
434 erasedModules.insert(moduleOp);
439 std::string
getName()
const override {
return "instance-stubber"; }
444 llvm::DenseSet<Operation *> erasedInsts;
445 llvm::DenseSet<Operation *> erasedModules;
451struct MemoryStubber :
public OpReduction<firrtl::MemOp> {
452 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
453 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
454 LogicalResult
rewrite(firrtl::MemOp memOp)
override {
455 LLVM_DEBUG(llvm::dbgs() <<
"Stubbing memory `" << memOp.getName() <<
"`\n");
456 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
459 SmallVector<Value> outputs;
460 for (
unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
461 auto result = memOp.getResult(i);
462 auto name = builder.getStringAttr(Twine(memOp.getName()) +
"_" +
463 memOp.getPortNameStr(i));
465 firrtl::WireOp::create(builder, result.getType(), name,
466 firrtl::NameKindEnum::DroppableName,
467 memOp.getPortAnnotation(i), StringAttr{})
469 invalidateOutputs(builder, wire, invalidCache,
true);
470 result.replaceAllUsesWith(wire);
474 switch (memOp.getPortKind(i)) {
475 case firrtl::MemOp::PortKind::Read:
476 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
478 case firrtl::MemOp::PortKind::Write:
479 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
481 case firrtl::MemOp::PortKind::ReadWrite:
482 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
483 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
485 case firrtl::MemOp::PortKind::Debug:
490 if (!isa<firrtl::RefType>(result.getType())) {
493 cast<firrtl::BundleType>(wire.getType()).getNumElements();
494 for (
unsigned i = 0; i != numFields; ++i) {
495 if (i != 2 && i != 3 && i != 5)
496 reduceXor(builder, xorInputs,
497 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
500 reduceXor(builder, xorInputs, input);
505 outputs.push_back(output);
509 for (
auto output : outputs)
510 connectToLeafs(builder, output, xorInputs);
512 nlaRemover.markNLAsInOperation(memOp);
516 std::string
getName()
const override {
return "memory-stubber"; }
523static bool isFlowSensitiveOp(Operation *op) {
524 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
525 firrtl::InstanceOp, firrtl::SubfieldOp, firrtl::SubindexOp,
526 firrtl::SubaccessOp>(op);
532template <
unsigned OpNum>
533struct FIRRTLOperandForwarder :
public Reduction {
534 uint64_t
match(Operation *op)
override {
535 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
537 if (isFlowSensitiveOp(op))
540 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
542 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
543 return resultTy && opTy &&
544 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
545 (resultTy.getBitWidthOrSentinel() == -1) ==
546 (opTy.getBitWidthOrSentinel() == -1) &&
547 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
549 LogicalResult
rewrite(Operation *op)
override {
551 ImplicitLocOpBuilder builder(op->getLoc(), op);
552 auto result = op->getResult(0);
553 auto operand = op->getOperand(OpNum);
554 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
555 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
556 auto resultWidth = resultTy.getBitWidthOrSentinel();
557 auto operandWidth = operandTy.getBitWidthOrSentinel();
559 if (resultWidth < operandWidth)
561 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
562 else if (resultWidth > operandWidth)
563 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
566 LLVM_DEBUG(llvm::dbgs() <<
"Forwarding " << newOp <<
" in " << *op <<
"\n");
567 result.replaceAllUsesWith(newOp);
571 std::string
getName()
const override {
572 return (
"firrtl-operand" + Twine(OpNum) +
"-forwarder").str();
578struct FIRRTLConstantifier :
public Reduction {
579 uint64_t
match(Operation *op)
override {
580 if (op->hasTrait<OpTrait::ConstantLike>())
582 if (op->getNumResults() != 1 || op->getNumOperands() == 0)
584 if (op->hasAttr(
"inner_sym"))
586 if (isFlowSensitiveOp(op))
588 auto type = dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
589 return isa_and_nonnull<firrtl::UIntType, firrtl::SIntType>(type);
591 LogicalResult
rewrite(Operation *op)
override {
593 OpBuilder builder(op);
594 auto type = cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
595 auto width = type.getBitWidthOrSentinel();
599 firrtl::ConstantOp::create(builder, op->getLoc(), type,
600 APSInt(width, isa<firrtl::UIntType>(type)));
601 op->replaceAllUsesWith(newOp);
605 std::string
getName()
const override {
return "firrtl-constantifier"; }
612struct ConnectInvalidator :
public Reduction {
613 uint64_t
match(Operation *op)
override {
614 if (!isa<FConnectLike>(op))
616 if (
auto *srcOp = op->getOperand(1).getDefiningOp())
617 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
618 isa<InvalidValueOp>(srcOp))
620 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
621 return type && type.isPassive();
623 LogicalResult
rewrite(Operation *op)
override {
625 auto rhs = op->getOperand(1);
626 OpBuilder builder(op);
627 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
628 auto *rhsOp = rhs.getDefiningOp();
629 op->setOperand(1, invOp);
634 std::string
getName()
const override {
return "connect-invalidator"; }
641struct AnnotationRemover :
public Reduction {
642 void beforeReduction(mlir::ModuleOp op)
override { nlaRemover.clear(); }
643 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
646 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
647 uint64_t matchId = 0;
650 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations"))
651 for (
unsigned i = 0; i < annos.size(); ++i)
652 addMatch(1, matchId++);
655 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations"))
656 for (
auto portAnnoArray : portAnnos)
657 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
658 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
659 addMatch(1, matchId++);
663 ArrayRef<uint64_t> matches)
override {
665 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
668 uint64_t matchId = 0;
669 auto processAnnotations =
670 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
671 SmallVector<Attribute> newAnnotations;
672 for (
auto anno : annotations) {
673 if (!matchesSet.contains(matchId)) {
674 newAnnotations.push_back(anno);
677 nlaRemover.markNLAsInAnnotation(anno);
681 return ArrayAttr::get(op->getContext(), newAnnotations);
685 if (
auto annos = op->getAttrOfType<ArrayAttr>(
"annotations")) {
686 op->setAttr(
"annotations", processAnnotations(annos.getValue()));
690 if (
auto portAnnos = op->getAttrOfType<ArrayAttr>(
"portAnnotations")) {
691 SmallVector<Attribute> newPortAnnos;
692 for (
auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
693 newPortAnnos.push_back(
694 processAnnotations(portAnnoArrayAttr.getValue()));
696 op->setAttr(
"portAnnotations",
697 ArrayAttr::get(op->getContext(), newPortAnnos));
703 std::string
getName()
const override {
return "annotation-remover"; }
709struct RootPortPruner :
public OpReduction<firrtl::FModuleOp> {
710 uint64_t
match(firrtl::FModuleOp module)
override {
711 auto circuit =
module->getParentOfType<firrtl::CircuitOp>();
714 return circuit.getNameAttr() ==
module.getNameAttr();
716 LogicalResult
rewrite(firrtl::FModuleOp module)
override {
718 size_t numPorts =
module.getNumPorts();
719 llvm::BitVector dropPorts(numPorts);
720 for (
unsigned i = 0; i != numPorts; ++i) {
724 llvm::make_early_inc_range(module.getArgument(i).getUsers()))
728 module.erasePorts(dropPorts);
731 std::string
getName()
const override {
return "root-port-pruner"; }
736struct ExtmoduleInstanceRemover :
public OpReduction<firrtl::InstanceOp> {
741 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
743 uint64_t
match(firrtl::InstanceOp instOp)
override {
744 return isa<firrtl::FExtModuleOp>(
745 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
747 LogicalResult
rewrite(firrtl::InstanceOp instOp)
override {
749 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
750 symbols.getNearestSymbolTable(instOp)))
752 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
753 SmallVector<Value> replacementWires;
755 auto wire = firrtl::WireOp::create(
757 (Twine(instOp.getName()) +
"_" +
info.getName()).str())
759 if (
info.isOutput()) {
760 auto inv = firrtl::InvalidValueOp::create(builder,
info.type);
761 firrtl::ConnectOp::create(builder, wire, inv);
763 replacementWires.push_back(wire);
765 nlaRemover.markNLAsInOperation(instOp);
766 instOp.replaceAllUsesWith(std::move(replacementWires));
770 std::string
getName()
const override {
return "extmodule-instance-remover"; }
778struct ConnectForwarder :
public Reduction {
779 uint64_t
match(Operation *op)
override {
780 if (!isa<firrtl::FConnectLike>(op))
782 auto dest = op->getOperand(0);
783 auto src = op->getOperand(1);
784 auto *destOp = dest.getDefiningOp();
785 auto *srcOp = src.getDefiningOp();
791 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
797 unsigned numConnects = 0;
798 for (
auto &use : dest.getUses()) {
799 auto *op = use.getOwner();
800 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
801 if (++numConnects > 1)
805 if (srcOp && !srcOp->isBeforeInBlock(op))
812 LogicalResult
rewrite(Operation *op)
override {
813 auto dst = op->getOperand(0);
814 auto src = op->getOperand(1);
815 dst.replaceAllUsesWith(src);
817 if (
auto *dstOp = dst.getDefiningOp())
819 if (
auto *srcOp = src.getDefiningOp())
824 std::string
getName()
const override {
return "connect-forwarder"; }
829template <
unsigned OpNum>
830struct ConnectSourceOperandForwarder :
public Reduction {
831 uint64_t
match(Operation *op)
override {
832 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
834 auto dest = op->getOperand(0);
835 auto *destOp = dest.getDefiningOp();
838 if (!destOp || !destOp->hasOneUse() ||
839 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
842 auto *srcOp = op->getOperand(1).getDefiningOp();
843 if (!srcOp || OpNum >= srcOp->getNumOperands())
846 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
848 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
850 return resultTy && opTy &&
851 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
852 ((resultTy.getBitWidthOrSentinel() == -1) ==
853 (opTy.getBitWidthOrSentinel() == -1)) &&
854 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
857 LogicalResult
rewrite(Operation *op)
override {
858 auto *destOp = op->getOperand(0).getDefiningOp();
859 auto *srcOp = op->getOperand(1).getDefiningOp();
860 auto forwardedOperand = srcOp->getOperand(OpNum);
861 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
863 if (
auto wire = dyn_cast<firrtl::WireOp>(destOp))
864 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
868 auto regName = destOp->getAttrOfType<StringAttr>(
"name");
871 auto clock = destOp->getOperand(0);
872 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
873 clock, regName ? regName.str() :
"")
878 builder.setInsertionPointAfter(op);
879 if (isa<firrtl::ConnectOp>(op))
880 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
882 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
892 std::string
getName()
const override {
893 return (
"connect-source-operand-" + Twine(OpNum) +
"-forwarder").str();
900struct DetachSubaccesses :
public Reduction {
901 void beforeReduction(mlir::ModuleOp op)
override { opsToErase.clear(); }
903 for (
auto *op : opsToErase)
904 op->dropAllReferences();
905 for (
auto *op : opsToErase)
908 uint64_t
match(Operation *op)
override {
911 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
912 llvm::all_of(op->getUses(), [](
auto &use) {
913 return use.getOperandNumber() == 0 &&
914 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
915 firrtl::SubaccessOp>(use.getOwner());
918 LogicalResult
rewrite(Operation *op)
override {
920 OpBuilder builder(op);
921 bool isWire = isa<firrtl::WireOp>(op);
924 invalidClock = firrtl::InvalidValueOp::create(
925 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
926 for (Operation *user :
llvm::make_early_inc_range(op->getUsers())) {
927 builder.setInsertionPoint(user);
928 auto type = user->getResult(0).getType();
931 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
934 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
935 user->replaceAllUsesWith(replOp);
936 opsToErase.insert(user);
938 opsToErase.insert(op);
941 std::string
getName()
const override {
return "detach-subaccesses"; }
942 llvm::DenseSet<Operation *> opsToErase;
948struct NodeSymbolRemover :
public Reduction {
953 uint64_t
match(Operation *op)
override {
955 auto sym = op->getAttrOfType<hw::InnerSymAttr>(
"inner_sym");
956 if (!sym || sym.empty())
960 if (innerSymUses.hasInnerRef(op))
965 LogicalResult
rewrite(Operation *op)
override {
966 op->removeAttr(
"inner_sym");
970 std::string
getName()
const override {
return "node-symbol-remover"; }
977struct EagerInliner :
public OpReduction<InstanceOp> {
981 nlaTable = std::make_unique<NLATable>(op);
982 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
985 nlaRemover.remove(op);
987 innerSymTables.reset();
990 uint64_t
match(InstanceOp instOp)
override {
991 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
993 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
996 if (!isa<FModuleOp>(moduleOp))
1000 DenseSet<hw::HierPathOp> nlas;
1001 nlaTable->getInstanceNLAs(instOp, nlas);
1007 auto referencedModule = cast<FModuleOp>(moduleOp);
1008 auto parentModule = instOp->getParentOfType<FModuleOp>();
1009 if (hasInnerSymbolCollisions(referencedModule, parentModule))
1015 LogicalResult
rewrite(InstanceOp instOp)
override {
1016 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1017 auto moduleOp = cast<FModuleOp>(
1018 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1020 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1021 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1024 IRRewriter rewriter(instOp);
1025 SmallVector<Value> argWires;
1026 for (
unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1027 auto result = instOp.getResult(i);
1028 auto name = rewriter.getStringAttr(Twine(instOp.getName()) +
"_" +
1029 instOp.getPortNameStr(i));
1030 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1031 name, NameKindEnum::DroppableName,
1032 instOp.getPortAnnotation(i), StringAttr{})
1034 result.replaceAllUsesWith(wire);
1035 argWires.push_back(wire);
1039 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1043 nlaRemover.markNLAsInOperation(instOp);
1045 nlaRemover.markNLAsInOperation(moduleOp);
1048 clonedModuleOp.erase();
1054 bool hasInnerSymbolCollisions(FModuleOp referencedModule,
1055 FModuleOp parentModule) {
1057 auto &targetTable = innerSymTables->getInnerSymbolTable(referencedModule);
1058 auto &parentTable = innerSymTables->getInnerSymbolTable(parentModule);
1063 LogicalResult walkResult = targetTable.walkSymbols(
1064 [&](StringAttr name,
1067 if (parentTable.lookup(name)) {
1075 return failed(walkResult);
1078 std::string
getName()
const override {
return "firrtl-eager-inliner"; }
1083 std::unique_ptr<NLATable> nlaTable;
1084 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1098struct ModuleInternalNameSanitizer :
public Reduction {
1099 uint64_t
match(Operation *op)
override {
1101 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1102 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1103 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1104 firrtl::CoverOp>(op);
1106 LogicalResult
rewrite(Operation *op)
override {
1107 TypeSwitch<Operation *, void>(op)
1108 .Case<firrtl::WireOp>([](
auto op) { op.setName(
"wire"); })
1109 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1110 [](
auto op) { op.setName(
"reg"); })
1111 .Case<firrtl::NodeOp>([](
auto op) { op.setName(
"node"); })
1112 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1113 [](
auto op) { op.setName(
"mem"); })
1114 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](
auto op) {
1115 op->setAttr(
"message", StringAttr::get(op.getContext(),
""));
1116 op->setAttr(
"name", StringAttr::get(op.getContext(),
""));
1121 std::string
getName()
const override {
1122 return "module-internal-name-sanitizer";
1127 bool isOneShot()
const override {
return true; }
1141struct ModuleNameSanitizer :
OpReduction<firrtl::CircuitOp> {
1143 const char *names[48] = {
1144 "Foo",
"Bar",
"Baz",
"Qux",
"Quux",
"Quuux",
"Quuuux",
1145 "Quz",
"Corge",
"Grault",
"Bazola",
"Ztesch",
"Thud",
"Grunt",
1146 "Bletch",
"Fum",
"Fred",
"Jim",
"Sheila",
"Barney",
"Flarp",
1147 "Zxc",
"Spqr",
"Wombat",
"Shme",
"Bongo",
"Spam",
"Eggs",
1148 "Snork",
"Zot",
"Blarg",
"Wibble",
"Toto",
"Titi",
"Tata",
1149 "Tutu",
"Pippo",
"Pluto",
"Paperino",
"Aap",
"Noot",
"Mies",
1150 "Oogle",
"Foogle",
"Boogle",
"Zork",
"Gork",
"Bork"};
1152 size_t nameIndex = 0;
1155 if (nameIndex >= 48)
1157 return names[nameIndex++];
1160 size_t portNameIndex = 0;
1162 char getPortName() {
1163 if (portNameIndex >= 26)
1165 return 'a' + portNameIndex++;
1170 LogicalResult
rewrite(firrtl::CircuitOp circuitOp)
override {
1174 auto *circuitName =
getName();
1175 iGraph.getTopLevelModule().setName(circuitName);
1176 circuitOp.setName(circuitName);
1178 for (
auto *node : iGraph) {
1179 auto module = node->getModule<firrtl::FModuleLike>();
1181 bool shouldReplacePorts =
false;
1182 SmallVector<Attribute> newNames;
1183 if (
auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1188 auto oldPorts = fmodule.getPorts();
1189 shouldReplacePorts = !oldPorts.empty();
1190 for (
unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1191 auto port = oldPorts[i];
1193 .
Case<firrtl::ClockType>(
1194 [&](
auto a) {
return ns.
newName(
"clk"); })
1195 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1196 [&](
auto a) {
return ns.
newName(
"rst"); })
1197 .Case<firrtl::RefType>(
1198 [&](
auto a) {
return ns.
newName(
"ref"); })
1199 .Default([&](
auto a) {
1200 return ns.
newName(Twine(getPortName()));
1202 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1204 fmodule->setAttr(
"portNames",
1205 ArrayAttr::get(fmodule.getContext(), newNames));
1208 if (module == iGraph.getTopLevelModule())
1210 auto newName = StringAttr::get(circuitOp.getContext(),
getName());
1211 module.setName(newName);
1212 for (
auto *use : node->uses()) {
1213 auto instanceOp = dyn_cast<firrtl::InstanceOp>(*use->getInstance());
1214 instanceOp.setModuleName(newName);
1215 instanceOp.setName(newName);
1216 if (shouldReplacePorts)
1217 instanceOp.setPortNamesAttr(
1218 ArrayAttr::get(circuitOp.getContext(), newNames));
1227 std::string
getName()
const override {
return "module-name-sanitizer"; }
1231 bool isOneShot()
const override {
return true; }
1250struct ModuleSwapper :
public OpReduction<InstanceOp> {
1252 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1253 struct CircuitState {
1254 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1255 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1256 std::unique_ptr<NLATable> nlaTable;
1262 moduleSizes.clear();
1263 circuitStates.clear();
1266 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1267 auto &state = circuitStates[circuitOp];
1268 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1269 buildModuleTypeGroups(circuitOp, state);
1270 return WalkResult::skip();
1273 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1279 PortSignature getModulePortSignature(FModuleLike module) {
1280 PortSignature signature;
1281 signature.reserve(module.getNumPorts());
1282 for (
unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1283 signature.emplace_back(module.getPortType(i),
module.getPortDirection(i));
1288 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1290 for (
auto module : circuitOp.
getBodyBlock()->getOps<FModuleLike>()) {
1291 auto signature = getModulePortSignature(module);
1292 state.moduleTypeGroups[signature].push_back(module);
1296 for (
auto &[signature, modules] : state.moduleTypeGroups) {
1297 if (modules.size() <= 1)
1300 FModuleLike smallestModule =
nullptr;
1301 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1303 for (
auto module : modules) {
1304 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1305 if (size < smallestSize) {
1306 smallestSize = size;
1307 smallestModule =
module;
1312 for (
auto module : modules) {
1313 if (module != smallestModule) {
1314 state.instanceToCanonicalModule[
module.getModuleNameAttr()] =
1321 uint64_t
match(InstanceOp instOp)
override {
1323 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1325 const auto &state = circuitStates.at(circuitOp);
1328 DenseSet<hw::HierPathOp> nlas;
1329 state.nlaTable->getInstanceNLAs(instOp, nlas);
1334 auto moduleName = instOp.getModuleNameAttr().getAttr();
1335 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1336 if (!canonicalModule)
1340 auto currentModule = cast<FModuleLike>(
1341 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1342 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1343 uint64_t canonicalSize =
1344 moduleSizes.getModuleSize(canonicalModule, symbols);
1345 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1348 LogicalResult
rewrite(InstanceOp instOp)
override {
1350 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1352 const auto &state = circuitStates.at(circuitOp);
1355 auto canonicalModule = state.instanceToCanonicalModule.at(
1356 instOp.getModuleNameAttr().getAttr());
1357 auto canonicalName = canonicalModule.getModuleNameAttr();
1358 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
1361 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1366 std::string
getName()
const override {
return "firrtl-module-swapper"; }
1375 DenseMap<CircuitOp, CircuitState> circuitStates;
1393struct ForceDedup :
public OpReduction<CircuitOp> {
1398 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1401 void matches(CircuitOp circuitOp,
1402 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1404 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1408 auto modulesAttr = anno.getMember<ArrayAttr>(
"modules");
1414 uint64_t benefit = modulesAttr.size();
1415 addMatch(benefit, annoIdx);
1420 ArrayRef<uint64_t> matches)
override {
1421 auto *context = circuitOp->getContext();
1425 SmallVector<Annotation> newAnnotations;
1427 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1429 if (!llvm::is_contained(matches, annoIdx)) {
1430 newAnnotations.push_back(anno);
1433 auto modulesAttr = anno.getMember<ArrayAttr>(
"modules");
1435 modulesAttr.size() >= 2);
1438 SmallVector<StringAttr> moduleNames;
1439 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
1441 auto refStr = moduleRef.getValue();
1442 auto pipePos = refStr.find(
'|');
1443 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
1444 auto moduleName = refStr.substr(pipePos + 1);
1445 moduleNames.push_back(StringAttr::get(context, moduleName));
1450 if (moduleNames.size() < 2)
1455 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
1456 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1458 if (newAnnotations.size() == annotations.size())
1463 newAnnoSet.applyToOperation(circuitOp);
1467 std::string
getName()
const override {
return "firrtl-force-dedup"; }
1473 void replaceModuleReferences(CircuitOp circuitOp,
1474 ArrayRef<StringAttr> moduleNames,
1477 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
1478 auto &symbolTable = symbols.getSymbolTable(tableOp);
1479 auto *context = circuitOp->getContext();
1483 FModuleLike canonicalModule;
1484 SmallVector<FModuleLike> modulesToReplace;
1485 for (
auto name : moduleNames) {
1486 if (
auto mod = symbolTable.lookup<FModuleLike>(name)) {
1487 if (!canonicalModule)
1488 canonicalModule = mod;
1490 modulesToReplace.push_back(mod);
1493 if (modulesToReplace.empty())
1497 auto canonicalName = canonicalModule.getModuleNameAttr();
1498 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
1499 circuitOp.walk([&](InstanceOp instOp) {
1500 auto moduleName = instOp.getModuleNameAttr().getAttr();
1501 if (llvm::is_contained(moduleNames, moduleName) &&
1502 moduleName != canonicalName) {
1503 instOp.setModuleNameAttr(canonicalRef);
1504 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1510 for (
auto oldMod : modulesToReplace) {
1511 SmallVector<hw::HierPathOp> nlaOps(
1512 nlaTable.
lookup(oldMod.getModuleNameAttr()));
1513 for (
auto nlaOp : nlaOps) {
1514 nlaTable.
erase(nlaOp);
1515 StringAttr oldModName = oldMod.getModuleNameAttr();
1516 StringAttr newModName = canonicalName;
1517 SmallVector<Attribute, 4> newPath;
1518 for (
auto nameRef : nlaOp.getNamepath()) {
1519 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
1520 if (ref.getModule() == oldModName) {
1521 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
1522 ref = hw::InnerRefAttr::get(newModName, ref.getName());
1523 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
1524 if (oldInst && newInst) {
1525 oldModName = oldInst.getReferencedModuleNameAttr();
1526 newModName = newInst.getReferencedModuleNameAttr();
1529 newPath.push_back(ref);
1530 }
else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
1531 newPath.push_back(FlatSymbolRefAttr::get(newModName));
1533 newPath.push_back(nameRef);
1536 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
1542 for (
auto module : modulesToReplace) {
1543 nlaRemover.markNLAsInOperation(module);
1568struct MustDedupChildren :
public OpReduction<CircuitOp> {
1573 void afterReduction(mlir::ModuleOp op)
override { nlaRemover.remove(op); }
1577 void matches(CircuitOp circuitOp,
1578 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
1580 uint64_t matchId = 0;
1582 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1586 auto modulesAttr = anno.getMember<ArrayAttr>(
"modules");
1587 if (!modulesAttr || modulesAttr.size() < 2)
1591 processInstanceGroups(
1592 circuitOp, modulesAttr,
1593 [&](ArrayRef<FInstanceLike>) { addMatch(1, matchId++); });
1598 ArrayRef<uint64_t> matches)
override {
1599 auto *context = circuitOp->getContext();
1601 SmallVector<Annotation> newAnnotations;
1602 uint64_t matchId = 0;
1604 for (
auto [annoIdx, anno] :
llvm::enumerate(annotations)) {
1606 newAnnotations.push_back(anno);
1610 auto modulesAttr = anno.getMember<ArrayAttr>(
"modules");
1611 if (!modulesAttr || modulesAttr.size() < 2) {
1612 newAnnotations.push_back(anno);
1617 bool anyMatchSelected =
false;
1618 processInstanceGroups(
1619 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
1621 if (!llvm::is_contained(matches, matchId++))
1623 anyMatchSelected =
true;
1626 SmallSetVector<StringAttr, 4> moduleTargets;
1627 for (
auto instOp : instanceGroup) {
1629 target.circuit = circuitOp.getName();
1630 target.module = instOp.getReferencedModuleName();
1631 moduleTargets.insert(target.toStringAttr(context));
1633 if (moduleTargets.size() < 2)
1637 SmallVector<NamedAttribute> newAnnoAttrs;
1638 newAnnoAttrs.emplace_back(
1639 StringAttr::get(context,
"class"),
1641 newAnnoAttrs.emplace_back(
1642 StringAttr::get(context,
"modules"),
1643 ArrayAttr::get(context,
1644 SmallVector<Attribute>(moduleTargets.begin(),
1645 moduleTargets.end())));
1647 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
1648 newAnnotations.emplace_back(newAnnoDict);
1654 if (anyMatchSelected)
1655 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1657 newAnnotations.push_back(anno);
1662 newAnnoSet.applyToOperation(circuitOp);
1666 std::string
getName()
const override {
return "must-dedup-children"; }
1674 void processInstanceGroups(
1675 CircuitOp circuitOp, ArrayAttr modulesAttr,
1676 llvm::function_ref<
void(ArrayRef<FInstanceLike>)> callback) {
1677 auto &symbolTable = symbols.getSymbolTable(circuitOp);
1680 SmallVector<FModuleLike> modules;
1681 for (
auto moduleRef : modulesAttr.getAsRange<StringAttr>())
1683 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
1684 modules.push_back(mod);
1687 if (modules.size() < 2)
1694 struct InstanceGroup {
1695 SmallVector<FInstanceLike> instances;
1696 bool nameIsUnique =
true;
1698 MapVector<StringAttr, InstanceGroup> instanceGroups;
1699 for (
auto module : modules) {
1701 module.walk([&](FInstanceLike instOp) {
1702 auto name = instOp.getInstanceNameAttr();
1703 auto &group = instanceGroups[name];
1704 if (nameCounts[name]++ > 1)
1705 group.nameIsUnique =
false;
1706 group.instances.push_back(instOp);
1712 for (
auto &[name, group] : instanceGroups)
1713 if (group.nameIsUnique && group.instances.size() >= 2)
1714 callback(group.instances);
1734 patterns.add<AnnotationRemover, 33>();
1737 patterns.add<MustDedupChildren, 30>();
1743 firrtl::createLowerCHIRRTLPass(),
true,
true);
1748 patterns.add<FIRRTLModuleExternalizer, 25>();
1749 patterns.add<InstanceStubber, 24>();
1753 firrtl::createLowerFIRRTLTypes(),
true,
true);
1760 firrtl::createRemoveUnusedPorts({
true}));
1761 patterns.add<NodeSymbolRemover, 15>();
1762 patterns.add<ConnectForwarder, 14>();
1763 patterns.add<ConnectInvalidator, 13>();
1764 patterns.add<FIRRTLConstantifier, 12>();
1765 patterns.add<FIRRTLOperandForwarder<0>, 11>();
1766 patterns.add<FIRRTLOperandForwarder<1>, 10>();
1767 patterns.add<FIRRTLOperandForwarder<2>, 9>();
1768 patterns.add<DetachSubaccesses, 7>();
1770 patterns.add<ExtmoduleInstanceRemover, 4>();
1771 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
1772 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
1773 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
1774 patterns.add<ModuleInternalNameSanitizer, 0>();
1775 patterns.add<ModuleNameSanitizer, 0>();
1779 mlir::DialectRegistry ®istry) {
1780 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
assert(baseType &&"element must be base type")
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
@ None
Don't explicitly preserve any named values.
constexpr const char * mustDedupAnnoClass
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
A reduction pattern that applies an mlir::Pass.
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)