22 #include "mlir/IR/Dominance.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/STLFunctionalExtras.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/Support/Parallel.h"
31 using namespace circt;
32 using namespace firrtl;
36 size_t numReadPorts = 0;
37 size_t numWritePorts = 0;
38 size_t numReadWritePorts = 0;
40 SmallVector<int32_t> writeClockIDs;
42 for (
size_t i = 0, e = op.getNumResults(); i != e; ++i) {
43 auto portKind = op.getPortKind(i);
44 if (portKind == MemOp::PortKind::Read)
46 else if (portKind == MemOp::PortKind::Write) {
47 for (
auto *a : op.getResult(i).getUsers()) {
48 auto subfield = dyn_cast<SubfieldOp>(a);
49 if (!subfield || subfield.getFieldIndex() != 2)
51 auto clockPort = a->getResult(0);
52 for (
auto *b : clockPort.getUsers()) {
53 if (
auto connect = dyn_cast<FConnectLike>(b)) {
54 if (
connect.getDest() == clockPort) {
56 clockToLeader.insert({
connect.getSrc(), numWritePorts});
58 writeClockIDs.push_back(numWritePorts);
60 writeClockIDs.push_back(result.first->second);
72 auto width = op.getDataType().getBitWidthOrSentinel();
74 op.emitError(
"'firrtl.mem' should have simple type and known width");
85 *seq::symbolizeRUW(
unsigned(op.getRuw())),
96 struct LowerMemoryPass :
public LowerMemoryBase<LowerMemoryPass> {
99 hw::InnerSymbolNamespace &getModuleNamespace(FModuleLike module) {
100 return moduleNamespaces.try_emplace(module, module).first->second;
103 SmallVector<PortInfo> getMemoryModulePorts(
const FirMemory &mem);
104 FMemModuleOp emitMemoryModule(MemOp op,
const FirMemory &summary,
105 const SmallVectorImpl<PortInfo> &ports);
106 FMemModuleOp getOrCreateMemModule(MemOp op,
const FirMemory &summary,
107 const SmallVectorImpl<PortInfo> &ports,
109 FModuleOp createWrapperModule(MemOp op,
const FirMemory &summary,
111 InstanceOp emitMemoryInstance(MemOp op, FModuleOp module,
113 void lowerMemory(MemOp mem,
const FirMemory &summary,
bool shouldDedup);
114 LogicalResult runOnModule(FModuleOp module,
bool shouldDedup);
115 void runOnOperation()
override;
118 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
120 SymbolTable *symbolTable;
124 std::map<FirMemory, FMemModuleOp> memories;
128 SmallVector<PortInfo>
129 LowerMemoryPass::getMemoryModulePorts(
const FirMemory &mem) {
130 auto *context = &getContext();
143 SmallVector<PortInfo> ports;
147 {nameAttr, type, direction, hw::InnerSymAttr{}, loc, annotations});
150 auto makePortCommon = [&](StringRef prefix,
size_t idx,
FIRRTLType addrType) {
151 addPort(prefix + Twine(idx) +
"_addr", addrType,
Direction::In);
153 addPort(prefix + Twine(idx) +
"_clk", clockType,
Direction::In);
157 makePortCommon(
"R", i, addrType);
161 makePortCommon(
"RW", i, addrType);
163 addPort(
"RW" + Twine(i) +
"_wdata", dataType,
Direction::In);
167 addPort(
"RW" + Twine(i) +
"_wmask", maskType,
Direction::In);
171 makePortCommon(
"W", i, addrType);
182 LowerMemoryPass::emitMemoryModule(MemOp op,
const FirMemory &mem,
183 const SmallVectorImpl<PortInfo> &ports) {
185 auto newName = circuitNamespace.newName(mem.
modName.getValue(),
"ext");
189 auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock());
190 ++numCreatedMemModules;
191 auto moduleOp = b.create<FMemModuleOp>(
195 SymbolTable::setSymbolVisibility(moduleOp, SymbolTable::Visibility::Private);
200 LowerMemoryPass::getOrCreateMemModule(MemOp op,
const FirMemory &summary,
201 const SmallVectorImpl<PortInfo> &ports,
206 auto it = memories.find(summary);
207 if (it != memories.end())
213 auto module = emitMemoryModule(op, summary, ports);
218 memories[summary] = module;
223 void LowerMemoryPass::lowerMemory(MemOp mem,
const FirMemory &summary,
225 auto *context = &getContext();
226 auto ports = getMemoryModulePorts(summary);
229 auto newName = circuitNamespace.newName(mem.getName());
233 auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock());
234 auto wrapper = b.create<FModuleOp>(
235 mem->getLoc(), wrapperName,
237 SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);
241 auto memModule = getOrCreateMemModule(mem, summary, ports, shouldDedup);
242 b.setInsertionPointToStart(wrapper.getBodyBlock());
245 b.create<InstanceOp>(mem->getLoc(), memModule, memModule.getModuleName(),
246 mem.getNameKind(), mem.getAnnotations().getValue());
249 for (
auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(),
250 memInst.getResults())) {
251 if (wrapper.getPortDirection(dst.getArgNumber()) ==
Direction::Out)
252 b.create<StrictConnectOp>(mem->getLoc(), dst, src);
254 b.create<StrictConnectOp>(mem->getLoc(), src, dst);
259 auto inst = emitMemoryInstance(mem, wrapper, summary);
265 auto leafSym = memModule.getModuleNameAttr();
271 bool nlaUpdated =
false;
272 SmallVector<Annotation> newMemModAnnos;
273 OpBuilder nlaBuilder(context);
277 auto nlaSym = anno.
getMember<FlatSymbolRefAttr>(nonlocalAttr);
281 auto newNLAIter = processedNLAs.find(nlaSym.getAttr());
282 StringAttr newNLAName;
283 if (newNLAIter == processedNLAs.end()) {
287 dyn_cast<hw::HierPathOp>(symbolTable->lookup(nlaSym.getAttr()));
288 auto namepath = nla.getNamepath().getValue();
289 SmallVector<Attribute> newNamepath(namepath.begin(), namepath.end());
290 if (!nla.isComponent())
292 getInnerRefTo(inst, [&](
auto mod) -> hw::InnerSymbolNamespace & {
293 return getModuleNamespace(mod);
295 newNamepath.push_back(leafAttr);
297 nlaBuilder.setInsertionPointAfter(nla);
298 auto newNLA = cast<hw::HierPathOp>(nlaBuilder.clone(*nla));
300 context, circuitNamespace.newName(nla.getNameAttr().getValue())));
302 newNLAName = newNLA.getNameAttr();
303 processedNLAs[nlaSym.getAttr()] = newNLAName;
305 newNLAName = newNLAIter->getSecond();
308 newMemModAnnos.push_back(anno);
314 newAnnos.addAnnotations(newMemModAnnos);
315 newAnnos.applyToOperation(memInst);
323 SmallVector<SubfieldOp> accesses;
324 for (
auto *op : structValue.getUsers()) {
325 assert(isa<SubfieldOp>(op));
326 auto fieldAccess = cast<SubfieldOp>(op);
328 fieldAccess.getInput().getType().get().getElementIndex(field);
329 if (elemIndex && *elemIndex == fieldAccess.getFieldIndex())
330 accesses.push_back(fieldAccess);
335 InstanceOp LowerMemoryPass::emitMemoryInstance(MemOp op, FModuleOp module,
338 auto *context = &getContext();
339 auto memName = op.getName();
344 SmallVector<Type, 8> portTypes;
345 SmallVector<Direction> portDirections;
346 SmallVector<Attribute> portNames;
347 DenseMap<Operation *, size_t> returnHolder;
348 mlir::DominanceInfo domInfo(op->getParentOfType<FModuleOp>());
353 for (
unsigned memportKindIdx = 0; memportKindIdx != 3; ++memportKindIdx) {
354 MemOp::PortKind memportKind = MemOp::PortKind::Read;
355 auto *portLabel =
"R";
356 switch (memportKindIdx) {
360 memportKind = MemOp::PortKind::ReadWrite;
364 memportKind = MemOp::PortKind::Write;
371 unsigned portNumber = 0;
375 auto ui1Type = getType(1);
376 auto addressType = getType(std::max(1U, llvm::Log2_64_Ceil(summary.
depth)));
382 for (
size_t i = 0, e = op.getNumResults(); i != e; ++i) {
384 if (memportKind != op.getPortKind(i))
387 auto addPort = [&](
Direction direction, StringRef field, Type portType) {
390 for (
auto a : accesses)
391 returnHolder[a] = portTypes.size();
393 portTypes.push_back(portType);
394 portDirections.push_back(direction);
396 builder.getStringAttr(portLabel + Twine(portNumber) +
"_" + field));
399 auto getDriver = [&](StringRef field) -> Operation * {
401 for (
auto a : accesses) {
402 for (
auto *user : a->getUsers()) {
404 if (
auto connect = dyn_cast<FConnectLike>(user);
415 auto removeMask = [&](StringRef enable, StringRef
mask) {
417 auto *maskConnect = getDriver(mask);
421 auto *enConnect = getDriver(enable);
426 OpBuilder b(maskConnect);
427 if (domInfo.dominates(maskConnect, enConnect))
428 b.setInsertionPoint(enConnect);
430 auto andOp = b.create<AndPrimOp>(
431 op->getLoc(), maskConnect->getOperand(1), enConnect->getOperand(1));
432 enConnect->setOperand(1, andOp);
433 enConnect->moveAfter(andOp);
435 auto *maskField = maskConnect->getOperand(0).getDefiningOp();
436 maskConnect->erase();
440 if (memportKind == MemOp::PortKind::Read) {
445 }
else if (memportKind == MemOp::PortKind::ReadWrite) {
456 removeMask(
"wmode",
"wmask");
466 removeMask(
"en",
"mask");
476 auto inst =
builder.create<InstanceOp>(
478 op.getNameKind(), portDirections, portNames,
479 ArrayRef<Attribute>(),
480 ArrayRef<Attribute>(),
false,
481 op.getInnerSymAttr());
484 for (
auto [subfield, result] : returnHolder) {
485 subfield->getResult(0).replaceAllUsesWith(inst.getResult(result));
492 LogicalResult LowerMemoryPass::runOnModule(FModuleOp module,
bool shouldDedup) {
494 llvm::make_early_inc_range(module.getBodyBlock()->getOps<MemOp>())) {
496 if (!type_isa<UIntType>(op.getDataType()))
497 return op->emitError(
498 "memories should be flattened before running LowerMemory");
504 lowerMemory(op, summary, shouldDedup);
509 void LowerMemoryPass::runOnOperation() {
510 auto circuit = getOperation();
511 auto *body = circuit.getBodyBlock();
512 auto &instanceGraph = getAnalysis<InstanceGraph>();
513 symbolTable = &getAnalysis<SymbolTable>();
514 circuitNamespace.add(circuit);
518 auto *dut = instanceGraph.getTopLevelNode();
519 auto it = llvm::find_if(*body, [&](Operation &op) ->
bool {
522 if (it != body->end())
523 dut = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
526 DenseSet<Operation *> dutModuleSet;
533 for (
auto module : body->getOps<FModuleOp>()) {
535 auto shouldDedup = dutModuleSet.contains(module);
536 if (failed(runOnModule(module, shouldDedup)))
537 return signalPassFailure();
540 circuitNamespace.clear();
541 symbolTable =
nullptr;
546 return std::make_unique<LowerMemoryPass>();
assert(baseType &&"element must be base type")
static SmallVector< SubfieldOp > getAllFieldAccesses(Value structValue, StringRef field)
FirMemory getSummary(MemOp op)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction
This represents the direction of a single port.
constexpr const char * dutAnnoClass
hw::InnerRefAttr getInnerRefTo(const hw::InnerSymTarget &target, GetNamespaceCallback getNamespace)
Obtain an inner reference to the target (operation or port), adding an inner symbol as necessary.
std::unique_ptr< mlir::Pass > createLowerMemoryPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
The namespace of a CircuitOp, generally inhabited by modules.
bool isSeqMem() const
Check whether the memory is a seq mem.
StringAttr getFirMemoryName() const