21 #include "mlir/IR/Dominance.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/STLFunctionalExtras.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/Support/Parallel.h"
33 #define GEN_PASS_DEF_LOWERMEMORY
34 #include "circt/Dialect/FIRRTL/Passes.h.inc"
38 using namespace circt;
39 using namespace firrtl;
43 size_t numReadPorts = 0;
44 size_t numWritePorts = 0;
45 size_t numReadWritePorts = 0;
47 SmallVector<int32_t> writeClockIDs;
49 for (
size_t i = 0, e = op.getNumResults(); i != e; ++i) {
50 auto portKind = op.getPortKind(i);
51 if (portKind == MemOp::PortKind::Read)
53 else if (portKind == MemOp::PortKind::Write) {
54 for (
auto *a : op.getResult(i).getUsers()) {
55 auto subfield = dyn_cast<SubfieldOp>(a);
56 if (!subfield || subfield.getFieldIndex() != 2)
58 auto clockPort = a->getResult(0);
59 for (
auto *b : clockPort.getUsers()) {
60 if (
auto connect = dyn_cast<FConnectLike>(b)) {
61 if (
connect.getDest() == clockPort) {
63 clockToLeader.insert({
connect.getSrc(), numWritePorts});
65 writeClockIDs.push_back(numWritePorts);
67 writeClockIDs.push_back(result.first->second);
79 auto width = op.getDataType().getBitWidthOrSentinel();
81 op.emitError(
"'firrtl.mem' should have simple type and known width");
92 *seq::symbolizeRUW(
unsigned(op.getRuw())),
103 struct LowerMemoryPass
104 :
public circt::firrtl::impl::LowerMemoryBase<LowerMemoryPass> {
107 hw::InnerSymbolNamespace &getModuleNamespace(FModuleLike module) {
108 return moduleNamespaces.try_emplace(module, module).first->second;
111 SmallVector<PortInfo> getMemoryModulePorts(
const FirMemory &mem);
112 FMemModuleOp emitMemoryModule(MemOp op,
const FirMemory &summary,
113 const SmallVectorImpl<PortInfo> &ports);
114 FMemModuleOp getOrCreateMemModule(MemOp op,
const FirMemory &summary,
115 const SmallVectorImpl<PortInfo> &ports,
117 FModuleOp createWrapperModule(MemOp op,
const FirMemory &summary,
119 InstanceOp emitMemoryInstance(MemOp op, FModuleOp module,
121 void lowerMemory(MemOp mem,
const FirMemory &summary,
bool shouldDedup);
122 LogicalResult runOnModule(FModuleOp module,
bool shouldDedup);
123 void runOnOperation()
override;
126 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
128 SymbolTable *symbolTable;
132 std::map<FirMemory, FMemModuleOp> memories;
136 SmallVector<PortInfo>
137 LowerMemoryPass::getMemoryModulePorts(
const FirMemory &mem) {
138 auto *context = &getContext();
151 SmallVector<PortInfo> ports;
155 {nameAttr, type, direction, hw::InnerSymAttr{}, loc, annotations});
158 auto makePortCommon = [&](StringRef prefix,
size_t idx,
FIRRTLType addrType) {
159 addPort(prefix + Twine(idx) +
"_addr", addrType,
Direction::In);
161 addPort(prefix + Twine(idx) +
"_clk", clockType,
Direction::In);
165 makePortCommon(
"R", i, addrType);
169 makePortCommon(
"RW", i, addrType);
171 addPort(
"RW" + Twine(i) +
"_wdata", dataType,
Direction::In);
175 addPort(
"RW" + Twine(i) +
"_wmask", maskType,
Direction::In);
179 makePortCommon(
"W", i, addrType);
190 LowerMemoryPass::emitMemoryModule(MemOp op,
const FirMemory &mem,
191 const SmallVectorImpl<PortInfo> &ports) {
193 auto newName = circuitNamespace.newName(mem.
modName.getValue(),
"ext");
197 auto b = OpBuilder::atBlockEnd(getOperation().
getBodyBlock());
198 ++numCreatedMemModules;
199 auto moduleOp = b.create<FMemModuleOp>(
203 SymbolTable::setSymbolVisibility(moduleOp, SymbolTable::Visibility::Private);
208 LowerMemoryPass::getOrCreateMemModule(MemOp op,
const FirMemory &summary,
209 const SmallVectorImpl<PortInfo> &ports,
214 auto it = memories.find(summary);
215 if (it != memories.end())
221 auto module = emitMemoryModule(op, summary, ports);
226 memories[summary] = module;
231 void LowerMemoryPass::lowerMemory(MemOp mem,
const FirMemory &summary,
233 auto *context = &getContext();
234 auto ports = getMemoryModulePorts(summary);
237 auto newName = circuitNamespace.newName(mem.getName());
241 auto b = OpBuilder::atBlockEnd(getOperation().
getBodyBlock());
242 auto wrapper = b.create<FModuleOp>(
243 mem->getLoc(), wrapperName,
245 SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);
249 auto memModule = getOrCreateMemModule(mem, summary, ports, shouldDedup);
250 b.setInsertionPointToStart(wrapper.getBodyBlock());
253 b.create<InstanceOp>(mem->getLoc(), memModule, memModule.getModuleName(),
254 mem.getNameKind(), mem.getAnnotations().getValue());
257 for (
auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(),
258 memInst.getResults())) {
259 if (wrapper.getPortDirection(dst.getArgNumber()) ==
Direction::Out)
260 b.create<MatchingConnectOp>(mem->getLoc(), dst, src);
262 b.create<MatchingConnectOp>(mem->getLoc(), src, dst);
267 auto inst = emitMemoryInstance(mem, wrapper, summary);
273 auto leafSym = memModule.getModuleNameAttr();
279 bool nlaUpdated =
false;
280 SmallVector<Annotation> newMemModAnnos;
281 OpBuilder nlaBuilder(context);
285 auto nlaSym = anno.
getMember<FlatSymbolRefAttr>(nonlocalAttr);
289 auto newNLAIter = processedNLAs.find(nlaSym.getAttr());
290 StringAttr newNLAName;
291 if (newNLAIter == processedNLAs.end()) {
295 dyn_cast<hw::HierPathOp>(symbolTable->lookup(nlaSym.getAttr()));
296 auto namepath = nla.getNamepath().getValue();
297 SmallVector<Attribute> newNamepath(namepath.begin(), namepath.end());
298 if (!nla.isComponent())
300 getInnerRefTo(inst, [&](
auto mod) -> hw::InnerSymbolNamespace & {
301 return getModuleNamespace(mod);
303 newNamepath.push_back(leafAttr);
305 nlaBuilder.setInsertionPointAfter(nla);
306 auto newNLA = cast<hw::HierPathOp>(nlaBuilder.clone(*nla));
308 context, circuitNamespace.newName(nla.getNameAttr().getValue())));
310 newNLAName = newNLA.getNameAttr();
311 processedNLAs[nlaSym.getAttr()] = newNLAName;
313 newNLAName = newNLAIter->getSecond();
316 newMemModAnnos.push_back(anno);
322 newAnnos.addAnnotations(newMemModAnnos);
323 newAnnos.applyToOperation(memInst);
331 SmallVector<SubfieldOp> accesses;
332 for (
auto *op : structValue.getUsers()) {
333 assert(isa<SubfieldOp>(op));
334 auto fieldAccess = cast<SubfieldOp>(op);
336 fieldAccess.getInput().getType().base().getElementIndex(field);
337 if (elemIndex && *elemIndex == fieldAccess.getFieldIndex())
338 accesses.push_back(fieldAccess);
343 InstanceOp LowerMemoryPass::emitMemoryInstance(MemOp op, FModuleOp module,
345 OpBuilder builder(op);
346 auto *context = &getContext();
347 auto memName = op.getName();
352 SmallVector<Type, 8> portTypes;
353 SmallVector<Direction> portDirections;
354 SmallVector<Attribute> portNames;
355 DenseMap<Operation *, size_t> returnHolder;
356 mlir::DominanceInfo domInfo(op->getParentOfType<FModuleOp>());
361 for (
unsigned memportKindIdx = 0; memportKindIdx != 3; ++memportKindIdx) {
362 MemOp::PortKind memportKind = MemOp::PortKind::Read;
363 auto *portLabel =
"R";
364 switch (memportKindIdx) {
368 memportKind = MemOp::PortKind::ReadWrite;
372 memportKind = MemOp::PortKind::Write;
379 unsigned portNumber = 0;
383 auto ui1Type = getType(1);
384 auto addressType = getType(std::max(1U, llvm::Log2_64_Ceil(summary.
depth)));
390 for (
size_t i = 0, e = op.getNumResults(); i != e; ++i) {
392 if (memportKind != op.getPortKind(i))
395 auto addPort = [&](
Direction direction, StringRef field, Type portType) {
398 for (
auto a : accesses)
399 returnHolder[a] = portTypes.size();
401 portTypes.push_back(portType);
402 portDirections.push_back(direction);
404 builder.getStringAttr(portLabel + Twine(portNumber) +
"_" + field));
407 auto getDriver = [&](StringRef field) -> Operation * {
409 for (
auto a : accesses) {
410 for (
auto *user : a->getUsers()) {
412 if (
auto connect = dyn_cast<FConnectLike>(user);
423 auto removeMask = [&](StringRef enable, StringRef
mask) {
425 auto *maskConnect = getDriver(mask);
429 auto *enConnect = getDriver(enable);
434 OpBuilder b(maskConnect);
435 if (domInfo.dominates(maskConnect, enConnect))
436 b.setInsertionPoint(enConnect);
438 auto andOp = b.create<AndPrimOp>(
439 op->getLoc(), maskConnect->getOperand(1), enConnect->getOperand(1));
440 enConnect->setOperand(1, andOp);
441 enConnect->moveAfter(andOp);
443 auto *maskField = maskConnect->getOperand(0).getDefiningOp();
444 maskConnect->erase();
448 if (memportKind == MemOp::PortKind::Read) {
453 }
else if (memportKind == MemOp::PortKind::ReadWrite) {
464 removeMask(
"wmode",
"wmask");
474 removeMask(
"en",
"mask");
484 auto inst = builder.create<InstanceOp>(
486 op.getNameKind(), portDirections, portNames,
487 ArrayRef<Attribute>(),
488 ArrayRef<Attribute>(),
489 ArrayRef<Attribute>(),
false,
490 op.getInnerSymAttr());
493 for (
auto [subfield, result] : returnHolder) {
494 subfield->getResult(0).replaceAllUsesWith(inst.getResult(result));
501 LogicalResult LowerMemoryPass::runOnModule(FModuleOp module,
bool shouldDedup) {
503 llvm::make_early_inc_range(module.getBodyBlock()->getOps<MemOp>())) {
505 if (!type_isa<UIntType>(op.getDataType()))
506 return op->emitError(
507 "memories should be flattened before running LowerMemory");
513 lowerMemory(op, summary, shouldDedup);
518 void LowerMemoryPass::runOnOperation() {
519 auto circuit = getOperation();
520 auto *body = circuit.getBodyBlock();
521 auto &instanceGraph = getAnalysis<InstanceGraph>();
522 symbolTable = &getAnalysis<SymbolTable>();
523 circuitNamespace.add(circuit);
527 auto *dut = instanceGraph.getTopLevelNode();
528 auto it = llvm::find_if(*body, [&](Operation &op) ->
bool {
531 if (it != body->end())
532 dut = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
535 DenseSet<Operation *> dutModuleSet;
542 for (
auto module : body->getOps<FModuleOp>()) {
544 auto shouldDedup = dutModuleSet.contains(module);
545 if (failed(runOnModule(module, shouldDedup)))
546 return signalPassFailure();
549 circuitNamespace.clear();
550 symbolTable =
nullptr;
555 return std::make_unique<LowerMemoryPass>();
assert(baseType &&"element must be base type")
static SmallVector< SubfieldOp > getAllFieldAccesses(Value structValue, StringRef field)
FirMemory getSummary(MemOp op)
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction
This represents the direction of a single port.
constexpr const char * dutAnnoClass
hw::InnerRefAttr getInnerRefTo(const hw::InnerSymTarget &target, GetNamespaceCallback getNamespace)
Obtain an inner reference to the target (operation or port), adding an inner symbol as necessary.
std::unique_ptr< mlir::Pass > createLowerMemoryPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
The namespace of a CircuitOp, generally inhabited by modules.
bool isSeqMem() const
Check whether the memory is a seq mem.
StringAttr getFirMemoryName() const