CIRCT 22.0.0git
Loading...
Searching...
No Matches
FIRRTLReductions.cpp
Go to the documentation of this file.
1//===- FIRRTLReductions.cpp - Reduction patterns for the FIRRTL dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/Matchers.h"
27#include "llvm/ADT/APSInt.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/SmallSet.h"
30#include "llvm/Support/Debug.h"
31
32#define DEBUG_TYPE "firrtl-reductions"
33
34using namespace mlir;
35using namespace circt;
36using namespace firrtl;
37using llvm::MapVector;
38using llvm::SmallSetVector;
39
40//===----------------------------------------------------------------------===//
41// Utilities
42//===----------------------------------------------------------------------===//
43
44namespace detail {
45/// A utility doing lazy construction of `SymbolTable`s and `SymbolUserMap`s,
46/// which is handy for reductions that need to look up a lot of symbols.
48 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
49
50 SymbolTable &getSymbolTable(Operation *op) {
51 return tables->getSymbolTable(op);
52 }
53 SymbolTable &getNearestSymbolTable(Operation *op) {
54 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
55 }
56
57 SymbolUserMap &getSymbolUserMap(Operation *op) {
58 auto it = userMaps.find(op);
59 if (it != userMaps.end())
60 return it->second;
61 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
62 }
63 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
64 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
65 }
66
67 void clear() {
68 tables = std::make_unique<SymbolTableCollection>();
69 userMaps.clear();
70 }
71
72private:
73 std::unique_ptr<SymbolTableCollection> tables;
75};
76} // namespace detail
77
78/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
79/// optional in case another type of module is instantiated.
80static std::optional<firrtl::FModuleOp>
81findInstantiatedModule(firrtl::InstanceOp instOp,
82 ::detail::SymbolCache &symbols) {
83 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
84 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
85 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
86 return moduleOp ? std::optional(moduleOp) : std::nullopt;
87}
88
89/// Utility to track the transitive size of modules.
91 void clear() { moduleSizes.clear(); }
92
93 uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols) {
94 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
95 return it->second;
96 uint64_t size = 1;
97 module->walk([&](Operation *op) {
98 size += 1;
99 if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
100 if (auto instModule = findInstantiatedModule(instOp, symbols))
101 size += getModuleSize(*instModule, symbols);
102 });
103 moduleSizes.insert({module, size});
104 return size;
105 }
106
107private:
108 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
109};
110
111/// Check that all connections to a value are invalids.
112static bool onlyInvalidated(Value arg) {
113 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
114 auto *op = use.getOwner();
115 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
116 return false;
117 if (use.getOperandNumber() != 0)
118 return false;
119 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
120 return false;
121 return true;
122 });
123}
124
125/// A tracker for track NLAs affected by a reduction. Performs the necessary
126/// cleanup steps in order to maintain IR validity after the reduction has
127/// applied. For example, removing an instance that forms part of an NLA path
128/// requires that NLA to be removed as well.
130 /// Clear the set of marked NLAs. Call this before attempting a reduction.
131 void clear() { nlasToRemove.clear(); }
132
133 /// Remove all marked annotations. Call this after applying a reduction in
134 /// order to validate the IR.
135 void remove(mlir::ModuleOp module) {
136 unsigned numRemoved = 0;
137 (void)numRemoved;
138 SymbolTableCollection symbolTables;
139 for (Operation &rootOp : *module.getBody()) {
140 if (!isa<firrtl::CircuitOp>(&rootOp))
141 continue;
142 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
143 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
144 for (auto sym : nlasToRemove) {
145 if (auto *op = symbolTable.lookup(sym)) {
146 if (symbolUserMap.useEmpty(op)) {
147 ++numRemoved;
148 op->erase();
149 }
150 }
151 }
152 }
153 LLVM_DEBUG({
154 unsigned numLost = nlasToRemove.size() - numRemoved;
155 if (numRemoved > 0 || numLost > 0) {
156 llvm::dbgs() << "Removed " << numRemoved << " NLAs";
157 if (numLost > 0)
158 llvm::dbgs() << " (" << numLost << " no longer there)";
159 llvm::dbgs() << "\n";
160 }
161 });
162 }
163
164 /// Mark all NLAs referenced in the given annotation as to be removed. This
165 /// can be an entire array or dictionary of annotations, and the function will
166 /// descend into child annotations appropriately.
167 void markNLAsInAnnotation(Attribute anno) {
168 if (auto dict = dyn_cast<DictionaryAttr>(anno)) {
169 if (auto field = dict.getAs<FlatSymbolRefAttr>("circt.nonlocal"))
170 nlasToRemove.insert(field.getAttr());
171 for (auto namedAttr : dict)
172 markNLAsInAnnotation(namedAttr.getValue());
173 } else if (auto array = dyn_cast<ArrayAttr>(anno)) {
174 for (auto attr : array)
175 markNLAsInAnnotation(attr);
176 }
177 }
178
179 /// Mark all NLAs referenced in an operation. Also traverses all nested
180 /// operations. Call this before removing an operation, to mark any associated
181 /// NLAs as to be removed as well.
182 void markNLAsInOperation(Operation *op) {
183 op->walk([&](Operation *op) {
184 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
185 markNLAsInAnnotation(annos);
186 });
187 }
188
189 /// The set of NLAs to remove, identified by their symbol.
190 llvm::DenseSet<StringAttr> nlasToRemove;
191};
192
193//===----------------------------------------------------------------------===//
194// Reduction patterns
195//===----------------------------------------------------------------------===//
196
197namespace {
198
199/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
200struct FIRRTLModuleExternalizer : public OpReduction<FModuleOp> {
201 void beforeReduction(mlir::ModuleOp op) override {
202 nlaRemover.clear();
203 symbols.clear();
204 moduleSizes.clear();
205 innerSymUses = reduce::InnerSymbolUses(op);
206 }
207 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
208
209 uint64_t match(FModuleOp module) override {
210 if (innerSymUses.hasInnerRef(module))
211 return 0;
212 return moduleSizes.getModuleSize(module, symbols);
213 }
214
215 LogicalResult rewrite(FModuleOp module) override {
216 // Hack up a list of known layers.
217 LayerSet layers;
218 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
219 for (auto attr : module.getPortTypes()) {
220 auto type = cast<TypeAttr>(attr).getValue();
221 if (auto refType = type_dyn_cast<RefType>(type))
222 if (auto layer = refType.getLayer())
223 layers.insert(layer);
224 }
225 SmallVector<Attribute, 4> layersArray;
226 layersArray.reserve(layers.size());
227 for (auto layer : layers)
228 layersArray.push_back(layer);
229
230 nlaRemover.markNLAsInOperation(module);
231 OpBuilder builder(module);
232 auto extmodule = FExtModuleOp::create(
233 builder, module->getLoc(),
234 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
235 module.getConventionAttr(), module.getPorts(),
236 builder.getArrayAttr(layersArray), StringRef(),
237 module.getAnnotationsAttr());
238 SymbolTable::setSymbolVisibility(extmodule,
239 SymbolTable::getSymbolVisibility(module));
240 module->erase();
241 return success();
242 }
243
244 std::string getName() const override { return "firrtl-module-externalizer"; }
245
246 ::detail::SymbolCache symbols;
247 NLARemover nlaRemover;
248 reduce::InnerSymbolUses innerSymUses;
249 ModuleSizeCache moduleSizes;
250};
251
252/// Invalidate all the leaf fields of a value with a given flippedness by
253/// connecting an invalid value to them. This is useful for ensuring that all
254/// output ports of an instance or memory (including those nested in bundles)
255/// are properly invalidated.
256static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
257 SmallDenseMap<Type, Value, 8> &invalidCache,
258 bool flip = false) {
259 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
260 if (!type)
261 return;
262
263 // Descend into bundles by creating subfield ops.
264 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
265 for (auto element : llvm::enumerate(bundleType.getElements())) {
266 auto subfield =
267 builder.createOrFold<firrtl::SubfieldOp>(value, element.index());
268 invalidateOutputs(builder, subfield, invalidCache,
269 flip ^ element.value().isFlip);
270 if (subfield.use_empty())
271 subfield.getDefiningOp()->erase();
272 }
273 return;
274 }
275
276 // Descend into vectors by creating subindex ops.
277 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
278 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
279 auto subindex = builder.createOrFold<firrtl::SubindexOp>(value, i);
280 invalidateOutputs(builder, subindex, invalidCache, flip);
281 if (subindex.use_empty())
282 subindex.getDefiningOp()->erase();
283 }
284 return;
285 }
286
287 // Only drive outputs.
288 if (flip)
289 return;
290 Value invalid = invalidCache.lookup(type);
291 if (!invalid) {
292 invalid = firrtl::InvalidValueOp::create(builder, type);
293 invalidCache.insert({type, invalid});
294 }
295 firrtl::ConnectOp::create(builder, value, invalid);
296}
297
298/// Connect a value to every leave of a destination value.
299static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
300 Value value) {
301 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
302 if (!type)
303 return;
304 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
305 for (auto element : llvm::enumerate(bundleType.getElements()))
306 connectToLeafs(builder,
307 firrtl::SubfieldOp::create(builder, dest, element.index()),
308 value);
309 return;
310 }
311 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
312 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
313 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
314 value);
315 return;
316 }
317 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
318 if (!valueType)
319 return;
320 auto destWidth = type.getBitWidthOrSentinel();
321 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
322 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
323 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
324 if (!isa<firrtl::UIntType>(type)) {
325 if (isa<firrtl::SIntType>(type))
326 value = firrtl::AsSIntPrimOp::create(builder, value);
327 else
328 return;
329 }
330 firrtl::ConnectOp::create(builder, dest, value);
331}
332
333/// Reduce all leaf fields of a value through an XOR tree.
334static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
335 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
336 if (!type)
337 return;
338 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
339 for (auto element : llvm::enumerate(bundleType.getElements()))
340 reduceXor(
341 builder, into,
342 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
343 return;
344 }
345 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
346 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
347 reduceXor(builder, into,
348 builder.createOrFold<firrtl::SubindexOp>(value, i));
349 return;
350 }
351 if (!isa<firrtl::UIntType>(type)) {
352 if (isa<firrtl::SIntType>(type))
353 value = firrtl::AsUIntPrimOp::create(builder, value);
354 else
355 return;
356 }
357 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
358}
359
360/// A sample reduction pattern that maps `firrtl.instance` to a set of
361/// invalidated wires. This often shortcuts a long iterative process of connect
362/// invalidation, module externalization, and wire stripping
363struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
364 void beforeReduction(mlir::ModuleOp op) override {
365 erasedInsts.clear();
366 erasedModules.clear();
367 symbols.clear();
368 nlaRemover.clear();
369 moduleSizes.clear();
370 }
371 void afterReduction(mlir::ModuleOp op) override {
372 // Look into deleted modules to find additional instances that are no longer
373 // instantiated anywhere.
374 SmallVector<Operation *> worklist;
375 auto deadInsts = erasedInsts;
376 for (auto *op : erasedModules)
377 worklist.push_back(op);
378 while (!worklist.empty()) {
379 auto *op = worklist.pop_back_val();
380 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
381 op->walk([&](firrtl::InstanceOp instOp) {
382 auto moduleOp = cast<firrtl::FModuleLike>(
383 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
384 deadInsts.insert(instOp);
385 if (llvm::all_of(
386 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
387 [&](Operation *user) { return deadInsts.contains(user); })) {
388 LLVM_DEBUG(llvm::dbgs() << "- Removing transitively unused module `"
389 << moduleOp.getModuleName() << "`\n");
390 erasedModules.insert(moduleOp);
391 worklist.push_back(moduleOp);
392 }
393 });
394 }
395
396 for (auto *op : erasedInsts)
397 op->erase();
398 for (auto *op : erasedModules)
399 op->erase();
400 nlaRemover.remove(op);
401 }
402
403 uint64_t match(firrtl::InstanceOp instOp) override {
404 if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
405 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
406 return 0;
407 }
408
409 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
410 LLVM_DEBUG(llvm::dbgs()
411 << "Stubbing instance `" << instOp.getName() << "`\n");
412 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
414 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
415 auto result = instOp.getResult(i);
416 auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" +
417 instOp.getPortNameStr(i));
418 auto wire =
419 firrtl::WireOp::create(builder, result.getType(), name,
420 firrtl::NameKindEnum::DroppableName,
421 instOp.getPortAnnotation(i), StringAttr{})
422 .getResult();
423 invalidateOutputs(builder, wire, invalidCache,
424 instOp.getPortDirection(i) == firrtl::Direction::In);
425 result.replaceAllUsesWith(wire);
426 }
427 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
428 auto moduleOp = cast<firrtl::FModuleLike>(
429 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
430 nlaRemover.markNLAsInOperation(instOp);
431 erasedInsts.insert(instOp);
432 if (llvm::all_of(
433 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
434 [&](Operation *user) { return erasedInsts.contains(user); })) {
435 LLVM_DEBUG(llvm::dbgs() << "- Removing now unused module `"
436 << moduleOp.getModuleName() << "`\n");
437 erasedModules.insert(moduleOp);
438 }
439 return success();
440 }
441
442 std::string getName() const override { return "instance-stubber"; }
443 bool acceptSizeIncrease() const override { return true; }
444
445 ::detail::SymbolCache symbols;
446 NLARemover nlaRemover;
447 llvm::DenseSet<Operation *> erasedInsts;
448 llvm::DenseSet<Operation *> erasedModules;
449 ModuleSizeCache moduleSizes;
450};
451
452/// A sample reduction pattern that maps `firrtl.mem` to a set of invalidated
453/// wires.
454struct MemoryStubber : public OpReduction<firrtl::MemOp> {
455 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
456 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
457 LogicalResult rewrite(firrtl::MemOp memOp) override {
458 LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
459 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
461 Value xorInputs;
462 SmallVector<Value> outputs;
463 for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
464 auto result = memOp.getResult(i);
465 auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" +
466 memOp.getPortNameStr(i));
467 auto wire =
468 firrtl::WireOp::create(builder, result.getType(), name,
469 firrtl::NameKindEnum::DroppableName,
470 memOp.getPortAnnotation(i), StringAttr{})
471 .getResult();
472 invalidateOutputs(builder, wire, invalidCache, true);
473 result.replaceAllUsesWith(wire);
474
475 // Isolate the input and output data fields of the port.
476 Value input, output;
477 switch (memOp.getPortKind(i)) {
478 case firrtl::MemOp::PortKind::Read:
479 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
480 break;
481 case firrtl::MemOp::PortKind::Write:
482 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
483 break;
484 case firrtl::MemOp::PortKind::ReadWrite:
485 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
486 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
487 break;
488 case firrtl::MemOp::PortKind::Debug:
489 output = wire;
490 break;
491 }
492
493 if (!isa<firrtl::RefType>(result.getType())) {
494 // Reduce all input ports to a single one through an XOR tree.
495 unsigned numFields =
496 cast<firrtl::BundleType>(wire.getType()).getNumElements();
497 for (unsigned i = 0; i != numFields; ++i) {
498 if (i != 2 && i != 3 && i != 5)
499 reduceXor(builder, xorInputs,
500 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
501 }
502 if (input)
503 reduceXor(builder, xorInputs, input);
504 }
505
506 // Track the output port to hook it up to the XORd input later.
507 if (output)
508 outputs.push_back(output);
509 }
510
511 // Hook up the outputs.
512 for (auto output : outputs)
513 connectToLeafs(builder, output, xorInputs);
514
515 nlaRemover.markNLAsInOperation(memOp);
516 memOp->erase();
517 return success();
518 }
519 std::string getName() const override { return "memory-stubber"; }
520 bool acceptSizeIncrease() const override { return true; }
521 NLARemover nlaRemover;
522};
523
524/// Check whether an operation interacts with flows in any way, which can make
525/// replacement and operand forwarding harder in some cases.
526static bool isFlowSensitiveOp(Operation *op) {
527 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
528 SubaccessOp, ObjectSubfieldOp>(op);
529}
530
531/// A sample reduction pattern that replaces all uses of an operation with one
532/// of its operands. This can help pruning large parts of the expression tree
533/// rapidly.
534template <unsigned OpNum>
535struct FIRRTLOperandForwarder : public Reduction {
536 uint64_t match(Operation *op) override {
537 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
538 return 0;
539 if (isFlowSensitiveOp(op))
540 return 0;
541 auto resultTy =
542 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
543 auto opTy =
544 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
545 return resultTy && opTy &&
546 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
547 (resultTy.getBitWidthOrSentinel() == -1) ==
548 (opTy.getBitWidthOrSentinel() == -1) &&
549 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
550 }
551 LogicalResult rewrite(Operation *op) override {
552 assert(match(op));
553 ImplicitLocOpBuilder builder(op->getLoc(), op);
554 auto result = op->getResult(0);
555 auto operand = op->getOperand(OpNum);
556 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
557 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
558 auto resultWidth = resultTy.getBitWidthOrSentinel();
559 auto operandWidth = operandTy.getBitWidthOrSentinel();
560 Value newOp;
561 if (resultWidth < operandWidth)
562 newOp =
563 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
564 else if (resultWidth > operandWidth)
565 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
566 else
567 newOp = operand;
568 LLVM_DEBUG(llvm::dbgs() << "Forwarding " << newOp << " in " << *op << "\n");
569 result.replaceAllUsesWith(newOp);
570 reduce::pruneUnusedOps(op, *this);
571 return success();
572 }
573 std::string getName() const override {
574 return ("firrtl-operand" + Twine(OpNum) + "-forwarder").str();
575 }
576};
577
578/// A sample reduction pattern that replaces FIRRTL operations with a constant
579/// zero of their type.
580struct Constantifier : public Reduction {
581 void beforeReduction(mlir::ModuleOp op) override {
582 symbols.clear();
583
584 // Find valid dummy classes that we can use for anyref casts.
585 anyrefCastDummy.clear();
586 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
587 for (auto classOp : circuitOp.getOps<ClassOp>()) {
588 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
589 anyrefCastDummy.insert({circuitOp, classOp});
590 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
591 }
592 }
593 return WalkResult::skip();
594 });
595 }
596
597 uint64_t match(Operation *op) override {
598 if (op->hasTrait<OpTrait::ConstantLike>()) {
599 Attribute attr;
600 if (!matchPattern(op, m_Constant(&attr)))
601 return 0;
602 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
603 if (intAttr.getValue().isZero())
604 return 0;
605 if (auto strAttr = dyn_cast<StringAttr>(attr))
606 if (strAttr.empty())
607 return 0;
608 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
609 if (floatAttr.getValue().isZero())
610 return 0;
611 }
612 if (auto listOp = dyn_cast<ListCreateOp>(op))
613 if (listOp.getElements().empty())
614 return 0;
615 if (auto pathOp = dyn_cast<UnresolvedPathOp>(op))
616 if (pathOp.getTarget().empty())
617 return 0;
618
619 // Don't replace anyref casts that already target a dummy class.
620 if (auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
621 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
622 auto className =
623 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
624 if (anyrefCastDummyNames[circuitOp].contains(className))
625 return 0;
626 }
627
628 if (op->getNumResults() != 1)
629 return 0;
630 if (op->hasAttr("inner_sym"))
631 return 0;
632 if (isFlowSensitiveOp(op))
633 return 0;
634 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
635 DoubleType, ListType, PathType, AnyRefType>(
636 op->getResult(0).getType());
637 }
638
639 LogicalResult rewrite(Operation *op) override {
640 OpBuilder builder(op);
641 auto type = op->getResult(0).getType();
642
643 // Handle UInt/SInt types.
644 if (isa<UIntType, SIntType>(type)) {
645 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
646 if (width == -1)
647 width = 64;
648 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
649 APSInt(width, isa<UIntType>(type)));
650 op->replaceAllUsesWith(newOp);
651 reduce::pruneUnusedOps(op, *this);
652 return success();
653 }
654
655 // Handle property string types.
656 if (isa<StringType>(type)) {
657 auto attr = builder.getStringAttr("");
658 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
659 op->replaceAllUsesWith(newOp);
660 reduce::pruneUnusedOps(op, *this);
661 return success();
662 }
663
664 // Handle property integer types.
665 if (isa<FIntegerType>(type)) {
666 auto attr = builder.getIntegerAttr(builder.getI64Type(), 0);
667 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
668 op->replaceAllUsesWith(newOp);
669 reduce::pruneUnusedOps(op, *this);
670 return success();
671 }
672
673 // Handle property boolean types.
674 if (isa<BoolType>(type)) {
675 auto attr = builder.getBoolAttr(false);
676 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
677 op->replaceAllUsesWith(newOp);
678 reduce::pruneUnusedOps(op, *this);
679 return success();
680 }
681
682 // Handle property double types.
683 if (isa<DoubleType>(type)) {
684 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
685 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
686 op->replaceAllUsesWith(newOp);
687 reduce::pruneUnusedOps(op, *this);
688 return success();
689 }
690
691 // Handle property list types.
692 if (isa<ListType>(type)) {
693 auto newOp =
694 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
695 op->replaceAllUsesWith(newOp);
696 reduce::pruneUnusedOps(op, *this);
697 return success();
698 }
699
700 // Handle property path types.
701 if (isa<PathType>(type)) {
702 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(), "");
703 op->replaceAllUsesWith(newOp);
704 reduce::pruneUnusedOps(op, *this);
705 return success();
706 }
707
708 // Handle anyref types.
709 if (isa<AnyRefType>(type)) {
710 auto circuitOp = op->getParentOfType<CircuitOp>();
711 auto &dummy = anyrefCastDummy[circuitOp];
712 if (!dummy) {
713 OpBuilder::InsertionGuard guard(builder);
714 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
715 auto &symbolTable = symbols.getNearestSymbolTable(op);
716 dummy = ClassOp::create(builder, op->getLoc(), "Dummy", {}, {});
717 symbolTable.insert(dummy);
718 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
719 }
720 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy, "dummy");
721 auto anyrefOp =
722 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
723 op->replaceAllUsesWith(anyrefOp);
724 reduce::pruneUnusedOps(op, *this);
725 return success();
726 }
727
728 return failure();
729 }
730
731 std::string getName() const override { return "firrtl-constantifier"; }
732 bool acceptSizeIncrease() const override { return true; }
733
734 ::detail::SymbolCache symbols;
736 SmallDenseMap<CircuitOp, DenseSet<StringAttr>, 2> anyrefCastDummyNames;
737};
738
739/// A sample reduction pattern that replaces the right-hand-side of
740/// `firrtl.connect` and `firrtl.matchingconnect` operations with a
741/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
742/// connects and creates opportunities for reduction in DCE/CSE.
743struct ConnectInvalidator : public Reduction {
744 uint64_t match(Operation *op) override {
745 if (!isa<FConnectLike>(op))
746 return 0;
747 if (auto *srcOp = op->getOperand(1).getDefiningOp())
748 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
749 isa<InvalidValueOp>(srcOp))
750 return 0;
751 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
752 return type && type.isPassive();
753 }
754 LogicalResult rewrite(Operation *op) override {
755 assert(match(op));
756 auto rhs = op->getOperand(1);
757 OpBuilder builder(op);
758 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
759 auto *rhsOp = rhs.getDefiningOp();
760 op->setOperand(1, invOp);
761 if (rhsOp)
762 reduce::pruneUnusedOps(rhsOp, *this);
763 return success();
764 }
765 std::string getName() const override { return "connect-invalidator"; }
766 bool acceptSizeIncrease() const override { return true; }
767};
768
769/// A reduction pattern that removes FIRRTL annotations from ports and
770/// operations. This generates one match per annotation and port annotation,
771/// allowing selective removal of individual annotations.
772struct AnnotationRemover : public Reduction {
773 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
774 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
775
776 void matches(Operation *op,
777 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
778 uint64_t matchId = 0;
779
780 // Generate matches for regular annotations
781 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
782 for (unsigned i = 0; i < annos.size(); ++i)
783 addMatch(1, matchId++);
784
785 // Generate matches for port annotations
786 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations"))
787 for (auto portAnnoArray : portAnnos)
788 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
789 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
790 addMatch(1, matchId++);
791 }
792
793 LogicalResult rewriteMatches(Operation *op,
794 ArrayRef<uint64_t> matches) override {
795 // Convert matches to a set for fast lookup
796 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
797
798 // Lambda to process annotations and filter out matched ones
799 uint64_t matchId = 0;
800 auto processAnnotations =
801 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
802 SmallVector<Attribute> newAnnotations;
803 for (auto anno : annotations) {
804 if (!matchesSet.contains(matchId)) {
805 newAnnotations.push_back(anno);
806 } else {
807 // Mark NLAs in the removed annotation for cleanup
808 nlaRemover.markNLAsInAnnotation(anno);
809 }
810 matchId++;
811 }
812 return ArrayAttr::get(op->getContext(), newAnnotations);
813 };
814
815 // Remove regular annotations
816 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
817 op->setAttr("annotations", processAnnotations(annos.getValue()));
818 }
819
820 // Remove port annotations
821 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
822 SmallVector<Attribute> newPortAnnos;
823 for (auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
824 newPortAnnos.push_back(
825 processAnnotations(portAnnoArrayAttr.getValue()));
826 }
827 op->setAttr("portAnnotations",
828 ArrayAttr::get(op->getContext(), newPortAnnos));
829 }
830
831 return success();
832 }
833
834 std::string getName() const override { return "annotation-remover"; }
835 NLARemover nlaRemover;
836};
837
838/// A sample reduction pattern that removes ports from the root `firrtl.module`
839/// if the port is not used or just invalidated.
840struct RootPortPruner : public OpReduction<firrtl::FModuleOp> {
841 uint64_t match(firrtl::FModuleOp module) override {
842 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
843 if (!circuit)
844 return 0;
845 return circuit.getNameAttr() == module.getNameAttr();
846 }
847 LogicalResult rewrite(firrtl::FModuleOp module) override {
848 assert(match(module));
849 size_t numPorts = module.getNumPorts();
850 llvm::BitVector dropPorts(numPorts);
851 for (unsigned i = 0; i != numPorts; ++i) {
852 if (onlyInvalidated(module.getArgument(i))) {
853 dropPorts.set(i);
854 for (auto *user :
855 llvm::make_early_inc_range(module.getArgument(i).getUsers()))
856 user->erase();
857 }
858 }
859 module.erasePorts(dropPorts);
860 return success();
861 }
862 std::string getName() const override { return "root-port-pruner"; }
863};
864
865/// A sample reduction pattern that replaces instances of `firrtl.extmodule`
866/// with wires.
867struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {
868 void beforeReduction(mlir::ModuleOp op) override {
869 symbols.clear();
870 nlaRemover.clear();
871 }
872 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
873
874 uint64_t match(firrtl::InstanceOp instOp) override {
875 return isa<firrtl::FExtModuleOp>(
876 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
877 }
878 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
879 auto portInfo =
880 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
881 symbols.getNearestSymbolTable(instOp)))
882 .getPorts();
883 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
884 SmallVector<Value> replacementWires;
885 for (firrtl::PortInfo info : portInfo) {
886 auto wire = firrtl::WireOp::create(
887 builder, info.type,
888 (Twine(instOp.getName()) + "_" + info.getName()).str())
889 .getResult();
890 if (info.isOutput()) {
891 auto inv = firrtl::InvalidValueOp::create(builder, info.type);
892 firrtl::ConnectOp::create(builder, wire, inv);
893 }
894 replacementWires.push_back(wire);
895 }
896 nlaRemover.markNLAsInOperation(instOp);
897 instOp.replaceAllUsesWith(std::move(replacementWires));
898 instOp->erase();
899 return success();
900 }
901 std::string getName() const override { return "extmodule-instance-remover"; }
902 bool acceptSizeIncrease() const override { return true; }
903
904 ::detail::SymbolCache symbols;
905 NLARemover nlaRemover;
906};
907
908/// A sample reduction pattern that pushes connected values through wires.
909struct ConnectForwarder : public Reduction {
910 uint64_t match(Operation *op) override {
911 if (!isa<firrtl::FConnectLike>(op))
912 return 0;
913 auto dest = op->getOperand(0);
914 auto src = op->getOperand(1);
915 auto *destOp = dest.getDefiningOp();
916 auto *srcOp = src.getDefiningOp();
917 if (dest == src)
918 return 0;
919
920 // Ensure that the destination is something we should be able to forward
921 // through.
922 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
923 destOp))
924 return 0;
925
926 // Ensure that the destination is connected to only once, and all uses of
927 // the connection occur after the definition of the source.
928 unsigned numConnects = 0;
929 for (auto &use : dest.getUses()) {
930 auto *op = use.getOwner();
931 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
932 if (++numConnects > 1)
933 return 0;
934 continue;
935 }
936 if (srcOp && !srcOp->isBeforeInBlock(op))
937 return 0;
938 }
939
940 return 1;
941 }
942
943 LogicalResult rewrite(Operation *op) override {
944 auto dst = op->getOperand(0);
945 auto src = op->getOperand(1);
946 dst.replaceAllUsesWith(src);
947 op->erase();
948 if (auto *dstOp = dst.getDefiningOp())
949 reduce::pruneUnusedOps(dstOp, *this);
950 if (auto *srcOp = src.getDefiningOp())
951 reduce::pruneUnusedOps(srcOp, *this);
952 return success();
953 }
954
955 std::string getName() const override { return "connect-forwarder"; }
956};
957
958/// A sample reduction pattern that replaces a single-use wire and register with
959/// an operand of the source value of the connection.
960template <unsigned OpNum>
961struct ConnectSourceOperandForwarder : public Reduction {
962 uint64_t match(Operation *op) override {
963 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
964 return 0;
965 auto dest = op->getOperand(0);
966 auto *destOp = dest.getDefiningOp();
967
968 // Ensure that the destination is used only once.
969 if (!destOp || !destOp->hasOneUse() ||
970 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
971 return 0;
972
973 auto *srcOp = op->getOperand(1).getDefiningOp();
974 if (!srcOp || OpNum >= srcOp->getNumOperands())
975 return 0;
976
977 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
978 auto opTy =
979 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
980
981 return resultTy && opTy &&
982 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
983 ((resultTy.getBitWidthOrSentinel() == -1) ==
984 (opTy.getBitWidthOrSentinel() == -1)) &&
985 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
986 }
987
988 LogicalResult rewrite(Operation *op) override {
989 auto *destOp = op->getOperand(0).getDefiningOp();
990 auto *srcOp = op->getOperand(1).getDefiningOp();
991 auto forwardedOperand = srcOp->getOperand(OpNum);
992 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
993 Value newDest;
994 if (auto wire = dyn_cast<firrtl::WireOp>(destOp))
995 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
996 wire.getName())
997 .getResult();
998 else {
999 auto regName = destOp->getAttrOfType<StringAttr>("name");
1000 // We can promote the register into a wire but we wouldn't do here because
1001 // the error might be caused by the register.
1002 auto clock = destOp->getOperand(0);
1003 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1004 clock, regName ? regName.str() : "")
1005 .getResult();
1006 }
1007
1008 // Create new connection between a new wire and the forwarded operand.
1009 builder.setInsertionPointAfter(op);
1010 if (isa<firrtl::ConnectOp>(op))
1011 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1012 else
1013 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1014
1015 // Remove the old connection and destination. We don't have to replace them
1016 // because destination has only one use.
1017 op->erase();
1018 destOp->erase();
1019 reduce::pruneUnusedOps(srcOp, *this);
1020
1021 return success();
1022 }
1023 std::string getName() const override {
1024 return ("connect-source-operand-" + Twine(OpNum) + "-forwarder").str();
1025 }
1026};
1027
1028/// A sample reduction pattern that tries to remove aggregate wires by replacing
1029/// all subaccesses with new independent wires. This can disentangle large
1030/// unused wires that are otherwise difficult to collect due to the subaccesses.
1031struct DetachSubaccesses : public Reduction {
1032 void beforeReduction(mlir::ModuleOp op) override { opsToErase.clear(); }
1033 void afterReduction(mlir::ModuleOp op) override {
1034 for (auto *op : opsToErase)
1035 op->dropAllReferences();
1036 for (auto *op : opsToErase)
1037 op->erase();
1038 }
1039 uint64_t match(Operation *op) override {
1040 // Only applies to wires and registers that are purely used in subaccess
1041 // operations.
1042 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1043 llvm::all_of(op->getUses(), [](auto &use) {
1044 return use.getOperandNumber() == 0 &&
1045 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1046 firrtl::SubaccessOp>(use.getOwner());
1047 });
1048 }
1049 LogicalResult rewrite(Operation *op) override {
1050 assert(match(op));
1051 OpBuilder builder(op);
1052 bool isWire = isa<firrtl::WireOp>(op);
1053 Value invalidClock;
1054 if (!isWire)
1055 invalidClock = firrtl::InvalidValueOp::create(
1056 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1057 for (Operation *user : llvm::make_early_inc_range(op->getUsers())) {
1058 builder.setInsertionPoint(user);
1059 auto type = user->getResult(0).getType();
1060 Operation *replOp;
1061 if (isWire)
1062 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1063 else
1064 replOp =
1065 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1066 user->replaceAllUsesWith(replOp);
1067 opsToErase.insert(user);
1068 }
1069 opsToErase.insert(op);
1070 return success();
1071 }
1072 std::string getName() const override { return "detach-subaccesses"; }
1073 llvm::DenseSet<Operation *> opsToErase;
1074};
1075
1076/// This reduction removes inner symbols on ops. Name preservation creates a lot
1077/// of node ops with symbols to keep name information but it also prevents
1078/// normal canonicalizations.
1079struct NodeSymbolRemover : public Reduction {
1080 void beforeReduction(mlir::ModuleOp op) override {
1081 innerSymUses = reduce::InnerSymbolUses(op);
1082 }
1083
1084 uint64_t match(Operation *op) override {
1085 // Only match ops with an inner symbol.
1086 auto sym = op->getAttrOfType<hw::InnerSymAttr>("inner_sym");
1087 if (!sym || sym.empty())
1088 return 0;
1089
1090 // Only match ops that have no references to their inner symbol.
1091 if (innerSymUses.hasInnerRef(op))
1092 return 0;
1093 return 1;
1094 }
1095
1096 LogicalResult rewrite(Operation *op) override {
1097 op->removeAttr("inner_sym");
1098 return success();
1099 }
1100
1101 std::string getName() const override { return "node-symbol-remover"; }
1102 bool acceptSizeIncrease() const override { return true; }
1103
1104 reduce::InnerSymbolUses innerSymUses;
1105};
1106
1107/// Check if inlining the referenced operation into the parent operation would
1108/// cause inner symbol collisions.
1109static bool
1110hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1111 hw::InnerSymbolTableCollection &innerSymTables) {
1112 // Get the inner symbol tables for both operations
1113 auto &targetTable = innerSymTables.getInnerSymbolTable(referencedOp);
1114 auto &parentTable = innerSymTables.getInnerSymbolTable(parentOp);
1115
1116 // Check if any inner symbol name in the target operation already exists
1117 // in the parent operation. Return failure() if a collision is found to stop
1118 // the walk early.
1119 LogicalResult walkResult = targetTable.walkSymbols(
1120 [&](StringAttr name, const hw::InnerSymTarget &target) -> LogicalResult {
1121 // Check if this symbol name exists in the parent operation
1122 if (parentTable.lookup(name)) {
1123 // Collision found, return failure to stop the walk
1124 return failure();
1125 }
1126 return success();
1127 });
1128
1129 // If the walk failed, it means we found a collision
1130 return failed(walkResult);
1131}
1132
1133/// A sample reduction pattern that eagerly inlines instances.
1134struct EagerInliner : public OpReduction<InstanceOp> {
1135 void beforeReduction(mlir::ModuleOp op) override {
1136 symbols.clear();
1137 nlaRemover.clear();
1138 nlaTables.clear();
1139 for (auto circuitOp : op.getOps<CircuitOp>())
1140 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1141 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1142 }
1143 void afterReduction(mlir::ModuleOp op) override {
1144 nlaRemover.remove(op);
1145 nlaTables.clear();
1146 innerSymTables.reset();
1147 }
1148
1149 uint64_t match(InstanceOp instOp) override {
1150 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1151 auto *moduleOp =
1152 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1153
1154 // Only inline FModuleOp instances
1155 if (!isa<FModuleOp>(moduleOp))
1156 return 0;
1157
1158 // Skip instances that participate in any NLAs
1159 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1160 if (!circuitOp)
1161 return 0;
1162 auto it = nlaTables.find(circuitOp);
1163 if (it == nlaTables.end() || !it->second)
1164 return 0;
1165 DenseSet<hw::HierPathOp> nlas;
1166 it->second->getInstanceNLAs(instOp, nlas);
1167 if (!nlas.empty())
1168 return 0;
1169
1170 // Check for inner symbol collisions between the referenced module and the
1171 // instance's parent module
1172 auto parentOp = instOp->getParentOfType<FModuleLike>();
1173 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1174 return 0;
1175
1176 return 1;
1177 }
1178
1179 LogicalResult rewrite(InstanceOp instOp) override {
1180 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1181 auto moduleOp = cast<FModuleOp>(
1182 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1183 bool isLastUse =
1184 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1185 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1186
1187 // Create wires to replace the instance results.
1188 IRRewriter rewriter(instOp);
1189 SmallVector<Value> argWires;
1190 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1191 auto result = instOp.getResult(i);
1192 auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" +
1193 instOp.getPortNameStr(i));
1194 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1195 name, NameKindEnum::DroppableName,
1196 instOp.getPortAnnotation(i), StringAttr{})
1197 .getResult();
1198 result.replaceAllUsesWith(wire);
1199 argWires.push_back(wire);
1200 }
1201
1202 // Splice in the cloned module body.
1203 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1204
1205 // Make sure we remove any NLAs that go through this instance, and the
1206 // module if we're about the delete the module.
1207 nlaRemover.markNLAsInOperation(instOp);
1208 if (isLastUse)
1209 nlaRemover.markNLAsInOperation(moduleOp);
1210
1211 instOp.erase();
1212 clonedModuleOp.erase();
1213 return success();
1214 }
1215
1216 std::string getName() const override { return "firrtl-eager-inliner"; }
1217 bool acceptSizeIncrease() const override { return true; }
1218
1219 ::detail::SymbolCache symbols;
1220 NLARemover nlaRemover;
1221 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1222 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1223};
1224
1225/// A reduction pattern that eagerly inlines `ObjectOp`s.
1226struct ObjectInliner : public OpReduction<ObjectOp> {
1227 void beforeReduction(mlir::ModuleOp op) override {
1228 blocksToSort.clear();
1229 symbols.clear();
1230 nlaRemover.clear();
1231 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1232 }
1233 void afterReduction(mlir::ModuleOp op) override {
1234 for (auto *block : blocksToSort)
1235 mlir::sortTopologically(block);
1236 blocksToSort.clear();
1237 nlaRemover.remove(op);
1238 innerSymTables.reset();
1239 }
1240
1241 uint64_t match(ObjectOp objOp) override {
1242 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1243 auto *classOp =
1244 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1245
1246 // Only inline `ClassOp`s.
1247 if (!isa<ClassOp>(classOp))
1248 return 0;
1249
1250 // Check for inner symbol collisions between the referenced class and the
1251 // object's parent module.
1252 auto parentOp = objOp->getParentOfType<FModuleLike>();
1253 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1254 return 0;
1255
1256 // Verify all uses are ObjectSubfieldOp.
1257 for (auto *user : objOp.getResult().getUsers())
1258 if (!isa<ObjectSubfieldOp>(user))
1259 return 0;
1260
1261 return 1;
1262 }
1263
1264 LogicalResult rewrite(ObjectOp objOp) override {
1265 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1266 auto classOp = cast<ClassOp>(
1267 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1268 auto clonedClassOp = classOp.clone();
1269
1270 // Create wires to replace the ObjectSubfieldOp results.
1271 IRRewriter rewriter(objOp);
1272 SmallVector<Value> portWires;
1273 auto classType = objOp.getType();
1274
1275 // Create a wire for each port in the class
1276 for (unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1277 auto element = classType.getElement(i);
1278 auto name = rewriter.getStringAttr(Twine(objOp.getName()) + "_" +
1279 element.name.getValue());
1280 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1281 NameKindEnum::DroppableName,
1282 rewriter.getArrayAttr({}), StringAttr{})
1283 .getResult();
1284 portWires.push_back(wire);
1285 }
1286
1287 // Replace all ObjectSubfieldOp uses with corresponding wires
1288 SmallVector<ObjectSubfieldOp> subfieldOps;
1289 for (auto *user : objOp.getResult().getUsers()) {
1290 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1291 subfieldOps.push_back(subfieldOp);
1292 auto index = subfieldOp.getIndex();
1293 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1294 }
1295
1296 // Splice in the cloned class body.
1297 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1298
1299 // After inlining the class body, we need to eliminate `WireOps` since
1300 // `ClassOps` cannot contain wires. For each port wire, find its single
1301 // connect, remove it, and replace all uses of the wire with the assigned
1302 // value.
1303 SmallVector<FConnectLike> connectsToErase;
1304 for (auto portWire : portWires) {
1305 // Find a single value to replace the wire with, and collect all connects
1306 // to the wire such that we can erase them later.
1307 Value value;
1308 for (auto *user : portWire.getUsers()) {
1309 if (auto connect = dyn_cast<FConnectLike>(user)) {
1310 if (connect.getDest() == portWire) {
1311 value = connect.getSrc();
1312 connectsToErase.push_back(connect);
1313 }
1314 }
1315 }
1316
1317 // Be very conservative about deleting these wires. Other reductions may
1318 // leave class ports unconnected, which means that there isn't always a
1319 // clean replacement available here. Better to just leave the wires in the
1320 // IR and let the verifier fail later.
1321 if (value)
1322 portWire.replaceAllUsesWith(value);
1323 for (auto connect : connectsToErase)
1324 connect.erase();
1325 if (portWire.use_empty())
1326 portWire.getDefiningOp()->erase();
1327 connectsToErase.clear();
1328 }
1329
1330 // Make sure we remove any NLAs that go through this object.
1331 nlaRemover.markNLAsInOperation(objOp);
1332
1333 // Since the above forwarding of SSA values through wires can create
1334 // dominance issues, mark the region containing the object to be sorted
1335 // topologically.
1336 blocksToSort.insert(objOp->getBlock());
1337
1338 // Erase the object and cloned class.
1339 for (auto subfieldOp : subfieldOps)
1340 subfieldOp.erase();
1341 objOp.erase();
1342 clonedClassOp.erase();
1343 return success();
1344 }
1345
1346 std::string getName() const override { return "firrtl-object-inliner"; }
1347 bool acceptSizeIncrease() const override { return true; }
1348
1349 SetVector<Block *> blocksToSort;
1350 ::detail::SymbolCache symbols;
1351 NLARemover nlaRemover;
1352 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1353};
1354
1355/// Psuedo-reduction that sanitizes the names of things inside modules. This is
1356/// not an actual reduction, but often removes extraneous information that has
1357/// no bearing on the actual reduction (and would likely be removed before
1358/// sharing the reduction). This makes the following changes:
1359///
1360/// - All wires are renamed to "wire"
1361/// - All registers are renamed to "reg"
1362/// - All nodes are renamed to "node"
1363/// - All memories are renamed to "mem"
1364/// - All verification messages and labels are dropped
1365///
1366struct ModuleInternalNameSanitizer : public Reduction {
1367 uint64_t match(Operation *op) override {
1368 // Only match operations with names.
1369 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1370 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1371 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1372 firrtl::CoverOp>(op);
1373 }
1374 LogicalResult rewrite(Operation *op) override {
1375 TypeSwitch<Operation *, void>(op)
1376 .Case<firrtl::WireOp>([](auto op) { op.setName("wire"); })
1377 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1378 [](auto op) { op.setName("reg"); })
1379 .Case<firrtl::NodeOp>([](auto op) { op.setName("node"); })
1380 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1381 [](auto op) { op.setName("mem"); })
1382 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](auto op) {
1383 op->setAttr("message", StringAttr::get(op.getContext(), ""));
1384 op->setAttr("name", StringAttr::get(op.getContext(), ""));
1385 });
1386 return success();
1387 }
1388
1389 std::string getName() const override {
1390 return "module-internal-name-sanitizer";
1391 }
1392
1393 bool acceptSizeIncrease() const override { return true; }
1394
1395 bool isOneShot() const override { return true; }
1396};
1397
1398/// Psuedo-reduction that sanitizes module, instance, and port names. This
1399/// makes the following changes:
1400///
1401/// - All modules are given metasyntactic names ("Foo", "Bar", etc.)
1402/// - All instances are renamed to match the new module name
1403/// - All module ports are renamed in the following way:
1404/// - All clocks are reanemd to "clk"
1405/// - All resets are renamed to "rst"
1406/// - All references are renamed to "ref"
1407/// - Anything else is renamed to "port"
1408///
1409struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
1410
1411 const char *names[48] = {
1412 "Foo", "Bar", "Baz", "Qux", "Quux", "Quuux", "Quuuux",
1413 "Quz", "Corge", "Grault", "Bazola", "Ztesch", "Thud", "Grunt",
1414 "Bletch", "Fum", "Fred", "Jim", "Sheila", "Barney", "Flarp",
1415 "Zxc", "Spqr", "Wombat", "Shme", "Bongo", "Spam", "Eggs",
1416 "Snork", "Zot", "Blarg", "Wibble", "Toto", "Titi", "Tata",
1417 "Tutu", "Pippo", "Pluto", "Paperino", "Aap", "Noot", "Mies",
1418 "Oogle", "Foogle", "Boogle", "Zork", "Gork", "Bork"};
1419
1420 size_t nameIndex = 0;
1421
1422 const char *getName() {
1423 if (nameIndex >= 48)
1424 nameIndex = 0;
1425 return names[nameIndex++];
1426 };
1427
1428 size_t portNameIndex = 0;
1429
1430 char getPortName() {
1431 if (portNameIndex >= 26)
1432 portNameIndex = 0;
1433 return 'a' + portNameIndex++;
1434 }
1435
1436 void beforeReduction(mlir::ModuleOp op) override { nameIndex = 0; }
1437
1438 LogicalResult rewrite(firrtl::CircuitOp circuitOp) override {
1439
1440 firrtl::InstanceGraph iGraph(circuitOp);
1441
1442 auto *circuitName = getName();
1443 iGraph.getTopLevelModule().setName(circuitName);
1444 circuitOp.setName(circuitName);
1445
1446 for (auto *node : iGraph) {
1447 auto module = node->getModule<firrtl::FModuleLike>();
1448
1449 bool shouldReplacePorts = false;
1450 SmallVector<Attribute> newNames;
1451 if (auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1452 portNameIndex = 0;
1453 // TODO: The namespace should be unnecessary. However, some FIRRTL
1454 // passes expect that port names are unique.
1456 auto oldPorts = fmodule.getPorts();
1457 shouldReplacePorts = !oldPorts.empty();
1458 for (unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1459 auto port = oldPorts[i];
1460 auto newName = firrtl::FIRRTLTypeSwitch<Type, StringRef>(port.type)
1461 .Case<firrtl::ClockType>(
1462 [&](auto a) { return ns.newName("clk"); })
1463 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1464 [&](auto a) { return ns.newName("rst"); })
1465 .Case<firrtl::RefType>(
1466 [&](auto a) { return ns.newName("ref"); })
1467 .Default([&](auto a) {
1468 return ns.newName(Twine(getPortName()));
1469 });
1470 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1471 }
1472 fmodule->setAttr("portNames",
1473 ArrayAttr::get(fmodule.getContext(), newNames));
1474 }
1475
1476 if (module == iGraph.getTopLevelModule())
1477 continue;
1478 auto newName = StringAttr::get(circuitOp.getContext(), getName());
1479 module.setName(newName);
1480 for (auto *use : node->uses()) {
1481 auto instanceOp = dyn_cast<firrtl::InstanceOp>(*use->getInstance());
1482 instanceOp.setModuleName(newName);
1483 instanceOp.setName(newName);
1484 if (shouldReplacePorts)
1485 instanceOp.setPortNamesAttr(
1486 ArrayAttr::get(circuitOp.getContext(), newNames));
1487 }
1488 }
1489
1490 circuitOp->dump();
1491
1492 return success();
1493 }
1494
1495 std::string getName() const override { return "module-name-sanitizer"; }
1496
1497 bool acceptSizeIncrease() const override { return true; }
1498
1499 bool isOneShot() const override { return true; }
1500};
1501
1502/// A reduction pattern that groups modules by their port signature (types and
1503/// directions) and replaces instances with the smallest module in each group.
1504/// This helps reduce the IR by consolidating functionally equivalent modules
1505/// based on their interface.
1506///
1507/// The pattern works by:
1508/// 1. Grouping all modules by their port signature (port types and directions)
1509/// 2. For each group with multiple modules, finding the smallest module using
1510/// the module size cache
1511/// 3. Replacing all instances of larger modules with instances of the smallest
1512/// module in the same group
1513/// 4. Removing the larger modules from the circuit
1514///
1515/// This reduction is useful for reducing circuits where multiple modules have
1516/// the same interface but different implementations, allowing the reducer to
1517/// try the smallest implementation first.
1518struct ModuleSwapper : public OpReduction<InstanceOp> {
1519 // Per-circuit state containing all the information needed for module swapping
1520 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1521 struct CircuitState {
1522 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1523 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1524 std::unique_ptr<NLATable> nlaTable;
1525 };
1526
1527 void beforeReduction(mlir::ModuleOp op) override {
1528 symbols.clear();
1529 nlaRemover.clear();
1530 moduleSizes.clear();
1531 circuitStates.clear();
1532
1533 // Collect module type groups and NLA tables for all circuits up front
1534 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1535 auto &state = circuitStates[circuitOp];
1536 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1537 buildModuleTypeGroups(circuitOp, state);
1538 return WalkResult::skip();
1539 });
1540 }
1541 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1542
1543 /// Create a vector of port type-direction pairs for the given FIRRTL module.
1544 /// This ignores port names, allowing modules with the same port types and
1545 /// directions but different port names to be considered equivalent for
1546 /// swapping.
1547 PortSignature getModulePortSignature(FModuleLike module) {
1548 PortSignature signature;
1549 signature.reserve(module.getNumPorts());
1550 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1551 signature.emplace_back(module.getPortType(i), module.getPortDirection(i));
1552 return signature;
1553 }
1554
1555 /// Group modules by their port signature and find the smallest in each group.
1556 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1557 // Group modules by their port signature
1558 for (auto module : circuitOp.getBodyBlock()->getOps<FModuleLike>()) {
1559 auto signature = getModulePortSignature(module);
1560 state.moduleTypeGroups[signature].push_back(module);
1561 }
1562
1563 // For each group, find the smallest module
1564 for (auto &[signature, modules] : state.moduleTypeGroups) {
1565 if (modules.size() <= 1)
1566 continue;
1567
1568 FModuleLike smallestModule = nullptr;
1569 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1570
1571 for (auto module : modules) {
1572 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1573 if (size < smallestSize) {
1574 smallestSize = size;
1575 smallestModule = module;
1576 }
1577 }
1578
1579 // Map all modules in this group to the smallest one
1580 for (auto module : modules) {
1581 if (module != smallestModule) {
1582 state.instanceToCanonicalModule[module.getModuleNameAttr()] =
1583 smallestModule;
1584 }
1585 }
1586 }
1587 }
1588
1589 uint64_t match(InstanceOp instOp) override {
1590 // Get the circuit this instance belongs to
1591 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1592 assert(circuitOp);
1593 const auto &state = circuitStates.at(circuitOp);
1594
1595 // Skip instances that participate in any NLAs
1596 DenseSet<hw::HierPathOp> nlas;
1597 state.nlaTable->getInstanceNLAs(instOp, nlas);
1598 if (!nlas.empty())
1599 return 0;
1600
1601 // Check if this instance can be redirected to a smaller module
1602 auto moduleName = instOp.getModuleNameAttr().getAttr();
1603 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1604 if (!canonicalModule)
1605 return 0;
1606
1607 // Benefit is the size difference
1608 auto currentModule = cast<FModuleLike>(
1609 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1610 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1611 uint64_t canonicalSize =
1612 moduleSizes.getModuleSize(canonicalModule, symbols);
1613 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1614 }
1615
1616 LogicalResult rewrite(InstanceOp instOp) override {
1617 // Get the circuit this instance belongs to
1618 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1619 assert(circuitOp);
1620 const auto &state = circuitStates.at(circuitOp);
1621
1622 // Replace the instantiated module with the canonical module.
1623 auto canonicalModule = state.instanceToCanonicalModule.at(
1624 instOp.getModuleNameAttr().getAttr());
1625 auto canonicalName = canonicalModule.getModuleNameAttr();
1626 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
1627
1628 // Update port names to match the canonical module
1629 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1630
1631 return success();
1632 }
1633
1634 std::string getName() const override { return "firrtl-module-swapper"; }
1635 bool acceptSizeIncrease() const override { return true; }
1636
1637private:
1638 ::detail::SymbolCache symbols;
1639 NLARemover nlaRemover;
1640 ModuleSizeCache moduleSizes;
1641
1642 // Per-circuit state containing all module swapping information
1643 DenseMap<CircuitOp, CircuitState> circuitStates;
1644};
1645
1646/// A reduction pattern that handles MustDedup annotations by replacing all
1647/// module names in a dedup group with a single module name. This helps reduce
1648/// the IR by consolidating module references that are required to be identical.
1649///
1650/// The pattern works by:
1651/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1652/// 2. For each dedup group, using the first module as the canonical name
1653/// 3. Replacing all instance references to other modules in the group with
1654/// references to the canonical module
1655/// 4. Removing the non-canonical modules from the circuit
1656/// 5. Removing the processed MustDedup annotation
1657///
1658/// This reduction is particularly useful for reducing large circuits where
1659/// multiple modules are known to be identical but haven't been deduplicated
1660/// yet.
1661struct ForceDedup : public OpReduction<CircuitOp> {
1662 void beforeReduction(mlir::ModuleOp op) override {
1663 symbols.clear();
1664 nlaRemover.clear();
1665 }
1666 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1667
1668 /// Collect all MustDedup annotations and create matches for each dedup group.
1669 void matches(CircuitOp circuitOp,
1670 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1671 auto annotations = AnnotationSet(circuitOp);
1672 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1673 if (!anno.isClass(mustDedupAnnoClass))
1674 continue;
1675
1676 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1677 if (!modulesAttr)
1678 continue;
1679
1680 // Each dedup group gets its own match with benefit proportional to group
1681 // size.
1682 uint64_t benefit = modulesAttr.size();
1683 addMatch(benefit, annoIdx);
1684 }
1685 }
1686
1687 LogicalResult rewriteMatches(CircuitOp circuitOp,
1688 ArrayRef<uint64_t> matches) override {
1689 auto *context = circuitOp->getContext();
1690 NLATable nlaTable(circuitOp);
1691 hw::InnerSymbolTableCollection innerSymTables;
1692 auto annotations = AnnotationSet(circuitOp);
1693 SmallVector<Annotation> newAnnotations;
1694
1695 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1696 // Check if this annotation was selected.
1697 if (!llvm::is_contained(matches, annoIdx)) {
1698 newAnnotations.push_back(anno);
1699 continue;
1700 }
1701 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1702 assert(anno.isClass(mustDedupAnnoClass) && modulesAttr &&
1703 modulesAttr.size() >= 2);
1704
1705 // Extract module names from the dedup group.
1706 SmallVector<StringAttr> moduleNames;
1707 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
1708 // Parse "~CircuitName|ModuleName" format.
1709 auto refStr = moduleRef.getValue();
1710 auto pipePos = refStr.find('|');
1711 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
1712 auto moduleName = refStr.substr(pipePos + 1);
1713 moduleNames.push_back(StringAttr::get(context, moduleName));
1714 }
1715 }
1716
1717 // Simply drop the annotation if there's only one module.
1718 if (moduleNames.size() < 2)
1719 continue;
1720
1721 // Replace all instances and references to other modules with the
1722 // first module.
1723 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
1724 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1725 }
1726 if (newAnnotations.size() == annotations.size())
1727 return failure();
1728
1729 // Update circuit annotations.
1730 AnnotationSet newAnnoSet(newAnnotations, context);
1731 newAnnoSet.applyToOperation(circuitOp);
1732 return success();
1733 }
1734
1735 std::string getName() const override { return "firrtl-force-dedup"; }
1736 bool acceptSizeIncrease() const override { return true; }
1737
1738private:
1739 /// Replace all references to modules in the dedup group with the canonical
1740 /// module name
1741 void replaceModuleReferences(CircuitOp circuitOp,
1742 ArrayRef<StringAttr> moduleNames,
1743 NLATable &nlaTable,
1744 hw::InnerSymbolTableCollection &innerSymTables) {
1745 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
1746 auto &symbolTable = symbols.getSymbolTable(tableOp);
1747 auto *context = circuitOp->getContext();
1748 auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
1749
1750 // Collect the modules.
1751 FModuleLike canonicalModule;
1752 SmallVector<FModuleLike> modulesToReplace;
1753 for (auto name : moduleNames) {
1754 if (auto mod = symbolTable.lookup<FModuleLike>(name)) {
1755 if (!canonicalModule)
1756 canonicalModule = mod;
1757 else
1758 modulesToReplace.push_back(mod);
1759 }
1760 }
1761 if (modulesToReplace.empty())
1762 return;
1763
1764 // Replace all instance references.
1765 auto canonicalName = canonicalModule.getModuleNameAttr();
1766 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
1767 circuitOp.walk([&](InstanceOp instOp) {
1768 auto moduleName = instOp.getModuleNameAttr().getAttr();
1769 if (llvm::is_contained(moduleNames, moduleName) &&
1770 moduleName != canonicalName) {
1771 instOp.setModuleNameAttr(canonicalRef);
1772 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1773 }
1774 });
1775
1776 // Update NLAs to reference the canonical module instead of modules being
1777 // removed using NLATable for better performance.
1778 for (auto oldMod : modulesToReplace) {
1779 SmallVector<hw::HierPathOp> nlaOps(
1780 nlaTable.lookup(oldMod.getModuleNameAttr()));
1781 for (auto nlaOp : nlaOps) {
1782 nlaTable.erase(nlaOp);
1783 StringAttr oldModName = oldMod.getModuleNameAttr();
1784 StringAttr newModName = canonicalName;
1785 SmallVector<Attribute, 4> newPath;
1786 for (auto nameRef : nlaOp.getNamepath()) {
1787 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
1788 if (ref.getModule() == oldModName) {
1789 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
1790 ref = hw::InnerRefAttr::get(newModName, ref.getName());
1791 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
1792 if (oldInst && newInst) {
1793 oldModName = oldInst.getReferencedModuleNameAttr();
1794 newModName = newInst.getReferencedModuleNameAttr();
1795 }
1796 }
1797 newPath.push_back(ref);
1798 } else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
1799 newPath.push_back(FlatSymbolRefAttr::get(newModName));
1800 } else {
1801 newPath.push_back(nameRef);
1802 }
1803 }
1804 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
1805 nlaTable.addNLA(nlaOp);
1806 }
1807 }
1808
1809 // Mark NLAs in modules to be removed.
1810 for (auto module : modulesToReplace) {
1811 nlaRemover.markNLAsInOperation(module);
1812 module->erase();
1813 }
1814 }
1815
1816 ::detail::SymbolCache symbols;
1817 NLARemover nlaRemover;
1818};
1819
1820/// A reduction pattern that moves `MustDedup` annotations from a module onto
1821/// its child modules. This pattern iterates over all MustDedup annotations,
1822/// collects all `FInstanceLike` ops in each module of the dedup group, and
1823/// creates new MustDedup annotations for corresponding instances across the
1824/// modules. Each set of corresponding instances becomes a separate match of the
1825/// reduction. The reduction also removes the original MustDedup annotation on
1826/// the parent module.
1827///
1828/// The pattern works by:
1829/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1830/// 2. For each dedup group, collecting all FInstanceLike operations in each
1831/// module
1832/// 3. Grouping corresponding instances across modules by their position/name
1833/// 4. Creating new MustDedup annotations for each group of corresponding
1834/// instances
1835/// 5. Removing the original MustDedup annotation from the circuit
1836struct MustDedupChildren : public OpReduction<CircuitOp> {
1837 void beforeReduction(mlir::ModuleOp op) override {
1838 symbols.clear();
1839 nlaRemover.clear();
1840 }
1841 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1842
1843 /// Collect all MustDedup annotations and create matches for each instance
1844 /// group.
1845 void matches(CircuitOp circuitOp,
1846 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1847 auto annotations = AnnotationSet(circuitOp);
1848 uint64_t matchId = 0;
1849
1850 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1851 if (!anno.isClass(mustDedupAnnoClass))
1852 continue;
1853
1854 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1855 if (!modulesAttr || modulesAttr.size() < 2)
1856 continue;
1857
1858 // Process each group of corresponding instances
1859 processInstanceGroups(
1860 circuitOp, modulesAttr,
1861 [&](ArrayRef<FInstanceLike>) { addMatch(1, matchId++); });
1862 }
1863 }
1864
1865 LogicalResult rewriteMatches(CircuitOp circuitOp,
1866 ArrayRef<uint64_t> matches) override {
1867 auto *context = circuitOp->getContext();
1868 auto annotations = AnnotationSet(circuitOp);
1869 SmallVector<Annotation> newAnnotations;
1870 uint64_t matchId = 0;
1871
1872 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1873 if (!anno.isClass(mustDedupAnnoClass)) {
1874 newAnnotations.push_back(anno);
1875 continue;
1876 }
1877
1878 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1879 if (!modulesAttr || modulesAttr.size() < 2) {
1880 newAnnotations.push_back(anno);
1881 continue;
1882 }
1883
1884 // Track whether any matches were selected for this annotation
1885 bool anyMatchSelected = false;
1886 processInstanceGroups(
1887 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
1888 // Check if this instance group was selected
1889 if (!llvm::is_contained(matches, matchId++))
1890 return;
1891 anyMatchSelected = true;
1892
1893 // Create the list of modules to put into this new annotation.
1894 SmallSetVector<StringAttr, 4> moduleTargets;
1895 for (auto instOp : instanceGroup) {
1896 auto target = TokenAnnoTarget();
1897 target.circuit = circuitOp.getName();
1898 target.module = instOp.getReferencedModuleName();
1899 moduleTargets.insert(target.toStringAttr(context));
1900 }
1901 if (moduleTargets.size() < 2)
1902 return;
1903
1904 // Create a new MustDedup annotation for this list of modules.
1905 SmallVector<NamedAttribute> newAnnoAttrs;
1906 newAnnoAttrs.emplace_back(
1907 StringAttr::get(context, "class"),
1908 StringAttr::get(context, mustDedupAnnoClass));
1909 newAnnoAttrs.emplace_back(
1910 StringAttr::get(context, "modules"),
1911 ArrayAttr::get(context,
1912 SmallVector<Attribute>(moduleTargets.begin(),
1913 moduleTargets.end())));
1914
1915 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
1916 newAnnotations.emplace_back(newAnnoDict);
1917 });
1918
1919 // If any matches were selected, mark the original annotation for removal
1920 // since we're replacing it with new MustDedup annotations on the child
1921 // modules. Otherwise keep the original annotation around.
1922 if (anyMatchSelected)
1923 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1924 else
1925 newAnnotations.push_back(anno);
1926 }
1927
1928 // Update circuit annotations
1929 AnnotationSet newAnnoSet(newAnnotations, context);
1930 newAnnoSet.applyToOperation(circuitOp);
1931 return success();
1932 }
1933
1934 std::string getName() const override { return "must-dedup-children"; }
1935 bool acceptSizeIncrease() const override { return true; }
1936
1937private:
1938 /// Helper function to process groups of corresponding instances from a
1939 /// MustDedup annotation. Calls the provided lambda for each group of
1940 /// corresponding instances across the modules. Only calls the lambda if there
1941 /// are at least 2 modules.
1942 void processInstanceGroups(
1943 CircuitOp circuitOp, ArrayAttr modulesAttr,
1944 llvm::function_ref<void(ArrayRef<FInstanceLike>)> callback) {
1945 auto &symbolTable = symbols.getSymbolTable(circuitOp);
1946
1947 // Extract module names and get the actual modules
1948 SmallVector<FModuleLike> modules;
1949 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
1950 if (auto target = tokenizePath(moduleRef))
1951 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
1952 modules.push_back(mod);
1953
1954 // Need at least 2 modules for deduplication
1955 if (modules.size() < 2)
1956 return;
1957
1958 // Collect all FInstanceLike operations from each module and group them by
1959 // name. Instance names are a good key for matching instances across
1960 // modules. But they may not be unique, so we need to be careful to only
1961 // match up instances that are uniquely named within every module.
1962 struct InstanceGroup {
1963 SmallVector<FInstanceLike> instances;
1964 bool nameIsUnique = true;
1965 };
1966 MapVector<StringAttr, InstanceGroup> instanceGroups;
1967 for (auto module : modules) {
1969 module.walk([&](FInstanceLike instOp) {
1970 auto name = instOp.getInstanceNameAttr();
1971 auto &group = instanceGroups[name];
1972 if (nameCounts[name]++ > 1)
1973 group.nameIsUnique = false;
1974 group.instances.push_back(instOp);
1975 });
1976 }
1977
1978 // Call the callback for each group of instances that are uniquely named and
1979 // consist of at least 2 instances.
1980 for (auto &[name, group] : instanceGroups)
1981 if (group.nameIsUnique && group.instances.size() >= 2)
1982 callback(group.instances);
1983 }
1984
1985 ::detail::SymbolCache symbols;
1986 NLARemover nlaRemover;
1987};
1988
1989} // namespace
1990
1991//===----------------------------------------------------------------------===//
1992// Reduction Registration
1993//===----------------------------------------------------------------------===//
1994
1997 // Gather a list of reduction patterns that we should try. Ideally these are
1998 // assigned reasonable benefit indicators (higher benefit patterns are
1999 // prioritized). For example, things that can knock out entire modules while
2000 // being cheap should be tried first (and thus have higher benefit), before
2001 // trying to tweak operands of individual arithmetic ops.
2002 patterns.add<AnnotationRemover, 33>();
2003 patterns.add<ModuleSwapper, 32>();
2004 patterns.add<ForceDedup, 31>();
2005 patterns.add<MustDedupChildren, 30>();
2006 patterns.add<PassReduction, 29>(
2007 getContext(),
2008 firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,
2009 true);
2010 patterns.add<PassReduction, 28>(getContext(),
2011 firrtl::createLowerCHIRRTLPass(), true, true);
2012 patterns.add<PassReduction, 27>(getContext(), firrtl::createInferWidths(),
2013 true, true);
2014 patterns.add<PassReduction, 26>(getContext(), firrtl::createInferResets(),
2015 true, true);
2016 patterns.add<FIRRTLModuleExternalizer, 25>();
2017 patterns.add<InstanceStubber, 24>();
2018 patterns.add<MemoryStubber, 23>();
2019 patterns.add<EagerInliner, 22>();
2020 patterns.add<ObjectInliner, 22>();
2021 patterns.add<PassReduction, 21>(getContext(),
2022 firrtl::createLowerFIRRTLTypes(), true, true);
2023 patterns.add<PassReduction, 20>(getContext(), firrtl::createExpandWhens(),
2024 true, true);
2025 patterns.add<PassReduction, 19>(getContext(), firrtl::createInliner());
2026 patterns.add<PassReduction, 18>(getContext(), firrtl::createIMConstProp());
2027 patterns.add<PassReduction, 17>(
2028 getContext(),
2029 firrtl::createRemoveUnusedPorts({/*ignoreDontTouch=*/true}));
2030 patterns.add<NodeSymbolRemover, 15>();
2031 patterns.add<ConnectForwarder, 14>();
2032 patterns.add<ConnectInvalidator, 13>();
2033 patterns.add<Constantifier, 12>();
2034 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2035 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2036 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2037 patterns.add<DetachSubaccesses, 7>();
2038 patterns.add<RootPortPruner, 5>();
2039 patterns.add<ExtmoduleInstanceRemover, 4>();
2040 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2041 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2042 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2043 patterns.add<ModuleInternalNameSanitizer, 0>();
2044 patterns.add<ModuleNameSanitizer, 0>();
2045}
2046
2048 mlir::DialectRegistry &registry) {
2049 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
2050 dialect->addInterfaces<FIRRTLReducePatternDialectInterface>();
2051 });
2052}
assert(baseType &&"element must be base type")
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
Definition NLATable.cpp:41
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
Definition NLATable.cpp:58
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Definition NLATable.cpp:68
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
Definition support.py:39
@ None
Don't explicitly preserve any named values.
Definition Passes.h:52
constexpr const char * mustDedupAnnoClass
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
Definition LayerSet.h:42
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition HWOps.cpp:36
void info(Twine message)
Definition LSPUtils.cpp:20
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
A reduction pattern that applies an mlir::Pass.
Definition Reduction.h:142
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition Reduction.h:35
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:79
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:66
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:96
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:50
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)