16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/IR/Threading.h"
18 #include "llvm/ADT/STLExtras.h"
20 #include <type_traits>
24 #define GEN_PASS_DEF_SPECIALIZELAYERS
25 #include "circt/Dialect/FIRRTL/Passes.h.inc"
30 using namespace circt;
31 using namespace firrtl;
39 return mlir::ArrayAttr::getFromOpaquePointer(ptr);
47 struct AnnotationCleaner {
50 AnnotationCleaner(
const DenseSet<StringAttr> &removedPaths)
51 : removedPaths(removedPaths) {}
55 if (
auto nla = anno.
getMember<FlatSymbolRefAttr>(
"circt.nonlocal"))
56 return removedPaths.contains(nla.getAttr());
62 void cleanAnnotations(Operation *op) {
66 if (!oldAnnotations.empty()) {
67 auto newAnnotations = cleanAnnotations(oldAnnotations);
68 if (oldAnnotations != newAnnotations)
69 newAnnotations.applyToOperation(op);
73 void operator()(FModuleLike module) {
75 cleanAnnotations(module);
78 for (
size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
80 if (!oldAnnotations.empty()) {
81 auto newAnnotations = cleanAnnotations(oldAnnotations);
82 if (oldAnnotations != newAnnotations)
83 newAnnotations.applyToPort(module, i);
88 module->walk([&](Operation *op) {
92 if (
auto mem = dyn_cast<MemOp>(op)) {
94 for (
size_t i = 0, e = mem.getNumResults(); i < e; ++i) {
96 if (!oldAnnotations.empty()) {
97 auto newAnnotations = cleanAnnotations(oldAnnotations);
98 if (oldAnnotations != newAnnotations)
99 newAnnotations.applyToPort(mem, i);
107 const DenseSet<StringAttr> &removedPaths;
111 struct InsertionPoint {
113 static InsertionPoint atBlockEnd(Block *block) {
114 return InsertionPoint{block, block->end()};
119 void moveOpBefore(Operation *op) {
120 op->moveBefore(block, it);
121 it = Block::iterator(op);
125 InsertionPoint(Block *block, Block::iterator it) : block(block), it(it) {}
134 template <
typename T>
137 Specialized() : Specialized(LayerSpecialization::Disable, nullptr) {}
140 Specialized(T value) : Specialized(LayerSpecialization::Enable, value) {}
143 bool isDisabled()
const {
144 return value.getInt() == LayerSpecialization::Disable;
150 return value.getPointer();
153 operator bool()
const {
return !isDisabled(); }
156 Specialized(LayerSpecialization specialization, T value)
157 : value(value, specialization) {}
158 llvm::PointerIntPair<T, 1, LayerSpecialization> value;
161 struct SpecializeLayers {
164 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations,
165 std::optional<LayerSpecialization> defaultSpecialization)
166 : context(circuit->getContext()), circuit(circuit),
167 specializations(specializations),
168 defaultSpecialization(defaultSpecialization) {}
172 static void recordRemovedInnerSym(DenseSet<Attribute> &removedSyms,
173 StringAttr moduleName,
174 hw::InnerSymAttr innerSym) {
175 for (
auto field : innerSym)
181 static void recordRemovedInnerSyms(DenseSet<Attribute> &removedSyms,
182 StringAttr moduleName, Block *block) {
183 block->walk([&](hw::InnerSymbolOpInterface op) {
184 if (
auto innerSym = op.getInnerSymAttr())
185 recordRemovedInnerSym(removedSyms, moduleName, innerSym);
191 std::optional<LayerSpecialization> getSpecialization(SymbolRefAttr layerRef) {
192 auto it = specializations.find(layerRef);
193 if (it != specializations.end())
194 return it->getSecond();
195 return defaultSpecialization;
202 std::optional<LayerSpecialization>
203 getSpecialization(StringAttr head, ArrayRef<FlatSymbolRefAttr> nestedRefs) {
210 Specialized<SymbolRefAttr> specializeLayerRef(SymbolRefAttr layerRef) {
218 auto oldRoot = layerRef.getRootReference();
219 SmallVector<FlatSymbolRefAttr> oldNestedRefs;
222 SmallString<64> prefix;
223 SmallVector<FlatSymbolRefAttr> newRef;
230 auto helper = [&](StringAttr ref) ->
bool {
231 auto specialization = getSpecialization(oldRoot, oldNestedRefs);
235 if (!specialization) {
244 if (*specialization == LayerSpecialization::Enable) {
245 prefix.append(ref.getValue());
254 if (!helper(oldRoot))
257 for (
auto ref : layerRef.getNestedReferences()) {
258 oldNestedRefs.push_back(ref);
259 if (!helper(ref.getAttr()))
264 return {SymbolRefAttr()};
271 auto newRoot = newRef.front().getAttr();
277 RefType specializeRefType(RefType refType) {
278 if (
auto oldLayer = refType.getLayer()) {
279 if (
auto newLayer = specializeLayerRef(oldLayer))
280 return RefType::get(refType.getType(), refType.getForceable(),
281 newLayer.getValue());
287 Type specializeType(Type type) {
288 if (
auto refType = dyn_cast<RefType>(type))
289 return specializeRefType(refType);
295 Value specializeValue(Value value) {
296 if (
auto newType = specializeType(value.getType())) {
297 value.setType(newType);
308 void specializeOp(LayerBlockOp layerBlock, InsertionPoint &insertionPoint,
309 DenseSet<Attribute> &removedSyms) {
310 auto oldLayerRef = layerBlock.getLayerNameAttr();
316 auto specialization = getSpecialization(oldLayerRef);
319 if (!specialization) {
322 auto newLayerRef = specializeLayerRef(oldLayerRef).getValue();
323 if (oldLayerRef != newLayerRef)
324 layerBlock.setLayerNameAttr(newLayerRef);
327 auto *block = layerBlock.getBody();
328 auto bodyIP = InsertionPoint::atBlockEnd(block);
329 specializeBlock(block, bodyIP, removedSyms);
330 insertionPoint.moveOpBefore(layerBlock);
336 if (*specialization == LayerSpecialization::Enable) {
338 specializeBlock(layerBlock.getBody(), insertionPoint, removedSyms);
346 auto moduleName = layerBlock->getParentOfType<FModuleOp>().getNameAttr();
347 recordRemovedInnerSyms(removedSyms, moduleName, layerBlock.getBody());
351 void specializeOp(WhenOp when, InsertionPoint &insertionPoint,
352 DenseSet<Attribute> &removedSyms) {
355 auto *thenBlock = &when.getThenBlock();
356 auto thenIP = InsertionPoint::atBlockEnd(thenBlock);
357 specializeBlock(thenBlock, thenIP, removedSyms);
358 if (when.hasElseRegion()) {
359 auto *elseBlock = &when.getElseBlock();
360 auto elseIP = InsertionPoint::atBlockEnd(elseBlock);
361 specializeBlock(elseBlock, elseIP, removedSyms);
363 insertionPoint.moveOpBefore(when);
366 void specializeOp(MatchOp match, InsertionPoint &insertionPoint,
367 DenseSet<Attribute> &removedSyms) {
368 for (
size_t i = 0, e = match.getNumRegions(); i < e; ++i) {
369 auto *caseBlock = &match.getRegion(i).front();
370 auto caseIP = InsertionPoint::atBlockEnd(caseBlock);
371 specializeBlock(caseBlock, caseIP, removedSyms);
373 insertionPoint.moveOpBefore(match);
376 void specializeOp(InstanceOp instance, InsertionPoint &insertionPoint,
377 DenseSet<Attribute> &removedSyms) {
380 llvm::BitVector disabledPorts(instance->getNumResults());
381 for (
auto result : instance->getResults())
382 if (!specializeValue(result))
383 disabledPorts.set(result.getResultNumber());
384 if (disabledPorts.any()) {
385 OpBuilder builder(instance);
386 auto newInstance = instance.erasePorts(builder, disabledPorts);
388 instance = newInstance;
394 auto newLayers = specializeEnableLayers(instance.getLayersAttr());
395 instance.setLayersAttr(newLayers.getValue());
397 insertionPoint.moveOpBefore(instance);
400 void specializeOp(InstanceChoiceOp instanceChoice,
401 InsertionPoint &insertionPoint,
402 DenseSet<Attribute> &removedSyms) {
405 llvm::BitVector disabledPorts(instanceChoice->getNumResults());
406 for (
auto result : instanceChoice->getResults())
407 if (!specializeValue(result))
408 disabledPorts.set(result.getResultNumber());
409 if (disabledPorts.any()) {
410 OpBuilder builder(instanceChoice);
411 auto newInstanceChoice =
412 instanceChoice.erasePorts(builder, disabledPorts);
413 instanceChoice->erase();
414 instanceChoice = newInstanceChoice;
420 auto newLayers = specializeEnableLayers(instanceChoice.getLayersAttr());
421 instanceChoice.setLayersAttr(newLayers.getValue());
423 insertionPoint.moveOpBefore(instanceChoice);
426 void specializeOp(WireOp wire, InsertionPoint &insertionPoint,
427 DenseSet<Attribute> &removedSyms) {
428 if (specializeValue(wire.getResult())) {
429 insertionPoint.moveOpBefore(wire);
431 if (
auto innerSym = wire.getInnerSymAttr())
432 recordRemovedInnerSym(removedSyms,
433 wire->getParentOfType<FModuleOp>().getNameAttr(),
439 void specializeOp(RefDefineOp refDefine, InsertionPoint &insertionPoint,
440 DenseSet<Attribute> &removedSyms) {
442 if (
auto layerRef = refDefine.getDest().getType().getLayer())
443 if (!specializeLayerRef(layerRef)) {
447 insertionPoint.moveOpBefore(refDefine);
450 void specializeOp(RefSubOp refSub, InsertionPoint &insertionPoint,
451 DenseSet<Attribute> &removedSyms) {
452 if (specializeValue(refSub->getResult(0)))
453 insertionPoint.moveOpBefore(refSub);
459 void specializeOp(RefCastOp refCast, InsertionPoint &insertionPoint,
460 DenseSet<Attribute> &removedSyms) {
461 if (specializeValue(refCast->getResult(0)))
462 insertionPoint.moveOpBefore(refCast);
469 void specializeBlock(Block *block, InsertionPoint &insertionPoint,
470 DenseSet<Attribute> &removedSyms) {
473 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
474 TypeSwitch<Operation *>(&op)
475 .Case<LayerBlockOp, WhenOp, MatchOp, InstanceOp, InstanceChoiceOp,
476 WireOp, RefDefineOp, RefSubOp, RefCastOp>(
477 [&](
auto op) { specializeOp(op, insertionPoint, removedSyms); })
478 .Default([&](Operation *op) {
481 insertionPoint.moveOpBefore(op);
488 Specialized<ArrayAttr> specializeEnableLayers(ArrayAttr layers) {
489 SmallVector<Attribute> newLayers;
490 for (
auto layer : layers.getAsRange<SymbolRefAttr>()) {
491 auto newLayer = specializeLayerRef(layer);
494 if (newLayer.getValue())
495 newLayers.push_back(newLayer.getValue());
500 void specializeModulePorts(FModuleLike moduleLike,
501 DenseSet<Attribute> &removedSyms) {
502 auto oldTypeAttrs = moduleLike.getPortTypesAttr();
505 SmallVector<Attribute> newTypeAttrs;
506 newTypeAttrs.reserve(oldTypeAttrs.size());
510 llvm::BitVector disabledPorts(oldTypeAttrs.size());
512 auto moduleName = moduleLike.getNameAttr();
513 for (
auto [index, typeAttr] :
514 llvm::enumerate(oldTypeAttrs.getAsRange<TypeAttr>())) {
516 if (
auto type = specializeType(typeAttr.getValue())) {
520 if (
auto portSym = moduleLike.getPortSymbolAttr(index))
521 recordRemovedInnerSym(removedSyms, moduleName, portSym);
522 disabledPorts.set(index);
527 moduleLike.erasePorts(disabledPorts);
530 moduleLike.setPortTypesAttr(
534 if (
auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation()))
535 for (
auto [arg, typeAttr] :
536 llvm::zip(moduleOp.getArguments(), newTypeAttrs))
537 arg.setType(cast<TypeAttr>(typeAttr).getValue());
540 template <
typename T>
541 DenseSet<Attribute> specializeModuleLike(T op) {
542 DenseSet<Attribute> removedSyms;
547 if constexpr (std::is_same_v<T, FModuleOp>) {
548 auto *block = cast<FModuleOp>(op).getBodyBlock();
549 auto bodyIP = InsertionPoint::atBlockEnd(block);
550 specializeBlock(block, bodyIP, removedSyms);
554 specializeModulePorts(op, removedSyms);
559 template <
typename T>
560 T specializeEnableLayers(T module, DenseSet<Attribute> &removedSyms) {
562 if (
auto newLayers = specializeEnableLayers(module.getLayersAttr())) {
563 module.setLayersAttr(newLayers.getValue());
569 auto moduleName = module.getNameAttr();
571 if constexpr (std::is_same_v<T, FModuleOp>)
572 recordRemovedInnerSyms(removedSyms, moduleName,
582 void specializeLayer(LayerOp layer) {
583 StringAttr head = layer.getSymNameAttr();
584 SmallVector<FlatSymbolRefAttr> nestedRefs;
586 std::function<void(LayerOp, Block::iterator,
const Twine &)> handleLayer =
587 [&](LayerOp layer, Block::iterator insertionPoint,
588 const Twine &prefix) {
589 auto *block = &layer.getBody().getBlocks().front();
590 auto specialization = getSpecialization(head, nestedRefs);
594 if (!specialization) {
598 if (!prefix.isTriviallyEmpty()) {
599 layer.setSymNameAttr(
601 auto *parentBlock = insertionPoint->getBlock();
602 layer->moveBefore(parentBlock, insertionPoint);
605 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
607 handleLayer(nested, Block::iterator(nested),
"");
608 nestedRefs.pop_back();
615 if (*specialization == LayerSpecialization::Enable) {
617 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
619 handleLayer(nested, insertionPoint,
620 prefix + layer.getSymName() +
"_");
621 nestedRefs.pop_back();
633 handleLayer(layer, Block::iterator(layer),
"");
640 SmallVector<Operation *> specialize;
641 DenseSet<Attribute> removedSyms;
642 for (
auto &op : llvm::make_early_inc_range(*circuit.getBodyBlock())) {
643 TypeSwitch<Operation *>(&op)
644 .Case<FModuleOp, FExtModuleOp>([&](
auto module) {
645 if (specializeEnableLayers(module, removedSyms))
646 specialize.push_back(module);
648 .Case<LayerOp>([&](LayerOp layer) { specializeLayer(layer); });
652 auto mergeSets = [](
auto &&a,
auto &&b) {
653 a.insert(b.begin(), b.end());
654 return std::forward<decltype(a)>(a);
660 context, specialize, removedSyms, mergeSets,
661 [&](Operation *op) -> DenseSet<Attribute> {
662 return TypeSwitch<Operation *, DenseSet<Attribute>>(op)
663 .Case<FModuleOp, FExtModuleOp>(
664 [&](
auto op) {
return specializeModuleLike(op); });
670 DenseSet<StringAttr> removedPaths;
671 for (
auto hierPath : llvm::make_early_inc_range(
672 circuit.getBody().getOps<hw::HierPathOp>())) {
673 auto namepath = hierPath.getNamepath().getValue();
674 auto shouldDelete = [&](Attribute ref) {
675 return removedSyms.contains(ref);
677 if (llvm::any_of(namepath.drop_back(), shouldDelete)) {
678 removedPaths.insert(SymbolTable::getSymbolName(hierPath));
685 if (shouldDelete(namepath.back()))
691 SmallVector<FModuleLike> clean;
692 for (
auto &op : *circuit.getBodyBlock())
693 if (isa<FModuleOp, FExtModuleOp, FIntModuleOp, FMemModuleOp>(op))
694 clean.push_back(cast<FModuleLike>(op));
696 parallelForEach(context, clean, [&](FModuleLike module) {
697 (AnnotationCleaner(removedPaths))(module);
701 MLIRContext *context;
703 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations;
706 std::optional<LayerSpecialization> defaultSpecialization;
709 struct SpecializeLayersPass
710 :
public circt::firrtl::impl::SpecializeLayersBase<SpecializeLayersPass> {
712 void runOnOperation()
override {
713 auto circuit = getOperation();
714 SymbolTableCollection stc;
717 DenseMap<SymbolRefAttr, LayerSpecialization> specializations;
720 bool shouldSpecialize =
false;
723 if (
auto enabledLayers = circuit.getEnableLayersAttr()) {
724 shouldSpecialize =
true;
725 circuit.removeEnableLayersAttr();
726 for (
auto enabledLayer : enabledLayers.getAsRange<SymbolRefAttr>()) {
728 if (!stc.lookupSymbolIn(circuit, enabledLayer)) {
729 mlir::emitError(circuit.getLoc()) <<
"unknown layer " << enabledLayer;
733 specializations[enabledLayer] = LayerSpecialization::Enable;
738 if (
auto disabledLayers = circuit.getDisableLayersAttr()) {
739 shouldSpecialize =
true;
740 circuit.removeDisableLayersAttr();
741 for (
auto disabledLayer : disabledLayers.getAsRange<SymbolRefAttr>()) {
743 if (!stc.lookupSymbolIn(circuit, disabledLayer)) {
744 mlir::emitError(circuit.getLoc())
745 <<
"unknown layer " << disabledLayer;
751 auto [it, inserted] = specializations.try_emplace(
752 disabledLayer, LayerSpecialization::Disable);
753 if (!inserted && it->getSecond() == LayerSpecialization::Enable) {
754 mlir::emitError(circuit.getLoc())
755 <<
"layer " << disabledLayer <<
" both enabled and disabled";
762 std::optional<LayerSpecialization> defaultSpecialization = std::nullopt;
763 if (
auto specialization = circuit.getDefaultLayerSpecialization()) {
764 shouldSpecialize =
true;
765 defaultSpecialization = *specialization;
771 if (!shouldSpecialize)
772 return markAllAnalysesPreserved();
775 SpecializeLayers(circuit, specializations, defaultSpecialization)();
781 return std::make_unique<SpecializeLayersPass>();
assert(baseType &&"element must be base type")
static AnnotationSet forPort(Operation *op, size_t portNo)
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createSpecializeLayersPass()
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static mlir::ArrayAttr getFromVoidPointer(void *ptr)