15#include "mlir/IR/Iterators.h"
16#include "mlir/IR/Threading.h"
17#include "mlir/Interfaces/SideEffectInterfaces.h"
18#include "mlir/Pass/Pass.h"
19#include "llvm/ADT/BitVector.h"
20#include "llvm/ADT/DenseMapInfoVariant.h"
21#include "llvm/ADT/PostOrderIterator.h"
22#include "llvm/Support/Debug.h"
24#define DEBUG_TYPE "firrtl-imdeadcodeelim"
28#define GEN_PASS_DEF_IMDEADCODEELIM
29#include "circt/Dialect/FIRRTL/Passes.h.inc"
34using namespace firrtl;
38 return !(mlir::isMemoryEffectFree(op) ||
39 mlir::hasSingleEffect<mlir::MemoryEffects::Allocate>(op) ||
40 mlir::hasSingleEffect<mlir::MemoryEffects::Read>(op));
45 return isa<WireOp, RegResetOp, RegOp, NodeOp, MemOp>(op);
50 if (
auto name = dyn_cast<FNamableOp>(op))
51 if (!name.hasDroppableName())
57struct IMDeadCodeElimPass
58 :
public circt::firrtl::impl::IMDeadCodeElimBase<IMDeadCodeElimPass> {
59 void runOnOperation()
override;
61 void rewriteModuleSignature(FModuleOp module);
62 void rewriteModuleBody(FModuleOp module);
63 void eraseEmptyModule(FModuleOp module);
64 void forwardConstantOutputPort(FModuleOp module);
67 bool isKnownAlive(Value value)
const {
68 assert(value &&
"null should not be used");
69 return liveElements.count(value);
73 bool isAssumedDead(Value value)
const {
return !isKnownAlive(value); }
74 bool isAssumedDead(Operation *op)
const {
75 return llvm::none_of(op->getResults(),
76 [&](Value value) { return isKnownAlive(value); });
80 bool isBlockExecutable(Block *block)
const {
81 return executableBlocks.count(block);
84 void visitUser(Operation *op);
85 void visitValue(Value value);
87 void visitConnect(FConnectLike connect);
88 void visitSubelement(Operation *op);
89 void markBlockExecutable(Block *block);
90 void markBlockUndeletable(Operation *op) {
91 markAlive(op->getParentOfType<FModuleOp>());
94 void markDeclaration(Operation *op);
95 void markInstanceOp(InstanceOp instanceOp);
96 void markObjectOp(ObjectOp objectOp);
97 void markUnknownSideEffectOp(Operation *op);
98 void visitInstanceOp(InstanceOp instance);
99 void visitHierPathOp(hw::HierPathOp hierpath);
100 void visitModuleOp(FModuleOp module);
104 DenseSet<Block *> executableBlocks;
110 std::variant<Value, FModuleOp, InstanceOp, hw::HierPathOp>;
112 void markAlive(ElementType element) {
113 if (!liveElements.insert(element).second)
115 worklist.push_back(element);
120 SmallVector<ElementType, 64> worklist;
121 llvm::DenseSet<ElementType> liveElements;
125 DenseMap<InstanceOp, SmallVector<hw::HierPathOp>> instanceToHierPaths;
128 DenseMap<hw::HierPathOp, SetVector<ElementType>> hierPathToElements;
132 mlir::SymbolTable *symbolTable;
136void IMDeadCodeElimPass::visitInstanceOp(InstanceOp instance) {
137 markBlockUndeletable(instance);
139 auto module = instance.getReferencedModule<FModuleOp>(*instanceGraph);
148 for (
auto hierPath : instanceToHierPaths[instance])
152 for (
auto &blockArg : module.getBody().getArguments()) {
153 auto portNo = blockArg.getArgNumber();
154 if (module.getPortDirection(portNo) == Direction::In &&
155 isKnownAlive(module.getArgument(portNo)))
156 markAlive(instance.getResult(portNo));
160void IMDeadCodeElimPass::visitModuleOp(FModuleOp module) {
162 for (
auto *use : instanceGraph->lookup(module)->uses())
163 markAlive(cast<InstanceOp>(*use->getInstance()));
166void IMDeadCodeElimPass::visitHierPathOp(hw::HierPathOp hierPathOp) {
168 for (
auto path : hierPathOp.getNamepathAttr())
169 if (auto innerRef = dyn_cast<
hw::InnerRefAttr>(path)) {
170 auto *op = innerRefNamespace->lookupOp(innerRef);
171 if (
auto instance = dyn_cast_or_null<InstanceOp>(op))
175 for (
auto elem : hierPathToElements[hierPathOp])
179void IMDeadCodeElimPass::markDeclaration(Operation *op) {
182 for (
auto result : op->getResults())
184 markBlockUndeletable(op);
188void IMDeadCodeElimPass::markUnknownSideEffectOp(Operation *op) {
191 for (
auto result : op->getResults())
193 for (
auto operand : op->getOperands())
195 markBlockUndeletable(op);
198void IMDeadCodeElimPass::visitUser(Operation *op) {
199 LLVM_DEBUG(llvm::dbgs() <<
"Visit: " << *op <<
"\n");
200 if (
auto connectOp = dyn_cast<FConnectLike>(op))
201 return visitConnect(connectOp);
202 if (isa<SubfieldOp, SubindexOp, SubaccessOp, ObjectSubfieldOp>(op))
203 return visitSubelement(op);
206void IMDeadCodeElimPass::markInstanceOp(InstanceOp instance) {
208 Operation *op = instance.getReferencedModule(*instanceGraph);
212 if (!isa<FModuleOp>(op)) {
213 auto module = dyn_cast<FModuleLike>(op);
214 for (
auto resultNo :
llvm::
seq(0u, instance.getNumResults())) {
216 if (module.getPortDirection(resultNo) == Direction::Out)
220 markAlive(instance.getResult(resultNo));
228 auto fModule = cast<FModuleOp>(op);
229 markBlockExecutable(fModule.getBodyBlock());
232void IMDeadCodeElimPass::markObjectOp(ObjectOp
object) {
237void IMDeadCodeElimPass::markBlockExecutable(Block *block) {
238 if (!executableBlocks.insert(block).second)
241 auto fmodule = dyn_cast<FModuleOp>(block->getParentOp());
242 if (fmodule && fmodule.isPublic())
246 for (
auto blockArg : block->getArguments())
253 for (
auto &op : *block) {
255 markDeclaration(&op);
256 else if (
auto instance = dyn_cast<InstanceOp>(op))
257 markInstanceOp(instance);
258 else if (
auto object = dyn_cast<ObjectOp>(op))
259 markObjectOp(
object);
260 else if (isa<FConnectLike>(op))
264 markUnknownSideEffectOp(&op);
267 for (
auto ®ion : op.getRegions())
268 for (auto &block : region.getBlocks())
269 markBlockExecutable(&block);
276void IMDeadCodeElimPass::forwardConstantOutputPort(FModuleOp module) {
278 SmallVector<std::pair<unsigned, APSInt>> constantPortIndicesAndValues;
279 auto ports =
module.getPorts();
280 auto *instanceGraphNode = instanceGraph->
lookup(module);
282 for (
const auto &e :
llvm::enumerate(ports)) {
283 unsigned index = e.index();
284 auto port = e.value();
285 auto arg =
module.getArgument(index);
293 if (
auto constant =
connect.getSrc().getDefiningOp<ConstantOp>())
294 constantPortIndicesAndValues.push_back({index, constant.getValue()});
298 if (constantPortIndicesAndValues.empty())
302 for (
auto *use : instanceGraphNode->uses()) {
303 auto instance = cast<InstanceOp>(*use->getInstance());
304 ImplicitLocOpBuilder builder(instance.getLoc(), instance);
305 for (
auto [index, constant] : constantPortIndicesAndValues) {
306 auto result = instance.getResult(index);
307 assert(ports[index].isOutput() &&
"must be an output port");
310 result.replaceAllUsesWith(ConstantOp::create(builder, constant));
315void IMDeadCodeElimPass::runOnOperation() {
318 auto circuits = getOperation().getOps<CircuitOp>();
319 if (circuits.empty())
322 auto circuit = *circuits.begin();
324 if (!llvm::hasSingleElement(circuits)) {
325 mlir::emitError(circuit.getLoc(),
326 "cannot process multiple circuit operations")
327 .attachNote((*std::next(circuits.begin())).getLoc())
328 <<
"second circuit here";
329 return signalPassFailure();
332 instanceGraph = &getChildAnalysis<InstanceGraph>(circuit);
333 symbolTable = &getChildAnalysis<SymbolTable>(circuit);
334 auto &istc = getChildAnalysis<hw::InnerSymbolTableCollection>(circuit);
337 innerRefNamespace = &theInnerRefNamespace;
340 getOperation().walk([&](Operation *op) {
341 if (isa<FModuleOp>(op))
344 if (
auto hierPath = dyn_cast<hw::HierPathOp>(op)) {
345 auto namePath = hierPath.getNamepath().getValue();
348 if (hierPath.isPublic() || namePath.size() <= 1 ||
349 isa<hw::InnerRefAttr>(namePath.back()))
350 return markAlive(hierPath);
353 dyn_cast_or_null<firrtl::InstanceOp>(innerRefNamespace->lookupOp(
354 cast<hw::InnerRefAttr>(namePath.drop_back().back()))))
355 instanceToHierPaths[instance].push_back(hierPath);
361 op->getAttrDictionary().walk([&](Attribute attr) {
362 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(attr)) {
364 if (
auto instance = dyn_cast_or_null<firrtl::InstanceOp>(
365 innerRefNamespace->lookupOp(innerRef)))
370 if (
auto symbolRef = dyn_cast<FlatSymbolRefAttr>(attr)) {
371 auto *symbol = symbolTable->lookup(symbolRef.getAttr());
376 if (
auto hierPath = dyn_cast<hw::HierPathOp>(symbol))
380 if (
auto module = dyn_cast<FModuleOp>(symbol)) {
381 if (!isa<firrtl::InstanceOp>(op)) {
382 LLVM_DEBUG(llvm::dbgs()
383 <<
"Unknown use of " << module.getModuleNameAttr()
384 <<
" in " << op->getName() <<
"\n");
386 markBlockExecutable(module.getBodyBlock());
398 SmallVector<FModuleOp, 0> modules(llvm::make_filter_range(
400 llvm::post_order(instanceGraph),
401 [](
auto *node) {
return dyn_cast<FModuleOp>(*node->getModule()); }),
402 [](
auto module) {
return module; }));
406 for (
auto module : modules)
407 forwardConstantOutputPort(module);
409 for (
auto module : circuit.
getBodyBlock()->getOps<FModuleOp>()) {
411 if (module.isPublic()) {
412 markBlockExecutable(module.getBodyBlock());
419 auto visitAnnotation = [&](
int portId,
Annotation anno) ->
bool {
420 auto hierPathSym = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal");
421 hw::HierPathOp hierPathOp;
424 symbolTable->template lookup<hw::HierPathOp>(hierPathSym.getAttr());
427 markAlive(hierPathOp);
429 markAlive(module.getArgument(portId));
436 module, std::bind(visitAnnotation, -1, std::placeholders::_1));
440 while (!worklist.empty()) {
441 auto v = worklist.pop_back_val();
442 if (
auto *value = std::get_if<Value>(&v))
444 else if (
auto *instance = std::get_if<InstanceOp>(&v))
445 visitInstanceOp(*instance);
446 else if (
auto *hierpath = std::get_if<hw::HierPathOp>(&v))
447 visitHierPathOp(*hierpath);
448 else if (
auto *module = std::get_if<FModuleOp>(&v))
449 visitModuleOp(*module);
453 for (
auto module :
llvm::make_early_inc_range(
455 if (isBlockExecutable(module.getBodyBlock()))
456 rewriteModuleSignature(module);
467 mlir::parallelForEach(circuit.getContext(),
468 circuit.getBodyBlock()->getOps<FModuleOp>(),
469 [&](
auto op) { rewriteModuleBody(op); });
472 for (
auto op :
llvm::make_early_inc_range(
474 if (!liveElements.count(op))
477 for (
auto module : modules)
478 eraseEmptyModule(module);
481 executableBlocks.clear();
482 liveElements.clear();
483 instanceToHierPaths.clear();
484 hierPathToElements.clear();
487void IMDeadCodeElimPass::visitValue(Value value) {
488 assert(isKnownAlive(value) &&
"only alive values reach here");
491 for (Operation *user : value.getUsers())
494 if (
auto blockArg = dyn_cast<BlockArgument>(value)) {
496 dyn_cast<FModuleOp>(blockArg.getParentBlock()->getParentOp())) {
497 auto portDirection =
module.getPortDirection(blockArg.getArgNumber());
501 if (portDirection == Direction::In) {
502 for (
auto *instRec : instanceGraph->lookup(module)->uses()) {
503 auto instance = cast<InstanceOp>(instRec->getInstance());
504 if (liveElements.contains(instance))
505 markAlive(instance.getResult(blockArg.getArgNumber()));
510 if (!type_isa<DomainType>(blockArg.getType()))
511 for (
auto domain : cast<ArrayAttr>(
512 module.getDomainInfoAttrForPort(blockArg.getArgNumber())))
513 markAlive(module.getArgument(
514 cast<IntegerAttr>(domain).getValue().getZExtValue()));
521 if (
auto instance = value.getDefiningOp<InstanceOp>()) {
522 auto instanceResult = cast<mlir::OpResult>(value);
524 auto module = instance.getReferencedModule<FModuleOp>(*instanceGraph);
527 if (!module || module.getPortDirection(instanceResult.getResultNumber()) ==
533 BlockArgument modulePortVal =
534 module.getArgument(instanceResult.getResultNumber());
535 return markAlive(modulePortVal);
539 if (
auto mem = value.getDefiningOp<MemOp>()) {
540 for (
auto port : mem->getResults())
547 if (
auto op = value.getDefiningOp()) {
548 for (
auto operand : op->getOperands())
550 for (
auto ®ion : op->getRegions())
551 for (auto &block : region)
552 markBlockExecutable(&block);
556 if (
auto fop = value.getDefiningOp<Forceable>();
557 fop && fop.isForceable() &&
558 (fop.getData() == value || fop.getDataRef() == value)) {
559 markAlive(fop.getData());
560 markAlive(fop.getDataRef());
564void IMDeadCodeElimPass::visitConnect(FConnectLike connect) {
566 if (isKnownAlive(
connect.getDest()))
570void IMDeadCodeElimPass::visitSubelement(Operation *op) {
571 if (isKnownAlive(op->getOperand(0)))
572 markAlive(op->getResult(0));
575void IMDeadCodeElimPass::rewriteModuleBody(FModuleOp module) {
576 assert(isBlockExecutable(module.getBodyBlock()) &&
577 "unreachable modules must be already deleted");
579 auto removeDeadNonLocalAnnotations = [&](
int _,
Annotation anno) ->
bool {
580 auto hierPathSym = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal");
584 symbolTable->template lookup<hw::HierPathOp>(hierPathSym.getAttr());
585 return !liveElements.count(hierPathOp);
591 std::bind(removeDeadNonLocalAnnotations, -1, std::placeholders::_1));
594 module.walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
597 LLVM_DEBUG(llvm::dbgs() << "Visit: " << *op << "\n");
598 if (
auto connect = dyn_cast<FConnectLike>(op)) {
599 if (isAssumedDead(
connect.getDest())) {
600 LLVM_DEBUG(llvm::dbgs() <<
"DEAD: " << connect <<
"\n";);
610 LLVM_DEBUG(llvm::dbgs() <<
"DEAD: " << *op <<
"\n";);
611 assert(op->use_empty() &&
"users should be already removed");
618 if (mlir::isOpTriviallyDead(op)) {
625void IMDeadCodeElimPass::rewriteModuleSignature(FModuleOp module) {
626 assert(isBlockExecutable(module.getBodyBlock()) &&
627 "unreachable modules must be already deleted");
629 LLVM_DEBUG(llvm::dbgs() <<
"Prune ports of module: " << module.getName()
632 auto replaceInstanceResultWithWire = [&](ImplicitLocOpBuilder &builder,
634 InstanceOp instance) {
635 auto result = instance.getResult(index);
636 if (isAssumedDead(result)) {
640 mlir::UnrealizedConversionCastOp::create(
641 builder, ArrayRef<Type>{result.getType()}, ArrayRef<Value>{})
643 result.replaceAllUsesWith(wire);
647 Value wire = WireOp::create(builder, result.getType()).getResult();
648 result.replaceAllUsesWith(wire);
652 liveElements.erase(result);
653 liveElements.insert(wire);
657 for (
auto *use :
llvm::make_early_inc_range(instanceGraphNode->uses())) {
658 auto instance = cast<InstanceOp>(*use->getInstance());
659 if (!liveElements.count(instance)) {
661 ImplicitLocOpBuilder builder(instance.getLoc(), instance);
662 for (
auto index :
llvm::
seq(0u, instance.getNumResults()))
663 replaceInstanceResultWithWire(builder, index, instance);
671 if (module.isPublic())
674 unsigned numOldPorts =
module.getNumPorts();
675 llvm::BitVector deadPortIndexes(numOldPorts);
677 ImplicitLocOpBuilder builder(module.getLoc(), module.getContext());
678 builder.setInsertionPointToStart(module.getBodyBlock());
679 auto oldPorts =
module.getPorts();
681 for (
auto index :
llvm::
seq(0u, numOldPorts)) {
682 auto argument =
module.getArgument(index);
684 "If the port has don't touch, it should be known alive");
690 if (isKnownAlive(argument)) {
694 if (module.getPortDirection(index) == Direction::In)
699 if (llvm::any_of(instanceGraph->
lookup(module)->
uses(),
702 record->getInstance()->getResult(index));
708 auto wire = WireOp::create(builder, argument.getType()).getResult();
711 liveElements.erase(argument);
712 liveElements.insert(wire);
713 argument.replaceAllUsesWith(wire);
714 deadPortIndexes.set(index);
721 mlir::UnrealizedConversionCastOp::create(
722 builder, ArrayRef<Type>{argument.getType()}, ArrayRef<Value>{})
725 argument.replaceAllUsesWith(wire);
726 assert(isAssumedDead(wire) &&
"dummy wire must be dead");
727 deadPortIndexes.set(index);
731 if (deadPortIndexes.none())
736 for (
auto arg : module.getArguments())
737 liveElements.erase(arg);
740 module.erasePorts(deadPortIndexes);
743 for (
auto arg : module.getArguments())
744 liveElements.insert(arg);
747 for (
auto *use :
llvm::make_early_inc_range(instanceGraphNode->uses())) {
748 auto instance = cast<InstanceOp>(*use->getInstance());
749 ImplicitLocOpBuilder builder(instance.getLoc(), instance);
751 for (
auto index : deadPortIndexes.set_bits())
752 replaceInstanceResultWithWire(builder, index, instance);
756 for (
auto oldResult : instance.getResults())
757 liveElements.erase(oldResult);
761 instance.cloneWithErasedPortsAndReplaceUses(deadPortIndexes);
764 for (
auto newResult : newInstance.getResults())
765 liveElements.insert(newResult);
768 if (liveElements.contains(instance)) {
769 liveElements.erase(instance);
770 liveElements.insert(newInstance);
776 numRemovedPorts += deadPortIndexes.count();
779void IMDeadCodeElimPass::eraseEmptyModule(FModuleOp module) {
781 if (!module.getBodyBlock()->empty())
785 if (module.isPublic()) {
786 mlir::emitWarning(module.getLoc())
787 <<
"module `" <<
module.getName()
788 << "` is empty but cannot be removed because the module is public";
792 if (!module.getAnnotations().empty()) {
793 module.emitWarning() << "module `" << module.getName()
794 << "` is empty but cannot be removed "
795 "because the module has annotations "
796 << module.getAnnotations();
800 if (!module.getBodyBlock()->args_empty()) {
801 auto diag =
module.emitWarning()
802 << "module `" << module.getName()
803 << "` is empty but cannot be removed because the "
805 llvm::interleaveComma(module.getPortNames(), diag);
806 diag <<
" are referenced by name or dontTouched";
811 LLVM_DEBUG(llvm::dbgs() <<
"Erase " << module.getName() <<
"\n");
814 instanceGraph->
lookup(module.getModuleNameAttr());
816 SmallVector<Location> instancesWithSymbols;
817 for (
auto *use :
llvm::make_early_inc_range(instanceGraphNode->uses())) {
818 auto instance = cast<InstanceOp>(use->getInstance());
819 if (instance.getInnerSym()) {
820 instancesWithSymbols.push_back(instance.getLoc());
828 if (!instancesWithSymbols.empty()) {
829 auto diag =
module.emitWarning()
830 << "module `" << module.getName()
831 << "` is empty but cannot be removed because an instance is "
832 "referenced by name";
833 diag.attachNote(FusedLoc::get(&getContext(), instancesWithSymbols))
834 <<
"these are instances with symbols";
839 if (liveElements.contains(module))
842 instanceGraph->
erase(instanceGraphNode);
assert(baseType &&"element must be base type")
static bool isDeletableDeclaration(Operation *op)
Return true if this is a wire or register we're allowed to delete.
static bool hasUnknownSideEffect(Operation *op)
static bool isDeclaration(Operation *op)
Return true if this is a wire or a register or a node.
static Block * getBodyBlock(FModuleLike mod)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
llvm::iterator_range< UseIterator > uses()
virtual void replaceInstance(InstanceOpInterface inst, InstanceOpInterface newInst)
Replaces an instance of a module with another instance.
virtual void erase(InstanceGraphNode *node)
Remove this module from the instance graph.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
This is an edge in the InstanceGraph.
connect(destination, source)
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This class represents the namespace in which InnerRef's can be resolved.