15 #include "mlir/IR/Threading.h"
16 #include "llvm/ADT/STLExtras.h"
18 #include <type_traits>
22 #define GEN_PASS_DEF_SPECIALIZELAYERS
23 #include "circt/Dialect/FIRRTL/Passes.h.inc"
28 using namespace circt;
29 using namespace firrtl;
37 return mlir::ArrayAttr::getFromOpaquePointer(ptr);
45 struct AnnotationCleaner {
48 AnnotationCleaner(
const DenseSet<StringAttr> &removedPaths)
49 : removedPaths(removedPaths) {}
53 if (
auto nla = anno.
getMember<FlatSymbolRefAttr>(
"circt.nonlocal"))
54 return removedPaths.contains(nla.getAttr());
60 void cleanAnnotations(Operation *op) {
64 if (!oldAnnotations.empty()) {
65 auto newAnnotations = cleanAnnotations(oldAnnotations);
66 if (oldAnnotations != newAnnotations)
67 newAnnotations.applyToOperation(op);
71 void operator()(FModuleLike module) {
73 cleanAnnotations(module);
76 for (
size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
78 if (!oldAnnotations.empty()) {
79 auto newAnnotations = cleanAnnotations(oldAnnotations);
80 if (oldAnnotations != newAnnotations)
81 newAnnotations.applyToPort(module, i);
86 module->walk([&](Operation *op) {
90 if (
auto mem = dyn_cast<MemOp>(op)) {
92 for (
size_t i = 0, e = mem.getNumResults(); i < e; ++i) {
94 if (!oldAnnotations.empty()) {
95 auto newAnnotations = cleanAnnotations(oldAnnotations);
96 if (oldAnnotations != newAnnotations)
97 newAnnotations.applyToPort(mem, i);
105 const DenseSet<StringAttr> &removedPaths;
109 struct InsertionPoint {
111 static InsertionPoint atBlockEnd(Block *block) {
112 return InsertionPoint{block, block->end()};
117 void moveOpBefore(Operation *op) {
118 op->moveBefore(block, it);
119 it = Block::iterator(op);
123 InsertionPoint(Block *block, Block::iterator it) : block(block), it(it) {}
132 template <
typename T>
135 Specialized() : Specialized(LayerSpecialization::Disable, nullptr) {}
138 Specialized(T value) : Specialized(LayerSpecialization::Enable, value) {}
141 bool isDisabled()
const {
142 return value.getInt() == LayerSpecialization::Disable;
148 return value.getPointer();
151 operator bool()
const {
return !isDisabled(); }
154 Specialized(LayerSpecialization specialization, T value)
155 : value(value, specialization) {}
156 llvm::PointerIntPair<T, 1, LayerSpecialization> value;
159 struct SpecializeLayers {
162 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations,
163 std::optional<LayerSpecialization> defaultSpecialization)
164 : context(circuit->getContext()), circuit(circuit),
165 specializations(specializations),
166 defaultSpecialization(defaultSpecialization) {}
170 static void recordRemovedInnerSym(DenseSet<Attribute> &removedSyms,
171 StringAttr moduleName,
172 hw::InnerSymAttr innerSym) {
173 for (
auto field : innerSym)
179 static void recordRemovedInnerSyms(DenseSet<Attribute> &removedSyms,
180 StringAttr moduleName, Block *block) {
181 block->walk([&](hw::InnerSymbolOpInterface op) {
182 if (
auto innerSym = op.getInnerSymAttr())
183 recordRemovedInnerSym(removedSyms, moduleName, innerSym);
189 std::optional<LayerSpecialization> getSpecialization(SymbolRefAttr layerRef) {
190 auto it = specializations.find(layerRef);
191 if (it != specializations.end())
192 return it->getSecond();
193 return defaultSpecialization;
200 std::optional<LayerSpecialization>
201 getSpecialization(StringAttr head, ArrayRef<FlatSymbolRefAttr> nestedRefs) {
208 Specialized<SymbolRefAttr> specializeLayerRef(SymbolRefAttr layerRef) {
216 auto oldRoot = layerRef.getRootReference();
217 SmallVector<FlatSymbolRefAttr> oldNestedRefs;
220 SmallString<64> prefix;
221 SmallVector<FlatSymbolRefAttr> newRef;
228 auto helper = [&](StringAttr ref) ->
bool {
229 auto specialization = getSpecialization(oldRoot, oldNestedRefs);
233 if (!specialization) {
242 if (*specialization == LayerSpecialization::Enable) {
243 prefix.append(ref.getValue());
252 if (!helper(oldRoot))
255 for (
auto ref : layerRef.getNestedReferences()) {
256 oldNestedRefs.push_back(ref);
257 if (!helper(ref.getAttr()))
262 return {SymbolRefAttr()};
269 auto newRoot = newRef.front().getAttr();
275 RefType specializeRefType(RefType refType) {
276 if (
auto oldLayer = refType.getLayer()) {
277 if (
auto newLayer = specializeLayerRef(oldLayer))
278 return RefType::get(refType.getType(), refType.getForceable(),
279 newLayer.getValue());
285 Type specializeType(Type type) {
286 if (
auto refType = dyn_cast<RefType>(type))
287 return specializeRefType(refType);
293 Value specializeValue(Value value) {
294 if (
auto newType = specializeType(value.getType())) {
295 value.setType(newType);
306 void specializeOp(LayerBlockOp layerBlock, InsertionPoint &insertionPoint,
307 DenseSet<Attribute> &removedSyms) {
308 auto oldLayerRef = layerBlock.getLayerNameAttr();
314 auto specialization = getSpecialization(oldLayerRef);
317 if (!specialization) {
320 auto newLayerRef = specializeLayerRef(oldLayerRef).getValue();
321 if (oldLayerRef != newLayerRef)
322 layerBlock.setLayerNameAttr(newLayerRef);
325 auto *block = layerBlock.getBody();
326 auto bodyIP = InsertionPoint::atBlockEnd(block);
327 specializeBlock(block, bodyIP, removedSyms);
328 insertionPoint.moveOpBefore(layerBlock);
334 if (*specialization == LayerSpecialization::Enable) {
336 specializeBlock(layerBlock.getBody(), insertionPoint, removedSyms);
344 auto moduleName = layerBlock->getParentOfType<FModuleOp>().getNameAttr();
345 recordRemovedInnerSyms(removedSyms, moduleName, layerBlock.getBody());
349 void specializeOp(WhenOp when, InsertionPoint &insertionPoint,
350 DenseSet<Attribute> &removedSyms) {
353 auto *thenBlock = &when.getThenBlock();
354 auto thenIP = InsertionPoint::atBlockEnd(thenBlock);
355 specializeBlock(thenBlock, thenIP, removedSyms);
356 if (when.hasElseRegion()) {
357 auto *elseBlock = &when.getElseBlock();
358 auto elseIP = InsertionPoint::atBlockEnd(elseBlock);
359 specializeBlock(elseBlock, elseIP, removedSyms);
361 insertionPoint.moveOpBefore(when);
364 void specializeOp(MatchOp match, InsertionPoint &insertionPoint,
365 DenseSet<Attribute> &removedSyms) {
366 for (
size_t i = 0, e = match.getNumRegions(); i < e; ++i) {
367 auto *caseBlock = &match.getRegion(i).front();
368 auto caseIP = InsertionPoint::atBlockEnd(caseBlock);
369 specializeBlock(caseBlock, caseIP, removedSyms);
371 insertionPoint.moveOpBefore(match);
374 void specializeOp(InstanceOp instance, InsertionPoint &insertionPoint,
375 DenseSet<Attribute> &removedSyms) {
378 llvm::BitVector disabledPorts(instance->getNumResults());
379 for (
auto result : instance->getResults())
380 if (!specializeValue(result))
381 disabledPorts.set(result.getResultNumber());
382 if (disabledPorts.any()) {
383 OpBuilder builder(instance);
384 auto newInstance = instance.erasePorts(builder, disabledPorts);
386 instance = newInstance;
392 auto newLayers = specializeEnableLayers(instance.getLayersAttr());
393 instance.setLayersAttr(newLayers.getValue());
395 insertionPoint.moveOpBefore(instance);
398 void specializeOp(InstanceChoiceOp instanceChoice,
399 InsertionPoint &insertionPoint,
400 DenseSet<Attribute> &removedSyms) {
403 llvm::BitVector disabledPorts(instanceChoice->getNumResults());
404 for (
auto result : instanceChoice->getResults())
405 if (!specializeValue(result))
406 disabledPorts.set(result.getResultNumber());
407 if (disabledPorts.any()) {
408 OpBuilder builder(instanceChoice);
409 auto newInstanceChoice =
410 instanceChoice.erasePorts(builder, disabledPorts);
411 instanceChoice->erase();
412 instanceChoice = newInstanceChoice;
418 auto newLayers = specializeEnableLayers(instanceChoice.getLayersAttr());
419 instanceChoice.setLayersAttr(newLayers.getValue());
421 insertionPoint.moveOpBefore(instanceChoice);
424 void specializeOp(WireOp wire, InsertionPoint &insertionPoint,
425 DenseSet<Attribute> &removedSyms) {
426 if (specializeValue(wire.getResult())) {
427 insertionPoint.moveOpBefore(wire);
429 if (
auto innerSym = wire.getInnerSymAttr())
430 recordRemovedInnerSym(removedSyms,
431 wire->getParentOfType<FModuleOp>().getNameAttr(),
437 void specializeOp(RefDefineOp refDefine, InsertionPoint &insertionPoint,
438 DenseSet<Attribute> &removedSyms) {
440 if (
auto layerRef = refDefine.getDest().getType().getLayer())
441 if (!specializeLayerRef(layerRef)) {
445 insertionPoint.moveOpBefore(refDefine);
448 void specializeOp(RefSubOp refSub, InsertionPoint &insertionPoint,
449 DenseSet<Attribute> &removedSyms) {
450 if (specializeValue(refSub->getResult(0)))
451 insertionPoint.moveOpBefore(refSub);
457 void specializeOp(RefCastOp refCast, InsertionPoint &insertionPoint,
458 DenseSet<Attribute> &removedSyms) {
459 if (specializeValue(refCast->getResult(0)))
460 insertionPoint.moveOpBefore(refCast);
467 void specializeBlock(Block *block, InsertionPoint &insertionPoint,
468 DenseSet<Attribute> &removedSyms) {
471 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
472 TypeSwitch<Operation *>(&op)
473 .Case<LayerBlockOp, WhenOp, MatchOp, InstanceOp, InstanceChoiceOp,
474 WireOp, RefDefineOp, RefSubOp, RefCastOp>(
475 [&](
auto op) { specializeOp(op, insertionPoint, removedSyms); })
476 .Default([&](Operation *op) {
479 insertionPoint.moveOpBefore(op);
486 Specialized<ArrayAttr> specializeEnableLayers(ArrayAttr layers) {
487 SmallVector<Attribute> newLayers;
488 for (
auto layer : layers.getAsRange<SymbolRefAttr>()) {
489 auto newLayer = specializeLayerRef(layer);
492 if (newLayer.getValue())
493 newLayers.push_back(newLayer.getValue());
498 void specializeModulePorts(FModuleLike moduleLike,
499 DenseSet<Attribute> &removedSyms) {
500 auto oldTypeAttrs = moduleLike.getPortTypesAttr();
503 SmallVector<Attribute> newTypeAttrs;
504 newTypeAttrs.reserve(oldTypeAttrs.size());
508 llvm::BitVector disabledPorts(oldTypeAttrs.size());
510 auto moduleName = moduleLike.getNameAttr();
511 for (
auto [index, typeAttr] :
512 llvm::enumerate(oldTypeAttrs.getAsRange<TypeAttr>())) {
514 if (
auto type = specializeType(typeAttr.getValue())) {
518 if (
auto portSym = moduleLike.getPortSymbolAttr(index))
519 recordRemovedInnerSym(removedSyms, moduleName, portSym);
520 disabledPorts.set(index);
525 moduleLike.erasePorts(disabledPorts);
528 moduleLike.setPortTypesAttr(
532 if (
auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation()))
533 for (
auto [arg, typeAttr] :
534 llvm::zip(moduleOp.getArguments(), newTypeAttrs))
535 arg.setType(cast<TypeAttr>(typeAttr).getValue());
538 template <
typename T>
539 DenseSet<Attribute> specializeModuleLike(T op) {
540 DenseSet<Attribute> removedSyms;
545 if constexpr (std::is_same_v<T, FModuleOp>) {
546 auto *block = cast<FModuleOp>(op).getBodyBlock();
547 auto bodyIP = InsertionPoint::atBlockEnd(block);
548 specializeBlock(block, bodyIP, removedSyms);
552 specializeModulePorts(op, removedSyms);
557 template <
typename T>
558 T specializeEnableLayers(T module, DenseSet<Attribute> &removedSyms) {
560 if (
auto newLayers = specializeEnableLayers(module.getLayersAttr())) {
561 module.setLayersAttr(newLayers.getValue());
567 auto moduleName = module.getNameAttr();
569 if constexpr (std::is_same_v<T, FModuleOp>)
570 recordRemovedInnerSyms(removedSyms, moduleName,
580 void specializeLayer(LayerOp layer) {
581 StringAttr head = layer.getSymNameAttr();
582 SmallVector<FlatSymbolRefAttr> nestedRefs;
584 std::function<void(LayerOp, Block::iterator,
const Twine &)> handleLayer =
585 [&](LayerOp layer, Block::iterator insertionPoint,
586 const Twine &prefix) {
587 auto *block = &layer.getBody().getBlocks().front();
588 auto specialization = getSpecialization(head, nestedRefs);
592 if (!specialization) {
596 if (!prefix.isTriviallyEmpty()) {
597 layer.setSymNameAttr(
599 auto *parentBlock = insertionPoint->getBlock();
600 layer->moveBefore(parentBlock, insertionPoint);
603 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
605 handleLayer(nested, Block::iterator(nested),
"");
606 nestedRefs.pop_back();
613 if (*specialization == LayerSpecialization::Enable) {
615 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
617 handleLayer(nested, insertionPoint,
618 prefix + layer.getSymName() +
"_");
619 nestedRefs.pop_back();
631 handleLayer(layer, Block::iterator(layer),
"");
638 SmallVector<Operation *> specialize;
639 DenseSet<Attribute> removedSyms;
640 for (
auto &op : llvm::make_early_inc_range(*circuit.getBodyBlock())) {
641 TypeSwitch<Operation *>(&op)
642 .Case<FModuleOp, FExtModuleOp>([&](
auto module) {
643 if (specializeEnableLayers(module, removedSyms))
644 specialize.push_back(module);
646 .Case<LayerOp>([&](LayerOp layer) { specializeLayer(layer); });
650 auto mergeSets = [](
auto &&a,
auto &&b) {
651 a.insert(b.begin(), b.end());
652 return std::forward<decltype(a)>(a);
658 context, specialize, removedSyms, mergeSets,
659 [&](Operation *op) -> DenseSet<Attribute> {
660 return TypeSwitch<Operation *, DenseSet<Attribute>>(op)
661 .Case<FModuleOp, FExtModuleOp>(
662 [&](
auto op) {
return specializeModuleLike(op); });
668 DenseSet<StringAttr> removedPaths;
669 for (
auto hierPath : llvm::make_early_inc_range(
670 circuit.getBody().getOps<hw::HierPathOp>())) {
671 auto namepath = hierPath.getNamepath().getValue();
672 auto shouldDelete = [&](Attribute ref) {
673 return removedSyms.contains(ref);
675 if (llvm::any_of(namepath.drop_back(), shouldDelete)) {
676 removedPaths.insert(SymbolTable::getSymbolName(hierPath));
683 if (shouldDelete(namepath.back()))
689 SmallVector<FModuleLike> clean;
690 for (
auto &op : *circuit.getBodyBlock())
691 if (isa<FModuleOp, FExtModuleOp, FIntModuleOp, FMemModuleOp>(op))
692 clean.push_back(cast<FModuleLike>(op));
694 parallelForEach(context, clean, [&](FModuleLike module) {
695 (AnnotationCleaner(removedPaths))(module);
699 MLIRContext *context;
701 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations;
704 std::optional<LayerSpecialization> defaultSpecialization;
707 struct SpecializeLayersPass
708 :
public circt::firrtl::impl::SpecializeLayersBase<SpecializeLayersPass> {
710 void runOnOperation()
override {
711 auto circuit = getOperation();
712 SymbolTableCollection stc;
715 DenseMap<SymbolRefAttr, LayerSpecialization> specializations;
718 bool shouldSpecialize =
false;
721 if (
auto enabledLayers = circuit.getEnableLayersAttr()) {
722 shouldSpecialize =
true;
723 circuit.removeEnableLayersAttr();
724 for (
auto enabledLayer : enabledLayers.getAsRange<SymbolRefAttr>()) {
726 if (!stc.lookupSymbolIn(circuit, enabledLayer)) {
727 mlir::emitError(circuit.getLoc()) <<
"unknown layer " << enabledLayer;
731 specializations[enabledLayer] = LayerSpecialization::Enable;
736 if (
auto disabledLayers = circuit.getDisableLayersAttr()) {
737 shouldSpecialize =
true;
738 circuit.removeDisableLayersAttr();
739 for (
auto disabledLayer : disabledLayers.getAsRange<SymbolRefAttr>()) {
741 if (!stc.lookupSymbolIn(circuit, disabledLayer)) {
742 mlir::emitError(circuit.getLoc())
743 <<
"unknown layer " << disabledLayer;
749 auto [it, inserted] = specializations.try_emplace(
750 disabledLayer, LayerSpecialization::Disable);
751 if (!inserted && it->getSecond() == LayerSpecialization::Enable) {
752 mlir::emitError(circuit.getLoc())
753 <<
"layer " << disabledLayer <<
" both enabled and disabled";
760 std::optional<LayerSpecialization> defaultSpecialization = std::nullopt;
761 if (
auto specialization = circuit.getDefaultLayerSpecialization()) {
762 shouldSpecialize =
true;
763 defaultSpecialization = *specialization;
769 if (!shouldSpecialize)
770 return markAllAnalysesPreserved();
773 SpecializeLayers(circuit, specializations, defaultSpecialization)();
779 return std::make_unique<SpecializeLayersPass>();
assert(baseType &&"element must be base type")
static AnnotationSet forPort(Operation *op, size_t portNo)
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createSpecializeLayersPass()
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static mlir::ArrayAttr getFromVoidPointer(void *ptr)