CIRCT  19.0.0git
LowerSignatures.cpp
Go to the documentation of this file.
1 //===- LowerSignatures.cpp - Lower Module Signatures ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the LowerSignatures pass. This pass replaces aggregate
10 // types with expanded values in module arguments as specified by the ABI
11 // information.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetails.h"
16 
28 #include "circt/Dialect/SV/SVOps.h"
29 #include "circt/Support/Debug.h"
30 #include "mlir/IR/ImplicitLocOpBuilder.h"
31 #include "mlir/IR/Threading.h"
32 #include "llvm/ADT/APSInt.h"
33 #include "llvm/ADT/BitVector.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/Parallel.h"
36 
37 #define DEBUG_TYPE "firrtl-lower-signatures"
38 
39 using namespace circt;
40 using namespace firrtl;
41 
42 //===----------------------------------------------------------------------===//
43 // Module Type Lowering
44 //===----------------------------------------------------------------------===//
45 namespace {
46 
47 struct AttrCache {
48  AttrCache(MLIRContext *context) {
49  nameAttr = StringAttr::get(context, "name");
50  sPortDirections = StringAttr::get(context, "portDirections");
51  sPortNames = StringAttr::get(context, "portNames");
52  sPortTypes = StringAttr::get(context, "portTypes");
53  sPortLocations = StringAttr::get(context, "portLocations");
54  sPortAnnotations = StringAttr::get(context, "portAnnotations");
55  sInternalPaths = StringAttr::get(context, "internalPaths");
56  }
57  AttrCache(const AttrCache &) = default;
58 
59  StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
60  sPortAnnotations, sInternalPaths;
61 };
62 
63 struct FieldMapEntry : public PortInfo {
64  size_t portID;
65  size_t resultID;
66  size_t fieldID;
67 };
68 
69 using PortConversion = SmallVector<FieldMapEntry>;
70 
71 template <typename T>
72 class FieldIDSearch {
73  using E = typename T::ElementType;
74  using V = SmallVector<E>;
75 
76 public:
77  using const_iterator = typename V::const_iterator;
78 
79  template <typename Container>
80  FieldIDSearch(const Container &src) {
81  if constexpr (std::is_convertible_v<Container, Attribute>)
82  if (!src)
83  return;
84  for (auto attr : src)
85  vals.push_back(attr);
86  std::sort(vals.begin(), vals.end(), fieldComp);
87  }
88 
89  std::pair<const_iterator, const_iterator> find(uint64_t low,
90  uint64_t high) const {
91  return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
92  std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
93  }
94 
95  bool empty(uint64_t low, uint64_t high) const {
96  auto [b, e] = find(low, high);
97  return b == e;
98  }
99 
100 private:
101  static constexpr auto fieldComp = [](const E &lhs, const E &rhs) {
102  return lhs.getFieldID() < rhs.getFieldID();
103  };
104  static constexpr auto fieldCompInt2 = [](const E &lhs, uint64_t rhs) {
105  return lhs.getFieldID() < rhs;
106  };
107  static constexpr auto fieldCompInt1 = [](uint64_t lhs, const E &rhs) {
108  return lhs < rhs.getFieldID();
109  };
110 
111  V vals;
112 };
113 
114 } // namespace
115 
116 static hw::InnerSymAttr
117 symbolsForFieldIDRange(MLIRContext *ctx,
118  const FieldIDSearch<hw::InnerSymAttr> &syms,
119  uint64_t low, uint64_t high) {
120  auto [b, e] = syms.find(low, high);
121  SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
122  if (newSyms.empty())
123  return {};
124  for (auto &sym : newSyms)
126  ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
127  return hw::InnerSymAttr::get(ctx, newSyms);
128 }
129 
130 static AnnotationSet
131 annosForFieldIDRange(MLIRContext *ctx,
132  const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
133  uint64_t high) {
134  AnnotationSet newAnnos(ctx);
135  auto [b, e] = annos.find(low, high);
136  for (; b != e; ++b)
137  newAnnos.addAnnotations(Annotation(*b, b->getFieldID() - low));
138  return newAnnos;
139 }
140 
141 static LogicalResult
142 computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv,
143  size_t portID, const PortInfo &port, bool isFlip,
144  Twine name, FIRRTLType type, uint64_t fieldID,
145  const FieldIDSearch<hw::InnerSymAttr> &syms,
146  const FieldIDSearch<AnnotationSet> &annos) {
147  auto *ctx = type.getContext();
149  .Case<BundleType>([&](BundleType bundle) -> LogicalResult {
150  // This should be enhanced to be able to handle bundle<all flips of
151  // passive>, or this should be a canonicalizer
152  if (conv != Convention::Scalarized && bundle.isPassive()) {
153  auto lastId = fieldID + bundle.getMaxFieldID();
154  newPorts.push_back(
155  {{StringAttr::get(ctx, name), type,
156  isFlip ? Direction::Out : Direction::In,
157  symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
158  annosForFieldIDRange(ctx, annos, fieldID, lastId)},
159  portID,
160  newPorts.size(),
161  fieldID});
162  } else {
163  for (auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
164  if (failed(computeLoweringImpl(
165  mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
166  name + "_" + elem.name.getValue(), elem.type,
167  fieldID + bundle.getFieldID(idx), syms, annos)))
168  return failure();
169  if (!syms.empty(fieldID, fieldID))
170  return mod.emitError("Port [")
171  << port.name
172  << "] should be subdivided, but cannot be because of "
173  "symbol ["
174  << port.sym.getSymIfExists(fieldID) << "] on a bundle";
175  if (!annos.empty(fieldID, fieldID)) {
176  auto err = mod.emitError("Port [")
177  << port.name
178  << "] should be subdivided, but cannot be because of "
179  "annotations [";
180  auto [b, e] = annos.find(fieldID, fieldID);
181  err << b->getClass() << "(" << b->getFieldID() << ")";
182  b++;
183  for (; b != e; ++b)
184  err << ", " << b->getClass() << "(" << b->getFieldID() << ")";
185  err << "] on a bundle";
186  return err;
187  }
188  }
189  }
190  return success();
191  })
192  .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
193  if (conv != Convention::Scalarized &&
194  vector.getElementType().isPassive()) {
195  auto lastId = fieldID + vector.getMaxFieldID();
196  newPorts.push_back(
197  {{StringAttr::get(ctx, name), type,
198  isFlip ? Direction::Out : Direction::In,
199  symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
200  annosForFieldIDRange(ctx, annos, fieldID, lastId)},
201  portID,
202  newPorts.size(),
203  fieldID});
204  } else {
205  for (size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
206  if (failed(computeLoweringImpl(
207  mod, newPorts, conv, portID, port, isFlip,
208  name + "_" + Twine(i), vector.getElementType(),
209  fieldID + vector.getFieldID(i), syms, annos)))
210  return failure();
211  if (!syms.empty(fieldID, fieldID))
212  return mod.emitError("Port [")
213  << port.name
214  << "] should be subdivided, but cannot be because of "
215  "symbol ["
216  << port.sym.getSymIfExists(fieldID) << "] on a vector";
217  if (!annos.empty(fieldID, fieldID)) {
218  auto err = mod.emitError("Port [")
219  << port.name
220  << "] should be subdivided, but cannot be because of "
221  "annotations [";
222  auto [b, e] = annos.find(fieldID, fieldID);
223  err << b->getClass();
224  ++b;
225  for (; b != e; ++b)
226  err << ", " << b->getClass();
227  err << "] on a vector";
228  return err;
229  }
230  }
231  }
232  return success();
233  })
234  .Case<FEnumType>([&](FEnumType fenum) { return failure(); })
235  .Default([&](FIRRTLType type) {
236  // Properties and other types wind up here.
237  newPorts.push_back(
238  {{StringAttr::get(ctx, name), type,
239  isFlip ? Direction::Out : Direction::In,
240  symbolsForFieldIDRange(ctx, syms, fieldID, fieldID), port.loc,
241  annosForFieldIDRange(ctx, annos, fieldID, fieldID)},
242  portID,
243  newPorts.size(),
244  fieldID});
245  return success();
246  });
247 }
248 
249 // compute a new moduletype from an old module type and lowering convention.
250 // Also compute a fieldID map from port, fieldID -> port
251 static LogicalResult computeLowering(FModuleLike mod, Convention conv,
252  PortConversion &newPorts) {
253  for (auto [idx, port] : llvm::enumerate(mod.getPorts())) {
254  if (failed(computeLoweringImpl(
255  mod, newPorts, conv, idx, port, port.direction == Direction::Out,
256  port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
257  FieldIDSearch<hw::InnerSymAttr>(port.sym),
258  FieldIDSearch<AnnotationSet>(port.annotations))))
259  return failure();
260  }
261  return success();
262 }
263 
264 static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv,
265  AttrCache &cache,
266  PortConversion &newPorts) {
267  ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
268  if (computeLowering(module, conv, newPorts).failed())
269  return failure();
270  if (auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
271  Block *body = mod.getBodyBlock();
272  theBuilder.setInsertionPointToStart(body);
273  auto oldNumArgs = body->getNumArguments();
274 
275  // Compute the replacement value for old arguments
276  // This creates all the new arguments and produces bounce wires when
277  // necessary
278  SmallVector<Value> bounceWires(oldNumArgs);
279  for (auto &p : newPorts) {
280  auto newArg = body->addArgument(p.type, p.loc);
281  // Get or create a bounce wire for changed ports
282  // For unmodified ports, move the uses to the replacement port
283  if (p.fieldID != 0) {
284  auto &wire = bounceWires[p.portID];
285  if (!wire)
286  wire = theBuilder
287  .create<WireOp>(module.getPortType(p.portID),
288  module.getPortNameAttr(p.portID),
289  NameKindEnum::InterestingName)
290  .getResult();
291  } else {
292  bounceWires[p.portID] = newArg;
293  }
294  }
295  // replace old arguments. Somethings get dropped completely, like
296  // zero-length vectors.
297  for (auto idx = 0U; idx < oldNumArgs; ++idx) {
298  if (!bounceWires[idx]) {
299  bounceWires[idx] = theBuilder
300  .create<WireOp>(module.getPortType(idx),
301  module.getPortNameAttr(idx))
302  .getResult();
303  }
304  body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
305  }
306 
307  // Goodbye old ports, now ResultID in the PortInfo is correct.
308  body->eraseArguments(0, oldNumArgs);
309 
310  // Connect the bounce wires to the new arguments
311  for (auto &p : newPorts) {
312  if (isa<BlockArgument>(bounceWires[p.portID]))
313  continue;
314  if (p.isOutput())
315  emitConnect(
316  theBuilder, body->getArgument(p.resultID),
317  getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID));
318  else
319  emitConnect(
320  theBuilder,
321  getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID),
322  body->getArgument(p.resultID));
323  }
324  }
325 
326  SmallVector<NamedAttribute, 8> newModuleAttrs;
327 
328  // Copy over any attributes that weren't original argument attributes.
329  for (auto attr : module->getAttrDictionary())
330  // Drop old "portNames", directions, and argument attributes. These are
331  // handled differently below.
332  if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
333  attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
334  attr.getName() != "portSyms" && attr.getName() != "portLocations" &&
335  attr.getName() != "internalPaths")
336  newModuleAttrs.push_back(attr);
337 
338  SmallVector<Direction> newPortDirections;
339  SmallVector<Attribute> newPortNames;
340  SmallVector<Attribute> newPortTypes;
341  SmallVector<Attribute> newPortSyms;
342  SmallVector<Attribute> newPortLocations;
343  SmallVector<Attribute, 8> newPortAnnotations;
344  SmallVector<Attribute> newInternalPaths;
345 
346  bool hasInternalPaths = false;
347  auto internalPaths = module->getAttrOfType<ArrayAttr>("internalPaths");
348  for (auto p : newPorts) {
349  newPortTypes.push_back(TypeAttr::get(p.type));
350  newPortNames.push_back(p.name);
351  newPortDirections.push_back(p.direction);
352  newPortSyms.push_back(p.sym);
353  newPortLocations.push_back(p.loc);
354  newPortAnnotations.push_back(p.annotations.getArrayAttr());
355  if (internalPaths) {
356  auto internalPath = cast<InternalPathAttr>(internalPaths[p.portID]);
357  newInternalPaths.push_back(internalPath);
358  if (internalPath.getPath())
359  hasInternalPaths = true;
360  }
361  }
362 
363  newModuleAttrs.push_back(NamedAttribute(
364  cache.sPortDirections,
365  direction::packAttribute(module.getContext(), newPortDirections)));
366 
367  newModuleAttrs.push_back(
368  NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
369 
370  newModuleAttrs.push_back(
371  NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
372 
373  newModuleAttrs.push_back(NamedAttribute(
374  cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
375 
376  newModuleAttrs.push_back(NamedAttribute(
377  cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
378 
379  assert(newInternalPaths.empty() ||
380  newInternalPaths.size() == newPorts.size());
381  if (hasInternalPaths) {
382  newModuleAttrs.emplace_back(cache.sInternalPaths,
383  theBuilder.getArrayAttr(newInternalPaths));
384  }
385 
386  // Update the module's attributes.
387  module->setAttrs(newModuleAttrs);
388  module.setPortSymbols(newPortSyms);
389  return success();
390 }
391 
392 static void lowerModuleBody(FModuleOp mod,
393  const DenseMap<StringAttr, PortConversion> &ports) {
394  mod->walk([&](InstanceOp inst) -> void {
395  ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
396  const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
397 
398  // Fix up the Instance
399  SmallVector<PortInfo> instPorts; // Oh I wish ArrayRef was polymorphic.
400  for (auto p : modPorts) {
401  p.sym = {};
402  // Might need to partially copy stuff from the old instance.
403  p.annotations = AnnotationSet{mod.getContext()};
404  instPorts.push_back(p);
405  }
406  auto annos = inst.getAnnotations();
407  auto newOp = theBuilder.create<InstanceOp>(
408  instPorts, inst.getModuleName(), inst.getName(), inst.getNameKind(),
409  annos.getValue(), inst.getLayers(), inst.getLowerToBind(),
410  inst.getInnerSymAttr());
411 
412  auto oldDict = inst->getDiscardableAttrDictionary();
413  auto newDict = newOp->getDiscardableAttrDictionary();
414  auto oldNames = inst.getPortNamesAttr();
415  SmallVector<NamedAttribute> newAttrs;
416  for (auto na : oldDict)
417  if (!newDict.contains(na.getName()))
418  newOp->setDiscardableAttr(na.getName(), na.getValue());
419 
420  // Connect up the old instance users to the new instance
421  SmallVector<WireOp> bounce(inst.getNumResults());
422  for (auto p : modPorts) {
423  // No change? No bounce wire.
424  if (p.fieldID == 0) {
425  inst.getResult(p.portID).replaceAllUsesWith(
426  newOp.getResult(p.resultID));
427  continue;
428  }
429  if (!bounce[p.portID]) {
430  bounce[p.portID] = theBuilder.create<WireOp>(
431  inst.getResult(p.portID).getType(),
432  theBuilder.getStringAttr(
433  inst.getName() + "." +
434  cast<StringAttr>(oldNames[p.portID]).getValue()));
435  inst.getResult(p.portID).replaceAllUsesWith(
436  bounce[p.portID].getResult());
437  }
438  // Connect up the Instance to the bounce wires
439  if (p.isInput())
440  emitConnect(theBuilder, newOp.getResult(p.resultID),
441  getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
442  p.fieldID));
443  else
444  emitConnect(theBuilder,
445  getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
446  p.fieldID),
447  newOp.getResult(p.resultID));
448  }
449  // Zero Width ports may have dangling connects since they are not preserved
450  // and do not have bounce wires.
451  for (auto *use : llvm::make_early_inc_range(inst->getUsers())) {
452  assert(isa<StrictConnectOp>(use) || isa<ConnectOp>(use));
453  use->erase();
454  }
455  inst->erase();
456  return;
457  });
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // Pass Infrastructure
462 //===----------------------------------------------------------------------===//
463 
464 namespace {
465 struct LowerSignaturesPass : public LowerSignaturesBase<LowerSignaturesPass> {
466  void runOnOperation() override;
467 };
468 } // end anonymous namespace
469 
470 // This is the main entrypoint for the lowering pass.
471 void LowerSignaturesPass::runOnOperation() {
472  LLVM_DEBUG(debugPassHeader(this) << "\n");
473  // Cached attr
474  AttrCache cache(&getContext());
475 
476  DenseMap<StringAttr, PortConversion> portMap;
477  auto circuit = getOperation();
478 
479  for (auto mod : circuit.getOps<FModuleLike>()) {
480  if (lowerModuleSignature(mod, mod.getConvention(), cache,
481  portMap[mod.getNameAttr()])
482  .failed())
483  return signalPassFailure();
484  }
485  parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
486  [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
487 }
488 
489 /// This is the pass constructor.
490 std::unique_ptr<mlir::Pass> circt::firrtl::createLowerSignaturesPass() {
491  return std::make_unique<LowerSignaturesPass>();
492 }
assert(baseType &&"element must be base type")
static InstancePath empty
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:518
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:528
Base class for the port conversion of a particular port.
Definition: PortConverter.h:97
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
std::unique_ptr< mlir::Pass > createLowerSignaturesPass()
This is the pass constructor.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:24
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31
This holds the name and type that describes the module's ports.