19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/Support/Debug.h"
23#define DEBUG_TYPE "mem-to-reg-of-vec"
27#define GEN_PASS_DEF_MEMTOREGOFVEC
28#include "circt/Dialect/FIRRTL/Passes.h.inc"
33using namespace firrtl;
36struct MemToRegOfVecPass
37 :
public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
40 void runOnOperation()
override {
41 auto circtOp = getOperation();
42 auto &instanceInfo = getAnalysis<InstanceInfo>();
45 convertMemToRegOfVecAnnoClass))
46 return markAllAnalysesPreserved();
48 DenseSet<Operation *> dutModuleSet;
49 for (
auto moduleOp : circtOp.getOps<FModuleOp>())
50 if (instanceInfo.anyInstanceInEffectiveDesign(moduleOp))
51 dutModuleSet.insert(moduleOp);
53 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
55 if (auto mod = dyn_cast<FModuleOp>(op))
60 void runOnModule(FModuleOp mod) {
62 mod.getBodyBlock()->walk([&](MemOp memOp) {
63 LLVM_DEBUG(llvm::dbgs() <<
"\n Memory op:" << memOp);
65 auto firMem = memOp.getSummary();
72 if (firMem.isSeqMem())
75 generateMemory(memOp, firMem);
80 Value addPipelineStages(ImplicitLocOpBuilder &b,
size_t stages, Value clock,
81 Value pipeInput, StringRef name, Value gate = {}) {
86 auto reg = RegOp::create(b, pipeInput.getType(), clock, name).getResult();
88 WhenOp::create(b, gate,
false,
89 [&]() { MatchingConnectOp::create(b, reg, pipeInput); });
91 MatchingConnectOp::create(b, reg, pipeInput);
99 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
100 return SubfieldOp::create(builder, bundle,
"clk");
103 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
104 return SubfieldOp::create(builder, bundle,
"addr");
107 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
108 return SubfieldOp::create(builder, bundle,
"wmode");
111 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
112 return SubfieldOp::create(builder, bundle,
"en");
115 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
116 auto bType = type_cast<BundleType>(bundle.getType());
117 if (bType.getElement(
"mask"))
118 return SubfieldOp::create(builder, bundle,
"mask");
119 return SubfieldOp::create(builder, bundle,
"wmask");
122 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
123 bool getWdata =
false) {
124 auto bType = type_cast<BundleType>(bundle.getType());
125 if (bType.getElement(
"data"))
126 return SubfieldOp::create(builder, bundle,
"data");
127 if (bType.getElement(
"rdata") && !getWdata)
128 return SubfieldOp::create(builder, bundle,
"rdata");
129 return SubfieldOp::create(builder, bundle,
"wdata");
132 void generateRead(
const FirMemory &firMem, Value clock, Value
addr,
133 Value enable, Value
data, Value regOfVec,
134 ImplicitLocOpBuilder &builder) {
135 if (ignoreReadEnable) {
138 for (
size_t j = 0, e = firMem.
readLatency; j != e; ++j) {
139 auto enLast = enable;
141 enable = addPipelineStages(builder, 1, clock, enable,
"en");
142 addr = addPipelineStages(builder, 1, clock,
addr,
"addr", enLast);
148 addPipelineStages(builder, firMem.
readLatency, clock, enable,
"en");
154 Value
rdata = SubaccessOp::create(builder, regOfVec,
addr);
155 if (!ignoreReadEnable) {
157 MatchingConnectOp::create(
158 builder,
data, InvalidValueOp::create(builder,
data.getType()));
160 WhenOp::create(builder, enable,
false, [&]() {
161 MatchingConnectOp::create(builder,
data,
rdata);
165 MatchingConnectOp::create(builder,
data,
rdata);
169 void generateWrite(
const FirMemory &firMem, Value clock, Value
addr,
170 Value enable, Value maskBits, Value wdataIn,
171 Value regOfVec, ImplicitLocOpBuilder &builder) {
176 addr = addPipelineStages(builder, numStages, clock,
addr,
"addr");
177 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
178 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
179 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
188 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
203 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
205 wdataIn.getDefiningOp()->emitOpError(
206 "Cannot convert memory to bank of registers");
210 WhenOp::create(builder, enable,
false, [&]() {
212 for (
auto regDataMask : loweredRegDataMaskFields) {
213 auto regField = std::get<0>(regDataMask);
214 auto dataField = std::get<1>(regDataMask);
215 auto maskField = std::get<2>(regDataMask);
217 WhenOp::create(builder, maskField,
false, [&]() {
218 MatchingConnectOp::create(builder, regField, dataField);
224 void generateReadWrite(
const FirMemory &firMem, Value clock, Value
addr,
225 Value enable, Value maskBits, Value wdataIn,
226 Value rdataOut, Value
wmode, Value regOfVec,
227 ImplicitLocOpBuilder &builder) {
232 addr = addPipelineStages(builder, numStages, clock,
addr,
"addr");
233 enable = addPipelineStages(builder, numStages, clock, enable,
"en");
234 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn,
"wdata");
235 maskBits = addPipelineStages(builder, numStages, clock, maskBits,
"wmask");
238 Value
rdata = SubaccessOp::create(builder, regOfVec,
addr);
240 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
241 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
243 wdataIn.getDefiningOp()->emitOpError(
244 "Cannot convert memory to bank of registers");
248 MatchingConnectOp::create(
249 builder, rdataOut, InvalidValueOp::create(builder, rdataOut.getType()));
251 WhenOp::create(builder, enable,
false, [&]() {
254 builder,
wmode,
true,
258 for (
auto regDataMask : loweredRegDataMaskFields) {
259 auto regField = std::get<0>(regDataMask);
260 auto dataField = std::get<1>(regDataMask);
261 auto maskField = std::get<2>(regDataMask);
264 builder, maskField,
false, [&]() {
265 MatchingConnectOp::create(builder, regField, dataField);
270 [&]() { MatchingConnectOp::create(builder, rdataOut,
rdata); });
281 bool getFields(Value reg, Value input, Value
mask,
282 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
283 ImplicitLocOpBuilder &builder) {
287 if (
auto bundle = type_dyn_cast<BundleType>(inType)) {
288 if (
auto mBundle = type_dyn_cast<BundleType>(maskType))
289 return mBundle.getNumElements() == bundle.getNumElements();
290 }
else if (
auto vec = type_dyn_cast<FVectorType>(inType)) {
291 if (
auto mVec = type_dyn_cast<FVectorType>(maskType))
292 return mVec.getNumElements() == vec.getNumElements();
298 std::function<bool(Value, Value, Value)> flatAccess =
299 [&](Value
reg, Value input, Value
mask) ->
bool {
300 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
301 if (!isValidMask(inType, type_cast<FIRRTLType>(
mask.getType()))) {
302 input.getDefiningOp()->emitOpError(
"Mask type is not valid");
306 .
Case<BundleType>([&](BundleType bundle) {
307 for (
size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
308 auto regField = SubfieldOp::create(builder, reg, i);
309 auto inputField = SubfieldOp::create(builder, input, i);
310 auto maskField = SubfieldOp::create(builder,
mask, i);
311 if (!flatAccess(regField, inputField, maskField))
316 .Case<FVectorType>([&](
auto vector) {
317 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
318 auto regField = SubindexOp::create(builder, reg, i);
319 auto inputField = SubindexOp::create(builder, input, i);
320 auto maskField = SubindexOp::create(builder,
mask, i);
321 if (!flatAccess(regField, inputField, maskField))
326 .Case<IntType>([&](
auto iType) {
327 results.push_back({
reg, input,
mask});
328 return iType.getWidth().has_value();
330 .Default([&](
auto) {
return false; });
332 if (flatAccess(reg, input,
mask))
338 void generateMemory(MemOp memOp,
FirMemory &firMem) {
339 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
340 auto dataType = memOp.getDataType();
342 auto innerSym = memOp.getInnerSym();
343 SmallVector<Value> debugPorts;
346 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
348 auto result = memOp.getResult(index);
349 if (type_isa<RefType>(result.getType())) {
350 debugPorts.push_back(result);
355 auto wire = WireOp::create(
356 builder, result.getType(),
357 (memOp.getName() +
"_" + memOp.getPortName(index)).str(),
358 memOp.getNameKind());
359 result.replaceAllUsesWith(wire.getResult());
360 result = wire.getResult();
362 auto adr = getAddr(builder, result);
363 auto enb = getEnable(builder, result);
364 auto clk = getClock(builder, result);
365 auto dta = getData(builder, result);
370 RegOp::create(builder, FVectorType::get(dataType, firMem.
depth),
371 clk, memOp.getNameAttr());
374 if (!memOp.getAnnotationsAttr().empty())
375 regOfVec.setAnnotationsAttr(memOp.getAnnotationsAttr());
377 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
379 auto portKind = memOp.getPortKind(index);
380 if (portKind == MemOp::PortKind::Read) {
381 generateRead(firMem,
clk, adr, enb, dta, regOfVec.getResult(), builder);
382 }
else if (portKind == MemOp::PortKind::Write) {
383 auto mask = getMask(builder, result);
384 generateWrite(firMem,
clk, adr, enb,
mask, dta, regOfVec.getResult(),
387 auto wmode = getWmode(builder, result);
388 auto wDta = getData(builder, result,
true);
389 auto mask = getMask(builder, result);
390 generateReadWrite(firMem,
clk, adr, enb,
mask, wDta, dta,
wmode,
391 regOfVec.getResult(), builder);
398 for (
auto r : debugPorts)
399 r.replaceAllUsesWith(RefSendOp::create(builder, regOfVec.getResult()));
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)