CIRCT  20.0.0git
SpecializeLayers.cpp
Go to the documentation of this file.
1 //===- SpecializeLayers.cpp -------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "circt/Dialect/HW/HWOps.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/IR/Threading.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include <optional>
20 #include <type_traits>
21 
22 namespace circt {
23 namespace firrtl {
24 #define GEN_PASS_DEF_SPECIALIZELAYERS
25 #include "circt/Dialect/FIRRTL/Passes.h.inc"
26 } // namespace firrtl
27 } // namespace circt
28 
29 using namespace mlir;
30 using namespace circt;
31 using namespace firrtl;
32 
33 // TODO: this should be upstreamed.
34 namespace llvm {
35 template <>
36 struct PointerLikeTypeTraits<mlir::ArrayAttr>
37  : public PointerLikeTypeTraits<mlir::Attribute> {
38  static inline mlir::ArrayAttr getFromVoidPointer(void *ptr) {
39  return mlir::ArrayAttr::getFromOpaquePointer(ptr);
40  }
41 };
42 } // namespace llvm
43 
44 namespace {
45 /// Removes non-local annotations whose path is no longer viable, due to
46 /// the removal of module instances.
47 struct AnnotationCleaner {
48  /// Create an AnnotationCleaner which removes any annotation which contains a
49  /// reference to a symbol in removedPaths.
50  AnnotationCleaner(const DenseSet<StringAttr> &removedPaths)
51  : removedPaths(removedPaths) {}
52 
53  AnnotationSet cleanAnnotations(AnnotationSet annos) {
54  annos.removeAnnotations([&](Annotation anno) {
55  if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
56  return removedPaths.contains(nla.getAttr());
57  return false;
58  });
59  return annos;
60  }
61 
62  void cleanAnnotations(Operation *op) {
63  AnnotationSet oldAnnotations(op);
64  // We want to avoid attaching an empty annotation array on to an op that
65  // never had an annotation array in the first place.
66  if (!oldAnnotations.empty()) {
67  auto newAnnotations = cleanAnnotations(oldAnnotations);
68  if (oldAnnotations != newAnnotations)
69  newAnnotations.applyToOperation(op);
70  }
71  }
72 
73  void operator()(FModuleLike module) {
74  // Clean the regular annotations.
75  cleanAnnotations(module);
76 
77  // Clean all port annotations.
78  for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
79  auto oldAnnotations = AnnotationSet::forPort(module, i);
80  if (!oldAnnotations.empty()) {
81  auto newAnnotations = cleanAnnotations(oldAnnotations);
82  if (oldAnnotations != newAnnotations)
83  newAnnotations.applyToPort(module, i);
84  }
85  }
86 
87  // Clean all annotations in body.
88  module->walk([&](Operation *op) {
89  // Clean regular annotations.
90  cleanAnnotations(op);
91 
92  if (auto mem = dyn_cast<MemOp>(op)) {
93  // Update all annotations on ports.
94  for (size_t i = 0, e = mem.getNumResults(); i < e; ++i) {
95  auto oldAnnotations = AnnotationSet::forPort(mem, i);
96  if (!oldAnnotations.empty()) {
97  auto newAnnotations = cleanAnnotations(oldAnnotations);
98  if (oldAnnotations != newAnnotations)
99  newAnnotations.applyToPort(mem, i);
100  }
101  }
102  }
103  });
104  }
105 
106  /// A set of symbols of removed paths.
107  const DenseSet<StringAttr> &removedPaths;
108 };
109 
110 /// Helper to keep track of an insertion point and move operations around.
111 struct InsertionPoint {
112  /// Create an insertion point at the end of a block.
113  static InsertionPoint atBlockEnd(Block *block) {
114  return InsertionPoint{block, block->end()};
115  }
116 
117  /// Move the target operation before the current insertion point and update
118  /// the insertion point to point to the op.
119  void moveOpBefore(Operation *op) {
120  op->moveBefore(block, it);
121  it = Block::iterator(op);
122  }
123 
124 private:
125  InsertionPoint(Block *block, Block::iterator it) : block(block), it(it) {}
126 
127  Block *block;
128  Block::iterator it;
129 };
130 
131 /// A specialized value. If the value is colored such that it is disabled,
132 /// it will not contain an underlying value. Otherwise, contains the
133 /// specialized version of the value.
134 template <typename T>
135 struct Specialized {
136  /// Create a disabled specialized value.
137  Specialized() : Specialized(LayerSpecialization::Disable, nullptr) {}
138 
139  /// Create an enabled specialized value.
140  Specialized(T value) : Specialized(LayerSpecialization::Enable, value) {}
141 
142  /// Returns true if the value was specialized away.
143  bool isDisabled() const {
144  return value.getInt() == LayerSpecialization::Disable;
145  }
146 
147  /// Returns the specialized value if it still exists.
148  T getValue() const {
149  assert(!isDisabled());
150  return value.getPointer();
151  }
152 
153  operator bool() const { return !isDisabled(); }
154 
155 private:
156  Specialized(LayerSpecialization specialization, T value)
157  : value(value, specialization) {}
158  llvm::PointerIntPair<T, 1, LayerSpecialization> value;
159 };
160 
161 struct SpecializeLayers {
162  SpecializeLayers(
163  CircuitOp circuit,
164  const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations,
165  std::optional<LayerSpecialization> defaultSpecialization)
166  : context(circuit->getContext()), circuit(circuit),
167  specializations(specializations),
168  defaultSpecialization(defaultSpecialization) {}
169 
170  /// Create a reference to every field in the inner symbol, and record it in
171  /// the list of removed symbols.
172  static void recordRemovedInnerSym(DenseSet<Attribute> &removedSyms,
173  StringAttr moduleName,
174  hw::InnerSymAttr innerSym) {
175  for (auto field : innerSym)
176  removedSyms.insert(hw::InnerRefAttr::get(moduleName, field.getName()));
177  }
178 
179  /// Create a reference to every field in the inner symbol, and record it in
180  /// the list of removed symbols.
181  static void recordRemovedInnerSyms(DenseSet<Attribute> &removedSyms,
182  StringAttr moduleName, Block *block) {
183  block->walk([&](hw::InnerSymbolOpInterface op) {
184  if (auto innerSym = op.getInnerSymAttr())
185  recordRemovedInnerSym(removedSyms, moduleName, innerSym);
186  });
187  }
188 
189  /// If this layer reference is being specialized, returns the specialization
190  /// mode. Otherwise, it returns disabled value.
191  std::optional<LayerSpecialization> getSpecialization(SymbolRefAttr layerRef) {
192  auto it = specializations.find(layerRef);
193  if (it != specializations.end())
194  return it->getSecond();
195  return defaultSpecialization;
196  }
197 
198  /// Forms a symbol reference to a layer using head as the root reference, and
199  /// nestedRefs as the path to the specific layer. If this layer reference is
200  /// being specialized, returns the specialization mode. Otherwise, it returns
201  /// nullopt.
202  std::optional<LayerSpecialization>
203  getSpecialization(StringAttr head, ArrayRef<FlatSymbolRefAttr> nestedRefs) {
204  return getSpecialization(SymbolRefAttr::get(head, nestedRefs));
205  }
206 
207  /// Specialize a layer reference by removing enabled layers, mangling the
208  /// names of inlined layers, and returning a disabled value if the layer was
209  /// disabled. This can return nullptr if all layers were enabled.
210  Specialized<SymbolRefAttr> specializeLayerRef(SymbolRefAttr layerRef) {
211  // Walk the layer reference from root to leaf, checking if each outer layer
212  // was specialized or not. If an outer layer was disabled, then this
213  // specific layer is implicitly disabled as well. If an outer layer was
214  // enabled, we need to use its name to mangle the name of inner layers. If
215  // the layer is not specialized, we may need to mangle its name, otherwise
216  // we leave it alone.
217 
218  auto oldRoot = layerRef.getRootReference();
219  SmallVector<FlatSymbolRefAttr> oldNestedRefs;
220 
221  // A prefix to be used to track how to mangle the next non-inlined layer.
222  SmallString<64> prefix;
223  SmallVector<FlatSymbolRefAttr> newRef;
224 
225  // Side-effecting helper which returns false if the currently examined layer
226  // is disabled, true otherwise. If the current layer is being enabled, add
227  // its name to the prefix. If the layer is not specialized, we mangle the
228  // name with the prefix which is then reset to the empty string, and copy it
229  // into the new specialize layer reference.
230  auto helper = [&](StringAttr ref) -> bool {
231  auto specialization = getSpecialization(oldRoot, oldNestedRefs);
232 
233  // We are not specializing this layer. Mangle the name with the current
234  // prefix.
235  if (!specialization) {
236  newRef.push_back(FlatSymbolRefAttr::get(
237  StringAttr::get(ref.getContext(), prefix + ref.getValue())));
238  prefix.clear();
239  return true;
240  }
241 
242  // We are enabling this layer, the next non-enabled layer should
243  // include this layer's name as a prefix.
244  if (*specialization == LayerSpecialization::Enable) {
245  prefix.append(ref.getValue());
246  prefix.append("_");
247  return true;
248  }
249 
250  // We are disabling this layer.
251  return false;
252  };
253 
254  if (!helper(oldRoot))
255  return {};
256 
257  for (auto ref : layerRef.getNestedReferences()) {
258  oldNestedRefs.push_back(ref);
259  if (!helper(ref.getAttr()))
260  return {};
261  }
262 
263  if (newRef.empty())
264  return {SymbolRefAttr()};
265 
266  // Root references need to be handled differently than nested references,
267  // but since we don't know before hand which layer will form the new root
268  // layer, we copy all layers to the same array, at the cost of unnecessarily
269  // wrapping the new root reference into a FlatSymbolRefAttr and having to
270  // unpack it again.
271  auto newRoot = newRef.front().getAttr();
272  return {SymbolRefAttr::get(newRoot, ArrayRef(newRef).drop_front())};
273  }
274 
275  /// Specialize a RefType by specializing the layer color. If the RefType is
276  /// colored with a disabled layer, this will return nullptr.
277  RefType specializeRefType(RefType refType) {
278  if (auto oldLayer = refType.getLayer()) {
279  if (auto newLayer = specializeLayerRef(oldLayer))
280  return RefType::get(refType.getType(), refType.getForceable(),
281  newLayer.getValue());
282  return nullptr;
283  }
284  return refType;
285  }
286 
287  Type specializeType(Type type) {
288  if (auto refType = dyn_cast<RefType>(type))
289  return specializeRefType(refType);
290  return type;
291  }
292 
293  /// Specialize a value by modifying its type. Returns nullptr if this value
294  /// should be disabled, and the original value otherwise.
295  Value specializeValue(Value value) {
296  if (auto newType = specializeType(value.getType())) {
297  value.setType(newType);
298  return value;
299  }
300  return nullptr;
301  }
302 
303  /// Specialize a layerblock. If the layerblock is disabled, it and all of its
304  /// contents will be erased, and all removed inner symbols will be recorded
305  /// so that we can later clean up hierarchical paths. If the layer is
306  /// enabled, then we will inline the contents of the layer to the same
307  /// position and delete the layerblock.
308  void specializeOp(LayerBlockOp layerBlock, InsertionPoint &insertionPoint,
309  DenseSet<Attribute> &removedSyms) {
310  auto oldLayerRef = layerBlock.getLayerNameAttr();
311 
312  // Get the specialization for the current layerblock, not taking into
313  // account if an outer layerblock was specialized. We should not have
314  // recursed inside of the disabled layerblock anyways, as it just gets
315  // erased.
316  auto specialization = getSpecialization(oldLayerRef);
317 
318  // We are not specializing this layerblock.
319  if (!specialization) {
320  // We must update the name of this layerblock to reflect
321  // specializations of any outer layers.
322  auto newLayerRef = specializeLayerRef(oldLayerRef).getValue();
323  if (oldLayerRef != newLayerRef)
324  layerBlock.setLayerNameAttr(newLayerRef);
325  // Specialize inner operations, but keep them in their original
326  // location.
327  auto *block = layerBlock.getBody();
328  auto bodyIP = InsertionPoint::atBlockEnd(block);
329  specializeBlock(block, bodyIP, removedSyms);
330  insertionPoint.moveOpBefore(layerBlock);
331  return;
332  }
333 
334  // We are enabling this layer, and all contents of this layer need to be
335  // moved (inlined) to the insertion point.
336  if (*specialization == LayerSpecialization::Enable) {
337  // Move all contents to the insertion point and specialize them.
338  specializeBlock(layerBlock.getBody(), insertionPoint, removedSyms);
339  // Erase the now empty layerblock.
340  layerBlock->erase();
341  return;
342  }
343 
344  // We are disabling this layerblock, so we can just erase the layerblock. We
345  // need to record all the objects with symbols which are being deleted.
346  auto moduleName = layerBlock->getParentOfType<FModuleOp>().getNameAttr();
347  recordRemovedInnerSyms(removedSyms, moduleName, layerBlock.getBody());
348  layerBlock->erase();
349  }
350 
351  void specializeOp(WhenOp when, InsertionPoint &insertionPoint,
352  DenseSet<Attribute> &removedSyms) {
353  // We need to specialize both arms of the when, but the inner ops should
354  // be left where they are (and not inlined to the insertion point).
355  auto *thenBlock = &when.getThenBlock();
356  auto thenIP = InsertionPoint::atBlockEnd(thenBlock);
357  specializeBlock(thenBlock, thenIP, removedSyms);
358  if (when.hasElseRegion()) {
359  auto *elseBlock = &when.getElseBlock();
360  auto elseIP = InsertionPoint::atBlockEnd(elseBlock);
361  specializeBlock(elseBlock, elseIP, removedSyms);
362  }
363  insertionPoint.moveOpBefore(when);
364  }
365 
366  void specializeOp(MatchOp match, InsertionPoint &insertionPoint,
367  DenseSet<Attribute> &removedSyms) {
368  for (size_t i = 0, e = match.getNumRegions(); i < e; ++i) {
369  auto *caseBlock = &match.getRegion(i).front();
370  auto caseIP = InsertionPoint::atBlockEnd(caseBlock);
371  specializeBlock(caseBlock, caseIP, removedSyms);
372  }
373  insertionPoint.moveOpBefore(match);
374  }
375 
376  void specializeOp(InstanceOp instance, InsertionPoint &insertionPoint,
377  DenseSet<Attribute> &removedSyms) {
378  /// Update the types of any probe ports on the instance, and delete any
379  /// probe port that is permanently disabled.
380  llvm::BitVector disabledPorts(instance->getNumResults());
381  for (auto result : instance->getResults())
382  if (!specializeValue(result))
383  disabledPorts.set(result.getResultNumber());
384  if (disabledPorts.any()) {
385  OpBuilder builder(instance);
386  auto newInstance = instance.erasePorts(builder, disabledPorts);
387  instance->erase();
388  instance = newInstance;
389  }
390 
391  // Specialize the required enable layers. Due to the layer verifiers, there
392  // should not be any disabled layer in this instance and this should be
393  // infallible.
394  auto newLayers = specializeEnableLayers(instance.getLayersAttr());
395  instance.setLayersAttr(newLayers.getValue());
396 
397  insertionPoint.moveOpBefore(instance);
398  }
399 
400  void specializeOp(InstanceChoiceOp instanceChoice,
401  InsertionPoint &insertionPoint,
402  DenseSet<Attribute> &removedSyms) {
403  /// Update the types of any probe ports on the instanceChoice, and delete
404  /// any probe port that is permanently disabled.
405  llvm::BitVector disabledPorts(instanceChoice->getNumResults());
406  for (auto result : instanceChoice->getResults())
407  if (!specializeValue(result))
408  disabledPorts.set(result.getResultNumber());
409  if (disabledPorts.any()) {
410  OpBuilder builder(instanceChoice);
411  auto newInstanceChoice =
412  instanceChoice.erasePorts(builder, disabledPorts);
413  instanceChoice->erase();
414  instanceChoice = newInstanceChoice;
415  }
416 
417  // Specialize the required enable layers. Due to the layer verifiers, there
418  // should not be any disabled layer in this instanceChoice and this should
419  // be infallible.
420  auto newLayers = specializeEnableLayers(instanceChoice.getLayersAttr());
421  instanceChoice.setLayersAttr(newLayers.getValue());
422 
423  insertionPoint.moveOpBefore(instanceChoice);
424  }
425 
426  void specializeOp(WireOp wire, InsertionPoint &insertionPoint,
427  DenseSet<Attribute> &removedSyms) {
428  if (specializeValue(wire.getResult())) {
429  insertionPoint.moveOpBefore(wire);
430  } else {
431  if (auto innerSym = wire.getInnerSymAttr())
432  recordRemovedInnerSym(removedSyms,
433  wire->getParentOfType<FModuleOp>().getNameAttr(),
434  innerSym);
435  wire.erase();
436  }
437  }
438 
439  void specializeOp(RefDefineOp refDefine, InsertionPoint &insertionPoint,
440  DenseSet<Attribute> &removedSyms) {
441  // If this is connected disabled probes, erase the refdefine op.
442  if (auto layerRef = refDefine.getDest().getType().getLayer())
443  if (!specializeLayerRef(layerRef)) {
444  refDefine->erase();
445  return;
446  }
447  insertionPoint.moveOpBefore(refDefine);
448  }
449 
450  void specializeOp(RefSubOp refSub, InsertionPoint &insertionPoint,
451  DenseSet<Attribute> &removedSyms) {
452  if (specializeValue(refSub->getResult(0)))
453  insertionPoint.moveOpBefore(refSub);
454  else
455 
456  refSub.erase();
457  }
458 
459  void specializeOp(RefCastOp refCast, InsertionPoint &insertionPoint,
460  DenseSet<Attribute> &removedSyms) {
461  if (specializeValue(refCast->getResult(0)))
462  insertionPoint.moveOpBefore(refCast);
463  else
464  refCast.erase();
465  }
466 
467  /// Specialize a block of operations, removing any probes which are
468  /// disabled, and moving all operations to the insertion point.
469  void specializeBlock(Block *block, InsertionPoint &insertionPoint,
470  DenseSet<Attribute> &removedSyms) {
471  // Since this can erase operations that deal with disabled probes, we walk
472  // the block in reverse to make sure that we erase uses before defs.
473  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
474  TypeSwitch<Operation *>(&op)
475  .Case<LayerBlockOp, WhenOp, MatchOp, InstanceOp, InstanceChoiceOp,
476  WireOp, RefDefineOp, RefSubOp, RefCastOp>(
477  [&](auto op) { specializeOp(op, insertionPoint, removedSyms); })
478  .Default([&](Operation *op) {
479  // By default all operations should be inlined from an enabled
480  // layer.
481  insertionPoint.moveOpBefore(op);
482  });
483  }
484  }
485 
486  /// Specialize the list of enabled layers for a module. Return a disabled
487  /// layer if one of the required layers has been disabled.
488  Specialized<ArrayAttr> specializeEnableLayers(ArrayAttr layers) {
489  SmallVector<Attribute> newLayers;
490  for (auto layer : layers.getAsRange<SymbolRefAttr>()) {
491  auto newLayer = specializeLayerRef(layer);
492  if (!newLayer)
493  return {};
494  if (newLayer.getValue())
495  newLayers.push_back(newLayer.getValue());
496  }
497  return ArrayAttr::get(context, newLayers);
498  }
499 
500  void specializeModulePorts(FModuleLike moduleLike,
501  DenseSet<Attribute> &removedSyms) {
502  auto oldTypeAttrs = moduleLike.getPortTypesAttr();
503 
504  // The list of new port types.
505  SmallVector<Attribute> newTypeAttrs;
506  newTypeAttrs.reserve(oldTypeAttrs.size());
507 
508  // This is the list of port indices which need to be removed because they
509  // have been specialized away.
510  llvm::BitVector disabledPorts(oldTypeAttrs.size());
511 
512  auto moduleName = moduleLike.getNameAttr();
513  for (auto [index, typeAttr] :
514  llvm::enumerate(oldTypeAttrs.getAsRange<TypeAttr>())) {
515  // Specialize the type fo the port.
516  if (auto type = specializeType(typeAttr.getValue())) {
517  newTypeAttrs.push_back(TypeAttr::get(type));
518  } else {
519  // The port is being disabled, and should be removed.
520  if (auto portSym = moduleLike.getPortSymbolAttr(index))
521  recordRemovedInnerSym(removedSyms, moduleName, portSym);
522  disabledPorts.set(index);
523  }
524  }
525 
526  // Erase the disabled ports.
527  moduleLike.erasePorts(disabledPorts);
528 
529  // Update the rest of the port types.
530  moduleLike.setPortTypesAttr(
531  ArrayAttr::get(moduleLike.getContext(), newTypeAttrs));
532 
533  // We may also need to update the types on the block arguments.
534  if (auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation()))
535  for (auto [arg, typeAttr] :
536  llvm::zip(moduleOp.getArguments(), newTypeAttrs))
537  arg.setType(cast<TypeAttr>(typeAttr).getValue());
538  }
539 
540  template <typename T>
541  DenseSet<Attribute> specializeModuleLike(T op) {
542  DenseSet<Attribute> removedSyms;
543 
544  // Specialize all operations in the body of the module. This must be done
545  // before specializing the module ports so that we don't try to erase values
546  // that still have uses.
547  if constexpr (std::is_same_v<T, FModuleOp>) {
548  auto *block = cast<FModuleOp>(op).getBodyBlock();
549  auto bodyIP = InsertionPoint::atBlockEnd(block);
550  specializeBlock(block, bodyIP, removedSyms);
551  }
552 
553  // Specialize the ports of this module.
554  specializeModulePorts(op, removedSyms);
555 
556  return removedSyms;
557  }
558 
559  template <typename T>
560  T specializeEnableLayers(T module, DenseSet<Attribute> &removedSyms) {
561  // Update the required layers on the module.
562  if (auto newLayers = specializeEnableLayers(module.getLayersAttr())) {
563  module.setLayersAttr(newLayers.getValue());
564  return module;
565  }
566 
567  // If we disabled a layer which this module requires, we must delete the
568  // whole module.
569  auto moduleName = module.getNameAttr();
570  removedSyms.insert(FlatSymbolRefAttr::get(moduleName));
571  if constexpr (std::is_same_v<T, FModuleOp>)
572  recordRemovedInnerSyms(removedSyms, moduleName,
573  cast<FModuleOp>(module).getBodyBlock());
574 
575  module->erase();
576  return nullptr;
577  }
578 
579  /// Specialize a layer operation, by removing enabled layers and inlining
580  /// their contents, deleting disabled layers and all nested layers, and
581  /// mangling the names of any inlined layers.
582  void specializeLayer(LayerOp layer) {
583  StringAttr head = layer.getSymNameAttr();
584  SmallVector<FlatSymbolRefAttr> nestedRefs;
585 
586  std::function<void(LayerOp, Block::iterator, const Twine &)> handleLayer =
587  [&](LayerOp layer, Block::iterator insertionPoint,
588  const Twine &prefix) {
589  auto *block = &layer.getBody().getBlocks().front();
590  auto specialization = getSpecialization(head, nestedRefs);
591 
592  // If we are not specializing the current layer, visit the inner
593  // layers.
594  if (!specialization) {
595  // We only mangle the name and move the layer if the prefix is
596  // non-empty, which indicates that we are enabling the parent
597  // layer.
598  if (!prefix.isTriviallyEmpty()) {
599  layer.setSymNameAttr(
600  StringAttr::get(context, prefix + layer.getSymName()));
601  auto *parentBlock = insertionPoint->getBlock();
602  layer->moveBefore(parentBlock, insertionPoint);
603  }
604  for (auto nested :
605  llvm::make_early_inc_range(block->getOps<LayerOp>())) {
606  nestedRefs.push_back(SymbolRefAttr::get(nested));
607  handleLayer(nested, Block::iterator(nested), "");
608  nestedRefs.pop_back();
609  }
610  return;
611  }
612 
613  // We are enabling this layer. We must inline inner layers, and
614  // mangle their names.
615  if (*specialization == LayerSpecialization::Enable) {
616  for (auto nested :
617  llvm::make_early_inc_range(block->getOps<LayerOp>())) {
618  nestedRefs.push_back(SymbolRefAttr::get(nested));
619  handleLayer(nested, insertionPoint,
620  prefix + layer.getSymName() + "_");
621  nestedRefs.pop_back();
622  }
623  // Erase the now empty layer.
624  layer->erase();
625  return;
626  }
627 
628  // If we are disabling this layer, then we can just fully delete
629  // it.
630  layer->erase();
631  };
632 
633  handleLayer(layer, Block::iterator(layer), "");
634  }
635 
636  void operator()() {
637  // Gather all operations we need to specialize, and all the ops that we
638  // need to clean. We specialize layers and module's enable layers here
639  // because that can delete the operations and must be done serially.
640  SmallVector<Operation *> specialize;
641  DenseSet<Attribute> removedSyms;
642  for (auto &op : llvm::make_early_inc_range(*circuit.getBodyBlock())) {
643  TypeSwitch<Operation *>(&op)
644  .Case<FModuleOp, FExtModuleOp>([&](auto module) {
645  if (specializeEnableLayers(module, removedSyms))
646  specialize.push_back(module);
647  })
648  .Case<LayerOp>([&](LayerOp layer) { specializeLayer(layer); });
649  }
650 
651  // Function to merge two sets together.
652  auto mergeSets = [](auto &&a, auto &&b) {
653  a.insert(b.begin(), b.end());
654  return std::forward<decltype(a)>(a);
655  };
656 
657  // Specialize all modules in parallel. The result is a set of all inner
658  // symbol references which are no longer valid due to disabling layers.
659  removedSyms = transformReduce(
660  context, specialize, removedSyms, mergeSets,
661  [&](Operation *op) -> DenseSet<Attribute> {
662  return TypeSwitch<Operation *, DenseSet<Attribute>>(op)
663  .Case<FModuleOp, FExtModuleOp>(
664  [&](auto op) { return specializeModuleLike(op); });
665  });
666 
667  // Remove all hierarchical path operations which reference deleted symbols,
668  // and create a set of the removed paths operations. We will have to remove
669  // all annotations which use these paths.
670  DenseSet<StringAttr> removedPaths;
671  for (auto hierPath : llvm::make_early_inc_range(
672  circuit.getBody().getOps<hw::HierPathOp>())) {
673  auto namepath = hierPath.getNamepath().getValue();
674  auto shouldDelete = [&](Attribute ref) {
675  return removedSyms.contains(ref);
676  };
677  if (llvm::any_of(namepath.drop_back(), shouldDelete)) {
678  removedPaths.insert(SymbolTable::getSymbolName(hierPath));
679  hierPath->erase();
680  continue;
681  }
682  // If we deleted the target of the hierpath, we don't need to add it to
683  // the list of removedPaths, since no annotation will be left around to
684  // reference this path.
685  if (shouldDelete(namepath.back()))
686  hierPath->erase();
687  }
688 
689  // Walk all annotations in the circuit and remove the ones that have a
690  // path which traverses or targets a removed instantiation.
691  SmallVector<FModuleLike> clean;
692  for (auto &op : *circuit.getBodyBlock())
693  if (isa<FModuleOp, FExtModuleOp, FIntModuleOp, FMemModuleOp>(op))
694  clean.push_back(cast<FModuleLike>(op));
695 
696  parallelForEach(context, clean, [&](FModuleLike module) {
697  (AnnotationCleaner(removedPaths))(module);
698  });
699  }
700 
701  MLIRContext *context;
702  CircuitOp circuit;
703  const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations;
704  /// The default specialization mode to be applied when a layer has not been
705  /// explicitly enabled or disabled.
706  std::optional<LayerSpecialization> defaultSpecialization;
707 };
708 
709 struct SpecializeLayersPass
710  : public circt::firrtl::impl::SpecializeLayersBase<SpecializeLayersPass> {
711 
712  void runOnOperation() override {
713  auto circuit = getOperation();
714  SymbolTableCollection stc;
715 
716  // Set of layers to enable or disable.
717  DenseMap<SymbolRefAttr, LayerSpecialization> specializations;
718 
719  // If we are not specialization any layers, we can return early.
720  bool shouldSpecialize = false;
721 
722  // Record all the layers which are being enabled.
723  if (auto enabledLayers = circuit.getEnableLayersAttr()) {
724  shouldSpecialize = true;
725  circuit.removeEnableLayersAttr();
726  for (auto enabledLayer : enabledLayers.getAsRange<SymbolRefAttr>()) {
727  // Verify that this is a real layer.
728  if (!stc.lookupSymbolIn(circuit, enabledLayer)) {
729  mlir::emitError(circuit.getLoc()) << "unknown layer " << enabledLayer;
730  signalPassFailure();
731  return;
732  }
733  specializations[enabledLayer] = LayerSpecialization::Enable;
734  }
735  }
736 
737  // Record all of the layers which are being disabled.
738  if (auto disabledLayers = circuit.getDisableLayersAttr()) {
739  shouldSpecialize = true;
740  circuit.removeDisableLayersAttr();
741  for (auto disabledLayer : disabledLayers.getAsRange<SymbolRefAttr>()) {
742  // Verify that this is a real layer.
743  if (!stc.lookupSymbolIn(circuit, disabledLayer)) {
744  mlir::emitError(circuit.getLoc())
745  << "unknown layer " << disabledLayer;
746  signalPassFailure();
747  return;
748  }
749 
750  // Verify that we are not both enabling and disabling this layer.
751  auto [it, inserted] = specializations.try_emplace(
752  disabledLayer, LayerSpecialization::Disable);
753  if (!inserted && it->getSecond() == LayerSpecialization::Enable) {
754  mlir::emitError(circuit.getLoc())
755  << "layer " << disabledLayer << " both enabled and disabled";
756  signalPassFailure();
757  return;
758  }
759  }
760  }
761 
762  std::optional<LayerSpecialization> defaultSpecialization = std::nullopt;
763  if (auto specialization = circuit.getDefaultLayerSpecialization()) {
764  shouldSpecialize = true;
765  defaultSpecialization = *specialization;
766  }
767 
768  // If we did not transform the circuit, return early.
769  // TODO: if both arrays are empty we could preserve specific analyses, but
770  // not all analyses since we have modified the circuit op.
771  if (!shouldSpecialize)
772  return markAllAnalysesPreserved();
773 
774  // Run specialization on our circuit.
775  SpecializeLayers(circuit, specializations, defaultSpecialization)();
776  }
777 };
778 } // end anonymous namespace
779 
780 std::unique_ptr<Pass> firrtl::createSpecializeLayersPass() {
781  return std::make_unique<SpecializeLayersPass>();
782 }
assert(baseType &&"element must be base type")
static AnnotationSet forPort(Operation *op, size_t portNo)
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
std::unique_ptr< mlir::Pass > createSpecializeLayersPass()
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
Definition: FIRRTLUtils.h:295
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
static mlir::ArrayAttr getFromVoidPointer(void *ptr)