CIRCT 22.0.0git
Loading...
Searching...
No Matches
FIRRTLReductions.cpp
Go to the documentation of this file.
1//===- FIRRTLReductions.cpp - Reduction patterns for the FIRRTL dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/Matchers.h"
27#include "llvm/ADT/APSInt.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/SmallSet.h"
30#include "llvm/Support/Debug.h"
31
32#define DEBUG_TYPE "firrtl-reductions"
33
34using namespace mlir;
35using namespace circt;
36using namespace firrtl;
37using llvm::MapVector;
38using llvm::SmallDenseSet;
39using llvm::SmallSetVector;
40
41//===----------------------------------------------------------------------===//
42// Utilities
43//===----------------------------------------------------------------------===//
44
45namespace detail {
46/// A utility doing lazy construction of `SymbolTable`s and `SymbolUserMap`s,
47/// which is handy for reductions that need to look up a lot of symbols.
49 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
50
51 SymbolTable &getSymbolTable(Operation *op) {
52 return tables->getSymbolTable(op);
53 }
54 SymbolTable &getNearestSymbolTable(Operation *op) {
55 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
56 }
57
58 SymbolUserMap &getSymbolUserMap(Operation *op) {
59 auto it = userMaps.find(op);
60 if (it != userMaps.end())
61 return it->second;
62 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
63 }
64 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
65 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
66 }
67
68 void clear() {
69 tables = std::make_unique<SymbolTableCollection>();
70 userMaps.clear();
71 }
72
73private:
74 std::unique_ptr<SymbolTableCollection> tables;
76};
77} // namespace detail
78
79/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
80/// optional in case another type of module is instantiated.
81static std::optional<firrtl::FModuleOp>
82findInstantiatedModule(firrtl::InstanceOp instOp,
83 ::detail::SymbolCache &symbols) {
84 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
85 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
86 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
87 return moduleOp ? std::optional(moduleOp) : std::nullopt;
88}
89
90/// Utility to track the transitive size of modules.
92 void clear() { moduleSizes.clear(); }
93
94 uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols) {
95 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
96 return it->second;
97 uint64_t size = 1;
98 module->walk([&](Operation *op) {
99 size += 1;
100 if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
101 if (auto instModule = findInstantiatedModule(instOp, symbols))
102 size += getModuleSize(*instModule, symbols);
103 });
104 moduleSizes.insert({module, size});
105 return size;
106 }
107
108private:
109 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
110};
111
112/// Check that all connections to a value are invalids.
113static bool onlyInvalidated(Value arg) {
114 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
115 auto *op = use.getOwner();
116 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
117 return false;
118 if (use.getOperandNumber() != 0)
119 return false;
120 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
121 return false;
122 return true;
123 });
124}
125
126/// A tracker for track NLAs affected by a reduction. Performs the necessary
127/// cleanup steps in order to maintain IR validity after the reduction has
128/// applied. For example, removing an instance that forms part of an NLA path
129/// requires that NLA to be removed as well.
131 /// Clear the set of marked NLAs. Call this before attempting a reduction.
132 void clear() { nlasToRemove.clear(); }
133
134 /// Remove all marked annotations. Call this after applying a reduction in
135 /// order to validate the IR.
136 void remove(mlir::ModuleOp module) {
137 unsigned numRemoved = 0;
138 (void)numRemoved;
139 SymbolTableCollection symbolTables;
140 for (Operation &rootOp : *module.getBody()) {
141 if (!isa<firrtl::CircuitOp>(&rootOp))
142 continue;
143 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
144 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
145 for (auto sym : nlasToRemove) {
146 if (auto *op = symbolTable.lookup(sym)) {
147 if (symbolUserMap.useEmpty(op)) {
148 ++numRemoved;
149 op->erase();
150 }
151 }
152 }
153 }
154 LLVM_DEBUG({
155 unsigned numLost = nlasToRemove.size() - numRemoved;
156 if (numRemoved > 0 || numLost > 0) {
157 llvm::dbgs() << "Removed " << numRemoved << " NLAs";
158 if (numLost > 0)
159 llvm::dbgs() << " (" << numLost << " no longer there)";
160 llvm::dbgs() << "\n";
161 }
162 });
163 }
164
165 /// Mark all NLAs referenced in the given annotation as to be removed. This
166 /// can be an entire array or dictionary of annotations, and the function will
167 /// descend into child annotations appropriately.
168 void markNLAsInAnnotation(Attribute anno) {
169 if (auto dict = dyn_cast<DictionaryAttr>(anno)) {
170 if (auto field = dict.getAs<FlatSymbolRefAttr>("circt.nonlocal"))
171 nlasToRemove.insert(field.getAttr());
172 for (auto namedAttr : dict)
173 markNLAsInAnnotation(namedAttr.getValue());
174 } else if (auto array = dyn_cast<ArrayAttr>(anno)) {
175 for (auto attr : array)
176 markNLAsInAnnotation(attr);
177 }
178 }
179
180 /// Mark all NLAs referenced in an operation. Also traverses all nested
181 /// operations. Call this before removing an operation, to mark any associated
182 /// NLAs as to be removed as well.
183 void markNLAsInOperation(Operation *op) {
184 op->walk([&](Operation *op) {
185 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
186 markNLAsInAnnotation(annos);
187 });
188 }
189
190 /// The set of NLAs to remove, identified by their symbol.
191 llvm::DenseSet<StringAttr> nlasToRemove;
192};
193
194//===----------------------------------------------------------------------===//
195// Reduction patterns
196//===----------------------------------------------------------------------===//
197
198namespace {
199
200/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
201struct FIRRTLModuleExternalizer : public OpReduction<FModuleOp> {
202 void beforeReduction(mlir::ModuleOp op) override {
203 nlaRemover.clear();
204 symbols.clear();
205 moduleSizes.clear();
206 innerSymUses = reduce::InnerSymbolUses(op);
207 }
208 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
209
210 uint64_t match(FModuleOp module) override {
211 if (innerSymUses.hasInnerRef(module))
212 return 0;
213 return moduleSizes.getModuleSize(module, symbols);
214 }
215
216 LogicalResult rewrite(FModuleOp module) override {
217 // Hack up a list of known layers.
218 LayerSet layers;
219 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
220 for (auto attr : module.getPortTypes()) {
221 auto type = cast<TypeAttr>(attr).getValue();
222 if (auto refType = type_dyn_cast<RefType>(type))
223 if (auto layer = refType.getLayer())
224 layers.insert(layer);
225 }
226 SmallVector<Attribute, 4> layersArray;
227 layersArray.reserve(layers.size());
228 for (auto layer : layers)
229 layersArray.push_back(layer);
230
231 nlaRemover.markNLAsInOperation(module);
232 OpBuilder builder(module);
233 auto extmodule = FExtModuleOp::create(
234 builder, module->getLoc(),
235 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
236 module.getConventionAttr(), module.getPorts(),
237 builder.getArrayAttr(layersArray), StringRef(),
238 module.getAnnotationsAttr());
239 SymbolTable::setSymbolVisibility(extmodule,
240 SymbolTable::getSymbolVisibility(module));
241 module->erase();
242 return success();
243 }
244
245 std::string getName() const override { return "firrtl-module-externalizer"; }
246
247 ::detail::SymbolCache symbols;
248 NLARemover nlaRemover;
249 reduce::InnerSymbolUses innerSymUses;
250 ModuleSizeCache moduleSizes;
251};
252
253/// Invalidate all the leaf fields of a value with a given flippedness by
254/// connecting an invalid value to them. This is useful for ensuring that all
255/// output ports of an instance or memory (including those nested in bundles)
256/// are properly invalidated.
257static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
258 SmallDenseMap<Type, Value, 8> &invalidCache,
259 bool flip = false) {
260 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
261 if (!type)
262 return;
263
264 // Descend into bundles by creating subfield ops.
265 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
266 for (auto element : llvm::enumerate(bundleType.getElements())) {
267 auto subfield =
268 builder.createOrFold<firrtl::SubfieldOp>(value, element.index());
269 invalidateOutputs(builder, subfield, invalidCache,
270 flip ^ element.value().isFlip);
271 if (subfield.use_empty())
272 subfield.getDefiningOp()->erase();
273 }
274 return;
275 }
276
277 // Descend into vectors by creating subindex ops.
278 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
279 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
280 auto subindex = builder.createOrFold<firrtl::SubindexOp>(value, i);
281 invalidateOutputs(builder, subindex, invalidCache, flip);
282 if (subindex.use_empty())
283 subindex.getDefiningOp()->erase();
284 }
285 return;
286 }
287
288 // Only drive outputs.
289 if (flip)
290 return;
291 Value invalid = invalidCache.lookup(type);
292 if (!invalid) {
293 invalid = firrtl::InvalidValueOp::create(builder, type);
294 invalidCache.insert({type, invalid});
295 }
296 firrtl::ConnectOp::create(builder, value, invalid);
297}
298
299/// Connect a value to every leave of a destination value.
300static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
301 Value value) {
302 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
303 if (!type)
304 return;
305 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
306 for (auto element : llvm::enumerate(bundleType.getElements()))
307 connectToLeafs(builder,
308 firrtl::SubfieldOp::create(builder, dest, element.index()),
309 value);
310 return;
311 }
312 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
313 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
314 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
315 value);
316 return;
317 }
318 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
319 if (!valueType)
320 return;
321 auto destWidth = type.getBitWidthOrSentinel();
322 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
323 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
324 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
325 if (!isa<firrtl::UIntType>(type)) {
326 if (isa<firrtl::SIntType>(type))
327 value = firrtl::AsSIntPrimOp::create(builder, value);
328 else
329 return;
330 }
331 firrtl::ConnectOp::create(builder, dest, value);
332}
333
334/// Reduce all leaf fields of a value through an XOR tree.
335static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
336 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
337 if (!type)
338 return;
339 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
340 for (auto element : llvm::enumerate(bundleType.getElements()))
341 reduceXor(
342 builder, into,
343 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
344 return;
345 }
346 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
347 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
348 reduceXor(builder, into,
349 builder.createOrFold<firrtl::SubindexOp>(value, i));
350 return;
351 }
352 if (!isa<firrtl::UIntType>(type)) {
353 if (isa<firrtl::SIntType>(type))
354 value = firrtl::AsUIntPrimOp::create(builder, value);
355 else
356 return;
357 }
358 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
359}
360
361/// A sample reduction pattern that maps `firrtl.instance` to a set of
362/// invalidated wires. This often shortcuts a long iterative process of connect
363/// invalidation, module externalization, and wire stripping
364struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
365 void beforeReduction(mlir::ModuleOp op) override {
366 erasedInsts.clear();
367 erasedModules.clear();
368 symbols.clear();
369 nlaRemover.clear();
370 moduleSizes.clear();
371 }
372 void afterReduction(mlir::ModuleOp op) override {
373 // Look into deleted modules to find additional instances that are no longer
374 // instantiated anywhere.
375 SmallVector<Operation *> worklist;
376 auto deadInsts = erasedInsts;
377 for (auto *op : erasedModules)
378 worklist.push_back(op);
379 while (!worklist.empty()) {
380 auto *op = worklist.pop_back_val();
381 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
382 op->walk([&](firrtl::InstanceOp instOp) {
383 auto moduleOp = cast<firrtl::FModuleLike>(
384 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
385 deadInsts.insert(instOp);
386 if (llvm::all_of(
387 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
388 [&](Operation *user) { return deadInsts.contains(user); })) {
389 LLVM_DEBUG(llvm::dbgs() << "- Removing transitively unused module `"
390 << moduleOp.getModuleName() << "`\n");
391 erasedModules.insert(moduleOp);
392 worklist.push_back(moduleOp);
393 }
394 });
395 }
396
397 for (auto *op : erasedInsts)
398 op->erase();
399 for (auto *op : erasedModules)
400 op->erase();
401 nlaRemover.remove(op);
402 }
403
404 uint64_t match(firrtl::InstanceOp instOp) override {
405 if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
406 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
407 return 0;
408 }
409
410 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
411 LLVM_DEBUG(llvm::dbgs()
412 << "Stubbing instance `" << instOp.getName() << "`\n");
413 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
415 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
416 auto result = instOp.getResult(i);
417 auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" +
418 instOp.getPortName(i));
419 auto wire =
420 firrtl::WireOp::create(builder, result.getType(), name,
421 firrtl::NameKindEnum::DroppableName,
422 instOp.getPortAnnotation(i), StringAttr{})
423 .getResult();
424 invalidateOutputs(builder, wire, invalidCache,
425 instOp.getPortDirection(i) == firrtl::Direction::In);
426 result.replaceAllUsesWith(wire);
427 }
428 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
429 auto moduleOp = cast<firrtl::FModuleLike>(
430 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
431 nlaRemover.markNLAsInOperation(instOp);
432 erasedInsts.insert(instOp);
433 if (llvm::all_of(
434 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
435 [&](Operation *user) { return erasedInsts.contains(user); })) {
436 LLVM_DEBUG(llvm::dbgs() << "- Removing now unused module `"
437 << moduleOp.getModuleName() << "`\n");
438 erasedModules.insert(moduleOp);
439 }
440 return success();
441 }
442
443 std::string getName() const override { return "instance-stubber"; }
444 bool acceptSizeIncrease() const override { return true; }
445
446 ::detail::SymbolCache symbols;
447 NLARemover nlaRemover;
448 llvm::DenseSet<Operation *> erasedInsts;
449 llvm::DenseSet<Operation *> erasedModules;
450 ModuleSizeCache moduleSizes;
451};
452
453/// A sample reduction pattern that maps `firrtl.mem` to a set of invalidated
454/// wires.
455struct MemoryStubber : public OpReduction<firrtl::MemOp> {
456 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
457 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
458 LogicalResult rewrite(firrtl::MemOp memOp) override {
459 LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
460 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
462 Value xorInputs;
463 SmallVector<Value> outputs;
464 for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
465 auto result = memOp.getResult(i);
466 auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" +
467 memOp.getPortName(i));
468 auto wire =
469 firrtl::WireOp::create(builder, result.getType(), name,
470 firrtl::NameKindEnum::DroppableName,
471 memOp.getPortAnnotation(i), StringAttr{})
472 .getResult();
473 invalidateOutputs(builder, wire, invalidCache, true);
474 result.replaceAllUsesWith(wire);
475
476 // Isolate the input and output data fields of the port.
477 Value input, output;
478 switch (memOp.getPortKind(i)) {
479 case firrtl::MemOp::PortKind::Read:
480 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
481 break;
482 case firrtl::MemOp::PortKind::Write:
483 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
484 break;
485 case firrtl::MemOp::PortKind::ReadWrite:
486 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
487 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
488 break;
489 case firrtl::MemOp::PortKind::Debug:
490 output = wire;
491 break;
492 }
493
494 if (!isa<firrtl::RefType>(result.getType())) {
495 // Reduce all input ports to a single one through an XOR tree.
496 unsigned numFields =
497 cast<firrtl::BundleType>(wire.getType()).getNumElements();
498 for (unsigned i = 0; i != numFields; ++i) {
499 if (i != 2 && i != 3 && i != 5)
500 reduceXor(builder, xorInputs,
501 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
502 }
503 if (input)
504 reduceXor(builder, xorInputs, input);
505 }
506
507 // Track the output port to hook it up to the XORd input later.
508 if (output)
509 outputs.push_back(output);
510 }
511
512 // Hook up the outputs.
513 for (auto output : outputs)
514 connectToLeafs(builder, output, xorInputs);
515
516 nlaRemover.markNLAsInOperation(memOp);
517 memOp->erase();
518 return success();
519 }
520 std::string getName() const override { return "memory-stubber"; }
521 bool acceptSizeIncrease() const override { return true; }
522 NLARemover nlaRemover;
523};
524
525/// Check whether an operation interacts with flows in any way, which can make
526/// replacement and operand forwarding harder in some cases.
527static bool isFlowSensitiveOp(Operation *op) {
528 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
529 SubaccessOp, ObjectSubfieldOp>(op);
530}
531
532/// A sample reduction pattern that replaces all uses of an operation with one
533/// of its operands. This can help pruning large parts of the expression tree
534/// rapidly.
535template <unsigned OpNum>
536struct FIRRTLOperandForwarder : public Reduction {
537 uint64_t match(Operation *op) override {
538 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
539 return 0;
540 if (isFlowSensitiveOp(op))
541 return 0;
542 auto resultTy =
543 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
544 auto opTy =
545 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
546 return resultTy && opTy &&
547 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
548 (resultTy.getBitWidthOrSentinel() == -1) ==
549 (opTy.getBitWidthOrSentinel() == -1) &&
550 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
551 }
552 LogicalResult rewrite(Operation *op) override {
553 assert(match(op));
554 ImplicitLocOpBuilder builder(op->getLoc(), op);
555 auto result = op->getResult(0);
556 auto operand = op->getOperand(OpNum);
557 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
558 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
559 auto resultWidth = resultTy.getBitWidthOrSentinel();
560 auto operandWidth = operandTy.getBitWidthOrSentinel();
561 Value newOp;
562 if (resultWidth < operandWidth)
563 newOp =
564 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
565 else if (resultWidth > operandWidth)
566 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
567 else
568 newOp = operand;
569 LLVM_DEBUG(llvm::dbgs() << "Forwarding " << newOp << " in " << *op << "\n");
570 result.replaceAllUsesWith(newOp);
571 reduce::pruneUnusedOps(op, *this);
572 return success();
573 }
574 std::string getName() const override {
575 return ("firrtl-operand" + Twine(OpNum) + "-forwarder").str();
576 }
577};
578
579/// A sample reduction pattern that replaces FIRRTL operations with a constant
580/// zero of their type.
581struct Constantifier : public Reduction {
582 void beforeReduction(mlir::ModuleOp op) override {
583 symbols.clear();
584
585 // Find valid dummy classes that we can use for anyref casts.
586 anyrefCastDummy.clear();
587 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
588 for (auto classOp : circuitOp.getOps<ClassOp>()) {
589 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
590 anyrefCastDummy.insert({circuitOp, classOp});
591 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
592 }
593 }
594 return WalkResult::skip();
595 });
596 }
597
598 uint64_t match(Operation *op) override {
599 if (op->hasTrait<OpTrait::ConstantLike>()) {
600 Attribute attr;
601 if (!matchPattern(op, m_Constant(&attr)))
602 return 0;
603 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
604 if (intAttr.getValue().isZero())
605 return 0;
606 if (auto strAttr = dyn_cast<StringAttr>(attr))
607 if (strAttr.empty())
608 return 0;
609 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
610 if (floatAttr.getValue().isZero())
611 return 0;
612 }
613 if (auto listOp = dyn_cast<ListCreateOp>(op))
614 if (listOp.getElements().empty())
615 return 0;
616 if (auto pathOp = dyn_cast<UnresolvedPathOp>(op))
617 if (pathOp.getTarget().empty())
618 return 0;
619
620 // Don't replace anyref casts that already target a dummy class.
621 if (auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
622 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
623 auto className =
624 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
625 if (anyrefCastDummyNames[circuitOp].contains(className))
626 return 0;
627 }
628
629 if (op->getNumResults() != 1)
630 return 0;
631 if (op->hasAttr("inner_sym"))
632 return 0;
633 if (isFlowSensitiveOp(op))
634 return 0;
635 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
636 DoubleType, ListType, PathType, AnyRefType>(
637 op->getResult(0).getType());
638 }
639
640 LogicalResult rewrite(Operation *op) override {
641 OpBuilder builder(op);
642 auto type = op->getResult(0).getType();
643
644 // Handle UInt/SInt types.
645 if (isa<UIntType, SIntType>(type)) {
646 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
647 if (width == -1)
648 width = 64;
649 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
650 APSInt(width, isa<UIntType>(type)));
651 op->replaceAllUsesWith(newOp);
652 reduce::pruneUnusedOps(op, *this);
653 return success();
654 }
655
656 // Handle property string types.
657 if (isa<StringType>(type)) {
658 auto attr = builder.getStringAttr("");
659 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
660 op->replaceAllUsesWith(newOp);
661 reduce::pruneUnusedOps(op, *this);
662 return success();
663 }
664
665 // Handle property integer types.
666 if (isa<FIntegerType>(type)) {
667 auto attr = builder.getIntegerAttr(builder.getI64Type(), 0);
668 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
669 op->replaceAllUsesWith(newOp);
670 reduce::pruneUnusedOps(op, *this);
671 return success();
672 }
673
674 // Handle property boolean types.
675 if (isa<BoolType>(type)) {
676 auto attr = builder.getBoolAttr(false);
677 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
678 op->replaceAllUsesWith(newOp);
679 reduce::pruneUnusedOps(op, *this);
680 return success();
681 }
682
683 // Handle property double types.
684 if (isa<DoubleType>(type)) {
685 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
686 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
687 op->replaceAllUsesWith(newOp);
688 reduce::pruneUnusedOps(op, *this);
689 return success();
690 }
691
692 // Handle property list types.
693 if (isa<ListType>(type)) {
694 auto newOp =
695 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
696 op->replaceAllUsesWith(newOp);
697 reduce::pruneUnusedOps(op, *this);
698 return success();
699 }
700
701 // Handle property path types.
702 if (isa<PathType>(type)) {
703 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(), "");
704 op->replaceAllUsesWith(newOp);
705 reduce::pruneUnusedOps(op, *this);
706 return success();
707 }
708
709 // Handle anyref types.
710 if (isa<AnyRefType>(type)) {
711 auto circuitOp = op->getParentOfType<CircuitOp>();
712 auto &dummy = anyrefCastDummy[circuitOp];
713 if (!dummy) {
714 OpBuilder::InsertionGuard guard(builder);
715 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
716 auto &symbolTable = symbols.getNearestSymbolTable(op);
717 dummy = ClassOp::create(builder, op->getLoc(), "Dummy", {}, {});
718 symbolTable.insert(dummy);
719 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
720 }
721 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy, "dummy");
722 auto anyrefOp =
723 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
724 op->replaceAllUsesWith(anyrefOp);
725 reduce::pruneUnusedOps(op, *this);
726 return success();
727 }
728
729 return failure();
730 }
731
732 std::string getName() const override { return "firrtl-constantifier"; }
733 bool acceptSizeIncrease() const override { return true; }
734
735 ::detail::SymbolCache symbols;
737 SmallDenseMap<CircuitOp, DenseSet<StringAttr>, 2> anyrefCastDummyNames;
738};
739
740/// A sample reduction pattern that replaces the right-hand-side of
741/// `firrtl.connect` and `firrtl.matchingconnect` operations with a
742/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
743/// connects and creates opportunities for reduction in DCE/CSE.
744struct ConnectInvalidator : public Reduction {
745 uint64_t match(Operation *op) override {
746 if (!isa<FConnectLike>(op))
747 return 0;
748 if (auto *srcOp = op->getOperand(1).getDefiningOp())
749 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
750 isa<InvalidValueOp>(srcOp))
751 return 0;
752 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
753 return type && type.isPassive();
754 }
755 LogicalResult rewrite(Operation *op) override {
756 assert(match(op));
757 auto rhs = op->getOperand(1);
758 OpBuilder builder(op);
759 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
760 auto *rhsOp = rhs.getDefiningOp();
761 op->setOperand(1, invOp);
762 if (rhsOp)
763 reduce::pruneUnusedOps(rhsOp, *this);
764 return success();
765 }
766 std::string getName() const override { return "connect-invalidator"; }
767 bool acceptSizeIncrease() const override { return true; }
768};
769
770/// A reduction pattern that removes FIRRTL annotations from ports and
771/// operations. This generates one match per annotation and port annotation,
772/// allowing selective removal of individual annotations.
773struct AnnotationRemover : public Reduction {
774 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
775 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
776
777 void matches(Operation *op,
778 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
779 uint64_t matchId = 0;
780
781 // Generate matches for regular annotations
782 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
783 for (unsigned i = 0; i < annos.size(); ++i)
784 addMatch(1, matchId++);
785
786 // Generate matches for port annotations
787 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations"))
788 for (auto portAnnoArray : portAnnos)
789 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
790 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
791 addMatch(1, matchId++);
792 }
793
794 LogicalResult rewriteMatches(Operation *op,
795 ArrayRef<uint64_t> matches) override {
796 // Convert matches to a set for fast lookup
797 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
798
799 // Lambda to process annotations and filter out matched ones
800 uint64_t matchId = 0;
801 auto processAnnotations =
802 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
803 SmallVector<Attribute> newAnnotations;
804 for (auto anno : annotations) {
805 if (!matchesSet.contains(matchId)) {
806 newAnnotations.push_back(anno);
807 } else {
808 // Mark NLAs in the removed annotation for cleanup
809 nlaRemover.markNLAsInAnnotation(anno);
810 }
811 matchId++;
812 }
813 return ArrayAttr::get(op->getContext(), newAnnotations);
814 };
815
816 // Remove regular annotations
817 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
818 op->setAttr("annotations", processAnnotations(annos.getValue()));
819 }
820
821 // Remove port annotations
822 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
823 SmallVector<Attribute> newPortAnnos;
824 for (auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
825 newPortAnnos.push_back(
826 processAnnotations(portAnnoArrayAttr.getValue()));
827 }
828 op->setAttr("portAnnotations",
829 ArrayAttr::get(op->getContext(), newPortAnnos));
830 }
831
832 return success();
833 }
834
835 std::string getName() const override { return "annotation-remover"; }
836 NLARemover nlaRemover;
837};
838
839/// A reduction pattern that replaces ResetType with UInt<1> across an entire
840/// circuit. This walks all operations in the circuit and replaces ResetType in
841/// results, block arguments, and attributes.
842struct SimplifyResets : public OpReduction<CircuitOp> {
843 uint64_t match(CircuitOp circuit) override {
844 uint64_t numResets = 0;
845 AttrTypeWalker walker;
846 walker.addWalk([&](ResetType type) { ++numResets; });
847
848 circuit.walk([&](Operation *op) {
849 for (auto result : op->getResults())
850 walker.walk(result.getType());
851
852 for (auto &region : op->getRegions())
853 for (auto &block : region)
854 for (auto arg : block.getArguments())
855 walker.walk(arg.getType());
856
857 walker.walk(op->getAttrDictionary());
858 });
859
860 return numResets;
861 }
862
863 LogicalResult rewrite(CircuitOp circuit) override {
864 auto uint1Type = UIntType::get(circuit->getContext(), 1, false);
865 auto constUint1Type = UIntType::get(circuit->getContext(), 1, true);
866
867 AttrTypeReplacer replacer;
868 replacer.addReplacement([&](ResetType type) {
869 return type.isConst() ? constUint1Type : uint1Type;
870 });
871 replacer.recursivelyReplaceElementsIn(circuit, /*replaceAttrs=*/true,
872 /*replaceLocs=*/false,
873 /*replaceTypes=*/true);
874
875 // Remove annotations related to InferResets pass
876 circuit.walk([&](Operation *op) {
877 // Remove operation annotations
882 });
883
884 // Remove port annotations for module-like operations
885 if (auto module = dyn_cast<FModuleLike>(op)) {
886 AnnotationSet::removePortAnnotations(module, [&](unsigned portIdx,
887 Annotation anno) {
891 });
892 }
893 });
894
895 return success();
896 }
897
898 std::string getName() const override { return "firrtl-simplify-resets"; }
899 bool acceptSizeIncrease() const override { return true; }
900};
901
902/// A sample reduction pattern that removes ports from the root `firrtl.module`
903/// if the port is not used or just invalidated.
904struct RootPortPruner : public OpReduction<firrtl::FModuleOp> {
905 uint64_t match(firrtl::FModuleOp module) override {
906 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
907 if (!circuit)
908 return 0;
909 return circuit.getNameAttr() == module.getNameAttr();
910 }
911 LogicalResult rewrite(firrtl::FModuleOp module) override {
912 assert(match(module));
913 size_t numPorts = module.getNumPorts();
914 llvm::BitVector dropPorts(numPorts);
915 for (unsigned i = 0; i != numPorts; ++i) {
916 if (onlyInvalidated(module.getArgument(i))) {
917 dropPorts.set(i);
918 for (auto *user :
919 llvm::make_early_inc_range(module.getArgument(i).getUsers()))
920 user->erase();
921 }
922 }
923 module.erasePorts(dropPorts);
924 return success();
925 }
926 std::string getName() const override { return "root-port-pruner"; }
927};
928
929/// A sample reduction pattern that replaces instances of `firrtl.extmodule`
930/// with wires.
931struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {
932 void beforeReduction(mlir::ModuleOp op) override {
933 symbols.clear();
934 nlaRemover.clear();
935 }
936 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
937
938 uint64_t match(firrtl::InstanceOp instOp) override {
939 return isa<firrtl::FExtModuleOp>(
940 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
941 }
942 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
943 auto portInfo =
944 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
945 symbols.getNearestSymbolTable(instOp)))
946 .getPorts();
947 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
948 SmallVector<Value> replacementWires;
949 for (firrtl::PortInfo info : portInfo) {
950 auto wire = firrtl::WireOp::create(
951 builder, info.type,
952 (Twine(instOp.getName()) + "_" + info.getName()).str())
953 .getResult();
954 if (info.isOutput()) {
955 auto inv = firrtl::InvalidValueOp::create(builder, info.type);
956 firrtl::ConnectOp::create(builder, wire, inv);
957 }
958 replacementWires.push_back(wire);
959 }
960 nlaRemover.markNLAsInOperation(instOp);
961 instOp.replaceAllUsesWith(std::move(replacementWires));
962 instOp->erase();
963 return success();
964 }
965 std::string getName() const override { return "extmodule-instance-remover"; }
966 bool acceptSizeIncrease() const override { return true; }
967
968 ::detail::SymbolCache symbols;
969 NLARemover nlaRemover;
970};
971
972/// A sample reduction pattern that pushes connected values through wires.
973struct ConnectForwarder : public Reduction {
974 uint64_t match(Operation *op) override {
975 if (!isa<firrtl::FConnectLike>(op))
976 return 0;
977 auto dest = op->getOperand(0);
978 auto src = op->getOperand(1);
979 auto *destOp = dest.getDefiningOp();
980 auto *srcOp = src.getDefiningOp();
981 if (dest == src)
982 return 0;
983
984 // Ensure that the destination is something we should be able to forward
985 // through.
986 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
987 destOp))
988 return 0;
989
990 // Ensure that the destination is connected to only once, and all uses of
991 // the connection occur after the definition of the source.
992 unsigned numConnects = 0;
993 for (auto &use : dest.getUses()) {
994 auto *op = use.getOwner();
995 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
996 if (++numConnects > 1)
997 return 0;
998 continue;
999 }
1000 if (srcOp && !srcOp->isBeforeInBlock(op))
1001 return 0;
1002 }
1003
1004 return 1;
1005 }
1006
1007 LogicalResult rewrite(Operation *op) override {
1008 auto dst = op->getOperand(0);
1009 auto src = op->getOperand(1);
1010 dst.replaceAllUsesWith(src);
1011 op->erase();
1012 if (auto *dstOp = dst.getDefiningOp())
1013 reduce::pruneUnusedOps(dstOp, *this);
1014 if (auto *srcOp = src.getDefiningOp())
1015 reduce::pruneUnusedOps(srcOp, *this);
1016 return success();
1017 }
1018
1019 std::string getName() const override { return "connect-forwarder"; }
1020};
1021
1022/// A sample reduction pattern that replaces a single-use wire and register with
1023/// an operand of the source value of the connection.
1024template <unsigned OpNum>
1025struct ConnectSourceOperandForwarder : public Reduction {
1026 uint64_t match(Operation *op) override {
1027 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1028 return 0;
1029 auto dest = op->getOperand(0);
1030 auto *destOp = dest.getDefiningOp();
1031
1032 // Ensure that the destination is used only once.
1033 if (!destOp || !destOp->hasOneUse() ||
1034 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1035 return 0;
1036
1037 auto *srcOp = op->getOperand(1).getDefiningOp();
1038 if (!srcOp || OpNum >= srcOp->getNumOperands())
1039 return 0;
1040
1041 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1042 auto opTy =
1043 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1044
1045 return resultTy && opTy &&
1046 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1047 ((resultTy.getBitWidthOrSentinel() == -1) ==
1048 (opTy.getBitWidthOrSentinel() == -1)) &&
1049 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1050 }
1051
1052 LogicalResult rewrite(Operation *op) override {
1053 auto *destOp = op->getOperand(0).getDefiningOp();
1054 auto *srcOp = op->getOperand(1).getDefiningOp();
1055 auto forwardedOperand = srcOp->getOperand(OpNum);
1056 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1057 Value newDest;
1058 if (auto wire = dyn_cast<firrtl::WireOp>(destOp))
1059 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1060 wire.getName())
1061 .getResult();
1062 else {
1063 auto regName = destOp->getAttrOfType<StringAttr>("name");
1064 // We can promote the register into a wire but we wouldn't do here because
1065 // the error might be caused by the register.
1066 auto clock = destOp->getOperand(0);
1067 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1068 clock, regName ? regName.str() : "")
1069 .getResult();
1070 }
1071
1072 // Create new connection between a new wire and the forwarded operand.
1073 builder.setInsertionPointAfter(op);
1074 if (isa<firrtl::ConnectOp>(op))
1075 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1076 else
1077 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1078
1079 // Remove the old connection and destination. We don't have to replace them
1080 // because destination has only one use.
1081 op->erase();
1082 destOp->erase();
1083 reduce::pruneUnusedOps(srcOp, *this);
1084
1085 return success();
1086 }
1087 std::string getName() const override {
1088 return ("connect-source-operand-" + Twine(OpNum) + "-forwarder").str();
1089 }
1090};
1091
1092/// A sample reduction pattern that tries to remove aggregate wires by replacing
1093/// all subaccesses with new independent wires. This can disentangle large
1094/// unused wires that are otherwise difficult to collect due to the subaccesses.
1095struct DetachSubaccesses : public Reduction {
1096 void beforeReduction(mlir::ModuleOp op) override { opsToErase.clear(); }
1097 void afterReduction(mlir::ModuleOp op) override {
1098 for (auto *op : opsToErase)
1099 op->dropAllReferences();
1100 for (auto *op : opsToErase)
1101 op->erase();
1102 }
1103 uint64_t match(Operation *op) override {
1104 // Only applies to wires and registers that are purely used in subaccess
1105 // operations.
1106 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1107 llvm::all_of(op->getUses(), [](auto &use) {
1108 return use.getOperandNumber() == 0 &&
1109 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1110 firrtl::SubaccessOp>(use.getOwner());
1111 });
1112 }
1113 LogicalResult rewrite(Operation *op) override {
1114 assert(match(op));
1115 OpBuilder builder(op);
1116 bool isWire = isa<firrtl::WireOp>(op);
1117 Value invalidClock;
1118 if (!isWire)
1119 invalidClock = firrtl::InvalidValueOp::create(
1120 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1121 for (Operation *user : llvm::make_early_inc_range(op->getUsers())) {
1122 builder.setInsertionPoint(user);
1123 auto type = user->getResult(0).getType();
1124 Operation *replOp;
1125 if (isWire)
1126 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1127 else
1128 replOp =
1129 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1130 user->replaceAllUsesWith(replOp);
1131 opsToErase.insert(user);
1132 }
1133 opsToErase.insert(op);
1134 return success();
1135 }
1136 std::string getName() const override { return "detach-subaccesses"; }
1137 llvm::DenseSet<Operation *> opsToErase;
1138};
1139
1140/// This reduction removes inner symbols on ops. Name preservation creates a lot
1141/// of node ops with symbols to keep name information but it also prevents
1142/// normal canonicalizations.
1143struct NodeSymbolRemover : public Reduction {
1144 void beforeReduction(mlir::ModuleOp op) override {
1145 innerSymUses = reduce::InnerSymbolUses(op);
1146 }
1147
1148 uint64_t match(Operation *op) override {
1149 // Only match ops with an inner symbol.
1150 auto sym = op->getAttrOfType<hw::InnerSymAttr>("inner_sym");
1151 if (!sym || sym.empty())
1152 return 0;
1153
1154 // Only match ops that have no references to their inner symbol.
1155 if (innerSymUses.hasInnerRef(op))
1156 return 0;
1157 return 1;
1158 }
1159
1160 LogicalResult rewrite(Operation *op) override {
1161 op->removeAttr("inner_sym");
1162 return success();
1163 }
1164
1165 std::string getName() const override { return "node-symbol-remover"; }
1166 bool acceptSizeIncrease() const override { return true; }
1167
1168 reduce::InnerSymbolUses innerSymUses;
1169};
1170
1171/// Check if inlining the referenced operation into the parent operation would
1172/// cause inner symbol collisions.
1173static bool
1174hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1175 hw::InnerSymbolTableCollection &innerSymTables) {
1176 // Get the inner symbol tables for both operations
1177 auto &targetTable = innerSymTables.getInnerSymbolTable(referencedOp);
1178 auto &parentTable = innerSymTables.getInnerSymbolTable(parentOp);
1179
1180 // Check if any inner symbol name in the target operation already exists
1181 // in the parent operation. Return failure() if a collision is found to stop
1182 // the walk early.
1183 LogicalResult walkResult = targetTable.walkSymbols(
1184 [&](StringAttr name, const hw::InnerSymTarget &target) -> LogicalResult {
1185 // Check if this symbol name exists in the parent operation
1186 if (parentTable.lookup(name)) {
1187 // Collision found, return failure to stop the walk
1188 return failure();
1189 }
1190 return success();
1191 });
1192
1193 // If the walk failed, it means we found a collision
1194 return failed(walkResult);
1195}
1196
1197/// A sample reduction pattern that eagerly inlines instances.
1198struct EagerInliner : public OpReduction<InstanceOp> {
1199 void beforeReduction(mlir::ModuleOp op) override {
1200 symbols.clear();
1201 nlaRemover.clear();
1202 nlaTables.clear();
1203 for (auto circuitOp : op.getOps<CircuitOp>())
1204 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1205 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1206 }
1207 void afterReduction(mlir::ModuleOp op) override {
1208 nlaRemover.remove(op);
1209 nlaTables.clear();
1210 innerSymTables.reset();
1211 }
1212
1213 uint64_t match(InstanceOp instOp) override {
1214 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1215 auto *moduleOp =
1216 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1217
1218 // Only inline FModuleOp instances
1219 if (!isa<FModuleOp>(moduleOp))
1220 return 0;
1221
1222 // Skip instances that participate in any NLAs
1223 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1224 if (!circuitOp)
1225 return 0;
1226 auto it = nlaTables.find(circuitOp);
1227 if (it == nlaTables.end() || !it->second)
1228 return 0;
1229 DenseSet<hw::HierPathOp> nlas;
1230 it->second->getInstanceNLAs(instOp, nlas);
1231 if (!nlas.empty())
1232 return 0;
1233
1234 // Check for inner symbol collisions between the referenced module and the
1235 // instance's parent module
1236 auto parentOp = instOp->getParentOfType<FModuleLike>();
1237 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1238 return 0;
1239
1240 return 1;
1241 }
1242
1243 LogicalResult rewrite(InstanceOp instOp) override {
1244 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1245 auto moduleOp = cast<FModuleOp>(
1246 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1247 bool isLastUse =
1248 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1249 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1250
1251 // Create wires to replace the instance results.
1252 IRRewriter rewriter(instOp);
1253 SmallVector<Value> argWires;
1254 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1255 auto result = instOp.getResult(i);
1256 auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" +
1257 instOp.getPortName(i));
1258 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1259 name, NameKindEnum::DroppableName,
1260 instOp.getPortAnnotation(i), StringAttr{})
1261 .getResult();
1262 result.replaceAllUsesWith(wire);
1263 argWires.push_back(wire);
1264 }
1265
1266 // Splice in the cloned module body.
1267 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1268
1269 // Make sure we remove any NLAs that go through this instance, and the
1270 // module if we're about the delete the module.
1271 nlaRemover.markNLAsInOperation(instOp);
1272 if (isLastUse)
1273 nlaRemover.markNLAsInOperation(moduleOp);
1274
1275 instOp.erase();
1276 clonedModuleOp.erase();
1277 return success();
1278 }
1279
1280 std::string getName() const override { return "firrtl-eager-inliner"; }
1281 bool acceptSizeIncrease() const override { return true; }
1282
1283 ::detail::SymbolCache symbols;
1284 NLARemover nlaRemover;
1285 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1286 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1287};
1288
1289/// A reduction pattern that eagerly inlines `ObjectOp`s.
1290struct ObjectInliner : public OpReduction<ObjectOp> {
1291 void beforeReduction(mlir::ModuleOp op) override {
1292 blocksToSort.clear();
1293 symbols.clear();
1294 nlaRemover.clear();
1295 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1296 }
1297 void afterReduction(mlir::ModuleOp op) override {
1298 for (auto *block : blocksToSort)
1299 mlir::sortTopologically(block);
1300 blocksToSort.clear();
1301 nlaRemover.remove(op);
1302 innerSymTables.reset();
1303 }
1304
1305 uint64_t match(ObjectOp objOp) override {
1306 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1307 auto *classOp =
1308 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1309
1310 // Only inline `ClassOp`s.
1311 if (!isa<ClassOp>(classOp))
1312 return 0;
1313
1314 // Check for inner symbol collisions between the referenced class and the
1315 // object's parent module.
1316 auto parentOp = objOp->getParentOfType<FModuleLike>();
1317 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1318 return 0;
1319
1320 // Verify all uses are ObjectSubfieldOp.
1321 for (auto *user : objOp.getResult().getUsers())
1322 if (!isa<ObjectSubfieldOp>(user))
1323 return 0;
1324
1325 return 1;
1326 }
1327
1328 LogicalResult rewrite(ObjectOp objOp) override {
1329 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1330 auto classOp = cast<ClassOp>(
1331 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1332 auto clonedClassOp = classOp.clone();
1333
1334 // Create wires to replace the ObjectSubfieldOp results.
1335 IRRewriter rewriter(objOp);
1336 SmallVector<Value> portWires;
1337 auto classType = objOp.getType();
1338
1339 // Create a wire for each port in the class
1340 for (unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1341 auto element = classType.getElement(i);
1342 auto name = rewriter.getStringAttr(Twine(objOp.getName()) + "_" +
1343 element.name.getValue());
1344 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1345 NameKindEnum::DroppableName,
1346 rewriter.getArrayAttr({}), StringAttr{})
1347 .getResult();
1348 portWires.push_back(wire);
1349 }
1350
1351 // Replace all ObjectSubfieldOp uses with corresponding wires
1352 SmallVector<ObjectSubfieldOp> subfieldOps;
1353 for (auto *user : objOp.getResult().getUsers()) {
1354 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1355 subfieldOps.push_back(subfieldOp);
1356 auto index = subfieldOp.getIndex();
1357 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1358 }
1359
1360 // Splice in the cloned class body.
1361 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1362
1363 // After inlining the class body, we need to eliminate `WireOps` since
1364 // `ClassOps` cannot contain wires. For each port wire, find its single
1365 // connect, remove it, and replace all uses of the wire with the assigned
1366 // value.
1367 SmallVector<FConnectLike> connectsToErase;
1368 for (auto portWire : portWires) {
1369 // Find a single value to replace the wire with, and collect all connects
1370 // to the wire such that we can erase them later.
1371 Value value;
1372 for (auto *user : portWire.getUsers()) {
1373 if (auto connect = dyn_cast<FConnectLike>(user)) {
1374 if (connect.getDest() == portWire) {
1375 value = connect.getSrc();
1376 connectsToErase.push_back(connect);
1377 }
1378 }
1379 }
1380
1381 // Be very conservative about deleting these wires. Other reductions may
1382 // leave class ports unconnected, which means that there isn't always a
1383 // clean replacement available here. Better to just leave the wires in the
1384 // IR and let the verifier fail later.
1385 if (value)
1386 portWire.replaceAllUsesWith(value);
1387 for (auto connect : connectsToErase)
1388 connect.erase();
1389 if (portWire.use_empty())
1390 portWire.getDefiningOp()->erase();
1391 connectsToErase.clear();
1392 }
1393
1394 // Make sure we remove any NLAs that go through this object.
1395 nlaRemover.markNLAsInOperation(objOp);
1396
1397 // Since the above forwarding of SSA values through wires can create
1398 // dominance issues, mark the region containing the object to be sorted
1399 // topologically.
1400 blocksToSort.insert(objOp->getBlock());
1401
1402 // Erase the object and cloned class.
1403 for (auto subfieldOp : subfieldOps)
1404 subfieldOp.erase();
1405 objOp.erase();
1406 clonedClassOp.erase();
1407 return success();
1408 }
1409
1410 std::string getName() const override { return "firrtl-object-inliner"; }
1411 bool acceptSizeIncrease() const override { return true; }
1412
1413 SetVector<Block *> blocksToSort;
1414 ::detail::SymbolCache symbols;
1415 NLARemover nlaRemover;
1416 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1417};
1418
1419/// Psuedo-reduction that sanitizes the names of things inside modules. This is
1420/// not an actual reduction, but often removes extraneous information that has
1421/// no bearing on the actual reduction (and would likely be removed before
1422/// sharing the reduction). This makes the following changes:
1423///
1424/// - All wires are renamed to "wire"
1425/// - All registers are renamed to "reg"
1426/// - All nodes are renamed to "node"
1427/// - All memories are renamed to "mem"
1428/// - All verification messages and labels are dropped
1429///
1430struct ModuleInternalNameSanitizer : public Reduction {
1431 uint64_t match(Operation *op) override {
1432 // Only match operations with names.
1433 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1434 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1435 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1436 firrtl::CoverOp>(op);
1437 }
1438 LogicalResult rewrite(Operation *op) override {
1439 TypeSwitch<Operation *, void>(op)
1440 .Case<firrtl::WireOp>([](auto op) { op.setName("wire"); })
1441 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1442 [](auto op) { op.setName("reg"); })
1443 .Case<firrtl::NodeOp>([](auto op) { op.setName("node"); })
1444 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1445 [](auto op) { op.setName("mem"); })
1446 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](auto op) {
1447 op->setAttr("message", StringAttr::get(op.getContext(), ""));
1448 op->setAttr("name", StringAttr::get(op.getContext(), ""));
1449 });
1450 return success();
1451 }
1452
1453 std::string getName() const override {
1454 return "module-internal-name-sanitizer";
1455 }
1456
1457 bool acceptSizeIncrease() const override { return true; }
1458
1459 bool isOneShot() const override { return true; }
1460};
1461
1462/// Psuedo-reduction that sanitizes module, instance, and port names. This
1463/// makes the following changes:
1464///
1465/// - All modules are given metasyntactic names ("Foo", "Bar", etc.)
1466/// - All instances are renamed to match the new module name
1467/// - All module ports are renamed in the following way:
1468/// - All clocks are reanemd to "clk"
1469/// - All resets are renamed to "rst"
1470/// - All references are renamed to "ref"
1471/// - Anything else is renamed to "port"
1472///
1473struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
1474
1475 const char *names[48] = {
1476 "Foo", "Bar", "Baz", "Qux", "Quux", "Quuux", "Quuuux",
1477 "Quz", "Corge", "Grault", "Bazola", "Ztesch", "Thud", "Grunt",
1478 "Bletch", "Fum", "Fred", "Jim", "Sheila", "Barney", "Flarp",
1479 "Zxc", "Spqr", "Wombat", "Shme", "Bongo", "Spam", "Eggs",
1480 "Snork", "Zot", "Blarg", "Wibble", "Toto", "Titi", "Tata",
1481 "Tutu", "Pippo", "Pluto", "Paperino", "Aap", "Noot", "Mies",
1482 "Oogle", "Foogle", "Boogle", "Zork", "Gork", "Bork"};
1483
1484 size_t nameIndex = 0;
1485
1486 const char *getName() {
1487 if (nameIndex >= 48)
1488 nameIndex = 0;
1489 return names[nameIndex++];
1490 };
1491
1492 size_t portNameIndex = 0;
1493
1494 char getPortName() {
1495 if (portNameIndex >= 26)
1496 portNameIndex = 0;
1497 return 'a' + portNameIndex++;
1498 }
1499
1500 void beforeReduction(mlir::ModuleOp op) override { nameIndex = 0; }
1501
1502 LogicalResult rewrite(firrtl::CircuitOp circuitOp) override {
1503
1504 firrtl::InstanceGraph iGraph(circuitOp);
1505
1506 auto *circuitName = getName();
1507 iGraph.getTopLevelModule().setName(circuitName);
1508 circuitOp.setName(circuitName);
1509
1510 for (auto *node : iGraph) {
1511 auto module = node->getModule<firrtl::FModuleLike>();
1512
1513 bool shouldReplacePorts = false;
1514 SmallVector<Attribute> newNames;
1515 if (auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1516 portNameIndex = 0;
1517 // TODO: The namespace should be unnecessary. However, some FIRRTL
1518 // passes expect that port names are unique.
1520 auto oldPorts = fmodule.getPorts();
1521 shouldReplacePorts = !oldPorts.empty();
1522 for (unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1523 auto port = oldPorts[i];
1524 auto newName = firrtl::FIRRTLTypeSwitch<Type, StringRef>(port.type)
1525 .Case<firrtl::ClockType>(
1526 [&](auto a) { return ns.newName("clk"); })
1527 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1528 [&](auto a) { return ns.newName("rst"); })
1529 .Case<firrtl::RefType>(
1530 [&](auto a) { return ns.newName("ref"); })
1531 .Default([&](auto a) {
1532 return ns.newName(Twine(getPortName()));
1533 });
1534 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1535 }
1536 fmodule->setAttr("portNames",
1537 ArrayAttr::get(fmodule.getContext(), newNames));
1538 }
1539
1540 if (module == iGraph.getTopLevelModule())
1541 continue;
1542 auto newName = StringAttr::get(circuitOp.getContext(), getName());
1543 module.setName(newName);
1544 for (auto *use : node->uses()) {
1545 auto instanceOp = dyn_cast<firrtl::InstanceOp>(*use->getInstance());
1546 instanceOp.setModuleName(newName);
1547 instanceOp.setName(newName);
1548 if (shouldReplacePorts)
1549 instanceOp.setPortNamesAttr(
1550 ArrayAttr::get(circuitOp.getContext(), newNames));
1551 }
1552 }
1553
1554 circuitOp->dump();
1555
1556 return success();
1557 }
1558
1559 std::string getName() const override { return "module-name-sanitizer"; }
1560
1561 bool acceptSizeIncrease() const override { return true; }
1562
1563 bool isOneShot() const override { return true; }
1564};
1565
1566/// A reduction pattern that groups modules by their port signature (types and
1567/// directions) and replaces instances with the smallest module in each group.
1568/// This helps reduce the IR by consolidating functionally equivalent modules
1569/// based on their interface.
1570///
1571/// The pattern works by:
1572/// 1. Grouping all modules by their port signature (port types and directions)
1573/// 2. For each group with multiple modules, finding the smallest module using
1574/// the module size cache
1575/// 3. Replacing all instances of larger modules with instances of the smallest
1576/// module in the same group
1577/// 4. Removing the larger modules from the circuit
1578///
1579/// This reduction is useful for reducing circuits where multiple modules have
1580/// the same interface but different implementations, allowing the reducer to
1581/// try the smallest implementation first.
1582struct ModuleSwapper : public OpReduction<InstanceOp> {
1583 // Per-circuit state containing all the information needed for module swapping
1584 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1585 struct CircuitState {
1586 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1587 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1588 std::unique_ptr<NLATable> nlaTable;
1589 };
1590
1591 void beforeReduction(mlir::ModuleOp op) override {
1592 symbols.clear();
1593 nlaRemover.clear();
1594 moduleSizes.clear();
1595 circuitStates.clear();
1596
1597 // Collect module type groups and NLA tables for all circuits up front
1598 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1599 auto &state = circuitStates[circuitOp];
1600 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1601 buildModuleTypeGroups(circuitOp, state);
1602 return WalkResult::skip();
1603 });
1604 }
1605 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1606
1607 /// Create a vector of port type-direction pairs for the given FIRRTL module.
1608 /// This ignores port names, allowing modules with the same port types and
1609 /// directions but different port names to be considered equivalent for
1610 /// swapping.
1611 PortSignature getModulePortSignature(FModuleLike module) {
1612 PortSignature signature;
1613 signature.reserve(module.getNumPorts());
1614 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1615 signature.emplace_back(module.getPortType(i), module.getPortDirection(i));
1616 return signature;
1617 }
1618
1619 /// Group modules by their port signature and find the smallest in each group.
1620 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1621 // Group modules by their port signature
1622 for (auto module : circuitOp.getBodyBlock()->getOps<FModuleLike>()) {
1623 auto signature = getModulePortSignature(module);
1624 state.moduleTypeGroups[signature].push_back(module);
1625 }
1626
1627 // For each group, find the smallest module
1628 for (auto &[signature, modules] : state.moduleTypeGroups) {
1629 if (modules.size() <= 1)
1630 continue;
1631
1632 FModuleLike smallestModule = nullptr;
1633 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1634
1635 for (auto module : modules) {
1636 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1637 if (size < smallestSize) {
1638 smallestSize = size;
1639 smallestModule = module;
1640 }
1641 }
1642
1643 // Map all modules in this group to the smallest one
1644 for (auto module : modules) {
1645 if (module != smallestModule) {
1646 state.instanceToCanonicalModule[module.getModuleNameAttr()] =
1647 smallestModule;
1648 }
1649 }
1650 }
1651 }
1652
1653 uint64_t match(InstanceOp instOp) override {
1654 // Get the circuit this instance belongs to
1655 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1656 assert(circuitOp);
1657 const auto &state = circuitStates.at(circuitOp);
1658
1659 // Skip instances that participate in any NLAs
1660 DenseSet<hw::HierPathOp> nlas;
1661 state.nlaTable->getInstanceNLAs(instOp, nlas);
1662 if (!nlas.empty())
1663 return 0;
1664
1665 // Check if this instance can be redirected to a smaller module
1666 auto moduleName = instOp.getModuleNameAttr().getAttr();
1667 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1668 if (!canonicalModule)
1669 return 0;
1670
1671 // Benefit is the size difference
1672 auto currentModule = cast<FModuleLike>(
1673 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1674 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1675 uint64_t canonicalSize =
1676 moduleSizes.getModuleSize(canonicalModule, symbols);
1677 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1678 }
1679
1680 LogicalResult rewrite(InstanceOp instOp) override {
1681 // Get the circuit this instance belongs to
1682 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1683 assert(circuitOp);
1684 const auto &state = circuitStates.at(circuitOp);
1685
1686 // Replace the instantiated module with the canonical module.
1687 auto canonicalModule = state.instanceToCanonicalModule.at(
1688 instOp.getModuleNameAttr().getAttr());
1689 auto canonicalName = canonicalModule.getModuleNameAttr();
1690 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
1691
1692 // Update port names to match the canonical module
1693 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1694
1695 return success();
1696 }
1697
1698 std::string getName() const override { return "firrtl-module-swapper"; }
1699 bool acceptSizeIncrease() const override { return true; }
1700
1701private:
1702 ::detail::SymbolCache symbols;
1703 NLARemover nlaRemover;
1704 ModuleSizeCache moduleSizes;
1705
1706 // Per-circuit state containing all module swapping information
1707 DenseMap<CircuitOp, CircuitState> circuitStates;
1708};
1709
1710/// A reduction pattern that handles MustDedup annotations by replacing all
1711/// module names in a dedup group with a single module name. This helps reduce
1712/// the IR by consolidating module references that are required to be identical.
1713///
1714/// The pattern works by:
1715/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1716/// 2. For each dedup group, using the first module as the canonical name
1717/// 3. Replacing all instance references to other modules in the group with
1718/// references to the canonical module
1719/// 4. Removing the non-canonical modules from the circuit
1720/// 5. Removing the processed MustDedup annotation
1721///
1722/// This reduction is particularly useful for reducing large circuits where
1723/// multiple modules are known to be identical but haven't been deduplicated
1724/// yet.
1725struct ForceDedup : public OpReduction<CircuitOp> {
1726 void beforeReduction(mlir::ModuleOp op) override {
1727 symbols.clear();
1728 nlaRemover.clear();
1729 modulesToErase.clear();
1730 moduleSizes.clear();
1731 }
1732 void afterReduction(mlir::ModuleOp op) override {
1733 nlaRemover.remove(op);
1734 for (auto mod : modulesToErase)
1735 mod->erase();
1736 }
1737
1738 /// Collect all MustDedup annotations and create matches for each dedup group.
1739 void matches(CircuitOp circuitOp,
1740 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1741 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
1742 auto annotations = AnnotationSet(circuitOp);
1743 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1744 if (!anno.isClass(mustDedupAnnoClass))
1745 continue;
1746
1747 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1748 if (!modulesAttr || modulesAttr.size() < 2)
1749 continue;
1750
1751 // Check that all modules have the same port signature. Malformed inputs
1752 // may have modules listed in a MustDedup annotation that have distinct
1753 // port types.
1754 uint64_t totalSize = 0;
1755 ArrayAttr portTypes;
1756 DenseBoolArrayAttr portDirections;
1757 bool allSame = true;
1758 for (auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
1759 auto target = tokenizePath(moduleName);
1760 if (!target) {
1761 allSame = false;
1762 break;
1763 }
1764 auto mod = symbolTable.lookup<FModuleLike>(target->module);
1765 if (!mod) {
1766 allSame = false;
1767 break;
1768 }
1769 totalSize += moduleSizes.getModuleSize(mod, symbols);
1770 if (!portTypes) {
1771 portTypes = mod.getPortTypesAttr();
1772 portDirections = mod.getPortDirectionsAttr();
1773 } else if (portTypes != mod.getPortTypesAttr() ||
1774 portDirections != mod.getPortDirectionsAttr()) {
1775 allSame = false;
1776 break;
1777 }
1778 }
1779 if (!allSame)
1780 continue;
1781
1782 // Each dedup group gets its own match with benefit proportional to group
1783 // size.
1784 addMatch(totalSize, annoIdx);
1785 }
1786 }
1787
1788 LogicalResult rewriteMatches(CircuitOp circuitOp,
1789 ArrayRef<uint64_t> matches) override {
1790 auto *context = circuitOp->getContext();
1791 NLATable nlaTable(circuitOp);
1792 hw::InnerSymbolTableCollection innerSymTables;
1793 auto annotations = AnnotationSet(circuitOp);
1794 SmallVector<Annotation> newAnnotations;
1795
1796 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1797 // Check if this annotation was selected.
1798 if (!llvm::is_contained(matches, annoIdx)) {
1799 newAnnotations.push_back(anno);
1800 continue;
1801 }
1802 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1803 assert(anno.isClass(mustDedupAnnoClass) && modulesAttr &&
1804 modulesAttr.size() >= 2);
1805
1806 // Extract module names from the dedup group.
1807 SmallVector<StringAttr> moduleNames;
1808 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
1809 // Parse "~CircuitName|ModuleName" format.
1810 auto refStr = moduleRef.getValue();
1811 auto pipePos = refStr.find('|');
1812 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
1813 auto moduleName = refStr.substr(pipePos + 1);
1814 moduleNames.push_back(StringAttr::get(context, moduleName));
1815 }
1816 }
1817
1818 // Simply drop the annotation if there's only one module.
1819 if (moduleNames.size() < 2)
1820 continue;
1821
1822 // Replace all instances and references to other modules with the
1823 // first module.
1824 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
1825 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1826 }
1827 if (newAnnotations.size() == annotations.size())
1828 return failure();
1829
1830 // Update circuit annotations.
1831 AnnotationSet newAnnoSet(newAnnotations, context);
1832 newAnnoSet.applyToOperation(circuitOp);
1833 return success();
1834 }
1835
1836 std::string getName() const override { return "firrtl-force-dedup"; }
1837 bool acceptSizeIncrease() const override { return true; }
1838
1839private:
1840 /// Replace all references to modules in the dedup group with the canonical
1841 /// module name
1842 void replaceModuleReferences(CircuitOp circuitOp,
1843 ArrayRef<StringAttr> moduleNames,
1844 NLATable &nlaTable,
1845 hw::InnerSymbolTableCollection &innerSymTables) {
1846 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
1847 auto &symbolTable = symbols.getSymbolTable(tableOp);
1848 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
1849 auto *context = circuitOp->getContext();
1850 auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
1851
1852 // Collect the modules.
1853 FModuleLike canonicalModule;
1854 SmallVector<FModuleLike> modulesToReplace;
1855 for (auto name : moduleNames) {
1856 if (auto mod = symbolTable.lookup<FModuleLike>(name)) {
1857 if (!canonicalModule)
1858 canonicalModule = mod;
1859 else
1860 modulesToReplace.push_back(mod);
1861 }
1862 }
1863 if (modulesToReplace.empty())
1864 return;
1865
1866 // Replace all instance references.
1867 auto canonicalName = canonicalModule.getModuleNameAttr();
1868 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
1869 for (auto moduleName : moduleNames) {
1870 if (moduleName == canonicalName)
1871 continue;
1872 auto *symbolOp = symbolTable.lookup(moduleName);
1873 if (!symbolOp)
1874 continue;
1875 for (auto *user : symbolUserMap.getUsers(symbolOp)) {
1876 auto instOp = dyn_cast<InstanceOp>(user);
1877 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
1878 continue;
1879 instOp.setModuleNameAttr(canonicalRef);
1880 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1881 }
1882 }
1883
1884 // Update NLAs to reference the canonical module instead of modules being
1885 // removed using NLATable for better performance.
1886 for (auto oldMod : modulesToReplace) {
1887 SmallVector<hw::HierPathOp> nlaOps(
1888 nlaTable.lookup(oldMod.getModuleNameAttr()));
1889 for (auto nlaOp : nlaOps) {
1890 nlaTable.erase(nlaOp);
1891 StringAttr oldModName = oldMod.getModuleNameAttr();
1892 StringAttr newModName = canonicalName;
1893 SmallVector<Attribute, 4> newPath;
1894 for (auto nameRef : nlaOp.getNamepath()) {
1895 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
1896 if (ref.getModule() == oldModName) {
1897 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
1898 ref = hw::InnerRefAttr::get(newModName, ref.getName());
1899 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
1900 if (oldInst && newInst) {
1901 oldModName = oldInst.getReferencedModuleNameAttr();
1902 newModName = newInst.getReferencedModuleNameAttr();
1903 }
1904 }
1905 newPath.push_back(ref);
1906 } else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
1907 newPath.push_back(FlatSymbolRefAttr::get(newModName));
1908 } else {
1909 newPath.push_back(nameRef);
1910 }
1911 }
1912 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
1913 nlaTable.addNLA(nlaOp);
1914 }
1915 }
1916
1917 // Mark NLAs in modules to be removed.
1918 for (auto module : modulesToReplace) {
1919 nlaRemover.markNLAsInOperation(module);
1920 modulesToErase.insert(module);
1921 }
1922 }
1923
1924 ::detail::SymbolCache symbols;
1925 NLARemover nlaRemover;
1926 SetVector<FModuleLike> modulesToErase;
1927 ModuleSizeCache moduleSizes;
1928};
1929
1930/// A reduction pattern that moves `MustDedup` annotations from a module onto
1931/// its child modules. This pattern iterates over all MustDedup annotations,
1932/// collects all `FInstanceLike` ops in each module of the dedup group, and
1933/// creates new MustDedup annotations for corresponding instances across the
1934/// modules. Each set of corresponding instances becomes a separate match of the
1935/// reduction. The reduction also removes the original MustDedup annotation on
1936/// the parent module.
1937///
1938/// The pattern works by:
1939/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1940/// 2. For each dedup group, collecting all FInstanceLike operations in each
1941/// module
1942/// 3. Grouping corresponding instances across modules by their position/name
1943/// 4. Creating new MustDedup annotations for each group of corresponding
1944/// instances
1945/// 5. Removing the original MustDedup annotation from the circuit
1946struct MustDedupChildren : public OpReduction<CircuitOp> {
1947 void beforeReduction(mlir::ModuleOp op) override {
1948 symbols.clear();
1949 nlaRemover.clear();
1950 }
1951 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1952
1953 /// Collect all MustDedup annotations and create matches for each instance
1954 /// group.
1955 void matches(CircuitOp circuitOp,
1956 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1957 auto annotations = AnnotationSet(circuitOp);
1958 uint64_t matchId = 0;
1959
1960 DenseSet<StringRef> modulesAlreadyInMustDedup;
1961 for (auto [annoIdx, anno] : llvm::enumerate(annotations))
1962 if (anno.isClass(mustDedupAnnoClass))
1963 if (auto modulesAttr = anno.getMember<ArrayAttr>("modules"))
1964 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
1965 if (auto target = tokenizePath(moduleRef))
1966 modulesAlreadyInMustDedup.insert(target->module);
1967
1968 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1969 if (!anno.isClass(mustDedupAnnoClass))
1970 continue;
1971
1972 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1973 if (!modulesAttr || modulesAttr.size() < 2)
1974 continue;
1975
1976 // Process each group of corresponding instances
1977 processInstanceGroups(
1978 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
1979 matchId++;
1980
1981 // Make sure there are at least two distinct modules.
1982 SmallDenseSet<StringAttr, 4> moduleTargets;
1983 for (auto instOp : instanceGroup)
1984 moduleTargets.insert(instOp.getReferencedModuleNameAttr());
1985 if (moduleTargets.size() < 2)
1986 return;
1987
1988 // Make sure none of the modules are not yet in a must dedup
1989 // annotation.
1990 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
1991 return modulesAlreadyInMustDedup.contains(
1992 inst.getReferencedModuleName());
1993 }))
1994 return;
1995
1996 addMatch(1, matchId - 1);
1997 });
1998 }
1999 }
2000
2001 LogicalResult rewriteMatches(CircuitOp circuitOp,
2002 ArrayRef<uint64_t> matches) override {
2003 auto *context = circuitOp->getContext();
2004 auto annotations = AnnotationSet(circuitOp);
2005 SmallVector<Annotation> newAnnotations;
2006 uint64_t matchId = 0;
2007
2008 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2009 if (!anno.isClass(mustDedupAnnoClass)) {
2010 newAnnotations.push_back(anno);
2011 continue;
2012 }
2013
2014 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2015 if (!modulesAttr || modulesAttr.size() < 2) {
2016 newAnnotations.push_back(anno);
2017 continue;
2018 }
2019
2020 processInstanceGroups(
2021 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2022 // Check if this instance group was selected
2023 if (!llvm::is_contained(matches, matchId++))
2024 return;
2025
2026 // Create the list of modules to put into this new annotation.
2027 SmallSetVector<StringAttr, 4> moduleTargets;
2028 for (auto instOp : instanceGroup) {
2029 auto target = TokenAnnoTarget();
2030 target.circuit = circuitOp.getName();
2031 target.module = instOp.getReferencedModuleName();
2032 moduleTargets.insert(target.toStringAttr(context));
2033 }
2034
2035 // Create a new MustDedup annotation for this list of modules.
2036 SmallVector<NamedAttribute> newAnnoAttrs;
2037 newAnnoAttrs.emplace_back(
2038 StringAttr::get(context, "class"),
2039 StringAttr::get(context, mustDedupAnnoClass));
2040 newAnnoAttrs.emplace_back(
2041 StringAttr::get(context, "modules"),
2042 ArrayAttr::get(context,
2043 SmallVector<Attribute>(moduleTargets.begin(),
2044 moduleTargets.end())));
2045
2046 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
2047 newAnnotations.emplace_back(newAnnoDict);
2048 });
2049
2050 // Keep the original annotation around.
2051 newAnnotations.push_back(anno);
2052 }
2053
2054 // Update circuit annotations
2055 AnnotationSet newAnnoSet(newAnnotations, context);
2056 newAnnoSet.applyToOperation(circuitOp);
2057 return success();
2058 }
2059
2060 std::string getName() const override { return "must-dedup-children"; }
2061 bool acceptSizeIncrease() const override { return true; }
2062
2063private:
2064 /// Helper function to process groups of corresponding instances from a
2065 /// MustDedup annotation. Calls the provided lambda for each group of
2066 /// corresponding instances across the modules. Only calls the lambda if there
2067 /// are at least 2 modules.
2068 void processInstanceGroups(
2069 CircuitOp circuitOp, ArrayAttr modulesAttr,
2070 llvm::function_ref<void(ArrayRef<FInstanceLike>)> callback) {
2071 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2072
2073 // Extract module names and get the actual modules
2074 SmallVector<FModuleLike> modules;
2075 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2076 if (auto target = tokenizePath(moduleRef))
2077 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2078 modules.push_back(mod);
2079
2080 // Need at least 2 modules for deduplication
2081 if (modules.size() < 2)
2082 return;
2083
2084 // Collect all FInstanceLike operations from each module and group them by
2085 // name. Instance names are a good key for matching instances across
2086 // modules. But they may not be unique, so we need to be careful to only
2087 // match up instances that are uniquely named within every module.
2088 struct InstanceGroup {
2089 SmallVector<FInstanceLike> instances;
2090 bool nameIsUnique = true;
2091 };
2092 MapVector<StringAttr, InstanceGroup> instanceGroups;
2093 for (auto module : modules) {
2095 module.walk([&](FInstanceLike instOp) {
2096 if (isa<ObjectOp>(instOp.getOperation()))
2097 return;
2098 auto name = instOp.getInstanceNameAttr();
2099 auto &group = instanceGroups[name];
2100 if (nameCounts[name]++ > 1)
2101 group.nameIsUnique = false;
2102 group.instances.push_back(instOp);
2103 });
2104 }
2105
2106 // Call the callback for each group of instances that are uniquely named and
2107 // consist of at least 2 instances.
2108 for (auto &[name, group] : instanceGroups)
2109 if (group.nameIsUnique && group.instances.size() >= 2)
2110 callback(group.instances);
2111 }
2112
2113 ::detail::SymbolCache symbols;
2114 NLARemover nlaRemover;
2115};
2116
2117} // namespace
2118
2119//===----------------------------------------------------------------------===//
2120// Reduction Registration
2121//===----------------------------------------------------------------------===//
2122
2125 // Gather a list of reduction patterns that we should try. Ideally these are
2126 // assigned reasonable benefit indicators (higher benefit patterns are
2127 // prioritized). For example, things that can knock out entire modules while
2128 // being cheap should be tried first (and thus have higher benefit), before
2129 // trying to tweak operands of individual arithmetic ops.
2130 patterns.add<SimplifyResets, 34>();
2131 patterns.add<ForceDedup, 33>();
2132 patterns.add<MustDedupChildren, 32>();
2133 patterns.add<AnnotationRemover, 31>();
2134 patterns.add<ModuleSwapper, 30>();
2135 patterns.add<PassReduction, 29>(
2136 getContext(),
2137 firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,
2138 true);
2139 patterns.add<PassReduction, 28>(getContext(),
2140 firrtl::createLowerCHIRRTLPass(), true, true);
2141 patterns.add<PassReduction, 27>(getContext(), firrtl::createInferWidths(),
2142 true, true);
2143 patterns.add<PassReduction, 26>(getContext(), firrtl::createInferResets(),
2144 true, true);
2145 patterns.add<FIRRTLModuleExternalizer, 25>();
2146 patterns.add<InstanceStubber, 24>();
2147 patterns.add<MemoryStubber, 23>();
2148 patterns.add<EagerInliner, 22>();
2149 patterns.add<ObjectInliner, 22>();
2150 patterns.add<PassReduction, 21>(getContext(),
2151 firrtl::createLowerFIRRTLTypes(), true, true);
2152 patterns.add<PassReduction, 20>(getContext(), firrtl::createExpandWhens(),
2153 true, true);
2154 patterns.add<PassReduction, 19>(getContext(), firrtl::createInliner());
2155 patterns.add<PassReduction, 18>(getContext(), firrtl::createIMConstProp());
2156 patterns.add<PassReduction, 17>(
2157 getContext(),
2158 firrtl::createRemoveUnusedPorts({/*ignoreDontTouch=*/true}));
2159 patterns.add<NodeSymbolRemover, 15>();
2160 patterns.add<ConnectForwarder, 14>();
2161 patterns.add<ConnectInvalidator, 13>();
2162 patterns.add<Constantifier, 12>();
2163 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2164 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2165 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2166 patterns.add<DetachSubaccesses, 7>();
2167 patterns.add<RootPortPruner, 5>();
2168 patterns.add<ExtmoduleInstanceRemover, 4>();
2169 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2170 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2171 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2172 patterns.add<ModuleInternalNameSanitizer, 0>();
2173 patterns.add<ModuleNameSanitizer, 0>();
2174}
2175
2177 mlir::DialectRegistry &registry) {
2178 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
2179 dialect->addInterfaces<FIRRTLReducePatternDialectInterface>();
2180 });
2181}
assert(baseType &&"element must be base type")
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
Definition NLATable.cpp:41
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
Definition NLATable.cpp:58
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Definition NLATable.cpp:68
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
connect(destination, source)
Definition support.py:39
@ None
Don't explicitly preserve any named values.
Definition Passes.h:52
constexpr const char * excludeFromFullResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
constexpr const char * fullResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
constexpr const char * fullAsyncResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
constexpr const char * mustDedupAnnoClass
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
Definition LayerSet.h:42
constexpr const char * ignoreFullAsyncResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition HWOps.cpp:36
void info(Twine message)
Definition LSPUtils.cpp:20
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
A reduction pattern that applies an mlir::Pass.
Definition Reduction.h:142
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition Reduction.h:35
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:79
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:66
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:96
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:50
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)