CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLReductions.cpp
Go to the documentation of this file.
1//===- FIRRTLReductions.cpp - Reduction patterns for the FIRRTL dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/Matchers.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/Support/Debug.h"
32
33#define DEBUG_TYPE "firrtl-reductions"
34
35using namespace mlir;
36using namespace circt;
37using namespace firrtl;
38using llvm::MapVector;
39using llvm::SmallDenseSet;
40using llvm::SmallSetVector;
41
42//===----------------------------------------------------------------------===//
43// Utilities
44//===----------------------------------------------------------------------===//
45
46namespace detail {
47/// A utility doing lazy construction of `SymbolTable`s and `SymbolUserMap`s,
48/// which is handy for reductions that need to look up a lot of symbols.
50 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
51
52 SymbolTable &getSymbolTable(Operation *op) {
53 return tables->getSymbolTable(op);
54 }
55 SymbolTable &getNearestSymbolTable(Operation *op) {
56 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
57 }
58
59 SymbolUserMap &getSymbolUserMap(Operation *op) {
60 auto it = userMaps.find(op);
61 if (it != userMaps.end())
62 return it->second;
63 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
64 }
65 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
66 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
67 }
68
69 void clear() {
70 tables = std::make_unique<SymbolTableCollection>();
71 userMaps.clear();
72 }
73
74private:
75 std::unique_ptr<SymbolTableCollection> tables;
77};
78} // namespace detail
79
80/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
81/// optional in case another type of module is instantiated.
82static std::optional<firrtl::FModuleOp>
83findInstantiatedModule(firrtl::InstanceOp instOp,
84 ::detail::SymbolCache &symbols) {
85 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
86 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
87 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
88 return moduleOp ? std::optional(moduleOp) : std::nullopt;
89}
90
91/// Utility to track the transitive size of modules.
93 void clear() { moduleSizes.clear(); }
94
95 uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols) {
96 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
97 return it->second;
98 uint64_t size = 1;
99 module->walk([&](Operation *op) {
100 size += 1;
101 if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
102 if (auto instModule = findInstantiatedModule(instOp, symbols))
103 size += getModuleSize(*instModule, symbols);
104 });
105 moduleSizes.insert({module, size});
106 return size;
107 }
108
109private:
110 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
111};
112
113/// Check that all connections to a value are invalids.
114static bool onlyInvalidated(Value arg) {
115 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
116 auto *op = use.getOwner();
117 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
118 return false;
119 if (use.getOperandNumber() != 0)
120 return false;
121 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
122 return false;
123 return true;
124 });
125}
126
127/// A tracker for track NLAs affected by a reduction. Performs the necessary
128/// cleanup steps in order to maintain IR validity after the reduction has
129/// applied. For example, removing an instance that forms part of an NLA path
130/// requires that NLA to be removed as well.
132 /// Clear the set of marked NLAs. Call this before attempting a reduction.
133 void clear() { nlasToRemove.clear(); }
134
135 /// Remove all marked annotations. Call this after applying a reduction in
136 /// order to validate the IR.
137 void remove(mlir::ModuleOp module) {
138 unsigned numRemoved = 0;
139 (void)numRemoved;
140 SymbolTableCollection symbolTables;
141 for (Operation &rootOp : *module.getBody()) {
142 if (!isa<firrtl::CircuitOp>(&rootOp))
143 continue;
144 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
145 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
146 for (auto sym : nlasToRemove) {
147 if (auto *op = symbolTable.lookup(sym)) {
148 if (symbolUserMap.useEmpty(op)) {
149 ++numRemoved;
150 op->erase();
151 }
152 }
153 }
154 }
155 LLVM_DEBUG({
156 unsigned numLost = nlasToRemove.size() - numRemoved;
157 if (numRemoved > 0 || numLost > 0) {
158 llvm::dbgs() << "Removed " << numRemoved << " NLAs";
159 if (numLost > 0)
160 llvm::dbgs() << " (" << numLost << " no longer there)";
161 llvm::dbgs() << "\n";
162 }
163 });
164 }
165
166 /// Mark all NLAs referenced in the given annotation as to be removed. This
167 /// can be an entire array or dictionary of annotations, and the function will
168 /// descend into child annotations appropriately.
169 void markNLAsInAnnotation(Attribute anno) {
170 if (auto dict = dyn_cast<DictionaryAttr>(anno)) {
171 if (auto field = dict.getAs<FlatSymbolRefAttr>("circt.nonlocal"))
172 nlasToRemove.insert(field.getAttr());
173 for (auto namedAttr : dict)
174 markNLAsInAnnotation(namedAttr.getValue());
175 } else if (auto array = dyn_cast<ArrayAttr>(anno)) {
176 for (auto attr : array)
177 markNLAsInAnnotation(attr);
178 }
179 }
180
181 /// Mark all NLAs referenced in an operation. Also traverses all nested
182 /// operations. Call this before removing an operation, to mark any associated
183 /// NLAs as to be removed as well.
184 void markNLAsInOperation(Operation *op) {
185 op->walk([&](Operation *op) {
186 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
187 markNLAsInAnnotation(annos);
188 });
189 }
190
191 /// The set of NLAs to remove, identified by their symbol.
192 llvm::DenseSet<StringAttr> nlasToRemove;
193};
194
195//===----------------------------------------------------------------------===//
196// Reduction patterns
197//===----------------------------------------------------------------------===//
198
199namespace {
200
201/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
202struct FIRRTLModuleExternalizer : public OpReduction<FModuleOp> {
203 void beforeReduction(mlir::ModuleOp op) override {
204 nlaRemover.clear();
205 symbols.clear();
206 moduleSizes.clear();
207 innerSymUses = reduce::InnerSymbolUses(op);
208 }
209 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
210
211 uint64_t match(FModuleOp module) override {
212 if (innerSymUses.hasInnerRef(module))
213 return 0;
214 return moduleSizes.getModuleSize(module, symbols);
215 }
216
217 LogicalResult rewrite(FModuleOp module) override {
218 // Hack up a list of known layers.
219 LayerSet layers;
220 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
221 for (auto attr : module.getPortTypes()) {
222 auto type = cast<TypeAttr>(attr).getValue();
223 if (auto refType = type_dyn_cast<RefType>(type))
224 if (auto layer = refType.getLayer())
225 layers.insert(layer);
226 }
227 SmallVector<Attribute, 4> layersArray;
228 layersArray.reserve(layers.size());
229 for (auto layer : layers)
230 layersArray.push_back(layer);
231
232 nlaRemover.markNLAsInOperation(module);
233 OpBuilder builder(module);
234 auto extmodule = FExtModuleOp::create(
235 builder, module->getLoc(),
236 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
237 module.getConventionAttr(), module.getPorts(),
238 builder.getArrayAttr(layersArray), StringRef(),
239 module.getAnnotationsAttr());
240 SymbolTable::setSymbolVisibility(extmodule,
241 SymbolTable::getSymbolVisibility(module));
242 module->erase();
243 return success();
244 }
245
246 std::string getName() const override { return "firrtl-module-externalizer"; }
247
248 ::detail::SymbolCache symbols;
249 NLARemover nlaRemover;
250 reduce::InnerSymbolUses innerSymUses;
251 ModuleSizeCache moduleSizes;
252};
253
254/// Invalidate all the leaf fields of a value with a given flippedness by
255/// connecting an invalid value to them. This function handles different FIRRTL
256/// types appropriately:
257/// - Ref types (probes): Creates wire infrastructure with ref.send/ref.define
258/// and invalidates the underlying wire.
259/// - Bundle/Vector types: Recursively descends into elements.
260/// - Base types: Creates InvalidValueOp and connects it.
261/// - Property types: Creates UnknownValueOp and assigns it.
262///
263/// This is useful for ensuring that all output ports of an instance or memory
264/// (including those nested in bundles) are properly invalidated.
265static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
266 TieOffCache &tieOffCache, bool flip = false) {
267 auto type = type_dyn_cast<FIRRTLType>(value.getType());
268 if (!type)
269 return;
270
271 // Handle ref types (probes) by creating wires and defining them properly.
272 if (auto refType = type_dyn_cast<RefType>(type)) {
273 // Input probes are illegal in FIRRTL.
274 assert(!flip && "input probes are not allowed");
275
276 auto underlyingType = refType.getType();
277
278 if (!refType.getForceable()) {
279 // For probe types: create underlying wire, ref.send, ref.define, and
280 // invalidate.
281 auto targetWire = WireOp::create(builder, underlyingType);
282 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
283 builder.create<RefDefineOp>(value, refSend.getResult());
284
285 // Invalidate the underlying wire.
286 auto invalid = tieOffCache.getInvalid(underlyingType);
287 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
288 return;
289 }
290
291 // For rwprobe types: create forceable wire, ref.define, and invalidate.
292 auto forceableWire =
293 WireOp::create(builder, underlyingType,
294 /*name=*/"", NameKindEnum::DroppableName,
295 /*annotations=*/ArrayRef<Attribute>{},
296 /*innerSym=*/StringAttr{},
297 /*forceable=*/true);
298
299 // The forceable wire returns both the wire and the rwprobe.
300 auto targetWire = forceableWire.getResult();
301 auto forceableRef = forceableWire.getDataRef();
302
303 builder.create<RefDefineOp>(value, forceableRef);
304
305 // Invalidate the underlying wire.
306 auto invalid = tieOffCache.getInvalid(underlyingType);
307 MatchingConnectOp::create(builder, targetWire, invalid);
308 return;
309 }
310
311 // Descend into bundles by creating subfield ops.
312 if (auto bundleType = type_dyn_cast<BundleType>(type)) {
313 for (auto element : llvm::enumerate(bundleType.getElements())) {
314 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
315 invalidateOutputs(builder, subfield, tieOffCache,
316 flip ^ element.value().isFlip);
317 if (subfield.use_empty())
318 subfield.getDefiningOp()->erase();
319 }
320 return;
321 }
322
323 // Descend into vectors by creating subindex ops.
324 if (auto vectorType = type_dyn_cast<FVectorType>(type)) {
325 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
326 auto subindex = builder.createOrFold<SubindexOp>(value, i);
327 invalidateOutputs(builder, subindex, tieOffCache, flip);
328 if (subindex.use_empty())
329 subindex.getDefiningOp()->erase();
330 }
331 return;
332 }
333
334 // Only drive outputs.
335 if (flip)
336 return;
337
338 // Create InvalidValueOp for FIRRTLBaseType.
339 if (auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
340 auto invalid = tieOffCache.getInvalid(baseType);
341 ConnectOp::create(builder, value, invalid);
342 return;
343 }
344
345 // For property types, use UnknownValueOp to tie off the connection.
346 if (auto propType = type_dyn_cast<PropertyType>(type)) {
347 auto unknown = tieOffCache.getUnknown(propType);
348 builder.create<PropAssignOp>(value, unknown);
349 }
350}
351
352/// Connect a value to every leave of a destination value.
353static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
354 Value value) {
355 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
356 if (!type)
357 return;
358 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
359 for (auto element : llvm::enumerate(bundleType.getElements()))
360 connectToLeafs(builder,
361 firrtl::SubfieldOp::create(builder, dest, element.index()),
362 value);
363 return;
364 }
365 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
366 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
367 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
368 value);
369 return;
370 }
371 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
372 if (!valueType)
373 return;
374 auto destWidth = type.getBitWidthOrSentinel();
375 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
376 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
377 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
378 if (!isa<firrtl::UIntType>(type)) {
379 if (isa<firrtl::SIntType>(type))
380 value = firrtl::AsSIntPrimOp::create(builder, value);
381 else
382 return;
383 }
384 firrtl::ConnectOp::create(builder, dest, value);
385}
386
387/// Reduce all leaf fields of a value through an XOR tree.
388static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
389 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
390 if (!type)
391 return;
392 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
393 for (auto element : llvm::enumerate(bundleType.getElements()))
394 reduceXor(
395 builder, into,
396 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
397 return;
398 }
399 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
400 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
401 reduceXor(builder, into,
402 builder.createOrFold<firrtl::SubindexOp>(value, i));
403 return;
404 }
405 if (!isa<firrtl::UIntType>(type)) {
406 if (isa<firrtl::SIntType>(type))
407 value = firrtl::AsUIntPrimOp::create(builder, value);
408 else
409 return;
410 }
411 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
412}
413
414/// A sample reduction pattern that maps `firrtl.instance` to a set of
415/// invalidated wires. This often shortcuts a long iterative process of connect
416/// invalidation, module externalization, and wire stripping
417struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
418 void beforeReduction(mlir::ModuleOp op) override {
419 erasedInsts.clear();
420 erasedModules.clear();
421 symbols.clear();
422 nlaRemover.clear();
423 moduleSizes.clear();
424 }
425 void afterReduction(mlir::ModuleOp op) override {
426 // Look into deleted modules to find additional instances that are no longer
427 // instantiated anywhere.
428 SmallVector<Operation *> worklist;
429 auto deadInsts = erasedInsts;
430 for (auto *op : erasedModules)
431 worklist.push_back(op);
432 while (!worklist.empty()) {
433 auto *op = worklist.pop_back_val();
434 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
435 op->walk([&](firrtl::InstanceOp instOp) {
436 auto moduleOp = cast<firrtl::FModuleLike>(
437 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
438 deadInsts.insert(instOp);
439 if (llvm::all_of(
440 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
441 [&](Operation *user) { return deadInsts.contains(user); })) {
442 LLVM_DEBUG(llvm::dbgs() << "- Removing transitively unused module `"
443 << moduleOp.getModuleName() << "`\n");
444 erasedModules.insert(moduleOp);
445 worklist.push_back(moduleOp);
446 }
447 });
448 }
449
450 for (auto *op : erasedInsts)
451 op->erase();
452 for (auto *op : erasedModules)
453 op->erase();
454 nlaRemover.remove(op);
455 }
456
457 uint64_t match(firrtl::InstanceOp instOp) override {
458 if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
459 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
460 return 0;
461 }
462
463 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
464 LLVM_DEBUG(llvm::dbgs()
465 << "Stubbing instance `" << instOp.getName() << "`\n");
466 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
467 TieOffCache tieOffCache(builder);
468 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
469 auto result = instOp.getResult(i);
470 auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" +
471 instOp.getPortName(i));
472 auto wire =
473 firrtl::WireOp::create(builder, result.getType(), name,
474 firrtl::NameKindEnum::DroppableName,
475 instOp.getPortAnnotation(i), StringAttr{})
476 .getResult();
477 invalidateOutputs(builder, wire, tieOffCache,
478 instOp.getPortDirection(i) == firrtl::Direction::In);
479 result.replaceAllUsesWith(wire);
480 }
481 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
482 auto moduleOp = cast<firrtl::FModuleLike>(
483 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
484 nlaRemover.markNLAsInOperation(instOp);
485 erasedInsts.insert(instOp);
486 if (llvm::all_of(
487 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
488 [&](Operation *user) { return erasedInsts.contains(user); })) {
489 LLVM_DEBUG(llvm::dbgs() << "- Removing now unused module `"
490 << moduleOp.getModuleName() << "`\n");
491 erasedModules.insert(moduleOp);
492 }
493 return success();
494 }
495
496 std::string getName() const override { return "instance-stubber"; }
497 bool acceptSizeIncrease() const override { return true; }
498
499 ::detail::SymbolCache symbols;
500 NLARemover nlaRemover;
501 llvm::DenseSet<Operation *> erasedInsts;
502 llvm::DenseSet<Operation *> erasedModules;
503 ModuleSizeCache moduleSizes;
504};
505
506/// A sample reduction pattern that maps `firrtl.mem` to a set of invalidated
507/// wires.
508struct MemoryStubber : public OpReduction<firrtl::MemOp> {
509 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
510 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
511 LogicalResult rewrite(firrtl::MemOp memOp) override {
512 LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
513 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
514 TieOffCache tieOffCache(builder);
515 Value xorInputs;
516 SmallVector<Value> outputs;
517 for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
518 auto result = memOp.getResult(i);
519 auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" +
520 memOp.getPortName(i));
521 auto wire =
522 firrtl::WireOp::create(builder, result.getType(), name,
523 firrtl::NameKindEnum::DroppableName,
524 memOp.getPortAnnotation(i), StringAttr{})
525 .getResult();
526 invalidateOutputs(builder, wire, tieOffCache, true);
527 result.replaceAllUsesWith(wire);
528
529 // Isolate the input and output data fields of the port.
530 Value input, output;
531 switch (memOp.getPortKind(i)) {
532 case firrtl::MemOp::PortKind::Read:
533 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
534 break;
535 case firrtl::MemOp::PortKind::Write:
536 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
537 break;
538 case firrtl::MemOp::PortKind::ReadWrite:
539 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
540 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
541 break;
542 case firrtl::MemOp::PortKind::Debug:
543 output = wire;
544 break;
545 }
546
547 if (!isa<firrtl::RefType>(result.getType())) {
548 // Reduce all input ports to a single one through an XOR tree.
549 unsigned numFields =
550 cast<firrtl::BundleType>(wire.getType()).getNumElements();
551 for (unsigned i = 0; i != numFields; ++i) {
552 if (i != 2 && i != 3 && i != 5)
553 reduceXor(builder, xorInputs,
554 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
555 }
556 if (input)
557 reduceXor(builder, xorInputs, input);
558 }
559
560 // Track the output port to hook it up to the XORd input later.
561 if (output)
562 outputs.push_back(output);
563 }
564
565 // Hook up the outputs.
566 for (auto output : outputs)
567 connectToLeafs(builder, output, xorInputs);
568
569 nlaRemover.markNLAsInOperation(memOp);
570 memOp->erase();
571 return success();
572 }
573 std::string getName() const override { return "memory-stubber"; }
574 bool acceptSizeIncrease() const override { return true; }
575 NLARemover nlaRemover;
576};
577
578/// Check whether an operation interacts with flows in any way, which can make
579/// replacement and operand forwarding harder in some cases.
580static bool isFlowSensitiveOp(Operation *op) {
581 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
582 SubaccessOp, ObjectSubfieldOp>(op);
583}
584
585/// A sample reduction pattern that replaces all uses of an operation with one
586/// of its operands. This can help pruning large parts of the expression tree
587/// rapidly.
588template <unsigned OpNum>
589struct FIRRTLOperandForwarder : public Reduction {
590 uint64_t match(Operation *op) override {
591 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
592 return 0;
593 if (isFlowSensitiveOp(op))
594 return 0;
595 auto resultTy =
596 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
597 auto opTy =
598 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
599 return resultTy && opTy &&
600 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
601 (resultTy.getBitWidthOrSentinel() == -1) ==
602 (opTy.getBitWidthOrSentinel() == -1) &&
603 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
604 }
605 LogicalResult rewrite(Operation *op) override {
606 assert(match(op));
607 ImplicitLocOpBuilder builder(op->getLoc(), op);
608 auto result = op->getResult(0);
609 auto operand = op->getOperand(OpNum);
610 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
611 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
612 auto resultWidth = resultTy.getBitWidthOrSentinel();
613 auto operandWidth = operandTy.getBitWidthOrSentinel();
614 Value newOp;
615 if (resultWidth < operandWidth)
616 newOp =
617 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
618 else if (resultWidth > operandWidth)
619 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
620 else
621 newOp = operand;
622 LLVM_DEBUG(llvm::dbgs() << "Forwarding " << newOp << " in " << *op << "\n");
623 result.replaceAllUsesWith(newOp);
624 reduce::pruneUnusedOps(op, *this);
625 return success();
626 }
627 std::string getName() const override {
628 return ("firrtl-operand" + Twine(OpNum) + "-forwarder").str();
629 }
630};
631
632/// A sample reduction pattern that replaces FIRRTL operations with a constant
633/// zero of their type.
634struct Constantifier : public Reduction {
635 void beforeReduction(mlir::ModuleOp op) override {
636 symbols.clear();
637
638 // Find valid dummy classes that we can use for anyref casts.
639 anyrefCastDummy.clear();
640 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
641 for (auto classOp : circuitOp.getOps<ClassOp>()) {
642 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
643 anyrefCastDummy.insert({circuitOp, classOp});
644 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
645 }
646 }
647 return WalkResult::skip();
648 });
649 }
650
651 uint64_t match(Operation *op) override {
652 if (op->hasTrait<OpTrait::ConstantLike>()) {
653 Attribute attr;
654 if (!matchPattern(op, m_Constant(&attr)))
655 return 0;
656 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
657 if (intAttr.getValue().isZero())
658 return 0;
659 if (auto strAttr = dyn_cast<StringAttr>(attr))
660 if (strAttr.empty())
661 return 0;
662 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
663 if (floatAttr.getValue().isZero())
664 return 0;
665 }
666 if (auto listOp = dyn_cast<ListCreateOp>(op))
667 if (listOp.getElements().empty())
668 return 0;
669 if (auto pathOp = dyn_cast<UnresolvedPathOp>(op))
670 if (pathOp.getTarget().empty())
671 return 0;
672
673 // Don't replace anyref casts that already target a dummy class.
674 if (auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
675 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
676 auto className =
677 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
678 if (anyrefCastDummyNames[circuitOp].contains(className))
679 return 0;
680 }
681
682 if (op->getNumResults() != 1)
683 return 0;
684 if (op->hasAttr("inner_sym"))
685 return 0;
686 if (isFlowSensitiveOp(op))
687 return 0;
688 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
689 DoubleType, ListType, PathType, AnyRefType>(
690 op->getResult(0).getType());
691 }
692
693 LogicalResult rewrite(Operation *op) override {
694 OpBuilder builder(op);
695 auto type = op->getResult(0).getType();
696
697 // Handle UInt/SInt types.
698 if (isa<UIntType, SIntType>(type)) {
699 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
700 if (width == -1)
701 width = 64;
702 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
703 APSInt(width, isa<UIntType>(type)));
704 op->replaceAllUsesWith(newOp);
705 reduce::pruneUnusedOps(op, *this);
706 return success();
707 }
708
709 // Handle property string types.
710 if (isa<StringType>(type)) {
711 auto attr = builder.getStringAttr("");
712 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
713 op->replaceAllUsesWith(newOp);
714 reduce::pruneUnusedOps(op, *this);
715 return success();
716 }
717
718 // Handle property integer types.
719 if (isa<FIntegerType>(type)) {
720 auto attr = builder.getIntegerAttr(builder.getIntegerType(64, true), 0);
721 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
722 op->replaceAllUsesWith(newOp);
723 reduce::pruneUnusedOps(op, *this);
724 return success();
725 }
726
727 // Handle property boolean types.
728 if (isa<BoolType>(type)) {
729 auto attr = builder.getBoolAttr(false);
730 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
731 op->replaceAllUsesWith(newOp);
732 reduce::pruneUnusedOps(op, *this);
733 return success();
734 }
735
736 // Handle property double types.
737 if (isa<DoubleType>(type)) {
738 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
739 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
740 op->replaceAllUsesWith(newOp);
741 reduce::pruneUnusedOps(op, *this);
742 return success();
743 }
744
745 // Handle property list types.
746 if (isa<ListType>(type)) {
747 auto newOp =
748 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
749 op->replaceAllUsesWith(newOp);
750 reduce::pruneUnusedOps(op, *this);
751 return success();
752 }
753
754 // Handle property path types.
755 if (isa<PathType>(type)) {
756 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(), "");
757 op->replaceAllUsesWith(newOp);
758 reduce::pruneUnusedOps(op, *this);
759 return success();
760 }
761
762 // Handle anyref types.
763 if (isa<AnyRefType>(type)) {
764 auto circuitOp = op->getParentOfType<CircuitOp>();
765 auto &dummy = anyrefCastDummy[circuitOp];
766 if (!dummy) {
767 OpBuilder::InsertionGuard guard(builder);
768 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
769 auto &symbolTable = symbols.getNearestSymbolTable(op);
770 dummy = ClassOp::create(builder, op->getLoc(), "Dummy", {}, {});
771 symbolTable.insert(dummy);
772 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
773 }
774 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy, "dummy");
775 auto anyrefOp =
776 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
777 op->replaceAllUsesWith(anyrefOp);
778 reduce::pruneUnusedOps(op, *this);
779 return success();
780 }
781
782 return failure();
783 }
784
785 std::string getName() const override { return "firrtl-constantifier"; }
786 bool acceptSizeIncrease() const override { return true; }
787
788 ::detail::SymbolCache symbols;
790 SmallDenseMap<CircuitOp, DenseSet<StringAttr>, 2> anyrefCastDummyNames;
791};
792
793/// A sample reduction pattern that replaces the right-hand-side of
794/// `firrtl.connect` and `firrtl.matchingconnect` operations with a
795/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
796/// connects and creates opportunities for reduction in DCE/CSE.
797struct ConnectInvalidator : public Reduction {
798 uint64_t match(Operation *op) override {
799 if (!isa<FConnectLike>(op))
800 return 0;
801 if (auto *srcOp = op->getOperand(1).getDefiningOp())
802 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
803 isa<InvalidValueOp>(srcOp))
804 return 0;
805 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
806 return type && type.isPassive();
807 }
808 LogicalResult rewrite(Operation *op) override {
809 assert(match(op));
810 auto rhs = op->getOperand(1);
811 OpBuilder builder(op);
812 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
813 auto *rhsOp = rhs.getDefiningOp();
814 op->setOperand(1, invOp);
815 if (rhsOp)
816 reduce::pruneUnusedOps(rhsOp, *this);
817 return success();
818 }
819 std::string getName() const override { return "connect-invalidator"; }
820 bool acceptSizeIncrease() const override { return true; }
821};
822
823/// A reduction pattern that removes FIRRTL annotations from ports and
824/// operations. This generates one match per annotation and port annotation,
825/// allowing selective removal of individual annotations.
826struct AnnotationRemover : public Reduction {
827 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
828 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
829
830 void matches(Operation *op,
831 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
832 uint64_t matchId = 0;
833
834 // Generate matches for regular annotations
835 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
836 for (unsigned i = 0; i < annos.size(); ++i)
837 addMatch(1, matchId++);
838
839 // Generate matches for port annotations
840 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations"))
841 for (auto portAnnoArray : portAnnos)
842 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
843 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
844 addMatch(1, matchId++);
845 }
846
847 LogicalResult rewriteMatches(Operation *op,
848 ArrayRef<uint64_t> matches) override {
849 // Convert matches to a set for fast lookup
850 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
851
852 // Lambda to process annotations and filter out matched ones
853 uint64_t matchId = 0;
854 auto processAnnotations =
855 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
856 SmallVector<Attribute> newAnnotations;
857 for (auto anno : annotations) {
858 if (!matchesSet.contains(matchId)) {
859 newAnnotations.push_back(anno);
860 } else {
861 // Mark NLAs in the removed annotation for cleanup
862 nlaRemover.markNLAsInAnnotation(anno);
863 }
864 matchId++;
865 }
866 return ArrayAttr::get(op->getContext(), newAnnotations);
867 };
868
869 // Remove regular annotations
870 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
871 op->setAttr("annotations", processAnnotations(annos.getValue()));
872 }
873
874 // Remove port annotations
875 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
876 SmallVector<Attribute> newPortAnnos;
877 for (auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
878 newPortAnnos.push_back(
879 processAnnotations(portAnnoArrayAttr.getValue()));
880 }
881 op->setAttr("portAnnotations",
882 ArrayAttr::get(op->getContext(), newPortAnnos));
883 }
884
885 return success();
886 }
887
888 std::string getName() const override { return "annotation-remover"; }
889 NLARemover nlaRemover;
890};
891
892/// A reduction pattern that replaces ResetType with UInt<1> across an entire
893/// circuit. This walks all operations in the circuit and replaces ResetType in
894/// results, block arguments, and attributes.
895struct SimplifyResets : public OpReduction<CircuitOp> {
896 uint64_t match(CircuitOp circuit) override {
897 uint64_t numResets = 0;
898 AttrTypeWalker walker;
899 walker.addWalk([&](ResetType type) { ++numResets; });
900
901 circuit.walk([&](Operation *op) {
902 for (auto result : op->getResults())
903 walker.walk(result.getType());
904
905 for (auto &region : op->getRegions())
906 for (auto &block : region)
907 for (auto arg : block.getArguments())
908 walker.walk(arg.getType());
909
910 walker.walk(op->getAttrDictionary());
911 });
912
913 return numResets;
914 }
915
916 LogicalResult rewrite(CircuitOp circuit) override {
917 auto uint1Type = UIntType::get(circuit->getContext(), 1, false);
918 auto constUint1Type = UIntType::get(circuit->getContext(), 1, true);
919
920 AttrTypeReplacer replacer;
921 replacer.addReplacement([&](ResetType type) {
922 return type.isConst() ? constUint1Type : uint1Type;
923 });
924 replacer.recursivelyReplaceElementsIn(circuit, /*replaceAttrs=*/true,
925 /*replaceLocs=*/false,
926 /*replaceTypes=*/true);
927
928 // Remove annotations related to InferResets pass
929 circuit.walk([&](Operation *op) {
930 // Remove operation annotations
932 return anno.isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
933 fullAsyncResetAnnoClass,
934 ignoreFullAsyncResetAnnoClass);
935 });
936
937 // Remove port annotations for module-like operations
938 if (auto module = dyn_cast<FModuleLike>(op)) {
939 AnnotationSet::removePortAnnotations(module, [&](unsigned portIdx,
940 Annotation anno) {
941 return anno.isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
942 fullAsyncResetAnnoClass,
943 ignoreFullAsyncResetAnnoClass);
944 });
945 }
946 });
947
948 return success();
949 }
950
951 std::string getName() const override { return "firrtl-simplify-resets"; }
952 bool acceptSizeIncrease() const override { return true; }
953};
954
955/// A sample reduction pattern that removes ports from the root `firrtl.module`
956/// if the port is not used or just invalidated.
957struct RootPortPruner : public OpReduction<firrtl::FModuleOp> {
958 void matches(firrtl::FModuleOp module,
959 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
960 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
961 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
962 return;
963
964 // Generate one match per port that can be removed
965 size_t numPorts = module.getNumPorts();
966 for (unsigned i = 0; i != numPorts; ++i) {
967 if (onlyInvalidated(module.getArgument(i)))
968 addMatch(1, i);
969 }
970 }
971
972 LogicalResult rewriteMatches(firrtl::FModuleOp module,
973 ArrayRef<uint64_t> matches) override {
974 // Build a BitVector of ports to remove
975 llvm::BitVector dropPorts(module.getNumPorts());
976 for (auto portIdx : matches)
977 dropPorts.set(portIdx);
978
979 // Erase users of the ports being removed
980 for (auto portIdx : matches) {
981 for (auto *user :
982 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
983 user->erase();
984 }
985
986 // Remove the ports from the module
987 module.erasePorts(dropPorts);
988 return success();
989 }
990
991 std::string getName() const override { return "root-port-pruner"; }
992};
993
994/// A reduction pattern that removes all ports from the root `firrtl.extmodule`.
995/// Since extmodules have no body, all ports can be safely removed for reduction
996/// purposes.
997struct RootExtmodulePortPruner : public OpReduction<firrtl::FExtModuleOp> {
998 void matches(firrtl::FExtModuleOp module,
999 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1000 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
1001 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
1002 return;
1003
1004 // Generate one match per port (all ports can be removed from root
1005 // extmodule)
1006 size_t numPorts = module.getNumPorts();
1007 for (unsigned i = 0; i != numPorts; ++i)
1008 addMatch(1, i);
1009 }
1010
1011 LogicalResult rewriteMatches(firrtl::FExtModuleOp module,
1012 ArrayRef<uint64_t> matches) override {
1013 if (matches.empty())
1014 return failure();
1015
1016 // Build a BitVector of ports to remove
1017 llvm::BitVector dropPorts(module.getNumPorts());
1018 for (auto portIdx : matches)
1019 dropPorts.set(portIdx);
1020
1021 // Remove the ports from the module
1022 module.erasePorts(dropPorts);
1023 return success();
1024 }
1025
1026 std::string getName() const override { return "root-extmodule-port-pruner"; }
1027};
1028
1029/// A sample reduction pattern that replaces instances of `firrtl.extmodule`
1030/// with wires.
1031struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {
1032 void beforeReduction(mlir::ModuleOp op) override {
1033 symbols.clear();
1034 nlaRemover.clear();
1035 }
1036 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1037
1038 uint64_t match(firrtl::InstanceOp instOp) override {
1039 return isa<firrtl::FExtModuleOp>(
1040 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1041 }
1042 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
1043 auto portInfo =
1044 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1045 symbols.getNearestSymbolTable(instOp)))
1046 .getPorts();
1047 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1048 TieOffCache tieOffCache(builder);
1049 SmallVector<Value> replacementWires;
1050 for (firrtl::PortInfo info : portInfo) {
1051 auto wire = firrtl::WireOp::create(
1052 builder, info.type,
1053 (Twine(instOp.getName()) + "_" + info.getName()).str())
1054 .getResult();
1055 if (info.isOutput()) {
1056 // Tie off output ports using TieOffCache.
1057 if (auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(info.type)) {
1058 auto inv = tieOffCache.getInvalid(baseType);
1059 firrtl::ConnectOp::create(builder, wire, inv);
1060 } else if (auto propType = dyn_cast<firrtl::PropertyType>(info.type)) {
1061 auto unknown = tieOffCache.getUnknown(propType);
1062 builder.create<firrtl::PropAssignOp>(wire, unknown);
1063 }
1064 }
1065 replacementWires.push_back(wire);
1066 }
1067 nlaRemover.markNLAsInOperation(instOp);
1068 instOp.replaceAllUsesWith(std::move(replacementWires));
1069 instOp->erase();
1070 return success();
1071 }
1072 std::string getName() const override { return "extmodule-instance-remover"; }
1073 bool acceptSizeIncrease() const override { return true; }
1074
1075 ::detail::SymbolCache symbols;
1076 NLARemover nlaRemover;
1077};
1078
1079/// A reduction pattern that removes unused ports from extmodules and regular
1080/// modules. This is particularly useful for reducing test cases with many probe
1081/// ports or other unused ports.
1082///
1083/// Shared helper functions for port pruning reductions.
1084struct PortPrunerHelpers {
1085 /// Compute which ports are unused across all instances of a module.
1086 template <typename ModuleOpType>
1087 static void computeUnusedInstancePorts(ModuleOpType module,
1088 ArrayRef<Operation *> users,
1089 llvm::BitVector &portsToRemove) {
1090 auto ports = module.getPorts();
1091 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1092 bool portUsed = false;
1093 for (auto *user : users) {
1094 if (auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1095 auto result = instOp.getResult(portIdx);
1096 if (!result.use_empty()) {
1097 portUsed = true;
1098 break;
1099 }
1100 }
1101 }
1102 if (!portUsed)
1103 portsToRemove.set(portIdx);
1104 }
1105 }
1106
1107 /// Update all instances of a module to remove the specified ports.
1108 static void
1109 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1110 const llvm::BitVector &portsToRemove) {
1111 // Update all instances to remove the corresponding results
1112 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1113 for (auto *user : users) {
1114 if (auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1115 instancesToUpdate.push_back(instOp);
1116 }
1117
1118 for (auto instOp : instancesToUpdate) {
1119 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1120
1121 // Manually replace uses, skipping erased ports
1122 size_t newResultIdx = 0;
1123 for (size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1124 ++oldResultIdx) {
1125 if (portsToRemove[oldResultIdx]) {
1126 // This port is being removed, assert it has no uses
1127 assert(instOp.getResult(oldResultIdx).use_empty() &&
1128 "removing port with uses");
1129 } else {
1130 // Replace uses of the old result with the new result
1131 instOp.getResult(oldResultIdx)
1132 .replaceAllUsesWith(newInst->getResult(newResultIdx));
1133 ++newResultIdx;
1134 }
1135 }
1136
1137 instOp->erase();
1138 }
1139 }
1140};
1141
1142/// Reduction to remove unused ports from regular modules.
1143struct ModulePortPruner : public OpReduction<firrtl::FModuleOp> {
1144 void beforeReduction(mlir::ModuleOp op) override {
1145 symbols.clear();
1146 nlaRemover.clear();
1147 }
1148 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1149
1150 void matches(firrtl::FModuleOp module,
1151 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1152 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1153 auto &userMap = symbols.getSymbolUserMap(tableOp);
1154 auto ports = module.getPorts();
1155 auto users = userMap.getUsers(module);
1156
1157 // Compute which ports can be removed. A port can only be removed if it
1158 // is unused in both the module body and across all instances.
1159 llvm::BitVector portsToRemove(ports.size());
1160
1161 // Check if ports are unused across all instances.
1162 if (!users.empty())
1163 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1164 portsToRemove);
1165 else
1166 // If there are no instances, all ports are candidates for removal.
1167 portsToRemove.set();
1168
1169 // Additionally check if ports are unused within the module body itself.
1170 // A port must be unused in both instances and the module body to be
1171 // removable.
1172 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1173 if (!portsToRemove[portIdx])
1174 continue;
1175 if (!module.getArgument(portIdx).use_empty())
1176 portsToRemove.reset(portIdx);
1177 }
1178
1179 // Generate one match per removable port.
1180 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1181 if (portsToRemove[portIdx])
1182 addMatch(1, portIdx);
1183 }
1184
1185 LogicalResult rewriteMatches(firrtl::FModuleOp module,
1186 ArrayRef<uint64_t> matches) override {
1187 if (matches.empty())
1188 return failure();
1189
1190 // Build a BitVector of ports to remove
1191 llvm::BitVector portsToRemove(module.getNumPorts());
1192 for (auto portIdx : matches)
1193 portsToRemove.set(portIdx);
1194
1195 // Get users for updating instances
1196 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1197 auto &userMap = symbols.getSymbolUserMap(tableOp);
1198 auto users = userMap.getUsers(module);
1199
1200 // Update all instances
1201 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1202 portsToRemove);
1203
1204 // Remove the ports from the module. We don't need to erase users because
1205 // matches() already ensured that these ports have no users.
1206 module.erasePorts(portsToRemove);
1207
1208 return success();
1209 }
1210
1211 std::string getName() const override { return "module-port-pruner"; }
1212
1213 ::detail::SymbolCache symbols;
1214 NLARemover nlaRemover;
1215};
1216
1217/// Reduction to remove unused ports from extmodules.
1218struct ExtmodulePortPruner : public OpReduction<firrtl::FExtModuleOp> {
1219 void beforeReduction(mlir::ModuleOp op) override {
1220 symbols.clear();
1221 nlaRemover.clear();
1222 }
1223 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1224
1225 void matches(firrtl::FExtModuleOp module,
1226 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1227 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1228 auto &userMap = symbols.getSymbolUserMap(tableOp);
1229 auto ports = module.getPorts();
1230 auto users = userMap.getUsers(module);
1231
1232 // Compute which ports can be removed
1233 llvm::BitVector portsToRemove(ports.size());
1234
1235 if (users.empty()) {
1236 // If the extmodule has no instances, aggressively remove all ports
1237 portsToRemove.set();
1238 } else {
1239 // For extmodules with instances, check if ports are unused across all
1240 // instances
1241 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1242 portsToRemove);
1243 }
1244
1245 // Generate one match per removable port
1246 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1247 if (portsToRemove[portIdx])
1248 addMatch(1, portIdx);
1249 }
1250
1251 LogicalResult rewriteMatches(firrtl::FExtModuleOp module,
1252 ArrayRef<uint64_t> matches) override {
1253 if (matches.empty())
1254 return failure();
1255
1256 // Build a BitVector of ports to remove
1257 llvm::BitVector portsToRemove(module.getNumPorts());
1258 for (auto portIdx : matches)
1259 portsToRemove.set(portIdx);
1260
1261 // Get users for updating instances
1262 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1263 auto &userMap = symbols.getSymbolUserMap(tableOp);
1264 auto users = userMap.getUsers(module);
1265
1266 // Update all instances.
1267 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1268 portsToRemove);
1269
1270 // Remove the ports from the module (no body to clean up for extmodules).
1271 module.erasePorts(portsToRemove);
1272
1273 return success();
1274 }
1275
1276 std::string getName() const override { return "extmodule-port-pruner"; }
1277
1278 ::detail::SymbolCache symbols;
1279 NLARemover nlaRemover;
1280};
1281
1282/// A sample reduction pattern that pushes connected values through wires.
1283struct ConnectForwarder : public Reduction {
1284 void beforeReduction(mlir::ModuleOp op) override {
1285 domInfo = std::make_unique<DominanceInfo>(op);
1286 }
1287
1288 uint64_t match(Operation *op) override {
1289 if (!isa<firrtl::FConnectLike>(op))
1290 return 0;
1291 auto dest = op->getOperand(0);
1292 auto src = op->getOperand(1);
1293 auto *destOp = dest.getDefiningOp();
1294 auto *srcOp = src.getDefiningOp();
1295 if (dest == src)
1296 return 0;
1297
1298 // Ensure that the destination is something we should be able to forward
1299 // through.
1300 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1301 destOp))
1302 return 0;
1303
1304 // Ensure that the destination is connected to only once, and all uses of
1305 // the connection occur after the definition of the source.
1306 unsigned numConnects = 0;
1307 for (auto &use : dest.getUses()) {
1308 auto *op = use.getOwner();
1309 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1310 if (++numConnects > 1)
1311 return 0;
1312 continue;
1313 }
1314 // Check if srcOp properly dominates op, but op is not enclosed in srcOp.
1315 // This handles cross-block cases (e.g., layerblocks).
1316 if (srcOp &&
1317 !domInfo->properlyDominates(srcOp, op, /*enclosingOpOk=*/false))
1318 return 0;
1319 }
1320
1321 return 1;
1322 }
1323
1324 LogicalResult rewrite(Operation *op) override {
1325 auto dst = op->getOperand(0);
1326 auto src = op->getOperand(1);
1327 dst.replaceAllUsesExcept(src, op);
1328 op->erase();
1329 SmallVector<Operation *> worklist(
1330 {dst.getDefiningOp(), src.getDefiningOp()});
1331 reduce::pruneUnusedOps(worklist, *this);
1332 return success();
1333 }
1334
1335 std::string getName() const override { return "connect-forwarder"; }
1336
1337private:
1338 std::unique_ptr<DominanceInfo> domInfo;
1339};
1340
1341/// A sample reduction pattern that replaces a single-use wire and register with
1342/// an operand of the source value of the connection.
1343template <unsigned OpNum>
1344struct ConnectSourceOperandForwarder : public Reduction {
1345 uint64_t match(Operation *op) override {
1346 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1347 return 0;
1348 auto dest = op->getOperand(0);
1349 auto *destOp = dest.getDefiningOp();
1350
1351 // Ensure that the destination is used only once.
1352 if (!destOp || !destOp->hasOneUse() ||
1353 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1354 return 0;
1355
1356 auto *srcOp = op->getOperand(1).getDefiningOp();
1357 if (!srcOp || OpNum >= srcOp->getNumOperands())
1358 return 0;
1359
1360 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1361 auto opTy =
1362 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1363
1364 return resultTy && opTy &&
1365 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1366 ((resultTy.getBitWidthOrSentinel() == -1) ==
1367 (opTy.getBitWidthOrSentinel() == -1)) &&
1368 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1369 }
1370
1371 LogicalResult rewrite(Operation *op) override {
1372 auto *destOp = op->getOperand(0).getDefiningOp();
1373 auto *srcOp = op->getOperand(1).getDefiningOp();
1374 auto forwardedOperand = srcOp->getOperand(OpNum);
1375 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1376 Value newDest;
1377 if (auto wire = dyn_cast<firrtl::WireOp>(destOp))
1378 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1379 wire.getName())
1380 .getResult();
1381 else {
1382 auto regName = destOp->getAttrOfType<StringAttr>("name");
1383 // We can promote the register into a wire but we wouldn't do here because
1384 // the error might be caused by the register.
1385 auto clock = destOp->getOperand(0);
1386 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1387 clock, regName ? regName.str() : "")
1388 .getResult();
1389 }
1390
1391 // Create new connection between a new wire and the forwarded operand.
1392 builder.setInsertionPointAfter(op);
1393 if (isa<firrtl::ConnectOp>(op))
1394 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1395 else
1396 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1397
1398 // Remove the old connection and destination. We don't have to replace them
1399 // because destination has only one use.
1400 op->erase();
1401 destOp->erase();
1402 reduce::pruneUnusedOps(srcOp, *this);
1403
1404 return success();
1405 }
1406 std::string getName() const override {
1407 return ("connect-source-operand-" + Twine(OpNum) + "-forwarder").str();
1408 }
1409};
1410
1411/// A sample reduction pattern that tries to remove aggregate wires by replacing
1412/// all subaccesses with new independent wires. This can disentangle large
1413/// unused wires that are otherwise difficult to collect due to the subaccesses.
1414struct DetachSubaccesses : public Reduction {
1415 void beforeReduction(mlir::ModuleOp op) override { opsToErase.clear(); }
1416 void afterReduction(mlir::ModuleOp op) override {
1417 for (auto *op : opsToErase)
1418 op->dropAllReferences();
1419 for (auto *op : opsToErase)
1420 op->erase();
1421 }
1422 uint64_t match(Operation *op) override {
1423 // Only applies to wires and registers that are purely used in subaccess
1424 // operations.
1425 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1426 llvm::all_of(op->getUses(), [](auto &use) {
1427 return use.getOperandNumber() == 0 &&
1428 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1429 firrtl::SubaccessOp>(use.getOwner());
1430 });
1431 }
1432 LogicalResult rewrite(Operation *op) override {
1433 assert(match(op));
1434 OpBuilder builder(op);
1435 bool isWire = isa<firrtl::WireOp>(op);
1436 Value invalidClock;
1437 if (!isWire)
1438 invalidClock = firrtl::InvalidValueOp::create(
1439 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1440 for (Operation *user : llvm::make_early_inc_range(op->getUsers())) {
1441 builder.setInsertionPoint(user);
1442 auto type = user->getResult(0).getType();
1443 Operation *replOp;
1444 if (isWire)
1445 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1446 else
1447 replOp =
1448 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1449 user->replaceAllUsesWith(replOp);
1450 opsToErase.insert(user);
1451 }
1452 opsToErase.insert(op);
1453 return success();
1454 }
1455 std::string getName() const override { return "detach-subaccesses"; }
1456 llvm::DenseSet<Operation *> opsToErase;
1457};
1458
1459/// This reduction removes inner symbols on ops. Name preservation creates a lot
1460/// of node ops with symbols to keep name information but it also prevents
1461/// normal canonicalizations.
1462struct NodeSymbolRemover : public Reduction {
1463 void beforeReduction(mlir::ModuleOp op) override {
1464 innerSymUses = reduce::InnerSymbolUses(op);
1465 }
1466
1467 uint64_t match(Operation *op) override {
1468 // Only match ops with an inner symbol.
1469 auto sym = op->getAttrOfType<hw::InnerSymAttr>("inner_sym");
1470 if (!sym || sym.empty())
1471 return 0;
1472
1473 // Only match ops that have no references to their inner symbol.
1474 if (innerSymUses.hasInnerRef(op))
1475 return 0;
1476 return 1;
1477 }
1478
1479 LogicalResult rewrite(Operation *op) override {
1480 op->removeAttr("inner_sym");
1481 return success();
1482 }
1483
1484 std::string getName() const override { return "node-symbol-remover"; }
1485 bool acceptSizeIncrease() const override { return true; }
1486
1487 reduce::InnerSymbolUses innerSymUses;
1488};
1489
1490/// Check if inlining the referenced operation into the parent operation would
1491/// cause inner symbol collisions.
1492static bool
1493hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1494 hw::InnerSymbolTableCollection &innerSymTables) {
1495 // Get the inner symbol tables for both operations
1496 auto &targetTable = innerSymTables.getInnerSymbolTable(referencedOp);
1497 auto &parentTable = innerSymTables.getInnerSymbolTable(parentOp);
1498
1499 // Check if any inner symbol name in the target operation already exists
1500 // in the parent operation. Return failure() if a collision is found to stop
1501 // the walk early.
1502 LogicalResult walkResult = targetTable.walkSymbols(
1503 [&](StringAttr name, const hw::InnerSymTarget &target) -> LogicalResult {
1504 // Check if this symbol name exists in the parent operation
1505 if (parentTable.lookup(name)) {
1506 // Collision found, return failure to stop the walk
1507 return failure();
1508 }
1509 return success();
1510 });
1511
1512 // If the walk failed, it means we found a collision
1513 return failed(walkResult);
1514}
1515
1516/// A sample reduction pattern that eagerly inlines instances.
1517struct EagerInliner : public OpReduction<InstanceOp> {
1518 void beforeReduction(mlir::ModuleOp op) override {
1519 symbols.clear();
1520 nlaRemover.clear();
1521 nlaTables.clear();
1522 for (auto circuitOp : op.getOps<CircuitOp>())
1523 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1524 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1525 }
1526 void afterReduction(mlir::ModuleOp op) override {
1527 nlaRemover.remove(op);
1528 nlaTables.clear();
1529 innerSymTables.reset();
1530 }
1531
1532 uint64_t match(InstanceOp instOp) override {
1533 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1534 auto *moduleOp =
1535 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1536
1537 // Only inline FModuleOp instances
1538 if (!isa<FModuleOp>(moduleOp))
1539 return 0;
1540
1541 // Skip instances that participate in any NLAs
1542 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1543 if (!circuitOp)
1544 return 0;
1545 auto it = nlaTables.find(circuitOp);
1546 if (it == nlaTables.end() || !it->second)
1547 return 0;
1548 DenseSet<hw::HierPathOp> nlas;
1549 it->second->getInstanceNLAs(instOp, nlas);
1550 if (!nlas.empty())
1551 return 0;
1552
1553 // Check for inner symbol collisions between the referenced module and the
1554 // instance's parent module
1555 auto parentOp = instOp->getParentOfType<FModuleLike>();
1556 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1557 return 0;
1558
1559 return 1;
1560 }
1561
1562 LogicalResult rewrite(InstanceOp instOp) override {
1563 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1564 auto moduleOp = cast<FModuleOp>(
1565 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1566 bool isLastUse =
1567 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1568 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1569
1570 // Create wires to replace the instance results.
1571 IRRewriter rewriter(instOp);
1572 SmallVector<Value> argWires;
1573 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1574 auto result = instOp.getResult(i);
1575 auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" +
1576 instOp.getPortName(i));
1577 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1578 name, NameKindEnum::DroppableName,
1579 instOp.getPortAnnotation(i), StringAttr{})
1580 .getResult();
1581 result.replaceAllUsesWith(wire);
1582 argWires.push_back(wire);
1583 }
1584
1585 // Splice in the cloned module body.
1586 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1587
1588 // Make sure we remove any NLAs that go through this instance, and the
1589 // module if we're about the delete the module.
1590 nlaRemover.markNLAsInOperation(instOp);
1591 if (isLastUse)
1592 nlaRemover.markNLAsInOperation(moduleOp);
1593
1594 instOp.erase();
1595 clonedModuleOp.erase();
1596 return success();
1597 }
1598
1599 std::string getName() const override { return "firrtl-eager-inliner"; }
1600 bool acceptSizeIncrease() const override { return true; }
1601
1602 ::detail::SymbolCache symbols;
1603 NLARemover nlaRemover;
1604 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1605 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1606};
1607
1608/// A reduction pattern that eagerly inlines `ObjectOp`s.
1609struct ObjectInliner : public OpReduction<ObjectOp> {
1610 void beforeReduction(mlir::ModuleOp op) override {
1611 blocksToSort.clear();
1612 symbols.clear();
1613 nlaRemover.clear();
1614 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1615 }
1616 void afterReduction(mlir::ModuleOp op) override {
1617 for (auto *block : blocksToSort)
1618 mlir::sortTopologically(block);
1619 blocksToSort.clear();
1620 nlaRemover.remove(op);
1621 innerSymTables.reset();
1622 }
1623
1624 uint64_t match(ObjectOp objOp) override {
1625 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1626 auto *classOp =
1627 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1628
1629 // Only inline `ClassOp`s.
1630 if (!isa<ClassOp>(classOp))
1631 return 0;
1632
1633 // Check for inner symbol collisions between the referenced class and the
1634 // object's parent module.
1635 auto parentOp = objOp->getParentOfType<FModuleLike>();
1636 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1637 return 0;
1638
1639 // Verify all uses are ObjectSubfieldOp.
1640 for (auto *user : objOp.getResult().getUsers())
1641 if (!isa<ObjectSubfieldOp>(user))
1642 return 0;
1643
1644 return 1;
1645 }
1646
1647 LogicalResult rewrite(ObjectOp objOp) override {
1648 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1649 auto classOp = cast<ClassOp>(
1650 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1651 auto clonedClassOp = classOp.clone();
1652
1653 // Create wires to replace the ObjectSubfieldOp results.
1654 IRRewriter rewriter(objOp);
1655 SmallVector<Value> portWires;
1656 auto classType = objOp.getType();
1657
1658 // Create a wire for each port in the class
1659 for (unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1660 auto element = classType.getElement(i);
1661 auto name = rewriter.getStringAttr(Twine(objOp.getName()) + "_" +
1662 element.name.getValue());
1663 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1664 NameKindEnum::DroppableName,
1665 rewriter.getArrayAttr({}), StringAttr{})
1666 .getResult();
1667 portWires.push_back(wire);
1668 }
1669
1670 // Replace all ObjectSubfieldOp uses with corresponding wires
1671 SmallVector<ObjectSubfieldOp> subfieldOps;
1672 for (auto *user : objOp.getResult().getUsers()) {
1673 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1674 subfieldOps.push_back(subfieldOp);
1675 auto index = subfieldOp.getIndex();
1676 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1677 }
1678
1679 // Splice in the cloned class body.
1680 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1681
1682 // After inlining the class body, we need to eliminate `WireOps` since
1683 // `ClassOps` cannot contain wires. For each port wire, find its single
1684 // connect, remove it, and replace all uses of the wire with the assigned
1685 // value.
1686 SmallVector<FConnectLike> connectsToErase;
1687 for (auto portWire : portWires) {
1688 // Find a single value to replace the wire with, and collect all connects
1689 // to the wire such that we can erase them later.
1690 Value value;
1691 for (auto *user : portWire.getUsers()) {
1692 if (auto connect = dyn_cast<FConnectLike>(user)) {
1693 if (connect.getDest() == portWire) {
1694 value = connect.getSrc();
1695 connectsToErase.push_back(connect);
1696 }
1697 }
1698 }
1699
1700 // Be very conservative about deleting these wires. Other reductions may
1701 // leave class ports unconnected, which means that there isn't always a
1702 // clean replacement available here. Better to just leave the wires in the
1703 // IR and let the verifier fail later.
1704 if (value)
1705 portWire.replaceAllUsesWith(value);
1706 for (auto connect : connectsToErase)
1707 connect.erase();
1708 if (portWire.use_empty())
1709 portWire.getDefiningOp()->erase();
1710 connectsToErase.clear();
1711 }
1712
1713 // Make sure we remove any NLAs that go through this object.
1714 nlaRemover.markNLAsInOperation(objOp);
1715
1716 // Since the above forwarding of SSA values through wires can create
1717 // dominance issues, mark the region containing the object to be sorted
1718 // topologically.
1719 blocksToSort.insert(objOp->getBlock());
1720
1721 // Erase the object and cloned class.
1722 for (auto subfieldOp : subfieldOps)
1723 subfieldOp.erase();
1724 objOp.erase();
1725 clonedClassOp.erase();
1726 return success();
1727 }
1728
1729 std::string getName() const override { return "firrtl-object-inliner"; }
1730 bool acceptSizeIncrease() const override { return true; }
1731
1732 SetVector<Block *> blocksToSort;
1733 ::detail::SymbolCache symbols;
1734 NLARemover nlaRemover;
1735 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1736};
1737
1738/// Reduction that converts `regreset` to `reg` by dropping reset and init
1739/// value.
1740struct ResetDisconnector : public OpReduction<RegResetOp> {
1741 uint64_t match(RegResetOp op) override { return 1; }
1742
1743 LogicalResult rewrite(RegResetOp regResetOp) override {
1744 ImplicitLocOpBuilder builder(regResetOp.getLoc(), regResetOp);
1745 auto regOp = RegOp::create(
1746 builder, regResetOp.getResult().getType(), regResetOp.getClockVal(),
1747 regResetOp.getNameAttr(), regResetOp.getNameKindAttr(),
1748 regResetOp.getAnnotationsAttr(), regResetOp.getInnerSymAttr(),
1749 regResetOp.getForceableAttr());
1750
1751 regResetOp.getResult().replaceAllUsesWith(regOp.getResult());
1752 if (regResetOp.getForceable())
1753 regResetOp.getRef().replaceAllUsesWith(regOp.getRef());
1754 regResetOp.erase();
1755
1756 return success();
1757 }
1758
1759 std::string getName() const override { return "reset-disconnector"; }
1760};
1761
1762/// Psuedo-reduction that sanitizes the names of things inside modules. This is
1763/// not an actual reduction, but often removes extraneous information that has
1764/// no bearing on the actual reduction (and would likely be removed before
1765/// sharing the reduction). This makes the following changes:
1766///
1767/// - All wires are renamed to "wire"
1768/// - All registers are renamed to "reg"
1769/// - All nodes are renamed to "node"
1770/// - All memories are renamed to "mem"
1771/// - All verification messages and labels are dropped
1772///
1773struct ModuleInternalNameSanitizer : public Reduction {
1774 uint64_t match(Operation *op) override {
1775 // Only match operations with names.
1776 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1777 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1778 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1779 firrtl::CoverOp>(op);
1780 }
1781 LogicalResult rewrite(Operation *op) override {
1782 TypeSwitch<Operation *, void>(op)
1783 .Case<firrtl::WireOp>([](auto op) { op.setName("wire"); })
1784 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1785 [](auto op) { op.setName("reg"); })
1786 .Case<firrtl::NodeOp>([](auto op) { op.setName("node"); })
1787 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1788 [](auto op) { op.setName("mem"); })
1789 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](auto op) {
1790 op->setAttr("message", StringAttr::get(op.getContext(), ""));
1791 op->setAttr("name", StringAttr::get(op.getContext(), ""));
1792 });
1793 return success();
1794 }
1795
1796 std::string getName() const override {
1797 return "module-internal-name-sanitizer";
1798 }
1799
1800 bool acceptSizeIncrease() const override { return true; }
1801
1802 bool isOneShot() const override { return true; }
1803};
1804
1805/// Psuedo-reduction that sanitizes module, instance, and port names. This
1806/// makes the following changes:
1807///
1808/// - All modules are given metasyntactic names ("Foo", "Bar", etc.)
1809/// - All instances are renamed to match the new module name
1810/// - All module ports are renamed in the following way:
1811/// - All clocks are reanemd to "clk"
1812/// - All resets are renamed to "rst"
1813/// - All references are renamed to "ref"
1814/// - Anything else is renamed to "port"
1815///
1816struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
1817
1819 size_t portNameIndex = 0;
1820
1821 char getPortName() {
1822 if (portNameIndex >= 26)
1823 portNameIndex = 0;
1824 return 'a' + portNameIndex++;
1825 }
1826
1827 void beforeReduction(mlir::ModuleOp op) override { nameGenerator.reset(); }
1828
1829 LogicalResult rewrite(firrtl::CircuitOp circuitOp) override {
1830
1831 firrtl::InstanceGraph iGraph(circuitOp);
1832
1833 auto *circuitName = nameGenerator.getNextName();
1834 iGraph.getTopLevelModule().setName(circuitName);
1835 circuitOp.setName(circuitName);
1836
1837 for (auto *node : iGraph) {
1838 auto module = node->getModule<firrtl::FModuleLike>();
1839
1840 bool shouldReplacePorts = false;
1841 SmallVector<Attribute> newNames;
1842 if (auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1843 portNameIndex = 0;
1844 // TODO: The namespace should be unnecessary. However, some FIRRTL
1845 // passes expect that port names are unique.
1847 auto oldPorts = fmodule.getPorts();
1848 shouldReplacePorts = !oldPorts.empty();
1849 for (unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1850 auto port = oldPorts[i];
1851 auto newName = firrtl::FIRRTLTypeSwitch<Type, StringRef>(port.type)
1852 .Case<firrtl::ClockType>(
1853 [&](auto a) { return ns.newName("clk"); })
1854 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1855 [&](auto a) { return ns.newName("rst"); })
1856 .Case<firrtl::RefType>(
1857 [&](auto a) { return ns.newName("ref"); })
1858 .Default([&](auto a) {
1859 return ns.newName(Twine(getPortName()));
1860 });
1861 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1862 }
1863 fmodule->setAttr("portNames",
1864 ArrayAttr::get(fmodule.getContext(), newNames));
1865 }
1866
1867 if (module == iGraph.getTopLevelModule())
1868 continue;
1869 auto newName =
1870 StringAttr::get(circuitOp.getContext(), nameGenerator.getNextName());
1871 module.setName(newName);
1872 for (auto *use : node->uses()) {
1873 auto useOp = use->getInstance();
1874 if (auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1875 instanceOp.setModuleName(newName);
1876 instanceOp.setName(newName);
1877 if (shouldReplacePorts)
1878 instanceOp.setPortNamesAttr(
1879 ArrayAttr::get(circuitOp.getContext(), newNames));
1880 } else if (auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1881 // ObjectOp stores the class name in its result type, so we need to
1882 // create a new ClassType with the new name and set it on the result.
1883 auto oldClassType = objectOp.getType();
1884 auto newClassType = firrtl::ClassType::get(
1885 circuitOp.getContext(), FlatSymbolRefAttr::get(newName),
1886 oldClassType.getElements());
1887 objectOp.getResult().setType(newClassType);
1888 objectOp.setName(newName);
1889 }
1890 }
1891 }
1892
1893 return success();
1894 }
1895
1896 std::string getName() const override { return "module-name-sanitizer"; }
1897
1898 bool acceptSizeIncrease() const override { return true; }
1899
1900 bool isOneShot() const override { return true; }
1901};
1902
1903/// A reduction pattern that groups modules by their port signature (types and
1904/// directions) and replaces instances with the smallest module in each group.
1905/// This helps reduce the IR by consolidating functionally equivalent modules
1906/// based on their interface.
1907///
1908/// The pattern works by:
1909/// 1. Grouping all modules by their port signature (port types and directions)
1910/// 2. For each group with multiple modules, finding the smallest module using
1911/// the module size cache
1912/// 3. Replacing all instances of larger modules with instances of the smallest
1913/// module in the same group
1914/// 4. Removing the larger modules from the circuit
1915///
1916/// This reduction is useful for reducing circuits where multiple modules have
1917/// the same interface but different implementations, allowing the reducer to
1918/// try the smallest implementation first.
1919struct ModuleSwapper : public OpReduction<InstanceOp> {
1920 // Per-circuit state containing all the information needed for module swapping
1921 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1922 struct CircuitState {
1923 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1924 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1925 std::unique_ptr<NLATable> nlaTable;
1926 };
1927
1928 void beforeReduction(mlir::ModuleOp op) override {
1929 symbols.clear();
1930 nlaRemover.clear();
1931 moduleSizes.clear();
1932 circuitStates.clear();
1933
1934 // Collect module type groups and NLA tables for all circuits up front
1935 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1936 auto &state = circuitStates[circuitOp];
1937 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1938 buildModuleTypeGroups(circuitOp, state);
1939 return WalkResult::skip();
1940 });
1941 }
1942 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1943
1944 /// Create a vector of port type-direction pairs for the given FIRRTL module.
1945 /// This ignores port names, allowing modules with the same port types and
1946 /// directions but different port names to be considered equivalent for
1947 /// swapping.
1948 PortSignature getModulePortSignature(FModuleLike module) {
1949 PortSignature signature;
1950 signature.reserve(module.getNumPorts());
1951 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1952 signature.emplace_back(module.getPortType(i), module.getPortDirection(i));
1953 return signature;
1954 }
1955
1956 /// Group modules by their port signature and find the smallest in each group.
1957 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1958 // Group modules by their port signature
1959 for (auto module : circuitOp.getBodyBlock()->getOps<FModuleLike>()) {
1960 auto signature = getModulePortSignature(module);
1961 state.moduleTypeGroups[signature].push_back(module);
1962 }
1963
1964 // For each group, find the smallest module
1965 for (auto &[signature, modules] : state.moduleTypeGroups) {
1966 if (modules.size() <= 1)
1967 continue;
1968
1969 FModuleLike smallestModule = nullptr;
1970 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1971
1972 for (auto module : modules) {
1973 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1974 if (size < smallestSize) {
1975 smallestSize = size;
1976 smallestModule = module;
1977 }
1978 }
1979
1980 // Map all modules in this group to the smallest one
1981 for (auto module : modules) {
1982 if (module != smallestModule) {
1983 state.instanceToCanonicalModule[module.getModuleNameAttr()] =
1984 smallestModule;
1985 }
1986 }
1987 }
1988 }
1989
1990 uint64_t match(InstanceOp instOp) override {
1991 // Get the circuit this instance belongs to
1992 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1993 assert(circuitOp);
1994 const auto &state = circuitStates.at(circuitOp);
1995
1996 // Skip instances that participate in any NLAs
1997 DenseSet<hw::HierPathOp> nlas;
1998 state.nlaTable->getInstanceNLAs(instOp, nlas);
1999 if (!nlas.empty())
2000 return 0;
2001
2002 // Check if this instance can be redirected to a smaller module
2003 auto moduleName = instOp.getModuleNameAttr().getAttr();
2004 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
2005 if (!canonicalModule)
2006 return 0;
2007
2008 // Benefit is the size difference
2009 auto currentModule = cast<FModuleLike>(
2010 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
2011 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
2012 uint64_t canonicalSize =
2013 moduleSizes.getModuleSize(canonicalModule, symbols);
2014 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
2015 }
2016
2017 LogicalResult rewrite(InstanceOp instOp) override {
2018 // Get the circuit this instance belongs to
2019 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2020 assert(circuitOp);
2021 const auto &state = circuitStates.at(circuitOp);
2022
2023 // Replace the instantiated module with the canonical module.
2024 auto canonicalModule = state.instanceToCanonicalModule.at(
2025 instOp.getModuleNameAttr().getAttr());
2026 auto canonicalName = canonicalModule.getModuleNameAttr();
2027 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2028
2029 // Update port names to match the canonical module
2030 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2031
2032 return success();
2033 }
2034
2035 std::string getName() const override { return "firrtl-module-swapper"; }
2036 bool acceptSizeIncrease() const override { return true; }
2037
2038private:
2039 ::detail::SymbolCache symbols;
2040 NLARemover nlaRemover;
2041 ModuleSizeCache moduleSizes;
2042
2043 // Per-circuit state containing all module swapping information
2044 DenseMap<CircuitOp, CircuitState> circuitStates;
2045};
2046
2047/// A reduction pattern that handles MustDedup annotations by replacing all
2048/// module names in a dedup group with a single module name. This helps reduce
2049/// the IR by consolidating module references that are required to be identical.
2050///
2051/// The pattern works by:
2052/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
2053/// 2. For each dedup group, using the first module as the canonical name
2054/// 3. Replacing all instance references to other modules in the group with
2055/// references to the canonical module
2056/// 4. Removing the non-canonical modules from the circuit
2057/// 5. Removing the processed MustDedup annotation
2058///
2059/// This reduction is particularly useful for reducing large circuits where
2060/// multiple modules are known to be identical but haven't been deduplicated
2061/// yet.
2062struct ForceDedup : public OpReduction<CircuitOp> {
2063 void beforeReduction(mlir::ModuleOp op) override {
2064 symbols.clear();
2065 nlaRemover.clear();
2066 modulesToErase.clear();
2067 moduleSizes.clear();
2068 }
2069 void afterReduction(mlir::ModuleOp op) override {
2070 nlaRemover.remove(op);
2071 for (auto mod : modulesToErase)
2072 mod->erase();
2073 }
2074
2075 /// Collect all MustDedup annotations and create matches for each dedup group.
2076 void matches(CircuitOp circuitOp,
2077 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2078 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2079 auto annotations = AnnotationSet(circuitOp);
2080 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2081 if (!anno.isClass(mustDeduplicateAnnoClass))
2082 continue;
2083
2084 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2085 if (!modulesAttr || modulesAttr.size() < 2)
2086 continue;
2087
2088 // Check that all modules have the same port signature. Malformed inputs
2089 // may have modules listed in a MustDedup annotation that have distinct
2090 // port types.
2091 uint64_t totalSize = 0;
2092 ArrayAttr portTypes;
2093 DenseBoolArrayAttr portDirections;
2094 bool allSame = true;
2095 for (auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2096 auto target = tokenizePath(moduleName);
2097 if (!target) {
2098 allSame = false;
2099 break;
2100 }
2101 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2102 if (!mod) {
2103 allSame = false;
2104 break;
2105 }
2106 totalSize += moduleSizes.getModuleSize(mod, symbols);
2107 if (!portTypes) {
2108 portTypes = mod.getPortTypesAttr();
2109 portDirections = mod.getPortDirectionsAttr();
2110 } else if (portTypes != mod.getPortTypesAttr() ||
2111 portDirections != mod.getPortDirectionsAttr()) {
2112 allSame = false;
2113 break;
2114 }
2115 }
2116 if (!allSame)
2117 continue;
2118
2119 // Each dedup group gets its own match with benefit proportional to group
2120 // size.
2121 addMatch(totalSize, annoIdx);
2122 }
2123 }
2124
2125 LogicalResult rewriteMatches(CircuitOp circuitOp,
2126 ArrayRef<uint64_t> matches) override {
2127 auto *context = circuitOp->getContext();
2128 NLATable nlaTable(circuitOp);
2129 hw::InnerSymbolTableCollection innerSymTables;
2130 auto annotations = AnnotationSet(circuitOp);
2131 SmallVector<Annotation> newAnnotations;
2132
2133 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2134 // Check if this annotation was selected.
2135 if (!llvm::is_contained(matches, annoIdx)) {
2136 newAnnotations.push_back(anno);
2137 continue;
2138 }
2139 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2140 assert(anno.isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2141 modulesAttr.size() >= 2);
2142
2143 // Extract module names from the dedup group.
2144 SmallVector<StringAttr> moduleNames;
2145 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2146 // Parse "~CircuitName|ModuleName" format.
2147 auto refStr = moduleRef.getValue();
2148 auto pipePos = refStr.find('|');
2149 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2150 auto moduleName = refStr.substr(pipePos + 1);
2151 moduleNames.push_back(StringAttr::get(context, moduleName));
2152 }
2153 }
2154
2155 // Simply drop the annotation if there's only one module.
2156 if (moduleNames.size() < 2)
2157 continue;
2158
2159 // Replace all instances and references to other modules with the
2160 // first module.
2161 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2162 nlaRemover.markNLAsInAnnotation(anno.getAttr());
2163 }
2164 if (newAnnotations.size() == annotations.size())
2165 return failure();
2166
2167 // Update circuit annotations.
2168 AnnotationSet newAnnoSet(newAnnotations, context);
2169 newAnnoSet.applyToOperation(circuitOp);
2170 return success();
2171 }
2172
2173 std::string getName() const override { return "firrtl-force-dedup"; }
2174 bool acceptSizeIncrease() const override { return true; }
2175
2176private:
2177 /// Replace all references to modules in the dedup group with the canonical
2178 /// module name
2179 void replaceModuleReferences(CircuitOp circuitOp,
2180 ArrayRef<StringAttr> moduleNames,
2181 NLATable &nlaTable,
2182 hw::InnerSymbolTableCollection &innerSymTables) {
2183 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2184 auto &symbolTable = symbols.getSymbolTable(tableOp);
2185 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2186 auto *context = circuitOp->getContext();
2187 auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
2188
2189 // Collect the modules.
2190 FModuleLike canonicalModule;
2191 SmallVector<FModuleLike> modulesToReplace;
2192 for (auto name : moduleNames) {
2193 if (auto mod = symbolTable.lookup<FModuleLike>(name)) {
2194 if (!canonicalModule)
2195 canonicalModule = mod;
2196 else
2197 modulesToReplace.push_back(mod);
2198 }
2199 }
2200 if (modulesToReplace.empty())
2201 return;
2202
2203 // Replace all instance references.
2204 auto canonicalName = canonicalModule.getModuleNameAttr();
2205 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2206 for (auto moduleName : moduleNames) {
2207 if (moduleName == canonicalName)
2208 continue;
2209 auto *symbolOp = symbolTable.lookup(moduleName);
2210 if (!symbolOp)
2211 continue;
2212 for (auto *user : symbolUserMap.getUsers(symbolOp)) {
2213 auto instOp = dyn_cast<InstanceOp>(user);
2214 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2215 continue;
2216 instOp.setModuleNameAttr(canonicalRef);
2217 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2218 }
2219 }
2220
2221 // Update NLAs to reference the canonical module instead of modules being
2222 // removed using NLATable for better performance.
2223 for (auto oldMod : modulesToReplace) {
2224 SmallVector<hw::HierPathOp> nlaOps(
2225 nlaTable.lookup(oldMod.getModuleNameAttr()));
2226 for (auto nlaOp : nlaOps) {
2227 nlaTable.erase(nlaOp);
2228 StringAttr oldModName = oldMod.getModuleNameAttr();
2229 StringAttr newModName = canonicalName;
2230 SmallVector<Attribute, 4> newPath;
2231 for (auto nameRef : nlaOp.getNamepath()) {
2232 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2233 if (ref.getModule() == oldModName) {
2234 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2235 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2236 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2237 if (oldInst && newInst) {
2238 // Get the first module name from the list (for
2239 // InstanceOp/ObjectOp, there's only one)
2240 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2241 auto newModNames = newInst.getReferencedModuleNamesAttr();
2242 if (!oldModNames.empty() && !newModNames.empty()) {
2243 oldModName = cast<StringAttr>(oldModNames[0]);
2244 newModName = cast<StringAttr>(newModNames[0]);
2245 }
2246 }
2247 }
2248 newPath.push_back(ref);
2249 } else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2250 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2251 } else {
2252 newPath.push_back(nameRef);
2253 }
2254 }
2255 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
2256 nlaTable.addNLA(nlaOp);
2257 }
2258 }
2259
2260 // Mark NLAs in modules to be removed.
2261 for (auto module : modulesToReplace) {
2262 nlaRemover.markNLAsInOperation(module);
2263 modulesToErase.insert(module);
2264 }
2265 }
2266
2267 ::detail::SymbolCache symbols;
2268 NLARemover nlaRemover;
2269 SetVector<FModuleLike> modulesToErase;
2270 ModuleSizeCache moduleSizes;
2271};
2272
2273/// A reduction pattern that moves `MustDedup` annotations from a module onto
2274/// its child modules. This pattern iterates over all MustDedup annotations,
2275/// collects all `FInstanceLike` ops in each module of the dedup group, and
2276/// creates new MustDedup annotations for corresponding instances across the
2277/// modules. Each set of corresponding instances becomes a separate match of the
2278/// reduction. The reduction also removes the original MustDedup annotation on
2279/// the parent module.
2280///
2281/// The pattern works by:
2282/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
2283/// 2. For each dedup group, collecting all FInstanceLike operations in each
2284/// module
2285/// 3. Grouping corresponding instances across modules by their position/name
2286/// 4. Creating new MustDedup annotations for each group of corresponding
2287/// instances
2288/// 5. Removing the original MustDedup annotation from the circuit
2289struct MustDedupChildren : public OpReduction<CircuitOp> {
2290 void beforeReduction(mlir::ModuleOp op) override {
2291 symbols.clear();
2292 nlaRemover.clear();
2293 }
2294 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
2295
2296 /// Collect all MustDedup annotations and create matches for each instance
2297 /// group.
2298 void matches(CircuitOp circuitOp,
2299 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2300 auto annotations = AnnotationSet(circuitOp);
2301 uint64_t matchId = 0;
2302
2303 DenseSet<StringRef> modulesAlreadyInMustDedup;
2304 for (auto [annoIdx, anno] : llvm::enumerate(annotations))
2305 if (anno.isClass(mustDeduplicateAnnoClass))
2306 if (auto modulesAttr = anno.getMember<ArrayAttr>("modules"))
2307 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2308 if (auto target = tokenizePath(moduleRef))
2309 modulesAlreadyInMustDedup.insert(target->module);
2310
2311 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2312 if (!anno.isClass(mustDeduplicateAnnoClass))
2313 continue;
2314
2315 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2316 if (!modulesAttr || modulesAttr.size() < 2)
2317 continue;
2318
2319 // Process each group of corresponding instances
2320 processInstanceGroups(
2321 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2322 matchId++;
2323
2324 // Make sure there are at least two distinct modules.
2325 SmallDenseSet<StringAttr, 4> moduleTargets;
2326 for (auto instOp : instanceGroup) {
2327 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2328 for (auto moduleName : moduleNames)
2329 moduleTargets.insert(cast<StringAttr>(moduleName));
2330 }
2331 if (moduleTargets.size() < 2)
2332 return;
2333
2334 // Make sure none of the modules are not yet in a must dedup
2335 // annotation.
2336 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2337 auto moduleNames = inst.getReferencedModuleNames();
2338 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2339 return modulesAlreadyInMustDedup.contains(moduleName);
2340 });
2341 }))
2342 return;
2343
2344 addMatch(1, matchId - 1);
2345 });
2346 }
2347 }
2348
2349 LogicalResult rewriteMatches(CircuitOp circuitOp,
2350 ArrayRef<uint64_t> matches) override {
2351 auto *context = circuitOp->getContext();
2352 auto annotations = AnnotationSet(circuitOp);
2353 SmallVector<Annotation> newAnnotations;
2354 uint64_t matchId = 0;
2355
2356 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2357 if (!anno.isClass(mustDeduplicateAnnoClass)) {
2358 newAnnotations.push_back(anno);
2359 continue;
2360 }
2361
2362 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2363 if (!modulesAttr || modulesAttr.size() < 2) {
2364 newAnnotations.push_back(anno);
2365 continue;
2366 }
2367
2368 processInstanceGroups(
2369 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2370 // Check if this instance group was selected
2371 if (!llvm::is_contained(matches, matchId++))
2372 return;
2373
2374 // Create the list of modules to put into this new annotation.
2375 SmallSetVector<StringAttr, 4> moduleTargets;
2376 for (auto instOp : instanceGroup) {
2377 auto moduleNames = instOp.getReferencedModuleNames();
2378 for (auto moduleName : moduleNames) {
2379 auto target = TokenAnnoTarget();
2380 target.circuit = circuitOp.getName();
2381 target.module = moduleName;
2382 moduleTargets.insert(target.toStringAttr(context));
2383 }
2384 }
2385
2386 // Create a new MustDedup annotation for this list of modules.
2387 SmallVector<NamedAttribute> newAnnoAttrs;
2388 newAnnoAttrs.emplace_back(
2389 StringAttr::get(context, "class"),
2390 StringAttr::get(context, mustDeduplicateAnnoClass));
2391 newAnnoAttrs.emplace_back(
2392 StringAttr::get(context, "modules"),
2393 ArrayAttr::get(context,
2394 SmallVector<Attribute>(moduleTargets.begin(),
2395 moduleTargets.end())));
2396
2397 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
2398 newAnnotations.emplace_back(newAnnoDict);
2399 });
2400
2401 // Keep the original annotation around.
2402 newAnnotations.push_back(anno);
2403 }
2404
2405 // Update circuit annotations
2406 AnnotationSet newAnnoSet(newAnnotations, context);
2407 newAnnoSet.applyToOperation(circuitOp);
2408 return success();
2409 }
2410
2411 std::string getName() const override { return "must-dedup-children"; }
2412 bool acceptSizeIncrease() const override { return true; }
2413
2414private:
2415 /// Helper function to process groups of corresponding instances from a
2416 /// MustDedup annotation. Calls the provided lambda for each group of
2417 /// corresponding instances across the modules. Only calls the lambda if there
2418 /// are at least 2 modules.
2419 void processInstanceGroups(
2420 CircuitOp circuitOp, ArrayAttr modulesAttr,
2421 llvm::function_ref<void(ArrayRef<FInstanceLike>)> callback) {
2422 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2423
2424 // Extract module names and get the actual modules
2425 SmallVector<FModuleLike> modules;
2426 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2427 if (auto target = tokenizePath(moduleRef))
2428 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2429 modules.push_back(mod);
2430
2431 // Need at least 2 modules for deduplication
2432 if (modules.size() < 2)
2433 return;
2434
2435 // Collect all FInstanceLike operations from each module and group them by
2436 // name. Instance names are a good key for matching instances across
2437 // modules. But they may not be unique, so we need to be careful to only
2438 // match up instances that are uniquely named within every module.
2439 struct InstanceGroup {
2440 SmallVector<FInstanceLike> instances;
2441 bool nameIsUnique = true;
2442 };
2443 MapVector<StringAttr, InstanceGroup> instanceGroups;
2444 for (auto module : modules) {
2446 module.walk([&](FInstanceLike instOp) {
2447 if (isa<ObjectOp>(instOp.getOperation()))
2448 return;
2449 auto name = instOp.getInstanceNameAttr();
2450 auto &group = instanceGroups[name];
2451 if (nameCounts[name]++ > 1)
2452 group.nameIsUnique = false;
2453 group.instances.push_back(instOp);
2454 });
2455 }
2456
2457 // Call the callback for each group of instances that are uniquely named and
2458 // consist of at least 2 instances.
2459 for (auto &[name, group] : instanceGroups)
2460 if (group.nameIsUnique && group.instances.size() >= 2)
2461 callback(group.instances);
2462 }
2463
2464 ::detail::SymbolCache symbols;
2465 NLARemover nlaRemover;
2466};
2467
2468struct LayerDisable : public OpReduction<CircuitOp> {
2469 LayerDisable(MLIRContext *context) {
2470 pm = std::make_unique<mlir::PassManager>(
2471 context, "builtin.module", mlir::OpPassManager::Nesting::Explicit);
2472 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2473 };
2474
2475 void beforeReduction(mlir::ModuleOp op) override { symbolRefAttrMap.clear(); }
2476
2477 void afterReduction(mlir::ModuleOp op) override { (void)pm->run(op); };
2478
2479 void matches(CircuitOp circuitOp,
2480 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2481 uint64_t matchId = 0;
2482
2483 SmallVector<FlatSymbolRefAttr> nestedRefs;
2484 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2485 LayerOp layerOp) {
2486 if (!rootRef)
2487 rootRef = layerOp.getSymNameAttr();
2488 else
2489 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2490
2491 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2492 addMatch(1, matchId++);
2493
2494 for (auto nestedLayerOp : layerOp.getOps<LayerOp>())
2495 addLayer(rootRef, nestedLayerOp);
2496
2497 if (!nestedRefs.empty())
2498 nestedRefs.pop_back();
2499 };
2500
2501 for (auto layerOp : circuitOp.getOps<LayerOp>())
2502 addLayer({}, layerOp);
2503 }
2504
2505 LogicalResult rewriteMatches(CircuitOp circuitOp,
2506 ArrayRef<uint64_t> matches) override {
2507 SmallVector<Attribute> disableLayers;
2508 if (auto existingDisables = circuitOp.getDisableLayersAttr()) {
2509 auto disableRange = existingDisables.getAsRange<Attribute>();
2510 disableLayers.append(disableRange.begin(), disableRange.end());
2511 }
2512 for (auto match : matches)
2513 disableLayers.push_back(symbolRefAttrMap.at(match));
2514
2515 circuitOp.setDisableLayersAttr(
2516 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2517
2518 return success();
2519 }
2520
2521 std::string getName() const override { return "firrtl-layer-disable"; }
2522
2523 std::unique_ptr<mlir::PassManager> pm;
2524 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2525};
2526
2527} // namespace
2528
2529/// A reduction pattern that removes elements from FIRRTL list create
2530/// operations. This generates one match per element in each list, allowing
2531/// selective removal of individual elements.
2532struct ListCreateElementRemover : public OpReduction<ListCreateOp> {
2533 void matches(ListCreateOp listOp,
2534 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2535 // Create one match for each element in the list
2536 auto elements = listOp.getElements();
2537 for (size_t i = 0; i < elements.size(); ++i)
2538 addMatch(1, i);
2539 }
2540
2541 LogicalResult rewriteMatches(ListCreateOp listOp,
2542 ArrayRef<uint64_t> matches) override {
2543 // Convert matches to a set for fast lookup
2544 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
2545
2546 // Collect elements that should be kept (not in matches)
2547 SmallVector<Value> newElements;
2548 auto elements = listOp.getElements();
2549 for (size_t i = 0; i < elements.size(); ++i) {
2550 if (!matchesSet.contains(i))
2551 newElements.push_back(elements[i]);
2552 }
2553
2554 // Create a new list with the remaining elements
2555 OpBuilder builder(listOp);
2556 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2557 listOp.getType(), newElements);
2558 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2559 listOp.erase();
2560
2561 return success();
2562 }
2563
2564 std::string getName() const override {
2565 return "firrtl-list-create-element-remover";
2566 }
2567};
2568
2569/// Reduction that removes the `convention` attribute from regular modules.
2570struct ModuleConventionRemover : public OpReduction<FModuleOp> {
2571 uint64_t match(FModuleOp module) override {
2572 return module.getConvention() != Convention::Internal;
2573 }
2574
2575 LogicalResult rewrite(FModuleOp module) override {
2576 module.setConvention(Convention::Internal);
2577 return success();
2578 }
2579
2580 std::string getName() const override { return "module-convention-remover"; }
2581 bool acceptSizeIncrease() const override { return true; }
2582 bool isOneShot() const override { return true; }
2583};
2584
2585/// Reduction that removes the `convention` attribute from external modules.
2586struct ExtmoduleConventionRemover : public OpReduction<FExtModuleOp> {
2587 uint64_t match(FExtModuleOp extmodule) override {
2588 return extmodule.getConvention() != Convention::Internal;
2589 }
2590
2591 LogicalResult rewrite(FExtModuleOp extmodule) override {
2592 extmodule.setConvention(Convention::Internal);
2593 return success();
2594 }
2595
2596 std::string getName() const override {
2597 return "extmodule-convention-remover";
2598 }
2599 bool acceptSizeIncrease() const override { return true; }
2600 bool isOneShot() const override { return true; }
2601};
2602
2603//===----------------------------------------------------------------------===//
2604// Reduction Registration
2605//===----------------------------------------------------------------------===//
2606
2609 // Gather a list of reduction patterns that we should try. Ideally these are
2610 // assigned reasonable benefit indicators (higher benefit patterns are
2611 // prioritized). For example, things that can knock out entire modules while
2612 // being cheap should be tried first (and thus have higher benefit), before
2613 // trying to tweak operands of individual arithmetic ops.
2614 patterns.add<SimplifyResets, 35>();
2615 patterns.add<ForceDedup, 34>();
2616 patterns.add<MustDedupChildren, 33>();
2617 patterns.add<AnnotationRemover, 32>();
2618 patterns.add<ModuleSwapper, 31>();
2619 patterns.add<LayerDisable, 30>(getContext());
2620 patterns.add<PassReduction, 29>(
2621 getContext(),
2622 firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,
2623 true);
2624 patterns.add<PassReduction, 28>(getContext(),
2625 firrtl::createLowerCHIRRTLPass(), true, true);
2626 patterns.add<PassReduction, 27>(getContext(), firrtl::createInferWidths(),
2627 true, true);
2628 patterns.add<PassReduction, 26>(getContext(), firrtl::createInferResets(),
2629 true, true);
2630 patterns.add<FIRRTLModuleExternalizer, 25>();
2631 patterns.add<InstanceStubber, 24>();
2632 patterns.add<MemoryStubber, 23>();
2633 patterns.add<EagerInliner, 22>();
2634 patterns.add<ObjectInliner, 22>();
2635 patterns.add<PassReduction, 21>(getContext(),
2636 firrtl::createLowerFIRRTLTypes(), true, true);
2637 patterns.add<PassReduction, 20>(getContext(), firrtl::createExpandWhens(),
2638 true, true);
2639 patterns.add<PassReduction, 19>(getContext(), firrtl::createInliner());
2640 patterns.add<PassReduction, 18>(getContext(), firrtl::createIMConstProp());
2641 patterns.add<PassReduction, 17>(
2642 getContext(),
2643 firrtl::createRemoveUnusedPorts({/*ignoreDontTouch=*/true}));
2644 patterns.add<NodeSymbolRemover, 16>();
2645 patterns.add<PassReduction, 15>(getContext(), firrtl::createIMDeadCodeElim());
2646 patterns.add<ConnectForwarder, 14>();
2647 patterns.add<ConnectInvalidator, 13>();
2648 patterns.add<Constantifier, 12>();
2649 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2650 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2651 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2653 patterns.add<ResetDisconnector, 8>();
2654 patterns.add<DetachSubaccesses, 7>();
2655 patterns.add<ModulePortPruner, 7>();
2656 patterns.add<ExtmodulePortPruner, 6>();
2657 patterns.add<RootPortPruner, 5>();
2658 patterns.add<RootExtmodulePortPruner, 5>();
2659 patterns.add<ExtmoduleInstanceRemover, 4>();
2660 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2661 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2662 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2664 patterns.add<ModuleNameSanitizer, 0>();
2667}
2668
2670 mlir::DialectRegistry &registry) {
2671 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
2672 dialect->addInterfaces<FIRRTLReducePatternDialectInterface>();
2673 });
2674}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
Definition NLATable.cpp:41
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
Definition NLATable.cpp:58
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Definition NLATable.cpp:68
Helper class to cache tie-off values for different FIRRTL types.
Definition FIRRTLUtils.h:63
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
A utility class that generates metasyntactic variable names for use in reductions.
void reset()
Reset the generator to start from the beginning of the sequence.
const char * getNextName()
Get the next metasyntactic name in the sequence.
connect(destination, source)
Definition support.py:39
@ None
Don't explicitly preserve any named values.
Definition Passes.h:52
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
Definition LayerSet.h:43
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition HWOps.cpp:36
void info(Twine message)
Definition LSPUtils.cpp:20
void pruneUnusedOps(SmallVectorImpl< Operation * > &worklist, Reduction &reduction)
Starting from an initial worklist of operations, traverse through it and its operands and erase opera...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Reduction that removes the convention attribute from external modules.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
uint64_t match(FExtModuleOp extmodule) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(FExtModuleOp extmodule) override
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Reduction that removes the convention attribute from regular modules.
uint64_t match(FModuleOp module) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(FModuleOp module) override
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Pseudo-reduction that sanitizes the names of operations inside modules.
Pseudo-reduction that sanitizes module and port names.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
A reduction pattern that applies an mlir::Pass.
Definition Reduction.h:142
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition Reduction.h:35
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:79
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:66
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:96
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:50
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)