CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerMemory.cpp
Go to the documentation of this file.
1//===- LowerMemory.cpp - Lower Memories -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//===----------------------------------------------------------------------===//
7//
8// This file defines the LowerMemories pass.
9//
10//===----------------------------------------------------------------------===//
11
21#include "mlir/IR/Dominance.h"
22#include "mlir/Pass/Pass.h"
23#include "llvm/ADT/DepthFirstIterator.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/STLFunctionalExtras.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/Support/Parallel.h"
28#include <optional>
29#include <set>
30
31namespace circt {
32namespace firrtl {
33#define GEN_PASS_DEF_LOWERMEMORY
34#include "circt/Dialect/FIRRTL/Passes.h.inc"
35} // namespace firrtl
36} // namespace circt
37
38using namespace circt;
39using namespace firrtl;
40
41// Extract all the relevant attributes from the MemOp and return the FirMemory.
43 size_t numReadPorts = 0;
44 size_t numWritePorts = 0;
45 size_t numReadWritePorts = 0;
47 SmallVector<int32_t> writeClockIDs;
48
49 for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
50 auto portKind = op.getPortKind(i);
51 if (portKind == MemOp::PortKind::Read)
52 ++numReadPorts;
53 else if (portKind == MemOp::PortKind::Write) {
54 for (auto *a : op.getResult(i).getUsers()) {
55 auto subfield = dyn_cast<SubfieldOp>(a);
56 if (!subfield || subfield.getFieldIndex() != 2)
57 continue;
58 auto clockPort = a->getResult(0);
59 for (auto *b : clockPort.getUsers()) {
60 if (auto connect = dyn_cast<FConnectLike>(b)) {
61 if (connect.getDest() == clockPort) {
62 auto result =
63 clockToLeader.insert({connect.getSrc(), numWritePorts});
64 if (result.second) {
65 writeClockIDs.push_back(numWritePorts);
66 } else {
67 writeClockIDs.push_back(result.first->second);
68 }
69 }
70 }
71 }
72 break;
73 }
74 ++numWritePorts;
75 } else
76 ++numReadWritePorts;
77 }
78
79 auto width = op.getDataType().getBitWidthOrSentinel();
80 if (width <= 0) {
81 op.emitError("'firrtl.mem' should have simple type and known width");
82 width = 0;
83 }
84 return {numReadPorts,
85 numWritePorts,
86 numReadWritePorts,
87 (size_t)width,
88 op.getDepth(),
89 op.getReadLatency(),
90 op.getWriteLatency(),
91 op.getMaskBits(),
92 *seq::symbolizeRUW(unsigned(op.getRuw())),
93 seq::WUW::PortOrder,
94 writeClockIDs,
95 op.getNameAttr(),
96 op.getMaskBits() > 1,
97 op.getInitAttr(),
98 op.getPrefixAttr(),
99 op.getLoc()};
100}
101
102namespace {
103struct LowerMemoryPass
104 : public circt::firrtl::impl::LowerMemoryBase<LowerMemoryPass> {
105
106 /// Get the cached namespace for a module.
107 hw::InnerSymbolNamespace &getModuleNamespace(FModuleLike moduleOp) {
108 return moduleNamespaces.try_emplace(moduleOp, moduleOp).first->second;
109 }
110
111 SmallVector<PortInfo> getMemoryModulePorts(const FirMemory &mem);
112 FMemModuleOp emitMemoryModule(MemOp op, const FirMemory &summary,
113 const SmallVectorImpl<PortInfo> &ports);
114 FMemModuleOp getOrCreateMemModule(MemOp op, const FirMemory &summary,
115 const SmallVectorImpl<PortInfo> &ports,
116 bool shouldDedup);
117 FModuleOp createWrapperModule(MemOp op, const FirMemory &summary,
118 bool shouldDedup);
119 InstanceOp emitMemoryInstance(MemOp op, FModuleOp moduleOp,
120 const FirMemory &summary);
121 void lowerMemory(MemOp mem, const FirMemory &summary, bool shouldDedup);
122 LogicalResult runOnModule(FModuleOp moduleOp, bool shouldDedup);
123 void runOnOperation() override;
124
125 /// Cached module namespaces.
126 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
127 CircuitNamespace circuitNamespace;
128 SymbolTable *symbolTable;
129
130 /// The set of all memories seen so far. This is used to "deduplicate"
131 /// memories by emitting modules one module for equivalent memories.
132 std::map<FirMemory, FMemModuleOp> memories;
133
134 /// A sequence of operations that should be erased later.
135 SetVector<Operation *> operationsToErase;
136};
137} // end anonymous namespace
138
139SmallVector<PortInfo>
140LowerMemoryPass::getMemoryModulePorts(const FirMemory &mem) {
141 auto *context = &getContext();
142
143 // We don't need a single bit mask, it can be combined with enable. Create
144 // an unmasked memory if maskBits = 1.
145 FIRRTLType u1Type = UIntType::get(context, 1);
146 FIRRTLType dataType = UIntType::get(context, mem.dataWidth);
147 FIRRTLType maskType = UIntType::get(context, mem.maskBits);
148 FIRRTLType addrType =
149 UIntType::get(context, std::max(1U, llvm::Log2_64_Ceil(mem.depth)));
150 FIRRTLType clockType = ClockType::get(context);
151 Location loc = UnknownLoc::get(context);
152 AnnotationSet annotations = AnnotationSet(context);
153
154 SmallVector<PortInfo> ports;
155 auto addPort = [&](const Twine &name, FIRRTLType type, Direction direction) {
156 auto nameAttr = StringAttr::get(context, name);
157 ports.push_back(
158 {nameAttr, type, direction, hw::InnerSymAttr{}, loc, annotations, {}});
159 };
160
161 auto makePortCommon = [&](StringRef prefix, size_t idx, FIRRTLType addrType) {
162 addPort(prefix + Twine(idx) + "_addr", addrType, Direction::In);
163 addPort(prefix + Twine(idx) + "_en", u1Type, Direction::In);
164 addPort(prefix + Twine(idx) + "_clk", clockType, Direction::In);
165 };
166
167 for (size_t i = 0, e = mem.numReadPorts; i != e; ++i) {
168 makePortCommon("R", i, addrType);
169 addPort("R" + Twine(i) + "_data", dataType, Direction::Out);
170 }
171 for (size_t i = 0, e = mem.numReadWritePorts; i != e; ++i) {
172 makePortCommon("RW", i, addrType);
173 addPort("RW" + Twine(i) + "_wmode", u1Type, Direction::In);
174 addPort("RW" + Twine(i) + "_wdata", dataType, Direction::In);
175 addPort("RW" + Twine(i) + "_rdata", dataType, Direction::Out);
176 // Ignore mask port, if maskBits =1
177 if (mem.isMasked)
178 addPort("RW" + Twine(i) + "_wmask", maskType, Direction::In);
179 }
180
181 for (size_t i = 0, e = mem.numWritePorts; i != e; ++i) {
182 makePortCommon("W", i, addrType);
183 addPort("W" + Twine(i) + "_data", dataType, Direction::In);
184 // Ignore mask port, if maskBits =1
185 if (mem.isMasked)
186 addPort("W" + Twine(i) + "_mask", maskType, Direction::In);
187 }
188
189 return ports;
190}
191
192FMemModuleOp
193LowerMemoryPass::emitMemoryModule(MemOp op, const FirMemory &mem,
194 const SmallVectorImpl<PortInfo> &ports) {
195 // Get a non-colliding name for the memory module, and update the summary.
196 StringRef prefix = "";
197 if (mem.prefix)
198 prefix = mem.prefix.getValue();
199 auto newName =
200 circuitNamespace.newName(prefix + mem.modName.getValue(), "ext");
201 auto moduleName = StringAttr::get(&getContext(), newName);
202
203 // Insert the memory module just above the current module.
204 OpBuilder b(op->getParentOfType<FModuleOp>());
205 ++numCreatedMemModules;
206 auto moduleOp = FMemModuleOp::create(
207 b, mem.loc, moduleName, ports, mem.numReadPorts, mem.numWritePorts,
209 mem.writeLatency, mem.depth,
210 *symbolizeRUWBehavior(static_cast<uint32_t>(mem.readUnderWrite)));
211 SymbolTable::setSymbolVisibility(moduleOp, SymbolTable::Visibility::Private);
212 return moduleOp;
213}
214
215FMemModuleOp
216LowerMemoryPass::getOrCreateMemModule(MemOp op, const FirMemory &summary,
217 const SmallVectorImpl<PortInfo> &ports,
218 bool shouldDedup) {
219 // Try to find a matching memory blackbox that we already created. If
220 // shouldDedup is true, we will just generate a new memory module.
221 if (shouldDedup) {
222 auto it = memories.find(summary);
223 if (it != memories.end())
224 return it->second;
225 }
226
227 // Create a new module for this memory. This can update the name recorded in
228 // the memory's summary.
229 auto moduleOp = emitMemoryModule(op, summary, ports);
230
231 // Record the memory module. We don't want to use this module for other
232 // memories, then we don't add it to the table.
233 if (shouldDedup)
234 memories[summary] = moduleOp;
235
236 return moduleOp;
237}
238
239void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
240 bool shouldDedup) {
241 auto *context = &getContext();
242 auto ports = getMemoryModulePorts(summary);
243
244 // Get a non-colliding name for the memory module, and update the summary.
245 StringRef prefix = "";
246 if (summary.prefix)
247 prefix = summary.prefix.getValue();
248 auto newName = circuitNamespace.newName(prefix + mem.getName());
249
250 auto wrapperName = StringAttr::get(&getContext(), newName);
251
252 // Create the wrapper module, inserting it just before the current module.
253 OpBuilder b(mem->getParentOfType<FModuleOp>());
254 auto wrapper = FModuleOp::create(
255 b, mem->getLoc(), wrapperName,
256 ConventionAttr::get(context, Convention::Internal), ports);
257 SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);
258
259 // Create an instance of the external memory module. The instance has the
260 // same name as the target module.
261 auto memModule = getOrCreateMemModule(mem, summary, ports, shouldDedup);
262 b.setInsertionPointToStart(wrapper.getBodyBlock());
263 auto memInst = InstanceOp::create(
264 b, mem->getLoc(), memModule, (mem.getName() + "_ext").str(),
265 mem.getNameKind(), mem.getAnnotations().getValue());
266
267 // Wire all the ports together.
268 for (auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(),
269 memInst.getResults())) {
270 if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out)
271 MatchingConnectOp::create(b, mem->getLoc(), dst, src);
272 else
273 MatchingConnectOp::create(b, mem->getLoc(), src, dst);
274 }
275
276 // Create an instance of the wrapper memory module, which will replace the
277 // original mem op.
278 auto inst = emitMemoryInstance(mem, wrapper, summary);
279
280 // We fixup the annotations here. We will be copying all annotations on to the
281 // module op, so we have to fix up the NLA to have the module as the leaf
282 // element.
283
284 auto leafSym = memModule.getModuleNameAttr();
285 auto leafAttr = FlatSymbolRefAttr::get(wrapper.getModuleNameAttr());
286
287 // NLAs that we have already processed.
289 auto nonlocalAttr = StringAttr::get(context, "circt.nonlocal");
290 bool nlaUpdated = false;
291 SmallVector<Annotation> newMemModAnnos;
292 OpBuilder nlaBuilder(context);
293
294 AnnotationSet::removeAnnotations(memInst, [&](Annotation anno) -> bool {
295 // We're only looking for non-local annotations.
296 auto nlaSym = anno.getMember<FlatSymbolRefAttr>(nonlocalAttr);
297 if (!nlaSym)
298 return false;
299 // If we have already seen this NLA, don't re-process it.
300 auto newNLAIter = processedNLAs.find(nlaSym.getAttr());
301 StringAttr newNLAName;
302 if (newNLAIter == processedNLAs.end()) {
303
304 // Update the NLA path to have the additional wrapper module.
305 auto nla =
306 dyn_cast<hw::HierPathOp>(symbolTable->lookup(nlaSym.getAttr()));
307 auto namepath = nla.getNamepath().getValue();
308 SmallVector<Attribute> newNamepath(namepath.begin(), namepath.end());
309 if (!nla.isComponent())
310 newNamepath.back() =
311 getInnerRefTo(inst, [&](auto mod) -> hw::InnerSymbolNamespace & {
312 return getModuleNamespace(mod);
313 });
314 newNamepath.push_back(leafAttr);
315
316 nlaBuilder.setInsertionPointAfter(nla);
317 auto newNLA = cast<hw::HierPathOp>(nlaBuilder.clone(*nla));
318 newNLA.setSymNameAttr(StringAttr::get(
319 context, circuitNamespace.newName(nla.getNameAttr().getValue())));
320 newNLA.setNamepathAttr(ArrayAttr::get(context, newNamepath));
321 newNLAName = newNLA.getNameAttr();
322 processedNLAs[nlaSym.getAttr()] = newNLAName;
323 } else
324 newNLAName = newNLAIter->getSecond();
325 anno.setMember("circt.nonlocal", FlatSymbolRefAttr::get(newNLAName));
326 nlaUpdated = true;
327 newMemModAnnos.push_back(anno);
328 return true;
329 });
330 if (nlaUpdated) {
331 memInst.setInnerSymAttr(hw::InnerSymAttr::get(leafSym));
332 AnnotationSet newAnnos(memInst);
333 newAnnos.addAnnotations(newMemModAnnos);
334 newAnnos.applyToOperation(memInst);
335 }
336 operationsToErase.insert(mem);
337 ++numLoweredMems;
338}
339
340static SmallVector<SubfieldOp> getAllFieldAccesses(Value structValue,
341 StringRef field) {
342 SmallVector<SubfieldOp> accesses;
343 for (auto *op : structValue.getUsers()) {
344 assert(isa<SubfieldOp>(op));
345 auto fieldAccess = cast<SubfieldOp>(op);
346 auto elemIndex =
347 fieldAccess.getInput().getType().base().getElementIndex(field);
348 if (elemIndex && *elemIndex == fieldAccess.getFieldIndex())
349 accesses.push_back(fieldAccess);
350 }
351 return accesses;
352}
353
354InstanceOp LowerMemoryPass::emitMemoryInstance(MemOp op, FModuleOp module,
355 const FirMemory &summary) {
356 OpBuilder builder(op);
357 auto *context = &getContext();
358 auto memName = op.getName();
359 if (memName.empty())
360 memName = "mem";
361
362 // Process each port in turn.
363 SmallVector<Type, 8> portTypes;
364 SmallVector<Direction> portDirections;
365 SmallVector<Attribute> portNames;
366 SmallVector<Attribute> domainInfo;
367 DenseMap<Operation *, size_t> returnHolder;
368 mlir::DominanceInfo domInfo(op->getParentOfType<FModuleOp>());
369
370 // The result values of the memory are not necessarily in the same order as
371 // the memory module that we're lowering to. We need to lower the read
372 // ports before the read/write ports, before the write ports.
373 for (unsigned memportKindIdx = 0; memportKindIdx != 3; ++memportKindIdx) {
374 MemOp::PortKind memportKind = MemOp::PortKind::Read;
375 auto *portLabel = "R";
376 switch (memportKindIdx) {
377 default:
378 break;
379 case 1:
380 memportKind = MemOp::PortKind::ReadWrite;
381 portLabel = "RW";
382 break;
383 case 2:
384 memportKind = MemOp::PortKind::Write;
385 portLabel = "W";
386 break;
387 }
388
389 // This is set to the count of the kind of memport we're emitting, for
390 // label names.
391 unsigned portNumber = 0;
392
393 // Get an unsigned type with the specified width.
394 auto getType = [&](size_t width) { return UIntType::get(context, width); };
395 auto ui1Type = getType(1);
396 auto addressType = getType(std::max(1U, llvm::Log2_64_Ceil(summary.depth)));
397 auto dataType = UIntType::get(context, summary.dataWidth);
398 auto clockType = ClockType::get(context);
399
400 // Memories return multiple structs, one for each port, which means we
401 // have two layers of type to split apart.
402 for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
403 // Process all of one kind before the next.
404 if (memportKind != op.getPortKind(i))
405 continue;
406
407 auto addPort = [&](Direction direction, StringRef field, Type portType) {
408 // Map subfields of the memory port to module ports.
409 auto accesses = getAllFieldAccesses(op.getResult(i), field);
410 for (auto a : accesses)
411 returnHolder[a] = portTypes.size();
412 // Record the new port information.
413 portTypes.push_back(portType);
414 portDirections.push_back(direction);
415 portNames.push_back(
416 builder.getStringAttr(portLabel + Twine(portNumber) + "_" + field));
417 domainInfo.push_back(builder.getArrayAttr({}));
418 };
419
420 auto getDriver = [&](StringRef field) -> Operation * {
421 auto accesses = getAllFieldAccesses(op.getResult(i), field);
422 for (auto a : accesses) {
423 for (auto *user : a->getUsers()) {
424 // If this is a connect driving a value to the field, return it.
425 if (auto connect = dyn_cast<FConnectLike>(user);
426 connect && connect.getDest() == a)
427 return connect;
428 }
429 }
430 return nullptr;
431 };
432
433 // Find the value connected to the enable and 'and' it with the mask,
434 // and then remove the mask entirely. This is used to remove the mask when
435 // it is 1 bit.
436 auto removeMask = [&](StringRef enable, StringRef mask) {
437 // Get the connect which drives a value to the mask element.
438 auto *maskConnect = getDriver(mask);
439 if (!maskConnect)
440 return;
441 // Get the connect which drives a value to the en element
442 auto *enConnect = getDriver(enable);
443 if (!enConnect)
444 return;
445 // Find the proper place to create the And operation. The mask and en
446 // signals must both dominate the new operation.
447 OpBuilder b(maskConnect);
448 if (domInfo.dominates(maskConnect, enConnect))
449 b.setInsertionPoint(enConnect);
450 // 'and' the enable and mask signals together and use it as the enable.
451 auto andOp =
452 AndPrimOp::create(b, op->getLoc(), maskConnect->getOperand(1),
453 enConnect->getOperand(1));
454 enConnect->setOperand(1, andOp);
455 enConnect->moveAfter(andOp);
456 // Erase the old mask connect.
457 auto *maskField = maskConnect->getOperand(0).getDefiningOp();
458 operationsToErase.insert(maskConnect);
459 operationsToErase.insert(maskField);
460 };
461
462 if (memportKind == MemOp::PortKind::Read) {
463 addPort(Direction::In, "addr", addressType);
464 addPort(Direction::In, "en", ui1Type);
465 addPort(Direction::In, "clk", clockType);
466 addPort(Direction::Out, "data", dataType);
467 } else if (memportKind == MemOp::PortKind::ReadWrite) {
468 addPort(Direction::In, "addr", addressType);
469 addPort(Direction::In, "en", ui1Type);
470 addPort(Direction::In, "clk", clockType);
471 addPort(Direction::In, "wmode", ui1Type);
472 addPort(Direction::In, "wdata", dataType);
473 addPort(Direction::Out, "rdata", dataType);
474 // Ignore mask port, if maskBits =1
475 if (summary.isMasked)
476 addPort(Direction::In, "wmask", getType(summary.maskBits));
477 else
478 removeMask("wmode", "wmask");
479 } else {
480 addPort(Direction::In, "addr", addressType);
481 addPort(Direction::In, "en", ui1Type);
482 addPort(Direction::In, "clk", clockType);
483 addPort(Direction::In, "data", dataType);
484 // Ignore mask port, if maskBits == 1
485 if (summary.isMasked)
486 addPort(Direction::In, "mask", getType(summary.maskBits));
487 else
488 removeMask("en", "mask");
489 }
490
491 ++portNumber;
492 }
493 }
494
495 // Create the instance to replace the memop. The instance name matches the
496 // name of the original memory module before deduplication.
497 // TODO: how do we lower port annotations?
498 auto inst = InstanceOp::create(
499 builder, op.getLoc(), portTypes, module.getNameAttr(),
500 summary.getFirMemoryName(), op.getNameKind(), portDirections, portNames,
501 domainInfo,
502 /*annotations=*/ArrayRef<Attribute>(),
503 /*portAnnotations=*/ArrayRef<Attribute>(),
504 /*layers=*/ArrayRef<Attribute>(), /*lowerToBind=*/false,
505 /*doNotPrint=*/false, op.getInnerSymAttr());
506
507 // Update all users of the result of read ports
508 for (auto [subfield, result] : returnHolder) {
509 subfield->getResult(0).replaceAllUsesWith(inst.getResult(result));
510 operationsToErase.insert(subfield);
511 }
512
513 return inst;
514}
515
516LogicalResult LowerMemoryPass::runOnModule(FModuleOp moduleOp,
517 bool shouldDedup) {
518 assert(operationsToErase.empty() && "operationsToErase must be empty");
519
520 auto result = moduleOp.walk([&](MemOp op) {
521 // Check that the memory has been properly lowered already.
522 if (!type_isa<UIntType>(op.getDataType())) {
523 op->emitError("memories should be flattened before running LowerMemory");
524 return WalkResult::interrupt();
525 }
526
527 auto summary = getSummary(op);
528 if (summary.isSeqMem())
529 lowerMemory(op, summary, shouldDedup);
530
531 return WalkResult::advance();
532 });
533
534 if (result.wasInterrupted())
535 return failure();
536
537 for (Operation *op : operationsToErase)
538 op->erase();
539
540 operationsToErase.clear();
541
542 return success();
543}
544
545void LowerMemoryPass::runOnOperation() {
546 auto circuit = getOperation();
547 auto &instanceInfo = getAnalysis<InstanceInfo>();
548 symbolTable = &getAnalysis<SymbolTable>();
549 circuitNamespace.add(circuit);
550
551 // We iterate the circuit from top-to-bottom. This ensures that we get
552 // consistent memory names. (Memory modules will be inserted before the
553 // module we are processing to prevent these being unnecessarily visited.)
554 // Deduplication of memories is allowed if the module is under the "effective"
555 // design-under-test (DUT).
556 for (auto moduleOp : circuit.getBodyBlock()->getOps<FModuleOp>()) {
557 auto shouldDedup = instanceInfo.anyInstanceInEffectiveDesign(moduleOp);
558 if (failed(runOnModule(moduleOp, shouldDedup)))
559 return signalPassFailure();
560 }
561
562 circuitNamespace.clear();
563 symbolTable = nullptr;
564 memories.clear();
565}
assert(baseType &&"element must be base type")
FirMemory getSummary(MemOp op)
static SmallVector< SubfieldOp > getAllFieldAccesses(Value structValue, StringRef field)
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
connect(destination, source)
Definition support.py:39
Direction
This represents the direction of a single port.
hw::InnerRefAttr getInnerRefTo(const hw::InnerSymTarget &target, GetNamespaceCallback getNamespace)
Obtain an inner reference to the target (operation or port), adding an inner symbol as necessary.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
The namespace of a CircuitOp, generally inhabited by modules.
Definition Namespace.h:24
bool isSeqMem() const
Check whether the memory is a seq mem.
Definition FIRRTLOps.h:214
StringAttr getFirMemoryName() const