19 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/DepthFirstIterator.h"
22 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "mem-to-reg-of-vec"
28 #define GEN_PASS_DEF_MEMTOREGOFVEC
29 #include "circt/Dialect/FIRRTL/Passes.h.inc"
33 using namespace circt;
34 using namespace firrtl;
37 struct MemToRegOfVecPass
38 :
public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
39 MemToRegOfVecPass(
bool replSeqMem,
bool ignoreReadEnable)
40 : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
42 void runOnOperation()
override {
43 auto circtOp = getOperation();
44 DenseSet<Operation *> dutModuleSet;
47 return markAllAnalysesPreserved();
48 auto *body = circtOp.getBodyBlock();
51 auto it = llvm::find_if(*body, [&](Operation &op) ->
bool {
54 if (it != body->end()) {
55 auto &instanceGraph = getAnalysis<InstanceGraph>();
56 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
57 llvm::for_each(llvm::depth_first(node),
62 auto mods = circtOp.getOps<FModuleOp>();
63 dutModuleSet.insert(mods.begin(), mods.end());
66 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
68 if (auto mod = dyn_cast<FModuleOp>(op))
73 void runOnModule(FModuleOp mod) {
75 mod.getBodyBlock()->walk([&](MemOp memOp) {
76 LLVM_DEBUG(llvm::dbgs() <<
"\n Memory op:" << memOp);
80 auto firMem = memOp.getSummary();
88 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
89 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
90 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
93 generateMemory(memOp, firMem);
98 Value addPipelineStages(ImplicitLocOpBuilder &b,
size_t stages, Value clock,
99 Value pipeInput, StringRef name, Value gate = {}) {
104 auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
106 b.create<WhenOp>(gate,
false, [&]() {
107 b.create<MatchingConnectOp>(
reg, pipeInput);
110 b.create<MatchingConnectOp>(
reg, pipeInput);
118 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
119 return builder.create<SubfieldOp>(bundle,
"clk");
122 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
123 return builder.create<SubfieldOp>(bundle,
"addr");
126 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
127 return builder.create<SubfieldOp>(bundle,
"wmode");
130 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
131 return builder.create<SubfieldOp>(bundle,
"en");
134 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
135 auto bType = type_cast<BundleType>(bundle.getType());
136 if (bType.getElement(
"mask"))
137 return builder.create<SubfieldOp>(bundle,
"mask");
138 return builder.create<SubfieldOp>(bundle,
"wmask");
141 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
142 bool getWdata =
false) {
143 auto bType = type_cast<BundleType>(bundle.getType());
144 if (bType.getElement(
"data"))
145 return builder.create<SubfieldOp>(bundle,
"data");
146 if (bType.getElement(
"rdata") && !getWdata)
147 return builder.create<SubfieldOp>(bundle,
"rdata");
148 return builder.create<SubfieldOp>(bundle,
"wdata");
151 void generateRead(
const FirMemory &firMem, Value clock, Value addr,
152 Value enable, Value data, Value regOfVec,
153 ImplicitLocOpBuilder &builder) {
154 if (ignoreReadEnable) {
157 for (
size_t j = 0, e = firMem.
readLatency; j != e; ++j) {
158 auto enLast = enable;
160 enable = addPipelineStages(builder, 1, clock, enable,
"en");
161 addr = addPipelineStages(builder, 1, clock, addr,
"addr", enLast);
167 addPipelineStages(builder, firMem.
readLatency, clock, enable,
"en");
169 addPipelineStages(builder, firMem.
readLatency, clock, addr,
"addr");
173 Value
rdata = builder.create<SubaccessOp>(regOfVec,
addr);
174 if (!ignoreReadEnable) {
176 builder.create<MatchingConnectOp>(
177 data, builder.create<InvalidValueOp>(
data.getType()));
179 builder.create<WhenOp>(enable,
false, [&]() {
180 builder.create<MatchingConnectOp>(
data,
rdata);
184 builder.create<MatchingConnectOp>(
data,
rdata);
188 void generateWrite(
const FirMemory &firMem, Value clock, Value addr,
189 Value enable, Value maskBits, Value wdataIn,
190 Value regOfVec, ImplicitLocOpBuilder &builder) {
195 addr = addPipelineStages(builder, numStages, clock, addr,
"addr");
196 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
197 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
198 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
207 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
222 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
224 wdataIn.getDefiningOp()->emitOpError(
225 "Cannot convert memory to bank of registers");
229 builder.create<WhenOp>(enable,
false, [&]() {
231 for (
auto regDataMask : loweredRegDataMaskFields) {
232 auto regField = std::get<0>(regDataMask);
233 auto dataField = std::get<1>(regDataMask);
234 auto maskField = std::get<2>(regDataMask);
236 builder.create<WhenOp>(maskField,
false, [&]() {
237 builder.create<MatchingConnectOp>(regField, dataField);
243 void generateReadWrite(
const FirMemory &firMem, Value clock, Value addr,
244 Value enable, Value maskBits, Value wdataIn,
245 Value rdataOut, Value wmode, Value regOfVec,
246 ImplicitLocOpBuilder &builder) {
251 addr = addPipelineStages(builder, numStages, clock, addr,
"addr");
252 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
253 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
254 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
257 Value
rdata = builder.create<SubaccessOp>(regOfVec,
addr);
259 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
260 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
262 wdataIn.getDefiningOp()->emitOpError(
263 "Cannot convert memory to bank of registers");
267 builder.create<MatchingConnectOp>(
268 rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
270 builder.create<WhenOp>(enable,
false, [&]() {
272 builder.create<WhenOp>(
277 for (
auto regDataMask : loweredRegDataMaskFields) {
278 auto regField = std::get<0>(regDataMask);
279 auto dataField = std::get<1>(regDataMask);
280 auto maskField = std::get<2>(regDataMask);
282 builder.create<WhenOp>(
283 maskField,
false, [&]() {
284 builder.create<MatchingConnectOp>(regField, dataField);
289 [&]() { builder.create<MatchingConnectOp>(rdataOut,
rdata); });
300 bool getFields(Value
reg, Value input, Value mask,
301 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
302 ImplicitLocOpBuilder &builder) {
306 if (
auto bundle = type_dyn_cast<BundleType>(inType)) {
307 if (
auto mBundle = type_dyn_cast<BundleType>(maskType))
308 return mBundle.getNumElements() == bundle.getNumElements();
309 }
else if (
auto vec = type_dyn_cast<FVectorType>(inType)) {
310 if (
auto mVec = type_dyn_cast<FVectorType>(maskType))
311 return mVec.getNumElements() == vec.getNumElements();
317 std::function<bool(Value, Value, Value)> flatAccess =
318 [&](Value
reg, Value input, Value
mask) ->
bool {
319 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
320 if (!isValidMask(inType, type_cast<FIRRTLType>(
mask.getType()))) {
321 input.getDefiningOp()->emitOpError(
"Mask type is not valid");
325 .
Case<BundleType>([&](BundleType bundle) {
326 for (
size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
327 auto regField = builder.create<SubfieldOp>(
reg, i);
328 auto inputField = builder.create<SubfieldOp>(input, i);
329 auto maskField = builder.create<SubfieldOp>(
mask, i);
330 if (!flatAccess(regField, inputField, maskField))
335 .Case<FVectorType>([&](
auto vector) {
336 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
337 auto regField = builder.create<SubindexOp>(
reg, i);
338 auto inputField = builder.create<SubindexOp>(input, i);
339 auto maskField = builder.create<SubindexOp>(
mask, i);
340 if (!flatAccess(regField, inputField, maskField))
345 .Case<IntType>([&](
auto iType) {
346 results.push_back({
reg, input,
mask});
347 return iType.getWidth().has_value();
349 .Default([&](
auto) {
return false; });
351 if (flatAccess(
reg, input, mask))
356 void scatterMemTapAnno(RegOp op, ArrayAttr attr,
357 ImplicitLocOpBuilder &builder) {
359 SmallVector<Attribute> regAnnotations;
360 auto vecType = type_cast<FVectorType>(op.getResult().getType());
361 for (
auto anno : annos) {
363 for (
size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
366 NamedAttrList newAnno;
367 newAnno.append(
"class", anno.getMember(
"class"));
368 newAnno.append(
"circt.fieldID",
369 builder.getI64IntegerAttr(vecType.getFieldID(i)));
370 newAnno.append(
"id", anno.getMember(
"id"));
371 if (
auto nla = anno.getMember(
"circt.nonlocal"))
372 newAnno.append(
"circt.nonlocal", nla);
377 regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
380 regAnnotations.push_back(anno.getAttr());
382 op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
386 void generateMemory(MemOp memOp,
FirMemory &firMem) {
387 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
388 auto dataType = memOp.getDataType();
390 auto innerSym = memOp.getInnerSym();
391 SmallVector<Value> debugPorts;
394 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
396 auto result = memOp.getResult(index);
397 if (type_isa<RefType>(result.getType())) {
398 debugPorts.push_back(result);
403 auto wire = builder.create<WireOp>(
405 (memOp.getName() +
"_" + memOp.getPortName(index).getValue()).str(),
406 memOp.getNameKind());
407 result.replaceAllUsesWith(wire.getResult());
408 result = wire.getResult();
410 auto adr = getAddr(builder, result);
411 auto enb = getEnable(builder, result);
412 auto clk = getClock(builder, result);
413 auto dta = getData(builder, result);
417 regOfVec = builder.create<RegOp>(
421 if (!memOp.getAnnotationsAttr().empty())
422 scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
424 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
426 auto portKind = memOp.getPortKind(index);
427 if (portKind == MemOp::PortKind::Read) {
428 generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
429 }
else if (portKind == MemOp::PortKind::Write) {
430 auto mask = getMask(builder, result);
431 generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
434 auto wmode = getWmode(builder, result);
435 auto wDta = getData(builder, result,
true);
436 auto mask = getMask(builder, result);
437 generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
438 regOfVec.getResult(), builder);
445 for (
auto r : debugPorts)
446 r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
451 bool ignoreReadEnable;
455 std::unique_ptr<mlir::Pass>
457 return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
constexpr const char * excludeMemToRegAnnoClass
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * dutAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)