CIRCT  20.0.0git
LowerSignatures.cpp
Go to the documentation of this file.
1 //===- LowerSignatures.cpp - Lower Module Signatures ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the LowerSignatures pass. This pass replaces aggregate
10 // types with expanded values in module arguments as specified by the ABI
11 // information.
12 //
13 //===----------------------------------------------------------------------===//
14 
21 #include "circt/Support/Debug.h"
22 #include "mlir/IR/Threading.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "firrtl-lower-signatures"
27 
28 namespace circt {
29 namespace firrtl {
30 #define GEN_PASS_DEF_LOWERSIGNATURES
31 #include "circt/Dialect/FIRRTL/Passes.h.inc"
32 } // namespace firrtl
33 } // namespace circt
34 
35 using namespace circt;
36 using namespace firrtl;
37 
38 //===----------------------------------------------------------------------===//
39 // Module Type Lowering
40 //===----------------------------------------------------------------------===//
41 namespace {
42 
43 struct AttrCache {
44  AttrCache(MLIRContext *context) {
45  nameAttr = StringAttr::get(context, "name");
46  sPortDirections = StringAttr::get(context, "portDirections");
47  sPortNames = StringAttr::get(context, "portNames");
48  sPortTypes = StringAttr::get(context, "portTypes");
49  sPortLocations = StringAttr::get(context, "portLocations");
50  sPortAnnotations = StringAttr::get(context, "portAnnotations");
51  sInternalPaths = StringAttr::get(context, "internalPaths");
52  }
53  AttrCache(const AttrCache &) = default;
54 
55  StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
56  sPortAnnotations, sInternalPaths;
57 };
58 
59 struct FieldMapEntry : public PortInfo {
60  size_t portID;
61  size_t resultID;
62  size_t fieldID;
63 };
64 
65 using PortConversion = SmallVector<FieldMapEntry>;
66 
67 template <typename T>
68 class FieldIDSearch {
69  using E = typename T::ElementType;
70  using V = SmallVector<E>;
71 
72 public:
73  using const_iterator = typename V::const_iterator;
74 
75  template <typename Container>
76  FieldIDSearch(const Container &src) {
77  if constexpr (std::is_convertible_v<Container, Attribute>)
78  if (!src)
79  return;
80  for (auto attr : src)
81  vals.push_back(attr);
82  std::sort(vals.begin(), vals.end(), fieldComp);
83  }
84 
85  std::pair<const_iterator, const_iterator> find(uint64_t low,
86  uint64_t high) const {
87  return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
88  std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
89  }
90 
91  bool empty(uint64_t low, uint64_t high) const {
92  auto [b, e] = find(low, high);
93  return b == e;
94  }
95 
96 private:
97  static constexpr auto fieldComp = [](const E &lhs, const E &rhs) {
98  return lhs.getFieldID() < rhs.getFieldID();
99  };
100  static constexpr auto fieldCompInt2 = [](const E &lhs, uint64_t rhs) {
101  return lhs.getFieldID() < rhs;
102  };
103  static constexpr auto fieldCompInt1 = [](uint64_t lhs, const E &rhs) {
104  return lhs < rhs.getFieldID();
105  };
106 
107  V vals;
108 };
109 
110 } // namespace
111 
112 static hw::InnerSymAttr
113 symbolsForFieldIDRange(MLIRContext *ctx,
114  const FieldIDSearch<hw::InnerSymAttr> &syms,
115  uint64_t low, uint64_t high) {
116  auto [b, e] = syms.find(low, high);
117  SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
118  if (newSyms.empty())
119  return {};
120  for (auto &sym : newSyms)
122  ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
123  return hw::InnerSymAttr::get(ctx, newSyms);
124 }
125 
126 static AnnotationSet
127 annosForFieldIDRange(MLIRContext *ctx,
128  const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
129  uint64_t high) {
130  AnnotationSet newAnnos(ctx);
131  auto [b, e] = annos.find(low, high);
132  for (; b != e; ++b)
133  newAnnos.addAnnotations(Annotation(*b, b->getFieldID() - low));
134  return newAnnos;
135 }
136 
137 static LogicalResult
138 computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv,
139  size_t portID, const PortInfo &port, bool isFlip,
140  Twine name, FIRRTLType type, uint64_t fieldID,
141  const FieldIDSearch<hw::InnerSymAttr> &syms,
142  const FieldIDSearch<AnnotationSet> &annos) {
143  auto *ctx = type.getContext();
145  .Case<BundleType>([&](BundleType bundle) -> LogicalResult {
146  // This should be enhanced to be able to handle bundle<all flips of
147  // passive>, or this should be a canonicalizer
148  if (conv != Convention::Scalarized && bundle.isPassive()) {
149  auto lastId = fieldID + bundle.getMaxFieldID();
150  newPorts.push_back(
151  {{StringAttr::get(ctx, name), type,
152  isFlip ? Direction::Out : Direction::In,
153  symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
154  annosForFieldIDRange(ctx, annos, fieldID, lastId)},
155  portID,
156  newPorts.size(),
157  fieldID});
158  } else {
159  for (auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
160  if (failed(computeLoweringImpl(
161  mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
162  name + "_" + elem.name.getValue(), elem.type,
163  fieldID + bundle.getFieldID(idx), syms, annos)))
164  return failure();
165  if (!syms.empty(fieldID, fieldID))
166  return mod.emitError("Port [")
167  << port.name
168  << "] should be subdivided, but cannot be because of "
169  "symbol ["
170  << port.sym.getSymIfExists(fieldID) << "] on a bundle";
171  if (!annos.empty(fieldID, fieldID)) {
172  auto err = mod.emitError("Port [")
173  << port.name
174  << "] should be subdivided, but cannot be because of "
175  "annotations [";
176  auto [b, e] = annos.find(fieldID, fieldID);
177  err << b->getClass() << "(" << b->getFieldID() << ")";
178  b++;
179  for (; b != e; ++b)
180  err << ", " << b->getClass() << "(" << b->getFieldID() << ")";
181  err << "] on a bundle";
182  return err;
183  }
184  }
185  }
186  return success();
187  })
188  .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
189  if (conv != Convention::Scalarized &&
190  vector.getElementType().isPassive()) {
191  auto lastId = fieldID + vector.getMaxFieldID();
192  newPorts.push_back(
193  {{StringAttr::get(ctx, name), type,
194  isFlip ? Direction::Out : Direction::In,
195  symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
196  annosForFieldIDRange(ctx, annos, fieldID, lastId)},
197  portID,
198  newPorts.size(),
199  fieldID});
200  } else {
201  for (size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
202  if (failed(computeLoweringImpl(
203  mod, newPorts, conv, portID, port, isFlip,
204  name + "_" + Twine(i), vector.getElementType(),
205  fieldID + vector.getFieldID(i), syms, annos)))
206  return failure();
207  if (!syms.empty(fieldID, fieldID))
208  return mod.emitError("Port [")
209  << port.name
210  << "] should be subdivided, but cannot be because of "
211  "symbol ["
212  << port.sym.getSymIfExists(fieldID) << "] on a vector";
213  if (!annos.empty(fieldID, fieldID)) {
214  auto err = mod.emitError("Port [")
215  << port.name
216  << "] should be subdivided, but cannot be because of "
217  "annotations [";
218  auto [b, e] = annos.find(fieldID, fieldID);
219  err << b->getClass();
220  ++b;
221  for (; b != e; ++b)
222  err << ", " << b->getClass();
223  err << "] on a vector";
224  return err;
225  }
226  }
227  }
228  return success();
229  })
230  .Case<FEnumType>([&](FEnumType fenum) { return failure(); })
231  .Default([&](FIRRTLType type) {
232  // Properties and other types wind up here.
233  newPorts.push_back(
234  {{StringAttr::get(ctx, name), type,
235  isFlip ? Direction::Out : Direction::In,
236  symbolsForFieldIDRange(ctx, syms, fieldID, fieldID), port.loc,
237  annosForFieldIDRange(ctx, annos, fieldID, fieldID)},
238  portID,
239  newPorts.size(),
240  fieldID});
241  return success();
242  });
243 }
244 
245 // compute a new moduletype from an old module type and lowering convention.
246 // Also compute a fieldID map from port, fieldID -> port
247 static LogicalResult computeLowering(FModuleLike mod, Convention conv,
248  PortConversion &newPorts) {
249  for (auto [idx, port] : llvm::enumerate(mod.getPorts())) {
250  if (failed(computeLoweringImpl(
251  mod, newPorts, conv, idx, port, port.direction == Direction::Out,
252  port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
253  FieldIDSearch<hw::InnerSymAttr>(port.sym),
254  FieldIDSearch<AnnotationSet>(port.annotations))))
255  return failure();
256  }
257  return success();
258 }
259 
260 static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv,
261  AttrCache &cache,
262  PortConversion &newPorts) {
263  ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
264  if (computeLowering(module, conv, newPorts).failed())
265  return failure();
266  if (auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
267  Block *body = mod.getBodyBlock();
268  theBuilder.setInsertionPointToStart(body);
269  auto oldNumArgs = body->getNumArguments();
270 
271  // Compute the replacement value for old arguments
272  // This creates all the new arguments and produces bounce wires when
273  // necessary
274  SmallVector<Value> bounceWires(oldNumArgs);
275  for (auto &p : newPorts) {
276  auto newArg = body->addArgument(p.type, p.loc);
277  // Get or create a bounce wire for changed ports
278  // For unmodified ports, move the uses to the replacement port
279  if (p.fieldID != 0) {
280  auto &wire = bounceWires[p.portID];
281  if (!wire)
282  wire = theBuilder
283  .create<WireOp>(module.getPortType(p.portID),
284  module.getPortNameAttr(p.portID),
285  NameKindEnum::InterestingName)
286  .getResult();
287  } else {
288  bounceWires[p.portID] = newArg;
289  }
290  }
291  // replace old arguments. Somethings get dropped completely, like
292  // zero-length vectors.
293  for (auto idx = 0U; idx < oldNumArgs; ++idx) {
294  if (!bounceWires[idx]) {
295  bounceWires[idx] = theBuilder
296  .create<WireOp>(module.getPortType(idx),
297  module.getPortNameAttr(idx))
298  .getResult();
299  }
300  body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
301  }
302 
303  // Goodbye old ports, now ResultID in the PortInfo is correct.
304  body->eraseArguments(0, oldNumArgs);
305 
306  // Connect the bounce wires to the new arguments
307  for (auto &p : newPorts) {
308  if (isa<BlockArgument>(bounceWires[p.portID]))
309  continue;
310  if (p.isOutput())
311  emitConnect(
312  theBuilder, body->getArgument(p.resultID),
313  getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID));
314  else
315  emitConnect(
316  theBuilder,
317  getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID),
318  body->getArgument(p.resultID));
319  }
320  }
321 
322  SmallVector<NamedAttribute, 8> newModuleAttrs;
323 
324  // Copy over any attributes that weren't original argument attributes.
325  for (auto attr : module->getAttrDictionary())
326  // Drop old "portNames", directions, and argument attributes. These are
327  // handled differently below.
328  if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
329  attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
330  attr.getName() != "portSymbols" && attr.getName() != "portLocations" &&
331  attr.getName() != "internalPaths")
332  newModuleAttrs.push_back(attr);
333 
334  SmallVector<Direction> newPortDirections;
335  SmallVector<Attribute> newPortNames;
336  SmallVector<Attribute> newPortTypes;
337  SmallVector<Attribute> newPortSyms;
338  SmallVector<Attribute> newPortLocations;
339  SmallVector<Attribute, 8> newPortAnnotations;
340  SmallVector<Attribute> newInternalPaths;
341 
342  bool hasInternalPaths = false;
343  auto internalPaths = module->getAttrOfType<ArrayAttr>("internalPaths");
344  for (auto p : newPorts) {
345  newPortTypes.push_back(TypeAttr::get(p.type));
346  newPortNames.push_back(p.name);
347  newPortDirections.push_back(p.direction);
348  newPortSyms.push_back(p.sym);
349  newPortLocations.push_back(p.loc);
350  newPortAnnotations.push_back(p.annotations.getArrayAttr());
351  if (internalPaths) {
352  auto internalPath = cast<InternalPathAttr>(internalPaths[p.portID]);
353  newInternalPaths.push_back(internalPath);
354  if (internalPath.getPath())
355  hasInternalPaths = true;
356  }
357  }
358 
359  newModuleAttrs.push_back(NamedAttribute(
360  cache.sPortDirections,
361  direction::packAttribute(module.getContext(), newPortDirections)));
362 
363  newModuleAttrs.push_back(
364  NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
365 
366  newModuleAttrs.push_back(
367  NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
368 
369  newModuleAttrs.push_back(NamedAttribute(
370  cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
371 
372  newModuleAttrs.push_back(NamedAttribute(
373  cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
374 
375  assert(newInternalPaths.empty() ||
376  newInternalPaths.size() == newPorts.size());
377  if (hasInternalPaths) {
378  newModuleAttrs.emplace_back(cache.sInternalPaths,
379  theBuilder.getArrayAttr(newInternalPaths));
380  }
381 
382  // Update the module's attributes.
383  module->setAttrs(newModuleAttrs);
384  FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
385  module.setPortSymbols(newPortSyms);
386  return success();
387 }
388 
389 static void lowerModuleBody(FModuleOp mod,
390  const DenseMap<StringAttr, PortConversion> &ports) {
391  mod->walk([&](InstanceOp inst) -> void {
392  ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
393  const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
394 
395  // Fix up the Instance
396  SmallVector<PortInfo> instPorts; // Oh I wish ArrayRef was polymorphic.
397  for (auto p : modPorts) {
398  p.sym = {};
399  // Might need to partially copy stuff from the old instance.
400  p.annotations = AnnotationSet{mod.getContext()};
401  instPorts.push_back(p);
402  }
403  auto annos = inst.getAnnotations();
404  auto newOp = theBuilder.create<InstanceOp>(
405  instPorts, inst.getModuleName(), inst.getName(), inst.getNameKind(),
406  annos.getValue(), inst.getLayers(), inst.getLowerToBind(),
407  inst.getInnerSymAttr());
408 
409  auto oldDict = inst->getDiscardableAttrDictionary();
410  auto newDict = newOp->getDiscardableAttrDictionary();
411  auto oldNames = inst.getPortNamesAttr();
412  SmallVector<NamedAttribute> newAttrs;
413  for (auto na : oldDict)
414  if (!newDict.contains(na.getName()))
415  newOp->setDiscardableAttr(na.getName(), na.getValue());
416 
417  // Connect up the old instance users to the new instance
418  SmallVector<WireOp> bounce(inst.getNumResults());
419  for (auto p : modPorts) {
420  // No change? No bounce wire.
421  if (p.fieldID == 0) {
422  inst.getResult(p.portID).replaceAllUsesWith(
423  newOp.getResult(p.resultID));
424  continue;
425  }
426  if (!bounce[p.portID]) {
427  bounce[p.portID] = theBuilder.create<WireOp>(
428  inst.getResult(p.portID).getType(),
429  theBuilder.getStringAttr(
430  inst.getName() + "." +
431  cast<StringAttr>(oldNames[p.portID]).getValue()));
432  inst.getResult(p.portID).replaceAllUsesWith(
433  bounce[p.portID].getResult());
434  }
435  // Connect up the Instance to the bounce wires
436  if (p.isInput())
437  emitConnect(theBuilder, newOp.getResult(p.resultID),
438  getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
439  p.fieldID));
440  else
441  emitConnect(theBuilder,
442  getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
443  p.fieldID),
444  newOp.getResult(p.resultID));
445  }
446  // Zero Width ports may have dangling connects since they are not preserved
447  // and do not have bounce wires.
448  for (auto *use : llvm::make_early_inc_range(inst->getUsers())) {
449  assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
450  use->erase();
451  }
452  inst->erase();
453  return;
454  });
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // Pass Infrastructure
459 //===----------------------------------------------------------------------===//
460 
461 namespace {
462 struct LowerSignaturesPass
463  : public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
464  void runOnOperation() override;
465 };
466 } // end anonymous namespace
467 
468 // This is the main entrypoint for the lowering pass.
469 void LowerSignaturesPass::runOnOperation() {
470  LLVM_DEBUG(debugPassHeader(this) << "\n");
471  // Cached attr
472  AttrCache cache(&getContext());
473 
474  DenseMap<StringAttr, PortConversion> portMap;
475  auto circuit = getOperation();
476 
477  for (auto mod : circuit.getOps<FModuleLike>()) {
478  if (lowerModuleSignature(mod, mod.getConvention(), cache,
479  portMap[mod.getNameAttr()])
480  .failed())
481  return signalPassFailure();
482  }
483  parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
484  [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
485 }
486 
487 /// This is the pass constructor.
488 std::unique_ptr<mlir::Pass> circt::firrtl::createLowerSignaturesPass() {
489  return std::make_unique<LowerSignaturesPass>();
490 }
assert(baseType &&"element must be base type")
static InstancePath empty
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
Base class for the port conversion of a particular port.
Definition: PortConverter.h:97
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
std::unique_ptr< mlir::Pass > createLowerSignaturesPass()
This is the pass constructor.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:25
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31
This holds the name and type that describes the module's ports.