1 //===- LowerMemory.cpp - Lower Memories -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This file defines the LowerMemories pass.
9 //
10 //===----------------------------------------------------------------------===//
18 #include "circt/Dialect/HW/HWOps.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/STLFunctionalExtras.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/Support/Parallel.h"
28 #include <optional>
29 #include <set>
31 namespace circt {
32 namespace firrtl {
34 #include "circt/Dialect/FIRRTL/"
35 } // namespace firrtl
36 } // namespace circt
38 using namespace circt;
39 using namespace firrtl;
41 // Extract all the relevant attributes from the MemOp and return the FirMemory.
42 FirMemory getSummary(MemOp op) {
43  size_t numReadPorts = 0;
44  size_t numWritePorts = 0;
45  size_t numReadWritePorts = 0;
47  SmallVector<int32_t> writeClockIDs;
49  for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
50  auto portKind = op.getPortKind(i);
51  if (portKind == MemOp::PortKind::Read)
52  ++numReadPorts;
53  else if (portKind == MemOp::PortKind::Write) {
54  for (auto *a : op.getResult(i).getUsers()) {
55  auto subfield = dyn_cast<SubfieldOp>(a);
56  if (!subfield || subfield.getFieldIndex() != 2)
57  continue;
58  auto clockPort = a->getResult(0);
59  for (auto *b : clockPort.getUsers()) {
60  if (auto connect = dyn_cast<FConnectLike>(b)) {
61  if (connect.getDest() == clockPort) {
62  auto result =
63  clockToLeader.insert({connect.getSrc(), numWritePorts});
64  if (result.second) {
65  writeClockIDs.push_back(numWritePorts);
66  } else {
67  writeClockIDs.push_back(result.first->second);
68  }
69  }
70  }
71  }
72  break;
73  }
74  ++numWritePorts;
75  } else
76  ++numReadWritePorts;
77  }
79  auto width = op.getDataType().getBitWidthOrSentinel();
80  if (width <= 0) {
81  op.emitError("'firrtl.mem' should have simple type and known width");
82  width = 0;
83  }
84  return {numReadPorts,
85  numWritePorts,
86  numReadWritePorts,
87  (size_t)width,
88  op.getDepth(),
89  op.getReadLatency(),
90  op.getWriteLatency(),
91  op.getMaskBits(),
92  *seq::symbolizeRUW(unsigned(op.getRuw())),
93  seq::WUW::PortOrder,
94  writeClockIDs,
95  op.getNameAttr(),
96  op.getMaskBits() > 1,
97  op.getInitAttr(),
98  op.getPrefixAttr(),
99  op.getLoc()};
100 }
102 namespace {
103 struct LowerMemoryPass
104  : public circt::firrtl::impl::LowerMemoryBase<LowerMemoryPass> {
106  /// Get the cached namespace for a module.
107  hw::InnerSymbolNamespace &getModuleNamespace(FModuleLike module) {
108  return moduleNamespaces.try_emplace(module, module).first->second;
109  }
111  SmallVector<PortInfo> getMemoryModulePorts(const FirMemory &mem);
112  FMemModuleOp emitMemoryModule(MemOp op, const FirMemory &summary,
113  const SmallVectorImpl<PortInfo> &ports);
114  FMemModuleOp getOrCreateMemModule(MemOp op, const FirMemory &summary,
115  const SmallVectorImpl<PortInfo> &ports,
116  bool shouldDedup);
117  FModuleOp createWrapperModule(MemOp op, const FirMemory &summary,
118  bool shouldDedup);
119  InstanceOp emitMemoryInstance(MemOp op, FModuleOp module,
120  const FirMemory &summary);
121  void lowerMemory(MemOp mem, const FirMemory &summary, bool shouldDedup);
122  LogicalResult runOnModule(FModuleOp module, bool shouldDedup);
123  void runOnOperation() override;
125  /// Cached module namespaces.
126  DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
127  CircuitNamespace circuitNamespace;
128  SymbolTable *symbolTable;
130  /// The set of all memories seen so far. This is used to "deduplicate"
131  /// memories by emitting modules one module for equivalent memories.
132  std::map<FirMemory, FMemModuleOp> memories;
133 };
134 } // end anonymous namespace
136 SmallVector<PortInfo>
137 LowerMemoryPass::getMemoryModulePorts(const FirMemory &mem) {
138  auto *context = &getContext();
140  // We don't need a single bit mask, it can be combined with enable. Create
141  // an unmasked memory if maskBits = 1.
142  FIRRTLType u1Type = UIntType::get(context, 1);
143  FIRRTLType dataType = UIntType::get(context, mem.dataWidth);
144  FIRRTLType maskType = UIntType::get(context, mem.maskBits);
145  FIRRTLType addrType =
146  UIntType::get(context, std::max(1U, llvm::Log2_64_Ceil(mem.depth)));
147  FIRRTLType clockType = ClockType::get(context);
148  Location loc = UnknownLoc::get(context);
149  AnnotationSet annotations = AnnotationSet(context);
151  SmallVector<PortInfo> ports;
152  auto addPort = [&](const Twine &name, FIRRTLType type, Direction direction) {
153  auto nameAttr = StringAttr::get(context, name);
154  ports.push_back(
155  {nameAttr, type, direction, hw::InnerSymAttr{}, loc, annotations});
156  };
158  auto makePortCommon = [&](StringRef prefix, size_t idx, FIRRTLType addrType) {
159  addPort(prefix + Twine(idx) + "_addr", addrType, Direction::In);
160  addPort(prefix + Twine(idx) + "_en", u1Type, Direction::In);
161  addPort(prefix + Twine(idx) + "_clk", clockType, Direction::In);
162  };
164  for (size_t i = 0, e = mem.numReadPorts; i != e; ++i) {
165  makePortCommon("R", i, addrType);
166  addPort("R" + Twine(i) + "_data", dataType, Direction::Out);
167  }
168  for (size_t i = 0, e = mem.numReadWritePorts; i != e; ++i) {
169  makePortCommon("RW", i, addrType);
170  addPort("RW" + Twine(i) + "_wmode", u1Type, Direction::In);
171  addPort("RW" + Twine(i) + "_wdata", dataType, Direction::In);
172  addPort("RW" + Twine(i) + "_rdata", dataType, Direction::Out);
173  // Ignore mask port, if maskBits =1
174  if (mem.isMasked)
175  addPort("RW" + Twine(i) + "_wmask", maskType, Direction::In);
176  }
178  for (size_t i = 0, e = mem.numWritePorts; i != e; ++i) {
179  makePortCommon("W", i, addrType);
180  addPort("W" + Twine(i) + "_data", dataType, Direction::In);
181  // Ignore mask port, if maskBits =1
182  if (mem.isMasked)
183  addPort("W" + Twine(i) + "_mask", maskType, Direction::In);
184  }
186  return ports;
187 }
189 FMemModuleOp
190 LowerMemoryPass::emitMemoryModule(MemOp op, const FirMemory &mem,
191  const SmallVectorImpl<PortInfo> &ports) {
192  // Get a non-colliding name for the memory module, and update the summary.
193  auto newName = circuitNamespace.newName(mem.modName.getValue(), "ext");
194  auto moduleName = StringAttr::get(&getContext(), newName);
196  // Insert the memory module at the bottom of the circuit.
197  auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock());
198  ++numCreatedMemModules;
199  auto moduleOp = b.create<FMemModuleOp>(
200  mem.loc, moduleName, ports, mem.numReadPorts, mem.numWritePorts,
201  mem.numReadWritePorts, mem.dataWidth, mem.maskBits, mem.readLatency,
202  mem.writeLatency, mem.depth);
203  SymbolTable::setSymbolVisibility(moduleOp, SymbolTable::Visibility::Private);
204  return moduleOp;
205 }
207 FMemModuleOp
208 LowerMemoryPass::getOrCreateMemModule(MemOp op, const FirMemory &summary,
209  const SmallVectorImpl<PortInfo> &ports,
210  bool shouldDedup) {
211  // Try to find a matching memory blackbox that we already created. If
212  // shouldDedup is true, we will just generate a new memory module.
213  if (shouldDedup) {
214  auto it = memories.find(summary);
215  if (it != memories.end())
216  return it->second;
217  }
219  // Create a new module for this memory. This can update the name recorded in
220  // the memory's summary.
221  auto module = emitMemoryModule(op, summary, ports);
223  // Record the memory module. We don't want to use this module for other
224  // memories, then we don't add it to the table.
225  if (shouldDedup)
226  memories[summary] = module;
228  return module;
229 }
231 void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
232  bool shouldDedup) {
233  auto *context = &getContext();
234  auto ports = getMemoryModulePorts(summary);
236  // Get a non-colliding name for the memory module, and update the summary.
237  auto newName = circuitNamespace.newName(mem.getName());
238  auto wrapperName = StringAttr::get(&getContext(), newName);
240  // Create the wrapper module, inserting it into the bottom of the circuit.
241  auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock());
242  auto wrapper = b.create<FModuleOp>(
243  mem->getLoc(), wrapperName,
244  ConventionAttr::get(context, Convention::Internal), ports);
245  SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);
247  // Create an instance of the external memory module. The instance has the
248  // same name as the target module.
249  auto memModule = getOrCreateMemModule(mem, summary, ports, shouldDedup);
250  b.setInsertionPointToStart(wrapper.getBodyBlock());
252  auto memInst =
253  b.create<InstanceOp>(mem->getLoc(), memModule, memModule.getModuleName(),
254  mem.getNameKind(), mem.getAnnotations().getValue());
256  // Wire all the ports together.
257  for (auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(),
258  memInst.getResults())) {
259  if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out)
260  b.create<MatchingConnectOp>(mem->getLoc(), dst, src);
261  else
262  b.create<MatchingConnectOp>(mem->getLoc(), src, dst);
263  }
265  // Create an instance of the wrapper memory module, which will replace the
266  // original mem op.
267  auto inst = emitMemoryInstance(mem, wrapper, summary);
269  // We fixup the annotations here. We will be copying all annotations on to the
270  // module op, so we have to fix up the NLA to have the module as the leaf
271  // element.
273  auto leafSym = memModule.getModuleNameAttr();
274  auto leafAttr = FlatSymbolRefAttr::get(wrapper.getModuleNameAttr());
276  // NLAs that we have already processed.
278  auto nonlocalAttr = StringAttr::get(context, "circt.nonlocal");
279  bool nlaUpdated = false;
280  SmallVector<Annotation> newMemModAnnos;
281  OpBuilder nlaBuilder(context);
283  AnnotationSet::removeAnnotations(memInst, [&](Annotation anno) -> bool {
284  // We're only looking for non-local annotations.
285  auto nlaSym = anno.getMember<FlatSymbolRefAttr>(nonlocalAttr);
286  if (!nlaSym)
287  return false;
288  // If we have already seen this NLA, don't re-process it.
289  auto newNLAIter = processedNLAs.find(nlaSym.getAttr());
290  StringAttr newNLAName;
291  if (newNLAIter == processedNLAs.end()) {
293  // Update the NLA path to have the additional wrapper module.
294  auto nla =
295  dyn_cast<hw::HierPathOp>(symbolTable->lookup(nlaSym.getAttr()));
296  auto namepath = nla.getNamepath().getValue();
297  SmallVector<Attribute> newNamepath(namepath.begin(), namepath.end());
298  if (!nla.isComponent())
299  newNamepath.back() =
300  getInnerRefTo(inst, [&](auto mod) -> hw::InnerSymbolNamespace & {
301  return getModuleNamespace(mod);
302  });
303  newNamepath.push_back(leafAttr);
305  nlaBuilder.setInsertionPointAfter(nla);
306  auto newNLA = cast<hw::HierPathOp>(nlaBuilder.clone(*nla));
307  newNLA.setSymNameAttr(StringAttr::get(
308  context, circuitNamespace.newName(nla.getNameAttr().getValue())));
309  newNLA.setNamepathAttr(ArrayAttr::get(context, newNamepath));
310  newNLAName = newNLA.getNameAttr();
311  processedNLAs[nlaSym.getAttr()] = newNLAName;
312  } else
313  newNLAName = newNLAIter->getSecond();
314  anno.setMember("circt.nonlocal", FlatSymbolRefAttr::get(newNLAName));
315  nlaUpdated = true;
316  newMemModAnnos.push_back(anno);
317  return true;
318  });
319  if (nlaUpdated) {
320  memInst.setInnerSymAttr(hw::InnerSymAttr::get(leafSym));
321  AnnotationSet newAnnos(memInst);
322  newAnnos.addAnnotations(newMemModAnnos);
323  newAnnos.applyToOperation(memInst);
324  }
325  mem->erase();
326  ++numLoweredMems;
327 }
329 static SmallVector<SubfieldOp> getAllFieldAccesses(Value structValue,
330  StringRef field) {
331  SmallVector<SubfieldOp> accesses;
332  for (auto *op : structValue.getUsers()) {
333  assert(isa<SubfieldOp>(op));
334  auto fieldAccess = cast<SubfieldOp>(op);
335  auto elemIndex =
336  fieldAccess.getInput().getType().base().getElementIndex(field);
337  if (elemIndex && *elemIndex == fieldAccess.getFieldIndex())
338  accesses.push_back(fieldAccess);
339  }
340  return accesses;
341 }
343 InstanceOp LowerMemoryPass::emitMemoryInstance(MemOp op, FModuleOp module,
344  const FirMemory &summary) {
345  OpBuilder builder(op);
346  auto *context = &getContext();
347  auto memName = op.getName();
348  if (memName.empty())
349  memName = "mem";
351  // Process each port in turn.
352  SmallVector<Type, 8> portTypes;
353  SmallVector<Direction> portDirections;
354  SmallVector<Attribute> portNames;
355  DenseMap<Operation *, size_t> returnHolder;
356  mlir::DominanceInfo domInfo(op->getParentOfType<FModuleOp>());
358  // The result values of the memory are not necessarily in the same order as
359  // the memory module that we're lowering to. We need to lower the read
360  // ports before the read/write ports, before the write ports.
361  for (unsigned memportKindIdx = 0; memportKindIdx != 3; ++memportKindIdx) {
362  MemOp::PortKind memportKind = MemOp::PortKind::Read;
363  auto *portLabel = "R";
364  switch (memportKindIdx) {
365  default:
366  break;
367  case 1:
368  memportKind = MemOp::PortKind::ReadWrite;
369  portLabel = "RW";
370  break;
371  case 2:
372  memportKind = MemOp::PortKind::Write;
373  portLabel = "W";
374  break;
375  }
377  // This is set to the count of the kind of memport we're emitting, for
378  // label names.
379  unsigned portNumber = 0;
381  // Get an unsigned type with the specified width.
382  auto getType = [&](size_t width) { return UIntType::get(context, width); };
383  auto ui1Type = getType(1);
384  auto addressType = getType(std::max(1U, llvm::Log2_64_Ceil(summary.depth)));
385  auto dataType = UIntType::get(context, summary.dataWidth);
386  auto clockType = ClockType::get(context);
388  // Memories return multiple structs, one for each port, which means we
389  // have two layers of type to split apart.
390  for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
391  // Process all of one kind before the next.
392  if (memportKind != op.getPortKind(i))
393  continue;
395  auto addPort = [&](Direction direction, StringRef field, Type portType) {
396  // Map subfields of the memory port to module ports.
397  auto accesses = getAllFieldAccesses(op.getResult(i), field);
398  for (auto a : accesses)
399  returnHolder[a] = portTypes.size();
400  // Record the new port information.
401  portTypes.push_back(portType);
402  portDirections.push_back(direction);
403  portNames.push_back(
404  builder.getStringAttr(portLabel + Twine(portNumber) + "_" + field));
405  };
407  auto getDriver = [&](StringRef field) -> Operation * {
408  auto accesses = getAllFieldAccesses(op.getResult(i), field);
409  for (auto a : accesses) {
410  for (auto *user : a->getUsers()) {
411  // If this is a connect driving a value to the field, return it.
412  if (auto connect = dyn_cast<FConnectLike>(user);
413  connect && connect.getDest() == a)
414  return connect;
415  }
416  }
417  return nullptr;
418  };
420  // Find the value connected to the enable and 'and' it with the mask,
421  // and then remove the mask entirely. This is used to remove the mask when
422  // it is 1 bit.
423  auto removeMask = [&](StringRef enable, StringRef mask) {
424  // Get the connect which drives a value to the mask element.
425  auto *maskConnect = getDriver(mask);
426  if (!maskConnect)
427  return;
428  // Get the connect which drives a value to the en element
429  auto *enConnect = getDriver(enable);
430  if (!enConnect)
431  return;
432  // Find the proper place to create the And operation. The mask and en
433  // signals must both dominate the new operation.
434  OpBuilder b(maskConnect);
435  if (domInfo.dominates(maskConnect, enConnect))
436  b.setInsertionPoint(enConnect);
437  // 'and' the enable and mask signals together and use it as the enable.
438  auto andOp = b.create<AndPrimOp>(
439  op->getLoc(), maskConnect->getOperand(1), enConnect->getOperand(1));
440  enConnect->setOperand(1, andOp);
441  enConnect->moveAfter(andOp);
442  // Erase the old mask connect.
443  auto *maskField = maskConnect->getOperand(0).getDefiningOp();
444  maskConnect->erase();
445  maskField->erase();
446  };
448  if (memportKind == MemOp::PortKind::Read) {
449  addPort(Direction::In, "addr", addressType);
450  addPort(Direction::In, "en", ui1Type);
451  addPort(Direction::In, "clk", clockType);
452  addPort(Direction::Out, "data", dataType);
453  } else if (memportKind == MemOp::PortKind::ReadWrite) {
454  addPort(Direction::In, "addr", addressType);
455  addPort(Direction::In, "en", ui1Type);
456  addPort(Direction::In, "clk", clockType);
457  addPort(Direction::In, "wmode", ui1Type);
458  addPort(Direction::In, "wdata", dataType);
459  addPort(Direction::Out, "rdata", dataType);
460  // Ignore mask port, if maskBits =1
461  if (summary.isMasked)
462  addPort(Direction::In, "wmask", getType(summary.maskBits));
463  else
464  removeMask("wmode", "wmask");
465  } else {
466  addPort(Direction::In, "addr", addressType);
467  addPort(Direction::In, "en", ui1Type);
468  addPort(Direction::In, "clk", clockType);
469  addPort(Direction::In, "data", dataType);
470  // Ignore mask port, if maskBits == 1
471  if (summary.isMasked)
472  addPort(Direction::In, "mask", getType(summary.maskBits));
473  else
474  removeMask("en", "mask");
475  }
477  ++portNumber;
478  }
479  }
481  // Create the instance to replace the memop. The instance name matches the
482  // name of the original memory module before deduplication.
483  // TODO: how do we lower port annotations?
484  auto inst = builder.create<InstanceOp>(
485  op.getLoc(), portTypes, module.getNameAttr(), summary.getFirMemoryName(),
486  op.getNameKind(), portDirections, portNames,
487  /*annotations=*/ArrayRef<Attribute>(),
488  /*portAnnotations=*/ArrayRef<Attribute>(),
489  /*layers=*/ArrayRef<Attribute>(), /*lowerToBind=*/false,
490  op.getInnerSymAttr());
492  // Update all users of the result of read ports
493  for (auto [subfield, result] : returnHolder) {
494  subfield->getResult(0).replaceAllUsesWith(inst.getResult(result));
495  subfield->erase();
496  }
498  return inst;
499 }
501 LogicalResult LowerMemoryPass::runOnModule(FModuleOp module, bool shouldDedup) {
502  for (auto op :
503  llvm::make_early_inc_range(module.getBodyBlock()->getOps<MemOp>())) {
504  // Check that the memory has been properly lowered already.
505  if (!type_isa<UIntType>(op.getDataType()))
506  return op->emitError(
507  "memories should be flattened before running LowerMemory");
509  auto summary = getSummary(op);
510  if (!summary.isSeqMem())
511  continue;
513  lowerMemory(op, summary, shouldDedup);
514  }
515  return success();
516 }
518 void LowerMemoryPass::runOnOperation() {
519  auto circuit = getOperation();
520  auto *body = circuit.getBodyBlock();
521  auto &instanceGraph = getAnalysis<InstanceGraph>();
522  symbolTable = &getAnalysis<SymbolTable>();
523  circuitNamespace.add(circuit);
525  // Find the device under test and create a set of all modules underneath it.
526  // If no module is marked as the DUT, then the top module is the DUT.
527  auto *dut = instanceGraph.getTopLevelNode();
528  auto it = llvm::find_if(*body, [&](Operation &op) -> bool {
530  });
531  if (it != body->end())
532  dut = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
534  // The set of all modules underneath the design under test module.
535  DenseSet<Operation *> dutModuleSet;
536  llvm::for_each(llvm::depth_first(dut), [&](igraph::InstanceGraphNode *node) {
537  dutModuleSet.insert(node->getModule());
538  });
540  // We iterate the circuit from top-to-bottom to make sure that we get
541  // consistent memory names.
542  for (auto module : body->getOps<FModuleOp>()) {
543  // We don't dedup memories in the testharness with any other memories.
544  auto shouldDedup = dutModuleSet.contains(module);
545  if (failed(runOnModule(module, shouldDedup)))
546  return signalPassFailure();
547  }
549  circuitNamespace.clear();
550  symbolTable = nullptr;
551  memories.clear();
552 }
554 std::unique_ptr<mlir::Pass> circt::firrtl::createLowerMemoryPass() {
555  return std::make_unique<LowerMemoryPass>();
556 }
