CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLReductions.cpp
Go to the documentation of this file.
1//===- FIRRTLReductions.cpp - Reduction patterns for the FIRRTL dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
25#include "mlir/Analysis/TopologicalSortUtils.h"
26#include "mlir/IR/Dominance.h"
27#include "mlir/IR/ImplicitLocOpBuilder.h"
28#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APSInt.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/Support/Debug.h"
33
34#define DEBUG_TYPE "firrtl-reductions"
35
36using namespace mlir;
37using namespace circt;
38using namespace firrtl;
39using llvm::MapVector;
40using llvm::SmallDenseSet;
42
43//===----------------------------------------------------------------------===//
44// Utilities
45//===----------------------------------------------------------------------===//
46
47namespace detail {
48/// A utility doing lazy construction of `SymbolTable`s and `SymbolUserMap`s,
49/// which is handy for reductions that need to look up a lot of symbols.
51 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
52
53 SymbolTable &getSymbolTable(Operation *op) {
54 return tables->getSymbolTable(op);
55 }
56 SymbolTable &getNearestSymbolTable(Operation *op) {
57 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
58 }
59
60 SymbolUserMap &getSymbolUserMap(Operation *op) {
61 auto it = userMaps.find(op);
62 if (it != userMaps.end())
63 return it->second;
64 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
65 }
66 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
67 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
68 }
69
70 void clear() {
71 tables = std::make_unique<SymbolTableCollection>();
72 userMaps.clear();
73 }
74
75private:
76 std::unique_ptr<SymbolTableCollection> tables;
78};
79} // namespace detail
80
81/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
82/// optional in case another type of module is instantiated.
83static std::optional<firrtl::FModuleOp>
84findInstantiatedModule(firrtl::InstanceOp instOp,
85 ::detail::SymbolCache &symbols) {
86 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
87 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
88 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
89 return moduleOp ? std::optional(moduleOp) : std::nullopt;
90}
91
92/// Utility to track the transitive size of modules.
94 void clear() { moduleSizes.clear(); }
95
96 uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols) {
97 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
98 return it->second;
99 uint64_t size = 1;
100 module->walk([&](Operation *op) {
101 size += 1;
102 if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
103 if (auto instModule = findInstantiatedModule(instOp, symbols))
104 size += getModuleSize(*instModule, symbols);
105 });
106 moduleSizes.insert({module, size});
107 return size;
108 }
109
110private:
111 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
112};
113
114/// Check that all connections to a value are invalids.
115static bool onlyInvalidated(Value arg) {
116 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
117 auto *op = use.getOwner();
118 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
119 return false;
120 if (use.getOperandNumber() != 0)
121 return false;
122 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
123 return false;
124 return true;
125 });
126}
127
128/// A tracker for track NLAs affected by a reduction. Performs the necessary
129/// cleanup steps in order to maintain IR validity after the reduction has
130/// applied. For example, removing an instance that forms part of an NLA path
131/// requires that NLA to be removed as well.
133 /// Clear the set of marked NLAs. Call this before attempting a reduction.
134 void clear() { nlasToRemove.clear(); }
135
136 /// Remove all marked annotations. Call this after applying a reduction in
137 /// order to validate the IR.
138 void remove(mlir::ModuleOp module) {
139 unsigned numRemoved = 0;
140 (void)numRemoved;
141 SymbolTableCollection symbolTables;
142 for (Operation &rootOp : *module.getBody()) {
143 if (!isa<firrtl::CircuitOp>(&rootOp))
144 continue;
145 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
146 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
147 for (auto sym : nlasToRemove) {
148 if (auto *op = symbolTable.lookup(sym)) {
149 if (symbolUserMap.useEmpty(op)) {
150 ++numRemoved;
151 op->erase();
152 }
153 }
154 }
155 }
156 LLVM_DEBUG({
157 unsigned numLost = nlasToRemove.size() - numRemoved;
158 if (numRemoved > 0 || numLost > 0) {
159 llvm::dbgs() << "Removed " << numRemoved << " NLAs";
160 if (numLost > 0)
161 llvm::dbgs() << " (" << numLost << " no longer there)";
162 llvm::dbgs() << "\n";
163 }
164 });
165 }
166
167 /// Mark all NLAs referenced in the given annotation as to be removed. This
168 /// can be an entire array or dictionary of annotations, and the function will
169 /// descend into child annotations appropriately.
170 void markNLAsInAnnotation(Attribute anno) {
171 if (auto dict = dyn_cast<DictionaryAttr>(anno)) {
172 if (auto field = dict.getAs<FlatSymbolRefAttr>("circt.nonlocal"))
173 nlasToRemove.insert(field.getAttr());
174 for (auto namedAttr : dict)
175 markNLAsInAnnotation(namedAttr.getValue());
176 } else if (auto array = dyn_cast<ArrayAttr>(anno)) {
177 for (auto attr : array)
178 markNLAsInAnnotation(attr);
179 }
180 }
181
182 /// Mark all NLAs referenced in an operation. Also traverses all nested
183 /// operations. Call this before removing an operation, to mark any associated
184 /// NLAs as to be removed as well.
185 void markNLAsInOperation(Operation *op) {
186 op->walk([&](Operation *op) {
187 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
188 markNLAsInAnnotation(annos);
189 });
190 }
191
192 /// The set of NLAs to remove, identified by their symbol.
193 llvm::DenseSet<StringAttr> nlasToRemove;
194};
195
196//===----------------------------------------------------------------------===//
197// Reduction patterns
198//===----------------------------------------------------------------------===//
199
200namespace {
201
202/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
203struct FIRRTLModuleExternalizer : public OpReduction<FModuleOp> {
204 void beforeReduction(mlir::ModuleOp op) override {
205 nlaRemover.clear();
206 symbols.clear();
207 moduleSizes.clear();
208 innerSymUses = reduce::InnerSymbolUses(op);
209 }
210 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
211
212 uint64_t match(FModuleOp module) override {
213 if (innerSymUses.hasInnerRef(module))
214 return 0;
215 return moduleSizes.getModuleSize(module, symbols);
216 }
217
218 LogicalResult rewrite(FModuleOp module) override {
219 // Hack up a list of known layers.
220 LayerSet layers;
221 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
222 for (auto attr : module.getPortTypes()) {
223 auto type = cast<TypeAttr>(attr).getValue();
224 if (auto refType = type_dyn_cast<RefType>(type))
225 if (auto layer = refType.getLayer())
226 layers.insert(layer);
227 }
228 SmallVector<Attribute, 4> layersArray;
229 layersArray.reserve(layers.size());
230 for (auto layer : layers)
231 layersArray.push_back(layer);
232
233 nlaRemover.markNLAsInOperation(module);
234 OpBuilder builder(module);
235 auto extmodule = FExtModuleOp::create(
236 builder, module->getLoc(),
237 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
238 module.getConventionAttr(), module.getPorts(),
239 builder.getArrayAttr(layersArray), StringRef(),
240 module.getAnnotationsAttr());
241 SymbolTable::setSymbolVisibility(extmodule,
242 SymbolTable::getSymbolVisibility(module));
243 module->erase();
244 return success();
245 }
246
247 std::string getName() const override { return "firrtl-module-externalizer"; }
248
249 ::detail::SymbolCache symbols;
250 NLARemover nlaRemover;
251 reduce::InnerSymbolUses innerSymUses;
252 ModuleSizeCache moduleSizes;
253};
254
255/// Invalidate all the leaf fields of a value with a given flippedness by
256/// connecting an invalid value to them. This function handles different FIRRTL
257/// types appropriately:
258/// - Ref types (probes): Creates wire infrastructure with ref.send/ref.define
259/// and invalidates the underlying wire.
260/// - Bundle/Vector types: Recursively descends into elements.
261/// - Base types: Creates InvalidValueOp and connects it.
262/// - Property types: Creates UnknownValueOp and assigns it.
263///
264/// This is useful for ensuring that all output ports of an instance or memory
265/// (including those nested in bundles) are properly invalidated.
266static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
267 TieOffCache &tieOffCache, bool flip = false) {
268 auto type = type_dyn_cast<FIRRTLType>(value.getType());
269 if (!type)
270 return;
271
272 // Handle ref types (probes) by creating wires and defining them properly.
273 if (auto refType = type_dyn_cast<RefType>(type)) {
274 // Input probes are illegal in FIRRTL.
275 assert(!flip && "input probes are not allowed");
276
277 auto underlyingType = refType.getType();
278
279 if (!refType.getForceable()) {
280 // For probe types: create underlying wire, ref.send, ref.define, and
281 // invalidate.
282 auto targetWire = WireOp::create(builder, underlyingType);
283 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
284 builder.create<RefDefineOp>(value, refSend.getResult());
285
286 // Invalidate the underlying wire.
287 auto invalid = tieOffCache.getInvalid(underlyingType);
288 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
289 return;
290 }
291
292 // For rwprobe types: create forceable wire, ref.define, and invalidate.
293 auto forceableWire =
294 WireOp::create(builder, underlyingType,
295 /*name=*/"", NameKindEnum::DroppableName,
296 /*annotations=*/ArrayRef<Attribute>{},
297 /*innerSym=*/StringAttr{},
298 /*forceable=*/true);
299
300 // The forceable wire returns both the wire and the rwprobe.
301 auto targetWire = forceableWire.getResult();
302 auto forceableRef = forceableWire.getDataRef();
303
304 builder.create<RefDefineOp>(value, forceableRef);
305
306 // Invalidate the underlying wire.
307 auto invalid = tieOffCache.getInvalid(underlyingType);
308 MatchingConnectOp::create(builder, targetWire, invalid);
309 return;
310 }
311
312 // Descend into bundles by creating subfield ops.
313 if (auto bundleType = type_dyn_cast<BundleType>(type)) {
314 for (auto element : llvm::enumerate(bundleType.getElements())) {
315 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
316 invalidateOutputs(builder, subfield, tieOffCache,
317 flip ^ element.value().isFlip);
318 if (subfield.use_empty())
319 subfield.getDefiningOp()->erase();
320 }
321 return;
322 }
323
324 // Descend into vectors by creating subindex ops.
325 if (auto vectorType = type_dyn_cast<FVectorType>(type)) {
326 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
327 auto subindex = builder.createOrFold<SubindexOp>(value, i);
328 invalidateOutputs(builder, subindex, tieOffCache, flip);
329 if (subindex.use_empty())
330 subindex.getDefiningOp()->erase();
331 }
332 return;
333 }
334
335 // Only drive outputs.
336 if (flip)
337 return;
338
339 // Create InvalidValueOp for FIRRTLBaseType.
340 if (auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
341 auto invalid = tieOffCache.getInvalid(baseType);
342 ConnectOp::create(builder, value, invalid);
343 return;
344 }
345
346 // For property types, use UnknownValueOp to tie off the connection.
347 if (auto propType = type_dyn_cast<PropertyType>(type)) {
348 auto unknown = tieOffCache.getUnknown(propType);
349 builder.create<PropAssignOp>(value, unknown);
350 }
351}
352
353/// Connect a value to every leave of a destination value.
354static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
355 Value value) {
356 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
357 if (!type)
358 return;
359 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
360 for (auto element : llvm::enumerate(bundleType.getElements()))
361 connectToLeafs(builder,
362 firrtl::SubfieldOp::create(builder, dest, element.index()),
363 value);
364 return;
365 }
366 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
367 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
368 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
369 value);
370 return;
371 }
372 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
373 if (!valueType)
374 return;
375 auto destWidth = type.getBitWidthOrSentinel();
376 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
377 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
378 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
379 if (!isa<firrtl::UIntType>(type)) {
380 if (isa<firrtl::SIntType>(type))
381 value = firrtl::AsSIntPrimOp::create(builder, value);
382 else
383 return;
384 }
385 firrtl::ConnectOp::create(builder, dest, value);
386}
387
388/// Reduce all leaf fields of a value through an XOR tree.
389static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
390 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
391 if (!type)
392 return;
393 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
394 for (auto element : llvm::enumerate(bundleType.getElements()))
395 reduceXor(
396 builder, into,
397 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
398 return;
399 }
400 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
401 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
402 reduceXor(builder, into,
403 builder.createOrFold<firrtl::SubindexOp>(value, i));
404 return;
405 }
406 if (!isa<firrtl::UIntType>(type)) {
407 if (isa<firrtl::SIntType>(type))
408 value = firrtl::AsUIntPrimOp::create(builder, value);
409 else
410 return;
411 }
412 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
413}
414
415/// A sample reduction pattern that maps `firrtl.instance` to a set of
416/// invalidated wires. This often shortcuts a long iterative process of connect
417/// invalidation, module externalization, and wire stripping
418struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
419 void beforeReduction(mlir::ModuleOp op) override {
420 erasedInsts.clear();
421 erasedModules.clear();
422 symbols.clear();
423 nlaRemover.clear();
424 moduleSizes.clear();
425 }
426 void afterReduction(mlir::ModuleOp op) override {
427 // Look into deleted modules to find additional instances that are no longer
428 // instantiated anywhere.
429 SmallVector<Operation *> worklist;
430 auto deadInsts = erasedInsts;
431 for (auto *op : erasedModules)
432 worklist.push_back(op);
433 while (!worklist.empty()) {
434 auto *op = worklist.pop_back_val();
435 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
436 op->walk([&](firrtl::InstanceOp instOp) {
437 auto moduleOp = cast<firrtl::FModuleLike>(
438 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
439 deadInsts.insert(instOp);
440 if (llvm::all_of(
441 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
442 [&](Operation *user) { return deadInsts.contains(user); })) {
443 LLVM_DEBUG(llvm::dbgs() << "- Removing transitively unused module `"
444 << moduleOp.getModuleName() << "`\n");
445 erasedModules.insert(moduleOp);
446 worklist.push_back(moduleOp);
447 }
448 });
449 }
450
451 for (auto *op : erasedInsts)
452 op->erase();
453 for (auto *op : erasedModules)
454 op->erase();
455 nlaRemover.remove(op);
456 }
457
458 uint64_t match(firrtl::InstanceOp instOp) override {
459 if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
460 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
461 return 0;
462 }
463
464 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
465 LLVM_DEBUG(llvm::dbgs()
466 << "Stubbing instance `" << instOp.getName() << "`\n");
467 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
468 TieOffCache tieOffCache(builder);
469 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
470 auto result = instOp.getResult(i);
471 auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" +
472 instOp.getPortName(i));
473 auto wire =
474 firrtl::WireOp::create(builder, result.getType(), name,
475 firrtl::NameKindEnum::DroppableName,
476 instOp.getPortAnnotation(i), StringAttr{})
477 .getResult();
478 invalidateOutputs(builder, wire, tieOffCache,
479 instOp.getPortDirection(i) == firrtl::Direction::In);
480 result.replaceAllUsesWith(wire);
481 }
482 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
483 auto moduleOp = cast<firrtl::FModuleLike>(
484 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
485 nlaRemover.markNLAsInOperation(instOp);
486 erasedInsts.insert(instOp);
487 if (llvm::all_of(
488 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
489 [&](Operation *user) { return erasedInsts.contains(user); })) {
490 LLVM_DEBUG(llvm::dbgs() << "- Removing now unused module `"
491 << moduleOp.getModuleName() << "`\n");
492 erasedModules.insert(moduleOp);
493 }
494 return success();
495 }
496
497 std::string getName() const override { return "instance-stubber"; }
498 bool acceptSizeIncrease() const override { return true; }
499
500 ::detail::SymbolCache symbols;
501 NLARemover nlaRemover;
502 llvm::DenseSet<Operation *> erasedInsts;
503 llvm::DenseSet<Operation *> erasedModules;
504 ModuleSizeCache moduleSizes;
505};
506
507/// A sample reduction pattern that maps `firrtl.mem` to a set of invalidated
508/// wires.
509struct MemoryStubber : public OpReduction<firrtl::MemOp> {
510 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
511 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
512 LogicalResult rewrite(firrtl::MemOp memOp) override {
513 LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
514 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
515 TieOffCache tieOffCache(builder);
516 Value xorInputs;
517 SmallVector<Value> outputs;
518 for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
519 auto result = memOp.getResult(i);
520 auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" +
521 memOp.getPortName(i));
522 auto wire =
523 firrtl::WireOp::create(builder, result.getType(), name,
524 firrtl::NameKindEnum::DroppableName,
525 memOp.getPortAnnotation(i), StringAttr{})
526 .getResult();
527 invalidateOutputs(builder, wire, tieOffCache, true);
528 result.replaceAllUsesWith(wire);
529
530 // Isolate the input and output data fields of the port.
531 Value input, output;
532 switch (memOp.getPortKind(i)) {
533 case firrtl::MemOp::PortKind::Read:
534 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
535 break;
536 case firrtl::MemOp::PortKind::Write:
537 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
538 break;
539 case firrtl::MemOp::PortKind::ReadWrite:
540 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
541 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
542 break;
543 case firrtl::MemOp::PortKind::Debug:
544 output = wire;
545 break;
546 }
547
548 if (!isa<firrtl::RefType>(result.getType())) {
549 // Reduce all input ports to a single one through an XOR tree.
550 unsigned numFields =
551 cast<firrtl::BundleType>(wire.getType()).getNumElements();
552 for (unsigned i = 0; i != numFields; ++i) {
553 if (i != 2 && i != 3 && i != 5)
554 reduceXor(builder, xorInputs,
555 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
556 }
557 if (input)
558 reduceXor(builder, xorInputs, input);
559 }
560
561 // Track the output port to hook it up to the XORd input later.
562 if (output)
563 outputs.push_back(output);
564 }
565
566 // Hook up the outputs.
567 for (auto output : outputs)
568 connectToLeafs(builder, output, xorInputs);
569
570 nlaRemover.markNLAsInOperation(memOp);
571 memOp->erase();
572 return success();
573 }
574 std::string getName() const override { return "memory-stubber"; }
575 bool acceptSizeIncrease() const override { return true; }
576 NLARemover nlaRemover;
577};
578
579/// Check whether an operation interacts with flows in any way, which can make
580/// replacement and operand forwarding harder in some cases.
581static bool isFlowSensitiveOp(Operation *op) {
582 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
583 SubaccessOp, ObjectSubfieldOp>(op);
584}
585
586/// A sample reduction pattern that replaces all uses of an operation with one
587/// of its operands. This can help pruning large parts of the expression tree
588/// rapidly.
589template <unsigned OpNum>
590struct FIRRTLOperandForwarder : public Reduction {
591 uint64_t match(Operation *op) override {
592 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
593 return 0;
594 if (isFlowSensitiveOp(op))
595 return 0;
596 auto resultTy =
597 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
598 auto opTy =
599 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
600 return resultTy && opTy &&
601 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
602 (resultTy.getBitWidthOrSentinel() == -1) ==
603 (opTy.getBitWidthOrSentinel() == -1) &&
604 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
605 }
606 LogicalResult rewrite(Operation *op) override {
607 assert(match(op));
608 ImplicitLocOpBuilder builder(op->getLoc(), op);
609 auto result = op->getResult(0);
610 auto operand = op->getOperand(OpNum);
611 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
612 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
613 auto resultWidth = resultTy.getBitWidthOrSentinel();
614 auto operandWidth = operandTy.getBitWidthOrSentinel();
615 Value newOp;
616 if (resultWidth < operandWidth)
617 newOp =
618 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
619 else if (resultWidth > operandWidth)
620 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
621 else
622 newOp = operand;
623 LLVM_DEBUG(llvm::dbgs() << "Forwarding " << newOp << " in " << *op << "\n");
624 result.replaceAllUsesWith(newOp);
625 reduce::pruneUnusedOps(op, *this);
626 return success();
627 }
628 std::string getName() const override {
629 return ("firrtl-operand" + Twine(OpNum) + "-forwarder").str();
630 }
631};
632
633/// A sample reduction pattern that replaces FIRRTL operations with a constant
634/// zero of their type.
635struct Constantifier : public Reduction {
636 void beforeReduction(mlir::ModuleOp op) override {
637 symbols.clear();
638
639 // Find valid dummy classes that we can use for anyref casts.
640 anyrefCastDummy.clear();
641 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
642 for (auto classOp : circuitOp.getOps<ClassOp>()) {
643 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
644 anyrefCastDummy.insert({circuitOp, classOp});
645 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
646 }
647 }
648 return WalkResult::skip();
649 });
650 }
651
652 uint64_t match(Operation *op) override {
653 if (op->hasTrait<OpTrait::ConstantLike>()) {
654 Attribute attr;
655 if (!matchPattern(op, m_Constant(&attr)))
656 return 0;
657 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
658 if (intAttr.getValue().isZero())
659 return 0;
660 if (auto strAttr = dyn_cast<StringAttr>(attr))
661 if (strAttr.empty())
662 return 0;
663 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
664 if (floatAttr.getValue().isZero())
665 return 0;
666 }
667 if (auto listOp = dyn_cast<ListCreateOp>(op))
668 if (listOp.getElements().empty())
669 return 0;
670 if (auto pathOp = dyn_cast<UnresolvedPathOp>(op))
671 if (pathOp.getTarget().empty())
672 return 0;
673
674 // Don't replace anyref casts that already target a dummy class.
675 if (auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
676 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
677 auto className =
678 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
679 if (anyrefCastDummyNames[circuitOp].contains(className))
680 return 0;
681 }
682
683 if (op->getNumResults() != 1)
684 return 0;
685 if (op->hasAttr("inner_sym"))
686 return 0;
687 if (isFlowSensitiveOp(op))
688 return 0;
689 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
690 DoubleType, ListType, PathType, AnyRefType>(
691 op->getResult(0).getType());
692 }
693
694 LogicalResult rewrite(Operation *op) override {
695 OpBuilder builder(op);
696 auto type = op->getResult(0).getType();
697
698 // Handle UInt/SInt types.
699 if (isa<UIntType, SIntType>(type)) {
700 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
701 if (width == -1)
702 width = 64;
703 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
704 APSInt(width, isa<UIntType>(type)));
705 op->replaceAllUsesWith(newOp);
706 reduce::pruneUnusedOps(op, *this);
707 return success();
708 }
709
710 // Handle property string types.
711 if (isa<StringType>(type)) {
712 auto attr = builder.getStringAttr("");
713 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
714 op->replaceAllUsesWith(newOp);
715 reduce::pruneUnusedOps(op, *this);
716 return success();
717 }
718
719 // Handle property integer types.
720 if (isa<FIntegerType>(type)) {
721 auto attr = builder.getIntegerAttr(builder.getIntegerType(64, true), 0);
722 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
723 op->replaceAllUsesWith(newOp);
724 reduce::pruneUnusedOps(op, *this);
725 return success();
726 }
727
728 // Handle property boolean types.
729 if (isa<BoolType>(type)) {
730 auto attr = builder.getBoolAttr(false);
731 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
732 op->replaceAllUsesWith(newOp);
733 reduce::pruneUnusedOps(op, *this);
734 return success();
735 }
736
737 // Handle property double types.
738 if (isa<DoubleType>(type)) {
739 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
740 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
741 op->replaceAllUsesWith(newOp);
742 reduce::pruneUnusedOps(op, *this);
743 return success();
744 }
745
746 // Handle property list types.
747 if (isa<ListType>(type)) {
748 auto newOp =
749 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
750 op->replaceAllUsesWith(newOp);
751 reduce::pruneUnusedOps(op, *this);
752 return success();
753 }
754
755 // Handle property path types.
756 if (isa<PathType>(type)) {
757 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(), "");
758 op->replaceAllUsesWith(newOp);
759 reduce::pruneUnusedOps(op, *this);
760 return success();
761 }
762
763 // Handle anyref types.
764 if (isa<AnyRefType>(type)) {
765 auto circuitOp = op->getParentOfType<CircuitOp>();
766 auto &dummy = anyrefCastDummy[circuitOp];
767 if (!dummy) {
768 OpBuilder::InsertionGuard guard(builder);
769 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
770 auto &symbolTable = symbols.getNearestSymbolTable(op);
771 dummy = ClassOp::create(builder, op->getLoc(), "Dummy", {}, {});
772 symbolTable.insert(dummy);
773 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
774 }
775 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy, "dummy");
776 auto anyrefOp =
777 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
778 op->replaceAllUsesWith(anyrefOp);
779 reduce::pruneUnusedOps(op, *this);
780 return success();
781 }
782
783 return failure();
784 }
785
786 std::string getName() const override { return "firrtl-constantifier"; }
787 bool acceptSizeIncrease() const override { return true; }
788
789 ::detail::SymbolCache symbols;
791 SmallDenseMap<CircuitOp, DenseSet<StringAttr>, 2> anyrefCastDummyNames;
792};
793
794/// A sample reduction pattern that replaces the right-hand-side of
795/// `firrtl.connect` and `firrtl.matchingconnect` operations with a
796/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
797/// connects and creates opportunities for reduction in DCE/CSE.
798struct ConnectInvalidator : public Reduction {
799 uint64_t match(Operation *op) override {
800 if (!isa<FConnectLike>(op))
801 return 0;
802 if (auto *srcOp = op->getOperand(1).getDefiningOp())
803 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
804 isa<InvalidValueOp>(srcOp))
805 return 0;
806 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
807 return type && type.isPassive();
808 }
809 LogicalResult rewrite(Operation *op) override {
810 assert(match(op));
811 auto rhs = op->getOperand(1);
812 OpBuilder builder(op);
813 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
814 auto *rhsOp = rhs.getDefiningOp();
815 op->setOperand(1, invOp);
816 if (rhsOp)
817 reduce::pruneUnusedOps(rhsOp, *this);
818 return success();
819 }
820 std::string getName() const override { return "connect-invalidator"; }
821 bool acceptSizeIncrease() const override { return true; }
822};
823
824/// A reduction pattern that removes FIRRTL annotations from ports and
825/// operations. This generates one match per annotation and port annotation,
826/// allowing selective removal of individual annotations.
827struct AnnotationRemover : public Reduction {
828 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
829 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
830
831 void matches(Operation *op,
832 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
833 uint64_t matchId = 0;
834
835 // Generate matches for regular annotations
836 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
837 for (unsigned i = 0; i < annos.size(); ++i)
838 addMatch(1, matchId++);
839
840 // Generate matches for port annotations
841 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations"))
842 for (auto portAnnoArray : portAnnos)
843 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
844 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
845 addMatch(1, matchId++);
846 }
847
848 LogicalResult rewriteMatches(Operation *op,
849 ArrayRef<uint64_t> matches) override {
850 // Convert matches to a set for fast lookup
851 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
852
853 // Lambda to process annotations and filter out matched ones
854 uint64_t matchId = 0;
855 auto processAnnotations =
856 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
857 SmallVector<Attribute> newAnnotations;
858 for (auto anno : annotations) {
859 if (!matchesSet.contains(matchId)) {
860 newAnnotations.push_back(anno);
861 } else {
862 // Mark NLAs in the removed annotation for cleanup
863 nlaRemover.markNLAsInAnnotation(anno);
864 }
865 matchId++;
866 }
867 return ArrayAttr::get(op->getContext(), newAnnotations);
868 };
869
870 // Remove regular annotations
871 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
872 op->setAttr("annotations", processAnnotations(annos.getValue()));
873 }
874
875 // Remove port annotations
876 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
877 SmallVector<Attribute> newPortAnnos;
878 for (auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
879 newPortAnnos.push_back(
880 processAnnotations(portAnnoArrayAttr.getValue()));
881 }
882 op->setAttr("portAnnotations",
883 ArrayAttr::get(op->getContext(), newPortAnnos));
884 }
885
886 return success();
887 }
888
889 std::string getName() const override { return "annotation-remover"; }
890 NLARemover nlaRemover;
891};
892
893/// A reduction pattern that replaces ResetType with UInt<1> across an entire
894/// circuit. This walks all operations in the circuit and replaces ResetType in
895/// results, block arguments, and attributes.
896struct SimplifyResets : public OpReduction<CircuitOp> {
897 uint64_t match(CircuitOp circuit) override {
898 uint64_t numResets = 0;
899 AttrTypeWalker walker;
900 walker.addWalk([&](ResetType type) { ++numResets; });
901
902 circuit.walk([&](Operation *op) {
903 for (auto result : op->getResults())
904 walker.walk(result.getType());
905
906 for (auto &region : op->getRegions())
907 for (auto &block : region)
908 for (auto arg : block.getArguments())
909 walker.walk(arg.getType());
910
911 walker.walk(op->getAttrDictionary());
912 });
913
914 return numResets;
915 }
916
917 LogicalResult rewrite(CircuitOp circuit) override {
918 auto uint1Type = UIntType::get(circuit->getContext(), 1, false);
919 auto constUint1Type = UIntType::get(circuit->getContext(), 1, true);
920
921 AttrTypeReplacer replacer;
922 replacer.addReplacement([&](ResetType type) {
923 return type.isConst() ? constUint1Type : uint1Type;
924 });
925 replacer.recursivelyReplaceElementsIn(circuit, /*replaceAttrs=*/true,
926 /*replaceLocs=*/false,
927 /*replaceTypes=*/true);
928
929 // Remove annotations related to InferResets pass
930 circuit.walk([&](Operation *op) {
931 // Remove operation annotations
933 return anno.isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
934 fullAsyncResetAnnoClass,
935 ignoreFullAsyncResetAnnoClass);
936 });
937
938 // Remove port annotations for module-like operations
939 if (auto module = dyn_cast<FModuleLike>(op)) {
940 AnnotationSet::removePortAnnotations(module, [&](unsigned portIdx,
941 Annotation anno) {
942 return anno.isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
943 fullAsyncResetAnnoClass,
944 ignoreFullAsyncResetAnnoClass);
945 });
946 }
947 });
948
949 return success();
950 }
951
952 std::string getName() const override { return "firrtl-simplify-resets"; }
953 bool acceptSizeIncrease() const override { return true; }
954};
955
956/// A sample reduction pattern that removes ports from the root `firrtl.module`
957/// if the port is not used or just invalidated.
958struct RootPortPruner : public OpReduction<firrtl::FModuleOp> {
959 void matches(firrtl::FModuleOp module,
960 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
961 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
962 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
963 return;
964
965 // Generate one match per port that can be removed
966 size_t numPorts = module.getNumPorts();
967 for (unsigned i = 0; i != numPorts; ++i) {
968 if (onlyInvalidated(module.getArgument(i)))
969 addMatch(1, i);
970 }
971 }
972
973 LogicalResult rewriteMatches(firrtl::FModuleOp module,
974 ArrayRef<uint64_t> matches) override {
975 // Build a BitVector of ports to remove
976 llvm::BitVector dropPorts(module.getNumPorts());
977 for (auto portIdx : matches)
978 dropPorts.set(portIdx);
979
980 // Erase users of the ports being removed
981 for (auto portIdx : matches) {
982 for (auto *user :
983 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
984 user->erase();
985 }
986
987 // Remove the ports from the module
988 module.erasePorts(dropPorts);
989 return success();
990 }
991
992 std::string getName() const override { return "root-port-pruner"; }
993};
994
995/// A reduction pattern that removes all ports from the root `firrtl.extmodule`.
996/// Since extmodules have no body, all ports can be safely removed for reduction
997/// purposes.
998struct RootExtmodulePortPruner : public OpReduction<firrtl::FExtModuleOp> {
999 void matches(firrtl::FExtModuleOp module,
1000 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1001 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
1002 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
1003 return;
1004
1005 // Generate one match per port (all ports can be removed from root
1006 // extmodule)
1007 size_t numPorts = module.getNumPorts();
1008 for (unsigned i = 0; i != numPorts; ++i)
1009 addMatch(1, i);
1010 }
1011
1012 LogicalResult rewriteMatches(firrtl::FExtModuleOp module,
1013 ArrayRef<uint64_t> matches) override {
1014 if (matches.empty())
1015 return failure();
1016
1017 // Build a BitVector of ports to remove
1018 llvm::BitVector dropPorts(module.getNumPorts());
1019 for (auto portIdx : matches)
1020 dropPorts.set(portIdx);
1021
1022 // Remove the ports from the module
1023 module.erasePorts(dropPorts);
1024 return success();
1025 }
1026
1027 std::string getName() const override { return "root-extmodule-port-pruner"; }
1028};
1029
1030/// A sample reduction pattern that replaces instances of `firrtl.extmodule`
1031/// with wires.
1032struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {
1033 void beforeReduction(mlir::ModuleOp op) override {
1034 symbols.clear();
1035 nlaRemover.clear();
1036 }
1037 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1038
1039 uint64_t match(firrtl::InstanceOp instOp) override {
1040 return isa<firrtl::FExtModuleOp>(
1041 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1042 }
1043 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
1044 auto portInfo =
1045 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1046 symbols.getNearestSymbolTable(instOp)))
1047 .getPorts();
1048 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1049 TieOffCache tieOffCache(builder);
1050 SmallVector<Value> replacementWires;
1051 for (firrtl::PortInfo info : portInfo) {
1052 auto wire = firrtl::WireOp::create(
1053 builder, info.type,
1054 (Twine(instOp.getName()) + "_" + info.getName()).str())
1055 .getResult();
1056 if (info.isOutput()) {
1057 // Tie off output ports using TieOffCache.
1058 if (auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(info.type)) {
1059 auto inv = tieOffCache.getInvalid(baseType);
1060 firrtl::ConnectOp::create(builder, wire, inv);
1061 } else if (auto propType = dyn_cast<firrtl::PropertyType>(info.type)) {
1062 auto unknown = tieOffCache.getUnknown(propType);
1063 builder.create<firrtl::PropAssignOp>(wire, unknown);
1064 }
1065 }
1066 replacementWires.push_back(wire);
1067 }
1068 nlaRemover.markNLAsInOperation(instOp);
1069 instOp.replaceAllUsesWith(std::move(replacementWires));
1070 instOp->erase();
1071 return success();
1072 }
1073 std::string getName() const override { return "extmodule-instance-remover"; }
1074 bool acceptSizeIncrease() const override { return true; }
1075
1076 ::detail::SymbolCache symbols;
1077 NLARemover nlaRemover;
1078};
1079
1080/// A reduction pattern that removes unused ports from extmodules and regular
1081/// modules. This is particularly useful for reducing test cases with many probe
1082/// ports or other unused ports.
1083///
1084/// Shared helper functions for port pruning reductions.
1085struct PortPrunerHelpers {
1086 /// Compute which ports are unused across all instances of a module.
1087 template <typename ModuleOpType>
1088 static void computeUnusedInstancePorts(ModuleOpType module,
1089 ArrayRef<Operation *> users,
1090 llvm::BitVector &portsToRemove) {
1091 auto ports = module.getPorts();
1092 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1093 bool portUsed = false;
1094 for (auto *user : users) {
1095 if (auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1096 auto result = instOp.getResult(portIdx);
1097 if (!result.use_empty()) {
1098 portUsed = true;
1099 break;
1100 }
1101 }
1102 }
1103 if (!portUsed)
1104 portsToRemove.set(portIdx);
1105 }
1106 }
1107
1108 /// Update all instances of a module to remove the specified ports.
1109 static void
1110 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1111 const llvm::BitVector &portsToRemove) {
1112 // Update all instances to remove the corresponding results
1113 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1114 for (auto *user : users) {
1115 if (auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1116 instancesToUpdate.push_back(instOp);
1117 }
1118
1119 for (auto instOp : instancesToUpdate) {
1120 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1121
1122 // Manually replace uses, skipping erased ports
1123 size_t newResultIdx = 0;
1124 for (size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1125 ++oldResultIdx) {
1126 if (portsToRemove[oldResultIdx]) {
1127 // This port is being removed, assert it has no uses
1128 assert(instOp.getResult(oldResultIdx).use_empty() &&
1129 "removing port with uses");
1130 } else {
1131 // Replace uses of the old result with the new result
1132 instOp.getResult(oldResultIdx)
1133 .replaceAllUsesWith(newInst->getResult(newResultIdx));
1134 ++newResultIdx;
1135 }
1136 }
1137
1138 instOp->erase();
1139 }
1140 }
1141};
1142
1143/// Reduction to remove unused ports from regular modules.
1144struct ModulePortPruner : public OpReduction<firrtl::FModuleOp> {
1145 void beforeReduction(mlir::ModuleOp op) override {
1146 symbols.clear();
1147 nlaRemover.clear();
1148 }
1149 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1150
1151 void matches(firrtl::FModuleOp module,
1152 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1153 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1154 auto &userMap = symbols.getSymbolUserMap(tableOp);
1155 auto ports = module.getPorts();
1156 auto users = userMap.getUsers(module);
1157
1158 // Compute which ports can be removed. A port can only be removed if it
1159 // is unused in both the module body and across all instances.
1160 llvm::BitVector portsToRemove(ports.size());
1161
1162 // Check if ports are unused across all instances.
1163 if (!users.empty())
1164 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1165 portsToRemove);
1166 else
1167 // If there are no instances, all ports are candidates for removal.
1168 portsToRemove.set();
1169
1170 // Additionally check if ports are unused within the module body itself.
1171 // A port must be unused in both instances and the module body to be
1172 // removable.
1173 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1174 if (!portsToRemove[portIdx])
1175 continue;
1176 if (!module.getArgument(portIdx).use_empty())
1177 portsToRemove.reset(portIdx);
1178 }
1179
1180 // Generate one match per removable port.
1181 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1182 if (portsToRemove[portIdx])
1183 addMatch(1, portIdx);
1184 }
1185
1186 LogicalResult rewriteMatches(firrtl::FModuleOp module,
1187 ArrayRef<uint64_t> matches) override {
1188 if (matches.empty())
1189 return failure();
1190
1191 // Build a BitVector of ports to remove
1192 llvm::BitVector portsToRemove(module.getNumPorts());
1193 for (auto portIdx : matches)
1194 portsToRemove.set(portIdx);
1195
1196 // Get users for updating instances
1197 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1198 auto &userMap = symbols.getSymbolUserMap(tableOp);
1199 auto users = userMap.getUsers(module);
1200
1201 // Update all instances
1202 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1203 portsToRemove);
1204
1205 // Remove the ports from the module. We don't need to erase users because
1206 // matches() already ensured that these ports have no users.
1207 module.erasePorts(portsToRemove);
1208
1209 return success();
1210 }
1211
1212 std::string getName() const override { return "module-port-pruner"; }
1213
1214 ::detail::SymbolCache symbols;
1215 NLARemover nlaRemover;
1216};
1217
1218/// Reduction to remove unused ports from extmodules.
1219struct ExtmodulePortPruner : public OpReduction<firrtl::FExtModuleOp> {
1220 void beforeReduction(mlir::ModuleOp op) override {
1221 symbols.clear();
1222 nlaRemover.clear();
1223 }
1224 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1225
1226 void matches(firrtl::FExtModuleOp module,
1227 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1228 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1229 auto &userMap = symbols.getSymbolUserMap(tableOp);
1230 auto ports = module.getPorts();
1231 auto users = userMap.getUsers(module);
1232
1233 // Compute which ports can be removed
1234 llvm::BitVector portsToRemove(ports.size());
1235
1236 if (users.empty()) {
1237 // If the extmodule has no instances, aggressively remove all ports
1238 portsToRemove.set();
1239 } else {
1240 // For extmodules with instances, check if ports are unused across all
1241 // instances
1242 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1243 portsToRemove);
1244 }
1245
1246 // Generate one match per removable port
1247 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1248 if (portsToRemove[portIdx])
1249 addMatch(1, portIdx);
1250 }
1251
1252 LogicalResult rewriteMatches(firrtl::FExtModuleOp module,
1253 ArrayRef<uint64_t> matches) override {
1254 if (matches.empty())
1255 return failure();
1256
1257 // Build a BitVector of ports to remove
1258 llvm::BitVector portsToRemove(module.getNumPorts());
1259 for (auto portIdx : matches)
1260 portsToRemove.set(portIdx);
1261
1262 // Get users for updating instances
1263 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1264 auto &userMap = symbols.getSymbolUserMap(tableOp);
1265 auto users = userMap.getUsers(module);
1266
1267 // Update all instances.
1268 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1269 portsToRemove);
1270
1271 // Remove the ports from the module (no body to clean up for extmodules).
1272 module.erasePorts(portsToRemove);
1273
1274 return success();
1275 }
1276
1277 std::string getName() const override { return "extmodule-port-pruner"; }
1278
1279 ::detail::SymbolCache symbols;
1280 NLARemover nlaRemover;
1281};
1282
1283/// A sample reduction pattern that pushes connected values through wires.
1284struct ConnectForwarder : public Reduction {
1285 void beforeReduction(mlir::ModuleOp op) override {
1286 domInfo = std::make_unique<DominanceInfo>(op);
1287 }
1288
1289 uint64_t match(Operation *op) override {
1290 if (!isa<firrtl::FConnectLike>(op))
1291 return 0;
1292 auto dest = op->getOperand(0);
1293 auto src = op->getOperand(1);
1294 auto *destOp = dest.getDefiningOp();
1295 auto *srcOp = src.getDefiningOp();
1296 if (dest == src)
1297 return 0;
1298
1299 // Ensure that the destination is something we should be able to forward
1300 // through.
1301 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1302 destOp))
1303 return 0;
1304
1305 // Ensure that the destination is connected to only once, and all uses of
1306 // the connection occur after the definition of the source.
1307 unsigned numConnects = 0;
1308 for (auto &use : dest.getUses()) {
1309 auto *op = use.getOwner();
1310 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1311 if (++numConnects > 1)
1312 return 0;
1313 continue;
1314 }
1315 // Check if srcOp properly dominates op, but op is not enclosed in srcOp.
1316 // This handles cross-block cases (e.g., layerblocks).
1317 if (srcOp &&
1318 !domInfo->properlyDominates(srcOp, op, /*enclosingOpOk=*/false))
1319 return 0;
1320 }
1321
1322 return 1;
1323 }
1324
1325 LogicalResult rewrite(Operation *op) override {
1326 auto dst = op->getOperand(0);
1327 auto src = op->getOperand(1);
1328 dst.replaceAllUsesExcept(src, op);
1329 op->erase();
1330 SmallVector<Operation *> worklist(
1331 {dst.getDefiningOp(), src.getDefiningOp()});
1332 reduce::pruneUnusedOps(worklist, *this);
1333 return success();
1334 }
1335
1336 std::string getName() const override { return "connect-forwarder"; }
1337
1338private:
1339 std::unique_ptr<DominanceInfo> domInfo;
1340};
1341
1342/// A sample reduction pattern that replaces a single-use wire and register with
1343/// an operand of the source value of the connection.
1344template <unsigned OpNum>
1345struct ConnectSourceOperandForwarder : public Reduction {
1346 uint64_t match(Operation *op) override {
1347 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1348 return 0;
1349 auto dest = op->getOperand(0);
1350 auto *destOp = dest.getDefiningOp();
1351
1352 // Ensure that the destination is used only once.
1353 if (!destOp || !destOp->hasOneUse() ||
1354 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1355 return 0;
1356
1357 auto *srcOp = op->getOperand(1).getDefiningOp();
1358 if (!srcOp || OpNum >= srcOp->getNumOperands())
1359 return 0;
1360
1361 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1362 auto opTy =
1363 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1364
1365 return resultTy && opTy &&
1366 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1367 ((resultTy.getBitWidthOrSentinel() == -1) ==
1368 (opTy.getBitWidthOrSentinel() == -1)) &&
1369 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1370 }
1371
1372 LogicalResult rewrite(Operation *op) override {
1373 auto *destOp = op->getOperand(0).getDefiningOp();
1374 auto *srcOp = op->getOperand(1).getDefiningOp();
1375 auto forwardedOperand = srcOp->getOperand(OpNum);
1376 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1377 Value newDest;
1378 if (auto wire = dyn_cast<firrtl::WireOp>(destOp))
1379 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1380 wire.getName())
1381 .getResult();
1382 else {
1383 auto regName = destOp->getAttrOfType<StringAttr>("name");
1384 // We can promote the register into a wire but we wouldn't do here because
1385 // the error might be caused by the register.
1386 auto clock = destOp->getOperand(0);
1387 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1388 clock, regName ? regName.str() : "")
1389 .getResult();
1390 }
1391
1392 // Create new connection between a new wire and the forwarded operand.
1393 builder.setInsertionPointAfter(op);
1394 if (isa<firrtl::ConnectOp>(op))
1395 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1396 else
1397 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1398
1399 // Remove the old connection and destination. We don't have to replace them
1400 // because destination has only one use.
1401 op->erase();
1402 destOp->erase();
1403 reduce::pruneUnusedOps(srcOp, *this);
1404
1405 return success();
1406 }
1407 std::string getName() const override {
1408 return ("connect-source-operand-" + Twine(OpNum) + "-forwarder").str();
1409 }
1410};
1411
1412/// A sample reduction pattern that tries to remove aggregate wires by replacing
1413/// all subaccesses with new independent wires. This can disentangle large
1414/// unused wires that are otherwise difficult to collect due to the subaccesses.
1415struct DetachSubaccesses : public Reduction {
1416 void beforeReduction(mlir::ModuleOp op) override { opsToErase.clear(); }
1417 void afterReduction(mlir::ModuleOp op) override {
1418 for (auto *op : opsToErase)
1419 op->dropAllReferences();
1420 for (auto *op : opsToErase)
1421 op->erase();
1422 }
1423 uint64_t match(Operation *op) override {
1424 // Only applies to wires and registers that are purely used in subaccess
1425 // operations.
1426 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1427 llvm::all_of(op->getUses(), [](auto &use) {
1428 return use.getOperandNumber() == 0 &&
1429 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1430 firrtl::SubaccessOp>(use.getOwner());
1431 });
1432 }
1433 LogicalResult rewrite(Operation *op) override {
1434 assert(match(op));
1435 OpBuilder builder(op);
1436 bool isWire = isa<firrtl::WireOp>(op);
1437 Value invalidClock;
1438 if (!isWire)
1439 invalidClock = firrtl::InvalidValueOp::create(
1440 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1441 for (Operation *user : llvm::make_early_inc_range(op->getUsers())) {
1442 builder.setInsertionPoint(user);
1443 auto type = user->getResult(0).getType();
1444 Operation *replOp;
1445 if (isWire)
1446 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1447 else
1448 replOp =
1449 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1450 user->replaceAllUsesWith(replOp);
1451 opsToErase.insert(user);
1452 }
1453 opsToErase.insert(op);
1454 return success();
1455 }
1456 std::string getName() const override { return "detach-subaccesses"; }
1457 llvm::DenseSet<Operation *> opsToErase;
1458};
1459
1460/// This reduction removes inner symbols on ops. Name preservation creates a lot
1461/// of node ops with symbols to keep name information but it also prevents
1462/// normal canonicalizations.
1463struct NodeSymbolRemover : public Reduction {
1464 void beforeReduction(mlir::ModuleOp op) override {
1465 innerSymUses = reduce::InnerSymbolUses(op);
1466 }
1467
1468 uint64_t match(Operation *op) override {
1469 // Only match ops with an inner symbol.
1470 auto sym = op->getAttrOfType<hw::InnerSymAttr>("inner_sym");
1471 if (!sym || sym.empty())
1472 return 0;
1473
1474 // Only match ops that have no references to their inner symbol.
1475 if (innerSymUses.hasInnerRef(op))
1476 return 0;
1477 return 1;
1478 }
1479
1480 LogicalResult rewrite(Operation *op) override {
1481 op->removeAttr("inner_sym");
1482 return success();
1483 }
1484
1485 std::string getName() const override { return "node-symbol-remover"; }
1486 bool acceptSizeIncrease() const override { return true; }
1487
1488 reduce::InnerSymbolUses innerSymUses;
1489};
1490
1491/// Check if inlining the referenced operation into the parent operation would
1492/// cause inner symbol collisions.
1493static bool
1494hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1495 hw::InnerSymbolTableCollection &innerSymTables) {
1496 // Get the inner symbol tables for both operations
1497 auto &targetTable = innerSymTables.getInnerSymbolTable(referencedOp);
1498 auto &parentTable = innerSymTables.getInnerSymbolTable(parentOp);
1499
1500 // Check if any inner symbol name in the target operation already exists
1501 // in the parent operation. Return failure() if a collision is found to stop
1502 // the walk early.
1503 LogicalResult walkResult = targetTable.walkSymbols(
1504 [&](StringAttr name, const hw::InnerSymTarget &target) -> LogicalResult {
1505 // Check if this symbol name exists in the parent operation
1506 if (parentTable.lookup(name)) {
1507 // Collision found, return failure to stop the walk
1508 return failure();
1509 }
1510 return success();
1511 });
1512
1513 // If the walk failed, it means we found a collision
1514 return failed(walkResult);
1515}
1516
1517/// A sample reduction pattern that eagerly inlines instances.
1518struct EagerInliner : public OpReduction<InstanceOp> {
1519 void beforeReduction(mlir::ModuleOp op) override {
1520 symbols.clear();
1521 nlaRemover.clear();
1522 nlaTables.clear();
1523 for (auto circuitOp : op.getOps<CircuitOp>())
1524 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1525 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1526 }
1527 void afterReduction(mlir::ModuleOp op) override {
1528 nlaRemover.remove(op);
1529 nlaTables.clear();
1530 innerSymTables.reset();
1531 }
1532
1533 uint64_t match(InstanceOp instOp) override {
1534 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1535 auto *moduleOp =
1536 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1537
1538 // Only inline FModuleOp instances
1539 if (!isa<FModuleOp>(moduleOp))
1540 return 0;
1541
1542 // Skip instances that participate in any NLAs
1543 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1544 if (!circuitOp)
1545 return 0;
1546 auto it = nlaTables.find(circuitOp);
1547 if (it == nlaTables.end() || !it->second)
1548 return 0;
1549 DenseSet<hw::HierPathOp> nlas;
1550 it->second->getInstanceNLAs(instOp, nlas);
1551 if (!nlas.empty())
1552 return 0;
1553
1554 // Check for inner symbol collisions between the referenced module and the
1555 // instance's parent module
1556 auto parentOp = instOp->getParentOfType<FModuleLike>();
1557 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1558 return 0;
1559
1560 return 1;
1561 }
1562
1563 LogicalResult rewrite(InstanceOp instOp) override {
1564 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1565 auto moduleOp = cast<FModuleOp>(
1566 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1567 bool isLastUse =
1568 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1569 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1570
1571 // Create wires to replace the instance results.
1572 IRRewriter rewriter(instOp);
1573 SmallVector<Value> argWires;
1574 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1575 auto result = instOp.getResult(i);
1576 auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" +
1577 instOp.getPortName(i));
1578 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1579 name, NameKindEnum::DroppableName,
1580 instOp.getPortAnnotation(i), StringAttr{})
1581 .getResult();
1582 result.replaceAllUsesWith(wire);
1583 argWires.push_back(wire);
1584 }
1585
1586 // Splice in the cloned module body.
1587 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1588
1589 // Make sure we remove any NLAs that go through this instance, and the
1590 // module if we're about the delete the module.
1591 nlaRemover.markNLAsInOperation(instOp);
1592 if (isLastUse)
1593 nlaRemover.markNLAsInOperation(moduleOp);
1594
1595 instOp.erase();
1596 clonedModuleOp.erase();
1597 return success();
1598 }
1599
1600 std::string getName() const override { return "firrtl-eager-inliner"; }
1601 bool acceptSizeIncrease() const override { return true; }
1602
1603 ::detail::SymbolCache symbols;
1604 NLARemover nlaRemover;
1605 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1606 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1607};
1608
1609/// A reduction pattern that eagerly inlines `ObjectOp`s.
1610struct ObjectInliner : public OpReduction<ObjectOp> {
1611 void beforeReduction(mlir::ModuleOp op) override {
1612 blocksToSort.clear();
1613 symbols.clear();
1614 nlaRemover.clear();
1615 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1616 }
1617 void afterReduction(mlir::ModuleOp op) override {
1618 for (auto *block : blocksToSort)
1619 mlir::sortTopologically(block);
1620 blocksToSort.clear();
1621 nlaRemover.remove(op);
1622 innerSymTables.reset();
1623 }
1624
1625 uint64_t match(ObjectOp objOp) override {
1626 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1627 auto *classOp =
1628 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1629
1630 // Only inline `ClassOp`s.
1631 if (!isa<ClassOp>(classOp))
1632 return 0;
1633
1634 // Check for inner symbol collisions between the referenced class and the
1635 // object's parent module.
1636 auto parentOp = objOp->getParentOfType<FModuleLike>();
1637 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1638 return 0;
1639
1640 // Verify all uses are ObjectSubfieldOp.
1641 for (auto *user : objOp.getResult().getUsers())
1642 if (!isa<ObjectSubfieldOp>(user))
1643 return 0;
1644
1645 return 1;
1646 }
1647
1648 LogicalResult rewrite(ObjectOp objOp) override {
1649 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1650 auto classOp = cast<ClassOp>(
1651 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1652 auto clonedClassOp = classOp.clone();
1653
1654 // Create wires to replace the ObjectSubfieldOp results.
1655 IRRewriter rewriter(objOp);
1656 SmallVector<Value> portWires;
1657 auto classType = objOp.getType();
1658
1659 // Create a wire for each port in the class
1660 for (unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1661 auto element = classType.getElement(i);
1662 auto name = rewriter.getStringAttr(Twine(objOp.getName()) + "_" +
1663 element.name.getValue());
1664 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1665 NameKindEnum::DroppableName,
1666 rewriter.getArrayAttr({}), StringAttr{})
1667 .getResult();
1668 portWires.push_back(wire);
1669 }
1670
1671 // Replace all ObjectSubfieldOp uses with corresponding wires
1672 SmallVector<ObjectSubfieldOp> subfieldOps;
1673 for (auto *user : objOp.getResult().getUsers()) {
1674 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1675 subfieldOps.push_back(subfieldOp);
1676 auto index = subfieldOp.getIndex();
1677 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1678 }
1679
1680 // Splice in the cloned class body.
1681 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1682
1683 // After inlining the class body, we need to eliminate `WireOps` since
1684 // `ClassOps` cannot contain wires. For each port wire, find its single
1685 // connect, remove it, and replace all uses of the wire with the assigned
1686 // value.
1687 SmallVector<FConnectLike> connectsToErase;
1688 for (auto portWire : portWires) {
1689 // Find a single value to replace the wire with, and collect all connects
1690 // to the wire such that we can erase them later.
1691 Value value;
1692 for (auto *user : portWire.getUsers()) {
1693 if (auto connect = dyn_cast<FConnectLike>(user)) {
1694 if (connect.getDest() == portWire) {
1695 value = connect.getSrc();
1696 connectsToErase.push_back(connect);
1697 }
1698 }
1699 }
1700
1701 // Be very conservative about deleting these wires. Other reductions may
1702 // leave class ports unconnected, which means that there isn't always a
1703 // clean replacement available here. Better to just leave the wires in the
1704 // IR and let the verifier fail later.
1705 if (value)
1706 portWire.replaceAllUsesWith(value);
1707 for (auto connect : connectsToErase)
1708 connect.erase();
1709 if (portWire.use_empty())
1710 portWire.getDefiningOp()->erase();
1711 connectsToErase.clear();
1712 }
1713
1714 // Make sure we remove any NLAs that go through this object.
1715 nlaRemover.markNLAsInOperation(objOp);
1716
1717 // Since the above forwarding of SSA values through wires can create
1718 // dominance issues, mark the region containing the object to be sorted
1719 // topologically.
1720 blocksToSort.insert(objOp->getBlock());
1721
1722 // Erase the object and cloned class.
1723 for (auto subfieldOp : subfieldOps)
1724 subfieldOp.erase();
1725 objOp.erase();
1726 clonedClassOp.erase();
1727 return success();
1728 }
1729
1730 std::string getName() const override { return "firrtl-object-inliner"; }
1731 bool acceptSizeIncrease() const override { return true; }
1732
1733 SetVector<Block *> blocksToSort;
1734 ::detail::SymbolCache symbols;
1735 NLARemover nlaRemover;
1736 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1737};
1738
1739/// Reduction that converts `regreset` to `reg` by dropping reset and init
1740/// value.
1741struct ResetDisconnector : public OpReduction<RegResetOp> {
1742 uint64_t match(RegResetOp op) override { return 1; }
1743
1744 LogicalResult rewrite(RegResetOp regResetOp) override {
1745 ImplicitLocOpBuilder builder(regResetOp.getLoc(), regResetOp);
1746 auto regOp = RegOp::create(
1747 builder, regResetOp.getResult().getType(), regResetOp.getClockVal(),
1748 regResetOp.getNameAttr(), regResetOp.getNameKindAttr(),
1749 regResetOp.getAnnotationsAttr(), regResetOp.getInnerSymAttr(),
1750 regResetOp.getForceableAttr());
1751
1752 regResetOp.getResult().replaceAllUsesWith(regOp.getResult());
1753 if (regResetOp.getForceable())
1754 regResetOp.getRef().replaceAllUsesWith(regOp.getRef());
1755 regResetOp.erase();
1756
1757 return success();
1758 }
1759
1760 std::string getName() const override { return "reset-disconnector"; }
1761};
1762
1763/// Psuedo-reduction that sanitizes the names of things inside modules. This is
1764/// not an actual reduction, but often removes extraneous information that has
1765/// no bearing on the actual reduction (and would likely be removed before
1766/// sharing the reduction). This makes the following changes:
1767///
1768/// - All wires are renamed to "wire"
1769/// - All registers are renamed to "reg"
1770/// - All nodes are renamed to "node"
1771/// - All memories are renamed to "mem"
1772/// - All verification messages and labels are dropped
1773///
1774struct ModuleInternalNameSanitizer : public Reduction {
1775 uint64_t match(Operation *op) override {
1776 // Only match operations with names.
1777 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1778 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1779 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1780 firrtl::CoverOp>(op);
1781 }
1782 LogicalResult rewrite(Operation *op) override {
1783 TypeSwitch<Operation *, void>(op)
1784 .Case<firrtl::WireOp>([](auto op) { op.setName("wire"); })
1785 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1786 [](auto op) { op.setName("reg"); })
1787 .Case<firrtl::NodeOp>([](auto op) { op.setName("node"); })
1788 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1789 [](auto op) { op.setName("mem"); })
1790 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](auto op) {
1791 op->setAttr("message", StringAttr::get(op.getContext(), ""));
1792 op->setAttr("name", StringAttr::get(op.getContext(), ""));
1793 });
1794 return success();
1795 }
1796
1797 std::string getName() const override {
1798 return "module-internal-name-sanitizer";
1799 }
1800
1801 bool acceptSizeIncrease() const override { return true; }
1802
1803 bool isOneShot() const override { return true; }
1804};
1805
1806/// Psuedo-reduction that sanitizes module, instance, and port names. This
1807/// makes the following changes:
1808///
1809/// - All modules are given metasyntactic names ("Foo", "Bar", etc.)
1810/// - All instances are renamed to match the new module name
1811/// - All module ports are renamed in the following way:
1812/// - All clocks are reanemd to "clk"
1813/// - All resets are renamed to "rst"
1814/// - All references are renamed to "ref"
1815/// - Anything else is renamed to "port"
1816///
1817struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
1818
1820 size_t portNameIndex = 0;
1821
1822 char getPortName() {
1823 if (portNameIndex >= 26)
1824 portNameIndex = 0;
1825 return 'a' + portNameIndex++;
1826 }
1827
1828 void beforeReduction(mlir::ModuleOp op) override { nameGenerator.reset(); }
1829
1830 LogicalResult rewrite(firrtl::CircuitOp circuitOp) override {
1831
1832 // Analyses used to aid the rewrite.
1833 firrtl::InstanceGraph iGraph(circuitOp);
1834 NLATable nlaTable(circuitOp);
1835 SymbolTable symTable(circuitOp);
1836 CircuitNamespace ns(circuitOp);
1837
1838 // Rename symbols and NLAs.
1839 auto renameModule = [&](firrtl::FModuleLike mod,
1840 StringAttr newName) -> LogicalResult {
1841 StringAttr oldName = mod.getModuleNameAttr();
1842 if (failed(symTable.rename(mod, newName)))
1843 return failure();
1844 nlaTable.renameModule(oldName, newName);
1845 return success();
1846 };
1847
1848 // Set the top-modulefirst so that the circuit gets the first metasyntactic
1849 // name, i.e., "Foo".
1850 auto topModule = iGraph.getTopLevelModule();
1851 auto *ctx = circuitOp.getContext();
1853 topModule.getModuleName())) {
1854 auto newTopName = StringAttr::get(ctx, nameGenerator.getNextName(ns));
1855 if (failed(renameModule(topModule, newTopName)))
1856 return failure();
1857 circuitOp.setName(newTopName.getValue());
1858 }
1859
1860 for (auto *node : iGraph) {
1861 auto module = node->getModule<firrtl::FModuleLike>();
1862
1863 bool shouldReplacePorts = false;
1864 SmallVector<Attribute> newPortNames;
1865 if (auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1866 portNameIndex = 0;
1867 // TODO: The namespace should be unnecessary. However, some FIRRTL
1868 // passes expect that port names are unique.
1870 auto oldPorts = fmodule.getPorts();
1871 shouldReplacePorts = !oldPorts.empty();
1872 for (unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1873 auto port = oldPorts[i];
1874 auto newName = firrtl::FIRRTLTypeSwitch<Type, StringRef>(port.type)
1875 .Case<firrtl::ClockType>(
1876 [&](auto a) { return ns.newName("clk"); })
1877 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1878 [&](auto a) { return ns.newName("rst"); })
1879 .Case<firrtl::RefType>(
1880 [&](auto a) { return ns.newName("ref"); })
1881 .Default([&](auto a) {
1882 return ns.newName(Twine(getPortName()));
1883 });
1884 newPortNames.push_back(StringAttr::get(ctx, newName));
1885 }
1886 fmodule->setAttr("portNames",
1887 ArrayAttr::get(fmodule.getContext(), newPortNames));
1888 }
1889
1890 if (module == iGraph.getTopLevelModule())
1891 continue;
1892 // Skip renaming if the module already has a metasyntactic name.
1894 module.getModuleName()))
1895 continue;
1896 auto newName = StringAttr::get(ctx, nameGenerator.getNextName(ns));
1897 if (failed(renameModule(module, newName)))
1898 return failure();
1899 for (auto *use : node->uses()) {
1900 auto useOp = use->getInstance();
1901 if (auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1902 // SymbolTable::rename already updated the moduleName
1903 // FlatSymbolRefAttr on all InstanceOps. Only the debug instance name
1904 // and port names need manual fixup here.
1905 instanceOp.setName(newName);
1906 if (shouldReplacePorts)
1907 instanceOp.setPortNamesAttr(ArrayAttr::get(ctx, newPortNames));
1908 } else if (auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1909 // ObjectOp stores the class name in its result type. Result types
1910 // are not updated by SymbolTable::rename (AttrTypeReplacer is called
1911 // with replaceTypes=false), so we must patch the ClassType manually.
1912 auto oldClassType = objectOp.getType();
1913 auto newClassType = firrtl::ClassType::get(
1914 ctx, FlatSymbolRefAttr::get(newName), oldClassType.getElements());
1915 objectOp.getResult().setType(newClassType);
1916 objectOp.setName(newName);
1917 }
1918 }
1919 }
1920
1921 return success();
1922 }
1923
1924 std::string getName() const override { return "module-name-sanitizer"; }
1925
1926 bool acceptSizeIncrease() const override { return true; }
1927
1928 bool isOneShot() const override { return true; }
1929};
1930
1931/// A reduction pattern that groups modules by their port signature (types and
1932/// directions) and replaces instances with the smallest module in each group.
1933/// This helps reduce the IR by consolidating functionally equivalent modules
1934/// based on their interface.
1935///
1936/// The pattern works by:
1937/// 1. Grouping all modules by their port signature (port types and directions)
1938/// 2. For each group with multiple modules, finding the smallest module using
1939/// the module size cache
1940/// 3. Replacing all instances of larger modules with instances of the smallest
1941/// module in the same group
1942/// 4. Removing the larger modules from the circuit
1943///
1944/// This reduction is useful for reducing circuits where multiple modules have
1945/// the same interface but different implementations, allowing the reducer to
1946/// try the smallest implementation first.
1947struct ModuleSwapper : public OpReduction<InstanceOp> {
1948 // Per-circuit state containing all the information needed for module swapping
1949 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1950 struct CircuitState {
1951 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1952 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1953 std::unique_ptr<NLATable> nlaTable;
1954 };
1955
1956 void beforeReduction(mlir::ModuleOp op) override {
1957 symbols.clear();
1958 nlaRemover.clear();
1959 moduleSizes.clear();
1960 circuitStates.clear();
1961
1962 // Collect module type groups and NLA tables for all circuits up front
1963 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1964 auto &state = circuitStates[circuitOp];
1965 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1966 buildModuleTypeGroups(circuitOp, state);
1967 return WalkResult::skip();
1968 });
1969 }
1970 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1971
1972 /// Create a vector of port type-direction pairs for the given FIRRTL module.
1973 /// This ignores port names, allowing modules with the same port types and
1974 /// directions but different port names to be considered equivalent for
1975 /// swapping.
1976 PortSignature getModulePortSignature(FModuleLike module) {
1977 PortSignature signature;
1978 signature.reserve(module.getNumPorts());
1979 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1980 signature.emplace_back(module.getPortType(i), module.getPortDirection(i));
1981 return signature;
1982 }
1983
1984 /// Group modules by their port signature and find the smallest in each group.
1985 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1986 // Group modules by their port signature
1987 for (auto module : circuitOp.getBodyBlock()->getOps<FModuleLike>()) {
1988 auto signature = getModulePortSignature(module);
1989 state.moduleTypeGroups[signature].push_back(module);
1990 }
1991
1992 // For each group, find the smallest module
1993 for (auto &[signature, modules] : state.moduleTypeGroups) {
1994 if (modules.size() <= 1)
1995 continue;
1996
1997 FModuleLike smallestModule = nullptr;
1998 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1999
2000 for (auto module : modules) {
2001 uint64_t size = moduleSizes.getModuleSize(module, symbols);
2002 if (size < smallestSize) {
2003 smallestSize = size;
2004 smallestModule = module;
2005 }
2006 }
2007
2008 // Map all modules in this group to the smallest one
2009 for (auto module : modules) {
2010 if (module != smallestModule) {
2011 state.instanceToCanonicalModule[module.getModuleNameAttr()] =
2012 smallestModule;
2013 }
2014 }
2015 }
2016 }
2017
2018 uint64_t match(InstanceOp instOp) override {
2019 // Get the circuit this instance belongs to
2020 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2021 assert(circuitOp);
2022 const auto &state = circuitStates.at(circuitOp);
2023
2024 // Skip instances that participate in any NLAs
2025 DenseSet<hw::HierPathOp> nlas;
2026 state.nlaTable->getInstanceNLAs(instOp, nlas);
2027 if (!nlas.empty())
2028 return 0;
2029
2030 // Check if this instance can be redirected to a smaller module
2031 auto moduleName = instOp.getModuleNameAttr().getAttr();
2032 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
2033 if (!canonicalModule)
2034 return 0;
2035
2036 // Benefit is the size difference
2037 auto currentModule = cast<FModuleLike>(
2038 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
2039 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
2040 uint64_t canonicalSize =
2041 moduleSizes.getModuleSize(canonicalModule, symbols);
2042 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
2043 }
2044
2045 LogicalResult rewrite(InstanceOp instOp) override {
2046 // Get the circuit this instance belongs to
2047 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2048 assert(circuitOp);
2049 const auto &state = circuitStates.at(circuitOp);
2050
2051 // Replace the instantiated module with the canonical module.
2052 auto canonicalModule = state.instanceToCanonicalModule.at(
2053 instOp.getModuleNameAttr().getAttr());
2054 auto canonicalName = canonicalModule.getModuleNameAttr();
2055 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2056
2057 // Update port names to match the canonical module
2058 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2059
2060 return success();
2061 }
2062
2063 std::string getName() const override { return "firrtl-module-swapper"; }
2064 bool acceptSizeIncrease() const override { return true; }
2065
2066private:
2067 ::detail::SymbolCache symbols;
2068 NLARemover nlaRemover;
2069 ModuleSizeCache moduleSizes;
2070
2071 // Per-circuit state containing all module swapping information
2072 DenseMap<CircuitOp, CircuitState> circuitStates;
2073};
2074
2075/// A reduction pattern that handles MustDedup annotations by replacing all
2076/// module names in a dedup group with a single module name. This helps reduce
2077/// the IR by consolidating module references that are required to be identical.
2078///
2079/// The pattern works by:
2080/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
2081/// 2. For each dedup group, using the first module as the canonical name
2082/// 3. Replacing all instance references to other modules in the group with
2083/// references to the canonical module
2084/// 4. Removing the non-canonical modules from the circuit
2085/// 5. Removing the processed MustDedup annotation
2086///
2087/// This reduction is particularly useful for reducing large circuits where
2088/// multiple modules are known to be identical but haven't been deduplicated
2089/// yet.
2090struct ForceDedup : public OpReduction<CircuitOp> {
2091 void beforeReduction(mlir::ModuleOp op) override {
2092 symbols.clear();
2093 nlaRemover.clear();
2094 modulesToErase.clear();
2095 moduleSizes.clear();
2096 }
2097 void afterReduction(mlir::ModuleOp op) override {
2098 nlaRemover.remove(op);
2099 for (auto mod : modulesToErase)
2100 mod->erase();
2101 }
2102
2103 /// Collect all MustDedup annotations and create matches for each dedup group.
2104 void matches(CircuitOp circuitOp,
2105 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2106 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2107 auto annotations = AnnotationSet(circuitOp);
2108 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2109 if (!anno.isClass(mustDeduplicateAnnoClass))
2110 continue;
2111
2112 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2113 if (!modulesAttr || modulesAttr.size() < 2)
2114 continue;
2115
2116 // Check that all modules have the same port signature. Malformed inputs
2117 // may have modules listed in a MustDedup annotation that have distinct
2118 // port types.
2119 uint64_t totalSize = 0;
2120 ArrayAttr portTypes;
2121 DenseBoolArrayAttr portDirections;
2122 bool allSame = true;
2123 for (auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2124 auto target = tokenizePath(moduleName);
2125 if (!target) {
2126 allSame = false;
2127 break;
2128 }
2129 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2130 if (!mod) {
2131 allSame = false;
2132 break;
2133 }
2134 totalSize += moduleSizes.getModuleSize(mod, symbols);
2135 if (!portTypes) {
2136 portTypes = mod.getPortTypesAttr();
2137 portDirections = mod.getPortDirectionsAttr();
2138 } else if (portTypes != mod.getPortTypesAttr() ||
2139 portDirections != mod.getPortDirectionsAttr()) {
2140 allSame = false;
2141 break;
2142 }
2143 }
2144 if (!allSame)
2145 continue;
2146
2147 // Each dedup group gets its own match with benefit proportional to group
2148 // size.
2149 addMatch(totalSize, annoIdx);
2150 }
2151 }
2152
2153 LogicalResult rewriteMatches(CircuitOp circuitOp,
2154 ArrayRef<uint64_t> matches) override {
2155 auto *context = circuitOp->getContext();
2156 NLATable nlaTable(circuitOp);
2157 hw::InnerSymbolTableCollection innerSymTables;
2158 auto annotations = AnnotationSet(circuitOp);
2159 SmallVector<Annotation> newAnnotations;
2160
2161 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2162 // Check if this annotation was selected.
2163 if (!llvm::is_contained(matches, annoIdx)) {
2164 newAnnotations.push_back(anno);
2165 continue;
2166 }
2167 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2168 assert(anno.isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2169 modulesAttr.size() >= 2);
2170
2171 // Extract module names from the dedup group.
2172 SmallVector<StringAttr> moduleNames;
2173 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2174 // Parse "~CircuitName|ModuleName" format.
2175 auto refStr = moduleRef.getValue();
2176 auto pipePos = refStr.find('|');
2177 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2178 auto moduleName = refStr.substr(pipePos + 1);
2179 moduleNames.push_back(StringAttr::get(context, moduleName));
2180 }
2181 }
2182
2183 // Simply drop the annotation if there's only one module.
2184 if (moduleNames.size() < 2)
2185 continue;
2186
2187 // Replace all instances and references to other modules with the
2188 // first module.
2189 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2190 nlaRemover.markNLAsInAnnotation(anno.getAttr());
2191 }
2192 if (newAnnotations.size() == annotations.size())
2193 return failure();
2194
2195 // Update circuit annotations.
2196 AnnotationSet newAnnoSet(newAnnotations, context);
2197 newAnnoSet.applyToOperation(circuitOp);
2198 return success();
2199 }
2200
2201 std::string getName() const override { return "firrtl-force-dedup"; }
2202 bool acceptSizeIncrease() const override { return true; }
2203
2204private:
2205 /// Replace all references to modules in the dedup group with the canonical
2206 /// module name
2207 void replaceModuleReferences(CircuitOp circuitOp,
2208 ArrayRef<StringAttr> moduleNames,
2209 NLATable &nlaTable,
2210 hw::InnerSymbolTableCollection &innerSymTables) {
2211 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2212 auto &symbolTable = symbols.getSymbolTable(tableOp);
2213 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2214 auto *context = circuitOp->getContext();
2215 auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
2216
2217 // Collect the modules.
2218 FModuleLike canonicalModule;
2219 SmallVector<FModuleLike> modulesToReplace;
2220 for (auto name : moduleNames) {
2221 if (auto mod = symbolTable.lookup<FModuleLike>(name)) {
2222 if (!canonicalModule)
2223 canonicalModule = mod;
2224 else
2225 modulesToReplace.push_back(mod);
2226 }
2227 }
2228 if (modulesToReplace.empty())
2229 return;
2230
2231 // Replace all instance references.
2232 auto canonicalName = canonicalModule.getModuleNameAttr();
2233 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2234 for (auto moduleName : moduleNames) {
2235 if (moduleName == canonicalName)
2236 continue;
2237 auto *symbolOp = symbolTable.lookup(moduleName);
2238 if (!symbolOp)
2239 continue;
2240 for (auto *user : symbolUserMap.getUsers(symbolOp)) {
2241 auto instOp = dyn_cast<InstanceOp>(user);
2242 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2243 continue;
2244 instOp.setModuleNameAttr(canonicalRef);
2245 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2246 }
2247 }
2248
2249 // Update NLAs to reference the canonical module instead of modules being
2250 // removed using NLATable for better performance.
2251 for (auto oldMod : modulesToReplace) {
2252 SmallVector<hw::HierPathOp> nlaOps(
2253 nlaTable.lookup(oldMod.getModuleNameAttr()));
2254 for (auto nlaOp : nlaOps) {
2255 nlaTable.erase(nlaOp);
2256 StringAttr oldModName = oldMod.getModuleNameAttr();
2257 StringAttr newModName = canonicalName;
2258 SmallVector<Attribute, 4> newPath;
2259 for (auto nameRef : nlaOp.getNamepath()) {
2260 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2261 if (ref.getModule() == oldModName) {
2262 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2263 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2264 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2265 if (oldInst && newInst) {
2266 // Get the first module name from the list (for
2267 // InstanceOp/ObjectOp, there's only one)
2268 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2269 auto newModNames = newInst.getReferencedModuleNamesAttr();
2270 if (!oldModNames.empty() && !newModNames.empty()) {
2271 oldModName = cast<StringAttr>(oldModNames[0]);
2272 newModName = cast<StringAttr>(newModNames[0]);
2273 }
2274 }
2275 }
2276 newPath.push_back(ref);
2277 } else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2278 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2279 } else {
2280 newPath.push_back(nameRef);
2281 }
2282 }
2283 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
2284 nlaTable.addNLA(nlaOp);
2285 }
2286 }
2287
2288 // Mark NLAs in modules to be removed.
2289 for (auto module : modulesToReplace) {
2290 nlaRemover.markNLAsInOperation(module);
2291 modulesToErase.insert(module);
2292 }
2293 }
2294
2295 ::detail::SymbolCache symbols;
2296 NLARemover nlaRemover;
2297 SetVector<FModuleLike> modulesToErase;
2298 ModuleSizeCache moduleSizes;
2299};
2300
2301/// A reduction pattern that moves `MustDedup` annotations from a module onto
2302/// its child modules. This pattern iterates over all MustDedup annotations,
2303/// collects all `FInstanceLike` ops in each module of the dedup group, and
2304/// creates new MustDedup annotations for corresponding instances across the
2305/// modules. Each set of corresponding instances becomes a separate match of the
2306/// reduction. The reduction also removes the original MustDedup annotation on
2307/// the parent module.
2308///
2309/// The pattern works by:
2310/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
2311/// 2. For each dedup group, collecting all FInstanceLike operations in each
2312/// module
2313/// 3. Grouping corresponding instances across modules by their position/name
2314/// 4. Creating new MustDedup annotations for each group of corresponding
2315/// instances
2316/// 5. Removing the original MustDedup annotation from the circuit
2317struct MustDedupChildren : public OpReduction<CircuitOp> {
2318 void beforeReduction(mlir::ModuleOp op) override {
2319 symbols.clear();
2320 nlaRemover.clear();
2321 }
2322 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
2323
2324 /// Collect all MustDedup annotations and create matches for each instance
2325 /// group.
2326 void matches(CircuitOp circuitOp,
2327 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2328 auto annotations = AnnotationSet(circuitOp);
2329 uint64_t matchId = 0;
2330
2331 DenseSet<StringRef> modulesAlreadyInMustDedup;
2332 for (auto [annoIdx, anno] : llvm::enumerate(annotations))
2333 if (anno.isClass(mustDeduplicateAnnoClass))
2334 if (auto modulesAttr = anno.getMember<ArrayAttr>("modules"))
2335 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2336 if (auto target = tokenizePath(moduleRef))
2337 modulesAlreadyInMustDedup.insert(target->module);
2338
2339 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2340 if (!anno.isClass(mustDeduplicateAnnoClass))
2341 continue;
2342
2343 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2344 if (!modulesAttr || modulesAttr.size() < 2)
2345 continue;
2346
2347 // Process each group of corresponding instances
2348 processInstanceGroups(
2349 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2350 matchId++;
2351
2352 // Make sure there are at least two distinct modules.
2353 SmallDenseSet<StringAttr, 4> moduleTargets;
2354 for (auto instOp : instanceGroup) {
2355 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2356 for (auto moduleName : moduleNames)
2357 moduleTargets.insert(cast<StringAttr>(moduleName));
2358 }
2359 if (moduleTargets.size() < 2)
2360 return;
2361
2362 // Make sure none of the modules are not yet in a must dedup
2363 // annotation.
2364 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2365 auto moduleNames = inst.getReferencedModuleNames();
2366 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2367 return modulesAlreadyInMustDedup.contains(moduleName);
2368 });
2369 }))
2370 return;
2371
2372 addMatch(1, matchId - 1);
2373 });
2374 }
2375 }
2376
2377 LogicalResult rewriteMatches(CircuitOp circuitOp,
2378 ArrayRef<uint64_t> matches) override {
2379 auto *context = circuitOp->getContext();
2380 auto annotations = AnnotationSet(circuitOp);
2381 SmallVector<Annotation> newAnnotations;
2382 uint64_t matchId = 0;
2383
2384 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2385 if (!anno.isClass(mustDeduplicateAnnoClass)) {
2386 newAnnotations.push_back(anno);
2387 continue;
2388 }
2389
2390 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2391 if (!modulesAttr || modulesAttr.size() < 2) {
2392 newAnnotations.push_back(anno);
2393 continue;
2394 }
2395
2396 processInstanceGroups(
2397 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2398 // Check if this instance group was selected
2399 if (!llvm::is_contained(matches, matchId++))
2400 return;
2401
2402 // Create the list of modules to put into this new annotation.
2403 SmallSetVector<StringAttr, 4> moduleTargets;
2404 for (auto instOp : instanceGroup) {
2405 auto moduleNames = instOp.getReferencedModuleNames();
2406 for (auto moduleName : moduleNames) {
2407 auto target = TokenAnnoTarget();
2408 target.circuit = circuitOp.getName();
2409 target.module = moduleName;
2410 moduleTargets.insert(target.toStringAttr(context));
2411 }
2412 }
2413
2414 // Create a new MustDedup annotation for this list of modules.
2415 SmallVector<NamedAttribute> newAnnoAttrs;
2416 newAnnoAttrs.emplace_back(
2417 StringAttr::get(context, "class"),
2418 StringAttr::get(context, mustDeduplicateAnnoClass));
2419 newAnnoAttrs.emplace_back(
2420 StringAttr::get(context, "modules"),
2421 ArrayAttr::get(context,
2422 SmallVector<Attribute>(moduleTargets.begin(),
2423 moduleTargets.end())));
2424
2425 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
2426 newAnnotations.emplace_back(newAnnoDict);
2427 });
2428
2429 // Keep the original annotation around.
2430 newAnnotations.push_back(anno);
2431 }
2432
2433 // Update circuit annotations
2434 AnnotationSet newAnnoSet(newAnnotations, context);
2435 newAnnoSet.applyToOperation(circuitOp);
2436 return success();
2437 }
2438
2439 std::string getName() const override { return "must-dedup-children"; }
2440 bool acceptSizeIncrease() const override { return true; }
2441
2442private:
2443 /// Helper function to process groups of corresponding instances from a
2444 /// MustDedup annotation. Calls the provided lambda for each group of
2445 /// corresponding instances across the modules. Only calls the lambda if there
2446 /// are at least 2 modules.
2447 void processInstanceGroups(
2448 CircuitOp circuitOp, ArrayAttr modulesAttr,
2449 llvm::function_ref<void(ArrayRef<FInstanceLike>)> callback) {
2450 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2451
2452 // Extract module names and get the actual modules
2453 SmallVector<FModuleLike> modules;
2454 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2455 if (auto target = tokenizePath(moduleRef))
2456 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2457 modules.push_back(mod);
2458
2459 // Need at least 2 modules for deduplication
2460 if (modules.size() < 2)
2461 return;
2462
2463 // Collect all FInstanceLike operations from each module and group them by
2464 // name. Instance names are a good key for matching instances across
2465 // modules. But they may not be unique, so we need to be careful to only
2466 // match up instances that are uniquely named within every module.
2467 struct InstanceGroup {
2468 SmallVector<FInstanceLike> instances;
2469 bool nameIsUnique = true;
2470 };
2472 for (auto module : modules) {
2474 module.walk([&](FInstanceLike instOp) {
2475 if (isa<ObjectOp>(instOp.getOperation()))
2476 return;
2477 auto name = instOp.getInstanceNameAttr();
2478 auto &group = instanceGroups[name];
2479 if (nameCounts[name]++ > 1)
2480 group.nameIsUnique = false;
2481 group.instances.push_back(instOp);
2482 });
2483 }
2484
2485 // Call the callback for each group of instances that are uniquely named and
2486 // consist of at least 2 instances.
2487 for (auto &[name, group] : instanceGroups)
2488 if (group.nameIsUnique && group.instances.size() >= 2)
2489 callback(group.instances);
2490 }
2491
2492 ::detail::SymbolCache symbols;
2493 NLARemover nlaRemover;
2494};
2495
2496struct LayerDisable : public OpReduction<CircuitOp> {
2497 LayerDisable(MLIRContext *context) {
2498 pm = std::make_unique<mlir::PassManager>(
2499 context, "builtin.module", mlir::OpPassManager::Nesting::Explicit);
2500 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2501 };
2502
2503 void beforeReduction(mlir::ModuleOp op) override { symbolRefAttrMap.clear(); }
2504
2505 void afterReduction(mlir::ModuleOp op) override { (void)pm->run(op); };
2506
2507 void matches(CircuitOp circuitOp,
2508 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2509 uint64_t matchId = 0;
2510
2511 SmallVector<FlatSymbolRefAttr> nestedRefs;
2512 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2513 LayerOp layerOp) {
2514 if (!rootRef)
2515 rootRef = layerOp.getSymNameAttr();
2516 else
2517 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2518
2519 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2520 addMatch(1, matchId++);
2521
2522 for (auto nestedLayerOp : layerOp.getOps<LayerOp>())
2523 addLayer(rootRef, nestedLayerOp);
2524
2525 if (!nestedRefs.empty())
2526 nestedRefs.pop_back();
2527 };
2528
2529 for (auto layerOp : circuitOp.getOps<LayerOp>())
2530 addLayer({}, layerOp);
2531 }
2532
2533 LogicalResult rewriteMatches(CircuitOp circuitOp,
2534 ArrayRef<uint64_t> matches) override {
2535 SmallVector<Attribute> disableLayers;
2536 if (auto existingDisables = circuitOp.getDisableLayersAttr()) {
2537 auto disableRange = existingDisables.getAsRange<Attribute>();
2538 disableLayers.append(disableRange.begin(), disableRange.end());
2539 }
2540 for (auto match : matches)
2541 disableLayers.push_back(symbolRefAttrMap.at(match));
2542
2543 circuitOp.setDisableLayersAttr(
2544 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2545
2546 return success();
2547 }
2548
2549 std::string getName() const override { return "firrtl-layer-disable"; }
2550
2551 std::unique_ptr<mlir::PassManager> pm;
2552 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2553};
2554
2555} // namespace
2556
2557/// A reduction pattern that removes elements from FIRRTL list create
2558/// operations. This generates one match per element in each list, allowing
2559/// selective removal of individual elements.
2560struct ListCreateElementRemover : public OpReduction<ListCreateOp> {
2561 void matches(ListCreateOp listOp,
2562 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2563 // Create one match for each element in the list
2564 auto elements = listOp.getElements();
2565 for (size_t i = 0; i < elements.size(); ++i)
2566 addMatch(1, i);
2567 }
2568
2569 LogicalResult rewriteMatches(ListCreateOp listOp,
2570 ArrayRef<uint64_t> matches) override {
2571 // Convert matches to a set for fast lookup
2572 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
2573
2574 // Collect elements that should be kept (not in matches)
2575 SmallVector<Value> newElements;
2576 auto elements = listOp.getElements();
2577 for (size_t i = 0; i < elements.size(); ++i) {
2578 if (!matchesSet.contains(i))
2579 newElements.push_back(elements[i]);
2580 }
2581
2582 // Create a new list with the remaining elements
2583 OpBuilder builder(listOp);
2584 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2585 listOp.getType(), newElements);
2586 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2587 listOp.erase();
2588
2589 return success();
2590 }
2591
2592 std::string getName() const override {
2593 return "firrtl-list-create-element-remover";
2594 }
2595};
2596
2597/// Reduction that removes the `convention` attribute from regular modules.
2598struct ModuleConventionRemover : public OpReduction<FModuleOp> {
2599 uint64_t match(FModuleOp module) override {
2600 return module.getConvention() != Convention::Internal;
2601 }
2602
2603 LogicalResult rewrite(FModuleOp module) override {
2604 module.setConvention(Convention::Internal);
2605 return success();
2606 }
2607
2608 std::string getName() const override { return "module-convention-remover"; }
2609 bool acceptSizeIncrease() const override { return true; }
2610 bool isOneShot() const override { return true; }
2611};
2612
2613/// Reduction that removes the `convention` attribute from external modules.
2614struct ExtmoduleConventionRemover : public OpReduction<FExtModuleOp> {
2615 uint64_t match(FExtModuleOp extmodule) override {
2616 return extmodule.getConvention() != Convention::Internal;
2617 }
2618
2619 LogicalResult rewrite(FExtModuleOp extmodule) override {
2620 extmodule.setConvention(Convention::Internal);
2621 return success();
2622 }
2623
2624 std::string getName() const override {
2625 return "extmodule-convention-remover";
2626 }
2627 bool acceptSizeIncrease() const override { return true; }
2628 bool isOneShot() const override { return true; }
2629};
2630
2631//===----------------------------------------------------------------------===//
2632// Reduction Registration
2633//===----------------------------------------------------------------------===//
2634
2637 // Gather a list of reduction patterns that we should try. Ideally these are
2638 // assigned reasonable benefit indicators (higher benefit patterns are
2639 // prioritized). For example, things that can knock out entire modules while
2640 // being cheap should be tried first (and thus have higher benefit), before
2641 // trying to tweak operands of individual arithmetic ops.
2642 patterns.add<SimplifyResets, 35>();
2643 patterns.add<ForceDedup, 34>();
2644 patterns.add<MustDedupChildren, 33>();
2645 patterns.add<AnnotationRemover, 32>();
2646 patterns.add<ModuleSwapper, 31>();
2647 patterns.add<LayerDisable, 30>(getContext());
2648 patterns.add<PassReduction, 29>(
2649 getContext(),
2650 firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,
2651 true);
2652 patterns.add<PassReduction, 28>(getContext(),
2653 firrtl::createLowerCHIRRTLPass(), true, true);
2654 patterns.add<PassReduction, 27>(getContext(), firrtl::createInferWidths(),
2655 true, true);
2656 patterns.add<PassReduction, 26>(getContext(), firrtl::createInferResets(),
2657 true, true);
2658 patterns.add<FIRRTLModuleExternalizer, 25>();
2659 patterns.add<InstanceStubber, 24>();
2660 patterns.add<MemoryStubber, 23>();
2661 patterns.add<EagerInliner, 22>();
2662 patterns.add<ObjectInliner, 22>();
2663 patterns.add<PassReduction, 21>(getContext(),
2664 firrtl::createLowerFIRRTLTypes(), true, true);
2665 patterns.add<PassReduction, 20>(getContext(), firrtl::createExpandWhens(),
2666 true, true);
2667 patterns.add<PassReduction, 19>(getContext(), firrtl::createInliner());
2668 patterns.add<PassReduction, 18>(getContext(), firrtl::createIMConstProp());
2669 patterns.add<PassReduction, 17>(
2670 getContext(),
2671 firrtl::createRemoveUnusedPorts({/*ignoreDontTouch=*/true}));
2672 patterns.add<NodeSymbolRemover, 16>();
2673 patterns.add<PassReduction, 15>(getContext(), firrtl::createIMDeadCodeElim());
2674 patterns.add<ConnectForwarder, 14>();
2675 patterns.add<ConnectInvalidator, 13>();
2676 patterns.add<Constantifier, 12>();
2677 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2678 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2679 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2681 patterns.add<ResetDisconnector, 8>();
2682 patterns.add<DetachSubaccesses, 7>();
2683 patterns.add<ModulePortPruner, 7>();
2684 patterns.add<ExtmodulePortPruner, 6>();
2685 patterns.add<RootPortPruner, 5>();
2686 patterns.add<RootExtmodulePortPruner, 5>();
2687 patterns.add<ExtmoduleInstanceRemover, 4>();
2688 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2689 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2690 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2692 patterns.add<ModuleNameSanitizer, 0>();
2695}
2696
2698 mlir::DialectRegistry &registry) {
2699 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
2700 dialect->addInterfaces<FIRRTLReducePatternDialectInterface>();
2701 });
2702}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
Definition NLATable.cpp:41
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
Definition NLATable.cpp:58
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Definition NLATable.cpp:68
Helper class to cache tie-off values for different FIRRTL types.
Definition FIRRTLUtils.h:63
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
A utility class that generates metasyntactic variable names for use in reductions.
static bool isMetasyntacticName(StringRef name)
Return true if name already has a metasyntactic prefix, i.e.
void reset()
Reset the generator to start from the beginning of the sequence.
const char * getNextName()
Get the next metasyntactic name in the sequence.
connect(destination, source)
Definition support.py:39
@ None
Don't explicitly preserve any named values.
Definition Passes.h:52
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
Definition LayerSet.h:43
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition HWOps.cpp:36
void info(Twine message)
Definition LSPUtils.cpp:20
void pruneUnusedOps(SmallVectorImpl< Operation * > &worklist, Reduction &reduction)
Starting from an initial worklist of operations, traverse through it and its operands and erase opera...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Reduction that removes the convention attribute from external modules.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
uint64_t match(FExtModuleOp extmodule) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(FExtModuleOp extmodule) override
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Reduction that removes the convention attribute from regular modules.
uint64_t match(FModuleOp module) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(FModuleOp module) override
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Pseudo-reduction that sanitizes the names of operations inside modules.
Pseudo-reduction that sanitizes module and port names.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
A reduction pattern that applies an mlir::Pass.
Definition Reduction.h:142
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition Reduction.h:35
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:79
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:66
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:96
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:50
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
The namespace of a CircuitOp, generally inhabited by modules.
Definition Namespace.h:24
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)