CIRCT 23.0.0git
Loading...
Searching...
No Matches
SpecializeLayers.cpp
Go to the documentation of this file.
1//===- SpecializeLayers.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "circt/Support/Utils.h"
15#include "mlir/IR/Threading.h"
16#include "llvm/ADT/STLExtras.h"
17#include <optional>
18#include <type_traits>
19
20namespace circt {
21namespace firrtl {
22#define GEN_PASS_DEF_SPECIALIZELAYERS
23#include "circt/Dialect/FIRRTL/Passes.h.inc"
24} // namespace firrtl
25} // namespace circt
26
27using namespace mlir;
28using namespace circt;
29using namespace firrtl;
30
31// TODO: this should be upstreamed.
32namespace llvm {
33template <>
34struct PointerLikeTypeTraits<mlir::ArrayAttr>
35 : public PointerLikeTypeTraits<mlir::Attribute> {
36 static inline mlir::ArrayAttr getFromVoidPointer(void *ptr) {
37 return mlir::ArrayAttr::getFromOpaquePointer(ptr);
38 }
39};
40} // namespace llvm
41
42namespace {
43/// Removes non-local annotations whose path is no longer viable, due to
44/// the removal of module instances.
45struct AnnotationCleaner {
46 /// Create an AnnotationCleaner which removes any annotation which contains a
47 /// reference to a symbol in removedPaths.
48 AnnotationCleaner(const DenseSet<StringAttr> &removedPaths)
49 : removedPaths(removedPaths) {}
50
51 AnnotationSet cleanAnnotations(AnnotationSet annos) {
52 annos.removeAnnotations([&](Annotation anno) {
53 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
54 return removedPaths.contains(nla.getAttr());
55 return false;
56 });
57 return annos;
58 }
59
60 void cleanAnnotations(Operation *op) {
61 AnnotationSet oldAnnotations(op);
62 // We want to avoid attaching an empty annotation array on to an op that
63 // never had an annotation array in the first place.
64 if (!oldAnnotations.empty()) {
65 auto newAnnotations = cleanAnnotations(oldAnnotations);
66 if (oldAnnotations != newAnnotations)
67 newAnnotations.applyToOperation(op);
68 }
69 }
70
71 void operator()(FModuleLike module) {
72 // Clean the regular annotations.
73 cleanAnnotations(module);
74
75 // Clean all port annotations.
76 for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
77 auto oldAnnotations = AnnotationSet::forPort(module, i);
78 if (!oldAnnotations.empty()) {
79 auto newAnnotations = cleanAnnotations(oldAnnotations);
80 if (oldAnnotations != newAnnotations)
81 newAnnotations.applyToPort(module, i);
82 }
83 }
84
85 // Clean all annotations in body.
86 module->walk([&](Operation *op) {
87 // Clean regular annotations.
88 cleanAnnotations(op);
89
90 if (auto mem = dyn_cast<MemOp>(op)) {
91 // Update all annotations on ports.
92 for (size_t i = 0, e = mem.getNumResults(); i < e; ++i) {
93 auto oldAnnotations = AnnotationSet::forPort(mem, i);
94 if (!oldAnnotations.empty()) {
95 auto newAnnotations = cleanAnnotations(oldAnnotations);
96 if (oldAnnotations != newAnnotations)
97 newAnnotations.applyToPort(mem, i);
98 }
99 }
100 }
101 });
102 }
103
104 /// A set of symbols of removed paths.
105 const DenseSet<StringAttr> &removedPaths;
106};
107
108/// Helper to keep track of an insertion point and move operations around.
109struct InsertionPoint {
110 /// Create an insertion point at the end of a block.
111 static InsertionPoint atBlockEnd(Block *block) {
112 return InsertionPoint{block, block->end()};
113 }
114
115 /// Move the target operation before the current insertion point and update
116 /// the insertion point to point to the op.
117 void moveOpBefore(Operation *op) {
118 op->moveBefore(block, it);
119 it = Block::iterator(op);
120 }
121
122private:
123 InsertionPoint(Block *block, Block::iterator it) : block(block), it(it) {}
124
125 Block *block;
126 Block::iterator it;
127};
128
129/// A specialized value. If the value is colored such that it is disabled,
130/// it will not contain an underlying value. Otherwise, contains the
131/// specialized version of the value.
132template <typename T>
133struct Specialized {
134 /// Create a disabled specialized value.
135 Specialized() : Specialized(LayerSpecialization::Disable, nullptr) {}
136
137 /// Create an enabled specialized value.
138 Specialized(T value) : Specialized(LayerSpecialization::Enable, value) {}
139
140 /// Returns true if the value was specialized away.
141 bool isDisabled() const {
142 return value.getInt() == LayerSpecialization::Disable;
143 }
144
145 /// Returns the specialized value if it still exists.
146 T getValue() const {
147 assert(!isDisabled());
148 return value.getPointer();
149 }
150
151 operator bool() const { return !isDisabled(); }
152
153private:
154 Specialized(LayerSpecialization specialization, T value)
155 : value(value, specialization) {}
156 llvm::PointerIntPair<T, 1, LayerSpecialization> value;
157};
158
159struct SpecializeLayers {
160 SpecializeLayers(
161 CircuitOp circuit,
162 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations,
163 std::optional<LayerSpecialization> defaultSpecialization)
164 : context(circuit->getContext()), circuit(circuit),
165 specializations(specializations),
166 defaultSpecialization(defaultSpecialization) {}
167
168 /// Create a reference to every field in the inner symbol, and record it in
169 /// the list of removed symbols.
170 static void recordRemovedInnerSym(DenseSet<Attribute> &removedSyms,
171 StringAttr moduleName,
172 hw::InnerSymAttr innerSym) {
173 for (auto field : innerSym)
174 removedSyms.insert(hw::InnerRefAttr::get(moduleName, field.getName()));
175 }
176
177 /// Create a reference to every field in the inner symbol, and record it in
178 /// the list of removed symbols.
179 static void recordRemovedInnerSyms(DenseSet<Attribute> &removedSyms,
180 StringAttr moduleName, Block *block) {
181 block->walk([&](hw::InnerSymbolOpInterface op) {
182 if (auto innerSym = op.getInnerSymAttr())
183 recordRemovedInnerSym(removedSyms, moduleName, innerSym);
184 });
185 }
186
187 /// If this layer reference is being specialized, returns the specialization
188 /// mode. Otherwise, it returns disabled value.
189 std::optional<LayerSpecialization> getSpecialization(SymbolRefAttr layerRef) {
190 auto it = specializations.find(layerRef);
191 if (it != specializations.end())
192 return it->getSecond();
193 return defaultSpecialization;
194 }
195
196 /// Forms a symbol reference to a layer using head as the root reference, and
197 /// nestedRefs as the path to the specific layer. If this layer reference is
198 /// being specialized, returns the specialization mode. Otherwise, it returns
199 /// nullopt.
200 std::optional<LayerSpecialization>
201 getSpecialization(StringAttr head, ArrayRef<FlatSymbolRefAttr> nestedRefs) {
202 return getSpecialization(SymbolRefAttr::get(head, nestedRefs));
203 }
204
205 /// Specialize a layer reference by removing enabled layers, mangling the
206 /// names of inlined layers, and returning a disabled value if the layer was
207 /// disabled. This can return nullptr if all layers were enabled.
208 Specialized<SymbolRefAttr> specializeLayerRef(SymbolRefAttr layerRef) {
209 // Walk the layer reference from root to leaf, checking if each outer layer
210 // was specialized or not. If an outer layer was disabled, then this
211 // specific layer is implicitly disabled as well. If an outer layer was
212 // enabled, we need to use its name to mangle the name of inner layers. If
213 // the layer is not specialized, we may need to mangle its name, otherwise
214 // we leave it alone.
215
216 auto oldRoot = layerRef.getRootReference();
217 SmallVector<FlatSymbolRefAttr> oldNestedRefs;
218
219 // A prefix to be used to track how to mangle the next non-inlined layer.
220 SmallString<64> prefix;
221 SmallVector<FlatSymbolRefAttr> newRef;
222
223 // Side-effecting helper which returns false if the currently examined layer
224 // is disabled, true otherwise. If the current layer is being enabled, add
225 // its name to the prefix. If the layer is not specialized, we mangle the
226 // name with the prefix which is then reset to the empty string, and copy it
227 // into the new specialize layer reference.
228 auto helper = [&](StringAttr ref) -> bool {
229 auto specialization = getSpecialization(oldRoot, oldNestedRefs);
230
231 // We are not specializing this layer. Mangle the name with the current
232 // prefix.
233 if (!specialization) {
234 newRef.push_back(FlatSymbolRefAttr::get(
235 StringAttr::get(ref.getContext(), prefix + ref.getValue())));
236 prefix.clear();
237 return true;
238 }
239
240 // We are enabling this layer, the next non-enabled layer should
241 // include this layer's name as a prefix.
242 if (*specialization == LayerSpecialization::Enable) {
243 prefix.append(ref.getValue());
244 prefix.append("_");
245 return true;
246 }
247
248 // We are disabling this layer.
249 return false;
250 };
251
252 if (!helper(oldRoot))
253 return {};
254
255 for (auto ref : layerRef.getNestedReferences()) {
256 oldNestedRefs.push_back(ref);
257 if (!helper(ref.getAttr()))
258 return {};
259 }
260
261 if (newRef.empty())
262 return {SymbolRefAttr()};
263
264 // Root references need to be handled differently than nested references,
265 // but since we don't know before hand which layer will form the new root
266 // layer, we copy all layers to the same array, at the cost of unnecessarily
267 // wrapping the new root reference into a FlatSymbolRefAttr and having to
268 // unpack it again.
269 auto newRoot = newRef.front().getAttr();
270 return {SymbolRefAttr::get(newRoot, ArrayRef(newRef).drop_front())};
271 }
272
273 /// Specialize a RefType by specializing the layer color. If the RefType is
274 /// colored with a disabled layer, this will return nullptr.
275 RefType specializeRefType(RefType refType) {
276 if (auto oldLayer = refType.getLayer()) {
277 if (auto newLayer = specializeLayerRef(oldLayer))
278 return RefType::get(refType.getType(), refType.getForceable(),
279 newLayer.getValue());
280 return nullptr;
281 }
282 return refType;
283 }
284
285 Type specializeType(Type type) {
286 if (auto refType = dyn_cast<RefType>(type))
287 return specializeRefType(refType);
288 return type;
289 }
290
291 /// Specialize a value by modifying its type. Returns nullptr if this value
292 /// should be disabled, and the original value otherwise.
293 Value specializeValue(Value value) {
294 if (auto newType = specializeType(value.getType())) {
295 value.setType(newType);
296 return value;
297 }
298 return nullptr;
299 }
300
301 /// Specialize a layerblock. If the layerblock is disabled, it and all of its
302 /// contents will be erased, and all removed inner symbols will be recorded
303 /// so that we can later clean up hierarchical paths. If the layer is
304 /// enabled, then we will inline the contents of the layer to the same
305 /// position and delete the layerblock.
306 void specializeOp(LayerBlockOp layerBlock, InsertionPoint &insertionPoint,
307 DenseSet<Attribute> &removedSyms) {
308 auto oldLayerRef = layerBlock.getLayerNameAttr();
309
310 // Get the specialization for the current layerblock, not taking into
311 // account if an outer layerblock was specialized. We should not have
312 // recursed inside of the disabled layerblock anyways, as it just gets
313 // erased.
314 auto specialization = getSpecialization(oldLayerRef);
315
316 // We are not specializing this layerblock.
317 if (!specialization) {
318 // We must update the name of this layerblock to reflect
319 // specializations of any outer layers.
320 auto newLayerRef = specializeLayerRef(oldLayerRef).getValue();
321 if (oldLayerRef != newLayerRef)
322 layerBlock.setLayerNameAttr(newLayerRef);
323 // Specialize inner operations, but keep them in their original
324 // location.
325 auto *block = layerBlock.getBody();
326 auto bodyIP = InsertionPoint::atBlockEnd(block);
327 specializeBlock(block, bodyIP, removedSyms);
328 insertionPoint.moveOpBefore(layerBlock);
329 return;
330 }
331
332 // We are enabling this layer, and all contents of this layer need to be
333 // moved (inlined) to the insertion point.
334 if (*specialization == LayerSpecialization::Enable) {
335 // Move all contents to the insertion point and specialize them.
336 specializeBlock(layerBlock.getBody(), insertionPoint, removedSyms);
337 // Erase the now empty layerblock.
338 layerBlock->erase();
339 return;
340 }
341
342 // We are disabling this layerblock, so we can just erase the layerblock. We
343 // need to record all the objects with symbols which are being deleted.
344 auto moduleName = layerBlock->getParentOfType<FModuleOp>().getNameAttr();
345 recordRemovedInnerSyms(removedSyms, moduleName, layerBlock.getBody());
346 layerBlock->erase();
347 }
348
349 void specializeOp(WhenOp when, InsertionPoint &insertionPoint,
350 DenseSet<Attribute> &removedSyms) {
351 // We need to specialize both arms of the when, but the inner ops should
352 // be left where they are (and not inlined to the insertion point).
353 auto *thenBlock = &when.getThenBlock();
354 auto thenIP = InsertionPoint::atBlockEnd(thenBlock);
355 specializeBlock(thenBlock, thenIP, removedSyms);
356 if (when.hasElseRegion()) {
357 auto *elseBlock = &when.getElseBlock();
358 auto elseIP = InsertionPoint::atBlockEnd(elseBlock);
359 specializeBlock(elseBlock, elseIP, removedSyms);
360 }
361 insertionPoint.moveOpBefore(when);
362 }
363
364 void specializeOp(MatchOp match, InsertionPoint &insertionPoint,
365 DenseSet<Attribute> &removedSyms) {
366 for (size_t i = 0, e = match.getNumRegions(); i < e; ++i) {
367 auto *caseBlock = &match.getRegion(i).front();
368 auto caseIP = InsertionPoint::atBlockEnd(caseBlock);
369 specializeBlock(caseBlock, caseIP, removedSyms);
370 }
371 insertionPoint.moveOpBefore(match);
372 }
373
374 void specializeOp(FInstanceLike instance, InsertionPoint &insertionPoint,
375 DenseSet<Attribute> &removedSyms) {
376 /// Update the types of any probe ports on the instance, and delete any
377 /// probe port that is permanently disabled.
378 llvm::BitVector disabledPorts(instance->getNumResults());
379 for (auto result : instance->getResults())
380 if (!specializeValue(result))
381 disabledPorts.set(result.getResultNumber());
382
383 if (disabledPorts.any()) {
384 auto newInstance =
385 instance.cloneWithErasedPortsAndReplaceUses(disabledPorts);
386 instance->erase();
387 instance = newInstance;
388 }
389
390 // Specialize the required enable layers. Due to the layer verifiers, there
391 // should not be any disabled layer in this instance and this should be
392 // infallible.
393 auto newLayers = specializeEnableLayers(instance.getLayersAttr());
394 instance.setLayersAttr(newLayers.getValue());
395
396 insertionPoint.moveOpBefore(instance);
397 }
398
399 void specializeOp(WireOp wire, InsertionPoint &insertionPoint,
400 DenseSet<Attribute> &removedSyms) {
401 if (specializeValue(wire.getResult())) {
402 insertionPoint.moveOpBefore(wire);
403 } else {
404 if (auto innerSym = wire.getInnerSymAttr())
405 recordRemovedInnerSym(removedSyms,
406 wire->getParentOfType<FModuleOp>().getNameAttr(),
407 innerSym);
408 wire.erase();
409 }
410 }
411
412 void specializeOp(RefDefineOp refDefine, InsertionPoint &insertionPoint,
413 DenseSet<Attribute> &removedSyms) {
414 // If this is connected disabled probes, erase the refdefine op.
415 if (auto layerRef = refDefine.getDest().getType().getLayer())
416 if (!specializeLayerRef(layerRef)) {
417 refDefine->erase();
418 return;
419 }
420 insertionPoint.moveOpBefore(refDefine);
421 }
422
423 void specializeOp(RefSubOp refSub, InsertionPoint &insertionPoint,
424 DenseSet<Attribute> &removedSyms) {
425 if (specializeValue(refSub->getResult(0)))
426 insertionPoint.moveOpBefore(refSub);
427 else
428
429 refSub.erase();
430 }
431
432 void specializeOp(RefCastOp refCast, InsertionPoint &insertionPoint,
433 DenseSet<Attribute> &removedSyms) {
434 if (specializeValue(refCast->getResult(0)))
435 insertionPoint.moveOpBefore(refCast);
436 else
437 refCast.erase();
438 }
439
440 /// Specialize a block of operations, removing any probes which are
441 /// disabled, and moving all operations to the insertion point.
442 void specializeBlock(Block *block, InsertionPoint &insertionPoint,
443 DenseSet<Attribute> &removedSyms) {
444 // Since this can erase operations that deal with disabled probes, we walk
445 // the block in reverse to make sure that we erase uses before defs.
446 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
447 TypeSwitch<Operation *>(&op)
448 .Case<LayerBlockOp, WhenOp, MatchOp, InstanceOp, InstanceChoiceOp,
449 WireOp, RefDefineOp, RefSubOp, RefCastOp>(
450 [&](auto op) { specializeOp(op, insertionPoint, removedSyms); })
451 .Default([&](Operation *op) {
452 // By default all operations should be inlined from an enabled
453 // layer.
454 insertionPoint.moveOpBefore(op);
455 });
456 }
457 }
458
459 /// Specialize the list of known layers for an extmodule.
460 ArrayAttr specializeKnownLayers(ArrayAttr layers) {
461 SmallVector<Attribute> newLayers;
462 for (auto layer : layers.getAsRange<SymbolRefAttr>()) {
463 if (auto result = specializeLayerRef(layer))
464 if (auto newLayer = result.getValue())
465 newLayers.push_back(newLayer);
466 }
467
468 return ArrayAttr::get(context, newLayers);
469 }
470
471 /// Specialize the list of enabled layers for a module. Return a disabled
472 /// layer if one of the required layers has been disabled.
473 Specialized<ArrayAttr> specializeEnableLayers(ArrayAttr layers) {
474 SmallVector<Attribute> newLayers;
475 for (auto layer : layers.getAsRange<SymbolRefAttr>()) {
476 auto newLayer = specializeLayerRef(layer);
477 if (!newLayer)
478 return {};
479 if (newLayer.getValue())
480 newLayers.push_back(newLayer.getValue());
481 }
482 return ArrayAttr::get(context, newLayers);
483 }
484
485 void specializeModulePorts(FModuleLike moduleLike,
486 DenseSet<Attribute> &removedSyms) {
487 auto oldTypeAttrs = moduleLike.getPortTypesAttr();
488
489 // The list of new port types.
490 SmallVector<Attribute> newTypeAttrs;
491 newTypeAttrs.reserve(oldTypeAttrs.size());
492
493 // This is the list of port indices which need to be removed because they
494 // have been specialized away.
495 llvm::BitVector disabledPorts(oldTypeAttrs.size());
496
497 auto moduleName = moduleLike.getNameAttr();
498 for (auto [index, typeAttr] :
499 llvm::enumerate(oldTypeAttrs.getAsRange<TypeAttr>())) {
500 // Specialize the type fo the port.
501 if (auto type = specializeType(typeAttr.getValue())) {
502 newTypeAttrs.push_back(TypeAttr::get(type));
503 } else {
504 // The port is being disabled, and should be removed.
505 if (auto portSym = moduleLike.getPortSymbolAttr(index))
506 recordRemovedInnerSym(removedSyms, moduleName, portSym);
507 disabledPorts.set(index);
508 }
509 }
510
511 // Erase the disabled ports.
512 moduleLike.erasePorts(disabledPorts);
513
514 // Update the rest of the port types.
515 moduleLike.setPortTypesAttr(
516 ArrayAttr::get(moduleLike.getContext(), newTypeAttrs));
517
518 // We may also need to update the types on the block arguments.
519 if (auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation()))
520 for (auto [arg, typeAttr] :
521 llvm::zip(moduleOp.getArguments(), newTypeAttrs))
522 arg.setType(cast<TypeAttr>(typeAttr).getValue());
523 }
524
525 template <typename T>
526 DenseSet<Attribute> specializeModuleLike(T op) {
527 DenseSet<Attribute> removedSyms;
528
529 // Specialize all operations in the body of the module. This must be done
530 // before specializing the module ports so that we don't try to erase values
531 // that still have uses.
532 if constexpr (std::is_same_v<T, FModuleOp>) {
533 auto *block = cast<FModuleOp>(op).getBodyBlock();
534 auto bodyIP = InsertionPoint::atBlockEnd(block);
535 specializeBlock(block, bodyIP, removedSyms);
536 }
537
538 // Specialize the ports of this module.
539 specializeModulePorts(op, removedSyms);
540
541 return removedSyms;
542 }
543
544 template <typename T>
545 T specializeEnableLayers(T module, DenseSet<Attribute> &removedSyms) {
546 // Update the required layers on the module.
547 if (auto newLayers = specializeEnableLayers(module.getLayersAttr())) {
548 module.setLayersAttr(newLayers.getValue());
549 return module;
550 }
551
552 // If we disabled a layer which this module requires, we must delete the
553 // whole module.
554 auto moduleName = module.getNameAttr();
555 removedSyms.insert(FlatSymbolRefAttr::get(moduleName));
556 if constexpr (std::is_same_v<T, FModuleOp>)
557 recordRemovedInnerSyms(removedSyms, moduleName,
558 cast<FModuleOp>(module).getBodyBlock());
559
560 module->erase();
561 return nullptr;
562 }
563
564 /// Specialize the known layers of an extmodule.
565 void specializeKnownLayers(FExtModuleOp module) {
566 auto knownLayers = module.getKnownLayersAttr();
567 module.setKnownLayersAttr(specializeKnownLayers(knownLayers));
568 }
569
570 /// Specialize a layer operation, by removing enabled layers and inlining
571 /// their contents, deleting disabled layers and all nested layers, and
572 /// mangling the names of any inlined layers.
573 void specializeLayer(LayerOp layer) {
574 StringAttr head = layer.getSymNameAttr();
575 SmallVector<FlatSymbolRefAttr> nestedRefs;
576
577 std::function<void(LayerOp, Block::iterator, const Twine &)> handleLayer =
578 [&](LayerOp layer, Block::iterator insertionPoint,
579 const Twine &prefix) {
580 auto *block = &layer.getBody().getBlocks().front();
581 auto specialization = getSpecialization(head, nestedRefs);
582
583 // If we are not specializing the current layer, visit the inner
584 // layers.
585 if (!specialization) {
586 // We only mangle the name and move the layer if the prefix is
587 // non-empty, which indicates that we are enabling the parent
588 // layer.
589 if (!prefix.isTriviallyEmpty()) {
590 layer.setSymNameAttr(
591 StringAttr::get(context, prefix + layer.getSymName()));
592 auto *parentBlock = insertionPoint->getBlock();
593 layer->moveBefore(parentBlock, insertionPoint);
594 }
595 for (auto nested :
596 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
597 nestedRefs.push_back(SymbolRefAttr::get(nested));
598 handleLayer(nested, Block::iterator(nested), "");
599 nestedRefs.pop_back();
600 }
601 return;
602 }
603
604 // We are enabling this layer. We must inline inner layers, and
605 // mangle their names.
606 if (*specialization == LayerSpecialization::Enable) {
607 for (auto nested :
608 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
609 nestedRefs.push_back(SymbolRefAttr::get(nested));
610 handleLayer(nested, insertionPoint,
611 prefix + layer.getSymName() + "_");
612 nestedRefs.pop_back();
613 }
614 // Erase the now empty layer.
615 layer->erase();
616 return;
617 }
618
619 // If we are disabling this layer, then we can just fully delete
620 // it.
621 layer->erase();
622 };
623
624 handleLayer(layer, Block::iterator(layer), "");
625 }
626
627 void operator()() {
628 // Gather all operations we need to specialize, and all the ops that we
629 // need to clean. We specialize layers and module's enable layers here
630 // because that can delete the operations and must be done serially.
631 SmallVector<Operation *> specialize;
632 DenseSet<Attribute> removedSyms;
633 for (auto &op : llvm::make_early_inc_range(*circuit.getBodyBlock())) {
634 TypeSwitch<Operation *>(&op)
635 .Case<FModuleOp>([&](FModuleOp module) {
636 if (specializeEnableLayers(module, removedSyms))
637 specialize.push_back(module);
638 })
639 .Case<FExtModuleOp>([&](FExtModuleOp module) {
640 specializeKnownLayers(module);
641 if (specializeEnableLayers(module, removedSyms))
642 specialize.push_back(module);
643 })
644 .Case<LayerOp>([&](LayerOp layer) { specializeLayer(layer); });
645 }
646
647 // Function to merge two sets together.
648 auto mergeSets = [](auto &&a, auto &&b) {
649 a.insert(b.begin(), b.end());
650 return std::forward<decltype(a)>(a);
651 };
652
653 // Specialize all modules in parallel. The result is a set of all inner
654 // symbol references which are no longer valid due to disabling layers.
655 removedSyms = transformReduce(
656 context, specialize, removedSyms, mergeSets,
657 [&](Operation *op) -> DenseSet<Attribute> {
658 return TypeSwitch<Operation *, DenseSet<Attribute>>(op)
659 .Case<FModuleOp, FExtModuleOp>(
660 [&](auto op) { return specializeModuleLike(op); });
661 });
662
663 // Remove all hierarchical path operations which reference deleted symbols,
664 // and create a set of the removed paths operations. We will have to remove
665 // all annotations which use these paths.
666 DenseSet<StringAttr> removedPaths;
667 for (auto hierPath : llvm::make_early_inc_range(
668 circuit.getBody().getOps<hw::HierPathOp>())) {
669 auto namepath = hierPath.getNamepath().getValue();
670 auto shouldDelete = [&](Attribute ref) {
671 return removedSyms.contains(ref);
672 };
673 if (llvm::any_of(namepath.drop_back(), shouldDelete)) {
674 removedPaths.insert(SymbolTable::getSymbolName(hierPath));
675 hierPath->erase();
676 continue;
677 }
678 // If we deleted the target of the hierpath, we don't need to add it to
679 // the list of removedPaths, since no annotation will be left around to
680 // reference this path.
681 if (shouldDelete(namepath.back()))
682 hierPath->erase();
683 }
684
685 // Walk all annotations in the circuit and remove the ones that have a
686 // path which traverses or targets a removed instantiation.
687 SmallVector<FModuleLike> clean;
688 for (auto &op : *circuit.getBodyBlock())
689 if (isa<FModuleOp, FExtModuleOp, FIntModuleOp, FMemModuleOp>(op))
690 clean.push_back(cast<FModuleLike>(op));
691
692 parallelForEach(context, clean, [&](FModuleLike module) {
693 (AnnotationCleaner(removedPaths))(module);
694 });
695 }
696
697 MLIRContext *context;
698 CircuitOp circuit;
699 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations;
700 /// The default specialization mode to be applied when a layer has not been
701 /// explicitly enabled or disabled.
702 std::optional<LayerSpecialization> defaultSpecialization;
703};
704
705struct SpecializeLayersPass
706 : public circt::firrtl::impl::SpecializeLayersBase<SpecializeLayersPass> {
707
708 void runOnOperation() override {
709 auto circuit = getOperation();
710 SymbolTableCollection stc;
711
712 // Set of layers to enable or disable.
713 DenseMap<SymbolRefAttr, LayerSpecialization> specializations;
714
715 // If we are not specialization any layers, we can return early.
716 bool shouldSpecialize = false;
717
718 // Record all the layers which are being enabled.
719 if (auto enabledLayers = circuit.getEnableLayersAttr()) {
720 shouldSpecialize = true;
721 circuit.removeEnableLayersAttr();
722 for (auto enabledLayer : enabledLayers.getAsRange<SymbolRefAttr>()) {
723 // Verify that this is a real layer.
724 if (!stc.lookupSymbolIn(circuit, enabledLayer)) {
725 mlir::emitError(circuit.getLoc()) << "unknown layer " << enabledLayer;
726 signalPassFailure();
727 return;
728 }
729 specializations[enabledLayer] = LayerSpecialization::Enable;
730 }
731 }
732
733 // Record all of the layers which are being disabled.
734 if (auto disabledLayers = circuit.getDisableLayersAttr()) {
735 shouldSpecialize = true;
736 circuit.removeDisableLayersAttr();
737 for (auto disabledLayer : disabledLayers.getAsRange<SymbolRefAttr>()) {
738 // Verify that this is a real layer.
739 if (!stc.lookupSymbolIn(circuit, disabledLayer)) {
740 mlir::emitError(circuit.getLoc())
741 << "unknown layer " << disabledLayer;
742 signalPassFailure();
743 return;
744 }
745
746 // Verify that we are not both enabling and disabling this layer.
747 auto [it, inserted] = specializations.try_emplace(
748 disabledLayer, LayerSpecialization::Disable);
749 if (!inserted && it->getSecond() == LayerSpecialization::Enable) {
750 mlir::emitError(circuit.getLoc())
751 << "layer " << disabledLayer << " both enabled and disabled";
752 signalPassFailure();
753 return;
754 }
755 }
756 }
757
758 std::optional<LayerSpecialization> defaultSpecialization = std::nullopt;
759 if (auto specialization = circuit.getDefaultLayerSpecialization()) {
760 shouldSpecialize = true;
761 defaultSpecialization = *specialization;
762 }
763
764 // If we did not transform the circuit, return early.
765 // TODO: if both arrays are empty we could preserve specific analyses, but
766 // not all analyses since we have modified the circuit op.
767 if (!shouldSpecialize)
768 return markAllAnalysesPreserved();
769
770 // Run specialization on our circuit.
771 SpecializeLayers(circuit, specializations, defaultSpecialization)();
772 }
773};
774} // end anonymous namespace
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
Definition Utils.h:40
Definition hw.py:1
static mlir::ArrayAttr getFromVoidPointer(void *ptr)