21#include "mlir/IR/Dominance.h"
22#include "mlir/Pass/Pass.h"
23#include "llvm/ADT/DepthFirstIterator.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/STLFunctionalExtras.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/Support/Parallel.h"
33#define GEN_PASS_DEF_LOWERMEMORY
34#include "circt/Dialect/FIRRTL/Passes.h.inc"
39using namespace firrtl;
43 size_t numReadPorts = 0;
44 size_t numWritePorts = 0;
45 size_t numReadWritePorts = 0;
47 SmallVector<int32_t> writeClockIDs;
49 for (
size_t i = 0, e = op.getNumResults(); i != e; ++i) {
50 auto portKind = op.getPortKind(i);
51 if (portKind == MemOp::PortKind::Read)
53 else if (portKind == MemOp::PortKind::Write) {
54 for (
auto *a : op.getResult(i).getUsers()) {
55 auto subfield = dyn_cast<SubfieldOp>(a);
56 if (!subfield || subfield.getFieldIndex() != 2)
58 auto clockPort = a->getResult(0);
59 for (
auto *b : clockPort.getUsers()) {
60 if (
auto connect = dyn_cast<FConnectLike>(b)) {
61 if (connect.getDest() == clockPort) {
63 clockToLeader.insert({connect.getSrc(), numWritePorts});
65 writeClockIDs.push_back(numWritePorts);
67 writeClockIDs.push_back(result.first->second);
79 auto width = op.getDataType().getBitWidthOrSentinel();
81 op.emitError(
"'firrtl.mem' should have simple type and known width");
92 *seq::symbolizeRUW(
unsigned(op.getRuw())),
103struct LowerMemoryPass
104 :
public circt::firrtl::impl::LowerMemoryBase<LowerMemoryPass> {
108 return moduleNamespaces.try_emplace(moduleOp, moduleOp).first->second;
111 SmallVector<PortInfo> getMemoryModulePorts(
const FirMemory &mem);
112 FMemModuleOp emitMemoryModule(MemOp op,
const FirMemory &summary,
113 const SmallVectorImpl<PortInfo> &ports);
114 FMemModuleOp getOrCreateMemModule(MemOp op,
const FirMemory &summary,
115 const SmallVectorImpl<PortInfo> &ports,
117 FModuleOp createWrapperModule(MemOp op,
const FirMemory &summary,
119 InstanceOp emitMemoryInstance(MemOp op, FModuleOp moduleOp,
121 void lowerMemory(MemOp mem,
const FirMemory &summary,
bool shouldDedup);
122 LogicalResult runOnModule(FModuleOp moduleOp,
bool shouldDedup);
123 void runOnOperation()
override;
126 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
128 SymbolTable *symbolTable;
132 std::map<FirMemory, FMemModuleOp> memories;
135 SetVector<Operation *> operationsToErase;
140LowerMemoryPass::getMemoryModulePorts(
const FirMemory &mem) {
141 auto *context = &getContext();
145 FIRRTLType u1Type = UIntType::get(context, 1);
149 UIntType::get(context, std::max(1U, llvm::Log2_64_Ceil(mem.
depth)));
150 FIRRTLType clockType = ClockType::get(context);
151 Location loc = UnknownLoc::get(context);
154 SmallVector<PortInfo> ports;
156 auto nameAttr = StringAttr::get(context, name);
158 {nameAttr, type, direction, hw::InnerSymAttr{}, loc, annotations, {}});
161 auto makePortCommon = [&](StringRef prefix,
size_t idx,
FIRRTLType addrType) {
162 addPort(prefix + Twine(idx) +
"_addr", addrType, Direction::In);
163 addPort(prefix + Twine(idx) +
"_en", u1Type, Direction::In);
164 addPort(prefix + Twine(idx) +
"_clk", clockType, Direction::In);
168 makePortCommon(
"R", i, addrType);
169 addPort(
"R" + Twine(i) +
"_data", dataType, Direction::Out);
172 makePortCommon(
"RW", i, addrType);
173 addPort(
"RW" + Twine(i) +
"_wmode", u1Type, Direction::In);
174 addPort(
"RW" + Twine(i) +
"_wdata", dataType, Direction::In);
175 addPort(
"RW" + Twine(i) +
"_rdata", dataType, Direction::Out);
178 addPort(
"RW" + Twine(i) +
"_wmask", maskType, Direction::In);
182 makePortCommon(
"W", i, addrType);
183 addPort(
"W" + Twine(i) +
"_data", dataType, Direction::In);
186 addPort(
"W" + Twine(i) +
"_mask", maskType, Direction::In);
193LowerMemoryPass::emitMemoryModule(MemOp op,
const FirMemory &mem,
194 const SmallVectorImpl<PortInfo> &ports) {
196 StringRef prefix =
"";
198 prefix = mem.
prefix.getValue();
200 circuitNamespace.newName(prefix + mem.
modName.getValue(),
"ext");
201 auto moduleName = StringAttr::get(&getContext(), newName);
204 OpBuilder b(op->getParentOfType<FModuleOp>());
205 ++numCreatedMemModules;
206 auto moduleOp = FMemModuleOp::create(
210 *symbolizeRUWBehavior(
static_cast<uint32_t
>(mem.
readUnderWrite)));
211 SymbolTable::setSymbolVisibility(moduleOp, SymbolTable::Visibility::Private);
216LowerMemoryPass::getOrCreateMemModule(MemOp op,
const FirMemory &summary,
217 const SmallVectorImpl<PortInfo> &ports,
222 auto it = memories.find(summary);
223 if (it != memories.end())
229 auto moduleOp = emitMemoryModule(op, summary, ports);
234 memories[summary] = moduleOp;
239void LowerMemoryPass::lowerMemory(MemOp mem,
const FirMemory &summary,
241 auto *context = &getContext();
242 auto ports = getMemoryModulePorts(summary);
245 StringRef prefix =
"";
247 prefix = summary.
prefix.getValue();
248 auto newName = circuitNamespace.newName(prefix + mem.getName());
250 auto wrapperName = StringAttr::get(&getContext(), newName);
253 OpBuilder b(mem->getParentOfType<FModuleOp>());
254 auto wrapper = FModuleOp::create(
255 b, mem->getLoc(), wrapperName,
256 ConventionAttr::get(context, Convention::Internal), ports);
257 SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);
261 auto memModule = getOrCreateMemModule(mem, summary, ports, shouldDedup);
262 b.setInsertionPointToStart(wrapper.getBodyBlock());
263 auto memInst = InstanceOp::create(
264 b, mem->getLoc(), memModule, (mem.getName() +
"_ext").str(),
265 mem.getNameKind(), mem.getAnnotations().getValue());
269 memInst.getResults())) {
270 if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out)
271 MatchingConnectOp::create(b, mem->getLoc(), dst, src);
273 MatchingConnectOp::create(b, mem->getLoc(), src, dst);
278 auto inst = emitMemoryInstance(mem, wrapper, summary);
284 auto leafSym = memModule.getModuleNameAttr();
285 auto leafAttr = FlatSymbolRefAttr::get(wrapper.getModuleNameAttr());
289 auto nonlocalAttr = StringAttr::get(context,
"circt.nonlocal");
290 bool nlaUpdated =
false;
291 SmallVector<Annotation> newMemModAnnos;
292 OpBuilder nlaBuilder(context);
296 auto nlaSym = anno.
getMember<FlatSymbolRefAttr>(nonlocalAttr);
300 auto newNLAIter = processedNLAs.find(nlaSym.getAttr());
301 StringAttr newNLAName;
302 if (newNLAIter == processedNLAs.end()) {
306 dyn_cast<hw::HierPathOp>(symbolTable->lookup(nlaSym.getAttr()));
307 auto namepath = nla.getNamepath().getValue();
308 SmallVector<Attribute> newNamepath(namepath.begin(), namepath.end());
309 if (!nla.isComponent())
312 return getModuleNamespace(mod);
314 newNamepath.push_back(leafAttr);
316 nlaBuilder.setInsertionPointAfter(nla);
317 auto newNLA = cast<hw::HierPathOp>(nlaBuilder.clone(*nla));
318 newNLA.setSymNameAttr(StringAttr::get(
319 context, circuitNamespace.newName(nla.getNameAttr().getValue())));
320 newNLA.setNamepathAttr(ArrayAttr::get(context, newNamepath));
321 newNLAName = newNLA.getNameAttr();
322 processedNLAs[nlaSym.getAttr()] = newNLAName;
324 newNLAName = newNLAIter->getSecond();
325 anno.
setMember(
"circt.nonlocal", FlatSymbolRefAttr::get(newNLAName));
327 newMemModAnnos.push_back(anno);
331 memInst.setInnerSymAttr(hw::InnerSymAttr::get(leafSym));
333 newAnnos.addAnnotations(newMemModAnnos);
334 newAnnos.applyToOperation(memInst);
336 operationsToErase.insert(mem);
342 SmallVector<SubfieldOp> accesses;
343 for (
auto *op : structValue.getUsers()) {
344 assert(isa<SubfieldOp>(op));
345 auto fieldAccess = cast<SubfieldOp>(op);
347 fieldAccess.getInput().getType().base().getElementIndex(field);
348 if (elemIndex && *elemIndex == fieldAccess.getFieldIndex())
349 accesses.push_back(fieldAccess);
354InstanceOp LowerMemoryPass::emitMemoryInstance(MemOp op, FModuleOp module,
356 OpBuilder builder(op);
357 auto *context = &getContext();
358 auto memName = op.getName();
363 SmallVector<Type, 8> portTypes;
364 SmallVector<Direction> portDirections;
365 SmallVector<Attribute> portNames;
366 SmallVector<Attribute> domainInfo;
367 DenseMap<Operation *, size_t> returnHolder;
368 mlir::DominanceInfo domInfo(op->getParentOfType<FModuleOp>());
373 for (
unsigned memportKindIdx = 0; memportKindIdx != 3; ++memportKindIdx) {
374 MemOp::PortKind memportKind = MemOp::PortKind::Read;
375 auto *portLabel =
"R";
376 switch (memportKindIdx) {
380 memportKind = MemOp::PortKind::ReadWrite;
384 memportKind = MemOp::PortKind::Write;
391 unsigned portNumber = 0;
394 auto getType = [&](
size_t width) {
return UIntType::get(context, width); };
395 auto ui1Type = getType(1);
396 auto addressType = getType(std::max(1U, llvm::Log2_64_Ceil(summary.
depth)));
397 auto dataType = UIntType::get(context, summary.
dataWidth);
398 auto clockType = ClockType::get(context);
402 for (
size_t i = 0, e = op.getNumResults(); i != e; ++i) {
404 if (memportKind != op.getPortKind(i))
407 auto addPort = [&](
Direction direction, StringRef field, Type portType) {
410 for (
auto a : accesses)
411 returnHolder[a] = portTypes.size();
413 portTypes.push_back(portType);
414 portDirections.push_back(direction);
416 builder.getStringAttr(portLabel + Twine(portNumber) +
"_" + field));
417 domainInfo.push_back(builder.getArrayAttr({}));
420 auto getDriver = [&](StringRef field) -> Operation * {
422 for (
auto a : accesses) {
423 for (
auto *user : a->getUsers()) {
425 if (
auto connect = dyn_cast<FConnectLike>(user);
436 auto removeMask = [&](StringRef enable, StringRef
mask) {
438 auto *maskConnect = getDriver(
mask);
442 auto *enConnect = getDriver(enable);
447 OpBuilder b(maskConnect);
448 if (domInfo.dominates(maskConnect, enConnect))
449 b.setInsertionPoint(enConnect);
452 AndPrimOp::create(b, op->getLoc(), maskConnect->getOperand(1),
453 enConnect->getOperand(1));
454 enConnect->setOperand(1, andOp);
455 enConnect->moveAfter(andOp);
457 auto *maskField = maskConnect->getOperand(0).getDefiningOp();
458 operationsToErase.insert(maskConnect);
459 operationsToErase.insert(maskField);
462 if (memportKind == MemOp::PortKind::Read) {
463 addPort(Direction::In,
"addr", addressType);
464 addPort(Direction::In,
"en", ui1Type);
465 addPort(Direction::In,
"clk", clockType);
466 addPort(Direction::Out,
"data", dataType);
467 }
else if (memportKind == MemOp::PortKind::ReadWrite) {
468 addPort(Direction::In,
"addr", addressType);
469 addPort(Direction::In,
"en", ui1Type);
470 addPort(Direction::In,
"clk", clockType);
471 addPort(Direction::In,
"wmode", ui1Type);
472 addPort(Direction::In,
"wdata", dataType);
473 addPort(Direction::Out,
"rdata", dataType);
476 addPort(Direction::In,
"wmask", getType(summary.
maskBits));
478 removeMask(
"wmode",
"wmask");
480 addPort(Direction::In,
"addr", addressType);
481 addPort(Direction::In,
"en", ui1Type);
482 addPort(Direction::In,
"clk", clockType);
483 addPort(Direction::In,
"data", dataType);
486 addPort(Direction::In,
"mask", getType(summary.
maskBits));
488 removeMask(
"en",
"mask");
498 auto inst = InstanceOp::create(
499 builder, op.getLoc(), portTypes, module.getNameAttr(),
502 ArrayRef<Attribute>(),
503 ArrayRef<Attribute>(),
504 ArrayRef<Attribute>(),
false,
505 false, op.getInnerSymAttr());
508 for (
auto [subfield, result] : returnHolder) {
509 subfield->getResult(0).replaceAllUsesWith(inst.getResult(result));
510 operationsToErase.insert(subfield);
516LogicalResult LowerMemoryPass::runOnModule(FModuleOp moduleOp,
518 assert(operationsToErase.empty() &&
"operationsToErase must be empty");
520 auto result = moduleOp.walk([&](MemOp op) {
522 if (!type_isa<UIntType>(op.getDataType())) {
523 op->emitError(
"memories should be flattened before running LowerMemory");
524 return WalkResult::interrupt();
529 lowerMemory(op, summary, shouldDedup);
531 return WalkResult::advance();
534 if (result.wasInterrupted())
537 for (Operation *op : operationsToErase)
540 operationsToErase.clear();
545void LowerMemoryPass::runOnOperation() {
546 auto circuit = getOperation();
547 auto &instanceInfo = getAnalysis<InstanceInfo>();
548 symbolTable = &getAnalysis<SymbolTable>();
549 circuitNamespace.add(circuit);
556 for (
auto moduleOp : circuit.
getBodyBlock()->getOps<FModuleOp>()) {
557 auto shouldDedup = instanceInfo.anyInstanceInEffectiveDesign(moduleOp);
558 if (failed(runOnModule(moduleOp, shouldDedup)))
559 return signalPassFailure();
562 circuitNamespace.clear();
563 symbolTable =
nullptr;
assert(baseType &&"element must be base type")
FirMemory getSummary(MemOp op)
static SmallVector< SubfieldOp > getAllFieldAccesses(Value structValue, StringRef field)
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
connect(destination, source)
Direction
This represents the direction of a single port.
hw::InnerRefAttr getInnerRefTo(const hw::InnerSymTarget &target, GetNamespaceCallback getNamespace)
Obtain an inner reference to the target (operation or port), adding an inner symbol as necessary.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
The namespace of a CircuitOp, generally inhabited by modules.
bool isSeqMem() const
Check whether the memory is a seq mem.
StringAttr getFirMemoryName() const