CIRCT  20.0.0git
LowerSignatures.cpp
Go to the documentation of this file.
1 //===- LowerSignatures.cpp - Lower Module Signatures ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the LowerSignatures pass. This pass replaces aggregate
10 // types with expanded values in module arguments as specified by the ABI
11 // information.
12 //
13 //===----------------------------------------------------------------------===//
14 
17 #include "mlir/Pass/Pass.h"
18 
30 #include "circt/Dialect/SV/SVOps.h"
31 #include "circt/Support/Debug.h"
32 #include "mlir/IR/ImplicitLocOpBuilder.h"
33 #include "mlir/IR/Threading.h"
34 #include "llvm/ADT/APSInt.h"
35 #include "llvm/ADT/BitVector.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/Parallel.h"
38 
39 #define DEBUG_TYPE "firrtl-lower-signatures"
40 
41 namespace circt {
42 namespace firrtl {
43 #define GEN_PASS_DEF_LOWERSIGNATURES
44 #include "circt/Dialect/FIRRTL/Passes.h.inc"
45 } // namespace firrtl
46 } // namespace circt
47 
48 using namespace circt;
49 using namespace firrtl;
50 
51 //===----------------------------------------------------------------------===//
52 // Module Type Lowering
53 //===----------------------------------------------------------------------===//
54 namespace {
55 
56 struct AttrCache {
57  AttrCache(MLIRContext *context) {
58  nameAttr = StringAttr::get(context, "name");
59  sPortDirections = StringAttr::get(context, "portDirections");
60  sPortNames = StringAttr::get(context, "portNames");
61  sPortTypes = StringAttr::get(context, "portTypes");
62  sPortLocations = StringAttr::get(context, "portLocations");
63  sPortAnnotations = StringAttr::get(context, "portAnnotations");
64  sInternalPaths = StringAttr::get(context, "internalPaths");
65  }
66  AttrCache(const AttrCache &) = default;
67 
68  StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
69  sPortAnnotations, sInternalPaths;
70 };
71 
72 struct FieldMapEntry : public PortInfo {
73  size_t portID;
74  size_t resultID;
75  size_t fieldID;
76 };
77 
78 using PortConversion = SmallVector<FieldMapEntry>;
79 
80 template <typename T>
81 class FieldIDSearch {
82  using E = typename T::ElementType;
83  using V = SmallVector<E>;
84 
85 public:
86  using const_iterator = typename V::const_iterator;
87 
88  template <typename Container>
89  FieldIDSearch(const Container &src) {
90  if constexpr (std::is_convertible_v<Container, Attribute>)
91  if (!src)
92  return;
93  for (auto attr : src)
94  vals.push_back(attr);
95  std::sort(vals.begin(), vals.end(), fieldComp);
96  }
97 
98  std::pair<const_iterator, const_iterator> find(uint64_t low,
99  uint64_t high) const {
100  return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
101  std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
102  }
103 
104  bool empty(uint64_t low, uint64_t high) const {
105  auto [b, e] = find(low, high);
106  return b == e;
107  }
108 
109 private:
110  static constexpr auto fieldComp = [](const E &lhs, const E &rhs) {
111  return lhs.getFieldID() < rhs.getFieldID();
112  };
113  static constexpr auto fieldCompInt2 = [](const E &lhs, uint64_t rhs) {
114  return lhs.getFieldID() < rhs;
115  };
116  static constexpr auto fieldCompInt1 = [](uint64_t lhs, const E &rhs) {
117  return lhs < rhs.getFieldID();
118  };
119 
120  V vals;
121 };
122 
123 } // namespace
124 
125 static hw::InnerSymAttr
126 symbolsForFieldIDRange(MLIRContext *ctx,
127  const FieldIDSearch<hw::InnerSymAttr> &syms,
128  uint64_t low, uint64_t high) {
129  auto [b, e] = syms.find(low, high);
130  SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
131  if (newSyms.empty())
132  return {};
133  for (auto &sym : newSyms)
135  ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
136  return hw::InnerSymAttr::get(ctx, newSyms);
137 }
138 
139 static AnnotationSet
140 annosForFieldIDRange(MLIRContext *ctx,
141  const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
142  uint64_t high) {
143  AnnotationSet newAnnos(ctx);
144  auto [b, e] = annos.find(low, high);
145  for (; b != e; ++b)
146  newAnnos.addAnnotations(Annotation(*b, b->getFieldID() - low));
147  return newAnnos;
148 }
149 
150 static LogicalResult
151 computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv,
152  size_t portID, const PortInfo &port, bool isFlip,
153  Twine name, FIRRTLType type, uint64_t fieldID,
154  const FieldIDSearch<hw::InnerSymAttr> &syms,
155  const FieldIDSearch<AnnotationSet> &annos) {
156  auto *ctx = type.getContext();
158  .Case<BundleType>([&](BundleType bundle) -> LogicalResult {
159  // This should be enhanced to be able to handle bundle<all flips of
160  // passive>, or this should be a canonicalizer
161  if (conv != Convention::Scalarized && bundle.isPassive()) {
162  auto lastId = fieldID + bundle.getMaxFieldID();
163  newPorts.push_back(
164  {{StringAttr::get(ctx, name), type,
165  isFlip ? Direction::Out : Direction::In,
166  symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
167  annosForFieldIDRange(ctx, annos, fieldID, lastId)},
168  portID,
169  newPorts.size(),
170  fieldID});
171  } else {
172  for (auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
173  if (failed(computeLoweringImpl(
174  mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
175  name + "_" + elem.name.getValue(), elem.type,
176  fieldID + bundle.getFieldID(idx), syms, annos)))
177  return failure();
178  if (!syms.empty(fieldID, fieldID))
179  return mod.emitError("Port [")
180  << port.name
181  << "] should be subdivided, but cannot be because of "
182  "symbol ["
183  << port.sym.getSymIfExists(fieldID) << "] on a bundle";
184  if (!annos.empty(fieldID, fieldID)) {
185  auto err = mod.emitError("Port [")
186  << port.name
187  << "] should be subdivided, but cannot be because of "
188  "annotations [";
189  auto [b, e] = annos.find(fieldID, fieldID);
190  err << b->getClass() << "(" << b->getFieldID() << ")";
191  b++;
192  for (; b != e; ++b)
193  err << ", " << b->getClass() << "(" << b->getFieldID() << ")";
194  err << "] on a bundle";
195  return err;
196  }
197  }
198  }
199  return success();
200  })
201  .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
202  if (conv != Convention::Scalarized &&
203  vector.getElementType().isPassive()) {
204  auto lastId = fieldID + vector.getMaxFieldID();
205  newPorts.push_back(
206  {{StringAttr::get(ctx, name), type,
207  isFlip ? Direction::Out : Direction::In,
208  symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
209  annosForFieldIDRange(ctx, annos, fieldID, lastId)},
210  portID,
211  newPorts.size(),
212  fieldID});
213  } else {
214  for (size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
215  if (failed(computeLoweringImpl(
216  mod, newPorts, conv, portID, port, isFlip,
217  name + "_" + Twine(i), vector.getElementType(),
218  fieldID + vector.getFieldID(i), syms, annos)))
219  return failure();
220  if (!syms.empty(fieldID, fieldID))
221  return mod.emitError("Port [")
222  << port.name
223  << "] should be subdivided, but cannot be because of "
224  "symbol ["
225  << port.sym.getSymIfExists(fieldID) << "] on a vector";
226  if (!annos.empty(fieldID, fieldID)) {
227  auto err = mod.emitError("Port [")
228  << port.name
229  << "] should be subdivided, but cannot be because of "
230  "annotations [";
231  auto [b, e] = annos.find(fieldID, fieldID);
232  err << b->getClass();
233  ++b;
234  for (; b != e; ++b)
235  err << ", " << b->getClass();
236  err << "] on a vector";
237  return err;
238  }
239  }
240  }
241  return success();
242  })
243  .Case<FEnumType>([&](FEnumType fenum) { return failure(); })
244  .Default([&](FIRRTLType type) {
245  // Properties and other types wind up here.
246  newPorts.push_back(
247  {{StringAttr::get(ctx, name), type,
248  isFlip ? Direction::Out : Direction::In,
249  symbolsForFieldIDRange(ctx, syms, fieldID, fieldID), port.loc,
250  annosForFieldIDRange(ctx, annos, fieldID, fieldID)},
251  portID,
252  newPorts.size(),
253  fieldID});
254  return success();
255  });
256 }
257 
258 // compute a new moduletype from an old module type and lowering convention.
259 // Also compute a fieldID map from port, fieldID -> port
260 static LogicalResult computeLowering(FModuleLike mod, Convention conv,
261  PortConversion &newPorts) {
262  for (auto [idx, port] : llvm::enumerate(mod.getPorts())) {
263  if (failed(computeLoweringImpl(
264  mod, newPorts, conv, idx, port, port.direction == Direction::Out,
265  port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
266  FieldIDSearch<hw::InnerSymAttr>(port.sym),
267  FieldIDSearch<AnnotationSet>(port.annotations))))
268  return failure();
269  }
270  return success();
271 }
272 
273 static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv,
274  AttrCache &cache,
275  PortConversion &newPorts) {
276  ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
277  if (computeLowering(module, conv, newPorts).failed())
278  return failure();
279  if (auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
280  Block *body = mod.getBodyBlock();
281  theBuilder.setInsertionPointToStart(body);
282  auto oldNumArgs = body->getNumArguments();
283 
284  // Compute the replacement value for old arguments
285  // This creates all the new arguments and produces bounce wires when
286  // necessary
287  SmallVector<Value> bounceWires(oldNumArgs);
288  for (auto &p : newPorts) {
289  auto newArg = body->addArgument(p.type, p.loc);
290  // Get or create a bounce wire for changed ports
291  // For unmodified ports, move the uses to the replacement port
292  if (p.fieldID != 0) {
293  auto &wire = bounceWires[p.portID];
294  if (!wire)
295  wire = theBuilder
296  .create<WireOp>(module.getPortType(p.portID),
297  module.getPortNameAttr(p.portID),
298  NameKindEnum::InterestingName)
299  .getResult();
300  } else {
301  bounceWires[p.portID] = newArg;
302  }
303  }
304  // replace old arguments. Somethings get dropped completely, like
305  // zero-length vectors.
306  for (auto idx = 0U; idx < oldNumArgs; ++idx) {
307  if (!bounceWires[idx]) {
308  bounceWires[idx] = theBuilder
309  .create<WireOp>(module.getPortType(idx),
310  module.getPortNameAttr(idx))
311  .getResult();
312  }
313  body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
314  }
315 
316  // Goodbye old ports, now ResultID in the PortInfo is correct.
317  body->eraseArguments(0, oldNumArgs);
318 
319  // Connect the bounce wires to the new arguments
320  for (auto &p : newPorts) {
321  if (isa<BlockArgument>(bounceWires[p.portID]))
322  continue;
323  if (p.isOutput())
324  emitConnect(
325  theBuilder, body->getArgument(p.resultID),
326  getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID));
327  else
328  emitConnect(
329  theBuilder,
330  getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID),
331  body->getArgument(p.resultID));
332  }
333  }
334 
335  SmallVector<NamedAttribute, 8> newModuleAttrs;
336 
337  // Copy over any attributes that weren't original argument attributes.
338  for (auto attr : module->getAttrDictionary())
339  // Drop old "portNames", directions, and argument attributes. These are
340  // handled differently below.
341  if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
342  attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
343  attr.getName() != "portSymbols" && attr.getName() != "portLocations" &&
344  attr.getName() != "internalPaths")
345  newModuleAttrs.push_back(attr);
346 
347  SmallVector<Direction> newPortDirections;
348  SmallVector<Attribute> newPortNames;
349  SmallVector<Attribute> newPortTypes;
350  SmallVector<Attribute> newPortSyms;
351  SmallVector<Attribute> newPortLocations;
352  SmallVector<Attribute, 8> newPortAnnotations;
353  SmallVector<Attribute> newInternalPaths;
354 
355  bool hasInternalPaths = false;
356  auto internalPaths = module->getAttrOfType<ArrayAttr>("internalPaths");
357  for (auto p : newPorts) {
358  newPortTypes.push_back(TypeAttr::get(p.type));
359  newPortNames.push_back(p.name);
360  newPortDirections.push_back(p.direction);
361  newPortSyms.push_back(p.sym);
362  newPortLocations.push_back(p.loc);
363  newPortAnnotations.push_back(p.annotations.getArrayAttr());
364  if (internalPaths) {
365  auto internalPath = cast<InternalPathAttr>(internalPaths[p.portID]);
366  newInternalPaths.push_back(internalPath);
367  if (internalPath.getPath())
368  hasInternalPaths = true;
369  }
370  }
371 
372  newModuleAttrs.push_back(NamedAttribute(
373  cache.sPortDirections,
374  direction::packAttribute(module.getContext(), newPortDirections)));
375 
376  newModuleAttrs.push_back(
377  NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
378 
379  newModuleAttrs.push_back(
380  NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
381 
382  newModuleAttrs.push_back(NamedAttribute(
383  cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
384 
385  newModuleAttrs.push_back(NamedAttribute(
386  cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
387 
388  assert(newInternalPaths.empty() ||
389  newInternalPaths.size() == newPorts.size());
390  if (hasInternalPaths) {
391  newModuleAttrs.emplace_back(cache.sInternalPaths,
392  theBuilder.getArrayAttr(newInternalPaths));
393  }
394 
395  // Update the module's attributes.
396  module->setAttrs(newModuleAttrs);
397  FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
398  module.setPortSymbols(newPortSyms);
399  return success();
400 }
401 
402 static void lowerModuleBody(FModuleOp mod,
403  const DenseMap<StringAttr, PortConversion> &ports) {
404  mod->walk([&](InstanceOp inst) -> void {
405  ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
406  const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
407 
408  // Fix up the Instance
409  SmallVector<PortInfo> instPorts; // Oh I wish ArrayRef was polymorphic.
410  for (auto p : modPorts) {
411  p.sym = {};
412  // Might need to partially copy stuff from the old instance.
413  p.annotations = AnnotationSet{mod.getContext()};
414  instPorts.push_back(p);
415  }
416  auto annos = inst.getAnnotations();
417  auto newOp = theBuilder.create<InstanceOp>(
418  instPorts, inst.getModuleName(), inst.getName(), inst.getNameKind(),
419  annos.getValue(), inst.getLayers(), inst.getLowerToBind(),
420  inst.getInnerSymAttr());
421 
422  auto oldDict = inst->getDiscardableAttrDictionary();
423  auto newDict = newOp->getDiscardableAttrDictionary();
424  auto oldNames = inst.getPortNamesAttr();
425  SmallVector<NamedAttribute> newAttrs;
426  for (auto na : oldDict)
427  if (!newDict.contains(na.getName()))
428  newOp->setDiscardableAttr(na.getName(), na.getValue());
429 
430  // Connect up the old instance users to the new instance
431  SmallVector<WireOp> bounce(inst.getNumResults());
432  for (auto p : modPorts) {
433  // No change? No bounce wire.
434  if (p.fieldID == 0) {
435  inst.getResult(p.portID).replaceAllUsesWith(
436  newOp.getResult(p.resultID));
437  continue;
438  }
439  if (!bounce[p.portID]) {
440  bounce[p.portID] = theBuilder.create<WireOp>(
441  inst.getResult(p.portID).getType(),
442  theBuilder.getStringAttr(
443  inst.getName() + "." +
444  cast<StringAttr>(oldNames[p.portID]).getValue()));
445  inst.getResult(p.portID).replaceAllUsesWith(
446  bounce[p.portID].getResult());
447  }
448  // Connect up the Instance to the bounce wires
449  if (p.isInput())
450  emitConnect(theBuilder, newOp.getResult(p.resultID),
451  getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
452  p.fieldID));
453  else
454  emitConnect(theBuilder,
455  getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
456  p.fieldID),
457  newOp.getResult(p.resultID));
458  }
459  // Zero Width ports may have dangling connects since they are not preserved
460  // and do not have bounce wires.
461  for (auto *use : llvm::make_early_inc_range(inst->getUsers())) {
462  assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
463  use->erase();
464  }
465  inst->erase();
466  return;
467  });
468 }
469 
470 //===----------------------------------------------------------------------===//
471 // Pass Infrastructure
472 //===----------------------------------------------------------------------===//
473 
474 namespace {
475 struct LowerSignaturesPass
476  : public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
477  void runOnOperation() override;
478 };
479 } // end anonymous namespace
480 
481 // This is the main entrypoint for the lowering pass.
482 void LowerSignaturesPass::runOnOperation() {
483  LLVM_DEBUG(debugPassHeader(this) << "\n");
484  // Cached attr
485  AttrCache cache(&getContext());
486 
487  DenseMap<StringAttr, PortConversion> portMap;
488  auto circuit = getOperation();
489 
490  for (auto mod : circuit.getOps<FModuleLike>()) {
491  if (lowerModuleSignature(mod, mod.getConvention(), cache,
492  portMap[mod.getNameAttr()])
493  .failed())
494  return signalPassFailure();
495  }
496  parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
497  [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
498 }
499 
500 /// This is the pass constructor.
501 std::unique_ptr<mlir::Pass> circt::firrtl::createLowerSignaturesPass() {
502  return std::make_unique<LowerSignaturesPass>();
503 }
assert(baseType &&"element must be base type")
static InstancePath empty
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
Base class for the port conversion of a particular port.
Definition: PortConverter.h:97
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
std::unique_ptr< mlir::Pass > createLowerSignaturesPass()
This is the pass constructor.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:25
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31
This holds the name and type that describes the module's ports.