20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/Threading.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "mem-to-reg-of-vec"
31 #define GEN_PASS_DEF_MEMTOREGOFVEC
32 #include "circt/Dialect/FIRRTL/Passes.h.inc"
36 using namespace circt;
37 using namespace firrtl;
40 struct MemToRegOfVecPass
41 :
public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
42 MemToRegOfVecPass(
bool replSeqMem,
bool ignoreReadEnable)
43 : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
45 void runOnOperation()
override {
46 auto circtOp = getOperation();
47 DenseSet<Operation *> dutModuleSet;
50 return markAllAnalysesPreserved();
51 auto *body = circtOp.getBodyBlock();
54 auto it = llvm::find_if(*body, [&](Operation &op) ->
bool {
57 if (it != body->end()) {
58 auto &instanceGraph = getAnalysis<InstanceGraph>();
59 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
60 llvm::for_each(llvm::depth_first(node),
65 auto mods = circtOp.getOps<FModuleOp>();
66 dutModuleSet.insert(mods.begin(), mods.end());
69 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
71 if (auto mod = dyn_cast<FModuleOp>(op))
76 void runOnModule(FModuleOp mod) {
78 mod.getBodyBlock()->walk([&](MemOp memOp) {
79 LLVM_DEBUG(llvm::dbgs() <<
"\n Memory op:" << memOp);
83 auto firMem = memOp.getSummary();
91 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
92 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
93 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
96 generateMemory(memOp, firMem);
101 Value addPipelineStages(ImplicitLocOpBuilder &b,
size_t stages, Value clock,
102 Value pipeInput, StringRef name, Value gate = {}) {
107 auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
109 b.create<WhenOp>(gate,
false, [&]() {
110 b.create<MatchingConnectOp>(
reg, pipeInput);
113 b.create<MatchingConnectOp>(
reg, pipeInput);
121 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
122 return builder.create<SubfieldOp>(bundle,
"clk");
125 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
126 return builder.create<SubfieldOp>(bundle,
"addr");
129 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
130 return builder.create<SubfieldOp>(bundle,
"wmode");
133 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
134 return builder.create<SubfieldOp>(bundle,
"en");
137 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
138 auto bType = type_cast<BundleType>(bundle.getType());
139 if (bType.getElement(
"mask"))
140 return builder.create<SubfieldOp>(bundle,
"mask");
141 return builder.create<SubfieldOp>(bundle,
"wmask");
144 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
145 bool getWdata =
false) {
146 auto bType = type_cast<BundleType>(bundle.getType());
147 if (bType.getElement(
"data"))
148 return builder.create<SubfieldOp>(bundle,
"data");
149 if (bType.getElement(
"rdata") && !getWdata)
150 return builder.create<SubfieldOp>(bundle,
"rdata");
151 return builder.create<SubfieldOp>(bundle,
"wdata");
154 void generateRead(
const FirMemory &firMem, Value clock, Value addr,
155 Value enable, Value data, Value regOfVec,
156 ImplicitLocOpBuilder &builder) {
157 if (ignoreReadEnable) {
160 for (
size_t j = 0, e = firMem.
readLatency; j != e; ++j) {
161 auto enLast = enable;
163 enable = addPipelineStages(builder, 1, clock, enable,
"en");
164 addr = addPipelineStages(builder, 1, clock, addr,
"addr", enLast);
170 addPipelineStages(builder, firMem.
readLatency, clock, enable,
"en");
172 addPipelineStages(builder, firMem.
readLatency, clock, addr,
"addr");
176 Value
rdata = builder.create<SubaccessOp>(regOfVec,
addr);
177 if (!ignoreReadEnable) {
179 builder.create<MatchingConnectOp>(
180 data, builder.create<InvalidValueOp>(
data.getType()));
182 builder.create<WhenOp>(enable,
false, [&]() {
183 builder.create<MatchingConnectOp>(
data,
rdata);
187 builder.create<MatchingConnectOp>(
data,
rdata);
191 void generateWrite(
const FirMemory &firMem, Value clock, Value addr,
192 Value enable, Value maskBits, Value wdataIn,
193 Value regOfVec, ImplicitLocOpBuilder &builder) {
198 addr = addPipelineStages(builder, numStages, clock, addr,
"addr");
199 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
200 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
201 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
210 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
225 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
227 wdataIn.getDefiningOp()->emitOpError(
228 "Cannot convert memory to bank of registers");
232 builder.create<WhenOp>(enable,
false, [&]() {
234 for (
auto regDataMask : loweredRegDataMaskFields) {
235 auto regField = std::get<0>(regDataMask);
236 auto dataField = std::get<1>(regDataMask);
237 auto maskField = std::get<2>(regDataMask);
239 builder.create<WhenOp>(maskField,
false, [&]() {
240 builder.create<MatchingConnectOp>(regField, dataField);
246 void generateReadWrite(
const FirMemory &firMem, Value clock, Value addr,
247 Value enable, Value maskBits, Value wdataIn,
248 Value rdataOut, Value wmode, Value regOfVec,
249 ImplicitLocOpBuilder &builder) {
254 addr = addPipelineStages(builder, numStages, clock, addr,
"addr");
255 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
256 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
257 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
260 Value
rdata = builder.create<SubaccessOp>(regOfVec,
addr);
262 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
263 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
265 wdataIn.getDefiningOp()->emitOpError(
266 "Cannot convert memory to bank of registers");
270 builder.create<MatchingConnectOp>(
271 rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
273 builder.create<WhenOp>(enable,
false, [&]() {
275 builder.create<WhenOp>(
280 for (
auto regDataMask : loweredRegDataMaskFields) {
281 auto regField = std::get<0>(regDataMask);
282 auto dataField = std::get<1>(regDataMask);
283 auto maskField = std::get<2>(regDataMask);
285 builder.create<WhenOp>(
286 maskField,
false, [&]() {
287 builder.create<MatchingConnectOp>(regField, dataField);
292 [&]() { builder.create<MatchingConnectOp>(rdataOut,
rdata); });
303 bool getFields(Value
reg, Value input, Value mask,
304 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
305 ImplicitLocOpBuilder &builder) {
309 if (
auto bundle = type_dyn_cast<BundleType>(inType)) {
310 if (
auto mBundle = type_dyn_cast<BundleType>(maskType))
311 return mBundle.getNumElements() == bundle.getNumElements();
312 }
else if (
auto vec = type_dyn_cast<FVectorType>(inType)) {
313 if (
auto mVec = type_dyn_cast<FVectorType>(maskType))
314 return mVec.getNumElements() == vec.getNumElements();
320 std::function<bool(Value, Value, Value)> flatAccess =
321 [&](Value
reg, Value input, Value
mask) ->
bool {
322 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
323 if (!isValidMask(inType, type_cast<FIRRTLType>(
mask.getType()))) {
324 input.getDefiningOp()->emitOpError(
"Mask type is not valid");
328 .
Case<BundleType>([&](BundleType bundle) {
329 for (
size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
330 auto regField = builder.create<SubfieldOp>(
reg, i);
331 auto inputField = builder.create<SubfieldOp>(input, i);
332 auto maskField = builder.create<SubfieldOp>(
mask, i);
333 if (!flatAccess(regField, inputField, maskField))
338 .Case<FVectorType>([&](
auto vector) {
339 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
340 auto regField = builder.create<SubindexOp>(
reg, i);
341 auto inputField = builder.create<SubindexOp>(input, i);
342 auto maskField = builder.create<SubindexOp>(
mask, i);
343 if (!flatAccess(regField, inputField, maskField))
348 .Case<IntType>([&](
auto iType) {
349 results.push_back({
reg, input,
mask});
350 return iType.getWidth().has_value();
352 .Default([&](
auto) {
return false; });
354 if (flatAccess(
reg, input, mask))
359 void scatterMemTapAnno(RegOp op, ArrayAttr attr,
360 ImplicitLocOpBuilder &builder) {
362 SmallVector<Attribute> regAnnotations;
363 auto vecType = type_cast<FVectorType>(op.getResult().getType());
364 for (
auto anno : annos) {
366 for (
size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
369 NamedAttrList newAnno;
370 newAnno.append(
"class", anno.getMember(
"class"));
371 newAnno.append(
"circt.fieldID",
372 builder.getI64IntegerAttr(vecType.getFieldID(i)));
373 newAnno.append(
"id", anno.getMember(
"id"));
374 if (
auto nla = anno.getMember(
"circt.nonlocal"))
375 newAnno.append(
"circt.nonlocal", nla);
380 regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
383 regAnnotations.push_back(anno.getAttr());
385 op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
389 void generateMemory(MemOp memOp,
FirMemory &firMem) {
390 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
391 auto dataType = memOp.getDataType();
393 auto innerSym = memOp.getInnerSym();
394 SmallVector<Value> debugPorts;
397 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
399 auto result = memOp.getResult(index);
400 if (type_isa<RefType>(result.getType())) {
401 debugPorts.push_back(result);
406 auto wire = builder.create<WireOp>(
408 (memOp.getName() +
"_" + memOp.getPortName(index).getValue()).str(),
409 memOp.getNameKind());
410 result.replaceAllUsesWith(wire.getResult());
411 result = wire.getResult();
413 auto adr = getAddr(builder, result);
414 auto enb = getEnable(builder, result);
415 auto clk = getClock(builder, result);
416 auto dta = getData(builder, result);
420 regOfVec = builder.create<RegOp>(
424 if (!memOp.getAnnotationsAttr().empty())
425 scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
427 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
429 auto portKind = memOp.getPortKind(index);
430 if (portKind == MemOp::PortKind::Read) {
431 generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
432 }
else if (portKind == MemOp::PortKind::Write) {
433 auto mask = getMask(builder, result);
434 generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
437 auto wmode = getWmode(builder, result);
438 auto wDta = getData(builder, result,
true);
439 auto mask = getMask(builder, result);
440 generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
441 regOfVec.getResult(), builder);
448 for (
auto r : debugPorts)
449 r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
454 bool ignoreReadEnable;
458 std::unique_ptr<mlir::Pass>
460 return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
constexpr const char * excludeMemToRegAnnoClass
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * dutAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)