19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/Support/Debug.h"
23#define DEBUG_TYPE "mem-to-reg-of-vec"
27#define GEN_PASS_DEF_MEMTOREGOFVEC
28#include "circt/Dialect/FIRRTL/Passes.h.inc"
33using namespace firrtl;
36struct MemToRegOfVecPass
37 :
public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
38 MemToRegOfVecPass(
bool replSeqMem,
bool ignoreReadEnable)
39 : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
41 void runOnOperation()
override {
42 auto circtOp = getOperation();
43 auto &instanceInfo = getAnalysis<InstanceInfo>();
47 return markAllAnalysesPreserved();
49 DenseSet<Operation *> dutModuleSet;
50 for (
auto moduleOp : circtOp.getOps<FModuleOp>())
51 if (instanceInfo.anyInstanceInEffectiveDesign(moduleOp))
52 dutModuleSet.insert(moduleOp);
54 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
56 if (auto mod = dyn_cast<FModuleOp>(op))
61 void runOnModule(FModuleOp mod) {
63 mod.getBodyBlock()->walk([&](MemOp memOp) {
64 LLVM_DEBUG(llvm::dbgs() <<
"\n Memory op:" << memOp);
66 auto firMem = memOp.getSummary();
74 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
75 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
76 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
79 generateMemory(memOp, firMem);
84 Value addPipelineStages(ImplicitLocOpBuilder &b,
size_t stages, Value clock,
85 Value pipeInput, StringRef name, Value gate = {}) {
90 auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
92 b.create<WhenOp>(gate,
false, [&]() {
93 b.create<MatchingConnectOp>(
reg, pipeInput);
96 b.create<MatchingConnectOp>(
reg, pipeInput);
104 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
105 return builder.create<SubfieldOp>(bundle,
"clk");
108 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
109 return builder.create<SubfieldOp>(bundle,
"addr");
112 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
113 return builder.create<SubfieldOp>(bundle,
"wmode");
116 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
117 return builder.create<SubfieldOp>(bundle,
"en");
120 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
121 auto bType = type_cast<BundleType>(bundle.getType());
122 if (bType.getElement(
"mask"))
123 return builder.create<SubfieldOp>(bundle,
"mask");
124 return builder.create<SubfieldOp>(bundle,
"wmask");
127 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
128 bool getWdata =
false) {
129 auto bType = type_cast<BundleType>(bundle.getType());
130 if (bType.getElement(
"data"))
131 return builder.create<SubfieldOp>(bundle,
"data");
132 if (bType.getElement(
"rdata") && !getWdata)
133 return builder.create<SubfieldOp>(bundle,
"rdata");
134 return builder.create<SubfieldOp>(bundle,
"wdata");
137 void generateRead(
const FirMemory &firMem, Value clock, Value
addr,
138 Value enable, Value
data, Value regOfVec,
139 ImplicitLocOpBuilder &builder) {
140 if (ignoreReadEnable) {
143 for (
size_t j = 0, e = firMem.
readLatency; j != e; ++j) {
144 auto enLast = enable;
146 enable = addPipelineStages(builder, 1, clock, enable,
"en");
147 addr = addPipelineStages(builder, 1, clock,
addr,
"addr", enLast);
153 addPipelineStages(builder, firMem.
readLatency, clock, enable,
"en");
159 Value
rdata = builder.create<SubaccessOp>(regOfVec,
addr);
160 if (!ignoreReadEnable) {
162 builder.create<MatchingConnectOp>(
163 data, builder.create<InvalidValueOp>(
data.getType()));
165 builder.create<WhenOp>(enable,
false, [&]() {
166 builder.create<MatchingConnectOp>(
data,
rdata);
170 builder.create<MatchingConnectOp>(
data,
rdata);
174 void generateWrite(
const FirMemory &firMem, Value clock, Value
addr,
175 Value enable, Value maskBits, Value wdataIn,
176 Value regOfVec, ImplicitLocOpBuilder &builder) {
181 addr = addPipelineStages(builder, numStages, clock,
addr,
"addr");
182 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
183 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
184 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
193 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
208 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
210 wdataIn.getDefiningOp()->emitOpError(
211 "Cannot convert memory to bank of registers");
215 builder.create<WhenOp>(enable,
false, [&]() {
217 for (
auto regDataMask : loweredRegDataMaskFields) {
218 auto regField = std::get<0>(regDataMask);
219 auto dataField = std::get<1>(regDataMask);
220 auto maskField = std::get<2>(regDataMask);
222 builder.create<WhenOp>(maskField,
false, [&]() {
223 builder.create<MatchingConnectOp>(regField, dataField);
229 void generateReadWrite(
const FirMemory &firMem, Value clock, Value
addr,
230 Value enable, Value maskBits, Value wdataIn,
231 Value rdataOut, Value
wmode, Value regOfVec,
232 ImplicitLocOpBuilder &builder) {
237 addr = addPipelineStages(builder, numStages, clock,
addr,
"addr");
238 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
239 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
240 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
243 Value
rdata = builder.create<SubaccessOp>(regOfVec,
addr);
245 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
246 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
248 wdataIn.getDefiningOp()->emitOpError(
249 "Cannot convert memory to bank of registers");
253 builder.create<MatchingConnectOp>(
254 rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
256 builder.create<WhenOp>(enable,
false, [&]() {
258 builder.create<WhenOp>(
263 for (
auto regDataMask : loweredRegDataMaskFields) {
264 auto regField = std::get<0>(regDataMask);
265 auto dataField = std::get<1>(regDataMask);
266 auto maskField = std::get<2>(regDataMask);
268 builder.create<WhenOp>(
269 maskField,
false, [&]() {
270 builder.create<MatchingConnectOp>(regField, dataField);
275 [&]() { builder.create<MatchingConnectOp>(rdataOut,
rdata); });
286 bool getFields(Value reg, Value input, Value
mask,
287 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
288 ImplicitLocOpBuilder &builder) {
292 if (
auto bundle = type_dyn_cast<BundleType>(inType)) {
293 if (
auto mBundle = type_dyn_cast<BundleType>(maskType))
294 return mBundle.getNumElements() == bundle.getNumElements();
295 }
else if (
auto vec = type_dyn_cast<FVectorType>(inType)) {
296 if (
auto mVec = type_dyn_cast<FVectorType>(maskType))
297 return mVec.getNumElements() == vec.getNumElements();
303 std::function<bool(Value, Value, Value)> flatAccess =
304 [&](Value
reg, Value input, Value
mask) ->
bool {
305 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
306 if (!isValidMask(inType, type_cast<FIRRTLType>(
mask.getType()))) {
307 input.getDefiningOp()->emitOpError(
"Mask type is not valid");
311 .
Case<BundleType>([&](BundleType bundle) {
312 for (
size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
313 auto regField = builder.create<SubfieldOp>(
reg, i);
314 auto inputField = builder.create<SubfieldOp>(input, i);
315 auto maskField = builder.create<SubfieldOp>(
mask, i);
316 if (!flatAccess(regField, inputField, maskField))
321 .Case<FVectorType>([&](
auto vector) {
322 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
323 auto regField = builder.create<SubindexOp>(
reg, i);
324 auto inputField = builder.create<SubindexOp>(input, i);
325 auto maskField = builder.create<SubindexOp>(
mask, i);
326 if (!flatAccess(regField, inputField, maskField))
331 .Case<IntType>([&](
auto iType) {
332 results.push_back({
reg, input,
mask});
333 return iType.getWidth().has_value();
335 .Default([&](
auto) {
return false; });
337 if (flatAccess(reg, input,
mask))
342 void scatterMemTapAnno(RegOp op, ArrayAttr attr,
343 ImplicitLocOpBuilder &builder) {
345 SmallVector<Attribute> regAnnotations;
346 auto vecType = type_cast<FVectorType>(op.getResult().getType());
347 for (
auto anno : annos) {
349 for (
size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
352 NamedAttrList newAnno;
353 newAnno.append(
"class", anno.getMember(
"class"));
354 newAnno.append(
"circt.fieldID",
355 builder.getI64IntegerAttr(vecType.getFieldID(i)));
356 newAnno.append(
"id", anno.getMember(
"id"));
357 if (
auto nla = anno.getMember(
"circt.nonlocal"))
358 newAnno.append(
"circt.nonlocal", nla);
361 IntegerAttr::get(IntegerType::get(builder.getContext(), 64), i));
363 regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
366 regAnnotations.push_back(anno.getAttr());
368 op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
372 void generateMemory(MemOp memOp,
FirMemory &firMem) {
373 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
374 auto dataType = memOp.getDataType();
376 auto innerSym = memOp.getInnerSym();
377 SmallVector<Value> debugPorts;
380 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
382 auto result = memOp.getResult(index);
383 if (type_isa<RefType>(result.getType())) {
384 debugPorts.push_back(result);
389 auto wire = builder.create<WireOp>(
391 (memOp.getName() +
"_" + memOp.getPortName(index).getValue()).str(),
392 memOp.getNameKind());
393 result.replaceAllUsesWith(wire.getResult());
394 result = wire.getResult();
396 auto adr = getAddr(builder, result);
397 auto enb = getEnable(builder, result);
398 auto clk = getClock(builder, result);
399 auto dta = getData(builder, result);
403 regOfVec = builder.create<RegOp>(
404 FVectorType::get(dataType, firMem.
depth),
clk, memOp.getNameAttr());
407 if (!memOp.getAnnotationsAttr().empty())
408 scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
410 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
412 auto portKind = memOp.getPortKind(index);
413 if (portKind == MemOp::PortKind::Read) {
414 generateRead(firMem,
clk, adr, enb, dta, regOfVec.getResult(), builder);
415 }
else if (portKind == MemOp::PortKind::Write) {
416 auto mask = getMask(builder, result);
417 generateWrite(firMem,
clk, adr, enb,
mask, dta, regOfVec.getResult(),
420 auto wmode = getWmode(builder, result);
421 auto wDta = getData(builder, result,
true);
422 auto mask = getMask(builder, result);
423 generateReadWrite(firMem,
clk, adr, enb,
mask, wDta, dta,
wmode,
424 regOfVec.getResult(), builder);
431 for (
auto r : debugPorts)
432 r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
437 bool ignoreReadEnable;
441std::unique_ptr<mlir::Pass>
443 return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)