Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
SpecializeLayers.cpp
Go to the documentation of this file.
1//===- SpecializeLayers.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "circt/Support/Utils.h"
15#include "mlir/IR/Threading.h"
16#include "llvm/ADT/STLExtras.h"
17#include <optional>
18#include <type_traits>
19
20namespace circt {
21namespace firrtl {
22#define GEN_PASS_DEF_SPECIALIZELAYERS
23#include "circt/Dialect/FIRRTL/Passes.h.inc"
24} // namespace firrtl
25} // namespace circt
26
27using namespace mlir;
28using namespace circt;
29using namespace firrtl;
30
31// TODO: this should be upstreamed.
32namespace llvm {
33template <>
34struct PointerLikeTypeTraits<mlir::ArrayAttr>
35 : public PointerLikeTypeTraits<mlir::Attribute> {
36 static inline mlir::ArrayAttr getFromVoidPointer(void *ptr) {
37 return mlir::ArrayAttr::getFromOpaquePointer(ptr);
38 }
39};
40} // namespace llvm
41
42namespace {
43/// Removes non-local annotations whose path is no longer viable, due to
44/// the removal of module instances.
45struct AnnotationCleaner {
46 /// Create an AnnotationCleaner which removes any annotation which contains a
47 /// reference to a symbol in removedPaths.
48 AnnotationCleaner(const DenseSet<StringAttr> &removedPaths)
49 : removedPaths(removedPaths) {}
50
51 AnnotationSet cleanAnnotations(AnnotationSet annos) {
52 annos.removeAnnotations([&](Annotation anno) {
53 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
54 return removedPaths.contains(nla.getAttr());
55 return false;
56 });
57 return annos;
58 }
59
60 void cleanAnnotations(Operation *op) {
61 AnnotationSet oldAnnotations(op);
62 // We want to avoid attaching an empty annotation array on to an op that
63 // never had an annotation array in the first place.
64 if (!oldAnnotations.empty()) {
65 auto newAnnotations = cleanAnnotations(oldAnnotations);
66 if (oldAnnotations != newAnnotations)
67 newAnnotations.applyToOperation(op);
68 }
69 }
70
71 void operator()(FModuleLike module) {
72 // Clean the regular annotations.
73 cleanAnnotations(module);
74
75 // Clean all port annotations.
76 for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
77 auto oldAnnotations = AnnotationSet::forPort(module, i);
78 if (!oldAnnotations.empty()) {
79 auto newAnnotations = cleanAnnotations(oldAnnotations);
80 if (oldAnnotations != newAnnotations)
81 newAnnotations.applyToPort(module, i);
82 }
83 }
84
85 // Clean all annotations in body.
86 module->walk([&](Operation *op) {
87 // Clean regular annotations.
88 cleanAnnotations(op);
89
90 if (auto mem = dyn_cast<MemOp>(op)) {
91 // Update all annotations on ports.
92 for (size_t i = 0, e = mem.getNumResults(); i < e; ++i) {
93 auto oldAnnotations = AnnotationSet::forPort(mem, i);
94 if (!oldAnnotations.empty()) {
95 auto newAnnotations = cleanAnnotations(oldAnnotations);
96 if (oldAnnotations != newAnnotations)
97 newAnnotations.applyToPort(mem, i);
98 }
99 }
100 }
101 });
102 }
103
104 /// A set of symbols of removed paths.
105 const DenseSet<StringAttr> &removedPaths;
106};
107
108/// Helper to keep track of an insertion point and move operations around.
109struct InsertionPoint {
110 /// Create an insertion point at the end of a block.
111 static InsertionPoint atBlockEnd(Block *block) {
112 return InsertionPoint{block, block->end()};
113 }
114
115 /// Move the target operation before the current insertion point and update
116 /// the insertion point to point to the op.
117 void moveOpBefore(Operation *op) {
118 op->moveBefore(block, it);
119 it = Block::iterator(op);
120 }
121
122private:
123 InsertionPoint(Block *block, Block::iterator it) : block(block), it(it) {}
124
125 Block *block;
126 Block::iterator it;
127};
128
129/// A specialized value. If the value is colored such that it is disabled,
130/// it will not contain an underlying value. Otherwise, contains the
131/// specialized version of the value.
132template <typename T>
133struct Specialized {
134 /// Create a disabled specialized value.
135 Specialized() : Specialized(LayerSpecialization::Disable, nullptr) {}
136
137 /// Create an enabled specialized value.
138 Specialized(T value) : Specialized(LayerSpecialization::Enable, value) {}
139
140 /// Returns true if the value was specialized away.
141 bool isDisabled() const {
142 return value.getInt() == LayerSpecialization::Disable;
143 }
144
145 /// Returns the specialized value if it still exists.
146 T getValue() const {
147 assert(!isDisabled());
148 return value.getPointer();
149 }
150
151 operator bool() const { return !isDisabled(); }
152
153private:
154 Specialized(LayerSpecialization specialization, T value)
155 : value(value, specialization) {}
156 llvm::PointerIntPair<T, 1, LayerSpecialization> value;
157};
158
159struct SpecializeLayers {
160 SpecializeLayers(
161 CircuitOp circuit,
162 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations,
163 std::optional<LayerSpecialization> defaultSpecialization)
164 : context(circuit->getContext()), circuit(circuit),
165 specializations(specializations),
166 defaultSpecialization(defaultSpecialization) {}
167
168 /// Create a reference to every field in the inner symbol, and record it in
169 /// the list of removed symbols.
170 static void recordRemovedInnerSym(DenseSet<Attribute> &removedSyms,
171 StringAttr moduleName,
172 hw::InnerSymAttr innerSym) {
173 for (auto field : innerSym)
174 removedSyms.insert(hw::InnerRefAttr::get(moduleName, field.getName()));
175 }
176
177 /// Create a reference to every field in the inner symbol, and record it in
178 /// the list of removed symbols.
179 static void recordRemovedInnerSyms(DenseSet<Attribute> &removedSyms,
180 StringAttr moduleName, Block *block) {
181 block->walk([&](hw::InnerSymbolOpInterface op) {
182 if (auto innerSym = op.getInnerSymAttr())
183 recordRemovedInnerSym(removedSyms, moduleName, innerSym);
184 });
185 }
186
187 /// If this layer reference is being specialized, returns the specialization
188 /// mode. Otherwise, it returns disabled value.
189 std::optional<LayerSpecialization> getSpecialization(SymbolRefAttr layerRef) {
190 auto it = specializations.find(layerRef);
191 if (it != specializations.end())
192 return it->getSecond();
193 return defaultSpecialization;
194 }
195
196 /// Forms a symbol reference to a layer using head as the root reference, and
197 /// nestedRefs as the path to the specific layer. If this layer reference is
198 /// being specialized, returns the specialization mode. Otherwise, it returns
199 /// nullopt.
200 std::optional<LayerSpecialization>
201 getSpecialization(StringAttr head, ArrayRef<FlatSymbolRefAttr> nestedRefs) {
202 return getSpecialization(SymbolRefAttr::get(head, nestedRefs));
203 }
204
205 /// Specialize a layer reference by removing enabled layers, mangling the
206 /// names of inlined layers, and returning a disabled value if the layer was
207 /// disabled. This can return nullptr if all layers were enabled.
208 Specialized<SymbolRefAttr> specializeLayerRef(SymbolRefAttr layerRef) {
209 // Walk the layer reference from root to leaf, checking if each outer layer
210 // was specialized or not. If an outer layer was disabled, then this
211 // specific layer is implicitly disabled as well. If an outer layer was
212 // enabled, we need to use its name to mangle the name of inner layers. If
213 // the layer is not specialized, we may need to mangle its name, otherwise
214 // we leave it alone.
215
216 auto oldRoot = layerRef.getRootReference();
217 SmallVector<FlatSymbolRefAttr> oldNestedRefs;
218
219 // A prefix to be used to track how to mangle the next non-inlined layer.
220 SmallString<64> prefix;
221 SmallVector<FlatSymbolRefAttr> newRef;
222
223 // Side-effecting helper which returns false if the currently examined layer
224 // is disabled, true otherwise. If the current layer is being enabled, add
225 // its name to the prefix. If the layer is not specialized, we mangle the
226 // name with the prefix which is then reset to the empty string, and copy it
227 // into the new specialize layer reference.
228 auto helper = [&](StringAttr ref) -> bool {
229 auto specialization = getSpecialization(oldRoot, oldNestedRefs);
230
231 // We are not specializing this layer. Mangle the name with the current
232 // prefix.
233 if (!specialization) {
234 newRef.push_back(FlatSymbolRefAttr::get(
235 StringAttr::get(ref.getContext(), prefix + ref.getValue())));
236 prefix.clear();
237 return true;
238 }
239
240 // We are enabling this layer, the next non-enabled layer should
241 // include this layer's name as a prefix.
242 if (*specialization == LayerSpecialization::Enable) {
243 prefix.append(ref.getValue());
244 prefix.append("_");
245 return true;
246 }
247
248 // We are disabling this layer.
249 return false;
250 };
251
252 if (!helper(oldRoot))
253 return {};
254
255 for (auto ref : layerRef.getNestedReferences()) {
256 oldNestedRefs.push_back(ref);
257 if (!helper(ref.getAttr()))
258 return {};
259 }
260
261 if (newRef.empty())
262 return {SymbolRefAttr()};
263
264 // Root references need to be handled differently than nested references,
265 // but since we don't know before hand which layer will form the new root
266 // layer, we copy all layers to the same array, at the cost of unnecessarily
267 // wrapping the new root reference into a FlatSymbolRefAttr and having to
268 // unpack it again.
269 auto newRoot = newRef.front().getAttr();
270 return {SymbolRefAttr::get(newRoot, ArrayRef(newRef).drop_front())};
271 }
272
273 /// Specialize a RefType by specializing the layer color. If the RefType is
274 /// colored with a disabled layer, this will return nullptr.
275 RefType specializeRefType(RefType refType) {
276 if (auto oldLayer = refType.getLayer()) {
277 if (auto newLayer = specializeLayerRef(oldLayer))
278 return RefType::get(refType.getType(), refType.getForceable(),
279 newLayer.getValue());
280 return nullptr;
281 }
282 return refType;
283 }
284
285 Type specializeType(Type type) {
286 if (auto refType = dyn_cast<RefType>(type))
287 return specializeRefType(refType);
288 return type;
289 }
290
291 /// Specialize a value by modifying its type. Returns nullptr if this value
292 /// should be disabled, and the original value otherwise.
293 Value specializeValue(Value value) {
294 if (auto newType = specializeType(value.getType())) {
295 value.setType(newType);
296 return value;
297 }
298 return nullptr;
299 }
300
301 /// Specialize a layerblock. If the layerblock is disabled, it and all of its
302 /// contents will be erased, and all removed inner symbols will be recorded
303 /// so that we can later clean up hierarchical paths. If the layer is
304 /// enabled, then we will inline the contents of the layer to the same
305 /// position and delete the layerblock.
306 void specializeOp(LayerBlockOp layerBlock, InsertionPoint &insertionPoint,
307 DenseSet<Attribute> &removedSyms) {
308 auto oldLayerRef = layerBlock.getLayerNameAttr();
309
310 // Get the specialization for the current layerblock, not taking into
311 // account if an outer layerblock was specialized. We should not have
312 // recursed inside of the disabled layerblock anyways, as it just gets
313 // erased.
314 auto specialization = getSpecialization(oldLayerRef);
315
316 // We are not specializing this layerblock.
317 if (!specialization) {
318 // We must update the name of this layerblock to reflect
319 // specializations of any outer layers.
320 auto newLayerRef = specializeLayerRef(oldLayerRef).getValue();
321 if (oldLayerRef != newLayerRef)
322 layerBlock.setLayerNameAttr(newLayerRef);
323 // Specialize inner operations, but keep them in their original
324 // location.
325 auto *block = layerBlock.getBody();
326 auto bodyIP = InsertionPoint::atBlockEnd(block);
327 specializeBlock(block, bodyIP, removedSyms);
328 insertionPoint.moveOpBefore(layerBlock);
329 return;
330 }
331
332 // We are enabling this layer, and all contents of this layer need to be
333 // moved (inlined) to the insertion point.
334 if (*specialization == LayerSpecialization::Enable) {
335 // Move all contents to the insertion point and specialize them.
336 specializeBlock(layerBlock.getBody(), insertionPoint, removedSyms);
337 // Erase the now empty layerblock.
338 layerBlock->erase();
339 return;
340 }
341
342 // We are disabling this layerblock, so we can just erase the layerblock. We
343 // need to record all the objects with symbols which are being deleted.
344 auto moduleName = layerBlock->getParentOfType<FModuleOp>().getNameAttr();
345 recordRemovedInnerSyms(removedSyms, moduleName, layerBlock.getBody());
346 layerBlock->erase();
347 }
348
349 void specializeOp(WhenOp when, InsertionPoint &insertionPoint,
350 DenseSet<Attribute> &removedSyms) {
351 // We need to specialize both arms of the when, but the inner ops should
352 // be left where they are (and not inlined to the insertion point).
353 auto *thenBlock = &when.getThenBlock();
354 auto thenIP = InsertionPoint::atBlockEnd(thenBlock);
355 specializeBlock(thenBlock, thenIP, removedSyms);
356 if (when.hasElseRegion()) {
357 auto *elseBlock = &when.getElseBlock();
358 auto elseIP = InsertionPoint::atBlockEnd(elseBlock);
359 specializeBlock(elseBlock, elseIP, removedSyms);
360 }
361 insertionPoint.moveOpBefore(when);
362 }
363
364 void specializeOp(MatchOp match, InsertionPoint &insertionPoint,
365 DenseSet<Attribute> &removedSyms) {
366 for (size_t i = 0, e = match.getNumRegions(); i < e; ++i) {
367 auto *caseBlock = &match.getRegion(i).front();
368 auto caseIP = InsertionPoint::atBlockEnd(caseBlock);
369 specializeBlock(caseBlock, caseIP, removedSyms);
370 }
371 insertionPoint.moveOpBefore(match);
372 }
373
374 void specializeOp(InstanceOp instance, InsertionPoint &insertionPoint,
375 DenseSet<Attribute> &removedSyms) {
376 /// Update the types of any probe ports on the instance, and delete any
377 /// probe port that is permanently disabled.
378 llvm::BitVector disabledPorts(instance->getNumResults());
379 for (auto result : instance->getResults())
380 if (!specializeValue(result))
381 disabledPorts.set(result.getResultNumber());
382 if (disabledPorts.any()) {
383 OpBuilder builder(instance);
384 auto newInstance = instance.erasePorts(builder, disabledPorts);
385 instance->erase();
386 instance = newInstance;
387 }
388
389 // Specialize the required enable layers. Due to the layer verifiers, there
390 // should not be any disabled layer in this instance and this should be
391 // infallible.
392 auto newLayers = specializeEnableLayers(instance.getLayersAttr());
393 instance.setLayersAttr(newLayers.getValue());
394
395 insertionPoint.moveOpBefore(instance);
396 }
397
398 void specializeOp(InstanceChoiceOp instanceChoice,
399 InsertionPoint &insertionPoint,
400 DenseSet<Attribute> &removedSyms) {
401 /// Update the types of any probe ports on the instanceChoice, and delete
402 /// any probe port that is permanently disabled.
403 llvm::BitVector disabledPorts(instanceChoice->getNumResults());
404 for (auto result : instanceChoice->getResults())
405 if (!specializeValue(result))
406 disabledPorts.set(result.getResultNumber());
407 if (disabledPorts.any()) {
408 OpBuilder builder(instanceChoice);
409 auto newInstanceChoice =
410 instanceChoice.erasePorts(builder, disabledPorts);
411 instanceChoice->erase();
412 instanceChoice = newInstanceChoice;
413 }
414
415 // Specialize the required enable layers. Due to the layer verifiers, there
416 // should not be any disabled layer in this instanceChoice and this should
417 // be infallible.
418 auto newLayers = specializeEnableLayers(instanceChoice.getLayersAttr());
419 instanceChoice.setLayersAttr(newLayers.getValue());
420
421 insertionPoint.moveOpBefore(instanceChoice);
422 }
423
424 void specializeOp(WireOp wire, InsertionPoint &insertionPoint,
425 DenseSet<Attribute> &removedSyms) {
426 if (specializeValue(wire.getResult())) {
427 insertionPoint.moveOpBefore(wire);
428 } else {
429 if (auto innerSym = wire.getInnerSymAttr())
430 recordRemovedInnerSym(removedSyms,
431 wire->getParentOfType<FModuleOp>().getNameAttr(),
432 innerSym);
433 wire.erase();
434 }
435 }
436
437 void specializeOp(RefDefineOp refDefine, InsertionPoint &insertionPoint,
438 DenseSet<Attribute> &removedSyms) {
439 // If this is connected disabled probes, erase the refdefine op.
440 if (auto layerRef = refDefine.getDest().getType().getLayer())
441 if (!specializeLayerRef(layerRef)) {
442 refDefine->erase();
443 return;
444 }
445 insertionPoint.moveOpBefore(refDefine);
446 }
447
448 void specializeOp(RefSubOp refSub, InsertionPoint &insertionPoint,
449 DenseSet<Attribute> &removedSyms) {
450 if (specializeValue(refSub->getResult(0)))
451 insertionPoint.moveOpBefore(refSub);
452 else
453
454 refSub.erase();
455 }
456
457 void specializeOp(RefCastOp refCast, InsertionPoint &insertionPoint,
458 DenseSet<Attribute> &removedSyms) {
459 if (specializeValue(refCast->getResult(0)))
460 insertionPoint.moveOpBefore(refCast);
461 else
462 refCast.erase();
463 }
464
465 /// Specialize a block of operations, removing any probes which are
466 /// disabled, and moving all operations to the insertion point.
467 void specializeBlock(Block *block, InsertionPoint &insertionPoint,
468 DenseSet<Attribute> &removedSyms) {
469 // Since this can erase operations that deal with disabled probes, we walk
470 // the block in reverse to make sure that we erase uses before defs.
471 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
472 TypeSwitch<Operation *>(&op)
473 .Case<LayerBlockOp, WhenOp, MatchOp, InstanceOp, InstanceChoiceOp,
474 WireOp, RefDefineOp, RefSubOp, RefCastOp>(
475 [&](auto op) { specializeOp(op, insertionPoint, removedSyms); })
476 .Default([&](Operation *op) {
477 // By default all operations should be inlined from an enabled
478 // layer.
479 insertionPoint.moveOpBefore(op);
480 });
481 }
482 }
483
484 /// Specialize the list of known layers for an extmodule.
485 ArrayAttr specializeKnownLayers(ArrayAttr layers) {
486 SmallVector<Attribute> newLayers;
487 for (auto layer : layers.getAsRange<SymbolRefAttr>()) {
488 if (auto result = specializeLayerRef(layer))
489 if (auto newLayer = result.getValue())
490 newLayers.push_back(newLayer);
491 }
492
493 return ArrayAttr::get(context, newLayers);
494 }
495
496 /// Specialize the list of enabled layers for a module. Return a disabled
497 /// layer if one of the required layers has been disabled.
498 Specialized<ArrayAttr> specializeEnableLayers(ArrayAttr layers) {
499 SmallVector<Attribute> newLayers;
500 for (auto layer : layers.getAsRange<SymbolRefAttr>()) {
501 auto newLayer = specializeLayerRef(layer);
502 if (!newLayer)
503 return {};
504 if (newLayer.getValue())
505 newLayers.push_back(newLayer.getValue());
506 }
507 return ArrayAttr::get(context, newLayers);
508 }
509
510 void specializeModulePorts(FModuleLike moduleLike,
511 DenseSet<Attribute> &removedSyms) {
512 auto oldTypeAttrs = moduleLike.getPortTypesAttr();
513
514 // The list of new port types.
515 SmallVector<Attribute> newTypeAttrs;
516 newTypeAttrs.reserve(oldTypeAttrs.size());
517
518 // This is the list of port indices which need to be removed because they
519 // have been specialized away.
520 llvm::BitVector disabledPorts(oldTypeAttrs.size());
521
522 auto moduleName = moduleLike.getNameAttr();
523 for (auto [index, typeAttr] :
524 llvm::enumerate(oldTypeAttrs.getAsRange<TypeAttr>())) {
525 // Specialize the type fo the port.
526 if (auto type = specializeType(typeAttr.getValue())) {
527 newTypeAttrs.push_back(TypeAttr::get(type));
528 } else {
529 // The port is being disabled, and should be removed.
530 if (auto portSym = moduleLike.getPortSymbolAttr(index))
531 recordRemovedInnerSym(removedSyms, moduleName, portSym);
532 disabledPorts.set(index);
533 }
534 }
535
536 // Erase the disabled ports.
537 moduleLike.erasePorts(disabledPorts);
538
539 // Update the rest of the port types.
540 moduleLike.setPortTypesAttr(
541 ArrayAttr::get(moduleLike.getContext(), newTypeAttrs));
542
543 // We may also need to update the types on the block arguments.
544 if (auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation()))
545 for (auto [arg, typeAttr] :
546 llvm::zip(moduleOp.getArguments(), newTypeAttrs))
547 arg.setType(cast<TypeAttr>(typeAttr).getValue());
548 }
549
550 template <typename T>
551 DenseSet<Attribute> specializeModuleLike(T op) {
552 DenseSet<Attribute> removedSyms;
553
554 // Specialize all operations in the body of the module. This must be done
555 // before specializing the module ports so that we don't try to erase values
556 // that still have uses.
557 if constexpr (std::is_same_v<T, FModuleOp>) {
558 auto *block = cast<FModuleOp>(op).getBodyBlock();
559 auto bodyIP = InsertionPoint::atBlockEnd(block);
560 specializeBlock(block, bodyIP, removedSyms);
561 }
562
563 // Specialize the ports of this module.
564 specializeModulePorts(op, removedSyms);
565
566 return removedSyms;
567 }
568
569 template <typename T>
570 T specializeEnableLayers(T module, DenseSet<Attribute> &removedSyms) {
571 // Update the required layers on the module.
572 if (auto newLayers = specializeEnableLayers(module.getLayersAttr())) {
573 module.setLayersAttr(newLayers.getValue());
574 return module;
575 }
576
577 // If we disabled a layer which this module requires, we must delete the
578 // whole module.
579 auto moduleName = module.getNameAttr();
580 removedSyms.insert(FlatSymbolRefAttr::get(moduleName));
581 if constexpr (std::is_same_v<T, FModuleOp>)
582 recordRemovedInnerSyms(removedSyms, moduleName,
583 cast<FModuleOp>(module).getBodyBlock());
584
585 module->erase();
586 return nullptr;
587 }
588
589 /// Specialize the known layers of an extmodule.
590 void specializeKnownLayers(FExtModuleOp module) {
591 auto knownLayers = module.getKnownLayersAttr();
592 module.setKnownLayersAttr(specializeKnownLayers(knownLayers));
593 }
594
595 /// Specialize a layer operation, by removing enabled layers and inlining
596 /// their contents, deleting disabled layers and all nested layers, and
597 /// mangling the names of any inlined layers.
598 void specializeLayer(LayerOp layer) {
599 StringAttr head = layer.getSymNameAttr();
600 SmallVector<FlatSymbolRefAttr> nestedRefs;
601
602 std::function<void(LayerOp, Block::iterator, const Twine &)> handleLayer =
603 [&](LayerOp layer, Block::iterator insertionPoint,
604 const Twine &prefix) {
605 auto *block = &layer.getBody().getBlocks().front();
606 auto specialization = getSpecialization(head, nestedRefs);
607
608 // If we are not specializing the current layer, visit the inner
609 // layers.
610 if (!specialization) {
611 // We only mangle the name and move the layer if the prefix is
612 // non-empty, which indicates that we are enabling the parent
613 // layer.
614 if (!prefix.isTriviallyEmpty()) {
615 layer.setSymNameAttr(
616 StringAttr::get(context, prefix + layer.getSymName()));
617 auto *parentBlock = insertionPoint->getBlock();
618 layer->moveBefore(parentBlock, insertionPoint);
619 }
620 for (auto nested :
621 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
622 nestedRefs.push_back(SymbolRefAttr::get(nested));
623 handleLayer(nested, Block::iterator(nested), "");
624 nestedRefs.pop_back();
625 }
626 return;
627 }
628
629 // We are enabling this layer. We must inline inner layers, and
630 // mangle their names.
631 if (*specialization == LayerSpecialization::Enable) {
632 for (auto nested :
633 llvm::make_early_inc_range(block->getOps<LayerOp>())) {
634 nestedRefs.push_back(SymbolRefAttr::get(nested));
635 handleLayer(nested, insertionPoint,
636 prefix + layer.getSymName() + "_");
637 nestedRefs.pop_back();
638 }
639 // Erase the now empty layer.
640 layer->erase();
641 return;
642 }
643
644 // If we are disabling this layer, then we can just fully delete
645 // it.
646 layer->erase();
647 };
648
649 handleLayer(layer, Block::iterator(layer), "");
650 }
651
652 void operator()() {
653 // Gather all operations we need to specialize, and all the ops that we
654 // need to clean. We specialize layers and module's enable layers here
655 // because that can delete the operations and must be done serially.
656 SmallVector<Operation *> specialize;
657 DenseSet<Attribute> removedSyms;
658 for (auto &op : llvm::make_early_inc_range(*circuit.getBodyBlock())) {
659 TypeSwitch<Operation *>(&op)
660 .Case<FModuleOp>([&](FModuleOp module) {
661 if (specializeEnableLayers(module, removedSyms))
662 specialize.push_back(module);
663 })
664 .Case<FExtModuleOp>([&](FExtModuleOp module) {
665 specializeKnownLayers(module);
666 if (specializeEnableLayers(module, removedSyms))
667 specialize.push_back(module);
668 })
669 .Case<LayerOp>([&](LayerOp layer) { specializeLayer(layer); });
670 }
671
672 // Function to merge two sets together.
673 auto mergeSets = [](auto &&a, auto &&b) {
674 a.insert(b.begin(), b.end());
675 return std::forward<decltype(a)>(a);
676 };
677
678 // Specialize all modules in parallel. The result is a set of all inner
679 // symbol references which are no longer valid due to disabling layers.
680 removedSyms = transformReduce(
681 context, specialize, removedSyms, mergeSets,
682 [&](Operation *op) -> DenseSet<Attribute> {
683 return TypeSwitch<Operation *, DenseSet<Attribute>>(op)
684 .Case<FModuleOp, FExtModuleOp>(
685 [&](auto op) { return specializeModuleLike(op); });
686 });
687
688 // Remove all hierarchical path operations which reference deleted symbols,
689 // and create a set of the removed paths operations. We will have to remove
690 // all annotations which use these paths.
691 DenseSet<StringAttr> removedPaths;
692 for (auto hierPath : llvm::make_early_inc_range(
693 circuit.getBody().getOps<hw::HierPathOp>())) {
694 auto namepath = hierPath.getNamepath().getValue();
695 auto shouldDelete = [&](Attribute ref) {
696 return removedSyms.contains(ref);
697 };
698 if (llvm::any_of(namepath.drop_back(), shouldDelete)) {
699 removedPaths.insert(SymbolTable::getSymbolName(hierPath));
700 hierPath->erase();
701 continue;
702 }
703 // If we deleted the target of the hierpath, we don't need to add it to
704 // the list of removedPaths, since no annotation will be left around to
705 // reference this path.
706 if (shouldDelete(namepath.back()))
707 hierPath->erase();
708 }
709
710 // Walk all annotations in the circuit and remove the ones that have a
711 // path which traverses or targets a removed instantiation.
712 SmallVector<FModuleLike> clean;
713 for (auto &op : *circuit.getBodyBlock())
714 if (isa<FModuleOp, FExtModuleOp, FIntModuleOp, FMemModuleOp>(op))
715 clean.push_back(cast<FModuleLike>(op));
716
717 parallelForEach(context, clean, [&](FModuleLike module) {
718 (AnnotationCleaner(removedPaths))(module);
719 });
720 }
721
722 MLIRContext *context;
723 CircuitOp circuit;
724 const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations;
725 /// The default specialization mode to be applied when a layer has not been
726 /// explicitly enabled or disabled.
727 std::optional<LayerSpecialization> defaultSpecialization;
728};
729
730struct SpecializeLayersPass
731 : public circt::firrtl::impl::SpecializeLayersBase<SpecializeLayersPass> {
732
733 void runOnOperation() override {
734 auto circuit = getOperation();
735 SymbolTableCollection stc;
736
737 // Set of layers to enable or disable.
738 DenseMap<SymbolRefAttr, LayerSpecialization> specializations;
739
740 // If we are not specialization any layers, we can return early.
741 bool shouldSpecialize = false;
742
743 // Record all the layers which are being enabled.
744 if (auto enabledLayers = circuit.getEnableLayersAttr()) {
745 shouldSpecialize = true;
746 circuit.removeEnableLayersAttr();
747 for (auto enabledLayer : enabledLayers.getAsRange<SymbolRefAttr>()) {
748 // Verify that this is a real layer.
749 if (!stc.lookupSymbolIn(circuit, enabledLayer)) {
750 mlir::emitError(circuit.getLoc()) << "unknown layer " << enabledLayer;
751 signalPassFailure();
752 return;
753 }
754 specializations[enabledLayer] = LayerSpecialization::Enable;
755 }
756 }
757
758 // Record all of the layers which are being disabled.
759 if (auto disabledLayers = circuit.getDisableLayersAttr()) {
760 shouldSpecialize = true;
761 circuit.removeDisableLayersAttr();
762 for (auto disabledLayer : disabledLayers.getAsRange<SymbolRefAttr>()) {
763 // Verify that this is a real layer.
764 if (!stc.lookupSymbolIn(circuit, disabledLayer)) {
765 mlir::emitError(circuit.getLoc())
766 << "unknown layer " << disabledLayer;
767 signalPassFailure();
768 return;
769 }
770
771 // Verify that we are not both enabling and disabling this layer.
772 auto [it, inserted] = specializations.try_emplace(
773 disabledLayer, LayerSpecialization::Disable);
774 if (!inserted && it->getSecond() == LayerSpecialization::Enable) {
775 mlir::emitError(circuit.getLoc())
776 << "layer " << disabledLayer << " both enabled and disabled";
777 signalPassFailure();
778 return;
779 }
780 }
781 }
782
783 std::optional<LayerSpecialization> defaultSpecialization = std::nullopt;
784 if (auto specialization = circuit.getDefaultLayerSpecialization()) {
785 shouldSpecialize = true;
786 defaultSpecialization = *specialization;
787 }
788
789 // If we did not transform the circuit, return early.
790 // TODO: if both arrays are empty we could preserve specific analyses, but
791 // not all analyses since we have modified the circuit op.
792 if (!shouldSpecialize)
793 return markAllAnalysesPreserved();
794
795 // Run specialization on our circuit.
796 SpecializeLayers(circuit, specializations, defaultSpecialization)();
797 }
798};
799} // end anonymous namespace
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
Definition Utils.h:40
Definition hw.py:1
static mlir::ArrayAttr getFromVoidPointer(void *ptr)