19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/Support/Debug.h"
23#define DEBUG_TYPE "mem-to-reg-of-vec"
27#define GEN_PASS_DEF_MEMTOREGOFVEC
28#include "circt/Dialect/FIRRTL/Passes.h.inc"
33using namespace firrtl;
36struct MemToRegOfVecPass
37 :
public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
38 MemToRegOfVecPass(
bool replSeqMem,
bool ignoreReadEnable)
39 : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
41 void runOnOperation()
override {
42 auto circtOp = getOperation();
43 auto &instanceInfo = getAnalysis<InstanceInfo>();
47 return markAllAnalysesPreserved();
49 DenseSet<Operation *> dutModuleSet;
50 for (
auto moduleOp : circtOp.getOps<FModuleOp>())
51 if (instanceInfo.anyInstanceInEffectiveDesign(moduleOp))
52 dutModuleSet.insert(moduleOp);
54 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
56 if (auto mod = dyn_cast<FModuleOp>(op))
61 void runOnModule(FModuleOp mod) {
63 mod.getBodyBlock()->walk([&](MemOp memOp) {
64 LLVM_DEBUG(llvm::dbgs() <<
"\n Memory op:" << memOp);
68 auto firMem = memOp.getSummary();
76 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
77 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
78 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
81 generateMemory(memOp, firMem);
86 Value addPipelineStages(ImplicitLocOpBuilder &b,
size_t stages, Value clock,
87 Value pipeInput, StringRef name, Value gate = {}) {
92 auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
94 b.create<WhenOp>(gate,
false, [&]() {
95 b.create<MatchingConnectOp>(
reg, pipeInput);
98 b.create<MatchingConnectOp>(
reg, pipeInput);
106 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
107 return builder.create<SubfieldOp>(bundle,
"clk");
110 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
111 return builder.create<SubfieldOp>(bundle,
"addr");
114 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
115 return builder.create<SubfieldOp>(bundle,
"wmode");
118 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
119 return builder.create<SubfieldOp>(bundle,
"en");
122 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
123 auto bType = type_cast<BundleType>(bundle.getType());
124 if (bType.getElement(
"mask"))
125 return builder.create<SubfieldOp>(bundle,
"mask");
126 return builder.create<SubfieldOp>(bundle,
"wmask");
129 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
130 bool getWdata =
false) {
131 auto bType = type_cast<BundleType>(bundle.getType());
132 if (bType.getElement(
"data"))
133 return builder.create<SubfieldOp>(bundle,
"data");
134 if (bType.getElement(
"rdata") && !getWdata)
135 return builder.create<SubfieldOp>(bundle,
"rdata");
136 return builder.create<SubfieldOp>(bundle,
"wdata");
139 void generateRead(
const FirMemory &firMem, Value clock, Value
addr,
140 Value enable, Value
data, Value regOfVec,
141 ImplicitLocOpBuilder &builder) {
142 if (ignoreReadEnable) {
145 for (
size_t j = 0, e = firMem.
readLatency; j != e; ++j) {
146 auto enLast = enable;
148 enable = addPipelineStages(builder, 1, clock, enable,
"en");
149 addr = addPipelineStages(builder, 1, clock,
addr,
"addr", enLast);
155 addPipelineStages(builder, firMem.
readLatency, clock, enable,
"en");
161 Value
rdata = builder.create<SubaccessOp>(regOfVec,
addr);
162 if (!ignoreReadEnable) {
164 builder.create<MatchingConnectOp>(
165 data, builder.create<InvalidValueOp>(
data.getType()));
167 builder.create<WhenOp>(enable,
false, [&]() {
168 builder.create<MatchingConnectOp>(
data,
rdata);
172 builder.create<MatchingConnectOp>(
data,
rdata);
176 void generateWrite(
const FirMemory &firMem, Value clock, Value
addr,
177 Value enable, Value maskBits, Value wdataIn,
178 Value regOfVec, ImplicitLocOpBuilder &builder) {
183 addr = addPipelineStages(builder, numStages, clock,
addr,
"addr");
184 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
185 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
186 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
195 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
210 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
212 wdataIn.getDefiningOp()->emitOpError(
213 "Cannot convert memory to bank of registers");
217 builder.create<WhenOp>(enable,
false, [&]() {
219 for (
auto regDataMask : loweredRegDataMaskFields) {
220 auto regField = std::get<0>(regDataMask);
221 auto dataField = std::get<1>(regDataMask);
222 auto maskField = std::get<2>(regDataMask);
224 builder.create<WhenOp>(maskField,
false, [&]() {
225 builder.create<MatchingConnectOp>(regField, dataField);
231 void generateReadWrite(
const FirMemory &firMem, Value clock, Value
addr,
232 Value enable, Value maskBits, Value wdataIn,
233 Value rdataOut, Value
wmode, Value regOfVec,
234 ImplicitLocOpBuilder &builder) {
239 addr = addPipelineStages(builder, numStages, clock,
addr,
"addr");
240 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
241 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
242 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
245 Value
rdata = builder.create<SubaccessOp>(regOfVec,
addr);
247 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
248 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
250 wdataIn.getDefiningOp()->emitOpError(
251 "Cannot convert memory to bank of registers");
255 builder.create<MatchingConnectOp>(
256 rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
258 builder.create<WhenOp>(enable,
false, [&]() {
260 builder.create<WhenOp>(
265 for (
auto regDataMask : loweredRegDataMaskFields) {
266 auto regField = std::get<0>(regDataMask);
267 auto dataField = std::get<1>(regDataMask);
268 auto maskField = std::get<2>(regDataMask);
270 builder.create<WhenOp>(
271 maskField,
false, [&]() {
272 builder.create<MatchingConnectOp>(regField, dataField);
277 [&]() { builder.create<MatchingConnectOp>(rdataOut,
rdata); });
288 bool getFields(Value reg, Value input, Value
mask,
289 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
290 ImplicitLocOpBuilder &builder) {
294 if (
auto bundle = type_dyn_cast<BundleType>(inType)) {
295 if (
auto mBundle = type_dyn_cast<BundleType>(maskType))
296 return mBundle.getNumElements() == bundle.getNumElements();
297 }
else if (
auto vec = type_dyn_cast<FVectorType>(inType)) {
298 if (
auto mVec = type_dyn_cast<FVectorType>(maskType))
299 return mVec.getNumElements() == vec.getNumElements();
305 std::function<bool(Value, Value, Value)> flatAccess =
306 [&](Value
reg, Value input, Value
mask) ->
bool {
307 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
308 if (!isValidMask(inType, type_cast<FIRRTLType>(
mask.getType()))) {
309 input.getDefiningOp()->emitOpError(
"Mask type is not valid");
313 .
Case<BundleType>([&](BundleType bundle) {
314 for (
size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
315 auto regField = builder.create<SubfieldOp>(
reg, i);
316 auto inputField = builder.create<SubfieldOp>(input, i);
317 auto maskField = builder.create<SubfieldOp>(
mask, i);
318 if (!flatAccess(regField, inputField, maskField))
323 .Case<FVectorType>([&](
auto vector) {
324 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
325 auto regField = builder.create<SubindexOp>(
reg, i);
326 auto inputField = builder.create<SubindexOp>(input, i);
327 auto maskField = builder.create<SubindexOp>(
mask, i);
328 if (!flatAccess(regField, inputField, maskField))
333 .Case<IntType>([&](
auto iType) {
334 results.push_back({
reg, input,
mask});
335 return iType.getWidth().has_value();
337 .Default([&](
auto) {
return false; });
339 if (flatAccess(reg, input,
mask))
344 void scatterMemTapAnno(RegOp op, ArrayAttr attr,
345 ImplicitLocOpBuilder &builder) {
347 SmallVector<Attribute> regAnnotations;
348 auto vecType = type_cast<FVectorType>(op.getResult().getType());
349 for (
auto anno : annos) {
351 for (
size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
354 NamedAttrList newAnno;
355 newAnno.append(
"class", anno.getMember(
"class"));
356 newAnno.append(
"circt.fieldID",
357 builder.getI64IntegerAttr(vecType.getFieldID(i)));
358 newAnno.append(
"id", anno.getMember(
"id"));
359 if (
auto nla = anno.getMember(
"circt.nonlocal"))
360 newAnno.append(
"circt.nonlocal", nla);
363 IntegerAttr::get(IntegerType::get(builder.getContext(), 64), i));
365 regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
368 regAnnotations.push_back(anno.getAttr());
370 op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
374 void generateMemory(MemOp memOp,
FirMemory &firMem) {
375 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
376 auto dataType = memOp.getDataType();
378 auto innerSym = memOp.getInnerSym();
379 SmallVector<Value> debugPorts;
382 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
384 auto result = memOp.getResult(index);
385 if (type_isa<RefType>(result.getType())) {
386 debugPorts.push_back(result);
391 auto wire = builder.create<WireOp>(
393 (memOp.getName() +
"_" + memOp.getPortName(index).getValue()).str(),
394 memOp.getNameKind());
395 result.replaceAllUsesWith(wire.getResult());
396 result = wire.getResult();
398 auto adr = getAddr(builder, result);
399 auto enb = getEnable(builder, result);
400 auto clk = getClock(builder, result);
401 auto dta = getData(builder, result);
405 regOfVec = builder.create<RegOp>(
406 FVectorType::get(dataType, firMem.
depth),
clk, memOp.getNameAttr());
409 if (!memOp.getAnnotationsAttr().empty())
410 scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
412 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
414 auto portKind = memOp.getPortKind(index);
415 if (portKind == MemOp::PortKind::Read) {
416 generateRead(firMem,
clk, adr, enb, dta, regOfVec.getResult(), builder);
417 }
else if (portKind == MemOp::PortKind::Write) {
418 auto mask = getMask(builder, result);
419 generateWrite(firMem,
clk, adr, enb,
mask, dta, regOfVec.getResult(),
422 auto wmode = getWmode(builder, result);
423 auto wDta = getData(builder, result,
true);
424 auto mask = getMask(builder, result);
425 generateReadWrite(firMem,
clk, adr, enb,
mask, wDta, dta,
wmode,
426 regOfVec.getResult(), builder);
433 for (
auto r : debugPorts)
434 r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
439 bool ignoreReadEnable;
443std::unique_ptr<mlir::Pass>
445 return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
constexpr const char * excludeMemToRegAnnoClass
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)