CIRCT  19.0.0git
FirMemLowering.cpp
Go to the documentation of this file.
1 //===- FirMemLowering.cpp - FirMem lowering utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "FirMemLowering.h"
10 #include "mlir/IR/Threading.h"
11 #include "llvm/Support/Debug.h"
12 
13 using namespace circt;
14 using namespace hw;
15 using namespace seq;
16 using llvm::MapVector;
17 
18 #define DEBUG_TYPE "lower-seq-firmem"
19 
21  : context(circuit.getContext()), circuit(circuit) {
24 
25  // For each module, assign an index. Use it to identify the insertion point
26  // for the generated ops.
27  for (auto [index, module] : llvm::enumerate(circuit.getOps<HWModuleOp>()))
28  moduleIndex[module] = index;
29 }
30 
31 /// Collect the memories in a list of HW modules.
33 FirMemLowering::collectMemories(ArrayRef<HWModuleOp> modules) {
34  // For each module in the list populate a separate vector of `FirMemOp`s in
35  // that module. This allows for the traversal of the HW modules to be
36  // parallelized.
37  using ModuleMemories = SmallVector<std::pair<FirMemConfig, FirMemOp>, 0>;
38  SmallVector<ModuleMemories> memories(modules.size());
39 
40  mlir::parallelFor(context, 0, modules.size(), [&](auto idx) {
41  // TODO: Check if this module is in the DUT hierarchy.
42  // bool isInDut = state.isInDUT(module);
43  HWModuleOp(modules[idx]).walk([&](seq::FirMemOp op) {
44  memories[idx].push_back({collectMemory(op), op});
45  });
46  });
47 
48  // Group the gathered memories by unique `FirMemConfig` details.
49  MapVector<FirMemConfig, SmallVector<FirMemOp, 1>> grouped;
50  for (auto [module, moduleMemories] : llvm::zip(modules, memories))
51  for (auto [summary, memOp] : moduleMemories)
52  grouped[summary].push_back(memOp);
53 
54  return grouped;
55 }
56 
57 /// Trace a value through wires to its original definition.
58 static Value lookThroughWires(Value value) {
59  while (value) {
60  if (auto wireOp = value.getDefiningOp<WireOp>()) {
61  value = wireOp.getInput();
62  continue;
63  }
64  break;
65  }
66  return value;
67 }
68 
69 /// Determine the exact parametrization of the memory that should be generated
70 /// for a given `FirMemOp`.
71 FirMemConfig FirMemLowering::collectMemory(FirMemOp op) {
72  FirMemConfig cfg;
73  cfg.dataWidth = op.getType().getWidth();
74  cfg.depth = op.getType().getDepth();
75  cfg.readLatency = op.getReadLatency();
76  cfg.writeLatency = op.getWriteLatency();
77  cfg.maskBits = op.getType().getMaskWidth().value_or(1);
78  cfg.readUnderWrite = op.getRuw();
79  cfg.writeUnderWrite = op.getWuw();
80  if (auto init = op.getInitAttr()) {
81  cfg.initFilename = init.getFilename();
82  cfg.initIsBinary = init.getIsBinary();
83  cfg.initIsInline = init.getIsInline();
84  }
85  cfg.outputFile = op.getOutputFileAttr();
86  if (auto prefix = op.getPrefixAttr())
87  cfg.prefix = prefix.getValue();
88  // TODO: Handle modName (maybe not?)
89  // TODO: Handle groupID (maybe not?)
90 
91  // Count the read, write, and read-write ports, and identify the clocks
92  // driving the write ports.
94  for (auto *user : op->getUsers()) {
95  if (isa<FirMemReadOp>(user))
96  ++cfg.numReadPorts;
97  else if (isa<FirMemWriteOp>(user))
98  ++cfg.numWritePorts;
99  else if (isa<FirMemReadWriteOp>(user))
100  ++cfg.numReadWritePorts;
101 
102  // Assign IDs to the values used as clock. This allows later passes to
103  // easily detect which clocks are effectively driven by the same value.
104  if (isa<FirMemWriteOp, FirMemReadWriteOp>(user)) {
105  auto clock = lookThroughWires(user->getOperand(2));
106  cfg.writeClockIDs.push_back(
107  clockValues.insert({clock, clockValues.size()}).first->second);
108  }
109  }
110 
111  return cfg;
112 }
113 
115  if (!schemaOp) {
116  // Create or re-use the generator schema.
117  for (auto op : circuit.getOps<hw::HWGeneratorSchemaOp>()) {
118  if (op.getDescriptor() == "FIRRTL_Memory") {
119  schemaOp = op;
120  break;
121  }
122  }
123  if (!schemaOp) {
124  auto builder = OpBuilder::atBlockBegin(circuit.getBody());
125  std::array<StringRef, 14> schemaFields = {
126  "depth", "numReadPorts",
127  "numWritePorts", "numReadWritePorts",
128  "readLatency", "writeLatency",
129  "width", "maskGran",
130  "readUnderWrite", "writeUnderWrite",
131  "writeClockIDs", "initFilename",
132  "initIsBinary", "initIsInline"};
133  schemaOp = builder.create<hw::HWGeneratorSchemaOp>(
134  circuit.getLoc(), "FIRRTLMem", "FIRRTL_Memory",
135  builder.getStrArrayAttr(schemaFields));
136  }
137  }
139 }
140 
141 /// Create the `HWModuleGeneratedOp` for a single memory parametrization.
142 HWModuleGeneratedOp
144  ArrayRef<seq::FirMemOp> memOps) {
145  auto schemaSymRef = getOrCreateSchema();
146 
147  // Identify the first module which uses the memory configuration.
148  // Insert the generated module before it.
149  HWModuleOp insertPt;
150  for (auto memOp : memOps) {
151  auto parent = memOp->getParentOfType<HWModuleOp>();
152  if (!insertPt || moduleIndex[parent] < moduleIndex[insertPt])
153  insertPt = parent;
154  }
155 
156  OpBuilder builder(context);
157  builder.setInsertionPoint(insertPt);
158 
159  // Pick a name for the memory. Honor the optional prefix and try to include
160  // the common part of the names of the memory instances that use this
161  // configuration. The resulting name is of the form:
162  //
163  // <prefix>_<commonName>_<depth>x<width>
164  //
165  StringRef baseName = "";
166  bool firstFound = false;
167  for (auto memOp : memOps) {
168  if (auto memName = memOp.getName()) {
169  if (!firstFound) {
170  baseName = *memName;
171  firstFound = true;
172  continue;
173  }
174  unsigned idx = 0;
175  for (; idx < memName->size() && idx < baseName.size(); ++idx)
176  if ((*memName)[idx] != baseName[idx])
177  break;
178  baseName = baseName.take_front(idx);
179  }
180  }
181  baseName = baseName.rtrim('_');
182 
183  SmallString<32> nameBuffer;
184  nameBuffer += mem.prefix;
185  if (!baseName.empty()) {
186  nameBuffer += baseName;
187  } else {
188  nameBuffer += "mem";
189  }
190  nameBuffer += "_";
191  (Twine(mem.depth) + "x" + Twine(mem.dataWidth)).toVector(nameBuffer);
192  auto name = builder.getStringAttr(globalNamespace.newName(nameBuffer));
193 
194  LLVM_DEBUG(llvm::dbgs() << "Creating " << name << " for " << mem.depth
195  << " x " << mem.dataWidth << " memory\n");
196 
197  bool withMask = mem.maskBits > 1;
198  SmallVector<hw::PortInfo> ports;
199 
200  // Common types used for memory ports.
201  Type clkType = ClockType::get(context);
202  Type bitType = IntegerType::get(context, 1);
203  Type dataType = IntegerType::get(context, std::max((size_t)1, mem.dataWidth));
204  Type maskType = IntegerType::get(context, mem.maskBits);
205  Type addrType =
206  IntegerType::get(context, std::max(1U, llvm::Log2_64_Ceil(mem.depth)));
207 
208  // Helper to add an input port.
209  size_t inputIdx = 0;
210  auto addInput = [&](StringRef prefix, size_t idx, StringRef suffix,
211  Type type) {
212  ports.push_back({{builder.getStringAttr(prefix + Twine(idx) + suffix), type,
214  inputIdx++});
215  };
216 
217  // Helper to add an output port.
218  size_t outputIdx = 0;
219  auto addOutput = [&](StringRef prefix, size_t idx, StringRef suffix,
220  Type type) {
221  ports.push_back({{builder.getStringAttr(prefix + Twine(idx) + suffix), type,
223  outputIdx++});
224  };
225 
226  // Helper to add the ports common to read, read-write, and write ports.
227  auto addCommonPorts = [&](StringRef prefix, size_t idx) {
228  addInput(prefix, idx, "_addr", addrType);
229  addInput(prefix, idx, "_en", bitType);
230  addInput(prefix, idx, "_clk", clkType);
231  };
232 
233  // Add the read ports.
234  for (size_t i = 0, e = mem.numReadPorts; i != e; ++i) {
235  addCommonPorts("R", i);
236  addOutput("R", i, "_data", dataType);
237  }
238 
239  // Add the read-write ports.
240  for (size_t i = 0, e = mem.numReadWritePorts; i != e; ++i) {
241  addCommonPorts("RW", i);
242  addInput("RW", i, "_wmode", bitType);
243  addInput("RW", i, "_wdata", dataType);
244  addOutput("RW", i, "_rdata", dataType);
245  if (withMask)
246  addInput("RW", i, "_wmask", maskType);
247  }
248 
249  // Add the write ports.
250  for (size_t i = 0, e = mem.numWritePorts; i != e; ++i) {
251  addCommonPorts("W", i);
252  addInput("W", i, "_data", dataType);
253  if (withMask)
254  addInput("W", i, "_mask", maskType);
255  }
256 
257  // Mask granularity is the number of data bits that each mask bit can
258  // guard. By default it is equal to the data bitwidth.
259  auto genAttr = [&](StringRef name, Attribute attr) {
260  return builder.getNamedAttr(name, attr);
261  };
262  auto genAttrUI32 = [&](StringRef name, uint32_t value) {
263  return genAttr(name, builder.getUI32IntegerAttr(value));
264  };
265  NamedAttribute genAttrs[] = {
266  genAttr("depth", builder.getI64IntegerAttr(mem.depth)),
267  genAttrUI32("numReadPorts", mem.numReadPorts),
268  genAttrUI32("numWritePorts", mem.numWritePorts),
269  genAttrUI32("numReadWritePorts", mem.numReadWritePorts),
270  genAttrUI32("readLatency", mem.readLatency),
271  genAttrUI32("writeLatency", mem.writeLatency),
272  genAttrUI32("width", mem.dataWidth),
273  genAttrUI32("maskGran", mem.dataWidth / mem.maskBits),
274  genAttr("readUnderWrite",
275  seq::RUWAttr::get(builder.getContext(), mem.readUnderWrite)),
276  genAttr("writeUnderWrite",
277  seq::WUWAttr::get(builder.getContext(), mem.writeUnderWrite)),
278  genAttr("writeClockIDs", builder.getI32ArrayAttr(mem.writeClockIDs)),
279  genAttr("initFilename", builder.getStringAttr(mem.initFilename)),
280  genAttr("initIsBinary", builder.getBoolAttr(mem.initIsBinary)),
281  genAttr("initIsInline", builder.getBoolAttr(mem.initIsInline))};
282 
283  // Combine the locations of all actual `FirMemOp`s to be the location of the
284  // generated memory.
285  Location loc = FirMemOp(memOps.front()).getLoc();
286  if (memOps.size() > 1) {
287  SmallVector<Location> locs;
288  for (auto memOp : memOps)
289  locs.push_back(memOp.getLoc());
290  loc = FusedLoc::get(context, locs);
291  }
292 
293  // Create the module.
294  auto genOp = builder.create<hw::HWModuleGeneratedOp>(
295  loc, schemaSymRef, name, ports, StringRef{}, ArrayAttr{}, genAttrs);
296  if (mem.outputFile)
297  genOp->setAttr("output_file", mem.outputFile);
298 
299  return genOp;
300 }
301 
302 /// Replace all `FirMemOp`s in an HW module with an instance of the
303 /// corresponding generated module.
305  HWModuleOp module,
306  ArrayRef<std::tuple<FirMemConfig *, HWModuleGeneratedOp, FirMemOp>> mems) {
307  LLVM_DEBUG(llvm::dbgs() << "Lowering " << mems.size() << " memories in "
308  << module.getName() << "\n");
309 
310  DenseMap<unsigned, Value> constOneOps;
311  auto constOne = [&](unsigned width = 1) {
312  auto it = constOneOps.try_emplace(width, Value{});
313  if (it.second) {
314  auto builder = OpBuilder::atBlockBegin(module.getBodyBlock());
315  it.first->second = builder.create<hw::ConstantOp>(
316  module.getLoc(), builder.getIntegerType(width), 1);
317  }
318  return it.first->second;
319  };
320  auto valueOrOne = [&](Value value, unsigned width = 1) {
321  return value ? value : constOne(width);
322  };
323 
324  for (auto [config, genOp, memOp] : mems) {
325  LLVM_DEBUG(llvm::dbgs() << "- Lowering " << memOp.getName() << "\n");
326  SmallVector<Value> inputs;
327  SmallVector<Value> outputs;
328 
329  auto addInput = [&](Value value) { inputs.push_back(value); };
330  auto addOutput = [&](Value value) { outputs.push_back(value); };
331 
332  // Add the read ports.
333  for (auto *op : memOp->getUsers()) {
334  auto port = dyn_cast<FirMemReadOp>(op);
335  if (!port)
336  continue;
337  addInput(port.getAddress());
338  addInput(valueOrOne(port.getEnable()));
339  addInput(port.getClk());
340  addOutput(port.getData());
341  }
342 
343  // Add the read-write ports.
344  for (auto *op : memOp->getUsers()) {
345  auto port = dyn_cast<FirMemReadWriteOp>(op);
346  if (!port)
347  continue;
348  addInput(port.getAddress());
349  addInput(valueOrOne(port.getEnable()));
350  addInput(port.getClk());
351  addInput(port.getMode());
352  addInput(port.getWriteData());
353  addOutput(port.getReadData());
354  if (config->maskBits > 1)
355  addInput(valueOrOne(port.getMask(), config->maskBits));
356  }
357 
358  // Add the write ports.
359  for (auto *op : memOp->getUsers()) {
360  auto port = dyn_cast<FirMemWriteOp>(op);
361  if (!port)
362  continue;
363  addInput(port.getAddress());
364  addInput(valueOrOne(port.getEnable()));
365  addInput(port.getClk());
366  addInput(port.getData());
367  if (config->maskBits > 1)
368  addInput(valueOrOne(port.getMask(), config->maskBits));
369  }
370 
371  // Create the module instance.
372  StringRef memName = "mem";
373  if (auto name = memOp.getName(); name && !name->empty())
374  memName = *name;
375  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
376  auto instOp = builder.create<hw::InstanceOp>(
377  genOp, builder.getStringAttr(memName + "_ext"), inputs, ArrayAttr{},
378  memOp.getInnerSymAttr());
379  for (auto [oldOutput, newOutput] : llvm::zip(outputs, instOp.getResults()))
380  oldOutput.replaceAllUsesWith(newOutput);
381 
382  // Carry attributes over from the `FirMemOp` to the `InstanceOp`.
383  auto defaultAttrNames = memOp.getAttributeNames();
384  for (auto namedAttr : memOp->getAttrs())
385  if (!llvm::is_contained(defaultAttrNames, namedAttr.getName()))
386  instOp->setAttr(namedAttr.getName(), namedAttr.getValue());
387 
388  // Get rid of the `FirMemOp`.
389  for (auto *user : llvm::make_early_inc_range(memOp->getUsers()))
390  user->erase();
391  memOp.erase();
392  }
393 }
int32_t width
Definition: FIRRTL.cpp:36
static Value lookThroughWires(Value value)
Trace a value through wires to its original definition.
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
static std::vector< mlir::Value > toVector(mlir::ValueRange range)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
UniqueConfigs collectMemories(ArrayRef< hw::HWModuleOp > modules)
Groups memories by their kind from the whole design.
void lowerMemoriesInModule(hw::HWModuleOp module, ArrayRef< MemoryConfig > mems)
Lowers a group of memories from the same module.
hw::HWGeneratorSchemaOp schemaOp
FirMemLowering(ModuleOp circuit)
hw::HWModuleGeneratedOp createMemoryModule(FirMemConfig &mem, ArrayRef< seq::FirMemOp > memOps)
Creates the generated module for a given configuration.
FlatSymbolRefAttr getOrCreateSchema()
Find the schema or create it if it does not exist.
DenseMap< hw::HWModuleOp, size_t > moduleIndex
llvm::MapVector< FirMemConfig, SmallVector< seq::FirMemOp, 1 > > UniqueConfigs
A vector of unique FirMemConfigs and all the FirMemOps that use it.
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
Definition: Namespace.h:47
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:63
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1
Definition: seq.py:1
The configuration of a FIR memory.
SmallVector< int32_t, 1 > writeClockIDs