CIRCT 20.0.0git
Loading...
Searching...
No Matches
FirMemLowering.cpp
Go to the documentation of this file.
1//===- FirMemLowering.cpp - FirMem lowering utilities ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "FirMemLowering.h"
10#include "mlir/IR/Threading.h"
11#include "llvm/ADT/MapVector.h"
12#include "llvm/Support/Debug.h"
13
14using namespace circt;
15using namespace hw;
16using namespace seq;
17using llvm::MapVector;
18
19#define DEBUG_TYPE "lower-seq-firmem"
20
22 : context(circuit.getContext()), circuit(circuit) {
25
26 // For each module, assign an index. Use it to identify the insertion point
27 // for the generated ops.
28 for (auto [index, module] : llvm::enumerate(circuit.getOps<HWModuleOp>()))
29 moduleIndex[module] = index;
30}
31
32/// Collect the memories in a list of HW modules.
34FirMemLowering::collectMemories(ArrayRef<HWModuleOp> modules) {
35 // For each module in the list populate a separate vector of `FirMemOp`s in
36 // that module. This allows for the traversal of the HW modules to be
37 // parallelized.
38 using ModuleMemories = SmallVector<std::pair<FirMemConfig, FirMemOp>, 0>;
39 SmallVector<ModuleMemories> memories(modules.size());
40
41 mlir::parallelFor(context, 0, modules.size(), [&](auto idx) {
42 // TODO: Check if this module is in the DUT hierarchy.
43 // bool isInDut = state.isInDUT(module);
44 HWModuleOp(modules[idx]).walk([&](seq::FirMemOp op) {
45 memories[idx].push_back({collectMemory(op), op});
46 });
47 });
48
49 // Group the gathered memories by unique `FirMemConfig` details.
50 MapVector<FirMemConfig, SmallVector<FirMemOp, 1>> grouped;
51 for (auto [module, moduleMemories] : llvm::zip(modules, memories))
52 for (auto [summary, memOp] : moduleMemories)
53 grouped[summary].push_back(memOp);
54
55 return grouped;
56}
57
58/// Trace a value through wires to its original definition.
59static Value lookThroughWires(Value value) {
60 while (value) {
61 if (auto wireOp = value.getDefiningOp<WireOp>()) {
62 value = wireOp.getInput();
63 continue;
64 }
65 break;
66 }
67 return value;
68}
69
70/// Determine the exact parametrization of the memory that should be generated
71/// for a given `FirMemOp`.
73 FirMemConfig cfg;
74 cfg.dataWidth = op.getType().getWidth();
75 cfg.depth = op.getType().getDepth();
76 cfg.readLatency = op.getReadLatency();
77 cfg.writeLatency = op.getWriteLatency();
78 cfg.maskBits = op.getType().getMaskWidth().value_or(1);
79 cfg.readUnderWrite = op.getRuw();
80 cfg.writeUnderWrite = op.getWuw();
81 if (auto init = op.getInitAttr()) {
82 cfg.initFilename = init.getFilename();
83 cfg.initIsBinary = init.getIsBinary();
84 cfg.initIsInline = init.getIsInline();
85 }
86 cfg.outputFile = op.getOutputFileAttr();
87 if (auto prefix = op.getPrefixAttr())
88 cfg.prefix = prefix.getValue();
89 // TODO: Handle modName (maybe not?)
90 // TODO: Handle groupID (maybe not?)
91
92 // Count the read, write, and read-write ports, and identify the clocks
93 // driving the write ports.
95 for (auto *user : op->getUsers()) {
96 if (isa<FirMemReadOp>(user))
97 ++cfg.numReadPorts;
98 else if (isa<FirMemWriteOp>(user))
99 ++cfg.numWritePorts;
100 else if (isa<FirMemReadWriteOp>(user))
101 ++cfg.numReadWritePorts;
102
103 // Assign IDs to the values used as clock. This allows later passes to
104 // easily detect which clocks are effectively driven by the same value.
105 if (isa<FirMemWriteOp, FirMemReadWriteOp>(user)) {
106 auto clock = lookThroughWires(user->getOperand(2));
107 cfg.writeClockIDs.push_back(
108 clockValues.insert({clock, clockValues.size()}).first->second);
109 }
110 }
111
112 return cfg;
113}
114
116 if (!schemaOp) {
117 // Create or re-use the generator schema.
118 for (auto op : circuit.getOps<hw::HWGeneratorSchemaOp>()) {
119 if (op.getDescriptor() == "FIRRTL_Memory") {
120 schemaOp = op;
121 break;
122 }
123 }
124 if (!schemaOp) {
125 auto builder = OpBuilder::atBlockBegin(circuit.getBody());
126 std::array<StringRef, 14> schemaFields = {
127 "depth", "numReadPorts",
128 "numWritePorts", "numReadWritePorts",
129 "readLatency", "writeLatency",
130 "width", "maskGran",
131 "readUnderWrite", "writeUnderWrite",
132 "writeClockIDs", "initFilename",
133 "initIsBinary", "initIsInline"};
134 schemaOp = builder.create<hw::HWGeneratorSchemaOp>(
135 circuit.getLoc(), "FIRRTLMem", "FIRRTL_Memory",
136 builder.getStrArrayAttr(schemaFields));
137 }
138 }
139 return FlatSymbolRefAttr::get(schemaOp);
140}
141
142/// Create the `HWModuleGeneratedOp` for a single memory parametrization.
143HWModuleGeneratedOp
145 ArrayRef<seq::FirMemOp> memOps) {
146 auto schemaSymRef = getOrCreateSchema();
147
148 // Identify the first module which uses the memory configuration.
149 // Insert the generated module before it.
150 HWModuleOp insertPt;
151 for (auto memOp : memOps) {
152 auto parent = memOp->getParentOfType<HWModuleOp>();
153 if (!insertPt || moduleIndex[parent] < moduleIndex[insertPt])
154 insertPt = parent;
155 }
156
157 OpBuilder builder(context);
158 builder.setInsertionPoint(insertPt);
159
160 // Pick a name for the memory. Honor the optional prefix and try to include
161 // the common part of the names of the memory instances that use this
162 // configuration. The resulting name is of the form:
163 //
164 // <prefix>_<commonName>_<depth>x<width>
165 //
166 StringRef baseName = "";
167 bool firstFound = false;
168 for (auto memOp : memOps) {
169 if (auto memName = memOp.getName()) {
170 if (!firstFound) {
171 baseName = *memName;
172 firstFound = true;
173 continue;
174 }
175 unsigned idx = 0;
176 for (; idx < memName->size() && idx < baseName.size(); ++idx)
177 if ((*memName)[idx] != baseName[idx])
178 break;
179 baseName = baseName.take_front(idx);
180 }
181 }
182 baseName = baseName.rtrim('_');
183
184 SmallString<32> nameBuffer;
185 nameBuffer += mem.prefix;
186 if (!baseName.empty()) {
187 nameBuffer += baseName;
188 } else {
189 nameBuffer += "mem";
190 }
191 nameBuffer += "_";
192 (Twine(mem.depth) + "x" + Twine(mem.dataWidth)).toVector(nameBuffer);
193 auto name = builder.getStringAttr(globalNamespace.newName(nameBuffer));
194
195 LLVM_DEBUG(llvm::dbgs() << "Creating " << name << " for " << mem.depth
196 << " x " << mem.dataWidth << " memory\n");
197
198 bool withMask = mem.maskBits > 1;
199 SmallVector<hw::PortInfo> ports;
200
201 // Common types used for memory ports.
202 Type clkType = ClockType::get(context);
203 Type bitType = IntegerType::get(context, 1);
204 Type dataType = IntegerType::get(context, std::max((size_t)1, mem.dataWidth));
205 Type maskType = IntegerType::get(context, mem.maskBits);
206 Type addrType =
207 IntegerType::get(context, std::max(1U, llvm::Log2_64_Ceil(mem.depth)));
208
209 // Helper to add an input port.
210 size_t inputIdx = 0;
211 auto addInput = [&](StringRef prefix, size_t idx, StringRef suffix,
212 Type type) {
213 ports.push_back({{builder.getStringAttr(prefix + Twine(idx) + suffix), type,
214 ModulePort::Direction::Input},
215 inputIdx++});
216 };
217
218 // Helper to add an output port.
219 size_t outputIdx = 0;
220 auto addOutput = [&](StringRef prefix, size_t idx, StringRef suffix,
221 Type type) {
222 ports.push_back({{builder.getStringAttr(prefix + Twine(idx) + suffix), type,
223 ModulePort::Direction::Output},
224 outputIdx++});
225 };
226
227 // Helper to add the ports common to read, read-write, and write ports.
228 auto addCommonPorts = [&](StringRef prefix, size_t idx) {
229 addInput(prefix, idx, "_addr", addrType);
230 addInput(prefix, idx, "_en", bitType);
231 addInput(prefix, idx, "_clk", clkType);
232 };
233
234 // Add the read ports.
235 for (size_t i = 0, e = mem.numReadPorts; i != e; ++i) {
236 addCommonPorts("R", i);
237 addOutput("R", i, "_data", dataType);
238 }
239
240 // Add the read-write ports.
241 for (size_t i = 0, e = mem.numReadWritePorts; i != e; ++i) {
242 addCommonPorts("RW", i);
243 addInput("RW", i, "_wmode", bitType);
244 addInput("RW", i, "_wdata", dataType);
245 addOutput("RW", i, "_rdata", dataType);
246 if (withMask)
247 addInput("RW", i, "_wmask", maskType);
248 }
249
250 // Add the write ports.
251 for (size_t i = 0, e = mem.numWritePorts; i != e; ++i) {
252 addCommonPorts("W", i);
253 addInput("W", i, "_data", dataType);
254 if (withMask)
255 addInput("W", i, "_mask", maskType);
256 }
257
258 // Mask granularity is the number of data bits that each mask bit can
259 // guard. By default it is equal to the data bitwidth.
260 auto genAttr = [&](StringRef name, Attribute attr) {
261 return builder.getNamedAttr(name, attr);
262 };
263 auto genAttrUI32 = [&](StringRef name, uint32_t value) {
264 return genAttr(name, builder.getUI32IntegerAttr(value));
265 };
266 NamedAttribute genAttrs[] = {
267 genAttr("depth", builder.getI64IntegerAttr(mem.depth)),
268 genAttrUI32("numReadPorts", mem.numReadPorts),
269 genAttrUI32("numWritePorts", mem.numWritePorts),
270 genAttrUI32("numReadWritePorts", mem.numReadWritePorts),
271 genAttrUI32("readLatency", mem.readLatency),
272 genAttrUI32("writeLatency", mem.writeLatency),
273 genAttrUI32("width", mem.dataWidth),
274 genAttrUI32("maskGran", mem.dataWidth / mem.maskBits),
275 genAttr("readUnderWrite",
276 seq::RUWAttr::get(builder.getContext(), mem.readUnderWrite)),
277 genAttr("writeUnderWrite",
278 seq::WUWAttr::get(builder.getContext(), mem.writeUnderWrite)),
279 genAttr("writeClockIDs", builder.getI32ArrayAttr(mem.writeClockIDs)),
280 genAttr("initFilename", builder.getStringAttr(mem.initFilename)),
281 genAttr("initIsBinary", builder.getBoolAttr(mem.initIsBinary)),
282 genAttr("initIsInline", builder.getBoolAttr(mem.initIsInline))};
283
284 // Combine the locations of all actual `FirMemOp`s to be the location of the
285 // generated memory.
286 Location loc = FirMemOp(memOps.front()).getLoc();
287 if (memOps.size() > 1) {
288 SmallVector<Location> locs;
289 for (auto memOp : memOps)
290 locs.push_back(memOp.getLoc());
291 loc = FusedLoc::get(context, locs);
292 }
293
294 // Create the module.
295 auto genOp = builder.create<hw::HWModuleGeneratedOp>(
296 loc, schemaSymRef, name, ports, StringRef{}, ArrayAttr{}, genAttrs);
297 if (mem.outputFile)
298 genOp->setAttr("output_file", mem.outputFile);
299
300 return genOp;
301}
302
303/// Replace all `FirMemOp`s in an HW module with an instance of the
304/// corresponding generated module.
306 HWModuleOp module,
307 ArrayRef<std::tuple<FirMemConfig *, HWModuleGeneratedOp, FirMemOp>> mems) {
308 LLVM_DEBUG(llvm::dbgs() << "Lowering " << mems.size() << " memories in "
309 << module.getName() << "\n");
310
311 DenseMap<unsigned, Value> constOneOps;
312 auto constOne = [&](unsigned width = 1) {
313 auto it = constOneOps.try_emplace(width, Value{});
314 if (it.second) {
315 auto builder = OpBuilder::atBlockBegin(module.getBodyBlock());
316 it.first->second = builder.create<hw::ConstantOp>(
317 module.getLoc(), builder.getIntegerType(width), 1);
318 }
319 return it.first->second;
320 };
321 auto valueOrOne = [&](Value value, unsigned width = 1) {
322 return value ? value : constOne(width);
323 };
324
325 for (auto [config, genOp, memOp] : mems) {
326 LLVM_DEBUG(llvm::dbgs() << "- Lowering " << memOp.getName() << "\n");
327 SmallVector<Value> inputs;
328 SmallVector<Value> outputs;
329
330 auto addInput = [&](Value value) { inputs.push_back(value); };
331 auto addOutput = [&](Value value) { outputs.push_back(value); };
332
333 // Add the read ports.
334 for (auto *op : memOp->getUsers()) {
335 auto port = dyn_cast<FirMemReadOp>(op);
336 if (!port)
337 continue;
338 addInput(port.getAddress());
339 addInput(valueOrOne(port.getEnable()));
340 addInput(port.getClk());
341 addOutput(port.getData());
342 }
343
344 // Add the read-write ports.
345 for (auto *op : memOp->getUsers()) {
346 auto port = dyn_cast<FirMemReadWriteOp>(op);
347 if (!port)
348 continue;
349 addInput(port.getAddress());
350 addInput(valueOrOne(port.getEnable()));
351 addInput(port.getClk());
352 addInput(port.getMode());
353 addInput(port.getWriteData());
354 addOutput(port.getReadData());
355 if (config->maskBits > 1)
356 addInput(valueOrOne(port.getMask(), config->maskBits));
357 }
358
359 // Add the write ports.
360 for (auto *op : memOp->getUsers()) {
361 auto port = dyn_cast<FirMemWriteOp>(op);
362 if (!port)
363 continue;
364 addInput(port.getAddress());
365 addInput(valueOrOne(port.getEnable()));
366 addInput(port.getClk());
367 addInput(port.getData());
368 if (config->maskBits > 1)
369 addInput(valueOrOne(port.getMask(), config->maskBits));
370 }
371
372 // Create the module instance.
373 StringRef memName = "mem";
374 if (auto name = memOp.getName(); name && !name->empty())
375 memName = *name;
376 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
377 auto instOp = builder.create<hw::InstanceOp>(
378 genOp, builder.getStringAttr(memName + "_ext"), inputs, ArrayAttr{},
379 memOp.getInnerSymAttr());
380 for (auto [oldOutput, newOutput] : llvm::zip(outputs, instOp.getResults()))
381 oldOutput.replaceAllUsesWith(newOutput);
382
383 // Carry attributes over from the `FirMemOp` to the `InstanceOp`.
384 auto defaultAttrNames = memOp.getAttributeNames();
385 for (auto namedAttr : memOp->getAttrs())
386 if (!llvm::is_contained(defaultAttrNames, namedAttr.getName()))
387 instOp->setAttr(namedAttr.getName(), namedAttr.getValue());
388
389 // Get rid of the `FirMemOp`.
390 for (auto *user : llvm::make_early_inc_range(memOp->getUsers()))
391 user->erase();
392 memOp.erase();
393 }
394}
static Value lookThroughWires(Value value)
Trace a value through wires to its original definition.
static std::vector< mlir::Value > toVector(mlir::ValueRange range)
UniqueConfigs collectMemories(ArrayRef< hw::HWModuleOp > modules)
Groups memories by their kind from the whole design.
void lowerMemoriesInModule(hw::HWModuleOp module, ArrayRef< MemoryConfig > mems)
Lowers a group of memories from the same module.
hw::HWGeneratorSchemaOp schemaOp
FirMemLowering(ModuleOp circuit)
hw::HWModuleGeneratedOp createMemoryModule(FirMemConfig &mem, ArrayRef< seq::FirMemOp > memOps)
Creates the generated module for a given configuration.
FlatSymbolRefAttr getOrCreateSchema()
Find the schema or create it if it does not exist.
DenseMap< hw::HWModuleOp, size_t > moduleIndex
llvm::MapVector< FirMemConfig, SmallVector< seq::FirMemOp, 1 > > UniqueConfigs
A vector of unique FirMemConfigs and all the FirMemOps that use it.
FirMemConfig collectMemory(seq::FirMemOp op)
Determine the exact parametrization of the memory that should be generated for a given FirMemOp.
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:85
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
Definition seq.py:1
The configuration of a FIR memory.
SmallVector< int32_t, 1 > writeClockIDs