CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLReductions.cpp
Go to the documentation of this file.
1//===- FIRRTLReductions.cpp - Reduction patterns for the FIRRTL dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/Matchers.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/Support/Debug.h"
32
33#define DEBUG_TYPE "firrtl-reductions"
34
35using namespace mlir;
36using namespace circt;
37using namespace firrtl;
38using llvm::MapVector;
39using llvm::SmallDenseSet;
40using llvm::SmallSetVector;
41
42//===----------------------------------------------------------------------===//
43// Utilities
44//===----------------------------------------------------------------------===//
45
46namespace detail {
47/// A utility doing lazy construction of `SymbolTable`s and `SymbolUserMap`s,
48/// which is handy for reductions that need to look up a lot of symbols.
50 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
51
52 SymbolTable &getSymbolTable(Operation *op) {
53 return tables->getSymbolTable(op);
54 }
55 SymbolTable &getNearestSymbolTable(Operation *op) {
56 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
57 }
58
59 SymbolUserMap &getSymbolUserMap(Operation *op) {
60 auto it = userMaps.find(op);
61 if (it != userMaps.end())
62 return it->second;
63 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
64 }
65 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
66 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
67 }
68
69 void clear() {
70 tables = std::make_unique<SymbolTableCollection>();
71 userMaps.clear();
72 }
73
74private:
75 std::unique_ptr<SymbolTableCollection> tables;
77};
78} // namespace detail
79
80/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
81/// optional in case another type of module is instantiated.
82static std::optional<firrtl::FModuleOp>
83findInstantiatedModule(firrtl::InstanceOp instOp,
84 ::detail::SymbolCache &symbols) {
85 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
86 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
87 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
88 return moduleOp ? std::optional(moduleOp) : std::nullopt;
89}
90
91/// Utility to track the transitive size of modules.
93 void clear() { moduleSizes.clear(); }
94
95 uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols) {
96 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
97 return it->second;
98 uint64_t size = 1;
99 module->walk([&](Operation *op) {
100 size += 1;
101 if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
102 if (auto instModule = findInstantiatedModule(instOp, symbols))
103 size += getModuleSize(*instModule, symbols);
104 });
105 moduleSizes.insert({module, size});
106 return size;
107 }
108
109private:
110 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
111};
112
113/// Check that all connections to a value are invalids.
114static bool onlyInvalidated(Value arg) {
115 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
116 auto *op = use.getOwner();
117 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
118 return false;
119 if (use.getOperandNumber() != 0)
120 return false;
121 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
122 return false;
123 return true;
124 });
125}
126
127/// A tracker for track NLAs affected by a reduction. Performs the necessary
128/// cleanup steps in order to maintain IR validity after the reduction has
129/// applied. For example, removing an instance that forms part of an NLA path
130/// requires that NLA to be removed as well.
132 /// Clear the set of marked NLAs. Call this before attempting a reduction.
133 void clear() { nlasToRemove.clear(); }
134
135 /// Remove all marked annotations. Call this after applying a reduction in
136 /// order to validate the IR.
137 void remove(mlir::ModuleOp module) {
138 unsigned numRemoved = 0;
139 (void)numRemoved;
140 SymbolTableCollection symbolTables;
141 for (Operation &rootOp : *module.getBody()) {
142 if (!isa<firrtl::CircuitOp>(&rootOp))
143 continue;
144 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
145 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
146 for (auto sym : nlasToRemove) {
147 if (auto *op = symbolTable.lookup(sym)) {
148 if (symbolUserMap.useEmpty(op)) {
149 ++numRemoved;
150 op->erase();
151 }
152 }
153 }
154 }
155 LLVM_DEBUG({
156 unsigned numLost = nlasToRemove.size() - numRemoved;
157 if (numRemoved > 0 || numLost > 0) {
158 llvm::dbgs() << "Removed " << numRemoved << " NLAs";
159 if (numLost > 0)
160 llvm::dbgs() << " (" << numLost << " no longer there)";
161 llvm::dbgs() << "\n";
162 }
163 });
164 }
165
166 /// Mark all NLAs referenced in the given annotation as to be removed. This
167 /// can be an entire array or dictionary of annotations, and the function will
168 /// descend into child annotations appropriately.
169 void markNLAsInAnnotation(Attribute anno) {
170 if (auto dict = dyn_cast<DictionaryAttr>(anno)) {
171 if (auto field = dict.getAs<FlatSymbolRefAttr>("circt.nonlocal"))
172 nlasToRemove.insert(field.getAttr());
173 for (auto namedAttr : dict)
174 markNLAsInAnnotation(namedAttr.getValue());
175 } else if (auto array = dyn_cast<ArrayAttr>(anno)) {
176 for (auto attr : array)
177 markNLAsInAnnotation(attr);
178 }
179 }
180
181 /// Mark all NLAs referenced in an operation. Also traverses all nested
182 /// operations. Call this before removing an operation, to mark any associated
183 /// NLAs as to be removed as well.
184 void markNLAsInOperation(Operation *op) {
185 op->walk([&](Operation *op) {
186 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
187 markNLAsInAnnotation(annos);
188 });
189 }
190
191 /// The set of NLAs to remove, identified by their symbol.
192 llvm::DenseSet<StringAttr> nlasToRemove;
193};
194
195//===----------------------------------------------------------------------===//
196// Reduction patterns
197//===----------------------------------------------------------------------===//
198
199namespace {
200
201/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
202struct FIRRTLModuleExternalizer : public OpReduction<FModuleOp> {
203 void beforeReduction(mlir::ModuleOp op) override {
204 nlaRemover.clear();
205 symbols.clear();
206 moduleSizes.clear();
207 innerSymUses = reduce::InnerSymbolUses(op);
208 }
209 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
210
211 uint64_t match(FModuleOp module) override {
212 if (innerSymUses.hasInnerRef(module))
213 return 0;
214 return moduleSizes.getModuleSize(module, symbols);
215 }
216
217 LogicalResult rewrite(FModuleOp module) override {
218 // Hack up a list of known layers.
219 LayerSet layers;
220 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
221 for (auto attr : module.getPortTypes()) {
222 auto type = cast<TypeAttr>(attr).getValue();
223 if (auto refType = type_dyn_cast<RefType>(type))
224 if (auto layer = refType.getLayer())
225 layers.insert(layer);
226 }
227 SmallVector<Attribute, 4> layersArray;
228 layersArray.reserve(layers.size());
229 for (auto layer : layers)
230 layersArray.push_back(layer);
231
232 nlaRemover.markNLAsInOperation(module);
233 OpBuilder builder(module);
234 auto extmodule = FExtModuleOp::create(
235 builder, module->getLoc(),
236 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
237 module.getConventionAttr(), module.getPorts(),
238 builder.getArrayAttr(layersArray), StringRef(),
239 module.getAnnotationsAttr());
240 SymbolTable::setSymbolVisibility(extmodule,
241 SymbolTable::getSymbolVisibility(module));
242 module->erase();
243 return success();
244 }
245
246 std::string getName() const override { return "firrtl-module-externalizer"; }
247
248 ::detail::SymbolCache symbols;
249 NLARemover nlaRemover;
250 reduce::InnerSymbolUses innerSymUses;
251 ModuleSizeCache moduleSizes;
252};
253
254/// Invalidate all the leaf fields of a value with a given flippedness by
255/// connecting an invalid value to them. This function handles different FIRRTL
256/// types appropriately:
257/// - Ref types (probes): Creates wire infrastructure with ref.send/ref.define
258/// and invalidates the underlying wire.
259/// - Bundle/Vector types: Recursively descends into elements.
260/// - Base types: Creates InvalidValueOp and connects it.
261/// - Property types: Creates UnknownValueOp and assigns it.
262///
263/// This is useful for ensuring that all output ports of an instance or memory
264/// (including those nested in bundles) are properly invalidated.
265static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
266 TieOffCache &tieOffCache, bool flip = false) {
267 auto type = type_dyn_cast<FIRRTLType>(value.getType());
268 if (!type)
269 return;
270
271 // Handle ref types (probes) by creating wires and defining them properly.
272 if (auto refType = type_dyn_cast<RefType>(type)) {
273 // Input probes are illegal in FIRRTL.
274 assert(!flip && "input probes are not allowed");
275
276 auto underlyingType = refType.getType();
277
278 if (!refType.getForceable()) {
279 // For probe types: create underlying wire, ref.send, ref.define, and
280 // invalidate.
281 auto targetWire = WireOp::create(builder, underlyingType);
282 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
283 builder.create<RefDefineOp>(value, refSend.getResult());
284
285 // Invalidate the underlying wire.
286 auto invalid = tieOffCache.getInvalid(underlyingType);
287 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
288 return;
289 }
290
291 // For rwprobe types: create forceable wire, ref.define, and invalidate.
292 auto forceableWire =
293 WireOp::create(builder, underlyingType,
294 /*name=*/"", NameKindEnum::DroppableName,
295 /*annotations=*/ArrayRef<Attribute>{},
296 /*innerSym=*/StringAttr{},
297 /*forceable=*/true);
298
299 // The forceable wire returns both the wire and the rwprobe.
300 auto targetWire = forceableWire.getResult();
301 auto forceableRef = forceableWire.getDataRef();
302
303 builder.create<RefDefineOp>(value, forceableRef);
304
305 // Invalidate the underlying wire.
306 auto invalid = tieOffCache.getInvalid(underlyingType);
307 MatchingConnectOp::create(builder, targetWire, invalid);
308 return;
309 }
310
311 // Descend into bundles by creating subfield ops.
312 if (auto bundleType = type_dyn_cast<BundleType>(type)) {
313 for (auto element : llvm::enumerate(bundleType.getElements())) {
314 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
315 invalidateOutputs(builder, subfield, tieOffCache,
316 flip ^ element.value().isFlip);
317 if (subfield.use_empty())
318 subfield.getDefiningOp()->erase();
319 }
320 return;
321 }
322
323 // Descend into vectors by creating subindex ops.
324 if (auto vectorType = type_dyn_cast<FVectorType>(type)) {
325 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
326 auto subindex = builder.createOrFold<SubindexOp>(value, i);
327 invalidateOutputs(builder, subindex, tieOffCache, flip);
328 if (subindex.use_empty())
329 subindex.getDefiningOp()->erase();
330 }
331 return;
332 }
333
334 // Only drive outputs.
335 if (flip)
336 return;
337
338 // Create InvalidValueOp for FIRRTLBaseType.
339 if (auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
340 auto invalid = tieOffCache.getInvalid(baseType);
341 ConnectOp::create(builder, value, invalid);
342 return;
343 }
344
345 // For property types, use UnknownValueOp to tie off the connection.
346 if (auto propType = type_dyn_cast<PropertyType>(type)) {
347 auto unknown = tieOffCache.getUnknown(propType);
348 builder.create<PropAssignOp>(value, unknown);
349 }
350}
351
352/// Connect a value to every leave of a destination value.
353static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
354 Value value) {
355 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
356 if (!type)
357 return;
358 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
359 for (auto element : llvm::enumerate(bundleType.getElements()))
360 connectToLeafs(builder,
361 firrtl::SubfieldOp::create(builder, dest, element.index()),
362 value);
363 return;
364 }
365 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
366 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
367 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
368 value);
369 return;
370 }
371 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
372 if (!valueType)
373 return;
374 auto destWidth = type.getBitWidthOrSentinel();
375 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
376 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
377 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
378 if (!isa<firrtl::UIntType>(type)) {
379 if (isa<firrtl::SIntType>(type))
380 value = firrtl::AsSIntPrimOp::create(builder, value);
381 else
382 return;
383 }
384 firrtl::ConnectOp::create(builder, dest, value);
385}
386
387/// Reduce all leaf fields of a value through an XOR tree.
388static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
389 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
390 if (!type)
391 return;
392 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
393 for (auto element : llvm::enumerate(bundleType.getElements()))
394 reduceXor(
395 builder, into,
396 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
397 return;
398 }
399 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
400 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
401 reduceXor(builder, into,
402 builder.createOrFold<firrtl::SubindexOp>(value, i));
403 return;
404 }
405 if (!isa<firrtl::UIntType>(type)) {
406 if (isa<firrtl::SIntType>(type))
407 value = firrtl::AsUIntPrimOp::create(builder, value);
408 else
409 return;
410 }
411 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
412}
413
414/// A sample reduction pattern that maps `firrtl.instance` to a set of
415/// invalidated wires. This often shortcuts a long iterative process of connect
416/// invalidation, module externalization, and wire stripping
417struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
418 void beforeReduction(mlir::ModuleOp op) override {
419 erasedInsts.clear();
420 erasedModules.clear();
421 symbols.clear();
422 nlaRemover.clear();
423 moduleSizes.clear();
424 }
425 void afterReduction(mlir::ModuleOp op) override {
426 // Look into deleted modules to find additional instances that are no longer
427 // instantiated anywhere.
428 SmallVector<Operation *> worklist;
429 auto deadInsts = erasedInsts;
430 for (auto *op : erasedModules)
431 worklist.push_back(op);
432 while (!worklist.empty()) {
433 auto *op = worklist.pop_back_val();
434 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
435 op->walk([&](firrtl::InstanceOp instOp) {
436 auto moduleOp = cast<firrtl::FModuleLike>(
437 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
438 deadInsts.insert(instOp);
439 if (llvm::all_of(
440 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
441 [&](Operation *user) { return deadInsts.contains(user); })) {
442 LLVM_DEBUG(llvm::dbgs() << "- Removing transitively unused module `"
443 << moduleOp.getModuleName() << "`\n");
444 erasedModules.insert(moduleOp);
445 worklist.push_back(moduleOp);
446 }
447 });
448 }
449
450 for (auto *op : erasedInsts)
451 op->erase();
452 for (auto *op : erasedModules)
453 op->erase();
454 nlaRemover.remove(op);
455 }
456
457 uint64_t match(firrtl::InstanceOp instOp) override {
458 if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
459 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
460 return 0;
461 }
462
463 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
464 LLVM_DEBUG(llvm::dbgs()
465 << "Stubbing instance `" << instOp.getName() << "`\n");
466 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
467 TieOffCache tieOffCache(builder);
468 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
469 auto result = instOp.getResult(i);
470 auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" +
471 instOp.getPortName(i));
472 auto wire =
473 firrtl::WireOp::create(builder, result.getType(), name,
474 firrtl::NameKindEnum::DroppableName,
475 instOp.getPortAnnotation(i), StringAttr{})
476 .getResult();
477 invalidateOutputs(builder, wire, tieOffCache,
478 instOp.getPortDirection(i) == firrtl::Direction::In);
479 result.replaceAllUsesWith(wire);
480 }
481 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
482 auto moduleOp = cast<firrtl::FModuleLike>(
483 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
484 nlaRemover.markNLAsInOperation(instOp);
485 erasedInsts.insert(instOp);
486 if (llvm::all_of(
487 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
488 [&](Operation *user) { return erasedInsts.contains(user); })) {
489 LLVM_DEBUG(llvm::dbgs() << "- Removing now unused module `"
490 << moduleOp.getModuleName() << "`\n");
491 erasedModules.insert(moduleOp);
492 }
493 return success();
494 }
495
496 std::string getName() const override { return "instance-stubber"; }
497 bool acceptSizeIncrease() const override { return true; }
498
499 ::detail::SymbolCache symbols;
500 NLARemover nlaRemover;
501 llvm::DenseSet<Operation *> erasedInsts;
502 llvm::DenseSet<Operation *> erasedModules;
503 ModuleSizeCache moduleSizes;
504};
505
506/// A sample reduction pattern that maps `firrtl.mem` to a set of invalidated
507/// wires.
508struct MemoryStubber : public OpReduction<firrtl::MemOp> {
509 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
510 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
511 LogicalResult rewrite(firrtl::MemOp memOp) override {
512 LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
513 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
514 TieOffCache tieOffCache(builder);
515 Value xorInputs;
516 SmallVector<Value> outputs;
517 for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
518 auto result = memOp.getResult(i);
519 auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" +
520 memOp.getPortName(i));
521 auto wire =
522 firrtl::WireOp::create(builder, result.getType(), name,
523 firrtl::NameKindEnum::DroppableName,
524 memOp.getPortAnnotation(i), StringAttr{})
525 .getResult();
526 invalidateOutputs(builder, wire, tieOffCache, true);
527 result.replaceAllUsesWith(wire);
528
529 // Isolate the input and output data fields of the port.
530 Value input, output;
531 switch (memOp.getPortKind(i)) {
532 case firrtl::MemOp::PortKind::Read:
533 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
534 break;
535 case firrtl::MemOp::PortKind::Write:
536 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
537 break;
538 case firrtl::MemOp::PortKind::ReadWrite:
539 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
540 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
541 break;
542 case firrtl::MemOp::PortKind::Debug:
543 output = wire;
544 break;
545 }
546
547 if (!isa<firrtl::RefType>(result.getType())) {
548 // Reduce all input ports to a single one through an XOR tree.
549 unsigned numFields =
550 cast<firrtl::BundleType>(wire.getType()).getNumElements();
551 for (unsigned i = 0; i != numFields; ++i) {
552 if (i != 2 && i != 3 && i != 5)
553 reduceXor(builder, xorInputs,
554 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
555 }
556 if (input)
557 reduceXor(builder, xorInputs, input);
558 }
559
560 // Track the output port to hook it up to the XORd input later.
561 if (output)
562 outputs.push_back(output);
563 }
564
565 // Hook up the outputs.
566 for (auto output : outputs)
567 connectToLeafs(builder, output, xorInputs);
568
569 nlaRemover.markNLAsInOperation(memOp);
570 memOp->erase();
571 return success();
572 }
573 std::string getName() const override { return "memory-stubber"; }
574 bool acceptSizeIncrease() const override { return true; }
575 NLARemover nlaRemover;
576};
577
578/// Check whether an operation interacts with flows in any way, which can make
579/// replacement and operand forwarding harder in some cases.
580static bool isFlowSensitiveOp(Operation *op) {
581 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
582 SubaccessOp, ObjectSubfieldOp>(op);
583}
584
585/// A sample reduction pattern that replaces all uses of an operation with one
586/// of its operands. This can help pruning large parts of the expression tree
587/// rapidly.
588template <unsigned OpNum>
589struct FIRRTLOperandForwarder : public Reduction {
590 uint64_t match(Operation *op) override {
591 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
592 return 0;
593 if (isFlowSensitiveOp(op))
594 return 0;
595 auto resultTy =
596 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
597 auto opTy =
598 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
599 return resultTy && opTy &&
600 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
601 (resultTy.getBitWidthOrSentinel() == -1) ==
602 (opTy.getBitWidthOrSentinel() == -1) &&
603 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
604 }
605 LogicalResult rewrite(Operation *op) override {
606 assert(match(op));
607 ImplicitLocOpBuilder builder(op->getLoc(), op);
608 auto result = op->getResult(0);
609 auto operand = op->getOperand(OpNum);
610 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
611 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
612 auto resultWidth = resultTy.getBitWidthOrSentinel();
613 auto operandWidth = operandTy.getBitWidthOrSentinel();
614 Value newOp;
615 if (resultWidth < operandWidth)
616 newOp =
617 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
618 else if (resultWidth > operandWidth)
619 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
620 else
621 newOp = operand;
622 LLVM_DEBUG(llvm::dbgs() << "Forwarding " << newOp << " in " << *op << "\n");
623 result.replaceAllUsesWith(newOp);
624 reduce::pruneUnusedOps(op, *this);
625 return success();
626 }
627 std::string getName() const override {
628 return ("firrtl-operand" + Twine(OpNum) + "-forwarder").str();
629 }
630};
631
632/// A sample reduction pattern that replaces FIRRTL operations with a constant
633/// zero of their type.
634struct Constantifier : public Reduction {
635 void beforeReduction(mlir::ModuleOp op) override {
636 symbols.clear();
637
638 // Find valid dummy classes that we can use for anyref casts.
639 anyrefCastDummy.clear();
640 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
641 for (auto classOp : circuitOp.getOps<ClassOp>()) {
642 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
643 anyrefCastDummy.insert({circuitOp, classOp});
644 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
645 }
646 }
647 return WalkResult::skip();
648 });
649 }
650
651 uint64_t match(Operation *op) override {
652 if (op->hasTrait<OpTrait::ConstantLike>()) {
653 Attribute attr;
654 if (!matchPattern(op, m_Constant(&attr)))
655 return 0;
656 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
657 if (intAttr.getValue().isZero())
658 return 0;
659 if (auto strAttr = dyn_cast<StringAttr>(attr))
660 if (strAttr.empty())
661 return 0;
662 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
663 if (floatAttr.getValue().isZero())
664 return 0;
665 }
666 if (auto listOp = dyn_cast<ListCreateOp>(op))
667 if (listOp.getElements().empty())
668 return 0;
669 if (auto pathOp = dyn_cast<UnresolvedPathOp>(op))
670 if (pathOp.getTarget().empty())
671 return 0;
672
673 // Don't replace anyref casts that already target a dummy class.
674 if (auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
675 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
676 auto className =
677 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
678 if (anyrefCastDummyNames[circuitOp].contains(className))
679 return 0;
680 }
681
682 if (op->getNumResults() != 1)
683 return 0;
684 if (op->hasAttr("inner_sym"))
685 return 0;
686 if (isFlowSensitiveOp(op))
687 return 0;
688 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
689 DoubleType, ListType, PathType, AnyRefType>(
690 op->getResult(0).getType());
691 }
692
693 LogicalResult rewrite(Operation *op) override {
694 OpBuilder builder(op);
695 auto type = op->getResult(0).getType();
696
697 // Handle UInt/SInt types.
698 if (isa<UIntType, SIntType>(type)) {
699 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
700 if (width == -1)
701 width = 64;
702 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
703 APSInt(width, isa<UIntType>(type)));
704 op->replaceAllUsesWith(newOp);
705 reduce::pruneUnusedOps(op, *this);
706 return success();
707 }
708
709 // Handle property string types.
710 if (isa<StringType>(type)) {
711 auto attr = builder.getStringAttr("");
712 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
713 op->replaceAllUsesWith(newOp);
714 reduce::pruneUnusedOps(op, *this);
715 return success();
716 }
717
718 // Handle property integer types.
719 if (isa<FIntegerType>(type)) {
720 auto attr = builder.getIntegerAttr(builder.getIntegerType(64, true), 0);
721 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
722 op->replaceAllUsesWith(newOp);
723 reduce::pruneUnusedOps(op, *this);
724 return success();
725 }
726
727 // Handle property boolean types.
728 if (isa<BoolType>(type)) {
729 auto attr = builder.getBoolAttr(false);
730 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
731 op->replaceAllUsesWith(newOp);
732 reduce::pruneUnusedOps(op, *this);
733 return success();
734 }
735
736 // Handle property double types.
737 if (isa<DoubleType>(type)) {
738 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
739 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
740 op->replaceAllUsesWith(newOp);
741 reduce::pruneUnusedOps(op, *this);
742 return success();
743 }
744
745 // Handle property list types.
746 if (isa<ListType>(type)) {
747 auto newOp =
748 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
749 op->replaceAllUsesWith(newOp);
750 reduce::pruneUnusedOps(op, *this);
751 return success();
752 }
753
754 // Handle property path types.
755 if (isa<PathType>(type)) {
756 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(), "");
757 op->replaceAllUsesWith(newOp);
758 reduce::pruneUnusedOps(op, *this);
759 return success();
760 }
761
762 // Handle anyref types.
763 if (isa<AnyRefType>(type)) {
764 auto circuitOp = op->getParentOfType<CircuitOp>();
765 auto &dummy = anyrefCastDummy[circuitOp];
766 if (!dummy) {
767 OpBuilder::InsertionGuard guard(builder);
768 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
769 auto &symbolTable = symbols.getNearestSymbolTable(op);
770 dummy = ClassOp::create(builder, op->getLoc(), "Dummy", {}, {});
771 symbolTable.insert(dummy);
772 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
773 }
774 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy, "dummy");
775 auto anyrefOp =
776 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
777 op->replaceAllUsesWith(anyrefOp);
778 reduce::pruneUnusedOps(op, *this);
779 return success();
780 }
781
782 return failure();
783 }
784
785 std::string getName() const override { return "firrtl-constantifier"; }
786 bool acceptSizeIncrease() const override { return true; }
787
788 ::detail::SymbolCache symbols;
790 SmallDenseMap<CircuitOp, DenseSet<StringAttr>, 2> anyrefCastDummyNames;
791};
792
793/// A sample reduction pattern that replaces the right-hand-side of
794/// `firrtl.connect` and `firrtl.matchingconnect` operations with a
795/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
796/// connects and creates opportunities for reduction in DCE/CSE.
797struct ConnectInvalidator : public Reduction {
798 uint64_t match(Operation *op) override {
799 if (!isa<FConnectLike>(op))
800 return 0;
801 if (auto *srcOp = op->getOperand(1).getDefiningOp())
802 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
803 isa<InvalidValueOp>(srcOp))
804 return 0;
805 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
806 return type && type.isPassive();
807 }
808 LogicalResult rewrite(Operation *op) override {
809 assert(match(op));
810 auto rhs = op->getOperand(1);
811 OpBuilder builder(op);
812 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
813 auto *rhsOp = rhs.getDefiningOp();
814 op->setOperand(1, invOp);
815 if (rhsOp)
816 reduce::pruneUnusedOps(rhsOp, *this);
817 return success();
818 }
819 std::string getName() const override { return "connect-invalidator"; }
820 bool acceptSizeIncrease() const override { return true; }
821};
822
823/// A reduction pattern that removes FIRRTL annotations from ports and
824/// operations. This generates one match per annotation and port annotation,
825/// allowing selective removal of individual annotations.
826struct AnnotationRemover : public Reduction {
827 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
828 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
829
830 void matches(Operation *op,
831 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
832 uint64_t matchId = 0;
833
834 // Generate matches for regular annotations
835 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
836 for (unsigned i = 0; i < annos.size(); ++i)
837 addMatch(1, matchId++);
838
839 // Generate matches for port annotations
840 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations"))
841 for (auto portAnnoArray : portAnnos)
842 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
843 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
844 addMatch(1, matchId++);
845 }
846
847 LogicalResult rewriteMatches(Operation *op,
848 ArrayRef<uint64_t> matches) override {
849 // Convert matches to a set for fast lookup
850 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
851
852 // Lambda to process annotations and filter out matched ones
853 uint64_t matchId = 0;
854 auto processAnnotations =
855 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
856 SmallVector<Attribute> newAnnotations;
857 for (auto anno : annotations) {
858 if (!matchesSet.contains(matchId)) {
859 newAnnotations.push_back(anno);
860 } else {
861 // Mark NLAs in the removed annotation for cleanup
862 nlaRemover.markNLAsInAnnotation(anno);
863 }
864 matchId++;
865 }
866 return ArrayAttr::get(op->getContext(), newAnnotations);
867 };
868
869 // Remove regular annotations
870 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
871 op->setAttr("annotations", processAnnotations(annos.getValue()));
872 }
873
874 // Remove port annotations
875 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
876 SmallVector<Attribute> newPortAnnos;
877 for (auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
878 newPortAnnos.push_back(
879 processAnnotations(portAnnoArrayAttr.getValue()));
880 }
881 op->setAttr("portAnnotations",
882 ArrayAttr::get(op->getContext(), newPortAnnos));
883 }
884
885 return success();
886 }
887
888 std::string getName() const override { return "annotation-remover"; }
889 NLARemover nlaRemover;
890};
891
892/// A reduction pattern that replaces ResetType with UInt<1> across an entire
893/// circuit. This walks all operations in the circuit and replaces ResetType in
894/// results, block arguments, and attributes.
895struct SimplifyResets : public OpReduction<CircuitOp> {
896 uint64_t match(CircuitOp circuit) override {
897 uint64_t numResets = 0;
898 AttrTypeWalker walker;
899 walker.addWalk([&](ResetType type) { ++numResets; });
900
901 circuit.walk([&](Operation *op) {
902 for (auto result : op->getResults())
903 walker.walk(result.getType());
904
905 for (auto &region : op->getRegions())
906 for (auto &block : region)
907 for (auto arg : block.getArguments())
908 walker.walk(arg.getType());
909
910 walker.walk(op->getAttrDictionary());
911 });
912
913 return numResets;
914 }
915
916 LogicalResult rewrite(CircuitOp circuit) override {
917 auto uint1Type = UIntType::get(circuit->getContext(), 1, false);
918 auto constUint1Type = UIntType::get(circuit->getContext(), 1, true);
919
920 AttrTypeReplacer replacer;
921 replacer.addReplacement([&](ResetType type) {
922 return type.isConst() ? constUint1Type : uint1Type;
923 });
924 replacer.recursivelyReplaceElementsIn(circuit, /*replaceAttrs=*/true,
925 /*replaceLocs=*/false,
926 /*replaceTypes=*/true);
927
928 // Remove annotations related to InferResets pass
929 circuit.walk([&](Operation *op) {
930 // Remove operation annotations
932 return anno.isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
933 fullAsyncResetAnnoClass,
934 ignoreFullAsyncResetAnnoClass);
935 });
936
937 // Remove port annotations for module-like operations
938 if (auto module = dyn_cast<FModuleLike>(op)) {
939 AnnotationSet::removePortAnnotations(module, [&](unsigned portIdx,
940 Annotation anno) {
941 return anno.isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
942 fullAsyncResetAnnoClass,
943 ignoreFullAsyncResetAnnoClass);
944 });
945 }
946 });
947
948 return success();
949 }
950
951 std::string getName() const override { return "firrtl-simplify-resets"; }
952 bool acceptSizeIncrease() const override { return true; }
953};
954
955/// A sample reduction pattern that removes ports from the root `firrtl.module`
956/// if the port is not used or just invalidated.
957struct RootPortPruner : public OpReduction<firrtl::FModuleOp> {
958 uint64_t match(firrtl::FModuleOp module) override {
959 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
960 if (!circuit)
961 return 0;
962 return circuit.getNameAttr() == module.getNameAttr();
963 }
964 LogicalResult rewrite(firrtl::FModuleOp module) override {
965 assert(match(module));
966 size_t numPorts = module.getNumPorts();
967 llvm::BitVector dropPorts(numPorts);
968 for (unsigned i = 0; i != numPorts; ++i) {
969 if (onlyInvalidated(module.getArgument(i))) {
970 dropPorts.set(i);
971 for (auto *user :
972 llvm::make_early_inc_range(module.getArgument(i).getUsers()))
973 user->erase();
974 }
975 }
976 module.erasePorts(dropPorts);
977 return success();
978 }
979 std::string getName() const override { return "root-port-pruner"; }
980};
981
982/// A reduction pattern that removes all ports from the root `firrtl.extmodule`.
983/// Since extmodules have no body, all ports can be safely removed for reduction
984/// purposes.
985struct RootExtmodulePortPruner : public OpReduction<firrtl::FExtModuleOp> {
986 uint64_t match(firrtl::FExtModuleOp module) override {
987 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
988 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
989 return 0;
990 // We can remove all ports from the root extmodule
991 return module.getNumPorts();
992 }
993
994 LogicalResult rewrite(firrtl::FExtModuleOp module) override {
995 assert(match(module));
996 size_t numPorts = module.getNumPorts();
997 if (numPorts == 0)
998 return failure();
999
1000 llvm::BitVector dropPorts(numPorts);
1001 dropPorts.set(0, numPorts);
1002 module.erasePorts(dropPorts);
1003 return success();
1004 }
1005
1006 std::string getName() const override { return "root-extmodule-port-pruner"; }
1007};
1008
1009/// A sample reduction pattern that replaces instances of `firrtl.extmodule`
1010/// with wires.
1011struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {
1012 void beforeReduction(mlir::ModuleOp op) override {
1013 symbols.clear();
1014 nlaRemover.clear();
1015 }
1016 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1017
1018 uint64_t match(firrtl::InstanceOp instOp) override {
1019 return isa<firrtl::FExtModuleOp>(
1020 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1021 }
1022 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
1023 auto portInfo =
1024 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1025 symbols.getNearestSymbolTable(instOp)))
1026 .getPorts();
1027 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1028 TieOffCache tieOffCache(builder);
1029 SmallVector<Value> replacementWires;
1030 for (firrtl::PortInfo info : portInfo) {
1031 auto wire = firrtl::WireOp::create(
1032 builder, info.type,
1033 (Twine(instOp.getName()) + "_" + info.getName()).str())
1034 .getResult();
1035 if (info.isOutput()) {
1036 // Tie off output ports using TieOffCache.
1037 if (auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(info.type)) {
1038 auto inv = tieOffCache.getInvalid(baseType);
1039 firrtl::ConnectOp::create(builder, wire, inv);
1040 } else if (auto propType = dyn_cast<firrtl::PropertyType>(info.type)) {
1041 auto unknown = tieOffCache.getUnknown(propType);
1042 builder.create<firrtl::PropAssignOp>(wire, unknown);
1043 }
1044 }
1045 replacementWires.push_back(wire);
1046 }
1047 nlaRemover.markNLAsInOperation(instOp);
1048 instOp.replaceAllUsesWith(std::move(replacementWires));
1049 instOp->erase();
1050 return success();
1051 }
1052 std::string getName() const override { return "extmodule-instance-remover"; }
1053 bool acceptSizeIncrease() const override { return true; }
1054
1055 ::detail::SymbolCache symbols;
1056 NLARemover nlaRemover;
1057};
1058
1059/// A reduction pattern that removes unused ports from extmodules and regular
1060/// modules. This is particularly useful for reducing test cases with many probe
1061/// ports or other unused ports.
1062///
1063/// Shared helper functions for port pruning reductions.
1064struct PortPrunerHelpers {
1065 /// Compute which ports are unused across all instances of a module.
1066 template <typename ModuleOpType>
1067 static void computeUnusedInstancePorts(ModuleOpType module,
1068 ArrayRef<Operation *> users,
1069 llvm::BitVector &portsToRemove) {
1070 auto ports = module.getPorts();
1071 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1072 bool portUsed = false;
1073 for (auto *user : users) {
1074 if (auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1075 auto result = instOp.getResult(portIdx);
1076 if (!result.use_empty()) {
1077 portUsed = true;
1078 break;
1079 }
1080 }
1081 }
1082 if (!portUsed)
1083 portsToRemove.set(portIdx);
1084 }
1085 }
1086
1087 /// Update all instances of a module to remove the specified ports.
1088 static void
1089 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1090 const llvm::BitVector &portsToRemove) {
1091 // Update all instances to remove the corresponding results
1092 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1093 for (auto *user : users) {
1094 if (auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1095 instancesToUpdate.push_back(instOp);
1096 }
1097
1098 for (auto instOp : instancesToUpdate) {
1099 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1100
1101 // Manually replace uses, skipping erased ports
1102 size_t newResultIdx = 0;
1103 for (size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1104 ++oldResultIdx) {
1105 if (portsToRemove[oldResultIdx]) {
1106 // This port is being removed, assert it has no uses
1107 assert(instOp.getResult(oldResultIdx).use_empty() &&
1108 "removing port with uses");
1109 } else {
1110 // Replace uses of the old result with the new result
1111 instOp.getResult(oldResultIdx)
1112 .replaceAllUsesWith(newInst.getResult(newResultIdx));
1113 ++newResultIdx;
1114 }
1115 }
1116
1117 instOp->erase();
1118 }
1119 }
1120};
1121
1122/// Reduction to remove unused ports from regular modules.
1123struct ModulePortPruner : public OpReduction<firrtl::FModuleOp> {
1124 void beforeReduction(mlir::ModuleOp op) override {
1125 symbols.clear();
1126 nlaRemover.clear();
1127 portsToRemoveMap.clear();
1128 }
1129 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1130
1131 uint64_t match(firrtl::FModuleOp module) override {
1132 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1133 auto &userMap = symbols.getSymbolUserMap(tableOp);
1134 auto ports = module.getPorts();
1135 auto users = userMap.getUsers(module);
1136
1137 // Compute which ports to remove and cache the result
1138 llvm::BitVector portsToRemove(ports.size());
1139
1140 // If the module has no instances, aggressively remove ports that aren't
1141 // used within the module body itself
1142 if (users.empty()) {
1143 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1144 auto arg = module.getArgument(portIdx);
1145 if (arg.use_empty())
1146 portsToRemove.set(portIdx);
1147 }
1148 } else {
1149 // For modules with instances, check if ports are unused across all
1150 // instances
1151 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1152 portsToRemove);
1153 }
1154
1155 auto count = portsToRemove.count();
1156 if (count > 0)
1157 portsToRemoveMap[module] = std::move(portsToRemove);
1158
1159 return count;
1160 }
1161
1162 LogicalResult rewrite(firrtl::FModuleOp module) override {
1163 // Use the cached ports to remove from match()
1164 auto it = portsToRemoveMap.find(module);
1165 if (it == portsToRemoveMap.end())
1166 return failure();
1167
1168 const auto &portsToRemove = it->second;
1169
1170 // Get users for updating instances
1171 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1172 auto &userMap = symbols.getSymbolUserMap(tableOp);
1173 auto users = userMap.getUsers(module);
1174
1175 // Update all instances
1176 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1177 portsToRemove);
1178
1179 // Erase uses of port arguments within the module body.
1180 for (size_t portIdx = 0; portIdx < module.getNumPorts(); ++portIdx)
1181 if (portsToRemove[portIdx])
1182 // Recursively erase each user and its dependent operations
1183 for (auto *user :
1184 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
1185 reduce::pruneUnusedOps(user, *this);
1186
1187 // Remove the ports from the module
1188 module.erasePorts(portsToRemove);
1189
1190 return success();
1191 }
1192
1193 std::string getName() const override { return "module-port-pruner"; }
1194
1195 ::detail::SymbolCache symbols;
1196 NLARemover nlaRemover;
1197 DenseMap<firrtl::FModuleOp, llvm::BitVector> portsToRemoveMap;
1198};
1199
1200/// Reduction to remove unused ports from extmodules.
1201struct ExtmodulePortPruner : public OpReduction<firrtl::FExtModuleOp> {
1202 void beforeReduction(mlir::ModuleOp op) override {
1203 symbols.clear();
1204 nlaRemover.clear();
1205 portsToRemoveMap.clear();
1206 }
1207 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1208
1209 uint64_t match(firrtl::FExtModuleOp module) override {
1210 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1211 auto &userMap = symbols.getSymbolUserMap(tableOp);
1212 auto ports = module.getPorts();
1213 auto users = userMap.getUsers(module);
1214
1215 // Compute which ports to remove and cache the result
1216 llvm::BitVector portsToRemove(ports.size());
1217
1218 if (users.empty()) {
1219 // If the extmodule has no instances, aggressively remove all ports
1220 portsToRemove.set();
1221 } else {
1222 // For extmodules with instances, check if ports are unused across all
1223 // instances
1224 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1225 portsToRemove);
1226 }
1227
1228 auto count = portsToRemove.count();
1229 if (count > 0)
1230 portsToRemoveMap[module] = std::move(portsToRemove);
1231
1232 return count;
1233 }
1234
1235 LogicalResult rewrite(firrtl::FExtModuleOp module) override {
1236 // Use the cached ports to remove from match()
1237 auto it = portsToRemoveMap.find(module);
1238 if (it == portsToRemoveMap.end())
1239 return failure();
1240
1241 const auto &portsToRemove = it->second;
1242
1243 // Get users for updating instances
1244 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1245 auto &userMap = symbols.getSymbolUserMap(tableOp);
1246 auto users = userMap.getUsers(module);
1247
1248 // Update all instances.
1249 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1250 portsToRemove);
1251
1252 // Remove the ports from the module (no body to clean up for extmodules).
1253 module.erasePorts(portsToRemove);
1254
1255 return success();
1256 }
1257
1258 std::string getName() const override { return "extmodule-port-pruner"; }
1259
1260 ::detail::SymbolCache symbols;
1261 NLARemover nlaRemover;
1262 DenseMap<firrtl::FExtModuleOp, llvm::BitVector> portsToRemoveMap;
1263};
1264
1265/// A sample reduction pattern that pushes connected values through wires.
1266struct ConnectForwarder : public Reduction {
1267 void beforeReduction(mlir::ModuleOp op) override {
1268 domInfo = std::make_unique<DominanceInfo>(op);
1269 }
1270
1271 uint64_t match(Operation *op) override {
1272 if (!isa<firrtl::FConnectLike>(op))
1273 return 0;
1274 auto dest = op->getOperand(0);
1275 auto src = op->getOperand(1);
1276 auto *destOp = dest.getDefiningOp();
1277 auto *srcOp = src.getDefiningOp();
1278 if (dest == src)
1279 return 0;
1280
1281 // Ensure that the destination is something we should be able to forward
1282 // through.
1283 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1284 destOp))
1285 return 0;
1286
1287 // Ensure that the destination is connected to only once, and all uses of
1288 // the connection occur after the definition of the source.
1289 unsigned numConnects = 0;
1290 for (auto &use : dest.getUses()) {
1291 auto *op = use.getOwner();
1292 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1293 if (++numConnects > 1)
1294 return 0;
1295 continue;
1296 }
1297 // Check if srcOp properly dominates op, but op is not enclosed in srcOp.
1298 // This handles cross-block cases (e.g., layerblocks).
1299 if (srcOp &&
1300 !domInfo->properlyDominates(srcOp, op, /*enclosingOpOk=*/false))
1301 return 0;
1302 }
1303
1304 return 1;
1305 }
1306
1307 LogicalResult rewrite(Operation *op) override {
1308 auto dst = op->getOperand(0);
1309 auto src = op->getOperand(1);
1310 dst.replaceAllUsesWith(src);
1311 op->erase();
1312 if (auto *dstOp = dst.getDefiningOp())
1313 reduce::pruneUnusedOps(dstOp, *this);
1314 if (auto *srcOp = src.getDefiningOp())
1315 reduce::pruneUnusedOps(srcOp, *this);
1316 return success();
1317 }
1318
1319 std::string getName() const override { return "connect-forwarder"; }
1320
1321private:
1322 std::unique_ptr<DominanceInfo> domInfo;
1323};
1324
1325/// A sample reduction pattern that replaces a single-use wire and register with
1326/// an operand of the source value of the connection.
1327template <unsigned OpNum>
1328struct ConnectSourceOperandForwarder : public Reduction {
1329 uint64_t match(Operation *op) override {
1330 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1331 return 0;
1332 auto dest = op->getOperand(0);
1333 auto *destOp = dest.getDefiningOp();
1334
1335 // Ensure that the destination is used only once.
1336 if (!destOp || !destOp->hasOneUse() ||
1337 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1338 return 0;
1339
1340 auto *srcOp = op->getOperand(1).getDefiningOp();
1341 if (!srcOp || OpNum >= srcOp->getNumOperands())
1342 return 0;
1343
1344 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1345 auto opTy =
1346 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1347
1348 return resultTy && opTy &&
1349 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1350 ((resultTy.getBitWidthOrSentinel() == -1) ==
1351 (opTy.getBitWidthOrSentinel() == -1)) &&
1352 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1353 }
1354
1355 LogicalResult rewrite(Operation *op) override {
1356 auto *destOp = op->getOperand(0).getDefiningOp();
1357 auto *srcOp = op->getOperand(1).getDefiningOp();
1358 auto forwardedOperand = srcOp->getOperand(OpNum);
1359 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1360 Value newDest;
1361 if (auto wire = dyn_cast<firrtl::WireOp>(destOp))
1362 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1363 wire.getName())
1364 .getResult();
1365 else {
1366 auto regName = destOp->getAttrOfType<StringAttr>("name");
1367 // We can promote the register into a wire but we wouldn't do here because
1368 // the error might be caused by the register.
1369 auto clock = destOp->getOperand(0);
1370 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1371 clock, regName ? regName.str() : "")
1372 .getResult();
1373 }
1374
1375 // Create new connection between a new wire and the forwarded operand.
1376 builder.setInsertionPointAfter(op);
1377 if (isa<firrtl::ConnectOp>(op))
1378 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1379 else
1380 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1381
1382 // Remove the old connection and destination. We don't have to replace them
1383 // because destination has only one use.
1384 op->erase();
1385 destOp->erase();
1386 reduce::pruneUnusedOps(srcOp, *this);
1387
1388 return success();
1389 }
1390 std::string getName() const override {
1391 return ("connect-source-operand-" + Twine(OpNum) + "-forwarder").str();
1392 }
1393};
1394
1395/// A sample reduction pattern that tries to remove aggregate wires by replacing
1396/// all subaccesses with new independent wires. This can disentangle large
1397/// unused wires that are otherwise difficult to collect due to the subaccesses.
1398struct DetachSubaccesses : public Reduction {
1399 void beforeReduction(mlir::ModuleOp op) override { opsToErase.clear(); }
1400 void afterReduction(mlir::ModuleOp op) override {
1401 for (auto *op : opsToErase)
1402 op->dropAllReferences();
1403 for (auto *op : opsToErase)
1404 op->erase();
1405 }
1406 uint64_t match(Operation *op) override {
1407 // Only applies to wires and registers that are purely used in subaccess
1408 // operations.
1409 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1410 llvm::all_of(op->getUses(), [](auto &use) {
1411 return use.getOperandNumber() == 0 &&
1412 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1413 firrtl::SubaccessOp>(use.getOwner());
1414 });
1415 }
1416 LogicalResult rewrite(Operation *op) override {
1417 assert(match(op));
1418 OpBuilder builder(op);
1419 bool isWire = isa<firrtl::WireOp>(op);
1420 Value invalidClock;
1421 if (!isWire)
1422 invalidClock = firrtl::InvalidValueOp::create(
1423 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1424 for (Operation *user : llvm::make_early_inc_range(op->getUsers())) {
1425 builder.setInsertionPoint(user);
1426 auto type = user->getResult(0).getType();
1427 Operation *replOp;
1428 if (isWire)
1429 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1430 else
1431 replOp =
1432 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1433 user->replaceAllUsesWith(replOp);
1434 opsToErase.insert(user);
1435 }
1436 opsToErase.insert(op);
1437 return success();
1438 }
1439 std::string getName() const override { return "detach-subaccesses"; }
1440 llvm::DenseSet<Operation *> opsToErase;
1441};
1442
1443/// This reduction removes inner symbols on ops. Name preservation creates a lot
1444/// of node ops with symbols to keep name information but it also prevents
1445/// normal canonicalizations.
1446struct NodeSymbolRemover : public Reduction {
1447 void beforeReduction(mlir::ModuleOp op) override {
1448 innerSymUses = reduce::InnerSymbolUses(op);
1449 }
1450
1451 uint64_t match(Operation *op) override {
1452 // Only match ops with an inner symbol.
1453 auto sym = op->getAttrOfType<hw::InnerSymAttr>("inner_sym");
1454 if (!sym || sym.empty())
1455 return 0;
1456
1457 // Only match ops that have no references to their inner symbol.
1458 if (innerSymUses.hasInnerRef(op))
1459 return 0;
1460 return 1;
1461 }
1462
1463 LogicalResult rewrite(Operation *op) override {
1464 op->removeAttr("inner_sym");
1465 return success();
1466 }
1467
1468 std::string getName() const override { return "node-symbol-remover"; }
1469 bool acceptSizeIncrease() const override { return true; }
1470
1471 reduce::InnerSymbolUses innerSymUses;
1472};
1473
1474/// Check if inlining the referenced operation into the parent operation would
1475/// cause inner symbol collisions.
1476static bool
1477hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1478 hw::InnerSymbolTableCollection &innerSymTables) {
1479 // Get the inner symbol tables for both operations
1480 auto &targetTable = innerSymTables.getInnerSymbolTable(referencedOp);
1481 auto &parentTable = innerSymTables.getInnerSymbolTable(parentOp);
1482
1483 // Check if any inner symbol name in the target operation already exists
1484 // in the parent operation. Return failure() if a collision is found to stop
1485 // the walk early.
1486 LogicalResult walkResult = targetTable.walkSymbols(
1487 [&](StringAttr name, const hw::InnerSymTarget &target) -> LogicalResult {
1488 // Check if this symbol name exists in the parent operation
1489 if (parentTable.lookup(name)) {
1490 // Collision found, return failure to stop the walk
1491 return failure();
1492 }
1493 return success();
1494 });
1495
1496 // If the walk failed, it means we found a collision
1497 return failed(walkResult);
1498}
1499
1500/// A sample reduction pattern that eagerly inlines instances.
1501struct EagerInliner : public OpReduction<InstanceOp> {
1502 void beforeReduction(mlir::ModuleOp op) override {
1503 symbols.clear();
1504 nlaRemover.clear();
1505 nlaTables.clear();
1506 for (auto circuitOp : op.getOps<CircuitOp>())
1507 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1508 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1509 }
1510 void afterReduction(mlir::ModuleOp op) override {
1511 nlaRemover.remove(op);
1512 nlaTables.clear();
1513 innerSymTables.reset();
1514 }
1515
1516 uint64_t match(InstanceOp instOp) override {
1517 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1518 auto *moduleOp =
1519 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1520
1521 // Only inline FModuleOp instances
1522 if (!isa<FModuleOp>(moduleOp))
1523 return 0;
1524
1525 // Skip instances that participate in any NLAs
1526 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1527 if (!circuitOp)
1528 return 0;
1529 auto it = nlaTables.find(circuitOp);
1530 if (it == nlaTables.end() || !it->second)
1531 return 0;
1532 DenseSet<hw::HierPathOp> nlas;
1533 it->second->getInstanceNLAs(instOp, nlas);
1534 if (!nlas.empty())
1535 return 0;
1536
1537 // Check for inner symbol collisions between the referenced module and the
1538 // instance's parent module
1539 auto parentOp = instOp->getParentOfType<FModuleLike>();
1540 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1541 return 0;
1542
1543 return 1;
1544 }
1545
1546 LogicalResult rewrite(InstanceOp instOp) override {
1547 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1548 auto moduleOp = cast<FModuleOp>(
1549 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1550 bool isLastUse =
1551 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1552 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1553
1554 // Create wires to replace the instance results.
1555 IRRewriter rewriter(instOp);
1556 SmallVector<Value> argWires;
1557 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1558 auto result = instOp.getResult(i);
1559 auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" +
1560 instOp.getPortName(i));
1561 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1562 name, NameKindEnum::DroppableName,
1563 instOp.getPortAnnotation(i), StringAttr{})
1564 .getResult();
1565 result.replaceAllUsesWith(wire);
1566 argWires.push_back(wire);
1567 }
1568
1569 // Splice in the cloned module body.
1570 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1571
1572 // Make sure we remove any NLAs that go through this instance, and the
1573 // module if we're about the delete the module.
1574 nlaRemover.markNLAsInOperation(instOp);
1575 if (isLastUse)
1576 nlaRemover.markNLAsInOperation(moduleOp);
1577
1578 instOp.erase();
1579 clonedModuleOp.erase();
1580 return success();
1581 }
1582
1583 std::string getName() const override { return "firrtl-eager-inliner"; }
1584 bool acceptSizeIncrease() const override { return true; }
1585
1586 ::detail::SymbolCache symbols;
1587 NLARemover nlaRemover;
1588 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1589 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1590};
1591
1592/// A reduction pattern that eagerly inlines `ObjectOp`s.
1593struct ObjectInliner : public OpReduction<ObjectOp> {
1594 void beforeReduction(mlir::ModuleOp op) override {
1595 blocksToSort.clear();
1596 symbols.clear();
1597 nlaRemover.clear();
1598 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1599 }
1600 void afterReduction(mlir::ModuleOp op) override {
1601 for (auto *block : blocksToSort)
1602 mlir::sortTopologically(block);
1603 blocksToSort.clear();
1604 nlaRemover.remove(op);
1605 innerSymTables.reset();
1606 }
1607
1608 uint64_t match(ObjectOp objOp) override {
1609 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1610 auto *classOp =
1611 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1612
1613 // Only inline `ClassOp`s.
1614 if (!isa<ClassOp>(classOp))
1615 return 0;
1616
1617 // Check for inner symbol collisions between the referenced class and the
1618 // object's parent module.
1619 auto parentOp = objOp->getParentOfType<FModuleLike>();
1620 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1621 return 0;
1622
1623 // Verify all uses are ObjectSubfieldOp.
1624 for (auto *user : objOp.getResult().getUsers())
1625 if (!isa<ObjectSubfieldOp>(user))
1626 return 0;
1627
1628 return 1;
1629 }
1630
1631 LogicalResult rewrite(ObjectOp objOp) override {
1632 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1633 auto classOp = cast<ClassOp>(
1634 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1635 auto clonedClassOp = classOp.clone();
1636
1637 // Create wires to replace the ObjectSubfieldOp results.
1638 IRRewriter rewriter(objOp);
1639 SmallVector<Value> portWires;
1640 auto classType = objOp.getType();
1641
1642 // Create a wire for each port in the class
1643 for (unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1644 auto element = classType.getElement(i);
1645 auto name = rewriter.getStringAttr(Twine(objOp.getName()) + "_" +
1646 element.name.getValue());
1647 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1648 NameKindEnum::DroppableName,
1649 rewriter.getArrayAttr({}), StringAttr{})
1650 .getResult();
1651 portWires.push_back(wire);
1652 }
1653
1654 // Replace all ObjectSubfieldOp uses with corresponding wires
1655 SmallVector<ObjectSubfieldOp> subfieldOps;
1656 for (auto *user : objOp.getResult().getUsers()) {
1657 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1658 subfieldOps.push_back(subfieldOp);
1659 auto index = subfieldOp.getIndex();
1660 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1661 }
1662
1663 // Splice in the cloned class body.
1664 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1665
1666 // After inlining the class body, we need to eliminate `WireOps` since
1667 // `ClassOps` cannot contain wires. For each port wire, find its single
1668 // connect, remove it, and replace all uses of the wire with the assigned
1669 // value.
1670 SmallVector<FConnectLike> connectsToErase;
1671 for (auto portWire : portWires) {
1672 // Find a single value to replace the wire with, and collect all connects
1673 // to the wire such that we can erase them later.
1674 Value value;
1675 for (auto *user : portWire.getUsers()) {
1676 if (auto connect = dyn_cast<FConnectLike>(user)) {
1677 if (connect.getDest() == portWire) {
1678 value = connect.getSrc();
1679 connectsToErase.push_back(connect);
1680 }
1681 }
1682 }
1683
1684 // Be very conservative about deleting these wires. Other reductions may
1685 // leave class ports unconnected, which means that there isn't always a
1686 // clean replacement available here. Better to just leave the wires in the
1687 // IR and let the verifier fail later.
1688 if (value)
1689 portWire.replaceAllUsesWith(value);
1690 for (auto connect : connectsToErase)
1691 connect.erase();
1692 if (portWire.use_empty())
1693 portWire.getDefiningOp()->erase();
1694 connectsToErase.clear();
1695 }
1696
1697 // Make sure we remove any NLAs that go through this object.
1698 nlaRemover.markNLAsInOperation(objOp);
1699
1700 // Since the above forwarding of SSA values through wires can create
1701 // dominance issues, mark the region containing the object to be sorted
1702 // topologically.
1703 blocksToSort.insert(objOp->getBlock());
1704
1705 // Erase the object and cloned class.
1706 for (auto subfieldOp : subfieldOps)
1707 subfieldOp.erase();
1708 objOp.erase();
1709 clonedClassOp.erase();
1710 return success();
1711 }
1712
1713 std::string getName() const override { return "firrtl-object-inliner"; }
1714 bool acceptSizeIncrease() const override { return true; }
1715
1716 SetVector<Block *> blocksToSort;
1717 ::detail::SymbolCache symbols;
1718 NLARemover nlaRemover;
1719 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1720};
1721
1722/// Psuedo-reduction that sanitizes the names of things inside modules. This is
1723/// not an actual reduction, but often removes extraneous information that has
1724/// no bearing on the actual reduction (and would likely be removed before
1725/// sharing the reduction). This makes the following changes:
1726///
1727/// - All wires are renamed to "wire"
1728/// - All registers are renamed to "reg"
1729/// - All nodes are renamed to "node"
1730/// - All memories are renamed to "mem"
1731/// - All verification messages and labels are dropped
1732///
1733struct ModuleInternalNameSanitizer : public Reduction {
1734 uint64_t match(Operation *op) override {
1735 // Only match operations with names.
1736 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1737 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1738 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1739 firrtl::CoverOp>(op);
1740 }
1741 LogicalResult rewrite(Operation *op) override {
1742 TypeSwitch<Operation *, void>(op)
1743 .Case<firrtl::WireOp>([](auto op) { op.setName("wire"); })
1744 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1745 [](auto op) { op.setName("reg"); })
1746 .Case<firrtl::NodeOp>([](auto op) { op.setName("node"); })
1747 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1748 [](auto op) { op.setName("mem"); })
1749 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](auto op) {
1750 op->setAttr("message", StringAttr::get(op.getContext(), ""));
1751 op->setAttr("name", StringAttr::get(op.getContext(), ""));
1752 });
1753 return success();
1754 }
1755
1756 std::string getName() const override {
1757 return "module-internal-name-sanitizer";
1758 }
1759
1760 bool acceptSizeIncrease() const override { return true; }
1761
1762 bool isOneShot() const override { return true; }
1763};
1764
1765/// Psuedo-reduction that sanitizes module, instance, and port names. This
1766/// makes the following changes:
1767///
1768/// - All modules are given metasyntactic names ("Foo", "Bar", etc.)
1769/// - All instances are renamed to match the new module name
1770/// - All module ports are renamed in the following way:
1771/// - All clocks are reanemd to "clk"
1772/// - All resets are renamed to "rst"
1773/// - All references are renamed to "ref"
1774/// - Anything else is renamed to "port"
1775///
1776struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
1777
1778 const char *names[48] = {
1779 "Foo", "Bar", "Baz", "Qux", "Quux", "Quuux", "Quuuux",
1780 "Quz", "Corge", "Grault", "Bazola", "Ztesch", "Thud", "Grunt",
1781 "Bletch", "Fum", "Fred", "Jim", "Sheila", "Barney", "Flarp",
1782 "Zxc", "Spqr", "Wombat", "Shme", "Bongo", "Spam", "Eggs",
1783 "Snork", "Zot", "Blarg", "Wibble", "Toto", "Titi", "Tata",
1784 "Tutu", "Pippo", "Pluto", "Paperino", "Aap", "Noot", "Mies",
1785 "Oogle", "Foogle", "Boogle", "Zork", "Gork", "Bork"};
1786
1787 size_t nameIndex = 0;
1788
1789 const char *getName() {
1790 if (nameIndex >= 48)
1791 nameIndex = 0;
1792 return names[nameIndex++];
1793 };
1794
1795 size_t portNameIndex = 0;
1796
1797 char getPortName() {
1798 if (portNameIndex >= 26)
1799 portNameIndex = 0;
1800 return 'a' + portNameIndex++;
1801 }
1802
1803 void beforeReduction(mlir::ModuleOp op) override { nameIndex = 0; }
1804
1805 LogicalResult rewrite(firrtl::CircuitOp circuitOp) override {
1806
1807 firrtl::InstanceGraph iGraph(circuitOp);
1808
1809 auto *circuitName = getName();
1810 iGraph.getTopLevelModule().setName(circuitName);
1811 circuitOp.setName(circuitName);
1812
1813 for (auto *node : iGraph) {
1814 auto module = node->getModule<firrtl::FModuleLike>();
1815
1816 bool shouldReplacePorts = false;
1817 SmallVector<Attribute> newNames;
1818 if (auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1819 portNameIndex = 0;
1820 // TODO: The namespace should be unnecessary. However, some FIRRTL
1821 // passes expect that port names are unique.
1823 auto oldPorts = fmodule.getPorts();
1824 shouldReplacePorts = !oldPorts.empty();
1825 for (unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1826 auto port = oldPorts[i];
1827 auto newName = firrtl::FIRRTLTypeSwitch<Type, StringRef>(port.type)
1828 .Case<firrtl::ClockType>(
1829 [&](auto a) { return ns.newName("clk"); })
1830 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1831 [&](auto a) { return ns.newName("rst"); })
1832 .Case<firrtl::RefType>(
1833 [&](auto a) { return ns.newName("ref"); })
1834 .Default([&](auto a) {
1835 return ns.newName(Twine(getPortName()));
1836 });
1837 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1838 }
1839 fmodule->setAttr("portNames",
1840 ArrayAttr::get(fmodule.getContext(), newNames));
1841 }
1842
1843 if (module == iGraph.getTopLevelModule())
1844 continue;
1845 auto newName = StringAttr::get(circuitOp.getContext(), getName());
1846 module.setName(newName);
1847 for (auto *use : node->uses()) {
1848 auto useOp = use->getInstance();
1849 if (auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1850 instanceOp.setModuleName(newName);
1851 instanceOp.setName(newName);
1852 if (shouldReplacePorts)
1853 instanceOp.setPortNamesAttr(
1854 ArrayAttr::get(circuitOp.getContext(), newNames));
1855 } else if (auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1856 // ObjectOp stores the class name in its result type, so we need to
1857 // create a new ClassType with the new name and set it on the result.
1858 auto oldClassType = objectOp.getType();
1859 auto newClassType = firrtl::ClassType::get(
1860 circuitOp.getContext(), FlatSymbolRefAttr::get(newName),
1861 oldClassType.getElements());
1862 objectOp.getResult().setType(newClassType);
1863 objectOp.setName(newName);
1864 }
1865 }
1866 }
1867
1868 return success();
1869 }
1870
1871 std::string getName() const override { return "module-name-sanitizer"; }
1872
1873 bool acceptSizeIncrease() const override { return true; }
1874
1875 bool isOneShot() const override { return true; }
1876};
1877
1878/// A reduction pattern that groups modules by their port signature (types and
1879/// directions) and replaces instances with the smallest module in each group.
1880/// This helps reduce the IR by consolidating functionally equivalent modules
1881/// based on their interface.
1882///
1883/// The pattern works by:
1884/// 1. Grouping all modules by their port signature (port types and directions)
1885/// 2. For each group with multiple modules, finding the smallest module using
1886/// the module size cache
1887/// 3. Replacing all instances of larger modules with instances of the smallest
1888/// module in the same group
1889/// 4. Removing the larger modules from the circuit
1890///
1891/// This reduction is useful for reducing circuits where multiple modules have
1892/// the same interface but different implementations, allowing the reducer to
1893/// try the smallest implementation first.
1894struct ModuleSwapper : public OpReduction<InstanceOp> {
1895 // Per-circuit state containing all the information needed for module swapping
1896 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1897 struct CircuitState {
1898 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1899 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1900 std::unique_ptr<NLATable> nlaTable;
1901 };
1902
1903 void beforeReduction(mlir::ModuleOp op) override {
1904 symbols.clear();
1905 nlaRemover.clear();
1906 moduleSizes.clear();
1907 circuitStates.clear();
1908
1909 // Collect module type groups and NLA tables for all circuits up front
1910 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1911 auto &state = circuitStates[circuitOp];
1912 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1913 buildModuleTypeGroups(circuitOp, state);
1914 return WalkResult::skip();
1915 });
1916 }
1917 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1918
1919 /// Create a vector of port type-direction pairs for the given FIRRTL module.
1920 /// This ignores port names, allowing modules with the same port types and
1921 /// directions but different port names to be considered equivalent for
1922 /// swapping.
1923 PortSignature getModulePortSignature(FModuleLike module) {
1924 PortSignature signature;
1925 signature.reserve(module.getNumPorts());
1926 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1927 signature.emplace_back(module.getPortType(i), module.getPortDirection(i));
1928 return signature;
1929 }
1930
1931 /// Group modules by their port signature and find the smallest in each group.
1932 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1933 // Group modules by their port signature
1934 for (auto module : circuitOp.getBodyBlock()->getOps<FModuleLike>()) {
1935 auto signature = getModulePortSignature(module);
1936 state.moduleTypeGroups[signature].push_back(module);
1937 }
1938
1939 // For each group, find the smallest module
1940 for (auto &[signature, modules] : state.moduleTypeGroups) {
1941 if (modules.size() <= 1)
1942 continue;
1943
1944 FModuleLike smallestModule = nullptr;
1945 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1946
1947 for (auto module : modules) {
1948 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1949 if (size < smallestSize) {
1950 smallestSize = size;
1951 smallestModule = module;
1952 }
1953 }
1954
1955 // Map all modules in this group to the smallest one
1956 for (auto module : modules) {
1957 if (module != smallestModule) {
1958 state.instanceToCanonicalModule[module.getModuleNameAttr()] =
1959 smallestModule;
1960 }
1961 }
1962 }
1963 }
1964
1965 uint64_t match(InstanceOp instOp) override {
1966 // Get the circuit this instance belongs to
1967 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1968 assert(circuitOp);
1969 const auto &state = circuitStates.at(circuitOp);
1970
1971 // Skip instances that participate in any NLAs
1972 DenseSet<hw::HierPathOp> nlas;
1973 state.nlaTable->getInstanceNLAs(instOp, nlas);
1974 if (!nlas.empty())
1975 return 0;
1976
1977 // Check if this instance can be redirected to a smaller module
1978 auto moduleName = instOp.getModuleNameAttr().getAttr();
1979 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1980 if (!canonicalModule)
1981 return 0;
1982
1983 // Benefit is the size difference
1984 auto currentModule = cast<FModuleLike>(
1985 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1986 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1987 uint64_t canonicalSize =
1988 moduleSizes.getModuleSize(canonicalModule, symbols);
1989 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1990 }
1991
1992 LogicalResult rewrite(InstanceOp instOp) override {
1993 // Get the circuit this instance belongs to
1994 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1995 assert(circuitOp);
1996 const auto &state = circuitStates.at(circuitOp);
1997
1998 // Replace the instantiated module with the canonical module.
1999 auto canonicalModule = state.instanceToCanonicalModule.at(
2000 instOp.getModuleNameAttr().getAttr());
2001 auto canonicalName = canonicalModule.getModuleNameAttr();
2002 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2003
2004 // Update port names to match the canonical module
2005 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2006
2007 return success();
2008 }
2009
2010 std::string getName() const override { return "firrtl-module-swapper"; }
2011 bool acceptSizeIncrease() const override { return true; }
2012
2013private:
2014 ::detail::SymbolCache symbols;
2015 NLARemover nlaRemover;
2016 ModuleSizeCache moduleSizes;
2017
2018 // Per-circuit state containing all module swapping information
2019 DenseMap<CircuitOp, CircuitState> circuitStates;
2020};
2021
2022/// A reduction pattern that handles MustDedup annotations by replacing all
2023/// module names in a dedup group with a single module name. This helps reduce
2024/// the IR by consolidating module references that are required to be identical.
2025///
2026/// The pattern works by:
2027/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
2028/// 2. For each dedup group, using the first module as the canonical name
2029/// 3. Replacing all instance references to other modules in the group with
2030/// references to the canonical module
2031/// 4. Removing the non-canonical modules from the circuit
2032/// 5. Removing the processed MustDedup annotation
2033///
2034/// This reduction is particularly useful for reducing large circuits where
2035/// multiple modules are known to be identical but haven't been deduplicated
2036/// yet.
2037struct ForceDedup : public OpReduction<CircuitOp> {
2038 void beforeReduction(mlir::ModuleOp op) override {
2039 symbols.clear();
2040 nlaRemover.clear();
2041 modulesToErase.clear();
2042 moduleSizes.clear();
2043 }
2044 void afterReduction(mlir::ModuleOp op) override {
2045 nlaRemover.remove(op);
2046 for (auto mod : modulesToErase)
2047 mod->erase();
2048 }
2049
2050 /// Collect all MustDedup annotations and create matches for each dedup group.
2051 void matches(CircuitOp circuitOp,
2052 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2053 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2054 auto annotations = AnnotationSet(circuitOp);
2055 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2056 if (!anno.isClass(mustDeduplicateAnnoClass))
2057 continue;
2058
2059 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2060 if (!modulesAttr || modulesAttr.size() < 2)
2061 continue;
2062
2063 // Check that all modules have the same port signature. Malformed inputs
2064 // may have modules listed in a MustDedup annotation that have distinct
2065 // port types.
2066 uint64_t totalSize = 0;
2067 ArrayAttr portTypes;
2068 DenseBoolArrayAttr portDirections;
2069 bool allSame = true;
2070 for (auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2071 auto target = tokenizePath(moduleName);
2072 if (!target) {
2073 allSame = false;
2074 break;
2075 }
2076 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2077 if (!mod) {
2078 allSame = false;
2079 break;
2080 }
2081 totalSize += moduleSizes.getModuleSize(mod, symbols);
2082 if (!portTypes) {
2083 portTypes = mod.getPortTypesAttr();
2084 portDirections = mod.getPortDirectionsAttr();
2085 } else if (portTypes != mod.getPortTypesAttr() ||
2086 portDirections != mod.getPortDirectionsAttr()) {
2087 allSame = false;
2088 break;
2089 }
2090 }
2091 if (!allSame)
2092 continue;
2093
2094 // Each dedup group gets its own match with benefit proportional to group
2095 // size.
2096 addMatch(totalSize, annoIdx);
2097 }
2098 }
2099
2100 LogicalResult rewriteMatches(CircuitOp circuitOp,
2101 ArrayRef<uint64_t> matches) override {
2102 auto *context = circuitOp->getContext();
2103 NLATable nlaTable(circuitOp);
2104 hw::InnerSymbolTableCollection innerSymTables;
2105 auto annotations = AnnotationSet(circuitOp);
2106 SmallVector<Annotation> newAnnotations;
2107
2108 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2109 // Check if this annotation was selected.
2110 if (!llvm::is_contained(matches, annoIdx)) {
2111 newAnnotations.push_back(anno);
2112 continue;
2113 }
2114 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2115 assert(anno.isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2116 modulesAttr.size() >= 2);
2117
2118 // Extract module names from the dedup group.
2119 SmallVector<StringAttr> moduleNames;
2120 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2121 // Parse "~CircuitName|ModuleName" format.
2122 auto refStr = moduleRef.getValue();
2123 auto pipePos = refStr.find('|');
2124 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2125 auto moduleName = refStr.substr(pipePos + 1);
2126 moduleNames.push_back(StringAttr::get(context, moduleName));
2127 }
2128 }
2129
2130 // Simply drop the annotation if there's only one module.
2131 if (moduleNames.size() < 2)
2132 continue;
2133
2134 // Replace all instances and references to other modules with the
2135 // first module.
2136 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2137 nlaRemover.markNLAsInAnnotation(anno.getAttr());
2138 }
2139 if (newAnnotations.size() == annotations.size())
2140 return failure();
2141
2142 // Update circuit annotations.
2143 AnnotationSet newAnnoSet(newAnnotations, context);
2144 newAnnoSet.applyToOperation(circuitOp);
2145 return success();
2146 }
2147
2148 std::string getName() const override { return "firrtl-force-dedup"; }
2149 bool acceptSizeIncrease() const override { return true; }
2150
2151private:
2152 /// Replace all references to modules in the dedup group with the canonical
2153 /// module name
2154 void replaceModuleReferences(CircuitOp circuitOp,
2155 ArrayRef<StringAttr> moduleNames,
2156 NLATable &nlaTable,
2157 hw::InnerSymbolTableCollection &innerSymTables) {
2158 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2159 auto &symbolTable = symbols.getSymbolTable(tableOp);
2160 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2161 auto *context = circuitOp->getContext();
2162 auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
2163
2164 // Collect the modules.
2165 FModuleLike canonicalModule;
2166 SmallVector<FModuleLike> modulesToReplace;
2167 for (auto name : moduleNames) {
2168 if (auto mod = symbolTable.lookup<FModuleLike>(name)) {
2169 if (!canonicalModule)
2170 canonicalModule = mod;
2171 else
2172 modulesToReplace.push_back(mod);
2173 }
2174 }
2175 if (modulesToReplace.empty())
2176 return;
2177
2178 // Replace all instance references.
2179 auto canonicalName = canonicalModule.getModuleNameAttr();
2180 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2181 for (auto moduleName : moduleNames) {
2182 if (moduleName == canonicalName)
2183 continue;
2184 auto *symbolOp = symbolTable.lookup(moduleName);
2185 if (!symbolOp)
2186 continue;
2187 for (auto *user : symbolUserMap.getUsers(symbolOp)) {
2188 auto instOp = dyn_cast<InstanceOp>(user);
2189 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2190 continue;
2191 instOp.setModuleNameAttr(canonicalRef);
2192 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2193 }
2194 }
2195
2196 // Update NLAs to reference the canonical module instead of modules being
2197 // removed using NLATable for better performance.
2198 for (auto oldMod : modulesToReplace) {
2199 SmallVector<hw::HierPathOp> nlaOps(
2200 nlaTable.lookup(oldMod.getModuleNameAttr()));
2201 for (auto nlaOp : nlaOps) {
2202 nlaTable.erase(nlaOp);
2203 StringAttr oldModName = oldMod.getModuleNameAttr();
2204 StringAttr newModName = canonicalName;
2205 SmallVector<Attribute, 4> newPath;
2206 for (auto nameRef : nlaOp.getNamepath()) {
2207 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2208 if (ref.getModule() == oldModName) {
2209 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2210 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2211 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2212 if (oldInst && newInst) {
2213 // Get the first module name from the list (for
2214 // InstanceOp/ObjectOp, there's only one)
2215 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2216 auto newModNames = newInst.getReferencedModuleNamesAttr();
2217 if (!oldModNames.empty() && !newModNames.empty()) {
2218 oldModName = cast<StringAttr>(oldModNames[0]);
2219 newModName = cast<StringAttr>(newModNames[0]);
2220 }
2221 }
2222 }
2223 newPath.push_back(ref);
2224 } else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2225 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2226 } else {
2227 newPath.push_back(nameRef);
2228 }
2229 }
2230 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
2231 nlaTable.addNLA(nlaOp);
2232 }
2233 }
2234
2235 // Mark NLAs in modules to be removed.
2236 for (auto module : modulesToReplace) {
2237 nlaRemover.markNLAsInOperation(module);
2238 modulesToErase.insert(module);
2239 }
2240 }
2241
2242 ::detail::SymbolCache symbols;
2243 NLARemover nlaRemover;
2244 SetVector<FModuleLike> modulesToErase;
2245 ModuleSizeCache moduleSizes;
2246};
2247
2248/// A reduction pattern that moves `MustDedup` annotations from a module onto
2249/// its child modules. This pattern iterates over all MustDedup annotations,
2250/// collects all `FInstanceLike` ops in each module of the dedup group, and
2251/// creates new MustDedup annotations for corresponding instances across the
2252/// modules. Each set of corresponding instances becomes a separate match of the
2253/// reduction. The reduction also removes the original MustDedup annotation on
2254/// the parent module.
2255///
2256/// The pattern works by:
2257/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
2258/// 2. For each dedup group, collecting all FInstanceLike operations in each
2259/// module
2260/// 3. Grouping corresponding instances across modules by their position/name
2261/// 4. Creating new MustDedup annotations for each group of corresponding
2262/// instances
2263/// 5. Removing the original MustDedup annotation from the circuit
2264struct MustDedupChildren : public OpReduction<CircuitOp> {
2265 void beforeReduction(mlir::ModuleOp op) override {
2266 symbols.clear();
2267 nlaRemover.clear();
2268 }
2269 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
2270
2271 /// Collect all MustDedup annotations and create matches for each instance
2272 /// group.
2273 void matches(CircuitOp circuitOp,
2274 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2275 auto annotations = AnnotationSet(circuitOp);
2276 uint64_t matchId = 0;
2277
2278 DenseSet<StringRef> modulesAlreadyInMustDedup;
2279 for (auto [annoIdx, anno] : llvm::enumerate(annotations))
2280 if (anno.isClass(mustDeduplicateAnnoClass))
2281 if (auto modulesAttr = anno.getMember<ArrayAttr>("modules"))
2282 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2283 if (auto target = tokenizePath(moduleRef))
2284 modulesAlreadyInMustDedup.insert(target->module);
2285
2286 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2287 if (!anno.isClass(mustDeduplicateAnnoClass))
2288 continue;
2289
2290 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2291 if (!modulesAttr || modulesAttr.size() < 2)
2292 continue;
2293
2294 // Process each group of corresponding instances
2295 processInstanceGroups(
2296 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2297 matchId++;
2298
2299 // Make sure there are at least two distinct modules.
2300 SmallDenseSet<StringAttr, 4> moduleTargets;
2301 for (auto instOp : instanceGroup) {
2302 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2303 for (auto moduleName : moduleNames)
2304 moduleTargets.insert(cast<StringAttr>(moduleName));
2305 }
2306 if (moduleTargets.size() < 2)
2307 return;
2308
2309 // Make sure none of the modules are not yet in a must dedup
2310 // annotation.
2311 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2312 auto moduleNames = inst.getReferencedModuleNames();
2313 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2314 return modulesAlreadyInMustDedup.contains(moduleName);
2315 });
2316 }))
2317 return;
2318
2319 addMatch(1, matchId - 1);
2320 });
2321 }
2322 }
2323
2324 LogicalResult rewriteMatches(CircuitOp circuitOp,
2325 ArrayRef<uint64_t> matches) override {
2326 auto *context = circuitOp->getContext();
2327 auto annotations = AnnotationSet(circuitOp);
2328 SmallVector<Annotation> newAnnotations;
2329 uint64_t matchId = 0;
2330
2331 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2332 if (!anno.isClass(mustDeduplicateAnnoClass)) {
2333 newAnnotations.push_back(anno);
2334 continue;
2335 }
2336
2337 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2338 if (!modulesAttr || modulesAttr.size() < 2) {
2339 newAnnotations.push_back(anno);
2340 continue;
2341 }
2342
2343 processInstanceGroups(
2344 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2345 // Check if this instance group was selected
2346 if (!llvm::is_contained(matches, matchId++))
2347 return;
2348
2349 // Create the list of modules to put into this new annotation.
2350 SmallSetVector<StringAttr, 4> moduleTargets;
2351 for (auto instOp : instanceGroup) {
2352 auto moduleNames = instOp.getReferencedModuleNames();
2353 for (auto moduleName : moduleNames) {
2354 auto target = TokenAnnoTarget();
2355 target.circuit = circuitOp.getName();
2356 target.module = moduleName;
2357 moduleTargets.insert(target.toStringAttr(context));
2358 }
2359 }
2360
2361 // Create a new MustDedup annotation for this list of modules.
2362 SmallVector<NamedAttribute> newAnnoAttrs;
2363 newAnnoAttrs.emplace_back(
2364 StringAttr::get(context, "class"),
2365 StringAttr::get(context, mustDeduplicateAnnoClass));
2366 newAnnoAttrs.emplace_back(
2367 StringAttr::get(context, "modules"),
2368 ArrayAttr::get(context,
2369 SmallVector<Attribute>(moduleTargets.begin(),
2370 moduleTargets.end())));
2371
2372 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
2373 newAnnotations.emplace_back(newAnnoDict);
2374 });
2375
2376 // Keep the original annotation around.
2377 newAnnotations.push_back(anno);
2378 }
2379
2380 // Update circuit annotations
2381 AnnotationSet newAnnoSet(newAnnotations, context);
2382 newAnnoSet.applyToOperation(circuitOp);
2383 return success();
2384 }
2385
2386 std::string getName() const override { return "must-dedup-children"; }
2387 bool acceptSizeIncrease() const override { return true; }
2388
2389private:
2390 /// Helper function to process groups of corresponding instances from a
2391 /// MustDedup annotation. Calls the provided lambda for each group of
2392 /// corresponding instances across the modules. Only calls the lambda if there
2393 /// are at least 2 modules.
2394 void processInstanceGroups(
2395 CircuitOp circuitOp, ArrayAttr modulesAttr,
2396 llvm::function_ref<void(ArrayRef<FInstanceLike>)> callback) {
2397 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2398
2399 // Extract module names and get the actual modules
2400 SmallVector<FModuleLike> modules;
2401 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2402 if (auto target = tokenizePath(moduleRef))
2403 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2404 modules.push_back(mod);
2405
2406 // Need at least 2 modules for deduplication
2407 if (modules.size() < 2)
2408 return;
2409
2410 // Collect all FInstanceLike operations from each module and group them by
2411 // name. Instance names are a good key for matching instances across
2412 // modules. But they may not be unique, so we need to be careful to only
2413 // match up instances that are uniquely named within every module.
2414 struct InstanceGroup {
2415 SmallVector<FInstanceLike> instances;
2416 bool nameIsUnique = true;
2417 };
2418 MapVector<StringAttr, InstanceGroup> instanceGroups;
2419 for (auto module : modules) {
2421 module.walk([&](FInstanceLike instOp) {
2422 if (isa<ObjectOp>(instOp.getOperation()))
2423 return;
2424 auto name = instOp.getInstanceNameAttr();
2425 auto &group = instanceGroups[name];
2426 if (nameCounts[name]++ > 1)
2427 group.nameIsUnique = false;
2428 group.instances.push_back(instOp);
2429 });
2430 }
2431
2432 // Call the callback for each group of instances that are uniquely named and
2433 // consist of at least 2 instances.
2434 for (auto &[name, group] : instanceGroups)
2435 if (group.nameIsUnique && group.instances.size() >= 2)
2436 callback(group.instances);
2437 }
2438
2439 ::detail::SymbolCache symbols;
2440 NLARemover nlaRemover;
2441};
2442
2443struct LayerDisable : public OpReduction<CircuitOp> {
2444 LayerDisable(MLIRContext *context) {
2445 pm = std::make_unique<mlir::PassManager>(
2446 context, "builtin.module", mlir::OpPassManager::Nesting::Explicit);
2447 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2448 };
2449
2450 void beforeReduction(mlir::ModuleOp op) override { symbolRefAttrMap.clear(); }
2451
2452 void afterReduction(mlir::ModuleOp op) override { (void)pm->run(op); };
2453
2454 void matches(CircuitOp circuitOp,
2455 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2456 uint64_t matchId = 0;
2457
2458 SmallVector<FlatSymbolRefAttr> nestedRefs;
2459 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2460 LayerOp layerOp) {
2461 if (!rootRef)
2462 rootRef = layerOp.getSymNameAttr();
2463 else
2464 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2465
2466 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2467 addMatch(1, matchId++);
2468
2469 for (auto nestedLayerOp : layerOp.getOps<LayerOp>())
2470 addLayer(rootRef, nestedLayerOp);
2471
2472 if (!nestedRefs.empty())
2473 nestedRefs.pop_back();
2474 };
2475
2476 for (auto layerOp : circuitOp.getOps<LayerOp>())
2477 addLayer({}, layerOp);
2478 }
2479
2480 LogicalResult rewriteMatches(CircuitOp circuitOp,
2481 ArrayRef<uint64_t> matches) override {
2482 SmallVector<Attribute> disableLayers;
2483 if (auto existingDisables = circuitOp.getDisableLayersAttr()) {
2484 auto disableRange = existingDisables.getAsRange<Attribute>();
2485 disableLayers.append(disableRange.begin(), disableRange.end());
2486 }
2487 for (auto match : matches)
2488 disableLayers.push_back(symbolRefAttrMap.at(match));
2489
2490 circuitOp.setDisableLayersAttr(
2491 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2492
2493 return success();
2494 }
2495
2496 std::string getName() const override { return "firrtl-layer-disable"; }
2497
2498 std::unique_ptr<mlir::PassManager> pm;
2499 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2500};
2501
2502} // namespace
2503
2504/// A reduction pattern that removes elements from FIRRTL list create
2505/// operations. This generates one match per element in each list, allowing
2506/// selective removal of individual elements.
2507struct ListCreateElementRemover : public OpReduction<ListCreateOp> {
2508 void matches(ListCreateOp listOp,
2509 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2510 // Create one match for each element in the list
2511 auto elements = listOp.getElements();
2512 for (size_t i = 0; i < elements.size(); ++i)
2513 addMatch(1, i);
2514 }
2515
2516 LogicalResult rewriteMatches(ListCreateOp listOp,
2517 ArrayRef<uint64_t> matches) override {
2518 // Convert matches to a set for fast lookup
2519 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
2520
2521 // Collect elements that should be kept (not in matches)
2522 SmallVector<Value> newElements;
2523 auto elements = listOp.getElements();
2524 for (size_t i = 0; i < elements.size(); ++i) {
2525 if (!matchesSet.contains(i))
2526 newElements.push_back(elements[i]);
2527 }
2528
2529 // Create a new list with the remaining elements
2530 OpBuilder builder(listOp);
2531 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2532 listOp.getType(), newElements);
2533 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2534 listOp.erase();
2535
2536 return success();
2537 }
2538
2539 std::string getName() const override {
2540 return "firrtl-list-create-element-remover";
2541 }
2542};
2543
2544//===----------------------------------------------------------------------===//
2545// Reduction Registration
2546//===----------------------------------------------------------------------===//
2547
2550 // Gather a list of reduction patterns that we should try. Ideally these are
2551 // assigned reasonable benefit indicators (higher benefit patterns are
2552 // prioritized). For example, things that can knock out entire modules while
2553 // being cheap should be tried first (and thus have higher benefit), before
2554 // trying to tweak operands of individual arithmetic ops.
2555 patterns.add<SimplifyResets, 35>();
2556 patterns.add<ForceDedup, 34>();
2557 patterns.add<MustDedupChildren, 33>();
2558 patterns.add<AnnotationRemover, 32>();
2559 patterns.add<ModuleSwapper, 31>();
2560 patterns.add<LayerDisable, 30>(getContext());
2561 patterns.add<PassReduction, 29>(
2562 getContext(),
2563 firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,
2564 true);
2565 patterns.add<PassReduction, 28>(getContext(),
2566 firrtl::createLowerCHIRRTLPass(), true, true);
2567 patterns.add<PassReduction, 27>(getContext(), firrtl::createInferWidths(),
2568 true, true);
2569 patterns.add<PassReduction, 26>(getContext(), firrtl::createInferResets(),
2570 true, true);
2571 patterns.add<FIRRTLModuleExternalizer, 25>();
2572 patterns.add<InstanceStubber, 24>();
2573 patterns.add<MemoryStubber, 23>();
2574 patterns.add<EagerInliner, 22>();
2575 patterns.add<ObjectInliner, 22>();
2576 patterns.add<PassReduction, 21>(getContext(),
2577 firrtl::createLowerFIRRTLTypes(), true, true);
2578 patterns.add<PassReduction, 20>(getContext(), firrtl::createExpandWhens(),
2579 true, true);
2580 patterns.add<PassReduction, 19>(getContext(), firrtl::createInliner());
2581 patterns.add<PassReduction, 18>(getContext(), firrtl::createIMConstProp());
2582 patterns.add<PassReduction, 17>(
2583 getContext(),
2584 firrtl::createRemoveUnusedPorts({/*ignoreDontTouch=*/true}));
2585 patterns.add<NodeSymbolRemover, 15>();
2586 patterns.add<ConnectForwarder, 14>();
2587 patterns.add<ConnectInvalidator, 13>();
2588 patterns.add<Constantifier, 12>();
2589 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2590 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2591 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2593 patterns.add<DetachSubaccesses, 7>();
2594 patterns.add<ModulePortPruner, 7>();
2595 patterns.add<ExtmodulePortPruner, 6>();
2596 patterns.add<RootPortPruner, 5>();
2597 patterns.add<RootExtmodulePortPruner, 5>();
2598 patterns.add<ExtmoduleInstanceRemover, 4>();
2599 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2600 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2601 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2602 patterns.add<ModuleInternalNameSanitizer, 0>();
2603 patterns.add<ModuleNameSanitizer, 0>();
2604}
2605
2607 mlir::DialectRegistry &registry) {
2608 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
2609 dialect->addInterfaces<FIRRTLReducePatternDialectInterface>();
2610 });
2611}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
Definition NLATable.cpp:41
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
Definition NLATable.cpp:58
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Definition NLATable.cpp:68
Helper class to cache tie-off values for different FIRRTL types.
Definition FIRRTLUtils.h:62
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
Definition support.py:39
@ None
Don't explicitly preserve any named values.
Definition Passes.h:52
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
Definition LayerSet.h:43
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition HWOps.cpp:36
void info(Twine message)
Definition LSPUtils.cpp:20
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
A reduction pattern that applies an mlir::Pass.
Definition Reduction.h:142
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition Reduction.h:35
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:79
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:66
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:96
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:50
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)