15#include "mlir/IR/Threading.h"
16#include "llvm/ADT/STLExtras.h"
22#define GEN_PASS_DEF_SPECIALIZELAYERS
23#include "circt/Dialect/FIRRTL/Passes.h.inc"
29using namespace firrtl;
37 return mlir::ArrayAttr::getFromOpaquePointer(ptr);
45struct AnnotationCleaner {
48 AnnotationCleaner(
const DenseSet<StringAttr> &removedPaths)
49 : removedPaths(removedPaths) {}
53 if (
auto nla = anno.
getMember<FlatSymbolRefAttr>(
"circt.nonlocal"))
54 return removedPaths.contains(nla.getAttr());
60 void cleanAnnotations(Operation *op) {
64 if (!oldAnnotations.empty()) {
65 auto newAnnotations = cleanAnnotations(oldAnnotations);
66 if (oldAnnotations != newAnnotations)
67 newAnnotations.applyToOperation(op);
71 void operator()(FModuleLike module) {
73 cleanAnnotations(module);
76 for (
size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
78 if (!oldAnnotations.empty()) {
79 auto newAnnotations = cleanAnnotations(oldAnnotations);
80 if (oldAnnotations != newAnnotations)
81 newAnnotations.applyToPort(module, i);
86 module->walk([&](Operation *op) {
90 if (
auto mem = dyn_cast<MemOp>(op)) {
92 for (
size_t i = 0, e = mem.getNumResults(); i < e; ++i) {
94 if (!oldAnnotations.empty()) {
95 auto newAnnotations = cleanAnnotations(oldAnnotations);
96 if (oldAnnotations != newAnnotations)
97 newAnnotations.applyToPort(mem, i);
105 const DenseSet<StringAttr> &removedPaths;
109struct InsertionPoint {
111 static InsertionPoint atBlockEnd(Block *block) {
112 return InsertionPoint{block, block->end()};
117 void moveOpBefore(Operation *op) {
118 op->moveBefore(block, it);
119 it = Block::iterator(op);
123 InsertionPoint(Block *block, Block::iterator it) : block(block), it(it) {}
135 Specialized() : Specialized(LayerSpecialization::Disable, nullptr) {}
138 Specialized(T value) : Specialized(LayerSpecialization::Enable, value) {}
141 bool isDisabled()
const {
142 return value.getInt() == LayerSpecialization::Disable;
148 return value.getPointer();
151 operator bool()
const {
return !isDisabled(); }
154 Specialized(LayerSpecialization specialization, T value)
155 : value(value, specialization) {}
156 llvm::PointerIntPair<T, 1, LayerSpecialization> value;
159struct SpecializeLayers {
162 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations,
163 std::optional<LayerSpecialization> defaultSpecialization)
164 : context(circuit->getContext()), circuit(circuit),
165 specializations(specializations),
166 defaultSpecialization(defaultSpecialization) {}
170 static void recordRemovedInnerSym(DenseSet<Attribute> &removedSyms,
171 StringAttr moduleName,
172 hw::InnerSymAttr innerSym) {
173 for (
auto field : innerSym)
174 removedSyms.insert(
hw::InnerRefAttr::
get(moduleName, field.
getName()));
179 static void recordRemovedInnerSyms(DenseSet<Attribute> &removedSyms,
180 StringAttr moduleName, Block *block) {
181 block->walk([&](hw::InnerSymbolOpInterface op) {
182 if (
auto innerSym = op.getInnerSymAttr())
183 recordRemovedInnerSym(removedSyms, moduleName, innerSym);
189 std::optional<LayerSpecialization> getSpecialization(SymbolRefAttr layerRef) {
190 auto it = specializations.find(layerRef);
191 if (it != specializations.end())
192 return it->getSecond();
193 return defaultSpecialization;
200 std::optional<LayerSpecialization>
201 getSpecialization(StringAttr head, ArrayRef<FlatSymbolRefAttr> nestedRefs) {
202 return getSpecialization(SymbolRefAttr::get(head, nestedRefs));
208 Specialized<SymbolRefAttr> specializeLayerRef(SymbolRefAttr layerRef) {
216 auto oldRoot = layerRef.getRootReference();
217 SmallVector<FlatSymbolRefAttr> oldNestedRefs;
220 SmallString<64> prefix;
221 SmallVector<FlatSymbolRefAttr> newRef;
228 auto helper = [&](StringAttr ref) ->
bool {
229 auto specialization = getSpecialization(oldRoot, oldNestedRefs);
233 if (!specialization) {
234 newRef.push_back(FlatSymbolRefAttr::get(
235 StringAttr::get(ref.getContext(), prefix + ref.getValue())));
242 if (*specialization == LayerSpecialization::Enable) {
243 prefix.append(ref.getValue());
252 if (!helper(oldRoot))
255 for (
auto ref : layerRef.getNestedReferences()) {
256 oldNestedRefs.push_back(ref);
257 if (!helper(ref.getAttr()))
262 return {SymbolRefAttr()};
269 auto newRoot = newRef.front().getAttr();
270 return {SymbolRefAttr::get(newRoot, ArrayRef(newRef).drop_front())};
275 RefType specializeRefType(RefType refType) {
276 if (
auto oldLayer = refType.getLayer()) {
277 if (
auto newLayer = specializeLayerRef(oldLayer))
278 return RefType::get(refType.getType(), refType.getForceable(),
279 newLayer.getValue());
285 Type specializeType(Type type) {
286 if (
auto refType = dyn_cast<RefType>(type))
287 return specializeRefType(refType);
293 Value specializeValue(Value value) {
294 if (
auto newType = specializeType(value.getType())) {
295 value.setType(newType);
306 void specializeOp(LayerBlockOp layerBlock, InsertionPoint &insertionPoint,
307 DenseSet<Attribute> &removedSyms) {
308 auto oldLayerRef = layerBlock.getLayerNameAttr();
314 auto specialization = getSpecialization(oldLayerRef);
317 if (!specialization) {
320 auto newLayerRef = specializeLayerRef(oldLayerRef).getValue();
321 if (oldLayerRef != newLayerRef)
322 layerBlock.setLayerNameAttr(newLayerRef);
325 auto *block = layerBlock.getBody();
326 auto bodyIP = InsertionPoint::atBlockEnd(block);
327 specializeBlock(block, bodyIP, removedSyms);
328 insertionPoint.moveOpBefore(layerBlock);
334 if (*specialization == LayerSpecialization::Enable) {
336 specializeBlock(layerBlock.getBody(), insertionPoint, removedSyms);
344 auto moduleName = layerBlock->getParentOfType<FModuleOp>().getNameAttr();
345 recordRemovedInnerSyms(removedSyms, moduleName, layerBlock.getBody());
349 void specializeOp(WhenOp when, InsertionPoint &insertionPoint,
350 DenseSet<Attribute> &removedSyms) {
353 auto *thenBlock = &when.getThenBlock();
354 auto thenIP = InsertionPoint::atBlockEnd(thenBlock);
355 specializeBlock(thenBlock, thenIP, removedSyms);
356 if (when.hasElseRegion()) {
357 auto *elseBlock = &when.getElseBlock();
358 auto elseIP = InsertionPoint::atBlockEnd(elseBlock);
359 specializeBlock(elseBlock, elseIP, removedSyms);
361 insertionPoint.moveOpBefore(when);
364 void specializeOp(MatchOp match, InsertionPoint &insertionPoint,
365 DenseSet<Attribute> &removedSyms) {
366 for (
size_t i = 0, e = match.getNumRegions(); i < e; ++i) {
367 auto *caseBlock = &match.getRegion(i).front();
368 auto caseIP = InsertionPoint::atBlockEnd(caseBlock);
369 specializeBlock(caseBlock, caseIP, removedSyms);
371 insertionPoint.moveOpBefore(match);
374 void specializeOp(InstanceOp instance, InsertionPoint &insertionPoint,
375 DenseSet<Attribute> &removedSyms) {
378 llvm::BitVector disabledPorts(instance->getNumResults());
379 for (
auto result : instance->getResults())
380 if (!specializeValue(result))
381 disabledPorts.set(result.getResultNumber());
382 if (disabledPorts.any()) {
383 OpBuilder builder(instance);
384 auto newInstance = instance.erasePorts(builder, disabledPorts);
386 instance = newInstance;
392 auto newLayers = specializeEnableLayers(instance.getLayersAttr());
393 instance.setLayersAttr(newLayers.getValue());
395 insertionPoint.moveOpBefore(instance);
398 void specializeOp(InstanceChoiceOp instanceChoice,
399 InsertionPoint &insertionPoint,
400 DenseSet<Attribute> &removedSyms) {
403 llvm::BitVector disabledPorts(instanceChoice->getNumResults());
404 for (
auto result : instanceChoice->getResults())
405 if (!specializeValue(result))
406 disabledPorts.set(result.getResultNumber());
407 if (disabledPorts.any()) {
408 OpBuilder builder(instanceChoice);
409 auto newInstanceChoice =
410 instanceChoice.erasePorts(builder, disabledPorts);
411 instanceChoice->erase();
412 instanceChoice = newInstanceChoice;
418 auto newLayers = specializeEnableLayers(instanceChoice.getLayersAttr());
419 instanceChoice.setLayersAttr(newLayers.getValue());
421 insertionPoint.moveOpBefore(instanceChoice);
424 void specializeOp(WireOp wire, InsertionPoint &insertionPoint,
425 DenseSet<Attribute> &removedSyms) {
426 if (specializeValue(wire.getResult())) {
427 insertionPoint.moveOpBefore(wire);
429 if (
auto innerSym = wire.getInnerSymAttr())
430 recordRemovedInnerSym(removedSyms,
431 wire->getParentOfType<FModuleOp>().getNameAttr(),
437 void specializeOp(RefDefineOp refDefine, InsertionPoint &insertionPoint,
438 DenseSet<Attribute> &removedSyms) {
440 if (
auto layerRef = refDefine.getDest().getType().getLayer())
441 if (!specializeLayerRef(layerRef)) {
445 insertionPoint.moveOpBefore(refDefine);
448 void specializeOp(RefSubOp refSub, InsertionPoint &insertionPoint,
449 DenseSet<Attribute> &removedSyms) {
450 if (specializeValue(refSub->getResult(0)))
451 insertionPoint.moveOpBefore(refSub);
457 void specializeOp(RefCastOp refCast, InsertionPoint &insertionPoint,
458 DenseSet<Attribute> &removedSyms) {
459 if (specializeValue(refCast->getResult(0)))
460 insertionPoint.moveOpBefore(refCast);
467 void specializeBlock(Block *block, InsertionPoint &insertionPoint,
468 DenseSet<Attribute> &removedSyms) {
471 for (
auto &op :
llvm::make_early_inc_range(
llvm::reverse(*block))) {
472 TypeSwitch<Operation *>(&op)
473 .Case<LayerBlockOp, WhenOp, MatchOp, InstanceOp, InstanceChoiceOp,
474 WireOp, RefDefineOp, RefSubOp, RefCastOp>(
475 [&](
auto op) { specializeOp(op, insertionPoint, removedSyms); })
476 .Default([&](Operation *op) {
479 insertionPoint.moveOpBefore(op);
485 ArrayAttr specializeKnownLayers(ArrayAttr layers) {
486 SmallVector<Attribute> newLayers;
487 for (
auto layer : layers.getAsRange<SymbolRefAttr>()) {
488 if (
auto result = specializeLayerRef(layer))
489 if (
auto newLayer = result.getValue())
490 newLayers.push_back(newLayer);
493 return ArrayAttr::get(context, newLayers);
498 Specialized<ArrayAttr> specializeEnableLayers(ArrayAttr layers) {
499 SmallVector<Attribute> newLayers;
500 for (
auto layer : layers.getAsRange<SymbolRefAttr>()) {
501 auto newLayer = specializeLayerRef(layer);
504 if (newLayer.getValue())
505 newLayers.push_back(newLayer.getValue());
507 return ArrayAttr::get(context, newLayers);
510 void specializeModulePorts(FModuleLike moduleLike,
511 DenseSet<Attribute> &removedSyms) {
512 auto oldTypeAttrs = moduleLike.getPortTypesAttr();
515 SmallVector<Attribute> newTypeAttrs;
516 newTypeAttrs.reserve(oldTypeAttrs.size());
520 llvm::BitVector disabledPorts(oldTypeAttrs.size());
522 auto moduleName = moduleLike.getNameAttr();
523 for (
auto [index, typeAttr] :
524 llvm::enumerate(oldTypeAttrs.getAsRange<TypeAttr>())) {
526 if (
auto type = specializeType(typeAttr.getValue())) {
527 newTypeAttrs.push_back(TypeAttr::get(type));
530 if (
auto portSym = moduleLike.getPortSymbolAttr(index))
531 recordRemovedInnerSym(removedSyms, moduleName, portSym);
532 disabledPorts.set(index);
537 moduleLike.erasePorts(disabledPorts);
540 moduleLike.setPortTypesAttr(
541 ArrayAttr::get(moduleLike.getContext(), newTypeAttrs));
544 if (
auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation()))
545 for (
auto [arg, typeAttr] :
546 llvm::zip(moduleOp.getArguments(), newTypeAttrs))
547 arg.setType(cast<TypeAttr>(typeAttr).getValue());
550 template <
typename T>
551 DenseSet<Attribute> specializeModuleLike(T op) {
552 DenseSet<Attribute> removedSyms;
557 if constexpr (std::is_same_v<T, FModuleOp>) {
558 auto *block = cast<FModuleOp>(op).getBodyBlock();
559 auto bodyIP = InsertionPoint::atBlockEnd(block);
560 specializeBlock(block, bodyIP, removedSyms);
564 specializeModulePorts(op, removedSyms);
569 template <
typename T>
570 T specializeEnableLayers(T module, DenseSet<Attribute> &removedSyms) {
572 if (
auto newLayers = specializeEnableLayers(module.getLayersAttr())) {
573 module.setLayersAttr(newLayers.getValue());
579 auto moduleName =
module.getNameAttr();
580 removedSyms.insert(FlatSymbolRefAttr::get(moduleName));
581 if constexpr (std::is_same_v<T, FModuleOp>)
582 recordRemovedInnerSyms(removedSyms, moduleName,
583 cast<FModuleOp>(module).getBodyBlock());
590 void specializeKnownLayers(FExtModuleOp module) {
591 auto knownLayers =
module.getKnownLayersAttr();
592 module.setKnownLayersAttr(specializeKnownLayers(knownLayers));
598 void specializeLayer(LayerOp layer) {
599 StringAttr head = layer.getSymNameAttr();
600 SmallVector<FlatSymbolRefAttr> nestedRefs;
602 std::function<void(LayerOp, Block::iterator,
const Twine &)> handleLayer =
603 [&](LayerOp layer, Block::iterator insertionPoint,
604 const Twine &prefix) {
605 auto *block = &layer.getBody().getBlocks().front();
606 auto specialization = getSpecialization(head, nestedRefs);
610 if (!specialization) {
614 if (!prefix.isTriviallyEmpty()) {
615 layer.setSymNameAttr(
616 StringAttr::get(context, prefix + layer.getSymName()));
617 auto *parentBlock = insertionPoint->getBlock();
618 layer->moveBefore(parentBlock, insertionPoint);
621 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
622 nestedRefs.push_back(SymbolRefAttr::get(nested));
623 handleLayer(nested, Block::iterator(nested),
"");
624 nestedRefs.pop_back();
631 if (*specialization == LayerSpecialization::Enable) {
633 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
634 nestedRefs.push_back(SymbolRefAttr::get(nested));
635 handleLayer(nested, insertionPoint,
636 prefix + layer.getSymName() +
"_");
637 nestedRefs.pop_back();
649 handleLayer(layer, Block::iterator(layer),
"");
656 SmallVector<Operation *> specialize;
657 DenseSet<Attribute> removedSyms;
659 TypeSwitch<Operation *>(&op)
660 .Case<FModuleOp>([&](FModuleOp module) {
661 if (specializeEnableLayers(module, removedSyms))
662 specialize.push_back(module);
664 .Case<FExtModuleOp>([&](FExtModuleOp module) {
665 specializeKnownLayers(module);
666 if (specializeEnableLayers(module, removedSyms))
667 specialize.push_back(module);
669 .Case<LayerOp>([&](LayerOp layer) { specializeLayer(layer); });
673 auto mergeSets = [](
auto &&a,
auto &&b) {
674 a.insert(b.begin(), b.end());
675 return std::forward<decltype(a)>(a);
681 context, specialize, removedSyms, mergeSets,
682 [&](Operation *op) -> DenseSet<Attribute> {
683 return TypeSwitch<Operation *, DenseSet<Attribute>>(op)
684 .Case<FModuleOp, FExtModuleOp>(
685 [&](
auto op) {
return specializeModuleLike(op); });
691 DenseSet<StringAttr> removedPaths;
692 for (
auto hierPath :
llvm::make_early_inc_range(
693 circuit.getBody().getOps<
hw::HierPathOp>())) {
694 auto namepath = hierPath.getNamepath().getValue();
695 auto shouldDelete = [&](Attribute ref) {
696 return removedSyms.contains(ref);
698 if (llvm::any_of(namepath.drop_back(), shouldDelete)) {
699 removedPaths.insert(SymbolTable::getSymbolName(hierPath));
706 if (shouldDelete(namepath.back()))
712 SmallVector<FModuleLike> clean;
714 if (isa<FModuleOp, FExtModuleOp, FIntModuleOp, FMemModuleOp>(op))
715 clean.push_back(cast<FModuleLike>(op));
717 parallelForEach(context, clean, [&](FModuleLike module) {
718 (AnnotationCleaner(removedPaths))(module);
722 MLIRContext *context;
724 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations;
727 std::optional<LayerSpecialization> defaultSpecialization;
730struct SpecializeLayersPass
731 :
public circt::firrtl::impl::SpecializeLayersBase<SpecializeLayersPass> {
733 void runOnOperation()
override {
734 auto circuit = getOperation();
735 SymbolTableCollection stc;
738 DenseMap<SymbolRefAttr, LayerSpecialization> specializations;
741 bool shouldSpecialize =
false;
744 if (
auto enabledLayers = circuit.getEnableLayersAttr()) {
745 shouldSpecialize =
true;
746 circuit.removeEnableLayersAttr();
747 for (
auto enabledLayer : enabledLayers.getAsRange<SymbolRefAttr>()) {
749 if (!stc.lookupSymbolIn(circuit, enabledLayer)) {
750 mlir::emitError(circuit.getLoc()) <<
"unknown layer " << enabledLayer;
754 specializations[enabledLayer] = LayerSpecialization::Enable;
759 if (
auto disabledLayers = circuit.getDisableLayersAttr()) {
760 shouldSpecialize =
true;
761 circuit.removeDisableLayersAttr();
762 for (
auto disabledLayer : disabledLayers.getAsRange<SymbolRefAttr>()) {
764 if (!stc.lookupSymbolIn(circuit, disabledLayer)) {
765 mlir::emitError(circuit.getLoc())
766 <<
"unknown layer " << disabledLayer;
772 auto [it, inserted] = specializations.try_emplace(
773 disabledLayer, LayerSpecialization::Disable);
774 if (!inserted && it->getSecond() == LayerSpecialization::Enable) {
775 mlir::emitError(circuit.getLoc())
776 <<
"layer " << disabledLayer <<
" both enabled and disabled";
783 std::optional<LayerSpecialization> defaultSpecialization = std::nullopt;
784 if (
auto specialization = circuit.getDefaultLayerSpecialization()) {
785 shouldSpecialize =
true;
786 defaultSpecialization = *specialization;
792 if (!shouldSpecialize)
793 return markAllAnalysesPreserved();
796 SpecializeLayers(circuit, specializations, defaultSpecialization)();
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
static mlir::ArrayAttr getFromVoidPointer(void *ptr)