19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/Support/Debug.h"
23#define DEBUG_TYPE "mem-to-reg-of-vec"
27#define GEN_PASS_DEF_MEMTOREGOFVEC
28#include "circt/Dialect/FIRRTL/Passes.h.inc"
33using namespace firrtl;
36struct MemToRegOfVecPass
37 :
public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
40 void runOnOperation()
override {
41 auto circtOp = getOperation();
42 auto &instanceInfo = getAnalysis<InstanceInfo>();
46 return markAllAnalysesPreserved();
48 DenseSet<Operation *> dutModuleSet;
49 for (
auto moduleOp : circtOp.getOps<FModuleOp>())
50 if (instanceInfo.anyInstanceInEffectiveDesign(moduleOp))
51 dutModuleSet.insert(moduleOp);
53 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
55 if (auto mod = dyn_cast<FModuleOp>(op))
60 void runOnModule(FModuleOp mod) {
62 mod.getBodyBlock()->walk([&](MemOp memOp) {
63 LLVM_DEBUG(llvm::dbgs() <<
"\n Memory op:" << memOp);
65 auto firMem = memOp.getSummary();
73 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
74 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
75 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
78 generateMemory(memOp, firMem);
83 Value addPipelineStages(ImplicitLocOpBuilder &b,
size_t stages, Value clock,
84 Value pipeInput, StringRef name, Value gate = {}) {
89 auto reg = RegOp::create(b, pipeInput.getType(), clock, name).getResult();
91 WhenOp::create(b, gate,
false,
92 [&]() { MatchingConnectOp::create(b, reg, pipeInput); });
94 MatchingConnectOp::create(b, reg, pipeInput);
102 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
103 return SubfieldOp::create(builder, bundle,
"clk");
106 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
107 return SubfieldOp::create(builder, bundle,
"addr");
110 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
111 return SubfieldOp::create(builder, bundle,
"wmode");
114 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
115 return SubfieldOp::create(builder, bundle,
"en");
118 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
119 auto bType = type_cast<BundleType>(bundle.getType());
120 if (bType.getElement(
"mask"))
121 return SubfieldOp::create(builder, bundle,
"mask");
122 return SubfieldOp::create(builder, bundle,
"wmask");
125 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
126 bool getWdata =
false) {
127 auto bType = type_cast<BundleType>(bundle.getType());
128 if (bType.getElement(
"data"))
129 return SubfieldOp::create(builder, bundle,
"data");
130 if (bType.getElement(
"rdata") && !getWdata)
131 return SubfieldOp::create(builder, bundle,
"rdata");
132 return SubfieldOp::create(builder, bundle,
"wdata");
135 void generateRead(
const FirMemory &firMem, Value clock, Value
addr,
136 Value enable, Value
data, Value regOfVec,
137 ImplicitLocOpBuilder &builder) {
138 if (ignoreReadEnable) {
141 for (
size_t j = 0, e = firMem.
readLatency; j != e; ++j) {
142 auto enLast = enable;
144 enable = addPipelineStages(builder, 1, clock, enable,
"en");
145 addr = addPipelineStages(builder, 1, clock,
addr,
"addr", enLast);
151 addPipelineStages(builder, firMem.
readLatency, clock, enable,
"en");
157 Value
rdata = SubaccessOp::create(builder, regOfVec,
addr);
158 if (!ignoreReadEnable) {
160 MatchingConnectOp::create(
161 builder,
data, InvalidValueOp::create(builder,
data.getType()));
163 WhenOp::create(builder, enable,
false, [&]() {
164 MatchingConnectOp::create(builder,
data,
rdata);
168 MatchingConnectOp::create(builder,
data,
rdata);
172 void generateWrite(
const FirMemory &firMem, Value clock, Value
addr,
173 Value enable, Value maskBits, Value wdataIn,
174 Value regOfVec, ImplicitLocOpBuilder &builder) {
179 addr = addPipelineStages(builder, numStages, clock,
addr,
"addr");
180 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
181 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
182 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
191 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
206 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
208 wdataIn.getDefiningOp()->emitOpError(
209 "Cannot convert memory to bank of registers");
213 WhenOp::create(builder, enable,
false, [&]() {
215 for (
auto regDataMask : loweredRegDataMaskFields) {
216 auto regField = std::get<0>(regDataMask);
217 auto dataField = std::get<1>(regDataMask);
218 auto maskField = std::get<2>(regDataMask);
220 WhenOp::create(builder, maskField,
false, [&]() {
221 MatchingConnectOp::create(builder, regField, dataField);
227 void generateReadWrite(
const FirMemory &firMem, Value clock, Value
addr,
228 Value enable, Value maskBits, Value wdataIn,
229 Value rdataOut, Value
wmode, Value regOfVec,
230 ImplicitLocOpBuilder &builder) {
235 addr = addPipelineStages(builder, numStages, clock,
addr,
"addr");
236 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
237 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
238 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
241 Value
rdata = SubaccessOp::create(builder, regOfVec,
addr);
243 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
244 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
246 wdataIn.getDefiningOp()->emitOpError(
247 "Cannot convert memory to bank of registers");
251 MatchingConnectOp::create(
252 builder, rdataOut, InvalidValueOp::create(builder, rdataOut.getType()));
254 WhenOp::create(builder, enable,
false, [&]() {
257 builder,
wmode,
true,
261 for (
auto regDataMask : loweredRegDataMaskFields) {
262 auto regField = std::get<0>(regDataMask);
263 auto dataField = std::get<1>(regDataMask);
264 auto maskField = std::get<2>(regDataMask);
267 builder, maskField,
false, [&]() {
268 MatchingConnectOp::create(builder, regField, dataField);
273 [&]() { MatchingConnectOp::create(builder, rdataOut,
rdata); });
284 bool getFields(Value reg, Value input, Value
mask,
285 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
286 ImplicitLocOpBuilder &builder) {
290 if (
auto bundle = type_dyn_cast<BundleType>(inType)) {
291 if (
auto mBundle = type_dyn_cast<BundleType>(maskType))
292 return mBundle.getNumElements() == bundle.getNumElements();
293 }
else if (
auto vec = type_dyn_cast<FVectorType>(inType)) {
294 if (
auto mVec = type_dyn_cast<FVectorType>(maskType))
295 return mVec.getNumElements() == vec.getNumElements();
301 std::function<bool(Value, Value, Value)> flatAccess =
302 [&](Value
reg, Value input, Value
mask) ->
bool {
303 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
304 if (!isValidMask(inType, type_cast<FIRRTLType>(
mask.getType()))) {
305 input.getDefiningOp()->emitOpError(
"Mask type is not valid");
309 .
Case<BundleType>([&](BundleType bundle) {
310 for (
size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
311 auto regField = SubfieldOp::create(builder, reg, i);
312 auto inputField = SubfieldOp::create(builder, input, i);
313 auto maskField = SubfieldOp::create(builder,
mask, i);
314 if (!flatAccess(regField, inputField, maskField))
319 .Case<FVectorType>([&](
auto vector) {
320 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
321 auto regField = SubindexOp::create(builder, reg, i);
322 auto inputField = SubindexOp::create(builder, input, i);
323 auto maskField = SubindexOp::create(builder,
mask, i);
324 if (!flatAccess(regField, inputField, maskField))
329 .Case<IntType>([&](
auto iType) {
330 results.push_back({
reg, input,
mask});
331 return iType.getWidth().has_value();
333 .Default([&](
auto) {
return false; });
335 if (flatAccess(reg, input,
mask))
341 void generateMemory(MemOp memOp,
FirMemory &firMem) {
342 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
343 auto dataType = memOp.getDataType();
345 auto innerSym = memOp.getInnerSym();
346 SmallVector<Value> debugPorts;
349 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
351 auto result = memOp.getResult(index);
352 if (type_isa<RefType>(result.getType())) {
353 debugPorts.push_back(result);
358 auto wire = WireOp::create(
359 builder, result.getType(),
360 (memOp.getName() +
"_" + memOp.getPortName(index)).str(),
361 memOp.getNameKind());
362 result.replaceAllUsesWith(wire.getResult());
363 result = wire.getResult();
365 auto adr = getAddr(builder, result);
366 auto enb = getEnable(builder, result);
367 auto clk = getClock(builder, result);
368 auto dta = getData(builder, result);
373 RegOp::create(builder, FVectorType::get(dataType, firMem.
depth),
374 clk, memOp.getNameAttr());
377 if (!memOp.getAnnotationsAttr().empty())
378 regOfVec.setAnnotationsAttr(memOp.getAnnotationsAttr());
380 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
382 auto portKind = memOp.getPortKind(index);
383 if (portKind == MemOp::PortKind::Read) {
384 generateRead(firMem,
clk, adr, enb, dta, regOfVec.getResult(), builder);
385 }
else if (portKind == MemOp::PortKind::Write) {
386 auto mask = getMask(builder, result);
387 generateWrite(firMem,
clk, adr, enb,
mask, dta, regOfVec.getResult(),
390 auto wmode = getWmode(builder, result);
391 auto wDta = getData(builder, result,
true);
392 auto mask = getMask(builder, result);
393 generateReadWrite(firMem,
clk, adr, enb,
mask, wDta, dta,
wmode,
394 regOfVec.getResult(), builder);
401 for (
auto r : debugPorts)
402 r.replaceAllUsesWith(RefSendOp::create(builder, regOfVec.getResult()));
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)