CIRCT 22.0.0git
Loading...
Searching...
No Matches
FIRRTLReductions.cpp
Go to the documentation of this file.
1//===- FIRRTLReductions.cpp - Reduction patterns for the FIRRTL dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
24#include "mlir/IR/ImplicitLocOpBuilder.h"
25#include "llvm/ADT/APSInt.h"
26#include "llvm/ADT/SmallSet.h"
27#include "llvm/Support/Debug.h"
28
29#define DEBUG_TYPE "firrtl-reductions"
30
31using namespace mlir;
32using namespace circt;
33using namespace firrtl;
34using llvm::MapVector;
35using llvm::SmallSetVector;
36
37//===----------------------------------------------------------------------===//
38// Utilities
39//===----------------------------------------------------------------------===//
40
41namespace detail {
42/// A utility doing lazy construction of `SymbolTable`s and `SymbolUserMap`s,
43/// which is handy for reductions that need to look up a lot of symbols.
45 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
46
47 SymbolTable &getSymbolTable(Operation *op) {
48 return tables->getSymbolTable(op);
49 }
50 SymbolTable &getNearestSymbolTable(Operation *op) {
51 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
52 }
53
54 SymbolUserMap &getSymbolUserMap(Operation *op) {
55 auto it = userMaps.find(op);
56 if (it != userMaps.end())
57 return it->second;
58 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
59 }
60 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
61 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
62 }
63
64 void clear() {
65 tables = std::make_unique<SymbolTableCollection>();
66 userMaps.clear();
67 }
68
69private:
70 std::unique_ptr<SymbolTableCollection> tables;
72};
73} // namespace detail
74
75/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
76/// optional in case another type of module is instantiated.
77static std::optional<firrtl::FModuleOp>
78findInstantiatedModule(firrtl::InstanceOp instOp,
79 ::detail::SymbolCache &symbols) {
80 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
81 auto moduleOp = dyn_cast<firrtl::FModuleOp>(
82 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
83 return moduleOp ? std::optional(moduleOp) : std::nullopt;
84}
85
86/// Utility to track the transitive size of modules.
88 void clear() { moduleSizes.clear(); }
89
90 uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols) {
91 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
92 return it->second;
93 uint64_t size = 1;
94 module->walk([&](Operation *op) {
95 size += 1;
96 if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
97 if (auto instModule = findInstantiatedModule(instOp, symbols))
98 size += getModuleSize(*instModule, symbols);
99 });
100 moduleSizes.insert({module, size});
101 return size;
102 }
103
104private:
105 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
106};
107
108/// Check that all connections to a value are invalids.
109static bool onlyInvalidated(Value arg) {
110 return llvm::all_of(arg.getUses(), [](OpOperand &use) {
111 auto *op = use.getOwner();
112 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
113 return false;
114 if (use.getOperandNumber() != 0)
115 return false;
116 if (!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>())
117 return false;
118 return true;
119 });
120}
121
122/// A tracker for track NLAs affected by a reduction. Performs the necessary
123/// cleanup steps in order to maintain IR validity after the reduction has
124/// applied. For example, removing an instance that forms part of an NLA path
125/// requires that NLA to be removed as well.
127 /// Clear the set of marked NLAs. Call this before attempting a reduction.
128 void clear() { nlasToRemove.clear(); }
129
130 /// Remove all marked annotations. Call this after applying a reduction in
131 /// order to validate the IR.
132 void remove(mlir::ModuleOp module) {
133 unsigned numRemoved = 0;
134 (void)numRemoved;
135 SymbolTableCollection symbolTables;
136 for (Operation &rootOp : *module.getBody()) {
137 if (!isa<firrtl::CircuitOp>(&rootOp))
138 continue;
139 SymbolUserMap symbolUserMap(symbolTables, &rootOp);
140 auto &symbolTable = symbolTables.getSymbolTable(&rootOp);
141 for (auto sym : nlasToRemove) {
142 if (auto *op = symbolTable.lookup(sym)) {
143 if (symbolUserMap.useEmpty(op)) {
144 ++numRemoved;
145 op->erase();
146 }
147 }
148 }
149 }
150 LLVM_DEBUG({
151 unsigned numLost = nlasToRemove.size() - numRemoved;
152 if (numRemoved > 0 || numLost > 0) {
153 llvm::dbgs() << "Removed " << numRemoved << " NLAs";
154 if (numLost > 0)
155 llvm::dbgs() << " (" << numLost << " no longer there)";
156 llvm::dbgs() << "\n";
157 }
158 });
159 }
160
161 /// Mark all NLAs referenced in the given annotation as to be removed. This
162 /// can be an entire array or dictionary of annotations, and the function will
163 /// descend into child annotations appropriately.
164 void markNLAsInAnnotation(Attribute anno) {
165 if (auto dict = dyn_cast<DictionaryAttr>(anno)) {
166 if (auto field = dict.getAs<FlatSymbolRefAttr>("circt.nonlocal"))
167 nlasToRemove.insert(field.getAttr());
168 for (auto namedAttr : dict)
169 markNLAsInAnnotation(namedAttr.getValue());
170 } else if (auto array = dyn_cast<ArrayAttr>(anno)) {
171 for (auto attr : array)
172 markNLAsInAnnotation(attr);
173 }
174 }
175
176 /// Mark all NLAs referenced in an operation. Also traverses all nested
177 /// operations. Call this before removing an operation, to mark any associated
178 /// NLAs as to be removed as well.
179 void markNLAsInOperation(Operation *op) {
180 op->walk([&](Operation *op) {
181 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
182 markNLAsInAnnotation(annos);
183 });
184 }
185
186 /// The set of NLAs to remove, identified by their symbol.
187 llvm::DenseSet<StringAttr> nlasToRemove;
188};
189
190//===----------------------------------------------------------------------===//
191// Reduction patterns
192//===----------------------------------------------------------------------===//
193
194namespace {
195
196/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
197struct FIRRTLModuleExternalizer : public OpReduction<FModuleOp> {
198 void beforeReduction(mlir::ModuleOp op) override {
199 nlaRemover.clear();
200 symbols.clear();
201 moduleSizes.clear();
202 innerSymUses = reduce::InnerSymbolUses(op);
203 }
204 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
205
206 uint64_t match(FModuleOp module) override {
207 if (innerSymUses.hasInnerRef(module))
208 return 0;
209 return moduleSizes.getModuleSize(module, symbols);
210 }
211
212 LogicalResult rewrite(FModuleOp module) override {
213 // Hack up a list of known layers.
214 LayerSet layers;
215 layers.insert_range(module.getLayersAttr().getAsRange<SymbolRefAttr>());
216 for (auto attr : module.getPortTypes()) {
217 auto type = cast<TypeAttr>(attr).getValue();
218 if (auto refType = type_dyn_cast<RefType>(type))
219 if (auto layer = refType.getLayer())
220 layers.insert(layer);
221 }
222 SmallVector<Attribute, 4> layersArray;
223 layersArray.reserve(layers.size());
224 for (auto layer : layers)
225 layersArray.push_back(layer);
226
227 nlaRemover.markNLAsInOperation(module);
228 OpBuilder builder(module);
229 auto extmodule = FExtModuleOp::create(
230 builder, module->getLoc(),
231 module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()),
232 module.getConventionAttr(), module.getPorts(),
233 builder.getArrayAttr(layersArray), StringRef(),
234 module.getAnnotationsAttr());
235 SymbolTable::setSymbolVisibility(extmodule,
236 SymbolTable::getSymbolVisibility(module));
237 module->erase();
238 return success();
239 }
240
241 std::string getName() const override { return "firrtl-module-externalizer"; }
242
243 ::detail::SymbolCache symbols;
244 NLARemover nlaRemover;
245 reduce::InnerSymbolUses innerSymUses;
246 ModuleSizeCache moduleSizes;
247};
248
249/// Invalidate all the leaf fields of a value with a given flippedness by
250/// connecting an invalid value to them. This is useful for ensuring that all
251/// output ports of an instance or memory (including those nested in bundles)
252/// are properly invalidated.
253static void invalidateOutputs(ImplicitLocOpBuilder &builder, Value value,
254 SmallDenseMap<Type, Value, 8> &invalidCache,
255 bool flip = false) {
256 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
257 if (!type)
258 return;
259
260 // Descend into bundles by creating subfield ops.
261 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
262 for (auto element : llvm::enumerate(bundleType.getElements())) {
263 auto subfield =
264 builder.createOrFold<firrtl::SubfieldOp>(value, element.index());
265 invalidateOutputs(builder, subfield, invalidCache,
266 flip ^ element.value().isFlip);
267 if (subfield.use_empty())
268 subfield.getDefiningOp()->erase();
269 }
270 return;
271 }
272
273 // Descend into vectors by creating subindex ops.
274 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
275 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i) {
276 auto subindex = builder.createOrFold<firrtl::SubindexOp>(value, i);
277 invalidateOutputs(builder, subindex, invalidCache, flip);
278 if (subindex.use_empty())
279 subindex.getDefiningOp()->erase();
280 }
281 return;
282 }
283
284 // Only drive outputs.
285 if (flip)
286 return;
287 Value invalid = invalidCache.lookup(type);
288 if (!invalid) {
289 invalid = firrtl::InvalidValueOp::create(builder, type);
290 invalidCache.insert({type, invalid});
291 }
292 firrtl::ConnectOp::create(builder, value, invalid);
293}
294
295/// Connect a value to every leave of a destination value.
296static void connectToLeafs(ImplicitLocOpBuilder &builder, Value dest,
297 Value value) {
298 auto type = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
299 if (!type)
300 return;
301 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
302 for (auto element : llvm::enumerate(bundleType.getElements()))
303 connectToLeafs(builder,
304 firrtl::SubfieldOp::create(builder, dest, element.index()),
305 value);
306 return;
307 }
308 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
309 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
310 connectToLeafs(builder, firrtl::SubindexOp::create(builder, dest, i),
311 value);
312 return;
313 }
314 auto valueType = dyn_cast<firrtl::FIRRTLBaseType>(value.getType());
315 if (!valueType)
316 return;
317 auto destWidth = type.getBitWidthOrSentinel();
318 auto valueWidth = valueType ? valueType.getBitWidthOrSentinel() : -1;
319 if (destWidth >= 0 && valueWidth >= 0 && destWidth < valueWidth)
320 value = firrtl::HeadPrimOp::create(builder, value, destWidth);
321 if (!isa<firrtl::UIntType>(type)) {
322 if (isa<firrtl::SIntType>(type))
323 value = firrtl::AsSIntPrimOp::create(builder, value);
324 else
325 return;
326 }
327 firrtl::ConnectOp::create(builder, dest, value);
328}
329
330/// Reduce all leaf fields of a value through an XOR tree.
331static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
332 auto type = dyn_cast<firrtl::FIRRTLType>(value.getType());
333 if (!type)
334 return;
335 if (auto bundleType = dyn_cast<firrtl::BundleType>(type)) {
336 for (auto element : llvm::enumerate(bundleType.getElements()))
337 reduceXor(
338 builder, into,
339 builder.createOrFold<firrtl::SubfieldOp>(value, element.index()));
340 return;
341 }
342 if (auto vectorType = dyn_cast<firrtl::FVectorType>(type)) {
343 for (unsigned i = 0, e = vectorType.getNumElements(); i != e; ++i)
344 reduceXor(builder, into,
345 builder.createOrFold<firrtl::SubindexOp>(value, i));
346 return;
347 }
348 if (!isa<firrtl::UIntType>(type)) {
349 if (isa<firrtl::SIntType>(type))
350 value = firrtl::AsUIntPrimOp::create(builder, value);
351 else
352 return;
353 }
354 into = into ? builder.createOrFold<firrtl::XorPrimOp>(into, value) : value;
355}
356
357/// A sample reduction pattern that maps `firrtl.instance` to a set of
358/// invalidated wires. This often shortcuts a long iterative process of connect
359/// invalidation, module externalization, and wire stripping
360struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
361 void beforeReduction(mlir::ModuleOp op) override {
362 erasedInsts.clear();
363 erasedModules.clear();
364 symbols.clear();
365 nlaRemover.clear();
366 moduleSizes.clear();
367 }
368 void afterReduction(mlir::ModuleOp op) override {
369 // Look into deleted modules to find additional instances that are no longer
370 // instantiated anywhere.
371 SmallVector<Operation *> worklist;
372 auto deadInsts = erasedInsts;
373 for (auto *op : erasedModules)
374 worklist.push_back(op);
375 while (!worklist.empty()) {
376 auto *op = worklist.pop_back_val();
377 auto *tableOp = SymbolTable::getNearestSymbolTable(op);
378 op->walk([&](firrtl::InstanceOp instOp) {
379 auto moduleOp = cast<firrtl::FModuleLike>(
380 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
381 deadInsts.insert(instOp);
382 if (llvm::all_of(
383 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
384 [&](Operation *user) { return deadInsts.contains(user); })) {
385 LLVM_DEBUG(llvm::dbgs() << "- Removing transitively unused module `"
386 << moduleOp.getModuleName() << "`\n");
387 erasedModules.insert(moduleOp);
388 worklist.push_back(moduleOp);
389 }
390 });
391 }
392
393 for (auto *op : erasedInsts)
394 op->erase();
395 for (auto *op : erasedModules)
396 op->erase();
397 nlaRemover.remove(op);
398 }
399
400 uint64_t match(firrtl::InstanceOp instOp) override {
401 if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
402 return moduleSizes.getModuleSize(*fmoduleOp, symbols);
403 return 0;
404 }
405
406 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
407 LLVM_DEBUG(llvm::dbgs()
408 << "Stubbing instance `" << instOp.getName() << "`\n");
409 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
411 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
412 auto result = instOp.getResult(i);
413 auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" +
414 instOp.getPortNameStr(i));
415 auto wire =
416 firrtl::WireOp::create(builder, result.getType(), name,
417 firrtl::NameKindEnum::DroppableName,
418 instOp.getPortAnnotation(i), StringAttr{})
419 .getResult();
420 invalidateOutputs(builder, wire, invalidCache,
421 instOp.getPortDirection(i) == firrtl::Direction::In);
422 result.replaceAllUsesWith(wire);
423 }
424 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
425 auto moduleOp = cast<firrtl::FModuleLike>(
426 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
427 nlaRemover.markNLAsInOperation(instOp);
428 erasedInsts.insert(instOp);
429 if (llvm::all_of(
430 symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
431 [&](Operation *user) { return erasedInsts.contains(user); })) {
432 LLVM_DEBUG(llvm::dbgs() << "- Removing now unused module `"
433 << moduleOp.getModuleName() << "`\n");
434 erasedModules.insert(moduleOp);
435 }
436 return success();
437 }
438
439 std::string getName() const override { return "instance-stubber"; }
440 bool acceptSizeIncrease() const override { return true; }
441
442 ::detail::SymbolCache symbols;
443 NLARemover nlaRemover;
444 llvm::DenseSet<Operation *> erasedInsts;
445 llvm::DenseSet<Operation *> erasedModules;
446 ModuleSizeCache moduleSizes;
447};
448
449/// A sample reduction pattern that maps `firrtl.mem` to a set of invalidated
450/// wires.
451struct MemoryStubber : public OpReduction<firrtl::MemOp> {
452 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
453 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
454 LogicalResult rewrite(firrtl::MemOp memOp) override {
455 LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
456 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
458 Value xorInputs;
459 SmallVector<Value> outputs;
460 for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) {
461 auto result = memOp.getResult(i);
462 auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" +
463 memOp.getPortNameStr(i));
464 auto wire =
465 firrtl::WireOp::create(builder, result.getType(), name,
466 firrtl::NameKindEnum::DroppableName,
467 memOp.getPortAnnotation(i), StringAttr{})
468 .getResult();
469 invalidateOutputs(builder, wire, invalidCache, true);
470 result.replaceAllUsesWith(wire);
471
472 // Isolate the input and output data fields of the port.
473 Value input, output;
474 switch (memOp.getPortKind(i)) {
475 case firrtl::MemOp::PortKind::Read:
476 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
477 break;
478 case firrtl::MemOp::PortKind::Write:
479 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
480 break;
481 case firrtl::MemOp::PortKind::ReadWrite:
482 input = builder.createOrFold<firrtl::SubfieldOp>(wire, 5);
483 output = builder.createOrFold<firrtl::SubfieldOp>(wire, 3);
484 break;
485 case firrtl::MemOp::PortKind::Debug:
486 output = wire;
487 break;
488 }
489
490 if (!isa<firrtl::RefType>(result.getType())) {
491 // Reduce all input ports to a single one through an XOR tree.
492 unsigned numFields =
493 cast<firrtl::BundleType>(wire.getType()).getNumElements();
494 for (unsigned i = 0; i != numFields; ++i) {
495 if (i != 2 && i != 3 && i != 5)
496 reduceXor(builder, xorInputs,
497 builder.createOrFold<firrtl::SubfieldOp>(wire, i));
498 }
499 if (input)
500 reduceXor(builder, xorInputs, input);
501 }
502
503 // Track the output port to hook it up to the XORd input later.
504 if (output)
505 outputs.push_back(output);
506 }
507
508 // Hook up the outputs.
509 for (auto output : outputs)
510 connectToLeafs(builder, output, xorInputs);
511
512 nlaRemover.markNLAsInOperation(memOp);
513 memOp->erase();
514 return success();
515 }
516 std::string getName() const override { return "memory-stubber"; }
517 bool acceptSizeIncrease() const override { return true; }
518 NLARemover nlaRemover;
519};
520
521/// Check whether an operation interacts with flows in any way, which can make
522/// replacement and operand forwarding harder in some cases.
523static bool isFlowSensitiveOp(Operation *op) {
524 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
525 firrtl::InstanceOp, firrtl::SubfieldOp, firrtl::SubindexOp,
526 firrtl::SubaccessOp>(op);
527}
528
529/// A sample reduction pattern that replaces all uses of an operation with one
530/// of its operands. This can help pruning large parts of the expression tree
531/// rapidly.
532template <unsigned OpNum>
533struct FIRRTLOperandForwarder : public Reduction {
534 uint64_t match(Operation *op) override {
535 if (op->getNumResults() != 1 || OpNum >= op->getNumOperands())
536 return 0;
537 if (isFlowSensitiveOp(op))
538 return 0;
539 auto resultTy =
540 dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
541 auto opTy =
542 dyn_cast<firrtl::FIRRTLBaseType>(op->getOperand(OpNum).getType());
543 return resultTy && opTy &&
544 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
545 (resultTy.getBitWidthOrSentinel() == -1) ==
546 (opTy.getBitWidthOrSentinel() == -1) &&
547 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
548 }
549 LogicalResult rewrite(Operation *op) override {
550 assert(match(op));
551 ImplicitLocOpBuilder builder(op->getLoc(), op);
552 auto result = op->getResult(0);
553 auto operand = op->getOperand(OpNum);
554 auto resultTy = cast<firrtl::FIRRTLBaseType>(result.getType());
555 auto operandTy = cast<firrtl::FIRRTLBaseType>(operand.getType());
556 auto resultWidth = resultTy.getBitWidthOrSentinel();
557 auto operandWidth = operandTy.getBitWidthOrSentinel();
558 Value newOp;
559 if (resultWidth < operandWidth)
560 newOp =
561 builder.createOrFold<firrtl::BitsPrimOp>(operand, resultWidth - 1, 0);
562 else if (resultWidth > operandWidth)
563 newOp = builder.createOrFold<firrtl::PadPrimOp>(operand, resultWidth);
564 else
565 newOp = operand;
566 LLVM_DEBUG(llvm::dbgs() << "Forwarding " << newOp << " in " << *op << "\n");
567 result.replaceAllUsesWith(newOp);
568 reduce::pruneUnusedOps(op, *this);
569 return success();
570 }
571 std::string getName() const override {
572 return ("firrtl-operand" + Twine(OpNum) + "-forwarder").str();
573 }
574};
575
576/// A sample reduction pattern that replaces FIRRTL operations with a constant
577/// zero of their type.
578struct FIRRTLConstantifier : public Reduction {
579 uint64_t match(Operation *op) override {
580 if (op->hasTrait<OpTrait::ConstantLike>())
581 return 0;
582 if (op->getNumResults() != 1 || op->getNumOperands() == 0)
583 return 0;
584 if (op->hasAttr("inner_sym"))
585 return 0;
586 if (isFlowSensitiveOp(op))
587 return 0;
588 auto type = dyn_cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
589 return isa_and_nonnull<firrtl::UIntType, firrtl::SIntType>(type);
590 }
591 LogicalResult rewrite(Operation *op) override {
592 assert(match(op));
593 OpBuilder builder(op);
594 auto type = cast<firrtl::FIRRTLBaseType>(op->getResult(0).getType());
595 auto width = type.getBitWidthOrSentinel();
596 if (width == -1)
597 width = 64;
598 auto newOp =
599 firrtl::ConstantOp::create(builder, op->getLoc(), type,
600 APSInt(width, isa<firrtl::UIntType>(type)));
601 op->replaceAllUsesWith(newOp);
602 reduce::pruneUnusedOps(op, *this);
603 return success();
604 }
605 std::string getName() const override { return "firrtl-constantifier"; }
606};
607
608/// A sample reduction pattern that replaces the right-hand-side of
609/// `firrtl.connect` and `firrtl.matchingconnect` operations with a
610/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
611/// connects and creates opportunities for reduction in DCE/CSE.
612struct ConnectInvalidator : public Reduction {
613 uint64_t match(Operation *op) override {
614 if (!isa<FConnectLike>(op))
615 return 0;
616 if (auto *srcOp = op->getOperand(1).getDefiningOp())
617 if (srcOp->hasTrait<OpTrait::ConstantLike>() ||
618 isa<InvalidValueOp>(srcOp))
619 return 0;
620 auto type = dyn_cast<FIRRTLBaseType>(op->getOperand(1).getType());
621 return type && type.isPassive();
622 }
623 LogicalResult rewrite(Operation *op) override {
624 assert(match(op));
625 auto rhs = op->getOperand(1);
626 OpBuilder builder(op);
627 auto invOp = InvalidValueOp::create(builder, rhs.getLoc(), rhs.getType());
628 auto *rhsOp = rhs.getDefiningOp();
629 op->setOperand(1, invOp);
630 if (rhsOp)
631 reduce::pruneUnusedOps(rhsOp, *this);
632 return success();
633 }
634 std::string getName() const override { return "connect-invalidator"; }
635 bool acceptSizeIncrease() const override { return true; }
636};
637
638/// A reduction pattern that removes FIRRTL annotations from ports and
639/// operations. This generates one match per annotation and port annotation,
640/// allowing selective removal of individual annotations.
641struct AnnotationRemover : public Reduction {
642 void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
643 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
644
645 void matches(Operation *op,
646 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
647 uint64_t matchId = 0;
648
649 // Generate matches for regular annotations
650 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
651 for (unsigned i = 0; i < annos.size(); ++i)
652 addMatch(1, matchId++);
653
654 // Generate matches for port annotations
655 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations"))
656 for (auto portAnnoArray : portAnnos)
657 if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray))
658 for (unsigned i = 0; i < portAnnoArrayAttr.size(); ++i)
659 addMatch(1, matchId++);
660 }
661
662 LogicalResult rewriteMatches(Operation *op,
663 ArrayRef<uint64_t> matches) override {
664 // Convert matches to a set for fast lookup
665 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
666
667 // Lambda to process annotations and filter out matched ones
668 uint64_t matchId = 0;
669 auto processAnnotations =
670 [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
671 SmallVector<Attribute> newAnnotations;
672 for (auto anno : annotations) {
673 if (!matchesSet.contains(matchId)) {
674 newAnnotations.push_back(anno);
675 } else {
676 // Mark NLAs in the removed annotation for cleanup
677 nlaRemover.markNLAsInAnnotation(anno);
678 }
679 matchId++;
680 }
681 return ArrayAttr::get(op->getContext(), newAnnotations);
682 };
683
684 // Remove regular annotations
685 if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
686 op->setAttr("annotations", processAnnotations(annos.getValue()));
687 }
688
689 // Remove port annotations
690 if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
691 SmallVector<Attribute> newPortAnnos;
692 for (auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
693 newPortAnnos.push_back(
694 processAnnotations(portAnnoArrayAttr.getValue()));
695 }
696 op->setAttr("portAnnotations",
697 ArrayAttr::get(op->getContext(), newPortAnnos));
698 }
699
700 return success();
701 }
702
703 std::string getName() const override { return "annotation-remover"; }
704 NLARemover nlaRemover;
705};
706
707/// A sample reduction pattern that removes ports from the root `firrtl.module`
708/// if the port is not used or just invalidated.
709struct RootPortPruner : public OpReduction<firrtl::FModuleOp> {
710 uint64_t match(firrtl::FModuleOp module) override {
711 auto circuit = module->getParentOfType<firrtl::CircuitOp>();
712 if (!circuit)
713 return 0;
714 return circuit.getNameAttr() == module.getNameAttr();
715 }
716 LogicalResult rewrite(firrtl::FModuleOp module) override {
717 assert(match(module));
718 size_t numPorts = module.getNumPorts();
719 llvm::BitVector dropPorts(numPorts);
720 for (unsigned i = 0; i != numPorts; ++i) {
721 if (onlyInvalidated(module.getArgument(i))) {
722 dropPorts.set(i);
723 for (auto *user :
724 llvm::make_early_inc_range(module.getArgument(i).getUsers()))
725 user->erase();
726 }
727 }
728 module.erasePorts(dropPorts);
729 return success();
730 }
731 std::string getName() const override { return "root-port-pruner"; }
732};
733
734/// A sample reduction pattern that replaces instances of `firrtl.extmodule`
735/// with wires.
736struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {
737 void beforeReduction(mlir::ModuleOp op) override {
738 symbols.clear();
739 nlaRemover.clear();
740 }
741 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
742
743 uint64_t match(firrtl::InstanceOp instOp) override {
744 return isa<firrtl::FExtModuleOp>(
745 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
746 }
747 LogicalResult rewrite(firrtl::InstanceOp instOp) override {
748 auto portInfo =
749 cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
750 symbols.getNearestSymbolTable(instOp)))
751 .getPorts();
752 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
753 SmallVector<Value> replacementWires;
754 for (firrtl::PortInfo info : portInfo) {
755 auto wire = firrtl::WireOp::create(
756 builder, info.type,
757 (Twine(instOp.getName()) + "_" + info.getName()).str())
758 .getResult();
759 if (info.isOutput()) {
760 auto inv = firrtl::InvalidValueOp::create(builder, info.type);
761 firrtl::ConnectOp::create(builder, wire, inv);
762 }
763 replacementWires.push_back(wire);
764 }
765 nlaRemover.markNLAsInOperation(instOp);
766 instOp.replaceAllUsesWith(std::move(replacementWires));
767 instOp->erase();
768 return success();
769 }
770 std::string getName() const override { return "extmodule-instance-remover"; }
771 bool acceptSizeIncrease() const override { return true; }
772
773 ::detail::SymbolCache symbols;
774 NLARemover nlaRemover;
775};
776
777/// A sample reduction pattern that pushes connected values through wires.
778struct ConnectForwarder : public Reduction {
779 uint64_t match(Operation *op) override {
780 if (!isa<firrtl::FConnectLike>(op))
781 return 0;
782 auto dest = op->getOperand(0);
783 auto src = op->getOperand(1);
784 auto *destOp = dest.getDefiningOp();
785 auto *srcOp = src.getDefiningOp();
786 if (dest == src)
787 return 0;
788
789 // Ensure that the destination is something we should be able to forward
790 // through.
791 if (!isa_and_nonnull<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(
792 destOp))
793 return 0;
794
795 // Ensure that the destination is connected to only once, and all uses of
796 // the connection occur after the definition of the source.
797 unsigned numConnects = 0;
798 for (auto &use : dest.getUses()) {
799 auto *op = use.getOwner();
800 if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
801 if (++numConnects > 1)
802 return 0;
803 continue;
804 }
805 if (srcOp && !srcOp->isBeforeInBlock(op))
806 return 0;
807 }
808
809 return 1;
810 }
811
812 LogicalResult rewrite(Operation *op) override {
813 auto dst = op->getOperand(0);
814 auto src = op->getOperand(1);
815 dst.replaceAllUsesWith(src);
816 op->erase();
817 if (auto *dstOp = dst.getDefiningOp())
818 reduce::pruneUnusedOps(dstOp, *this);
819 if (auto *srcOp = src.getDefiningOp())
820 reduce::pruneUnusedOps(srcOp, *this);
821 return success();
822 }
823
824 std::string getName() const override { return "connect-forwarder"; }
825};
826
827/// A sample reduction pattern that replaces a single-use wire and register with
828/// an operand of the source value of the connection.
829template <unsigned OpNum>
830struct ConnectSourceOperandForwarder : public Reduction {
831 uint64_t match(Operation *op) override {
832 if (!isa<firrtl::ConnectOp, firrtl::MatchingConnectOp>(op))
833 return 0;
834 auto dest = op->getOperand(0);
835 auto *destOp = dest.getDefiningOp();
836
837 // Ensure that the destination is used only once.
838 if (!destOp || !destOp->hasOneUse() ||
839 !isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
840 return 0;
841
842 auto *srcOp = op->getOperand(1).getDefiningOp();
843 if (!srcOp || OpNum >= srcOp->getNumOperands())
844 return 0;
845
846 auto resultTy = dyn_cast<firrtl::FIRRTLBaseType>(dest.getType());
847 auto opTy =
848 dyn_cast<firrtl::FIRRTLBaseType>(srcOp->getOperand(OpNum).getType());
849
850 return resultTy && opTy &&
851 resultTy.getWidthlessType() == opTy.getWidthlessType() &&
852 ((resultTy.getBitWidthOrSentinel() == -1) ==
853 (opTy.getBitWidthOrSentinel() == -1)) &&
854 isa<firrtl::UIntType, firrtl::SIntType>(resultTy);
855 }
856
857 LogicalResult rewrite(Operation *op) override {
858 auto *destOp = op->getOperand(0).getDefiningOp();
859 auto *srcOp = op->getOperand(1).getDefiningOp();
860 auto forwardedOperand = srcOp->getOperand(OpNum);
861 ImplicitLocOpBuilder builder(destOp->getLoc(), destOp);
862 Value newDest;
863 if (auto wire = dyn_cast<firrtl::WireOp>(destOp))
864 newDest = firrtl::WireOp::create(builder, forwardedOperand.getType(),
865 wire.getName())
866 .getResult();
867 else {
868 auto regName = destOp->getAttrOfType<StringAttr>("name");
869 // We can promote the register into a wire but we wouldn't do here because
870 // the error might be caused by the register.
871 auto clock = destOp->getOperand(0);
872 newDest = firrtl::RegOp::create(builder, forwardedOperand.getType(),
873 clock, regName ? regName.str() : "")
874 .getResult();
875 }
876
877 // Create new connection between a new wire and the forwarded operand.
878 builder.setInsertionPointAfter(op);
879 if (isa<firrtl::ConnectOp>(op))
880 firrtl::ConnectOp::create(builder, newDest, forwardedOperand);
881 else
882 firrtl::MatchingConnectOp::create(builder, newDest, forwardedOperand);
883
884 // Remove the old connection and destination. We don't have to replace them
885 // because destination has only one use.
886 op->erase();
887 destOp->erase();
888 reduce::pruneUnusedOps(srcOp, *this);
889
890 return success();
891 }
892 std::string getName() const override {
893 return ("connect-source-operand-" + Twine(OpNum) + "-forwarder").str();
894 }
895};
896
897/// A sample reduction pattern that tries to remove aggregate wires by replacing
898/// all subaccesses with new independent wires. This can disentangle large
899/// unused wires that are otherwise difficult to collect due to the subaccesses.
900struct DetachSubaccesses : public Reduction {
901 void beforeReduction(mlir::ModuleOp op) override { opsToErase.clear(); }
902 void afterReduction(mlir::ModuleOp op) override {
903 for (auto *op : opsToErase)
904 op->dropAllReferences();
905 for (auto *op : opsToErase)
906 op->erase();
907 }
908 uint64_t match(Operation *op) override {
909 // Only applies to wires and registers that are purely used in subaccess
910 // operations.
911 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
912 llvm::all_of(op->getUses(), [](auto &use) {
913 return use.getOperandNumber() == 0 &&
914 isa<firrtl::SubfieldOp, firrtl::SubindexOp,
915 firrtl::SubaccessOp>(use.getOwner());
916 });
917 }
918 LogicalResult rewrite(Operation *op) override {
919 assert(match(op));
920 OpBuilder builder(op);
921 bool isWire = isa<firrtl::WireOp>(op);
922 Value invalidClock;
923 if (!isWire)
924 invalidClock = firrtl::InvalidValueOp::create(
925 builder, op->getLoc(), firrtl::ClockType::get(op->getContext()));
926 for (Operation *user : llvm::make_early_inc_range(op->getUsers())) {
927 builder.setInsertionPoint(user);
928 auto type = user->getResult(0).getType();
929 Operation *replOp;
930 if (isWire)
931 replOp = firrtl::WireOp::create(builder, user->getLoc(), type);
932 else
933 replOp =
934 firrtl::RegOp::create(builder, user->getLoc(), type, invalidClock);
935 user->replaceAllUsesWith(replOp);
936 opsToErase.insert(user);
937 }
938 opsToErase.insert(op);
939 return success();
940 }
941 std::string getName() const override { return "detach-subaccesses"; }
942 llvm::DenseSet<Operation *> opsToErase;
943};
944
945/// This reduction removes inner symbols on ops. Name preservation creates a lot
946/// of node ops with symbols to keep name information but it also prevents
947/// normal canonicalizations.
948struct NodeSymbolRemover : public Reduction {
949 void beforeReduction(mlir::ModuleOp op) override {
950 innerSymUses = reduce::InnerSymbolUses(op);
951 }
952
953 uint64_t match(Operation *op) override {
954 // Only match ops with an inner symbol.
955 auto sym = op->getAttrOfType<hw::InnerSymAttr>("inner_sym");
956 if (!sym || sym.empty())
957 return 0;
958
959 // Only match ops that have no references to their inner symbol.
960 if (innerSymUses.hasInnerRef(op))
961 return 0;
962 return 1;
963 }
964
965 LogicalResult rewrite(Operation *op) override {
966 op->removeAttr("inner_sym");
967 return success();
968 }
969
970 std::string getName() const override { return "node-symbol-remover"; }
971 bool acceptSizeIncrease() const override { return true; }
972
973 reduce::InnerSymbolUses innerSymUses;
974};
975
976/// A sample reduction pattern that eagerly inlines instances.
977struct EagerInliner : public OpReduction<InstanceOp> {
978 void beforeReduction(mlir::ModuleOp op) override {
979 symbols.clear();
980 nlaRemover.clear();
981 nlaTable = std::make_unique<NLATable>(op);
982 innerSymTables = std::make_unique<hw::InnerSymbolTableCollection>();
983 }
984 void afterReduction(mlir::ModuleOp op) override {
985 nlaRemover.remove(op);
986 nlaTable.reset();
987 innerSymTables.reset();
988 }
989
990 uint64_t match(InstanceOp instOp) override {
991 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
992 auto *moduleOp =
993 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
994
995 // Only inline FModuleOp instances
996 if (!isa<FModuleOp>(moduleOp))
997 return 0;
998
999 // Skip instances that participate in any NLAs
1000 DenseSet<hw::HierPathOp> nlas;
1001 nlaTable->getInstanceNLAs(instOp, nlas);
1002 if (!nlas.empty())
1003 return 0;
1004
1005 // Check for inner symbol collisions between the referenced module and the
1006 // instance's parent module
1007 auto referencedModule = cast<FModuleOp>(moduleOp);
1008 auto parentModule = instOp->getParentOfType<FModuleOp>();
1009 if (hasInnerSymbolCollisions(referencedModule, parentModule))
1010 return 0;
1011
1012 return 1;
1013 }
1014
1015 LogicalResult rewrite(InstanceOp instOp) override {
1016 auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
1017 auto moduleOp = cast<FModuleOp>(
1018 instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
1019 bool isLastUse =
1020 (symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1);
1021 auto clonedModuleOp = isLastUse ? moduleOp : moduleOp.clone();
1022
1023 // Create wires to replace the instance results.
1024 IRRewriter rewriter(instOp);
1025 SmallVector<Value> argWires;
1026 for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) {
1027 auto result = instOp.getResult(i);
1028 auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" +
1029 instOp.getPortNameStr(i));
1030 auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(),
1031 name, NameKindEnum::DroppableName,
1032 instOp.getPortAnnotation(i), StringAttr{})
1033 .getResult();
1034 result.replaceAllUsesWith(wire);
1035 argWires.push_back(wire);
1036 }
1037
1038 // Splice in the cloned module body.
1039 rewriter.inlineBlockBefore(clonedModuleOp.getBodyBlock(), instOp, argWires);
1040
1041 // Make sure we remove any NLAs that go through this instance, and the
1042 // module if we're about the delete the module.
1043 nlaRemover.markNLAsInOperation(instOp);
1044 if (isLastUse)
1045 nlaRemover.markNLAsInOperation(moduleOp);
1046
1047 instOp.erase();
1048 clonedModuleOp.erase();
1049 return success();
1050 }
1051
1052 /// Check if inlining the referenced module into the parent module would cause
1053 /// inner symbol collisions.
1054 bool hasInnerSymbolCollisions(FModuleOp referencedModule,
1055 FModuleOp parentModule) {
1056 // Get the inner symbol tables for both modules
1057 auto &targetTable = innerSymTables->getInnerSymbolTable(referencedModule);
1058 auto &parentTable = innerSymTables->getInnerSymbolTable(parentModule);
1059
1060 // Check if any inner symbol name in the target module already exists
1061 // in the parent module. Return failure() if a collision is found to stop
1062 // the walk early.
1063 LogicalResult walkResult = targetTable.walkSymbols(
1064 [&](StringAttr name,
1065 const hw::InnerSymTarget &target) -> LogicalResult {
1066 // Check if this symbol name exists in the parent module
1067 if (parentTable.lookup(name)) {
1068 // Collision found, return failure to stop the walk
1069 return failure();
1070 }
1071 return success();
1072 });
1073
1074 // If the walk failed, it means we found a collision
1075 return failed(walkResult);
1076 }
1077
1078 std::string getName() const override { return "firrtl-eager-inliner"; }
1079 bool acceptSizeIncrease() const override { return true; }
1080
1081 ::detail::SymbolCache symbols;
1082 NLARemover nlaRemover;
1083 std::unique_ptr<NLATable> nlaTable;
1084 std::unique_ptr<hw::InnerSymbolTableCollection> innerSymTables;
1085};
1086
1087/// Psuedo-reduction that sanitizes the names of things inside modules. This is
1088/// not an actual reduction, but often removes extraneous information that has
1089/// no bearing on the actual reduction (and would likely be removed before
1090/// sharing the reduction). This makes the following changes:
1091///
1092/// - All wires are renamed to "wire"
1093/// - All registers are renamed to "reg"
1094/// - All nodes are renamed to "node"
1095/// - All memories are renamed to "mem"
1096/// - All verification messages and labels are dropped
1097///
1098struct ModuleInternalNameSanitizer : public Reduction {
1099 uint64_t match(Operation *op) override {
1100 // Only match operations with names.
1101 return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp,
1102 firrtl::NodeOp, firrtl::MemOp, chirrtl::CombMemOp,
1103 chirrtl::SeqMemOp, firrtl::AssertOp, firrtl::AssumeOp,
1104 firrtl::CoverOp>(op);
1105 }
1106 LogicalResult rewrite(Operation *op) override {
1107 TypeSwitch<Operation *, void>(op)
1108 .Case<firrtl::WireOp>([](auto op) { op.setName("wire"); })
1109 .Case<firrtl::RegOp, firrtl::RegResetOp>(
1110 [](auto op) { op.setName("reg"); })
1111 .Case<firrtl::NodeOp>([](auto op) { op.setName("node"); })
1112 .Case<firrtl::MemOp, chirrtl::CombMemOp, chirrtl::SeqMemOp>(
1113 [](auto op) { op.setName("mem"); })
1114 .Case<firrtl::AssertOp, firrtl::AssumeOp, firrtl::CoverOp>([](auto op) {
1115 op->setAttr("message", StringAttr::get(op.getContext(), ""));
1116 op->setAttr("name", StringAttr::get(op.getContext(), ""));
1117 });
1118 return success();
1119 }
1120
1121 std::string getName() const override {
1122 return "module-internal-name-sanitizer";
1123 }
1124
1125 bool acceptSizeIncrease() const override { return true; }
1126
1127 bool isOneShot() const override { return true; }
1128};
1129
1130/// Psuedo-reduction that sanitizes module, instance, and port names. This
1131/// makes the following changes:
1132///
1133/// - All modules are given metasyntactic names ("Foo", "Bar", etc.)
1134/// - All instances are renamed to match the new module name
1135/// - All module ports are renamed in the following way:
1136/// - All clocks are reanemd to "clk"
1137/// - All resets are renamed to "rst"
1138/// - All references are renamed to "ref"
1139/// - Anything else is renamed to "port"
1140///
1141struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
1142
1143 const char *names[48] = {
1144 "Foo", "Bar", "Baz", "Qux", "Quux", "Quuux", "Quuuux",
1145 "Quz", "Corge", "Grault", "Bazola", "Ztesch", "Thud", "Grunt",
1146 "Bletch", "Fum", "Fred", "Jim", "Sheila", "Barney", "Flarp",
1147 "Zxc", "Spqr", "Wombat", "Shme", "Bongo", "Spam", "Eggs",
1148 "Snork", "Zot", "Blarg", "Wibble", "Toto", "Titi", "Tata",
1149 "Tutu", "Pippo", "Pluto", "Paperino", "Aap", "Noot", "Mies",
1150 "Oogle", "Foogle", "Boogle", "Zork", "Gork", "Bork"};
1151
1152 size_t nameIndex = 0;
1153
1154 const char *getName() {
1155 if (nameIndex >= 48)
1156 nameIndex = 0;
1157 return names[nameIndex++];
1158 };
1159
1160 size_t portNameIndex = 0;
1161
1162 char getPortName() {
1163 if (portNameIndex >= 26)
1164 portNameIndex = 0;
1165 return 'a' + portNameIndex++;
1166 }
1167
1168 void beforeReduction(mlir::ModuleOp op) override { nameIndex = 0; }
1169
1170 LogicalResult rewrite(firrtl::CircuitOp circuitOp) override {
1171
1172 firrtl::InstanceGraph iGraph(circuitOp);
1173
1174 auto *circuitName = getName();
1175 iGraph.getTopLevelModule().setName(circuitName);
1176 circuitOp.setName(circuitName);
1177
1178 for (auto *node : iGraph) {
1179 auto module = node->getModule<firrtl::FModuleLike>();
1180
1181 bool shouldReplacePorts = false;
1182 SmallVector<Attribute> newNames;
1183 if (auto fmodule = dyn_cast<firrtl::FModuleOp>(*module)) {
1184 portNameIndex = 0;
1185 // TODO: The namespace should be unnecessary. However, some FIRRTL
1186 // passes expect that port names are unique.
1188 auto oldPorts = fmodule.getPorts();
1189 shouldReplacePorts = !oldPorts.empty();
1190 for (unsigned i = 0, e = fmodule.getNumPorts(); i != e; ++i) {
1191 auto port = oldPorts[i];
1192 auto newName = firrtl::FIRRTLTypeSwitch<Type, StringRef>(port.type)
1193 .Case<firrtl::ClockType>(
1194 [&](auto a) { return ns.newName("clk"); })
1195 .Case<firrtl::ResetType, firrtl::AsyncResetType>(
1196 [&](auto a) { return ns.newName("rst"); })
1197 .Case<firrtl::RefType>(
1198 [&](auto a) { return ns.newName("ref"); })
1199 .Default([&](auto a) {
1200 return ns.newName(Twine(getPortName()));
1201 });
1202 newNames.push_back(StringAttr::get(circuitOp.getContext(), newName));
1203 }
1204 fmodule->setAttr("portNames",
1205 ArrayAttr::get(fmodule.getContext(), newNames));
1206 }
1207
1208 if (module == iGraph.getTopLevelModule())
1209 continue;
1210 auto newName = StringAttr::get(circuitOp.getContext(), getName());
1211 module.setName(newName);
1212 for (auto *use : node->uses()) {
1213 auto instanceOp = dyn_cast<firrtl::InstanceOp>(*use->getInstance());
1214 instanceOp.setModuleName(newName);
1215 instanceOp.setName(newName);
1216 if (shouldReplacePorts)
1217 instanceOp.setPortNamesAttr(
1218 ArrayAttr::get(circuitOp.getContext(), newNames));
1219 }
1220 }
1221
1222 circuitOp->dump();
1223
1224 return success();
1225 }
1226
1227 std::string getName() const override { return "module-name-sanitizer"; }
1228
1229 bool acceptSizeIncrease() const override { return true; }
1230
1231 bool isOneShot() const override { return true; }
1232};
1233
1234/// A reduction pattern that groups modules by their port signature (types and
1235/// directions) and replaces instances with the smallest module in each group.
1236/// This helps reduce the IR by consolidating functionally equivalent modules
1237/// based on their interface.
1238///
1239/// The pattern works by:
1240/// 1. Grouping all modules by their port signature (port types and directions)
1241/// 2. For each group with multiple modules, finding the smallest module using
1242/// the module size cache
1243/// 3. Replacing all instances of larger modules with instances of the smallest
1244/// module in the same group
1245/// 4. Removing the larger modules from the circuit
1246///
1247/// This reduction is useful for reducing circuits where multiple modules have
1248/// the same interface but different implementations, allowing the reducer to
1249/// try the smallest implementation first.
1250struct ModuleSwapper : public OpReduction<InstanceOp> {
1251 // Per-circuit state containing all the information needed for module swapping
1252 using PortSignature = SmallVector<std::pair<Type, Direction>>;
1253 struct CircuitState {
1254 DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1255 DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1256 std::unique_ptr<NLATable> nlaTable;
1257 };
1258
1259 void beforeReduction(mlir::ModuleOp op) override {
1260 symbols.clear();
1261 nlaRemover.clear();
1262 moduleSizes.clear();
1263 circuitStates.clear();
1264
1265 // Collect module type groups and NLA tables for all circuits up front
1266 op.walk<WalkOrder::PreOrder>([&](CircuitOp circuitOp) {
1267 auto &state = circuitStates[circuitOp];
1268 state.nlaTable = std::make_unique<NLATable>(circuitOp);
1269 buildModuleTypeGroups(circuitOp, state);
1270 return WalkResult::skip();
1271 });
1272 }
1273 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1274
1275 /// Create a vector of port type-direction pairs for the given FIRRTL module.
1276 /// This ignores port names, allowing modules with the same port types and
1277 /// directions but different port names to be considered equivalent for
1278 /// swapping.
1279 PortSignature getModulePortSignature(FModuleLike module) {
1280 PortSignature signature;
1281 signature.reserve(module.getNumPorts());
1282 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i)
1283 signature.emplace_back(module.getPortType(i), module.getPortDirection(i));
1284 return signature;
1285 }
1286
1287 /// Group modules by their port signature and find the smallest in each group.
1288 void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1289 // Group modules by their port signature
1290 for (auto module : circuitOp.getBodyBlock()->getOps<FModuleLike>()) {
1291 auto signature = getModulePortSignature(module);
1292 state.moduleTypeGroups[signature].push_back(module);
1293 }
1294
1295 // For each group, find the smallest module
1296 for (auto &[signature, modules] : state.moduleTypeGroups) {
1297 if (modules.size() <= 1)
1298 continue;
1299
1300 FModuleLike smallestModule = nullptr;
1301 uint64_t smallestSize = std::numeric_limits<uint64_t>::max();
1302
1303 for (auto module : modules) {
1304 uint64_t size = moduleSizes.getModuleSize(module, symbols);
1305 if (size < smallestSize) {
1306 smallestSize = size;
1307 smallestModule = module;
1308 }
1309 }
1310
1311 // Map all modules in this group to the smallest one
1312 for (auto module : modules) {
1313 if (module != smallestModule) {
1314 state.instanceToCanonicalModule[module.getModuleNameAttr()] =
1315 smallestModule;
1316 }
1317 }
1318 }
1319 }
1320
1321 uint64_t match(InstanceOp instOp) override {
1322 // Get the circuit this instance belongs to
1323 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1324 assert(circuitOp);
1325 const auto &state = circuitStates.at(circuitOp);
1326
1327 // Skip instances that participate in any NLAs
1328 DenseSet<hw::HierPathOp> nlas;
1329 state.nlaTable->getInstanceNLAs(instOp, nlas);
1330 if (!nlas.empty())
1331 return 0;
1332
1333 // Check if this instance can be redirected to a smaller module
1334 auto moduleName = instOp.getModuleNameAttr().getAttr();
1335 auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1336 if (!canonicalModule)
1337 return 0;
1338
1339 // Benefit is the size difference
1340 auto currentModule = cast<FModuleLike>(
1341 instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1342 uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1343 uint64_t canonicalSize =
1344 moduleSizes.getModuleSize(canonicalModule, symbols);
1345 return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1346 }
1347
1348 LogicalResult rewrite(InstanceOp instOp) override {
1349 // Get the circuit this instance belongs to
1350 auto circuitOp = instOp->getParentOfType<CircuitOp>();
1351 assert(circuitOp);
1352 const auto &state = circuitStates.at(circuitOp);
1353
1354 // Replace the instantiated module with the canonical module.
1355 auto canonicalModule = state.instanceToCanonicalModule.at(
1356 instOp.getModuleNameAttr().getAttr());
1357 auto canonicalName = canonicalModule.getModuleNameAttr();
1358 instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
1359
1360 // Update port names to match the canonical module
1361 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1362
1363 return success();
1364 }
1365
1366 std::string getName() const override { return "firrtl-module-swapper"; }
1367 bool acceptSizeIncrease() const override { return true; }
1368
1369private:
1370 ::detail::SymbolCache symbols;
1371 NLARemover nlaRemover;
1372 ModuleSizeCache moduleSizes;
1373
1374 // Per-circuit state containing all module swapping information
1375 DenseMap<CircuitOp, CircuitState> circuitStates;
1376};
1377
1378/// A reduction pattern that handles MustDedup annotations by replacing all
1379/// module names in a dedup group with a single module name. This helps reduce
1380/// the IR by consolidating module references that are required to be identical.
1381///
1382/// The pattern works by:
1383/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1384/// 2. For each dedup group, using the first module as the canonical name
1385/// 3. Replacing all instance references to other modules in the group with
1386/// references to the canonical module
1387/// 4. Removing the non-canonical modules from the circuit
1388/// 5. Removing the processed MustDedup annotation
1389///
1390/// This reduction is particularly useful for reducing large circuits where
1391/// multiple modules are known to be identical but haven't been deduplicated
1392/// yet.
1393struct ForceDedup : public OpReduction<CircuitOp> {
1394 void beforeReduction(mlir::ModuleOp op) override {
1395 symbols.clear();
1396 nlaRemover.clear();
1397 }
1398 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1399
1400 /// Collect all MustDedup annotations and create matches for each dedup group.
1401 void matches(CircuitOp circuitOp,
1402 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1403 auto annotations = AnnotationSet(circuitOp);
1404 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1405 if (!anno.isClass(mustDedupAnnoClass))
1406 continue;
1407
1408 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1409 if (!modulesAttr)
1410 continue;
1411
1412 // Each dedup group gets its own match with benefit proportional to group
1413 // size.
1414 uint64_t benefit = modulesAttr.size();
1415 addMatch(benefit, annoIdx);
1416 }
1417 }
1418
1419 LogicalResult rewriteMatches(CircuitOp circuitOp,
1420 ArrayRef<uint64_t> matches) override {
1421 auto *context = circuitOp->getContext();
1422 NLATable nlaTable(circuitOp);
1423 hw::InnerSymbolTableCollection innerSymTables;
1424 auto annotations = AnnotationSet(circuitOp);
1425 SmallVector<Annotation> newAnnotations;
1426
1427 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1428 // Check if this annotation was selected.
1429 if (!llvm::is_contained(matches, annoIdx)) {
1430 newAnnotations.push_back(anno);
1431 continue;
1432 }
1433 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1434 assert(anno.isClass(mustDedupAnnoClass) && modulesAttr &&
1435 modulesAttr.size() >= 2);
1436
1437 // Extract module names from the dedup group.
1438 SmallVector<StringAttr> moduleNames;
1439 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
1440 // Parse "~CircuitName|ModuleName" format.
1441 auto refStr = moduleRef.getValue();
1442 auto pipePos = refStr.find('|');
1443 if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
1444 auto moduleName = refStr.substr(pipePos + 1);
1445 moduleNames.push_back(StringAttr::get(context, moduleName));
1446 }
1447 }
1448
1449 // Simply drop the annotation if there's only one module.
1450 if (moduleNames.size() < 2)
1451 continue;
1452
1453 // Replace all instances and references to other modules with the
1454 // first module.
1455 replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
1456 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1457 }
1458 if (newAnnotations.size() == annotations.size())
1459 return failure();
1460
1461 // Update circuit annotations.
1462 AnnotationSet newAnnoSet(newAnnotations, context);
1463 newAnnoSet.applyToOperation(circuitOp);
1464 return success();
1465 }
1466
1467 std::string getName() const override { return "firrtl-force-dedup"; }
1468 bool acceptSizeIncrease() const override { return true; }
1469
1470private:
1471 /// Replace all references to modules in the dedup group with the canonical
1472 /// module name
1473 void replaceModuleReferences(CircuitOp circuitOp,
1474 ArrayRef<StringAttr> moduleNames,
1475 NLATable &nlaTable,
1476 hw::InnerSymbolTableCollection &innerSymTables) {
1477 auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
1478 auto &symbolTable = symbols.getSymbolTable(tableOp);
1479 auto *context = circuitOp->getContext();
1480 auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
1481
1482 // Collect the modules.
1483 FModuleLike canonicalModule;
1484 SmallVector<FModuleLike> modulesToReplace;
1485 for (auto name : moduleNames) {
1486 if (auto mod = symbolTable.lookup<FModuleLike>(name)) {
1487 if (!canonicalModule)
1488 canonicalModule = mod;
1489 else
1490 modulesToReplace.push_back(mod);
1491 }
1492 }
1493 if (modulesToReplace.empty())
1494 return;
1495
1496 // Replace all instance references.
1497 auto canonicalName = canonicalModule.getModuleNameAttr();
1498 auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
1499 circuitOp.walk([&](InstanceOp instOp) {
1500 auto moduleName = instOp.getModuleNameAttr().getAttr();
1501 if (llvm::is_contained(moduleNames, moduleName) &&
1502 moduleName != canonicalName) {
1503 instOp.setModuleNameAttr(canonicalRef);
1504 instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1505 }
1506 });
1507
1508 // Update NLAs to reference the canonical module instead of modules being
1509 // removed using NLATable for better performance.
1510 for (auto oldMod : modulesToReplace) {
1511 SmallVector<hw::HierPathOp> nlaOps(
1512 nlaTable.lookup(oldMod.getModuleNameAttr()));
1513 for (auto nlaOp : nlaOps) {
1514 nlaTable.erase(nlaOp);
1515 StringAttr oldModName = oldMod.getModuleNameAttr();
1516 StringAttr newModName = canonicalName;
1517 SmallVector<Attribute, 4> newPath;
1518 for (auto nameRef : nlaOp.getNamepath()) {
1519 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
1520 if (ref.getModule() == oldModName) {
1521 auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
1522 ref = hw::InnerRefAttr::get(newModName, ref.getName());
1523 auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
1524 if (oldInst && newInst) {
1525 oldModName = oldInst.getReferencedModuleNameAttr();
1526 newModName = newInst.getReferencedModuleNameAttr();
1527 }
1528 }
1529 newPath.push_back(ref);
1530 } else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
1531 newPath.push_back(FlatSymbolRefAttr::get(newModName));
1532 } else {
1533 newPath.push_back(nameRef);
1534 }
1535 }
1536 nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
1537 nlaTable.addNLA(nlaOp);
1538 }
1539 }
1540
1541 // Mark NLAs in modules to be removed.
1542 for (auto module : modulesToReplace) {
1543 nlaRemover.markNLAsInOperation(module);
1544 module->erase();
1545 }
1546 }
1547
1548 ::detail::SymbolCache symbols;
1549 NLARemover nlaRemover;
1550};
1551
1552/// A reduction pattern that moves `MustDedup` annotations from a module onto
1553/// its child modules. This pattern iterates over all MustDedup annotations,
1554/// collects all `FInstanceLike` ops in each module of the dedup group, and
1555/// creates new MustDedup annotations for corresponding instances across the
1556/// modules. Each set of corresponding instances becomes a separate match of the
1557/// reduction. The reduction also removes the original MustDedup annotation on
1558/// the parent module.
1559///
1560/// The pattern works by:
1561/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1562/// 2. For each dedup group, collecting all FInstanceLike operations in each
1563/// module
1564/// 3. Grouping corresponding instances across modules by their position/name
1565/// 4. Creating new MustDedup annotations for each group of corresponding
1566/// instances
1567/// 5. Removing the original MustDedup annotation from the circuit
1568struct MustDedupChildren : public OpReduction<CircuitOp> {
1569 void beforeReduction(mlir::ModuleOp op) override {
1570 symbols.clear();
1571 nlaRemover.clear();
1572 }
1573 void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1574
1575 /// Collect all MustDedup annotations and create matches for each instance
1576 /// group.
1577 void matches(CircuitOp circuitOp,
1578 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1579 auto annotations = AnnotationSet(circuitOp);
1580 uint64_t matchId = 0;
1581
1582 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1583 if (!anno.isClass(mustDedupAnnoClass))
1584 continue;
1585
1586 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1587 if (!modulesAttr || modulesAttr.size() < 2)
1588 continue;
1589
1590 // Process each group of corresponding instances
1591 processInstanceGroups(
1592 circuitOp, modulesAttr,
1593 [&](ArrayRef<FInstanceLike>) { addMatch(1, matchId++); });
1594 }
1595 }
1596
1597 LogicalResult rewriteMatches(CircuitOp circuitOp,
1598 ArrayRef<uint64_t> matches) override {
1599 auto *context = circuitOp->getContext();
1600 auto annotations = AnnotationSet(circuitOp);
1601 SmallVector<Annotation> newAnnotations;
1602 uint64_t matchId = 0;
1603
1604 for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1605 if (!anno.isClass(mustDedupAnnoClass)) {
1606 newAnnotations.push_back(anno);
1607 continue;
1608 }
1609
1610 auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1611 if (!modulesAttr || modulesAttr.size() < 2) {
1612 newAnnotations.push_back(anno);
1613 continue;
1614 }
1615
1616 // Track whether any matches were selected for this annotation
1617 bool anyMatchSelected = false;
1618 processInstanceGroups(
1619 circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
1620 // Check if this instance group was selected
1621 if (!llvm::is_contained(matches, matchId++))
1622 return;
1623 anyMatchSelected = true;
1624
1625 // Create the list of modules to put into this new annotation.
1626 SmallSetVector<StringAttr, 4> moduleTargets;
1627 for (auto instOp : instanceGroup) {
1628 auto target = TokenAnnoTarget();
1629 target.circuit = circuitOp.getName();
1630 target.module = instOp.getReferencedModuleName();
1631 moduleTargets.insert(target.toStringAttr(context));
1632 }
1633 if (moduleTargets.size() < 2)
1634 return;
1635
1636 // Create a new MustDedup annotation for this list of modules.
1637 SmallVector<NamedAttribute> newAnnoAttrs;
1638 newAnnoAttrs.emplace_back(
1639 StringAttr::get(context, "class"),
1640 StringAttr::get(context, mustDedupAnnoClass));
1641 newAnnoAttrs.emplace_back(
1642 StringAttr::get(context, "modules"),
1643 ArrayAttr::get(context,
1644 SmallVector<Attribute>(moduleTargets.begin(),
1645 moduleTargets.end())));
1646
1647 auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
1648 newAnnotations.emplace_back(newAnnoDict);
1649 });
1650
1651 // If any matches were selected, mark the original annotation for removal
1652 // since we're replacing it with new MustDedup annotations on the child
1653 // modules. Otherwise keep the original annotation around.
1654 if (anyMatchSelected)
1655 nlaRemover.markNLAsInAnnotation(anno.getAttr());
1656 else
1657 newAnnotations.push_back(anno);
1658 }
1659
1660 // Update circuit annotations
1661 AnnotationSet newAnnoSet(newAnnotations, context);
1662 newAnnoSet.applyToOperation(circuitOp);
1663 return success();
1664 }
1665
1666 std::string getName() const override { return "must-dedup-children"; }
1667 bool acceptSizeIncrease() const override { return true; }
1668
1669private:
1670 /// Helper function to process groups of corresponding instances from a
1671 /// MustDedup annotation. Calls the provided lambda for each group of
1672 /// corresponding instances across the modules. Only calls the lambda if there
1673 /// are at least 2 modules.
1674 void processInstanceGroups(
1675 CircuitOp circuitOp, ArrayAttr modulesAttr,
1676 llvm::function_ref<void(ArrayRef<FInstanceLike>)> callback) {
1677 auto &symbolTable = symbols.getSymbolTable(circuitOp);
1678
1679 // Extract module names and get the actual modules
1680 SmallVector<FModuleLike> modules;
1681 for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
1682 if (auto target = tokenizePath(moduleRef))
1683 if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
1684 modules.push_back(mod);
1685
1686 // Need at least 2 modules for deduplication
1687 if (modules.size() < 2)
1688 return;
1689
1690 // Collect all FInstanceLike operations from each module and group them by
1691 // name. Instance names are a good key for matching instances across
1692 // modules. But they may not be unique, so we need to be careful to only
1693 // match up instances that are uniquely named within every module.
1694 struct InstanceGroup {
1695 SmallVector<FInstanceLike> instances;
1696 bool nameIsUnique = true;
1697 };
1698 MapVector<StringAttr, InstanceGroup> instanceGroups;
1699 for (auto module : modules) {
1701 module.walk([&](FInstanceLike instOp) {
1702 auto name = instOp.getInstanceNameAttr();
1703 auto &group = instanceGroups[name];
1704 if (nameCounts[name]++ > 1)
1705 group.nameIsUnique = false;
1706 group.instances.push_back(instOp);
1707 });
1708 }
1709
1710 // Call the callback for each group of instances that are uniquely named and
1711 // consist of at least 2 instances.
1712 for (auto &[name, group] : instanceGroups)
1713 if (group.nameIsUnique && group.instances.size() >= 2)
1714 callback(group.instances);
1715 }
1716
1717 ::detail::SymbolCache symbols;
1718 NLARemover nlaRemover;
1719};
1720
1721} // namespace
1722
1723//===----------------------------------------------------------------------===//
1724// Reduction Registration
1725//===----------------------------------------------------------------------===//
1726
1729 // Gather a list of reduction patterns that we should try. Ideally these are
1730 // assigned reasonable benefit indicators (higher benefit patterns are
1731 // prioritized). For example, things that can knock out entire modules while
1732 // being cheap should be tried first (and thus have higher benefit), before
1733 // trying to tweak operands of individual arithmetic ops.
1734 patterns.add<AnnotationRemover, 33>();
1735 patterns.add<ModuleSwapper, 32>();
1736 patterns.add<ForceDedup, 31>();
1737 patterns.add<MustDedupChildren, 30>();
1738 patterns.add<PassReduction, 29>(
1739 getContext(),
1740 firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,
1741 true);
1742 patterns.add<PassReduction, 28>(getContext(),
1743 firrtl::createLowerCHIRRTLPass(), true, true);
1744 patterns.add<PassReduction, 27>(getContext(), firrtl::createInferWidths(),
1745 true, true);
1746 patterns.add<PassReduction, 26>(getContext(), firrtl::createInferResets(),
1747 true, true);
1748 patterns.add<FIRRTLModuleExternalizer, 25>();
1749 patterns.add<InstanceStubber, 24>();
1750 patterns.add<MemoryStubber, 23>();
1751 patterns.add<EagerInliner, 22>();
1752 patterns.add<PassReduction, 21>(getContext(),
1753 firrtl::createLowerFIRRTLTypes(), true, true);
1754 patterns.add<PassReduction, 20>(getContext(), firrtl::createExpandWhens(),
1755 true, true);
1756 patterns.add<PassReduction, 19>(getContext(), firrtl::createInliner());
1757 patterns.add<PassReduction, 18>(getContext(), firrtl::createIMConstProp());
1758 patterns.add<PassReduction, 17>(
1759 getContext(),
1760 firrtl::createRemoveUnusedPorts({/*ignoreDontTouch=*/true}));
1761 patterns.add<NodeSymbolRemover, 15>();
1762 patterns.add<ConnectForwarder, 14>();
1763 patterns.add<ConnectInvalidator, 13>();
1764 patterns.add<FIRRTLConstantifier, 12>();
1765 patterns.add<FIRRTLOperandForwarder<0>, 11>();
1766 patterns.add<FIRRTLOperandForwarder<1>, 10>();
1767 patterns.add<FIRRTLOperandForwarder<2>, 9>();
1768 patterns.add<DetachSubaccesses, 7>();
1769 patterns.add<RootPortPruner, 5>();
1770 patterns.add<ExtmoduleInstanceRemover, 4>();
1771 patterns.add<ConnectSourceOperandForwarder<0>, 3>();
1772 patterns.add<ConnectSourceOperandForwarder<1>, 2>();
1773 patterns.add<ConnectSourceOperandForwarder<2>, 1>();
1774 patterns.add<ModuleInternalNameSanitizer, 0>();
1775 patterns.add<ModuleNameSanitizer, 0>();
1776}
1777
1779 mlir::DialectRegistry &registry) {
1780 registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
1781 dialect->addInterfaces<FIRRTLReducePatternDialectInterface>();
1782 });
1783}
assert(baseType &&"element must be base type")
static bool onlyInvalidated(Value arg)
Check that all connections to a value are invalids.
static std::optional< firrtl::FModuleOp > findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols)
Utility to easily get the instantiated firrtl::FModuleOp or an empty optional in case another type of...
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
ArrayRef< hw::HierPathOp > lookup(Operation *op)
Lookup all NLAs an operation participates in.
Definition NLATable.cpp:41
void addNLA(hw::HierPathOp nla)
Insert a new NLA.
Definition NLATable.cpp:58
void erase(hw::HierPathOp nlaOp, SymbolTable *symbolTable=nullptr)
Remove the NLA from the analysis.
Definition NLATable.cpp:68
The target of an inner symbol, the entity the symbol is a handle for.
This class represents a collection of InnerSymbolTable's.
@ None
Don't explicitly preserve any named values.
Definition Passes.h:52
constexpr const char * mustDedupAnnoClass
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the FIRRTL Reduction pattern dialect interface to the given registry.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
Definition LayerSet.h:42
std::optional< TokenAnnoTarget > tokenizePath(StringRef origTarget)
Parse a FIRRTL annotation path into its constituent parts.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition HWOps.cpp:36
void info(Twine message)
Definition LSPUtils.cpp:20
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
A tracker for track NLAs affected by a reduction.
void remove(mlir::ModuleOp module)
Remove all marked annotations.
void clear()
Clear the set of marked NLAs. Call this before attempting a reduction.
llvm::DenseSet< StringAttr > nlasToRemove
The set of NLAs to remove, identified by their symbol.
void markNLAsInAnnotation(Attribute anno)
Mark all NLAs referenced in the given annotation as to be removed.
void markNLAsInOperation(Operation *op)
Mark all NLAs referenced in an operation.
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
A reduction pattern that applies an mlir::Pass.
Definition Reduction.h:142
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition Reduction.h:35
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:79
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:66
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:96
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:50
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override
This holds the name and type that describes the module's ports.
The parsed annotation path.
This class represents the namespace in which InnerRef's can be resolved.
A helper struct that scans a root operation and all its nested operations for InnerRefAttrs.
A utility doing lazy construction of SymbolTables and SymbolUserMaps, which is handy for reductions t...
std::unique_ptr< SymbolTableCollection > tables
SymbolUserMap & getSymbolUserMap(Operation *op)
SymbolUserMap & getNearestSymbolUserMap(Operation *op)
SymbolTable & getNearestSymbolTable(Operation *op)
SmallDenseMap< Operation *, SymbolUserMap, 2 > userMaps
SymbolTable & getSymbolTable(Operation *op)