CIRCT  20.0.0git
SpecializeLayers.cpp
Go to the documentation of this file.
1 //===- SpecializeLayers.cpp -------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "circt/Dialect/HW/HWOps.h"
15 #include "mlir/IR/Threading.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include <optional>
18 #include <type_traits>
19 
20 namespace circt {
21 namespace firrtl {
22 #define GEN_PASS_DEF_SPECIALIZELAYERS
23 #include "circt/Dialect/FIRRTL/Passes.h.inc"
24 } // namespace firrtl
25 } // namespace circt
26 
27 using namespace mlir;
28 using namespace circt;
29 using namespace firrtl;
30 
31 // TODO: this should be upstreamed.
32 namespace llvm {
33 template <>
34 struct PointerLikeTypeTraits<mlir::ArrayAttr>
35  : public PointerLikeTypeTraits<mlir::Attribute> {
36  static inline mlir::ArrayAttr getFromVoidPointer(void *ptr) {
37  return mlir::ArrayAttr::getFromOpaquePointer(ptr);
38  }
39 };
40 } // namespace llvm
41 
42 namespace {
43 /// Removes non-local annotations whose path is no longer viable, due to
44 /// the removal of module instances.
45 struct AnnotationCleaner {
46  /// Create an AnnotationCleaner which removes any annotation which contains a
47  /// reference to a symbol in removedPaths.
48  AnnotationCleaner(const DenseSet<StringAttr> &removedPaths)
49  : removedPaths(removedPaths) {}
50 
51  AnnotationSet cleanAnnotations(AnnotationSet annos) {
52  annos.removeAnnotations([&](Annotation anno) {
53  if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
54  return removedPaths.contains(nla.getAttr());
55  return false;
56  });
57  return annos;
58  }
59 
60  void cleanAnnotations(Operation *op) {
61  AnnotationSet oldAnnotations(op);
62  // We want to avoid attaching an empty annotation array on to an op that
63  // never had an annotation array in the first place.
64  if (!oldAnnotations.empty()) {
65  auto newAnnotations = cleanAnnotations(oldAnnotations);
66  if (oldAnnotations != newAnnotations)
67  newAnnotations.applyToOperation(op);
68  }
69  }
70 
71  void operator()(FModuleLike module) {
72  // Clean the regular annotations.
73  cleanAnnotations(module);
74 
75  // Clean all port annotations.
76  for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
77  auto oldAnnotations = AnnotationSet::forPort(module, i);
78  if (!oldAnnotations.empty()) {
79  auto newAnnotations = cleanAnnotations(oldAnnotations);
80  if (oldAnnotations != newAnnotations)
81  newAnnotations.applyToPort(module, i);
82  }
83  }
84 
85  // Clean all annotations in body.
86  module->walk([&](Operation *op) {
87  // Clean regular annotations.
88  cleanAnnotations(op);
89 
90  if (auto mem = dyn_cast<MemOp>(op)) {
91  // Update all annotations on ports.
92  for (size_t i = 0, e = mem.getNumResults(); i < e; ++i) {
93  auto oldAnnotations = AnnotationSet::forPort(mem, i);
94  if (!oldAnnotations.empty()) {
95  auto newAnnotations = cleanAnnotations(oldAnnotations);
96  if (oldAnnotations != newAnnotations)
97  newAnnotations.applyToPort(mem, i);
98  }
99  }
100  }
101  });
102  }
103 
104  /// A set of symbols of removed paths.
105  const DenseSet<StringAttr> &removedPaths;
106 };
107 
108 /// Helper to keep track of an insertion point and move operations around.
109 struct InsertionPoint {
110  /// Create an insertion point at the end of a block.
111  static InsertionPoint atBlockEnd(Block *block) {
112  return InsertionPoint{block, block->end()};
113  }
114 
115  /// Move the target operation before the current insertion point and update
116  /// the insertion point to point to the op.
117  void moveOpBefore(Operation *op) {
118  op->moveBefore(block, it);
119  it = Block::iterator(op);
120  }
121 
122 private:
123  InsertionPoint(Block *block, Block::iterator it) : block(block), it(it) {}
124 
125  Block *block;
126  Block::iterator it;
127 };
128 
129 /// A specialized value. If the value is colored such that it is disabled,
130 /// it will not contain an underlying value. Otherwise, contains the
131 /// specialized version of the value.
132 template <typename T>
133 struct Specialized {
134  /// Create a disabled specialized value.
135  Specialized() : Specialized(LayerSpecialization::Disable, nullptr) {}
136 
137  /// Create an enabled specialized value.
138  Specialized(T value) : Specialized(LayerSpecialization::Enable, value) {}
139 
140  /// Returns true if the value was specialized away.
141  bool isDisabled() const {
142  return value.getInt() == LayerSpecialization::Disable;
143  }
144 
145  /// Returns the specialized value if it still exists.
146  T getValue() const {
147  assert(!isDisabled());
148  return value.getPointer();
149  }
150 
151  operator bool() const { return !isDisabled(); }
152 
153 private:
154  Specialized(LayerSpecialization specialization, T value)
155  : value(value, specialization) {}
156  llvm::PointerIntPair<T, 1, LayerSpecialization> value;
157 };
158 
159 struct SpecializeLayers {
160  SpecializeLayers(
161  CircuitOp circuit,
162  const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations,
163  std::optional<LayerSpecialization> defaultSpecialization)
164  : context(circuit->getContext()), circuit(circuit),
165  specializations(specializations),
166  defaultSpecialization(defaultSpecialization) {}
167 
168  /// Create a reference to every field in the inner symbol, and record it in
169  /// the list of removed symbols.
170  static void recordRemovedInnerSym(DenseSet<Attribute> &removedSyms,
171  StringAttr moduleName,
172  hw::InnerSymAttr innerSym) {
173  for (auto field : innerSym)
174  removedSyms.insert(hw::InnerRefAttr::get(moduleName, field.getName()));
175  }
176 
177  /// Create a reference to every field in the inner symbol, and record it in
178  /// the list of removed symbols.
179  static void recordRemovedInnerSyms(DenseSet<Attribute> &removedSyms,
180  StringAttr moduleName, Block *block) {
181  block->walk([&](hw::InnerSymbolOpInterface op) {
182  if (auto innerSym = op.getInnerSymAttr())
183  recordRemovedInnerSym(removedSyms, moduleName, innerSym);
184  });
185  }
186 
187  /// If this layer reference is being specialized, returns the specialization
188  /// mode. Otherwise, it returns disabled value.
189  std::optional<LayerSpecialization> getSpecialization(SymbolRefAttr layerRef) {
190  auto it = specializations.find(layerRef);
191  if (it != specializations.end())
192  return it->getSecond();
193  return defaultSpecialization;
194  }
195 
196  /// Forms a symbol reference to a layer using head as the root reference, and
197  /// nestedRefs as the path to the specific layer. If this layer reference is
198  /// being specialized, returns the specialization mode. Otherwise, it returns
199  /// nullopt.
200  std::optional<LayerSpecialization>
201  getSpecialization(StringAttr head, ArrayRef<FlatSymbolRefAttr> nestedRefs) {
202  return getSpecialization(SymbolRefAttr::get(head, nestedRefs));
203  }
204 
205  /// Specialize a layer reference by removing enabled layers, mangling the
206  /// names of inlined layers, and returning a disabled value if the layer was
207  /// disabled. This can return nullptr if all layers were enabled.
208  Specialized<SymbolRefAttr> specializeLayerRef(SymbolRefAttr layerRef) {
209  // Walk the layer reference from root to leaf, checking if each outer layer
210  // was specialized or not. If an outer layer was disabled, then this
211  // specific layer is implicitly disabled as well. If an outer layer was
212  // enabled, we need to use its name to mangle the name of inner layers. If
213  // the layer is not specialized, we may need to mangle its name, otherwise
214  // we leave it alone.
215 
216  auto oldRoot = layerRef.getRootReference();
217  SmallVector<FlatSymbolRefAttr> oldNestedRefs;
218 
219  // A prefix to be used to track how to mangle the next non-inlined layer.
220  SmallString<64> prefix;
221  SmallVector<FlatSymbolRefAttr> newRef;
222 
223  // Side-effecting helper which returns false if the currently examined layer
224  // is disabled, true otherwise. If the current layer is being enabled, add
225  // its name to the prefix. If the layer is not specialized, we mangle the
226  // name with the prefix which is then reset to the empty string, and copy it
227  // into the new specialize layer reference.
228  auto helper = [&](StringAttr ref) -> bool {
229  auto specialization = getSpecialization(oldRoot, oldNestedRefs);
230 
231  // We are not specializing this layer. Mangle the name with the current
232  // prefix.
233  if (!specialization) {
234  newRef.push_back(FlatSymbolRefAttr::get(
235  StringAttr::get(ref.getContext(), prefix + ref.getValue())));
236  prefix.clear();
237  return true;
238  }
239 
240  // We are enabling this layer, the next non-enabled layer should
241  // include this layer's name as a prefix.
242  if (*specialization == LayerSpecialization::Enable) {
243  prefix.append(ref.getValue());
244  prefix.append("_");
245  return true;
246  }
247 
248  // We are disabling this layer.
249  return false;
250  };
251 
252  if (!helper(oldRoot))
253  return {};
254 
255  for (auto ref : layerRef.getNestedReferences()) {
256  oldNestedRefs.push_back(ref);
257  if (!helper(ref.getAttr()))
258  return {};
259  }
260 
261  if (newRef.empty())
262  return {SymbolRefAttr()};
263 
264  // Root references need to be handled differently than nested references,
265  // but since we don't know before hand which layer will form the new root
266  // layer, we copy all layers to the same array, at the cost of unnecessarily
267  // wrapping the new root reference into a FlatSymbolRefAttr and having to
268  // unpack it again.
269  auto newRoot = newRef.front().getAttr();
270  return {SymbolRefAttr::get(newRoot, ArrayRef(newRef).drop_front())};
271  }
272 
273  /// Specialize a RefType by specializing the layer color. If the RefType is
274  /// colored with a disabled layer, this will return nullptr.
275  RefType specializeRefType(RefType refType) {
276  if (auto oldLayer = refType.getLayer()) {
277  if (auto newLayer = specializeLayerRef(oldLayer))
278  return RefType::get(refType.getType(), refType.getForceable(),
279  newLayer.getValue());
280  return nullptr;
281  }
282  return refType;
283  }
284 
285  Type specializeType(Type type) {
286  if (auto refType = dyn_cast<RefType>(type))
287  return specializeRefType(refType);
288  return type;
289  }
290 
291  /// Specialize a value by modifying its type. Returns nullptr if this value
292  /// should be disabled, and the original value otherwise.
293  Value specializeValue(Value value) {
294  if (auto newType = specializeType(value.getType())) {
295  value.setType(newType);
296  return value;
297  }
298  return nullptr;
299  }
300 
301  /// Specialize a layerblock. If the layerblock is disabled, it and all of its
302  /// contents will be erased, and all removed inner symbols will be recorded
303  /// so that we can later clean up hierarchical paths. If the layer is
304  /// enabled, then we will inline the contents of the layer to the same
305  /// position and delete the layerblock.
306  void specializeOp(LayerBlockOp layerBlock, InsertionPoint &insertionPoint,
307  DenseSet<Attribute> &removedSyms) {
308  auto oldLayerRef = layerBlock.getLayerNameAttr();
309 
310  // Get the specialization for the current layerblock, not taking into
311  // account if an outer layerblock was specialized. We should not have
312  // recursed inside of the disabled layerblock anyways, as it just gets
313  // erased.
314  auto specialization = getSpecialization(oldLayerRef);
315 
316  // We are not specializing this layerblock.
317  if (!specialization) {
318  // We must update the name of this layerblock to reflect
319  // specializations of any outer layers.
320  auto newLayerRef = specializeLayerRef(oldLayerRef).getValue();
321  if (oldLayerRef != newLayerRef)
322  layerBlock.setLayerNameAttr(newLayerRef);
323  // Specialize inner operations, but keep them in their original
324  // location.
325  auto *block = layerBlock.getBody();
326  auto bodyIP = InsertionPoint::atBlockEnd(block);
327  specializeBlock(block, bodyIP, removedSyms);
328  insertionPoint.moveOpBefore(layerBlock);
329  return;
330  }
331 
332  // We are enabling this layer, and all contents of this layer need to be
333  // moved (inlined) to the insertion point.
334  if (*specialization == LayerSpecialization::Enable) {
335  // Move all contents to the insertion point and specialize them.
336  specializeBlock(layerBlock.getBody(), insertionPoint, removedSyms);
337  // Erase the now empty layerblock.
338  layerBlock->erase();
339  return;
340  }
341 
342  // We are disabling this layerblock, so we can just erase the layerblock. We
343  // need to record all the objects with symbols which are being deleted.
344  auto moduleName = layerBlock->getParentOfType<FModuleOp>().getNameAttr();
345  recordRemovedInnerSyms(removedSyms, moduleName, layerBlock.getBody());
346  layerBlock->erase();
347  }
348 
349  void specializeOp(WhenOp when, InsertionPoint &insertionPoint,
350  DenseSet<Attribute> &removedSyms) {
351  // We need to specialize both arms of the when, but the inner ops should
352  // be left where they are (and not inlined to the insertion point).
353  auto *thenBlock = &when.getThenBlock();
354  auto thenIP = InsertionPoint::atBlockEnd(thenBlock);
355  specializeBlock(thenBlock, thenIP, removedSyms);
356  if (when.hasElseRegion()) {
357  auto *elseBlock = &when.getElseBlock();
358  auto elseIP = InsertionPoint::atBlockEnd(elseBlock);
359  specializeBlock(elseBlock, elseIP, removedSyms);
360  }
361  insertionPoint.moveOpBefore(when);
362  }
363 
364  void specializeOp(MatchOp match, InsertionPoint &insertionPoint,
365  DenseSet<Attribute> &removedSyms) {
366  for (size_t i = 0, e = match.getNumRegions(); i < e; ++i) {
367  auto *caseBlock = &match.getRegion(i).front();
368  auto caseIP = InsertionPoint::atBlockEnd(caseBlock);
369  specializeBlock(caseBlock, caseIP, removedSyms);
370  }
371  insertionPoint.moveOpBefore(match);
372  }
373 
374  void specializeOp(InstanceOp instance, InsertionPoint &insertionPoint,
375  DenseSet<Attribute> &removedSyms) {
376  /// Update the types of any probe ports on the instance, and delete any
377  /// probe port that is permanently disabled.
378  llvm::BitVector disabledPorts(instance->getNumResults());
379  for (auto result : instance->getResults())
380  if (!specializeValue(result))
381  disabledPorts.set(result.getResultNumber());
382  if (disabledPorts.any()) {
383  OpBuilder builder(instance);
384  auto newInstance = instance.erasePorts(builder, disabledPorts);
385  instance->erase();
386  instance = newInstance;
387  }
388 
389  // Specialize the required enable layers. Due to the layer verifiers, there
390  // should not be any disabled layer in this instance and this should be
391  // infallible.
392  auto newLayers = specializeEnableLayers(instance.getLayersAttr());
393  instance.setLayersAttr(newLayers.getValue());
394 
395  insertionPoint.moveOpBefore(instance);
396  }
397 
398  void specializeOp(InstanceChoiceOp instanceChoice,
399  InsertionPoint &insertionPoint,
400  DenseSet<Attribute> &removedSyms) {
401  /// Update the types of any probe ports on the instanceChoice, and delete
402  /// any probe port that is permanently disabled.
403  llvm::BitVector disabledPorts(instanceChoice->getNumResults());
404  for (auto result : instanceChoice->getResults())
405  if (!specializeValue(result))
406  disabledPorts.set(result.getResultNumber());
407  if (disabledPorts.any()) {
408  OpBuilder builder(instanceChoice);
409  auto newInstanceChoice =
410  instanceChoice.erasePorts(builder, disabledPorts);
411  instanceChoice->erase();
412  instanceChoice = newInstanceChoice;
413  }
414 
415  // Specialize the required enable layers. Due to the layer verifiers, there
416  // should not be any disabled layer in this instanceChoice and this should
417  // be infallible.
418  auto newLayers = specializeEnableLayers(instanceChoice.getLayersAttr());
419  instanceChoice.setLayersAttr(newLayers.getValue());
420 
421  insertionPoint.moveOpBefore(instanceChoice);
422  }
423 
424  void specializeOp(WireOp wire, InsertionPoint &insertionPoint,
425  DenseSet<Attribute> &removedSyms) {
426  if (specializeValue(wire.getResult())) {
427  insertionPoint.moveOpBefore(wire);
428  } else {
429  if (auto innerSym = wire.getInnerSymAttr())
430  recordRemovedInnerSym(removedSyms,
431  wire->getParentOfType<FModuleOp>().getNameAttr(),
432  innerSym);
433  wire.erase();
434  }
435  }
436 
437  void specializeOp(RefDefineOp refDefine, InsertionPoint &insertionPoint,
438  DenseSet<Attribute> &removedSyms) {
439  // If this is connected disabled probes, erase the refdefine op.
440  if (auto layerRef = refDefine.getDest().getType().getLayer())
441  if (!specializeLayerRef(layerRef)) {
442  refDefine->erase();
443  return;
444  }
445  insertionPoint.moveOpBefore(refDefine);
446  }
447 
448  void specializeOp(RefSubOp refSub, InsertionPoint &insertionPoint,
449  DenseSet<Attribute> &removedSyms) {
450  if (specializeValue(refSub->getResult(0)))
451  insertionPoint.moveOpBefore(refSub);
452  else
453 
454  refSub.erase();
455  }
456 
457  void specializeOp(RefCastOp refCast, InsertionPoint &insertionPoint,
458  DenseSet<Attribute> &removedSyms) {
459  if (specializeValue(refCast->getResult(0)))
460  insertionPoint.moveOpBefore(refCast);
461  else
462  refCast.erase();
463  }
464 
465  /// Specialize a block of operations, removing any probes which are
466  /// disabled, and moving all operations to the insertion point.
467  void specializeBlock(Block *block, InsertionPoint &insertionPoint,
468  DenseSet<Attribute> &removedSyms) {
469  // Since this can erase operations that deal with disabled probes, we walk
470  // the block in reverse to make sure that we erase uses before defs.
471  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
472  TypeSwitch<Operation *>(&op)
473  .Case<LayerBlockOp, WhenOp, MatchOp, InstanceOp, InstanceChoiceOp,
474  WireOp, RefDefineOp, RefSubOp, RefCastOp>(
475  [&](auto op) { specializeOp(op, insertionPoint, removedSyms); })
476  .Default([&](Operation *op) {
477  // By default all operations should be inlined from an enabled
478  // layer.
479  insertionPoint.moveOpBefore(op);
480  });
481  }
482  }
483 
484  /// Specialize the list of enabled layers for a module. Return a disabled
485  /// layer if one of the required layers has been disabled.
486  Specialized<ArrayAttr> specializeEnableLayers(ArrayAttr layers) {
487  SmallVector<Attribute> newLayers;
488  for (auto layer : layers.getAsRange<SymbolRefAttr>()) {
489  auto newLayer = specializeLayerRef(layer);
490  if (!newLayer)
491  return {};
492  if (newLayer.getValue())
493  newLayers.push_back(newLayer.getValue());
494  }
495  return ArrayAttr::get(context, newLayers);
496  }
497 
498  void specializeModulePorts(FModuleLike moduleLike,
499  DenseSet<Attribute> &removedSyms) {
500  auto oldTypeAttrs = moduleLike.getPortTypesAttr();
501 
502  // The list of new port types.
503  SmallVector<Attribute> newTypeAttrs;
504  newTypeAttrs.reserve(oldTypeAttrs.size());
505 
506  // This is the list of port indices which need to be removed because they
507  // have been specialized away.
508  llvm::BitVector disabledPorts(oldTypeAttrs.size());
509 
510  auto moduleName = moduleLike.getNameAttr();
511  for (auto [index, typeAttr] :
512  llvm::enumerate(oldTypeAttrs.getAsRange<TypeAttr>())) {
513  // Specialize the type fo the port.
514  if (auto type = specializeType(typeAttr.getValue())) {
515  newTypeAttrs.push_back(TypeAttr::get(type));
516  } else {
517  // The port is being disabled, and should be removed.
518  if (auto portSym = moduleLike.getPortSymbolAttr(index))
519  recordRemovedInnerSym(removedSyms, moduleName, portSym);
520  disabledPorts.set(index);
521  }
522  }
523 
524  // Erase the disabled ports.
525  moduleLike.erasePorts(disabledPorts);
526 
527  // Update the rest of the port types.
528  moduleLike.setPortTypesAttr(
529  ArrayAttr::get(moduleLike.getContext(), newTypeAttrs));
530 
531  // We may also need to update the types on the block arguments.
532  if (auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation()))
533  for (auto [arg, typeAttr] :
534  llvm::zip(moduleOp.getArguments(), newTypeAttrs))
535  arg.setType(cast<TypeAttr>(typeAttr).getValue());
536  }
537 
538  template <typename T>
539  DenseSet<Attribute> specializeModuleLike(T op) {
540  DenseSet<Attribute> removedSyms;
541 
542  // Specialize all operations in the body of the module. This must be done
543  // before specializing the module ports so that we don't try to erase values
544  // that still have uses.
545  if constexpr (std::is_same_v<T, FModuleOp>) {
546  auto *block = cast<FModuleOp>(op).getBodyBlock();
547  auto bodyIP = InsertionPoint::atBlockEnd(block);
548  specializeBlock(block, bodyIP, removedSyms);
549  }
550 
551  // Specialize the ports of this module.
552  specializeModulePorts(op, removedSyms);
553 
554  return removedSyms;
555  }
556 
557  template <typename T>
558  T specializeEnableLayers(T module, DenseSet<Attribute> &removedSyms) {
559  // Update the required layers on the module.
560  if (auto newLayers = specializeEnableLayers(module.getLayersAttr())) {
561  module.setLayersAttr(newLayers.getValue());
562  return module;
563  }
564 
565  // If we disabled a layer which this module requires, we must delete the
566  // whole module.
567  auto moduleName = module.getNameAttr();
568  removedSyms.insert(FlatSymbolRefAttr::get(moduleName));
569  if constexpr (std::is_same_v<T, FModuleOp>)
570  recordRemovedInnerSyms(removedSyms, moduleName,
571  cast<FModuleOp>(module).getBodyBlock());
572 
573  module->erase();
574  return nullptr;
575  }
576 
577  /// Specialize a layer operation, by removing enabled layers and inlining
578  /// their contents, deleting disabled layers and all nested layers, and
579  /// mangling the names of any inlined layers.
580  void specializeLayer(LayerOp layer) {
581  StringAttr head = layer.getSymNameAttr();
582  SmallVector<FlatSymbolRefAttr> nestedRefs;
583 
584  std::function<void(LayerOp, Block::iterator, const Twine &)> handleLayer =
585  [&](LayerOp layer, Block::iterator insertionPoint,
586  const Twine &prefix) {
587  auto *block = &layer.getBody().getBlocks().front();
588  auto specialization = getSpecialization(head, nestedRefs);
589 
590  // If we are not specializing the current layer, visit the inner
591  // layers.
592  if (!specialization) {
593  // We only mangle the name and move the layer if the prefix is
594  // non-empty, which indicates that we are enabling the parent
595  // layer.
596  if (!prefix.isTriviallyEmpty()) {
597  layer.setSymNameAttr(
598  StringAttr::get(context, prefix + layer.getSymName()));
599  auto *parentBlock = insertionPoint->getBlock();
600  layer->moveBefore(parentBlock, insertionPoint);
601  }
602  for (auto nested :
603  llvm::make_early_inc_range(block->getOps<LayerOp>())) {
604  nestedRefs.push_back(SymbolRefAttr::get(nested));
605  handleLayer(nested, Block::iterator(nested), "");
606  nestedRefs.pop_back();
607  }
608  return;
609  }
610 
611  // We are enabling this layer. We must inline inner layers, and
612  // mangle their names.
613  if (*specialization == LayerSpecialization::Enable) {
614  for (auto nested :
615  llvm::make_early_inc_range(block->getOps<LayerOp>())) {
616  nestedRefs.push_back(SymbolRefAttr::get(nested));
617  handleLayer(nested, insertionPoint,
618  prefix + layer.getSymName() + "_");
619  nestedRefs.pop_back();
620  }
621  // Erase the now empty layer.
622  layer->erase();
623  return;
624  }
625 
626  // If we are disabling this layer, then we can just fully delete
627  // it.
628  layer->erase();
629  };
630 
631  handleLayer(layer, Block::iterator(layer), "");
632  }
633 
634  void operator()() {
635  // Gather all operations we need to specialize, and all the ops that we
636  // need to clean. We specialize layers and module's enable layers here
637  // because that can delete the operations and must be done serially.
638  SmallVector<Operation *> specialize;
639  DenseSet<Attribute> removedSyms;
640  for (auto &op : llvm::make_early_inc_range(*circuit.getBodyBlock())) {
641  TypeSwitch<Operation *>(&op)
642  .Case<FModuleOp, FExtModuleOp>([&](auto module) {
643  if (specializeEnableLayers(module, removedSyms))
644  specialize.push_back(module);
645  })
646  .Case<LayerOp>([&](LayerOp layer) { specializeLayer(layer); });
647  }
648 
649  // Function to merge two sets together.
650  auto mergeSets = [](auto &&a, auto &&b) {
651  a.insert(b.begin(), b.end());
652  return std::forward<decltype(a)>(a);
653  };
654 
655  // Specialize all modules in parallel. The result is a set of all inner
656  // symbol references which are no longer valid due to disabling layers.
657  removedSyms = transformReduce(
658  context, specialize, removedSyms, mergeSets,
659  [&](Operation *op) -> DenseSet<Attribute> {
660  return TypeSwitch<Operation *, DenseSet<Attribute>>(op)
661  .Case<FModuleOp, FExtModuleOp>(
662  [&](auto op) { return specializeModuleLike(op); });
663  });
664 
665  // Remove all hierarchical path operations which reference deleted symbols,
666  // and create a set of the removed paths operations. We will have to remove
667  // all annotations which use these paths.
668  DenseSet<StringAttr> removedPaths;
669  for (auto hierPath : llvm::make_early_inc_range(
670  circuit.getBody().getOps<hw::HierPathOp>())) {
671  auto namepath = hierPath.getNamepath().getValue();
672  auto shouldDelete = [&](Attribute ref) {
673  return removedSyms.contains(ref);
674  };
675  if (llvm::any_of(namepath.drop_back(), shouldDelete)) {
676  removedPaths.insert(SymbolTable::getSymbolName(hierPath));
677  hierPath->erase();
678  continue;
679  }
680  // If we deleted the target of the hierpath, we don't need to add it to
681  // the list of removedPaths, since no annotation will be left around to
682  // reference this path.
683  if (shouldDelete(namepath.back()))
684  hierPath->erase();
685  }
686 
687  // Walk all annotations in the circuit and remove the ones that have a
688  // path which traverses or targets a removed instantiation.
689  SmallVector<FModuleLike> clean;
690  for (auto &op : *circuit.getBodyBlock())
691  if (isa<FModuleOp, FExtModuleOp, FIntModuleOp, FMemModuleOp>(op))
692  clean.push_back(cast<FModuleLike>(op));
693 
694  parallelForEach(context, clean, [&](FModuleLike module) {
695  (AnnotationCleaner(removedPaths))(module);
696  });
697  }
698 
699  MLIRContext *context;
700  CircuitOp circuit;
701  const DenseMap<SymbolRefAttr, LayerSpecialization> &specializations;
702  /// The default specialization mode to be applied when a layer has not been
703  /// explicitly enabled or disabled.
704  std::optional<LayerSpecialization> defaultSpecialization;
705 };
706 
707 struct SpecializeLayersPass
708  : public circt::firrtl::impl::SpecializeLayersBase<SpecializeLayersPass> {
709 
710  void runOnOperation() override {
711  auto circuit = getOperation();
712  SymbolTableCollection stc;
713 
714  // Set of layers to enable or disable.
715  DenseMap<SymbolRefAttr, LayerSpecialization> specializations;
716 
717  // If we are not specialization any layers, we can return early.
718  bool shouldSpecialize = false;
719 
720  // Record all the layers which are being enabled.
721  if (auto enabledLayers = circuit.getEnableLayersAttr()) {
722  shouldSpecialize = true;
723  circuit.removeEnableLayersAttr();
724  for (auto enabledLayer : enabledLayers.getAsRange<SymbolRefAttr>()) {
725  // Verify that this is a real layer.
726  if (!stc.lookupSymbolIn(circuit, enabledLayer)) {
727  mlir::emitError(circuit.getLoc()) << "unknown layer " << enabledLayer;
728  signalPassFailure();
729  return;
730  }
731  specializations[enabledLayer] = LayerSpecialization::Enable;
732  }
733  }
734 
735  // Record all of the layers which are being disabled.
736  if (auto disabledLayers = circuit.getDisableLayersAttr()) {
737  shouldSpecialize = true;
738  circuit.removeDisableLayersAttr();
739  for (auto disabledLayer : disabledLayers.getAsRange<SymbolRefAttr>()) {
740  // Verify that this is a real layer.
741  if (!stc.lookupSymbolIn(circuit, disabledLayer)) {
742  mlir::emitError(circuit.getLoc())
743  << "unknown layer " << disabledLayer;
744  signalPassFailure();
745  return;
746  }
747 
748  // Verify that we are not both enabling and disabling this layer.
749  auto [it, inserted] = specializations.try_emplace(
750  disabledLayer, LayerSpecialization::Disable);
751  if (!inserted && it->getSecond() == LayerSpecialization::Enable) {
752  mlir::emitError(circuit.getLoc())
753  << "layer " << disabledLayer << " both enabled and disabled";
754  signalPassFailure();
755  return;
756  }
757  }
758  }
759 
760  std::optional<LayerSpecialization> defaultSpecialization = std::nullopt;
761  if (auto specialization = circuit.getDefaultLayerSpecialization()) {
762  shouldSpecialize = true;
763  defaultSpecialization = *specialization;
764  }
765 
766  // If we did not transform the circuit, return early.
767  // TODO: if both arrays are empty we could preserve specific analyses, but
768  // not all analyses since we have modified the circuit op.
769  if (!shouldSpecialize)
770  return markAllAnalysesPreserved();
771 
772  // Run specialization on our circuit.
773  SpecializeLayers(circuit, specializations, defaultSpecialization)();
774  }
775 };
776 } // end anonymous namespace
777 
778 std::unique_ptr<Pass> firrtl::createSpecializeLayersPass() {
779  return std::make_unique<SpecializeLayersPass>();
780 }
assert(baseType &&"element must be base type")
static AnnotationSet forPort(Operation *op, size_t portNo)
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
std::unique_ptr< mlir::Pass > createSpecializeLayersPass()
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
Definition: FIRRTLUtils.h:295
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
static mlir::ArrayAttr getFromVoidPointer(void *ptr)