CIRCT  20.0.0git
FirMemLowering.cpp
Go to the documentation of this file.
1 //===- FirMemLowering.cpp - FirMem lowering utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "FirMemLowering.h"
10 #include "mlir/IR/Threading.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace circt;
15 using namespace hw;
16 using namespace seq;
17 using llvm::MapVector;
18 
19 #define DEBUG_TYPE "lower-seq-firmem"
20 
22  : context(circuit.getContext()), circuit(circuit) {
25 
26  // For each module, assign an index. Use it to identify the insertion point
27  // for the generated ops.
28  for (auto [index, module] : llvm::enumerate(circuit.getOps<HWModuleOp>()))
29  moduleIndex[module] = index;
30 }
31 
32 /// Collect the memories in a list of HW modules.
34 FirMemLowering::collectMemories(ArrayRef<HWModuleOp> modules) {
35  // For each module in the list populate a separate vector of `FirMemOp`s in
36  // that module. This allows for the traversal of the HW modules to be
37  // parallelized.
38  using ModuleMemories = SmallVector<std::pair<FirMemConfig, FirMemOp>, 0>;
39  SmallVector<ModuleMemories> memories(modules.size());
40 
41  mlir::parallelFor(context, 0, modules.size(), [&](auto idx) {
42  // TODO: Check if this module is in the DUT hierarchy.
43  // bool isInDut = state.isInDUT(module);
44  HWModuleOp(modules[idx]).walk([&](seq::FirMemOp op) {
45  memories[idx].push_back({collectMemory(op), op});
46  });
47  });
48 
49  // Group the gathered memories by unique `FirMemConfig` details.
50  MapVector<FirMemConfig, SmallVector<FirMemOp, 1>> grouped;
51  for (auto [module, moduleMemories] : llvm::zip(modules, memories))
52  for (auto [summary, memOp] : moduleMemories)
53  grouped[summary].push_back(memOp);
54 
55  return grouped;
56 }
57 
58 /// Trace a value through wires to its original definition.
59 static Value lookThroughWires(Value value) {
60  while (value) {
61  if (auto wireOp = value.getDefiningOp<WireOp>()) {
62  value = wireOp.getInput();
63  continue;
64  }
65  break;
66  }
67  return value;
68 }
69 
70 /// Determine the exact parametrization of the memory that should be generated
71 /// for a given `FirMemOp`.
72 FirMemConfig FirMemLowering::collectMemory(FirMemOp op) {
73  FirMemConfig cfg;
74  cfg.dataWidth = op.getType().getWidth();
75  cfg.depth = op.getType().getDepth();
76  cfg.readLatency = op.getReadLatency();
77  cfg.writeLatency = op.getWriteLatency();
78  cfg.maskBits = op.getType().getMaskWidth().value_or(1);
79  cfg.readUnderWrite = op.getRuw();
80  cfg.writeUnderWrite = op.getWuw();
81  if (auto init = op.getInitAttr()) {
82  cfg.initFilename = init.getFilename();
83  cfg.initIsBinary = init.getIsBinary();
84  cfg.initIsInline = init.getIsInline();
85  }
86  cfg.outputFile = op.getOutputFileAttr();
87  if (auto prefix = op.getPrefixAttr())
88  cfg.prefix = prefix.getValue();
89  // TODO: Handle modName (maybe not?)
90  // TODO: Handle groupID (maybe not?)
91 
92  // Count the read, write, and read-write ports, and identify the clocks
93  // driving the write ports.
95  for (auto *user : op->getUsers()) {
96  if (isa<FirMemReadOp>(user))
97  ++cfg.numReadPorts;
98  else if (isa<FirMemWriteOp>(user))
99  ++cfg.numWritePorts;
100  else if (isa<FirMemReadWriteOp>(user))
101  ++cfg.numReadWritePorts;
102 
103  // Assign IDs to the values used as clock. This allows later passes to
104  // easily detect which clocks are effectively driven by the same value.
105  if (isa<FirMemWriteOp, FirMemReadWriteOp>(user)) {
106  auto clock = lookThroughWires(user->getOperand(2));
107  cfg.writeClockIDs.push_back(
108  clockValues.insert({clock, clockValues.size()}).first->second);
109  }
110  }
111 
112  return cfg;
113 }
114 
116  if (!schemaOp) {
117  // Create or re-use the generator schema.
118  for (auto op : circuit.getOps<hw::HWGeneratorSchemaOp>()) {
119  if (op.getDescriptor() == "FIRRTL_Memory") {
120  schemaOp = op;
121  break;
122  }
123  }
124  if (!schemaOp) {
125  auto builder = OpBuilder::atBlockBegin(circuit.getBody());
126  std::array<StringRef, 14> schemaFields = {
127  "depth", "numReadPorts",
128  "numWritePorts", "numReadWritePorts",
129  "readLatency", "writeLatency",
130  "width", "maskGran",
131  "readUnderWrite", "writeUnderWrite",
132  "writeClockIDs", "initFilename",
133  "initIsBinary", "initIsInline"};
134  schemaOp = builder.create<hw::HWGeneratorSchemaOp>(
135  circuit.getLoc(), "FIRRTLMem", "FIRRTL_Memory",
136  builder.getStrArrayAttr(schemaFields));
137  }
138  }
140 }
141 
142 /// Create the `HWModuleGeneratedOp` for a single memory parametrization.
143 HWModuleGeneratedOp
145  ArrayRef<seq::FirMemOp> memOps) {
146  auto schemaSymRef = getOrCreateSchema();
147 
148  // Identify the first module which uses the memory configuration.
149  // Insert the generated module before it.
150  HWModuleOp insertPt;
151  for (auto memOp : memOps) {
152  auto parent = memOp->getParentOfType<HWModuleOp>();
153  if (!insertPt || moduleIndex[parent] < moduleIndex[insertPt])
154  insertPt = parent;
155  }
156 
157  OpBuilder builder(context);
158  builder.setInsertionPoint(insertPt);
159 
160  // Pick a name for the memory. Honor the optional prefix and try to include
161  // the common part of the names of the memory instances that use this
162  // configuration. The resulting name is of the form:
163  //
164  // <prefix>_<commonName>_<depth>x<width>
165  //
166  StringRef baseName = "";
167  bool firstFound = false;
168  for (auto memOp : memOps) {
169  if (auto memName = memOp.getName()) {
170  if (!firstFound) {
171  baseName = *memName;
172  firstFound = true;
173  continue;
174  }
175  unsigned idx = 0;
176  for (; idx < memName->size() && idx < baseName.size(); ++idx)
177  if ((*memName)[idx] != baseName[idx])
178  break;
179  baseName = baseName.take_front(idx);
180  }
181  }
182  baseName = baseName.rtrim('_');
183 
184  SmallString<32> nameBuffer;
185  nameBuffer += mem.prefix;
186  if (!baseName.empty()) {
187  nameBuffer += baseName;
188  } else {
189  nameBuffer += "mem";
190  }
191  nameBuffer += "_";
192  (Twine(mem.depth) + "x" + Twine(mem.dataWidth)).toVector(nameBuffer);
193  auto name = builder.getStringAttr(globalNamespace.newName(nameBuffer));
194 
195  LLVM_DEBUG(llvm::dbgs() << "Creating " << name << " for " << mem.depth
196  << " x " << mem.dataWidth << " memory\n");
197 
198  bool withMask = mem.maskBits > 1;
199  SmallVector<hw::PortInfo> ports;
200 
201  // Common types used for memory ports.
202  Type clkType = ClockType::get(context);
203  Type bitType = IntegerType::get(context, 1);
204  Type dataType = IntegerType::get(context, std::max((size_t)1, mem.dataWidth));
205  Type maskType = IntegerType::get(context, mem.maskBits);
206  Type addrType =
207  IntegerType::get(context, std::max(1U, llvm::Log2_64_Ceil(mem.depth)));
208 
209  // Helper to add an input port.
210  size_t inputIdx = 0;
211  auto addInput = [&](StringRef prefix, size_t idx, StringRef suffix,
212  Type type) {
213  ports.push_back({{builder.getStringAttr(prefix + Twine(idx) + suffix), type,
215  inputIdx++});
216  };
217 
218  // Helper to add an output port.
219  size_t outputIdx = 0;
220  auto addOutput = [&](StringRef prefix, size_t idx, StringRef suffix,
221  Type type) {
222  ports.push_back({{builder.getStringAttr(prefix + Twine(idx) + suffix), type,
224  outputIdx++});
225  };
226 
227  // Helper to add the ports common to read, read-write, and write ports.
228  auto addCommonPorts = [&](StringRef prefix, size_t idx) {
229  addInput(prefix, idx, "_addr", addrType);
230  addInput(prefix, idx, "_en", bitType);
231  addInput(prefix, idx, "_clk", clkType);
232  };
233 
234  // Add the read ports.
235  for (size_t i = 0, e = mem.numReadPorts; i != e; ++i) {
236  addCommonPorts("R", i);
237  addOutput("R", i, "_data", dataType);
238  }
239 
240  // Add the read-write ports.
241  for (size_t i = 0, e = mem.numReadWritePorts; i != e; ++i) {
242  addCommonPorts("RW", i);
243  addInput("RW", i, "_wmode", bitType);
244  addInput("RW", i, "_wdata", dataType);
245  addOutput("RW", i, "_rdata", dataType);
246  if (withMask)
247  addInput("RW", i, "_wmask", maskType);
248  }
249 
250  // Add the write ports.
251  for (size_t i = 0, e = mem.numWritePorts; i != e; ++i) {
252  addCommonPorts("W", i);
253  addInput("W", i, "_data", dataType);
254  if (withMask)
255  addInput("W", i, "_mask", maskType);
256  }
257 
258  // Mask granularity is the number of data bits that each mask bit can
259  // guard. By default it is equal to the data bitwidth.
260  auto genAttr = [&](StringRef name, Attribute attr) {
261  return builder.getNamedAttr(name, attr);
262  };
263  auto genAttrUI32 = [&](StringRef name, uint32_t value) {
264  return genAttr(name, builder.getUI32IntegerAttr(value));
265  };
266  NamedAttribute genAttrs[] = {
267  genAttr("depth", builder.getI64IntegerAttr(mem.depth)),
268  genAttrUI32("numReadPorts", mem.numReadPorts),
269  genAttrUI32("numWritePorts", mem.numWritePorts),
270  genAttrUI32("numReadWritePorts", mem.numReadWritePorts),
271  genAttrUI32("readLatency", mem.readLatency),
272  genAttrUI32("writeLatency", mem.writeLatency),
273  genAttrUI32("width", mem.dataWidth),
274  genAttrUI32("maskGran", mem.dataWidth / mem.maskBits),
275  genAttr("readUnderWrite",
276  seq::RUWAttr::get(builder.getContext(), mem.readUnderWrite)),
277  genAttr("writeUnderWrite",
278  seq::WUWAttr::get(builder.getContext(), mem.writeUnderWrite)),
279  genAttr("writeClockIDs", builder.getI32ArrayAttr(mem.writeClockIDs)),
280  genAttr("initFilename", builder.getStringAttr(mem.initFilename)),
281  genAttr("initIsBinary", builder.getBoolAttr(mem.initIsBinary)),
282  genAttr("initIsInline", builder.getBoolAttr(mem.initIsInline))};
283 
284  // Combine the locations of all actual `FirMemOp`s to be the location of the
285  // generated memory.
286  Location loc = FirMemOp(memOps.front()).getLoc();
287  if (memOps.size() > 1) {
288  SmallVector<Location> locs;
289  for (auto memOp : memOps)
290  locs.push_back(memOp.getLoc());
291  loc = FusedLoc::get(context, locs);
292  }
293 
294  // Create the module.
295  auto genOp = builder.create<hw::HWModuleGeneratedOp>(
296  loc, schemaSymRef, name, ports, StringRef{}, ArrayAttr{}, genAttrs);
297  if (mem.outputFile)
298  genOp->setAttr("output_file", mem.outputFile);
299 
300  return genOp;
301 }
302 
303 /// Replace all `FirMemOp`s in an HW module with an instance of the
304 /// corresponding generated module.
306  HWModuleOp module,
307  ArrayRef<std::tuple<FirMemConfig *, HWModuleGeneratedOp, FirMemOp>> mems) {
308  LLVM_DEBUG(llvm::dbgs() << "Lowering " << mems.size() << " memories in "
309  << module.getName() << "\n");
310 
311  DenseMap<unsigned, Value> constOneOps;
312  auto constOne = [&](unsigned width = 1) {
313  auto it = constOneOps.try_emplace(width, Value{});
314  if (it.second) {
315  auto builder = OpBuilder::atBlockBegin(module.getBodyBlock());
316  it.first->second = builder.create<hw::ConstantOp>(
317  module.getLoc(), builder.getIntegerType(width), 1);
318  }
319  return it.first->second;
320  };
321  auto valueOrOne = [&](Value value, unsigned width = 1) {
322  return value ? value : constOne(width);
323  };
324 
325  for (auto [config, genOp, memOp] : mems) {
326  LLVM_DEBUG(llvm::dbgs() << "- Lowering " << memOp.getName() << "\n");
327  SmallVector<Value> inputs;
328  SmallVector<Value> outputs;
329 
330  auto addInput = [&](Value value) { inputs.push_back(value); };
331  auto addOutput = [&](Value value) { outputs.push_back(value); };
332 
333  // Add the read ports.
334  for (auto *op : memOp->getUsers()) {
335  auto port = dyn_cast<FirMemReadOp>(op);
336  if (!port)
337  continue;
338  addInput(port.getAddress());
339  addInput(valueOrOne(port.getEnable()));
340  addInput(port.getClk());
341  addOutput(port.getData());
342  }
343 
344  // Add the read-write ports.
345  for (auto *op : memOp->getUsers()) {
346  auto port = dyn_cast<FirMemReadWriteOp>(op);
347  if (!port)
348  continue;
349  addInput(port.getAddress());
350  addInput(valueOrOne(port.getEnable()));
351  addInput(port.getClk());
352  addInput(port.getMode());
353  addInput(port.getWriteData());
354  addOutput(port.getReadData());
355  if (config->maskBits > 1)
356  addInput(valueOrOne(port.getMask(), config->maskBits));
357  }
358 
359  // Add the write ports.
360  for (auto *op : memOp->getUsers()) {
361  auto port = dyn_cast<FirMemWriteOp>(op);
362  if (!port)
363  continue;
364  addInput(port.getAddress());
365  addInput(valueOrOne(port.getEnable()));
366  addInput(port.getClk());
367  addInput(port.getData());
368  if (config->maskBits > 1)
369  addInput(valueOrOne(port.getMask(), config->maskBits));
370  }
371 
372  // Create the module instance.
373  StringRef memName = "mem";
374  if (auto name = memOp.getName(); name && !name->empty())
375  memName = *name;
376  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
377  auto instOp = builder.create<hw::InstanceOp>(
378  genOp, builder.getStringAttr(memName + "_ext"), inputs, ArrayAttr{},
379  memOp.getInnerSymAttr());
380  for (auto [oldOutput, newOutput] : llvm::zip(outputs, instOp.getResults()))
381  oldOutput.replaceAllUsesWith(newOutput);
382 
383  // Carry attributes over from the `FirMemOp` to the `InstanceOp`.
384  auto defaultAttrNames = memOp.getAttributeNames();
385  for (auto namedAttr : memOp->getAttrs())
386  if (!llvm::is_contained(defaultAttrNames, namedAttr.getName()))
387  instOp->setAttr(namedAttr.getName(), namedAttr.getValue());
388 
389  // Get rid of the `FirMemOp`.
390  for (auto *user : llvm::make_early_inc_range(memOp->getUsers()))
391  user->erase();
392  memOp.erase();
393  }
394 }
int32_t width
Definition: FIRRTL.cpp:36
static Value lookThroughWires(Value value)
Trace a value through wires to its original definition.
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
static std::vector< mlir::Value > toVector(mlir::ValueRange range)
UniqueConfigs collectMemories(ArrayRef< hw::HWModuleOp > modules)
Groups memories by their kind from the whole design.
void lowerMemoriesInModule(hw::HWModuleOp module, ArrayRef< MemoryConfig > mems)
Lowers a group of memories from the same module.
hw::HWGeneratorSchemaOp schemaOp
FirMemLowering(ModuleOp circuit)
hw::HWModuleGeneratedOp createMemoryModule(FirMemConfig &mem, ArrayRef< seq::FirMemOp > memOps)
Creates the generated module for a given configuration.
FlatSymbolRefAttr getOrCreateSchema()
Find the schema or create it if it does not exist.
DenseMap< hw::HWModuleOp, size_t > moduleIndex
llvm::MapVector< FirMemConfig, SmallVector< seq::FirMemOp, 1 > > UniqueConfigs
A vector of unique FirMemConfigs and all the FirMemOps that use it.
void add(mlir::ModuleOp module)
Definition: Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:85
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1
Definition: seq.py:1
The configuration of a FIR memory.
SmallVector< int32_t, 1 > writeClockIDs