15#include "mlir/IR/Threading.h"
16#include "llvm/ADT/STLExtras.h"
22#define GEN_PASS_DEF_SPECIALIZELAYERS
23#include "circt/Dialect/FIRRTL/Passes.h.inc"
29using namespace firrtl;
37 return mlir::ArrayAttr::getFromOpaquePointer(ptr);
45struct AnnotationCleaner {
48 AnnotationCleaner(
const DenseSet<StringAttr> &removedPaths)
49 : removedPaths(removedPaths) {}
53 if (
auto nla = anno.
getMember<FlatSymbolRefAttr>(
"circt.nonlocal"))
54 return removedPaths.contains(nla.getAttr());
60 void cleanAnnotations(Operation *op) {
64 if (!oldAnnotations.empty()) {
65 auto newAnnotations = cleanAnnotations(oldAnnotations);
66 if (oldAnnotations != newAnnotations)
67 newAnnotations.applyToOperation(op);
71 void operator()(FModuleLike module) {
73 cleanAnnotations(module);
76 for (
size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
78 if (!oldAnnotations.empty()) {
79 auto newAnnotations = cleanAnnotations(oldAnnotations);
80 if (oldAnnotations != newAnnotations)
81 newAnnotations.applyToPort(module, i);
86 module->walk([&](Operation *op) {
90 if (
auto mem = dyn_cast<MemOp>(op)) {
92 for (
size_t i = 0, e = mem.getNumResults(); i < e; ++i) {
94 if (!oldAnnotations.empty()) {
95 auto newAnnotations = cleanAnnotations(oldAnnotations);
96 if (oldAnnotations != newAnnotations)
97 newAnnotations.applyToPort(mem, i);
105 const DenseSet<StringAttr> &removedPaths;
109struct InsertionPoint {
111 static InsertionPoint atBlockEnd(Block *block) {
112 return InsertionPoint{block, block->end()};
117 void moveOpBefore(Operation *op) {
118 op->moveBefore(block, it);
119 it = Block::iterator(op);
123 InsertionPoint(Block *block, Block::iterator it) : block(block), it(it) {}
135 Specialized() : Specialized(LayerSpecialization::Disable, nullptr) {}
138 Specialized(T value) : Specialized(LayerSpecialization::Enable, value) {}
141 bool isDisabled()
const {
142 return value.getInt() == LayerSpecialization::Disable;
148 return value.getPointer();
151 operator bool()
const {
return !isDisabled(); }
154 Specialized(LayerSpecialization specialization, T value)
155 : value(value, specialization) {}
156 llvm::PointerIntPair<T, 1, LayerSpecialization> value;
159struct SpecializeLayers {
162 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations,
163 std::optional<LayerSpecialization> defaultSpecialization)
164 : context(circuit->getContext()), circuit(circuit),
165 specializations(specializations),
166 defaultSpecialization(defaultSpecialization) {}
170 static void recordRemovedInnerSym(DenseSet<Attribute> &removedSyms,
171 StringAttr moduleName,
172 hw::InnerSymAttr innerSym) {
173 for (
auto field : innerSym)
174 removedSyms.insert(
hw::InnerRefAttr::
get(moduleName, field.
getName()));
179 static void recordRemovedInnerSyms(DenseSet<Attribute> &removedSyms,
180 StringAttr moduleName, Block *block) {
181 block->walk([&](hw::InnerSymbolOpInterface op) {
182 if (
auto innerSym = op.getInnerSymAttr())
183 recordRemovedInnerSym(removedSyms, moduleName, innerSym);
189 std::optional<LayerSpecialization> getSpecialization(SymbolRefAttr layerRef) {
190 auto it = specializations.find(layerRef);
191 if (it != specializations.end())
192 return it->getSecond();
193 return defaultSpecialization;
200 std::optional<LayerSpecialization>
201 getSpecialization(StringAttr head, ArrayRef<FlatSymbolRefAttr> nestedRefs) {
202 return getSpecialization(SymbolRefAttr::get(head, nestedRefs));
208 Specialized<SymbolRefAttr> specializeLayerRef(SymbolRefAttr layerRef) {
216 auto oldRoot = layerRef.getRootReference();
217 SmallVector<FlatSymbolRefAttr> oldNestedRefs;
220 SmallString<64> prefix;
221 SmallVector<FlatSymbolRefAttr> newRef;
228 auto helper = [&](StringAttr ref) ->
bool {
229 auto specialization = getSpecialization(oldRoot, oldNestedRefs);
233 if (!specialization) {
234 newRef.push_back(FlatSymbolRefAttr::get(
235 StringAttr::get(ref.getContext(), prefix + ref.getValue())));
242 if (*specialization == LayerSpecialization::Enable) {
243 prefix.append(ref.getValue());
252 if (!helper(oldRoot))
255 for (
auto ref : layerRef.getNestedReferences()) {
256 oldNestedRefs.push_back(ref);
257 if (!helper(ref.getAttr()))
262 return {SymbolRefAttr()};
269 auto newRoot = newRef.front().getAttr();
270 return {SymbolRefAttr::get(newRoot, ArrayRef(newRef).drop_front())};
275 RefType specializeRefType(RefType refType) {
276 if (
auto oldLayer = refType.getLayer()) {
277 if (
auto newLayer = specializeLayerRef(oldLayer))
278 return RefType::get(refType.getType(), refType.getForceable(),
279 newLayer.getValue());
285 Type specializeType(Type type) {
286 if (
auto refType = dyn_cast<RefType>(type))
287 return specializeRefType(refType);
293 Value specializeValue(Value value) {
294 if (
auto newType = specializeType(value.getType())) {
295 value.setType(newType);
306 void specializeOp(LayerBlockOp layerBlock, InsertionPoint &insertionPoint,
307 DenseSet<Attribute> &removedSyms) {
308 auto oldLayerRef = layerBlock.getLayerNameAttr();
314 auto specialization = getSpecialization(oldLayerRef);
317 if (!specialization) {
320 auto newLayerRef = specializeLayerRef(oldLayerRef).getValue();
321 if (oldLayerRef != newLayerRef)
322 layerBlock.setLayerNameAttr(newLayerRef);
325 auto *block = layerBlock.getBody();
326 auto bodyIP = InsertionPoint::atBlockEnd(block);
327 specializeBlock(block, bodyIP, removedSyms);
328 insertionPoint.moveOpBefore(layerBlock);
334 if (*specialization == LayerSpecialization::Enable) {
336 specializeBlock(layerBlock.getBody(), insertionPoint, removedSyms);
344 auto moduleName = layerBlock->getParentOfType<FModuleOp>().getNameAttr();
345 recordRemovedInnerSyms(removedSyms, moduleName, layerBlock.getBody());
349 void specializeOp(WhenOp when, InsertionPoint &insertionPoint,
350 DenseSet<Attribute> &removedSyms) {
353 auto *thenBlock = &when.getThenBlock();
354 auto thenIP = InsertionPoint::atBlockEnd(thenBlock);
355 specializeBlock(thenBlock, thenIP, removedSyms);
356 if (when.hasElseRegion()) {
357 auto *elseBlock = &when.getElseBlock();
358 auto elseIP = InsertionPoint::atBlockEnd(elseBlock);
359 specializeBlock(elseBlock, elseIP, removedSyms);
361 insertionPoint.moveOpBefore(when);
364 void specializeOp(MatchOp match, InsertionPoint &insertionPoint,
365 DenseSet<Attribute> &removedSyms) {
366 for (
size_t i = 0, e = match.getNumRegions(); i < e; ++i) {
367 auto *caseBlock = &match.getRegion(i).front();
368 auto caseIP = InsertionPoint::atBlockEnd(caseBlock);
369 specializeBlock(caseBlock, caseIP, removedSyms);
371 insertionPoint.moveOpBefore(match);
374 void specializeOp(FInstanceLike instance, InsertionPoint &insertionPoint,
375 DenseSet<Attribute> &removedSyms) {
378 llvm::BitVector disabledPorts(instance->getNumResults());
379 for (
auto result : instance->getResults())
380 if (!specializeValue(result))
381 disabledPorts.set(result.getResultNumber());
383 if (disabledPorts.any()) {
385 instance.cloneWithErasedPortsAndReplaceUses(disabledPorts);
387 instance = newInstance;
393 auto newLayers = specializeEnableLayers(instance.getLayersAttr());
394 instance.setLayersAttr(newLayers.getValue());
396 insertionPoint.moveOpBefore(instance);
399 void specializeOp(WireOp wire, InsertionPoint &insertionPoint,
400 DenseSet<Attribute> &removedSyms) {
401 if (specializeValue(wire.getResult())) {
402 insertionPoint.moveOpBefore(wire);
404 if (
auto innerSym = wire.getInnerSymAttr())
405 recordRemovedInnerSym(removedSyms,
406 wire->getParentOfType<FModuleOp>().getNameAttr(),
412 void specializeOp(RefDefineOp refDefine, InsertionPoint &insertionPoint,
413 DenseSet<Attribute> &removedSyms) {
415 if (
auto layerRef = refDefine.getDest().getType().getLayer())
416 if (!specializeLayerRef(layerRef)) {
420 insertionPoint.moveOpBefore(refDefine);
423 void specializeOp(RefSubOp refSub, InsertionPoint &insertionPoint,
424 DenseSet<Attribute> &removedSyms) {
425 if (specializeValue(refSub->getResult(0)))
426 insertionPoint.moveOpBefore(refSub);
432 void specializeOp(RefCastOp refCast, InsertionPoint &insertionPoint,
433 DenseSet<Attribute> &removedSyms) {
434 if (specializeValue(refCast->getResult(0)))
435 insertionPoint.moveOpBefore(refCast);
442 void specializeBlock(Block *block, InsertionPoint &insertionPoint,
443 DenseSet<Attribute> &removedSyms) {
446 for (
auto &op :
llvm::make_early_inc_range(
llvm::reverse(*block))) {
447 TypeSwitch<Operation *>(&op)
448 .Case<LayerBlockOp, WhenOp, MatchOp, InstanceOp, InstanceChoiceOp,
449 WireOp, RefDefineOp, RefSubOp, RefCastOp>(
450 [&](
auto op) { specializeOp(op, insertionPoint, removedSyms); })
451 .Default([&](Operation *op) {
454 insertionPoint.moveOpBefore(op);
460 ArrayAttr specializeKnownLayers(ArrayAttr layers) {
461 SmallVector<Attribute> newLayers;
462 for (
auto layer : layers.getAsRange<SymbolRefAttr>()) {
463 if (
auto result = specializeLayerRef(layer))
464 if (
auto newLayer = result.getValue())
465 newLayers.push_back(newLayer);
468 return ArrayAttr::get(context, newLayers);
473 Specialized<ArrayAttr> specializeEnableLayers(ArrayAttr layers) {
474 SmallVector<Attribute> newLayers;
475 for (
auto layer : layers.getAsRange<SymbolRefAttr>()) {
476 auto newLayer = specializeLayerRef(layer);
479 if (newLayer.getValue())
480 newLayers.push_back(newLayer.getValue());
482 return ArrayAttr::get(context, newLayers);
485 void specializeModulePorts(FModuleLike moduleLike,
486 DenseSet<Attribute> &removedSyms) {
487 auto oldTypeAttrs = moduleLike.getPortTypesAttr();
490 SmallVector<Attribute> newTypeAttrs;
491 newTypeAttrs.reserve(oldTypeAttrs.size());
495 llvm::BitVector disabledPorts(oldTypeAttrs.size());
497 auto moduleName = moduleLike.getNameAttr();
498 for (
auto [index, typeAttr] :
499 llvm::enumerate(oldTypeAttrs.getAsRange<TypeAttr>())) {
501 if (
auto type = specializeType(typeAttr.getValue())) {
502 newTypeAttrs.push_back(TypeAttr::get(type));
505 if (
auto portSym = moduleLike.getPortSymbolAttr(index))
506 recordRemovedInnerSym(removedSyms, moduleName, portSym);
507 disabledPorts.set(index);
512 moduleLike.erasePorts(disabledPorts);
515 moduleLike.setPortTypesAttr(
516 ArrayAttr::get(moduleLike.getContext(), newTypeAttrs));
519 if (
auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation()))
520 for (
auto [arg, typeAttr] :
521 llvm::zip(moduleOp.getArguments(), newTypeAttrs))
522 arg.setType(cast<TypeAttr>(typeAttr).getValue());
525 template <
typename T>
526 DenseSet<Attribute> specializeModuleLike(T op) {
527 DenseSet<Attribute> removedSyms;
532 if constexpr (std::is_same_v<T, FModuleOp>) {
533 auto *block = cast<FModuleOp>(op).getBodyBlock();
534 auto bodyIP = InsertionPoint::atBlockEnd(block);
535 specializeBlock(block, bodyIP, removedSyms);
539 specializeModulePorts(op, removedSyms);
544 template <
typename T>
545 T specializeEnableLayers(T module, DenseSet<Attribute> &removedSyms) {
547 if (
auto newLayers = specializeEnableLayers(module.getLayersAttr())) {
548 module.setLayersAttr(newLayers.getValue());
554 auto moduleName =
module.getNameAttr();
555 removedSyms.insert(FlatSymbolRefAttr::get(moduleName));
556 if constexpr (std::is_same_v<T, FModuleOp>)
557 recordRemovedInnerSyms(removedSyms, moduleName,
558 cast<FModuleOp>(module).getBodyBlock());
565 void specializeKnownLayers(FExtModuleOp module) {
566 auto knownLayers =
module.getKnownLayersAttr();
567 module.setKnownLayersAttr(specializeKnownLayers(knownLayers));
573 void specializeLayer(LayerOp layer) {
574 StringAttr head = layer.getSymNameAttr();
575 SmallVector<FlatSymbolRefAttr> nestedRefs;
577 std::function<void(LayerOp, Block::iterator,
const Twine &)> handleLayer =
578 [&](LayerOp layer, Block::iterator insertionPoint,
579 const Twine &prefix) {
580 auto *block = &layer.getBody().getBlocks().front();
581 auto specialization = getSpecialization(head, nestedRefs);
585 if (!specialization) {
589 if (!prefix.isTriviallyEmpty()) {
590 layer.setSymNameAttr(
591 StringAttr::get(context, prefix + layer.getSymName()));
592 auto *parentBlock = insertionPoint->getBlock();
593 layer->moveBefore(parentBlock, insertionPoint);
596 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
597 nestedRefs.push_back(SymbolRefAttr::get(nested));
598 handleLayer(nested, Block::iterator(nested),
"");
599 nestedRefs.pop_back();
606 if (*specialization == LayerSpecialization::Enable) {
608 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
609 nestedRefs.push_back(SymbolRefAttr::get(nested));
610 handleLayer(nested, insertionPoint,
611 prefix + layer.getSymName() +
"_");
612 nestedRefs.pop_back();
624 handleLayer(layer, Block::iterator(layer),
"");
631 SmallVector<Operation *> specialize;
632 DenseSet<Attribute> removedSyms;
634 TypeSwitch<Operation *>(&op)
635 .Case<FModuleOp>([&](FModuleOp module) {
636 if (specializeEnableLayers(module, removedSyms))
637 specialize.push_back(module);
639 .Case<FExtModuleOp>([&](FExtModuleOp module) {
640 specializeKnownLayers(module);
641 if (specializeEnableLayers(module, removedSyms))
642 specialize.push_back(module);
644 .Case<LayerOp>([&](LayerOp layer) { specializeLayer(layer); });
648 auto mergeSets = [](
auto &&
a,
auto &&
b) {
649 a.insert(
b.begin(),
b.end());
650 return std::forward<decltype(a)>(a);
656 context, specialize, removedSyms, mergeSets,
657 [&](Operation *op) -> DenseSet<Attribute> {
658 return TypeSwitch<Operation *, DenseSet<Attribute>>(op)
659 .Case<FModuleOp, FExtModuleOp>(
660 [&](
auto op) {
return specializeModuleLike(op); });
666 DenseSet<StringAttr> removedPaths;
667 for (
auto hierPath :
llvm::make_early_inc_range(
668 circuit.getBody().getOps<
hw::HierPathOp>())) {
669 auto namepath = hierPath.getNamepath().getValue();
670 auto shouldDelete = [&](Attribute ref) {
671 return removedSyms.contains(ref);
673 if (llvm::any_of(namepath.drop_back(), shouldDelete)) {
674 removedPaths.insert(SymbolTable::getSymbolName(hierPath));
681 if (shouldDelete(namepath.back()))
687 SmallVector<FModuleLike> clean;
689 if (isa<FModuleOp, FExtModuleOp, FIntModuleOp, FMemModuleOp>(op))
690 clean.push_back(cast<FModuleLike>(op));
692 parallelForEach(context, clean, [&](FModuleLike module) {
693 (AnnotationCleaner(removedPaths))(module);
697 MLIRContext *context;
699 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations;
702 std::optional<LayerSpecialization> defaultSpecialization;
705struct SpecializeLayersPass
706 :
public circt::firrtl::impl::SpecializeLayersBase<SpecializeLayersPass> {
708 void runOnOperation()
override {
709 auto circuit = getOperation();
710 SymbolTableCollection stc;
713 DenseMap<SymbolRefAttr, LayerSpecialization> specializations;
716 bool shouldSpecialize =
false;
719 if (
auto enabledLayers = circuit.getEnableLayersAttr()) {
720 shouldSpecialize =
true;
721 circuit.removeEnableLayersAttr();
722 for (
auto enabledLayer : enabledLayers.getAsRange<SymbolRefAttr>()) {
724 if (!stc.lookupSymbolIn(circuit, enabledLayer)) {
725 mlir::emitError(circuit.getLoc()) <<
"unknown layer " << enabledLayer;
729 specializations[enabledLayer] = LayerSpecialization::Enable;
734 if (
auto disabledLayers = circuit.getDisableLayersAttr()) {
735 shouldSpecialize =
true;
736 circuit.removeDisableLayersAttr();
737 for (
auto disabledLayer : disabledLayers.getAsRange<SymbolRefAttr>()) {
739 if (!stc.lookupSymbolIn(circuit, disabledLayer)) {
740 mlir::emitError(circuit.getLoc())
741 <<
"unknown layer " << disabledLayer;
747 auto [it, inserted] = specializations.try_emplace(
748 disabledLayer, LayerSpecialization::Disable);
749 if (!inserted && it->getSecond() == LayerSpecialization::Enable) {
750 mlir::emitError(circuit.getLoc())
751 <<
"layer " << disabledLayer <<
" both enabled and disabled";
758 std::optional<LayerSpecialization> defaultSpecialization = std::nullopt;
759 if (
auto specialization = circuit.getDefaultLayerSpecialization()) {
760 shouldSpecialize =
true;
761 defaultSpecialization = *specialization;
767 if (!shouldSpecialize)
768 return markAllAnalysesPreserved();
771 SpecializeLayers(circuit, specializations, defaultSpecialization)();
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
static mlir::ArrayAttr getFromVoidPointer(void *ptr)