CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLReductions.cpp
Go to the documentation of this file.
1//===- FIRRTLReductions.cpp - Reduction patterns for the FIRRTL dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
24#include "mlir/Analysis/TopologicalSortUtils.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/Matchers.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/Support/Debug.h"
32
33#define DEBUG_TYPE "firrtl-reductions"
34
35using namespace mlir;
36using namespace circt;
37using namespace firrtl;
38using llvm::MapVector;
39using llvm::SmallDenseSet;
40using llvm::SmallSetVector;
41
42//===----------------------------------------------------------------------===//
43// Utilities
44//===----------------------------------------------------------------------===//
45
46namespace detail {
47/// A utility doing lazy construction of `SymbolTable`s and `SymbolUserMap`s,
48/// which is handy for reductions that need to look up a lot of symbols.
50 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
51
52 SymbolTable &getSymbolTable(Operation *op) {
53 return tables->getSymbolTable(op);
54 }
55 SymbolTable &getNearestSymbolTable(Operation *op) {
56 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
57 }
58
59 SymbolUserMap &getSymbolUserMap(Operation *op) {
60 auto it = userMaps.find(op);
61 if (it != userMaps.end())
62 return it->second;
63 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
64 }
65 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
66 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
67 }
68
69 void clear() {
70 tables = std::make_unique<SymbolTableCollection>();
71 userMaps.clear();
72 }
73
74private:
75 std::unique_ptr<SymbolTableCollection> tables;
77};
78} // namespace detail
79
80/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
81/// optional in case another type of module is instantiated.
82static std::optional<firrtl::FModuleOp>
83findInstantiatedModule(firrtl::InstanceOp instOp,
84 ::detail::SymbolCache &symbols) {
85 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
86 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
87 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
88 return moduleOp ? std::optional(moduleOp) : std::nullopt;
89}
90
91/// Utility to track the transitive size of modules.
93 void clear() { moduleSizes.clear(); }
94
95 uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols) {
96 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
97 return it->second;
98 uint64_t size = 1;
99 module->walk([&](Operation *op) {
100 size += 1;
101 if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
102 if (auto instModule = findInstantiatedModule(instOp, symbols))
103 size += getModuleSize(*instModule, symbols);
104 });
105 moduleSizes.insert({module, size});
106 return size;
107 }
108
109private:
110 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
111};
112
113/// Check that all connections to a value are invalids.
114static bool onlyInvalidated(Value arg) {
115 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
116 auto *op = use.getOwner();
117 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
118 return false;
119 if (use.getOperandNumber() != 0)
120 return false;
121 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
122 return false;
123 return true;
124 });
125}
126
127/// A tracker for track NLAs affected by a reduction. Performs the necessary
128/// cleanup steps in order to maintain IR validity after the reduction has
129/// applied. For example, removing an instance that forms part of an NLA path
130/// requires that NLA to be removed as well.
132 /// Clear the set of marked NLAs. Call this before attempting a reduction.
133 void clear() { nlasToRemove.clear(); }
134
135 /// Remove all marked annotations. Call this after applying a reduction in
136 /// order to validate the IR.
137 void remove(mlir::ModuleOp module) {
138 unsigned numRemoved = 0;
139 (void)numRemoved;
140 SymbolTableCollection symbolTables;
141 for (Operation &rootOp : *module.getBody()) {
142 if (!isa<firrtl::CircuitOp>(&rootOp))
143 continue;
144 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
145 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
146 for (auto sym : nlasToRemove) {
147 if (auto *op = symbolTable.lookup(sym)) {
148 if (symbolUserMap.useEmpty(op)) {
149 ++numRemoved;
150 op->erase();
151 }
152 }
153 }
154 }
155 LLVM_DEBUG({
156 unsigned numLost = nlasToRemove.size() - numRemoved;
157 if (numRemoved > 0 || numLost > 0) {
158 llvm::dbgs() << "Removed " << numRemoved << " NLAs";
159 if (numLost > 0)
160 llvm::dbgs() << " (" << numLost << " no longer there)";
161 llvm::dbgs() << "\n";
162 }
163 });
164 }
165
166 /// Mark all NLAs referenced in the given annotation as to be removed. This
167 /// can be an entire array or dictionary of annotations, and the function will
168 /// descend into child annotations appropriately.
169 void markNLAsInAnnotation(Attribute anno) {
170 if (auto dict = dyn_cast<DictionaryAttr>(anno)) {
171 if (auto field = dict.getAs<FlatSymbolRefAttr>("circt.nonlocal"))
172 nlasToRemove.insert(field.getAttr());
173 for (auto namedAttr : dict)
174 markNLAsInAnnotation(namedAttr.getValue());
175 } else if (auto array = dyn_cast<ArrayAttr>(anno)) {
176 for (auto attr : array)
177 markNLAsInAnnotation(attr);
178 }
179 }
180
181 /// Mark all NLAs referenced in an operation. Also traverses all nested
182 /// operations. Call this before removing an operation, to mark any associated
183 /// NLAs as to be removed as well.
184 void markNLAsInOperation(Operation *op) {
185 op->walk([&](Operation *op) {
186 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
187 markNLAsInAnnotation(annos);
188 });
189 }
190
191 /// The set of NLAs to remove, identified by their symbol.
192 llvm::DenseSet<StringAttr> nlasToRemove;
193};
194
195//===----------------------------------------------------------------------===//
196// Reduction patterns
197//===----------------------------------------------------------------------===//
198
199namespace {
200
201/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
202struct FIRRTLModuleExternalizer : public OpReduction<FModuleOp> {
203 void beforeReduction(mlir::ModuleOp op) override {
204 nlaRemover.clear();
205 symbols.clear();
206 moduleSizes.clear();
207 innerSymUses = reduce::InnerSymbolUses(op);
208 }
209 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
210
211 uint64_t match(FModuleOp module) override {
212 if (innerSymUses.hasInnerRef(module))
213 return 0;
214 return moduleSizes.getModuleSize(module, symbols);
215 }
216
217 LogicalResult rewrite(FModuleOp module) override {
218 // Hack up a list of known layers.
219 LayerSet layers;
220 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
221 for (auto attr : module.getPortTypes()) {
222 auto type = cast<TypeAttr>(attr).getValue();
223 if (auto refType = type_dyn_cast<RefType>(type))
224 if (auto layer = refType.getLayer())
225 layers.insert(layer);
226 }
227 SmallVector<Attribute, 4> layersArray;
228 layersArray.reserve(layers.size());
229 for (auto layer : layers)
230 layersArray.push_back(layer);
231
232 nlaRemover.markNLAsInOperation(module);
233 OpBuilder builder(module);
234 auto extmodule = FExtModuleOp::create(
235 builder, module->getLoc(),
236 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
237 module.getConventionAttr(), module.getPorts(),
238 builder.getArrayAttr(layersArray), StringRef(),
239 module.getAnnotationsAttr());
240 SymbolTable::setSymbolVisibility(extmodule,
241 SymbolTable::getSymbolVisibility(module));
242 module->erase();
243 return success();
244 }
245
246 std::string getName() const override { return "firrtl-module-externalizer"; }
247
248 ::detail::SymbolCache symbols;
249 NLARemover nlaRemover;
250 reduce::InnerSymbolUses innerSymUses;
251 ModuleSizeCache moduleSizes;
252};
253
254/// Invalidate all the leaf fields of a value with a given flippedness by
255/// connecting an invalid value to them. This function handles different FIRRTL
256/// types appropriately:
257/// - Ref types (probes): Creates wire infrastructure with ref.send/ref.define
258/// and invalidates the underlying wire.
259/// - Bundle/Vector types: Recursively descends into elements.
260/// - Base types: Creates InvalidValueOp and connects it.
261/// - Property types: Creates UnknownValueOp and assigns it.
262///
263/// This is useful for ensuring that all output ports of an instance or memory
264/// (including those nested in bundles) are properly invalidated.
265static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
266 TieOffCache &tieOffCache, bool flip = false) {
267 auto type = type_dyn_cast<FIRRTLType>(value.getType());
268 if (!type)
269 return;
270
271 // Handle ref types (probes) by creating wires and defining them properly.
272 if (auto refType = type_dyn_cast<RefType>(type)) {
273 // Input probes are illegal in FIRRTL.
274 assert(!flip && "input probes are not allowed");
275
276 auto underlyingType = refType.getType();
277
278 if (!refType.getForceable()) {
279 // For probe types: create underlying wire, ref.send, ref.define, and
280 // invalidate.
281 auto targetWire = WireOp::create(builder, underlyingType);
282 auto refSend = builder.create<RefSendOp>(targetWire.getResult());
283 builder.create<RefDefineOp>(value, refSend.getResult());
284
285 // Invalidate the underlying wire.
286 auto invalid = tieOffCache.getInvalid(underlyingType);
287 MatchingConnectOp::create(builder, targetWire.getResult(), invalid);
288 return;
289 }
290
291 // For rwprobe types: create forceable wire, ref.define, and invalidate.
292 auto forceableWire =
293 WireOp::create(builder, underlyingType,
294 /*name=*/"", NameKindEnum::DroppableName,
295 /*annotations=*/ArrayRef<Attribute>{},
296 /*innerSym=*/StringAttr{},
297 /*forceable=*/true);
298
299 // The forceable wire returns both the wire and the rwprobe.
300 auto targetWire = forceableWire.getResult();
301 auto forceableRef = forceableWire.getDataRef();
302
303 builder.create<RefDefineOp>(value, forceableRef);
304
305 // Invalidate the underlying wire.
306 auto invalid = tieOffCache.getInvalid(underlyingType);
307 MatchingConnectOp::create(builder, targetWire, invalid);
308 return;
309 }
310
311 // Descend into bundles by creating subfield ops.
312 if (auto bundleType = type_dyn_cast<BundleType>(type)) {
313 for (auto element : llvm::enumerate(bundleType.getElements())) {
314 auto subfield = builder.createOrFold<SubfieldOp>(value, element.index());
315 invalidateOutputs(builder, subfield, tieOffCache,
316 flip ^ element.value().isFlip);
317 if (subfield.use_empty())
318 subfield.getDefiningOp()->erase();
319 }
320 return;
321 }
322
323 // Descend into vectors by creating subindex ops.
324 if (auto vectorType = type_dyn_cast<FVectorType>(type)) {
325 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
326 auto subindex = builder.createOrFold<SubindexOp>(value, i);
327 invalidateOutputs(builder, subindex, tieOffCache, flip);
328 if (subindex.use_empty())
329 subindex.getDefiningOp()->erase();
330 }
331 return;
332 }
333
334 // Only drive outputs.
335 if (flip)
336 return;
337
338 // Create InvalidValueOp for FIRRTLBaseType.
339 if (auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
340 auto invalid = tieOffCache.getInvalid(baseType);
341 ConnectOp::create(builder, value, invalid);
342 return;
343 }
344
345 // For property types, use UnknownValueOp to tie off the connection.
346 if (auto propType = type_dyn_cast<PropertyType>(type)) {
347 auto unknown = tieOffCache.getUnknown(propType);
348 builder.create<PropAssignOp>(value, unknown);
349 }
350}
351
352/// Connect a value to every leave of a destination value.
353static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
354 Value value) {
355 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
356 if (!type)
357 return;
358 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
359 for (auto element : llvm::enumerate(bundleType.getElements()))
360 connectToLeafs(builder,
361 firrtl::SubfieldOp::create(builder, dest, element.index()),
362 value);
363 return;
364 }
365 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
366 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
367 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
368 value);
369 return;
370 }
371 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
372 if (!valueType)
373 return;
374 auto destWidth = type.getBitWidthOrSentinel();
375 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
376 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
377 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
378 if (!isa<firrtl::UIntType>(type)) {
379 if (isa<firrtl::SIntType>(type))
380 value = firrtl::AsSIntPrimOp::create(builder, value);
381 else
382 return;
383 }
384 firrtl::ConnectOp::create(builder, dest, value);
385}
386
387/// Reduce all leaf fields of a value through an XOR tree.
388static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
389 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
390 if (!type)
391 return;
392 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
393 for (auto element : llvm::enumerate(bundleType.getElements()))
394 reduceXor(
395 builder, into,
396 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
397 return;
398 }
399 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
400 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
401 reduceXor(builder, into,
402 builder.createOrFold<firrtl::SubindexOp>(value, i));
403 return;
404 }
405 if (!isa<firrtl::UIntType>(type)) {
406 if (isa<firrtl::SIntType>(type))
407 value = firrtl::AsUIntPrimOp::create(builder, value);
408 else
409 return;
410 }
411 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
412}
413
414/// A sample reduction pattern that maps `firrtl.instance` to a set of
415/// invalidated wires. This often shortcuts a long iterative process of connect
416/// invalidation, module externalization, and wire stripping
417struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
418 void beforeReduction(mlir::ModuleOp op) override {
419 erasedInsts.clear();
420 erasedModules.clear();
421 symbols.clear();
422 nlaRemover.clear();
423 moduleSizes.clear();
424 }
425 void afterReduction(mlir::ModuleOp op) override {
426 // Look into deleted modules to find additional instances that are no longer
427 // instantiated anywhere.
428 SmallVector<Operation *> worklist;
429 auto deadInsts = erasedInsts;
430 for (auto *op : erasedModules)
431 worklist.push_back(op);
432 while (!worklist.empty()) {
433 auto *op = worklist.pop_back_val();
434 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
435 op->walk([&](firrtl::InstanceOp instOp) {
436 auto moduleOp = cast<firrtl::FModuleLike>(
437 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
438 deadInsts.insert(instOp);
439 if (llvm::all_of(
440 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
441 [&](Operation *user) { return deadInsts.contains(user); })) {
442 LLVM_DEBUG(llvm::dbgs() << "- Removing transitively unused module `"
443 << moduleOp.getModuleName() << "`\n");
444 erasedModules.insert(moduleOp);
445 worklist.push_back(moduleOp);
446 }
447 });
448 }
449
450 for (auto *op : erasedInsts)
451 op->erase();
452 for (auto *op : erasedModules)
453 op->erase();
454 nlaRemover.remove(op);
455 }
456
457 uint64_t match(firrtl::InstanceOp instOp) override {
458 if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
459 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
460 return 0;
461 }
462
463 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
464 LLVM_DEBUG(llvm::dbgs()
465 << "Stubbing instance `" << instOp.getName() << "`\n");
466 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
467 TieOffCache tieOffCache(builder);
468 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
469 auto result = instOp.getResult(i);
470 auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" +
471 instOp.getPortName(i));
472 auto wire =
473 firrtl::WireOp::create(builder, result.getType(), name,
474 firrtl::NameKindEnum::DroppableName,
475 instOp.getPortAnnotation(i), StringAttr{})
476 .getResult();
477 invalidateOutputs(builder, wire, tieOffCache,
478 instOp.getPortDirection(i) == firrtl::Direction::In);
479 result.replaceAllUsesWith(wire);
480 }
481 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
482 auto moduleOp = cast<firrtl::FModuleLike>(
483 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
484 nlaRemover.markNLAsInOperation(instOp);
485 erasedInsts.insert(instOp);
486 if (llvm::all_of(
487 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
488 [&](Operation *user) { return erasedInsts.contains(user); })) {
489 LLVM_DEBUG(llvm::dbgs() << "- Removing now unused module `"
490 << moduleOp.getModuleName() << "`\n");
491 erasedModules.insert(moduleOp);
492 }
493 return success();
494 }
495
496 std::string getName() const override { return "instance-stubber"; }
497 bool acceptSizeIncrease() const override { return true; }
498
499 ::detail::SymbolCache symbols;
500 NLARemover nlaRemover;
501 llvm::DenseSet<Operation *> erasedInsts;
502 llvm::DenseSet<Operation *> erasedModules;
503 ModuleSizeCache moduleSizes;
504};
505
506/// A sample reduction pattern that maps `firrtl.mem` to a set of invalidated
507/// wires.
508struct MemoryStubber : public OpReduction<firrtl::MemOp> {
509 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
510 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
511 LogicalResult rewrite(firrtl::MemOp memOp) override {
512 LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
513 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
514 TieOffCache tieOffCache(builder);
515 Value xorInputs;
516 SmallVector<Value> outputs;
517 for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
518 auto result = memOp.getResult(i);
519 auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" +
520 memOp.getPortName(i));
521 auto wire =
522 firrtl::WireOp::create(builder, result.getType(), name,
523 firrtl::NameKindEnum::DroppableName,
524 memOp.getPortAnnotation(i), StringAttr{})
525 .getResult();
526 invalidateOutputs(builder, wire, tieOffCache, true);
527 result.replaceAllUsesWith(wire);
528
529 // Isolate the input and output data fields of the port.
530 Value input, output;
531 switch (memOp.getPortKind(i)) {
532 case firrtl::MemOp::PortKind::Read:
533 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
534 break;
535 case firrtl::MemOp::PortKind::Write:
536 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
537 break;
538 case firrtl::MemOp::PortKind::ReadWrite:
539 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
540 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
541 break;
542 case firrtl::MemOp::PortKind::Debug:
543 output = wire;
544 break;
545 }
546
547 if (!isa<firrtl::RefType>(result.getType())) {
548 // Reduce all input ports to a single one through an XOR tree.
549 unsigned numFields =
550 cast<firrtl::BundleType>(wire.getType()).getNumElements();
551 for (unsigned i = 0; i != numFields; ++i) {
552 if (i != 2 && i != 3 && i != 5)
553 reduceXor(builder, xorInputs,
554 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
555 }
556 if (input)
557 reduceXor(builder, xorInputs, input);
558 }
559
560 // Track the output port to hook it up to the XORd input later.
561 if (output)
562 outputs.push_back(output);
563 }
564
565 // Hook up the outputs.
566 for (auto output : outputs)
567 connectToLeafs(builder, output, xorInputs);
568
569 nlaRemover.markNLAsInOperation(memOp);
570 memOp->erase();
571 return success();
572 }
573 std::string getName() const override { return "memory-stubber"; }
574 bool acceptSizeIncrease() const override { return true; }
575 NLARemover nlaRemover;
576};
577
578/// Check whether an operation interacts with flows in any way, which can make
579/// replacement and operand forwarding harder in some cases.
580static bool isFlowSensitiveOp(Operation *op) {
581 return isa<WireOp, RegOp, RegResetOp, InstanceOp, SubfieldOp, SubindexOp,
582 SubaccessOp, ObjectSubfieldOp>(op);
583}
584
585/// A sample reduction pattern that replaces all uses of an operation with one
586/// of its operands. This can help pruning large parts of the expression tree
587/// rapidly.
588template <unsigned OpNum>
589struct FIRRTLOperandForwarder : public Reduction {
590 uint64_t match(Operation *op) override {
591 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
592 return 0;
593 if (isFlowSensitiveOp(op))
594 return 0;
595 auto resultTy =
596 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
597 auto opTy =
598 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
599 return resultTy && opTy &&
600 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
601 (resultTy.getBitWidthOrSentinel() == -1) ==
602 (opTy.getBitWidthOrSentinel() == -1) &&
603 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
604 }
605 LogicalResult rewrite(Operation *op) override {
606 assert(match(op));
607 ImplicitLocOpBuilder builder(op->getLoc(), op);
608 auto result = op->getResult(0);
609 auto operand = op->getOperand(OpNum);
610 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
611 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
612 auto resultWidth = resultTy.getBitWidthOrSentinel();
613 auto operandWidth = operandTy.getBitWidthOrSentinel();
614 Value newOp;
615 if (resultWidth < operandWidth)
616 newOp =
617 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
618 else if (resultWidth > operandWidth)
619 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
620 else
621 newOp = operand;
622 LLVM_DEBUG(llvm::dbgs() << "Forwarding " << newOp << " in " << *op << "\n");
623 result.replaceAllUsesWith(newOp);
624 reduce::pruneUnusedOps(op, *this);
625 return success();
626 }
627 std::string getName() const override {
628 return ("firrtl-operand" + Twine(OpNum) + "-forwarder").str();
629 }
630};
631
632/// A sample reduction pattern that replaces FIRRTL operations with a constant
633/// zero of their type.
634struct Constantifier : public Reduction {
635 void beforeReduction(mlir::ModuleOp op) override {
636 symbols.clear();
637
638 // Find valid dummy classes that we can use for anyref casts.
639 anyrefCastDummy.clear();
640 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
641 for (auto classOp : circuitOp.getOps<ClassOp>()) {
642 if (classOp.getArguments().empty() && classOp.getBodyBlock()->empty()) {
643 anyrefCastDummy.insert({circuitOp, classOp});
644 anyrefCastDummyNames[circuitOp].insert(classOp.getNameAttr());
645 }
646 }
647 return WalkResult::skip();
648 });
649 }
650
651 uint64_t match(Operation *op) override {
652 if (op->hasTrait<OpTrait::ConstantLike>()) {
653 Attribute attr;
654 if (!matchPattern(op, m_Constant(&attr)))
655 return 0;
656 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
657 if (intAttr.getValue().isZero())
658 return 0;
659 if (auto strAttr = dyn_cast<StringAttr>(attr))
660 if (strAttr.empty())
661 return 0;
662 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
663 if (floatAttr.getValue().isZero())
664 return 0;
665 }
666 if (auto listOp = dyn_cast<ListCreateOp>(op))
667 if (listOp.getElements().empty())
668 return 0;
669 if (auto pathOp = dyn_cast<UnresolvedPathOp>(op))
670 if (pathOp.getTarget().empty())
671 return 0;
672
673 // Don't replace anyref casts that already target a dummy class.
674 if (auto anyrefCastOp = dyn_cast<ObjectAnyRefCastOp>(op)) {
675 auto circuitOp = anyrefCastOp->getParentOfType<CircuitOp>();
676 auto className =
677 anyrefCastOp.getInput().getType().getNameAttr().getAttr();
678 if (anyrefCastDummyNames[circuitOp].contains(className))
679 return 0;
680 }
681
682 if (op->getNumResults() != 1)
683 return 0;
684 if (op->hasAttr("inner_sym"))
685 return 0;
686 if (isFlowSensitiveOp(op))
687 return 0;
688 return isa<UIntType, SIntType, StringType, FIntegerType, BoolType,
689 DoubleType, ListType, PathType, AnyRefType>(
690 op->getResult(0).getType());
691 }
692
693 LogicalResult rewrite(Operation *op) override {
694 OpBuilder builder(op);
695 auto type = op->getResult(0).getType();
696
697 // Handle UInt/SInt types.
698 if (isa<UIntType, SIntType>(type)) {
699 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
700 if (width == -1)
701 width = 64;
702 auto newOp = ConstantOp::create(builder, op->getLoc(), type,
703 APSInt(width, isa<UIntType>(type)));
704 op->replaceAllUsesWith(newOp);
705 reduce::pruneUnusedOps(op, *this);
706 return success();
707 }
708
709 // Handle property string types.
710 if (isa<StringType>(type)) {
711 auto attr = builder.getStringAttr("");
712 auto newOp = StringConstantOp::create(builder, op->getLoc(), attr);
713 op->replaceAllUsesWith(newOp);
714 reduce::pruneUnusedOps(op, *this);
715 return success();
716 }
717
718 // Handle property integer types.
719 if (isa<FIntegerType>(type)) {
720 auto attr = builder.getIntegerAttr(builder.getIntegerType(64, true), 0);
721 auto newOp = FIntegerConstantOp::create(builder, op->getLoc(), attr);
722 op->replaceAllUsesWith(newOp);
723 reduce::pruneUnusedOps(op, *this);
724 return success();
725 }
726
727 // Handle property boolean types.
728 if (isa<BoolType>(type)) {
729 auto attr = builder.getBoolAttr(false);
730 auto newOp = BoolConstantOp::create(builder, op->getLoc(), attr);
731 op->replaceAllUsesWith(newOp);
732 reduce::pruneUnusedOps(op, *this);
733 return success();
734 }
735
736 // Handle property double types.
737 if (isa<DoubleType>(type)) {
738 auto attr = builder.getFloatAttr(builder.getF64Type(), 0.0);
739 auto newOp = DoubleConstantOp::create(builder, op->getLoc(), attr);
740 op->replaceAllUsesWith(newOp);
741 reduce::pruneUnusedOps(op, *this);
742 return success();
743 }
744
745 // Handle property list types.
746 if (isa<ListType>(type)) {
747 auto newOp =
748 ListCreateOp::create(builder, op->getLoc(), type, ValueRange{});
749 op->replaceAllUsesWith(newOp);
750 reduce::pruneUnusedOps(op, *this);
751 return success();
752 }
753
754 // Handle property path types.
755 if (isa<PathType>(type)) {
756 auto newOp = UnresolvedPathOp::create(builder, op->getLoc(), "");
757 op->replaceAllUsesWith(newOp);
758 reduce::pruneUnusedOps(op, *this);
759 return success();
760 }
761
762 // Handle anyref types.
763 if (isa<AnyRefType>(type)) {
764 auto circuitOp = op->getParentOfType<CircuitOp>();
765 auto &dummy = anyrefCastDummy[circuitOp];
766 if (!dummy) {
767 OpBuilder::InsertionGuard guard(builder);
768 builder.setInsertionPointToStart(circuitOp.getBodyBlock());
769 auto &symbolTable = symbols.getNearestSymbolTable(op);
770 dummy = ClassOp::create(builder, op->getLoc(), "Dummy", {}, {});
771 symbolTable.insert(dummy);
772 anyrefCastDummyNames[circuitOp].insert(dummy.getNameAttr());
773 }
774 auto objectOp = ObjectOp::create(builder, op->getLoc(), dummy, "dummy");
775 auto anyrefOp =
776 ObjectAnyRefCastOp::create(builder, op->getLoc(), objectOp);
777 op->replaceAllUsesWith(anyrefOp);
778 reduce::pruneUnusedOps(op, *this);
779 return success();
780 }
781
782 return failure();
783 }
784
785 std::string getName() const override { return "firrtl-constantifier"; }
786 bool acceptSizeIncrease() const override { return true; }
787
788 ::detail::SymbolCache symbols;
790 SmallDenseMap<CircuitOp, DenseSet<StringAttr>, 2> anyrefCastDummyNames;
791};
792
793/// A sample reduction pattern that replaces the right-hand-side of
794/// `firrtl.connect` and `firrtl.matchingconnect` operations with a
795/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
796/// connects and creates opportunities for reduction in DCE/CSE.
797struct ConnectInvalidator : public Reduction {
798 uint64_t match(Operation *op) override {
799 if (!isa<FConnectLike>(op))
800 return 0;
801 if (auto *srcOp = op->getOperand(1).getDefiningOp())
802 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
803 isa<InvalidValueOp>(srcOp))
804 return 0;
805 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
806 return type && type.isPassive();
807 }
808 LogicalResult rewrite(Operation *op) override {
809 assert(match(op));
810 auto rhs = op->getOperand(1);
811 OpBuilder builder(op);
812 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
813 auto *rhsOp = rhs.getDefiningOp();
814 op->setOperand(1, invOp);
815 if (rhsOp)
816 reduce::pruneUnusedOps(rhsOp, *this);
817 return success();
818 }
819 std::string getName() const override { return "connect-invalidator"; }
820 bool acceptSizeIncrease() const override { return true; }
821};
822
823/// A reduction pattern that removes FIRRTL annotations from ports and
824/// operations. This generates one match per annotation and port annotation,
825/// allowing selective removal of individual annotations.
826struct AnnotationRemover : public Reduction {
827 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
828 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
829
830 void matches(Operation *op,
831 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
832 uint64_t matchId = 0;
833
834 // Generate matches for regular annotations
835 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
836 for (unsigned i = 0; i < annos.size(); ++i)
837 addMatch(1, matchId++);
838
839 // Generate matches for port annotations
840 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations"))
841 for (auto portAnnoArray : portAnnos)
842 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
843 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
844 addMatch(1, matchId++);
845 }
846
847 LogicalResult rewriteMatches(Operation *op,
848 ArrayRef<uint64_t> matches) override {
849 // Convert matches to a set for fast lookup
850 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
851
852 // Lambda to process annotations and filter out matched ones
853 uint64_t matchId = 0;
854 auto processAnnotations =
855 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
856 SmallVector<Attribute> newAnnotations;
857 for (auto anno : annotations) {
858 if (!matchesSet.contains(matchId)) {
859 newAnnotations.push_back(anno);
860 } else {
861 // Mark NLAs in the removed annotation for cleanup
862 nlaRemover.markNLAsInAnnotation(anno);
863 }
864 matchId++;
865 }
866 return ArrayAttr::get(op->getContext(), newAnnotations);
867 };
868
869 // Remove regular annotations
870 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
871 op->setAttr("annotations", processAnnotations(annos.getValue()));
872 }
873
874 // Remove port annotations
875 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
876 SmallVector<Attribute> newPortAnnos;
877 for (auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
878 newPortAnnos.push_back(
879 processAnnotations(portAnnoArrayAttr.getValue()));
880 }
881 op->setAttr("portAnnotations",
882 ArrayAttr::get(op->getContext(), newPortAnnos));
883 }
884
885 return success();
886 }
887
888 std::string getName() const override { return "annotation-remover"; }
889 NLARemover nlaRemover;
890};
891
892/// A reduction pattern that replaces ResetType with UInt<1> across an entire
893/// circuit. This walks all operations in the circuit and replaces ResetType in
894/// results, block arguments, and attributes.
895struct SimplifyResets : public OpReduction<CircuitOp> {
896 uint64_t match(CircuitOp circuit) override {
897 uint64_t numResets = 0;
898 AttrTypeWalker walker;
899 walker.addWalk([&](ResetType type) { ++numResets; });
900
901 circuit.walk([&](Operation *op) {
902 for (auto result : op->getResults())
903 walker.walk(result.getType());
904
905 for (auto &region : op->getRegions())
906 for (auto &block : region)
907 for (auto arg : block.getArguments())
908 walker.walk(arg.getType());
909
910 walker.walk(op->getAttrDictionary());
911 });
912
913 return numResets;
914 }
915
916 LogicalResult rewrite(CircuitOp circuit) override {
917 auto uint1Type = UIntType::get(circuit->getContext(), 1, false);
918 auto constUint1Type = UIntType::get(circuit->getContext(), 1, true);
919
920 AttrTypeReplacer replacer;
921 replacer.addReplacement([&](ResetType type) {
922 return type.isConst() ? constUint1Type : uint1Type;
923 });
924 replacer.recursivelyReplaceElementsIn(circuit, /*replaceAttrs=*/true,
925 /*replaceLocs=*/false,
926 /*replaceTypes=*/true);
927
928 // Remove annotations related to InferResets pass
929 circuit.walk([&](Operation *op) {
930 // Remove operation annotations
932 return anno.isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
933 fullAsyncResetAnnoClass,
934 ignoreFullAsyncResetAnnoClass);
935 });
936
937 // Remove port annotations for module-like operations
938 if (auto module = dyn_cast<FModuleLike>(op)) {
939 AnnotationSet::removePortAnnotations(module, [&](unsigned portIdx,
940 Annotation anno) {
941 return anno.isClass(fullResetAnnoClass, excludeFromFullResetAnnoClass,
942 fullAsyncResetAnnoClass,
943 ignoreFullAsyncResetAnnoClass);
944 });
945 }
946 });
947
948 return success();
949 }
950
951 std::string getName() const override { return "firrtl-simplify-resets"; }
952 bool acceptSizeIncrease() const override { return true; }
953};
954
955/// A sample reduction pattern that removes ports from the root `firrtl.module`
956/// if the port is not used or just invalidated.
957struct RootPortPruner : public OpReduction<firrtl::FModuleOp> {
958 void matches(firrtl::FModuleOp module,
959 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
960 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
961 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
962 return;
963
964 // Generate one match per port that can be removed
965 size_t numPorts = module.getNumPorts();
966 for (unsigned i = 0; i != numPorts; ++i) {
967 if (onlyInvalidated(module.getArgument(i)))
968 addMatch(1, i);
969 }
970 }
971
972 LogicalResult rewriteMatches(firrtl::FModuleOp module,
973 ArrayRef<uint64_t> matches) override {
974 // Build a BitVector of ports to remove
975 llvm::BitVector dropPorts(module.getNumPorts());
976 for (auto portIdx : matches)
977 dropPorts.set(portIdx);
978
979 // Erase users of the ports being removed
980 for (auto portIdx : matches) {
981 for (auto *user :
982 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
983 user->erase();
984 }
985
986 // Remove the ports from the module
987 module.erasePorts(dropPorts);
988 return success();
989 }
990
991 std::string getName() const override { return "root-port-pruner"; }
992};
993
994/// A reduction pattern that removes all ports from the root `firrtl.extmodule`.
995/// Since extmodules have no body, all ports can be safely removed for reduction
996/// purposes.
997struct RootExtmodulePortPruner : public OpReduction<firrtl::FExtModuleOp> {
998 void matches(firrtl::FExtModuleOp module,
999 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1000 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
1001 if (!circuit || circuit.getNameAttr() != module.getNameAttr())
1002 return;
1003
1004 // Generate one match per port (all ports can be removed from root
1005 // extmodule)
1006 size_t numPorts = module.getNumPorts();
1007 for (unsigned i = 0; i != numPorts; ++i)
1008 addMatch(1, i);
1009 }
1010
1011 LogicalResult rewriteMatches(firrtl::FExtModuleOp module,
1012 ArrayRef<uint64_t> matches) override {
1013 if (matches.empty())
1014 return failure();
1015
1016 // Build a BitVector of ports to remove
1017 llvm::BitVector dropPorts(module.getNumPorts());
1018 for (auto portIdx : matches)
1019 dropPorts.set(portIdx);
1020
1021 // Remove the ports from the module
1022 module.erasePorts(dropPorts);
1023 return success();
1024 }
1025
1026 std::string getName() const override { return "root-extmodule-port-pruner"; }
1027};
1028
1029/// A sample reduction pattern that replaces instances of `firrtl.extmodule`
1030/// with wires.
1031struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {
1032 void beforeReduction(mlir::ModuleOp op) override {
1033 symbols.clear();
1034 nlaRemover.clear();
1035 }
1036 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1037
1038 uint64_t match(firrtl::InstanceOp instOp) override {
1039 return isa<firrtl::FExtModuleOp>(
1040 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1041 }
1042 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
1043 auto portInfo =
1044 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
1045 symbols.getNearestSymbolTable(instOp)))
1046 .getPorts();
1047 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
1048 TieOffCache tieOffCache(builder);
1049 SmallVector<Value> replacementWires;
1050 for (firrtl::PortInfo info : portInfo) {
1051 auto wire = firrtl::WireOp::create(
1052 builder, info.type,
1053 (Twine(instOp.getName()) + "_" + info.getName()).str())
1054 .getResult();
1055 if (info.isOutput()) {
1056 // Tie off output ports using TieOffCache.
1057 if (auto baseType = dyn_cast<firrtl::FIRRTLBaseType>(info.type)) {
1058 auto inv = tieOffCache.getInvalid(baseType);
1059 firrtl::ConnectOp::create(builder, wire, inv);
1060 } else if (auto propType = dyn_cast<firrtl::PropertyType>(info.type)) {
1061 auto unknown = tieOffCache.getUnknown(propType);
1062 builder.create<firrtl::PropAssignOp>(wire, unknown);
1063 }
1064 }
1065 replacementWires.push_back(wire);
1066 }
1067 nlaRemover.markNLAsInOperation(instOp);
1068 instOp.replaceAllUsesWith(std::move(replacementWires));
1069 instOp->erase();
1070 return success();
1071 }
1072 std::string getName() const override { return "extmodule-instance-remover"; }
1073 bool acceptSizeIncrease() const override { return true; }
1074
1075 ::detail::SymbolCache symbols;
1076 NLARemover nlaRemover;
1077};
1078
1079/// A reduction pattern that removes unused ports from extmodules and regular
1080/// modules. This is particularly useful for reducing test cases with many probe
1081/// ports or other unused ports.
1082///
1083/// Shared helper functions for port pruning reductions.
1084struct PortPrunerHelpers {
1085 /// Compute which ports are unused across all instances of a module.
1086 template <typename ModuleOpType>
1087 static void computeUnusedInstancePorts(ModuleOpType module,
1088 ArrayRef<Operation *> users,
1089 llvm::BitVector &portsToRemove) {
1090 auto ports = module.getPorts();
1091 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1092 bool portUsed = false;
1093 for (auto *user : users) {
1094 if (auto instOp = dyn_cast<firrtl::InstanceOp>(user)) {
1095 auto result = instOp.getResult(portIdx);
1096 if (!result.use_empty()) {
1097 portUsed = true;
1098 break;
1099 }
1100 }
1101 }
1102 if (!portUsed)
1103 portsToRemove.set(portIdx);
1104 }
1105 }
1106
1107 /// Update all instances of a module to remove the specified ports.
1108 static void
1109 updateInstancesAndErasePorts(Operation *module, ArrayRef<Operation *> users,
1110 const llvm::BitVector &portsToRemove) {
1111 // Update all instances to remove the corresponding results
1112 SmallVector<firrtl::InstanceOp> instancesToUpdate;
1113 for (auto *user : users) {
1114 if (auto instOp = dyn_cast<firrtl::InstanceOp>(user))
1115 instancesToUpdate.push_back(instOp);
1116 }
1117
1118 for (auto instOp : instancesToUpdate) {
1119 auto newInst = instOp.cloneWithErasedPorts(portsToRemove);
1120
1121 // Manually replace uses, skipping erased ports
1122 size_t newResultIdx = 0;
1123 for (size_t oldResultIdx = 0; oldResultIdx < instOp.getNumResults();
1124 ++oldResultIdx) {
1125 if (portsToRemove[oldResultIdx]) {
1126 // This port is being removed, assert it has no uses
1127 assert(instOp.getResult(oldResultIdx).use_empty() &&
1128 "removing port with uses");
1129 } else {
1130 // Replace uses of the old result with the new result
1131 instOp.getResult(oldResultIdx)
1132 .replaceAllUsesWith(newInst.getResult(newResultIdx));
1133 ++newResultIdx;
1134 }
1135 }
1136
1137 instOp->erase();
1138 }
1139 }
1140};
1141
1142/// Reduction to remove unused ports from regular modules.
1143struct ModulePortPruner : public OpReduction<firrtl::FModuleOp> {
1144 void beforeReduction(mlir::ModuleOp op) override {
1145 symbols.clear();
1146 nlaRemover.clear();
1147 }
1148 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1149
1150 void matches(firrtl::FModuleOp module,
1151 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1152 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1153 auto &userMap = symbols.getSymbolUserMap(tableOp);
1154 auto ports = module.getPorts();
1155 auto users = userMap.getUsers(module);
1156
1157 // Compute which ports can be removed
1158 llvm::BitVector portsToRemove(ports.size());
1159
1160 // If the module has no instances, aggressively remove ports that aren't
1161 // used within the module body itself
1162 if (users.empty()) {
1163 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx) {
1164 auto arg = module.getArgument(portIdx);
1165 if (arg.use_empty())
1166 portsToRemove.set(portIdx);
1167 }
1168 } else {
1169 // For modules with instances, check if ports are unused across all
1170 // instances
1171 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1172 portsToRemove);
1173 }
1174
1175 // Generate one match per removable port
1176 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1177 if (portsToRemove[portIdx])
1178 addMatch(1, portIdx);
1179 }
1180
1181 LogicalResult rewriteMatches(firrtl::FModuleOp module,
1182 ArrayRef<uint64_t> matches) override {
1183 if (matches.empty())
1184 return failure();
1185
1186 // Build a BitVector of ports to remove
1187 llvm::BitVector portsToRemove(module.getNumPorts());
1188 for (auto portIdx : matches)
1189 portsToRemove.set(portIdx);
1190
1191 // Get users for updating instances
1192 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1193 auto &userMap = symbols.getSymbolUserMap(tableOp);
1194 auto users = userMap.getUsers(module);
1195
1196 // Update all instances
1197 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1198 portsToRemove);
1199
1200 // Erase uses of port arguments within the module body.
1201 for (auto portIdx : matches) {
1202 // Recursively erase each user and its dependent operations
1203 for (auto *user :
1204 llvm::make_early_inc_range(module.getArgument(portIdx).getUsers()))
1205 reduce::pruneUnusedOps(user, *this);
1206 }
1207
1208 // Remove the ports from the module
1209 module.erasePorts(portsToRemove);
1210
1211 return success();
1212 }
1213
1214 std::string getName() const override { return "module-port-pruner"; }
1215
1216 ::detail::SymbolCache symbols;
1217 NLARemover nlaRemover;
1218};
1219
1220/// Reduction to remove unused ports from extmodules.
1221struct ExtmodulePortPruner : public OpReduction<firrtl::FExtModuleOp> {
1222 void beforeReduction(mlir::ModuleOp op) override {
1223 symbols.clear();
1224 nlaRemover.clear();
1225 }
1226 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1227
1228 void matches(firrtl::FExtModuleOp module,
1229 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1230 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1231 auto &userMap = symbols.getSymbolUserMap(tableOp);
1232 auto ports = module.getPorts();
1233 auto users = userMap.getUsers(module);
1234
1235 // Compute which ports can be removed
1236 llvm::BitVector portsToRemove(ports.size());
1237
1238 if (users.empty()) {
1239 // If the extmodule has no instances, aggressively remove all ports
1240 portsToRemove.set();
1241 } else {
1242 // For extmodules with instances, check if ports are unused across all
1243 // instances
1244 PortPrunerHelpers::computeUnusedInstancePorts(module, users,
1245 portsToRemove);
1246 }
1247
1248 // Generate one match per removable port
1249 for (size_t portIdx = 0; portIdx < ports.size(); ++portIdx)
1250 if (portsToRemove[portIdx])
1251 addMatch(1, portIdx);
1252 }
1253
1254 LogicalResult rewriteMatches(firrtl::FExtModuleOp module,
1255 ArrayRef<uint64_t> matches) override {
1256 if (matches.empty())
1257 return failure();
1258
1259 // Build a BitVector of ports to remove
1260 llvm::BitVector portsToRemove(module.getNumPorts());
1261 for (auto portIdx : matches)
1262 portsToRemove.set(portIdx);
1263
1264 // Get users for updating instances
1265 auto *tableOp = SymbolTable::getNearestSymbolTable(module);
1266 auto &userMap = symbols.getSymbolUserMap(tableOp);
1267 auto users = userMap.getUsers(module);
1268
1269 // Update all instances.
1270 PortPrunerHelpers::updateInstancesAndErasePorts(module, users,
1271 portsToRemove);
1272
1273 // Remove the ports from the module (no body to clean up for extmodules).
1274 module.erasePorts(portsToRemove);
1275
1276 return success();
1277 }
1278
1279 std::string getName() const override { return "extmodule-port-pruner"; }
1280
1281 ::detail::SymbolCache symbols;
1282 NLARemover nlaRemover;
1283};
1284
1285/// A sample reduction pattern that pushes connected values through wires.
1286struct ConnectForwarder : public Reduction {
1287 void beforeReduction(mlir::ModuleOp op) override {
1288 domInfo = std::make_unique<DominanceInfo>(op);
1289 }
1290
1291 uint64_t match(Operation *op) override {
1292 if (!isa<firrtl::FConnectLike>(op))
1293 return 0;
1294 auto dest = op->getOperand(0);
1295 auto src = op->getOperand(1);
1296 auto *destOp = dest.getDefiningOp();
1297 auto *srcOp = src.getDefiningOp();
1298 if (dest == src)
1299 return 0;
1300
1301 // Ensure that the destination is something we should be able to forward
1302 // through.
1303 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
1304 destOp))
1305 return 0;
1306
1307 // Ensure that the destination is connected to only once, and all uses of
1308 // the connection occur after the definition of the source.
1309 unsigned numConnects = 0;
1310 for (auto &use : dest.getUses()) {
1311 auto *op = use.getOwner();
1312 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
1313 if (++numConnects > 1)
1314 return 0;
1315 continue;
1316 }
1317 // Check if srcOp properly dominates op, but op is not enclosed in srcOp.
1318 // This handles cross-block cases (e.g., layerblocks).
1319 if (srcOp &&
1320 !domInfo->properlyDominates(srcOp, op, /*enclosingOpOk=*/false))
1321 return 0;
1322 }
1323
1324 return 1;
1325 }
1326
1327 LogicalResult rewrite(Operation *op) override {
1328 auto dst = op->getOperand(0);
1329 auto src = op->getOperand(1);
1330 dst.replaceAllUsesWith(src);
1331 op->erase();
1332 if (auto *dstOp = dst.getDefiningOp())
1333 reduce::pruneUnusedOps(dstOp, *this);
1334 if (auto *srcOp = src.getDefiningOp())
1335 reduce::pruneUnusedOps(srcOp, *this);
1336 return success();
1337 }
1338
1339 std::string getName() const override { return "connect-forwarder"; }
1340
1341private:
1342 std::unique_ptr<DominanceInfo> domInfo;
1343};
1344
1345/// A sample reduction pattern that replaces a single-use wire and register with
1346/// an operand of the source value of the connection.
1347template <unsigned OpNum>
1348struct ConnectSourceOperandForwarder : public Reduction {
1349 uint64_t match(Operation *op) override {
1350 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
1351 return 0;
1352 auto dest = op->getOperand(0);
1353 auto *destOp = dest.getDefiningOp();
1354
1355 // Ensure that the destination is used only once.
1356 if (!destOp || !destOp->hasOneUse() ||
1357 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
1358 return 0;
1359
1360 auto *srcOp = op->getOperand(1).getDefiningOp();
1361 if (!srcOp || OpNum >= srcOp->getNumOperands())
1362 return 0;
1363
1364 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
1365 auto opTy =
1366 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
1367
1368 return resultTy && opTy &&
1369 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
1370 ((resultTy.getBitWidthOrSentinel() == -1) ==
1371 (opTy.getBitWidthOrSentinel() == -1)) &&
1372 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
1373 }
1374
1375 LogicalResult rewrite(Operation *op) override {
1376 auto *destOp = op->getOperand(0).getDefiningOp();
1377 auto *srcOp = op->getOperand(1).getDefiningOp();
1378 auto forwardedOperand = srcOp->getOperand(OpNum);
1379 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
1380 Value newDest;
1381 if (auto wire = dyn_cast<firrtl::WireOp>(destOp))
1382 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
1383 wire.getName())
1384 .getResult();
1385 else {
1386 auto regName = destOp->getAttrOfType<StringAttr>("name");
1387 // We can promote the register into a wire but we wouldn't do here because
1388 // the error might be caused by the register.
1389 auto clock = destOp->getOperand(0);
1390 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
1391 clock, regName ? regName.str() : "")
1392 .getResult();
1393 }
1394
1395 // Create new connection between a new wire and the forwarded operand.
1396 builder.setInsertionPointAfter(op);
1397 if (isa<firrtl::ConnectOp>(op))
1398 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
1399 else
1400 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
1401
1402 // Remove the old connection and destination. We don't have to replace them
1403 // because destination has only one use.
1404 op->erase();
1405 destOp->erase();
1406 reduce::pruneUnusedOps(srcOp, *this);
1407
1408 return success();
1409 }
1410 std::string getName() const override {
1411 return ("connect-source-operand-" + Twine(OpNum) + "-forwarder").str();
1412 }
1413};
1414
1415/// A sample reduction pattern that tries to remove aggregate wires by replacing
1416/// all subaccesses with new independent wires. This can disentangle large
1417/// unused wires that are otherwise difficult to collect due to the subaccesses.
1418struct DetachSubaccesses : public Reduction {
1419 void beforeReduction(mlir::ModuleOp op) override { opsToErase.clear(); }
1420 void afterReduction(mlir::ModuleOp op) override {
1421 for (auto *op : opsToErase)
1422 op->dropAllReferences();
1423 for (auto *op : opsToErase)
1424 op->erase();
1425 }
1426 uint64_t match(Operation *op) override {
1427 // Only applies to wires and registers that are purely used in subaccess
1428 // operations.
1429 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
1430 llvm::all_of(op->getUses(), [](auto &use) {
1431 return use.getOperandNumber() == 0 &&
1432 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
1433 firrtl::SubaccessOp>(use.getOwner());
1434 });
1435 }
1436 LogicalResult rewrite(Operation *op) override {
1437 assert(match(op));
1438 OpBuilder builder(op);
1439 bool isWire = isa<firrtl::WireOp>(op);
1440 Value invalidClock;
1441 if (!isWire)
1442 invalidClock = firrtl::InvalidValueOp::create(
1443 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
1444 for (Operation *user : llvm::make_early_inc_range(op->getUsers())) {
1445 builder.setInsertionPoint(user);
1446 auto type = user->getResult(0).getType();
1447 Operation *replOp;
1448 if (isWire)
1449 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
1450 else
1451 replOp =
1452 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
1453 user->replaceAllUsesWith(replOp);
1454 opsToErase.insert(user);
1455 }
1456 opsToErase.insert(op);
1457 return success();
1458 }
1459 std::string getName() const override { return "detach-subaccesses"; }
1460 llvm::DenseSet<Operation *> opsToErase;
1461};
1462
1463/// This reduction removes inner symbols on ops. Name preservation creates a lot
1464/// of node ops with symbols to keep name information but it also prevents
1465/// normal canonicalizations.
1466struct NodeSymbolRemover : public Reduction {
1467 void beforeReduction(mlir::ModuleOp op) override {
1468 innerSymUses = reduce::InnerSymbolUses(op);
1469 }
1470
1471 uint64_t match(Operation *op) override {
1472 // Only match ops with an inner symbol.
1473 auto sym = op->getAttrOfType<hw::InnerSymAttr>("inner_sym");
1474 if (!sym || sym.empty())
1475 return 0;
1476
1477 // Only match ops that have no references to their inner symbol.
1478 if (innerSymUses.hasInnerRef(op))
1479 return 0;
1480 return 1;
1481 }
1482
1483 LogicalResult rewrite(Operation *op) override {
1484 op->removeAttr("inner_sym");
1485 return success();
1486 }
1487
1488 std::string getName() const override { return "node-symbol-remover"; }
1489 bool acceptSizeIncrease() const override { return true; }
1490
1491 reduce::InnerSymbolUses innerSymUses;
1492};
1493
1494/// Check if inlining the referenced operation into the parent operation would
1495/// cause inner symbol collisions.
1496static bool
1497hasInnerSymbolCollision(Operation *referencedOp, Operation *parentOp,
1498 hw::InnerSymbolTableCollection &innerSymTables) {
1499 // Get the inner symbol tables for both operations
1500 auto &targetTable = innerSymTables.getInnerSymbolTable(referencedOp);
1501 auto &parentTable = innerSymTables.getInnerSymbolTable(parentOp);
1502
1503 // Check if any inner symbol name in the target operation already exists
1504 // in the parent operation. Return failure() if a collision is found to stop
1505 // the walk early.
1506 LogicalResult walkResult = targetTable.walkSymbols(
1507 [&](StringAttr name, const hw::InnerSymTarget &target) -> LogicalResult {
1508 // Check if this symbol name exists in the parent operation
1509 if (parentTable.lookup(name)) {
1510 // Collision found, return failure to stop the walk
1511 return failure();
1512 }
1513 return success();
1514 });
1515
1516 // If the walk failed, it means we found a collision
1517 return failed(walkResult);
1518}
1519
1520/// A sample reduction pattern that eagerly inlines instances.
1521struct EagerInliner : public OpReduction<InstanceOp> {
1522 void beforeReduction(mlir::ModuleOp op) override {
1523 symbols.clear();
1524 nlaRemover.clear();
1525 nlaTables.clear();
1526 for (auto circuitOp : op.getOps<CircuitOp>())
1527 nlaTables.insert({circuitOp, std::make_unique<NLATable>(circuitOp)});
1528 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1529 }
1530 void afterReduction(mlir::ModuleOp op) override {
1531 nlaRemover.remove(op);
1532 nlaTables.clear();
1533 innerSymTables.reset();
1534 }
1535
1536 uint64_t match(InstanceOp instOp) override {
1537 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1538 auto *moduleOp =
1539 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1540
1541 // Only inline FModuleOp instances
1542 if (!isa<FModuleOp>(moduleOp))
1543 return 0;
1544
1545 // Skip instances that participate in any NLAs
1546 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1547 if (!circuitOp)
1548 return 0;
1549 auto it = nlaTables.find(circuitOp);
1550 if (it == nlaTables.end() || !it->second)
1551 return 0;
1552 DenseSet<hw::HierPathOp> nlas;
1553 it->second->getInstanceNLAs(instOp, nlas);
1554 if (!nlas.empty())
1555 return 0;
1556
1557 // Check for inner symbol collisions between the referenced module and the
1558 // instance's parent module
1559 auto parentOp = instOp->getParentOfType<FModuleLike>();
1560 if (hasInnerSymbolCollision(moduleOp, parentOp, *innerSymTables))
1561 return 0;
1562
1563 return 1;
1564 }
1565
1566 LogicalResult rewrite(InstanceOp instOp) override {
1567 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1568 auto moduleOp = cast<FModuleOp>(
1569 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1570 bool isLastUse =
1571 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1572 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1573
1574 // Create wires to replace the instance results.
1575 IRRewriter rewriter(instOp);
1576 SmallVector<Value> argWires;
1577 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1578 auto result = instOp.getResult(i);
1579 auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" +
1580 instOp.getPortName(i));
1581 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1582 name, NameKindEnum::DroppableName,
1583 instOp.getPortAnnotation(i), StringAttr{})
1584 .getResult();
1585 result.replaceAllUsesWith(wire);
1586 argWires.push_back(wire);
1587 }
1588
1589 // Splice in the cloned module body.
1590 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1591
1592 // Make sure we remove any NLAs that go through this instance, and the
1593 // module if we're about the delete the module.
1594 nlaRemover.markNLAsInOperation(instOp);
1595 if (isLastUse)
1596 nlaRemover.markNLAsInOperation(moduleOp);
1597
1598 instOp.erase();
1599 clonedModuleOp.erase();
1600 return success();
1601 }
1602
1603 std::string getName() const override { return "firrtl-eager-inliner"; }
1604 bool acceptSizeIncrease() const override { return true; }
1605
1606 ::detail::SymbolCache symbols;
1607 NLARemover nlaRemover;
1608 DenseMap<CircuitOp, std::unique_ptr<NLATable>> nlaTables;
1609 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1610};
1611
1612/// A reduction pattern that eagerly inlines `ObjectOp`s.
1613struct ObjectInliner : public OpReduction<ObjectOp> {
1614 void beforeReduction(mlir::ModuleOp op) override {
1615 blocksToSort.clear();
1616 symbols.clear();
1617 nlaRemover.clear();
1618 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
1619 }
1620 void afterReduction(mlir::ModuleOp op) override {
1621 for (auto *block : blocksToSort)
1622 mlir::sortTopologically(block);
1623 blocksToSort.clear();
1624 nlaRemover.remove(op);
1625 innerSymTables.reset();
1626 }
1627
1628 uint64_t match(ObjectOp objOp) override {
1629 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1630 auto *classOp =
1631 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
1632
1633 // Only inline `ClassOp`s.
1634 if (!isa<ClassOp>(classOp))
1635 return 0;
1636
1637 // Check for inner symbol collisions between the referenced class and the
1638 // object's parent module.
1639 auto parentOp = objOp->getParentOfType<FModuleLike>();
1640 if (hasInnerSymbolCollision(classOp, parentOp, *innerSymTables))
1641 return 0;
1642
1643 // Verify all uses are ObjectSubfieldOp.
1644 for (auto *user : objOp.getResult().getUsers())
1645 if (!isa<ObjectSubfieldOp>(user))
1646 return 0;
1647
1648 return 1;
1649 }
1650
1651 LogicalResult rewrite(ObjectOp objOp) override {
1652 auto *tableOp = SymbolTable::getNearestSymbolTable(objOp);
1653 auto classOp = cast<ClassOp>(
1654 objOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1655 auto clonedClassOp = classOp.clone();
1656
1657 // Create wires to replace the ObjectSubfieldOp results.
1658 IRRewriter rewriter(objOp);
1659 SmallVector<Value> portWires;
1660 auto classType = objOp.getType();
1661
1662 // Create a wire for each port in the class
1663 for (unsigned i = 0, e = classType.getNumElements(); i != e; ++i) {
1664 auto element = classType.getElement(i);
1665 auto name = rewriter.getStringAttr(Twine(objOp.getName()) + "_" +
1666 element.name.getValue());
1667 auto wire = WireOp::create(rewriter, objOp.getLoc(), element.type, name,
1668 NameKindEnum::DroppableName,
1669 rewriter.getArrayAttr({}), StringAttr{})
1670 .getResult();
1671 portWires.push_back(wire);
1672 }
1673
1674 // Replace all ObjectSubfieldOp uses with corresponding wires
1675 SmallVector<ObjectSubfieldOp> subfieldOps;
1676 for (auto *user : objOp.getResult().getUsers()) {
1677 auto subfieldOp = cast<ObjectSubfieldOp>(user);
1678 subfieldOps.push_back(subfieldOp);
1679 auto index = subfieldOp.getIndex();
1680 subfieldOp.getResult().replaceAllUsesWith(portWires[index]);
1681 }
1682
1683 // Splice in the cloned class body.
1684 rewriter.inlineBlockBefore(clonedClassOp.getBodyBlock(), objOp, portWires);
1685
1686 // After inlining the class body, we need to eliminate `WireOps` since
1687 // `ClassOps` cannot contain wires. For each port wire, find its single
1688 // connect, remove it, and replace all uses of the wire with the assigned
1689 // value.
1690 SmallVector<FConnectLike> connectsToErase;
1691 for (auto portWire : portWires) {
1692 // Find a single value to replace the wire with, and collect all connects
1693 // to the wire such that we can erase them later.
1694 Value value;
1695 for (auto *user : portWire.getUsers()) {
1696 if (auto connect = dyn_cast<FConnectLike>(user)) {
1697 if (connect.getDest() == portWire) {
1698 value = connect.getSrc();
1699 connectsToErase.push_back(connect);
1700 }
1701 }
1702 }
1703
1704 // Be very conservative about deleting these wires. Other reductions may
1705 // leave class ports unconnected, which means that there isn't always a
1706 // clean replacement available here. Better to just leave the wires in the
1707 // IR and let the verifier fail later.
1708 if (value)
1709 portWire.replaceAllUsesWith(value);
1710 for (auto connect : connectsToErase)
1711 connect.erase();
1712 if (portWire.use_empty())
1713 portWire.getDefiningOp()->erase();
1714 connectsToErase.clear();
1715 }
1716
1717 // Make sure we remove any NLAs that go through this object.
1718 nlaRemover.markNLAsInOperation(objOp);
1719
1720 // Since the above forwarding of SSA values through wires can create
1721 // dominance issues, mark the region containing the object to be sorted
1722 // topologically.
1723 blocksToSort.insert(objOp->getBlock());
1724
1725 // Erase the object and cloned class.
1726 for (auto subfieldOp : subfieldOps)
1727 subfieldOp.erase();
1728 objOp.erase();
1729 clonedClassOp.erase();
1730 return success();
1731 }
1732
1733 std::string getName() const override { return "firrtl-object-inliner"; }
1734 bool acceptSizeIncrease() const override { return true; }
1735
1736 SetVector<Block *> blocksToSort;
1737 ::detail::SymbolCache symbols;
1738 NLARemover nlaRemover;
1739 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1740};
1741
1742/// Psuedo-reduction that sanitizes the names of things inside modules. This is
1743/// not an actual reduction, but often removes extraneous information that has
1744/// no bearing on the actual reduction (and would likely be removed before
1745/// sharing the reduction). This makes the following changes:
1746///
1747/// - All wires are renamed to "wire"
1748/// - All registers are renamed to "reg"
1749/// - All nodes are renamed to "node"
1750/// - All memories are renamed to "mem"
1751/// - All verification messages and labels are dropped
1752///
1753struct ModuleInternalNameSanitizer : public Reduction {
1754 uint64_t match(Operation *op) override {
1755 // Only match operations with names.
1756 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1757 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1758 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1759 firrtl::CoverOp>(op);
1760 }
1761 LogicalResult rewrite(Operation *op) override {
1762 TypeSwitch<Operation *, void>(op)
1763 .Case<firrtl::WireOp>([](auto op) { op.setName("wire"); })
1764 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1765 [](auto op) { op.setName("reg"); })
1766 .Case<firrtl::NodeOp>([](auto op) { op.setName("node"); })
1767 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1768 [](auto op) { op.setName("mem"); })
1769 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](auto op) {
1770 op->setAttr("message", StringAttr::get(op.getContext(), ""));
1771 op->setAttr("name", StringAttr::get(op.getContext(), ""));
1772 });
1773 return success();
1774 }
1775
1776 std::string getName() const override {
1777 return "module-internal-name-sanitizer";
1778 }
1779
1780 bool acceptSizeIncrease() const override { return true; }
1781
1782 bool isOneShot() const override { return true; }
1783};
1784
1785/// Psuedo-reduction that sanitizes module, instance, and port names. This
1786/// makes the following changes:
1787///
1788/// - All modules are given metasyntactic names ("Foo", "Bar", etc.)
1789/// - All instances are renamed to match the new module name
1790/// - All module ports are renamed in the following way:
1791/// - All clocks are reanemd to "clk"
1792/// - All resets are renamed to "rst"
1793/// - All references are renamed to "ref"
1794/// - Anything else is renamed to "port"
1795///
1796struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
1797
1799 size_t portNameIndex = 0;
1800
1801 char getPortName() {
1802 if (portNameIndex >= 26)
1803 portNameIndex = 0;
1804 return 'a' + portNameIndex++;
1805 }
1806
1807 void beforeReduction(mlir::ModuleOp op) override { nameGenerator.reset(); }
1808
1809 LogicalResult rewrite(firrtl::CircuitOp circuitOp) override {
1810
1811 firrtl::InstanceGraph iGraph(circuitOp);
1812
1813 auto *circuitName = nameGenerator.getNextName();
1814 iGraph.getTopLevelModule().setName(circuitName);
1815 circuitOp.setName(circuitName);
1816
1817 for (auto *node : iGraph) {
1818 auto module = node->getModule<firrtl::FModuleLike>();
1819
1820 bool shouldReplacePorts = false;
1821 SmallVector<Attribute> newNames;
1822 if (auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1823 portNameIndex = 0;
1824 // TODO: The namespace should be unnecessary. However, some FIRRTL
1825 // passes expect that port names are unique.
1827 auto oldPorts = fmodule.getPorts();
1828 shouldReplacePorts = !oldPorts.empty();
1829 for (unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1830 auto port = oldPorts[i];
1831 auto newName = firrtl::FIRRTLTypeSwitch<Type, StringRef>(port.type)
1832 .Case<firrtl::ClockType>(
1833 [&](auto a) { return ns.newName("clk"); })
1834 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1835 [&](auto a) { return ns.newName("rst"); })
1836 .Case<firrtl::RefType>(
1837 [&](auto a) { return ns.newName("ref"); })
1838 .Default([&](auto a) {
1839 return ns.newName(Twine(getPortName()));
1840 });
1841 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1842 }
1843 fmodule->setAttr("portNames",
1844 ArrayAttr::get(fmodule.getContext(), newNames));
1845 }
1846
1847 if (module == iGraph.getTopLevelModule())
1848 continue;
1849 auto newName =
1850 StringAttr::get(circuitOp.getContext(), nameGenerator.getNextName());
1851 module.setName(newName);
1852 for (auto *use : node->uses()) {
1853 auto useOp = use->getInstance();
1854 if (auto instanceOp = dyn_cast<firrtl::InstanceOp>(*useOp)) {
1855 instanceOp.setModuleName(newName);
1856 instanceOp.setName(newName);
1857 if (shouldReplacePorts)
1858 instanceOp.setPortNamesAttr(
1859 ArrayAttr::get(circuitOp.getContext(), newNames));
1860 } else if (auto objectOp = dyn_cast<firrtl::ObjectOp>(*useOp)) {
1861 // ObjectOp stores the class name in its result type, so we need to
1862 // create a new ClassType with the new name and set it on the result.
1863 auto oldClassType = objectOp.getType();
1864 auto newClassType = firrtl::ClassType::get(
1865 circuitOp.getContext(), FlatSymbolRefAttr::get(newName),
1866 oldClassType.getElements());
1867 objectOp.getResult().setType(newClassType);
1868 objectOp.setName(newName);
1869 }
1870 }
1871 }
1872
1873 return success();
1874 }
1875
1876 std::string getName() const override { return "module-name-sanitizer"; }
1877
1878 bool acceptSizeIncrease() const override { return true; }
1879
1880 bool isOneShot() const override { return true; }
1881};
1882
1883/// A reduction pattern that groups modules by their port signature (types and
1884/// directions) and replaces instances with the smallest module in each group.
1885/// This helps reduce the IR by consolidating functionally equivalent modules
1886/// based on their interface.
1887///
1888/// The pattern works by:
1889/// 1. Grouping all modules by their port signature (port types and directions)
1890/// 2. For each group with multiple modules, finding the smallest module using
1891/// the module size cache
1892/// 3. Replacing all instances of larger modules with instances of the smallest
1893/// module in the same group
1894/// 4. Removing the larger modules from the circuit
1895///
1896/// This reduction is useful for reducing circuits where multiple modules have
1897/// the same interface but different implementations, allowing the reducer to
1898/// try the smallest implementation first.
1899struct ModuleSwapper : public OpReduction<InstanceOp> {
1900 // Per-circuit state containing all the information needed for module swapping
1901 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1902 struct CircuitState {
1903 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1904 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1905 std::unique_ptr<NLATable> nlaTable;
1906 };
1907
1908 void beforeReduction(mlir::ModuleOp op) override {
1909 symbols.clear();
1910 nlaRemover.clear();
1911 moduleSizes.clear();
1912 circuitStates.clear();
1913
1914 // Collect module type groups and NLA tables for all circuits up front
1915 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1916 auto &state = circuitStates[circuitOp];
1917 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1918 buildModuleTypeGroups(circuitOp, state);
1919 return WalkResult::skip();
1920 });
1921 }
1922 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1923
1924 /// Create a vector of port type-direction pairs for the given FIRRTL module.
1925 /// This ignores port names, allowing modules with the same port types and
1926 /// directions but different port names to be considered equivalent for
1927 /// swapping.
1928 PortSignature getModulePortSignature(FModuleLike module) {
1929 PortSignature signature;
1930 signature.reserve(module.getNumPorts());
1931 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1932 signature.emplace_back(module.getPortType(i), module.getPortDirection(i));
1933 return signature;
1934 }
1935
1936 /// Group modules by their port signature and find the smallest in each group.
1937 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1938 // Group modules by their port signature
1939 for (auto module : circuitOp.getBodyBlock()->getOps<FModuleLike>()) {
1940 auto signature = getModulePortSignature(module);
1941 state.moduleTypeGroups[signature].push_back(module);
1942 }
1943
1944 // For each group, find the smallest module
1945 for (auto &[signature, modules] : state.moduleTypeGroups) {
1946 if (modules.size() <= 1)
1947 continue;
1948
1949 FModuleLike smallestModule = nullptr;
1950 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1951
1952 for (auto module : modules) {
1953 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1954 if (size < smallestSize) {
1955 smallestSize = size;
1956 smallestModule = module;
1957 }
1958 }
1959
1960 // Map all modules in this group to the smallest one
1961 for (auto module : modules) {
1962 if (module != smallestModule) {
1963 state.instanceToCanonicalModule[module.getModuleNameAttr()] =
1964 smallestModule;
1965 }
1966 }
1967 }
1968 }
1969
1970 uint64_t match(InstanceOp instOp) override {
1971 // Get the circuit this instance belongs to
1972 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1973 assert(circuitOp);
1974 const auto &state = circuitStates.at(circuitOp);
1975
1976 // Skip instances that participate in any NLAs
1977 DenseSet<hw::HierPathOp> nlas;
1978 state.nlaTable->getInstanceNLAs(instOp, nlas);
1979 if (!nlas.empty())
1980 return 0;
1981
1982 // Check if this instance can be redirected to a smaller module
1983 auto moduleName = instOp.getModuleNameAttr().getAttr();
1984 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1985 if (!canonicalModule)
1986 return 0;
1987
1988 // Benefit is the size difference
1989 auto currentModule = cast<FModuleLike>(
1990 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1991 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1992 uint64_t canonicalSize =
1993 moduleSizes.getModuleSize(canonicalModule, symbols);
1994 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1995 }
1996
1997 LogicalResult rewrite(InstanceOp instOp) override {
1998 // Get the circuit this instance belongs to
1999 auto circuitOp = instOp->getParentOfType<CircuitOp>();
2000 assert(circuitOp);
2001 const auto &state = circuitStates.at(circuitOp);
2002
2003 // Replace the instantiated module with the canonical module.
2004 auto canonicalModule = state.instanceToCanonicalModule.at(
2005 instOp.getModuleNameAttr().getAttr());
2006 auto canonicalName = canonicalModule.getModuleNameAttr();
2007 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
2008
2009 // Update port names to match the canonical module
2010 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2011
2012 return success();
2013 }
2014
2015 std::string getName() const override { return "firrtl-module-swapper"; }
2016 bool acceptSizeIncrease() const override { return true; }
2017
2018private:
2019 ::detail::SymbolCache symbols;
2020 NLARemover nlaRemover;
2021 ModuleSizeCache moduleSizes;
2022
2023 // Per-circuit state containing all module swapping information
2024 DenseMap<CircuitOp, CircuitState> circuitStates;
2025};
2026
2027/// A reduction pattern that handles MustDedup annotations by replacing all
2028/// module names in a dedup group with a single module name. This helps reduce
2029/// the IR by consolidating module references that are required to be identical.
2030///
2031/// The pattern works by:
2032/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
2033/// 2. For each dedup group, using the first module as the canonical name
2034/// 3. Replacing all instance references to other modules in the group with
2035/// references to the canonical module
2036/// 4. Removing the non-canonical modules from the circuit
2037/// 5. Removing the processed MustDedup annotation
2038///
2039/// This reduction is particularly useful for reducing large circuits where
2040/// multiple modules are known to be identical but haven't been deduplicated
2041/// yet.
2042struct ForceDedup : public OpReduction<CircuitOp> {
2043 void beforeReduction(mlir::ModuleOp op) override {
2044 symbols.clear();
2045 nlaRemover.clear();
2046 modulesToErase.clear();
2047 moduleSizes.clear();
2048 }
2049 void afterReduction(mlir::ModuleOp op) override {
2050 nlaRemover.remove(op);
2051 for (auto mod : modulesToErase)
2052 mod->erase();
2053 }
2054
2055 /// Collect all MustDedup annotations and create matches for each dedup group.
2056 void matches(CircuitOp circuitOp,
2057 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2058 auto &symbolTable = symbols.getNearestSymbolTable(circuitOp);
2059 auto annotations = AnnotationSet(circuitOp);
2060 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2061 if (!anno.isClass(mustDeduplicateAnnoClass))
2062 continue;
2063
2064 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2065 if (!modulesAttr || modulesAttr.size() < 2)
2066 continue;
2067
2068 // Check that all modules have the same port signature. Malformed inputs
2069 // may have modules listed in a MustDedup annotation that have distinct
2070 // port types.
2071 uint64_t totalSize = 0;
2072 ArrayAttr portTypes;
2073 DenseBoolArrayAttr portDirections;
2074 bool allSame = true;
2075 for (auto moduleName : modulesAttr.getAsRange<StringAttr>()) {
2076 auto target = tokenizePath(moduleName);
2077 if (!target) {
2078 allSame = false;
2079 break;
2080 }
2081 auto mod = symbolTable.lookup<FModuleLike>(target->module);
2082 if (!mod) {
2083 allSame = false;
2084 break;
2085 }
2086 totalSize += moduleSizes.getModuleSize(mod, symbols);
2087 if (!portTypes) {
2088 portTypes = mod.getPortTypesAttr();
2089 portDirections = mod.getPortDirectionsAttr();
2090 } else if (portTypes != mod.getPortTypesAttr() ||
2091 portDirections != mod.getPortDirectionsAttr()) {
2092 allSame = false;
2093 break;
2094 }
2095 }
2096 if (!allSame)
2097 continue;
2098
2099 // Each dedup group gets its own match with benefit proportional to group
2100 // size.
2101 addMatch(totalSize, annoIdx);
2102 }
2103 }
2104
2105 LogicalResult rewriteMatches(CircuitOp circuitOp,
2106 ArrayRef<uint64_t> matches) override {
2107 auto *context = circuitOp->getContext();
2108 NLATable nlaTable(circuitOp);
2109 hw::InnerSymbolTableCollection innerSymTables;
2110 auto annotations = AnnotationSet(circuitOp);
2111 SmallVector<Annotation> newAnnotations;
2112
2113 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2114 // Check if this annotation was selected.
2115 if (!llvm::is_contained(matches, annoIdx)) {
2116 newAnnotations.push_back(anno);
2117 continue;
2118 }
2119 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2120 assert(anno.isClass(mustDeduplicateAnnoClass) && modulesAttr &&
2121 modulesAttr.size() >= 2);
2122
2123 // Extract module names from the dedup group.
2124 SmallVector<StringAttr> moduleNames;
2125 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
2126 // Parse "~CircuitName|ModuleName" format.
2127 auto refStr = moduleRef.getValue();
2128 auto pipePos = refStr.find('|');
2129 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
2130 auto moduleName = refStr.substr(pipePos + 1);
2131 moduleNames.push_back(StringAttr::get(context, moduleName));
2132 }
2133 }
2134
2135 // Simply drop the annotation if there's only one module.
2136 if (moduleNames.size() < 2)
2137 continue;
2138
2139 // Replace all instances and references to other modules with the
2140 // first module.
2141 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
2142 nlaRemover.markNLAsInAnnotation(anno.getAttr());
2143 }
2144 if (newAnnotations.size() == annotations.size())
2145 return failure();
2146
2147 // Update circuit annotations.
2148 AnnotationSet newAnnoSet(newAnnotations, context);
2149 newAnnoSet.applyToOperation(circuitOp);
2150 return success();
2151 }
2152
2153 std::string getName() const override { return "firrtl-force-dedup"; }
2154 bool acceptSizeIncrease() const override { return true; }
2155
2156private:
2157 /// Replace all references to modules in the dedup group with the canonical
2158 /// module name
2159 void replaceModuleReferences(CircuitOp circuitOp,
2160 ArrayRef<StringAttr> moduleNames,
2161 NLATable &nlaTable,
2162 hw::InnerSymbolTableCollection &innerSymTables) {
2163 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
2164 auto &symbolTable = symbols.getSymbolTable(tableOp);
2165 auto &symbolUserMap = symbols.getSymbolUserMap(tableOp);
2166 auto *context = circuitOp->getContext();
2167 auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
2168
2169 // Collect the modules.
2170 FModuleLike canonicalModule;
2171 SmallVector<FModuleLike> modulesToReplace;
2172 for (auto name : moduleNames) {
2173 if (auto mod = symbolTable.lookup<FModuleLike>(name)) {
2174 if (!canonicalModule)
2175 canonicalModule = mod;
2176 else
2177 modulesToReplace.push_back(mod);
2178 }
2179 }
2180 if (modulesToReplace.empty())
2181 return;
2182
2183 // Replace all instance references.
2184 auto canonicalName = canonicalModule.getModuleNameAttr();
2185 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
2186 for (auto moduleName : moduleNames) {
2187 if (moduleName == canonicalName)
2188 continue;
2189 auto *symbolOp = symbolTable.lookup(moduleName);
2190 if (!symbolOp)
2191 continue;
2192 for (auto *user : symbolUserMap.getUsers(symbolOp)) {
2193 auto instOp = dyn_cast<InstanceOp>(user);
2194 if (!instOp || instOp.getModuleNameAttr().getAttr() != moduleName)
2195 continue;
2196 instOp.setModuleNameAttr(canonicalRef);
2197 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
2198 }
2199 }
2200
2201 // Update NLAs to reference the canonical module instead of modules being
2202 // removed using NLATable for better performance.
2203 for (auto oldMod : modulesToReplace) {
2204 SmallVector<hw::HierPathOp> nlaOps(
2205 nlaTable.lookup(oldMod.getModuleNameAttr()));
2206 for (auto nlaOp : nlaOps) {
2207 nlaTable.erase(nlaOp);
2208 StringAttr oldModName = oldMod.getModuleNameAttr();
2209 StringAttr newModName = canonicalName;
2210 SmallVector<Attribute, 4> newPath;
2211 for (auto nameRef : nlaOp.getNamepath()) {
2212 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2213 if (ref.getModule() == oldModName) {
2214 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
2215 ref = hw::InnerRefAttr::get(newModName, ref.getName());
2216 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
2217 if (oldInst && newInst) {
2218 // Get the first module name from the list (for
2219 // InstanceOp/ObjectOp, there's only one)
2220 auto oldModNames = oldInst.getReferencedModuleNamesAttr();
2221 auto newModNames = newInst.getReferencedModuleNamesAttr();
2222 if (!oldModNames.empty() && !newModNames.empty()) {
2223 oldModName = cast<StringAttr>(oldModNames[0]);
2224 newModName = cast<StringAttr>(newModNames[0]);
2225 }
2226 }
2227 }
2228 newPath.push_back(ref);
2229 } else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
2230 newPath.push_back(FlatSymbolRefAttr::get(newModName));
2231 } else {
2232 newPath.push_back(nameRef);
2233 }
2234 }
2235 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
2236 nlaTable.addNLA(nlaOp);
2237 }
2238 }
2239
2240 // Mark NLAs in modules to be removed.
2241 for (auto module : modulesToReplace) {
2242 nlaRemover.markNLAsInOperation(module);
2243 modulesToErase.insert(module);
2244 }
2245 }
2246
2247 ::detail::SymbolCache symbols;
2248 NLARemover nlaRemover;
2249 SetVector<FModuleLike> modulesToErase;
2250 ModuleSizeCache moduleSizes;
2251};
2252
2253/// A reduction pattern that moves `MustDedup` annotations from a module onto
2254/// its child modules. This pattern iterates over all MustDedup annotations,
2255/// collects all `FInstanceLike` ops in each module of the dedup group, and
2256/// creates new MustDedup annotations for corresponding instances across the
2257/// modules. Each set of corresponding instances becomes a separate match of the
2258/// reduction. The reduction also removes the original MustDedup annotation on
2259/// the parent module.
2260///
2261/// The pattern works by:
2262/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
2263/// 2. For each dedup group, collecting all FInstanceLike operations in each
2264/// module
2265/// 3. Grouping corresponding instances across modules by their position/name
2266/// 4. Creating new MustDedup annotations for each group of corresponding
2267/// instances
2268/// 5. Removing the original MustDedup annotation from the circuit
2269struct MustDedupChildren : public OpReduction<CircuitOp> {
2270 void beforeReduction(mlir::ModuleOp op) override {
2271 symbols.clear();
2272 nlaRemover.clear();
2273 }
2274 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
2275
2276 /// Collect all MustDedup annotations and create matches for each instance
2277 /// group.
2278 void matches(CircuitOp circuitOp,
2279 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2280 auto annotations = AnnotationSet(circuitOp);
2281 uint64_t matchId = 0;
2282
2283 DenseSet<StringRef> modulesAlreadyInMustDedup;
2284 for (auto [annoIdx, anno] : llvm::enumerate(annotations))
2285 if (anno.isClass(mustDeduplicateAnnoClass))
2286 if (auto modulesAttr = anno.getMember<ArrayAttr>("modules"))
2287 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2288 if (auto target = tokenizePath(moduleRef))
2289 modulesAlreadyInMustDedup.insert(target->module);
2290
2291 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2292 if (!anno.isClass(mustDeduplicateAnnoClass))
2293 continue;
2294
2295 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2296 if (!modulesAttr || modulesAttr.size() < 2)
2297 continue;
2298
2299 // Process each group of corresponding instances
2300 processInstanceGroups(
2301 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2302 matchId++;
2303
2304 // Make sure there are at least two distinct modules.
2305 SmallDenseSet<StringAttr, 4> moduleTargets;
2306 for (auto instOp : instanceGroup) {
2307 auto moduleNames = instOp.getReferencedModuleNamesAttr();
2308 for (auto moduleName : moduleNames)
2309 moduleTargets.insert(cast<StringAttr>(moduleName));
2310 }
2311 if (moduleTargets.size() < 2)
2312 return;
2313
2314 // Make sure none of the modules are not yet in a must dedup
2315 // annotation.
2316 if (llvm::any_of(instanceGroup, [&](FInstanceLike inst) {
2317 auto moduleNames = inst.getReferencedModuleNames();
2318 return llvm::any_of(moduleNames, [&](StringRef moduleName) {
2319 return modulesAlreadyInMustDedup.contains(moduleName);
2320 });
2321 }))
2322 return;
2323
2324 addMatch(1, matchId - 1);
2325 });
2326 }
2327 }
2328
2329 LogicalResult rewriteMatches(CircuitOp circuitOp,
2330 ArrayRef<uint64_t> matches) override {
2331 auto *context = circuitOp->getContext();
2332 auto annotations = AnnotationSet(circuitOp);
2333 SmallVector<Annotation> newAnnotations;
2334 uint64_t matchId = 0;
2335
2336 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
2337 if (!anno.isClass(mustDeduplicateAnnoClass)) {
2338 newAnnotations.push_back(anno);
2339 continue;
2340 }
2341
2342 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
2343 if (!modulesAttr || modulesAttr.size() < 2) {
2344 newAnnotations.push_back(anno);
2345 continue;
2346 }
2347
2348 processInstanceGroups(
2349 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
2350 // Check if this instance group was selected
2351 if (!llvm::is_contained(matches, matchId++))
2352 return;
2353
2354 // Create the list of modules to put into this new annotation.
2355 SmallSetVector<StringAttr, 4> moduleTargets;
2356 for (auto instOp : instanceGroup) {
2357 auto moduleNames = instOp.getReferencedModuleNames();
2358 for (auto moduleName : moduleNames) {
2359 auto target = TokenAnnoTarget();
2360 target.circuit = circuitOp.getName();
2361 target.module = moduleName;
2362 moduleTargets.insert(target.toStringAttr(context));
2363 }
2364 }
2365
2366 // Create a new MustDedup annotation for this list of modules.
2367 SmallVector<NamedAttribute> newAnnoAttrs;
2368 newAnnoAttrs.emplace_back(
2369 StringAttr::get(context, "class"),
2370 StringAttr::get(context, mustDeduplicateAnnoClass));
2371 newAnnoAttrs.emplace_back(
2372 StringAttr::get(context, "modules"),
2373 ArrayAttr::get(context,
2374 SmallVector<Attribute>(moduleTargets.begin(),
2375 moduleTargets.end())));
2376
2377 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
2378 newAnnotations.emplace_back(newAnnoDict);
2379 });
2380
2381 // Keep the original annotation around.
2382 newAnnotations.push_back(anno);
2383 }
2384
2385 // Update circuit annotations
2386 AnnotationSet newAnnoSet(newAnnotations, context);
2387 newAnnoSet.applyToOperation(circuitOp);
2388 return success();
2389 }
2390
2391 std::string getName() const override { return "must-dedup-children"; }
2392 bool acceptSizeIncrease() const override { return true; }
2393
2394private:
2395 /// Helper function to process groups of corresponding instances from a
2396 /// MustDedup annotation. Calls the provided lambda for each group of
2397 /// corresponding instances across the modules. Only calls the lambda if there
2398 /// are at least 2 modules.
2399 void processInstanceGroups(
2400 CircuitOp circuitOp, ArrayAttr modulesAttr,
2401 llvm::function_ref<void(ArrayRef<FInstanceLike>)> callback) {
2402 auto &symbolTable = symbols.getSymbolTable(circuitOp);
2403
2404 // Extract module names and get the actual modules
2405 SmallVector<FModuleLike> modules;
2406 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
2407 if (auto target = tokenizePath(moduleRef))
2408 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
2409 modules.push_back(mod);
2410
2411 // Need at least 2 modules for deduplication
2412 if (modules.size() < 2)
2413 return;
2414
2415 // Collect all FInstanceLike operations from each module and group them by
2416 // name. Instance names are a good key for matching instances across
2417 // modules. But they may not be unique, so we need to be careful to only
2418 // match up instances that are uniquely named within every module.
2419 struct InstanceGroup {
2420 SmallVector<FInstanceLike> instances;
2421 bool nameIsUnique = true;
2422 };
2423 MapVector<StringAttr, InstanceGroup> instanceGroups;
2424 for (auto module : modules) {
2426 module.walk([&](FInstanceLike instOp) {
2427 if (isa<ObjectOp>(instOp.getOperation()))
2428 return;
2429 auto name = instOp.getInstanceNameAttr();
2430 auto &group = instanceGroups[name];
2431 if (nameCounts[name]++ > 1)
2432 group.nameIsUnique = false;
2433 group.instances.push_back(instOp);
2434 });
2435 }
2436
2437 // Call the callback for each group of instances that are uniquely named and
2438 // consist of at least 2 instances.
2439 for (auto &[name, group] : instanceGroups)
2440 if (group.nameIsUnique && group.instances.size() >= 2)
2441 callback(group.instances);
2442 }
2443
2444 ::detail::SymbolCache symbols;
2445 NLARemover nlaRemover;
2446};
2447
2448struct LayerDisable : public OpReduction<CircuitOp> {
2449 LayerDisable(MLIRContext *context) {
2450 pm = std::make_unique<mlir::PassManager>(
2451 context, "builtin.module", mlir::OpPassManager::Nesting::Explicit);
2452 pm->nest<firrtl::CircuitOp>().addPass(firrtl::createSpecializeLayers());
2453 };
2454
2455 void beforeReduction(mlir::ModuleOp op) override { symbolRefAttrMap.clear(); }
2456
2457 void afterReduction(mlir::ModuleOp op) override { (void)pm->run(op); };
2458
2459 void matches(CircuitOp circuitOp,
2460 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2461 uint64_t matchId = 0;
2462
2463 SmallVector<FlatSymbolRefAttr> nestedRefs;
2464 std::function<void(StringAttr, LayerOp)> addLayer = [&](StringAttr rootRef,
2465 LayerOp layerOp) {
2466 if (!rootRef)
2467 rootRef = layerOp.getSymNameAttr();
2468 else
2469 nestedRefs.push_back(FlatSymbolRefAttr::get(layerOp));
2470
2471 symbolRefAttrMap[matchId] = SymbolRefAttr::get(rootRef, nestedRefs);
2472 addMatch(1, matchId++);
2473
2474 for (auto nestedLayerOp : layerOp.getOps<LayerOp>())
2475 addLayer(rootRef, nestedLayerOp);
2476
2477 if (!nestedRefs.empty())
2478 nestedRefs.pop_back();
2479 };
2480
2481 for (auto layerOp : circuitOp.getOps<LayerOp>())
2482 addLayer({}, layerOp);
2483 }
2484
2485 LogicalResult rewriteMatches(CircuitOp circuitOp,
2486 ArrayRef<uint64_t> matches) override {
2487 SmallVector<Attribute> disableLayers;
2488 if (auto existingDisables = circuitOp.getDisableLayersAttr()) {
2489 auto disableRange = existingDisables.getAsRange<Attribute>();
2490 disableLayers.append(disableRange.begin(), disableRange.end());
2491 }
2492 for (auto match : matches)
2493 disableLayers.push_back(symbolRefAttrMap.at(match));
2494
2495 circuitOp.setDisableLayersAttr(
2496 ArrayAttr::get(circuitOp.getContext(), disableLayers));
2497
2498 return success();
2499 }
2500
2501 std::string getName() const override { return "firrtl-layer-disable"; }
2502
2503 std::unique_ptr<mlir::PassManager> pm;
2504 DenseMap<uint64_t, SymbolRefAttr> symbolRefAttrMap;
2505};
2506
2507} // namespace
2508
2509/// A reduction pattern that removes elements from FIRRTL list create
2510/// operations. This generates one match per element in each list, allowing
2511/// selective removal of individual elements.
2512struct ListCreateElementRemover : public OpReduction<ListCreateOp> {
2513 void matches(ListCreateOp listOp,
2514 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
2515 // Create one match for each element in the list
2516 auto elements = listOp.getElements();
2517 for (size_t i = 0; i < elements.size(); ++i)
2518 addMatch(1, i);
2519 }
2520
2521 LogicalResult rewriteMatches(ListCreateOp listOp,
2522 ArrayRef<uint64_t> matches) override {
2523 // Convert matches to a set for fast lookup
2524 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
2525
2526 // Collect elements that should be kept (not in matches)
2527 SmallVector<Value> newElements;
2528 auto elements = listOp.getElements();
2529 for (size_t i = 0; i < elements.size(); ++i) {
2530 if (!matchesSet.contains(i))
2531 newElements.push_back(elements[i]);
2532 }
2533
2534 // Create a new list with the remaining elements
2535 OpBuilder builder(listOp);
2536 auto newListOp = ListCreateOp::create(builder, listOp.getLoc(),
2537 listOp.getType(), newElements);
2538 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
2539 listOp.erase();
2540
2541 return success();
2542 }
2543
2544 std::string getName() const override {
2545 return "firrtl-list-create-element-remover";
2546 }
2547};
2548
2549//===----------------------------------------------------------------------===//
2550// Reduction Registration
2551//===----------------------------------------------------------------------===//
2552
2555 // Gather a list of reduction patterns that we should try. Ideally these are
2556 // assigned reasonable benefit indicators (higher benefit patterns are
2557 // prioritized). For example, things that can knock out entire modules while
2558 // being cheap should be tried first (and thus have higher benefit), before
2559 // trying to tweak operands of individual arithmetic ops.
2560 patterns.add<SimplifyResets, 35>();
2561 patterns.add<ForceDedup, 34>();
2562 patterns.add<MustDedupChildren, 33>();
2563 patterns.add<AnnotationRemover, 32>();
2564 patterns.add<ModuleSwapper, 31>();
2565 patterns.add<LayerDisable, 30>(getContext());
2566 patterns.add<PassReduction, 29>(
2567 getContext(),
2568 firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,
2569 true);
2570 patterns.add<PassReduction, 28>(getContext(),
2571 firrtl::createLowerCHIRRTLPass(), true, true);
2572 patterns.add<PassReduction, 27>(getContext(), firrtl::createInferWidths(),
2573 true, true);
2574 patterns.add<PassReduction, 26>(getContext(), firrtl::createInferResets(),
2575 true, true);
2576 patterns.add<FIRRTLModuleExternalizer, 25>();
2577 patterns.add<InstanceStubber, 24>();
2578 patterns.add<MemoryStubber, 23>();
2579 patterns.add<EagerInliner, 22>();
2580 patterns.add<ObjectInliner, 22>();
2581 patterns.add<PassReduction, 21>(getContext(),
2582 firrtl::createLowerFIRRTLTypes(), true, true);
2583 patterns.add<PassReduction, 20>(getContext(), firrtl::createExpandWhens(),
2584 true, true);
2585 patterns.add<PassReduction, 19>(getContext(), firrtl::createInliner());
2586 patterns.add<PassReduction, 18>(getContext(), firrtl::createIMConstProp());
2587 patterns.add<PassReduction, 17>(
2588 getContext(),
2589 firrtl::createRemoveUnusedPorts({/*ignoreDontTouch=*/true}));
2590 patterns.add<NodeSymbolRemover, 15>();
2591 patterns.add<ConnectForwarder, 14>();
2592 patterns.add<ConnectInvalidator, 13>();
2593 patterns.add<Constantifier, 12>();
2594 patterns.add<FIRRTLOperandForwarder<0>, 11>();
2595 patterns.add<FIRRTLOperandForwarder<1>, 10>();
2596 patterns.add<FIRRTLOperandForwarder<2>, 9>();
2598 patterns.add<DetachSubaccesses, 7>();
2599 patterns.add<ModulePortPruner, 7>();
2600 patterns.add<ExtmodulePortPruner, 6>();
2601 patterns.add<RootPortPruner, 5>();
2602 patterns.add<RootExtmodulePortPruner, 5>();
2603 patterns.add<ExtmoduleInstanceRemover, 4>();
2604 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
2605 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
2606 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
2608 patterns.add<ModuleNameSanitizer, 0>();
2609}
2610
2612 mlir::DialectRegistry &registry) {
2613 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
2614 dialect->addInterfaces<FIRRTLReducePatternDialectInterface>();
2615 });
2616}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
Attribute getAttr() const
Get the underlying attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
Definition NLATable.cpp:41
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
Definition NLATable.cpp:58
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Definition NLATable.cpp:68
Helper class to cache tie-off values for different FIRRTL types.
Definition FIRRTLUtils.h:62
Value getInvalid(FIRRTLBaseType type)
Get or create an InvalidValueOp for the given base type.
Value getUnknown(PropertyType type)
Get or create an UnknownValueOp for the given property type.
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
InnerSymbolTable & getInnerSymbolTable(Operation *op)
Get or create the InnerSymbolTable for the specified operation.
static RetTy walkSymbols(Operation *op, FuncTy &&callback)
Walk the given IST operation and invoke the callback for all encountered inner symbols.
A utility class that generates metasyntactic variable names for use in reductions.
void reset()
Reset the generator to start from the beginning of the sequence.
const char * getNextName()
Get the next metasyntactic name in the sequence.
connect(destination, source)
Definition support.py:39
@ None
Don't explicitly preserve any named values.
Definition Passes.h:52
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
Definition LayerSet.h:43
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition HWOps.cpp:36
void info(Twine message)
Definition LSPUtils.cpp:20
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A reduction pattern that removes elements from FIRRTL list create operations.
LogicalResult rewriteMatches(ListCreateOp listOp, ArrayRef< uint64_t > matches) override
void matches(ListCreateOp listOp, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
std::string getName() const override
Return a human-readable name for this reduction pattern.
Pseudo-reduction that sanitizes the names of operations inside modules.
Pseudo-reduction that sanitizes module and port names.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
A reduction pattern that applies an mlir::Pass.
Definition Reduction.h:142
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition Reduction.h:35
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:79
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:66
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:96
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:50
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)