21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/Threading.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "mem-to-reg-of-vec"
29 using namespace circt;
30 using namespace firrtl;
33 struct MemToRegOfVecPass :
public MemToRegOfVecBase<MemToRegOfVecPass> {
34 MemToRegOfVecPass(
bool replSeqMem,
bool ignoreReadEnable)
35 : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
37 void runOnOperation()
override {
38 auto circtOp = getOperation();
39 DenseSet<Operation *> dutModuleSet;
43 auto *body = circtOp.getBodyBlock();
46 auto it = llvm::find_if(*body, [&](Operation &op) ->
bool {
49 if (it != body->end()) {
50 auto &instanceGraph = getAnalysis<InstanceGraph>();
51 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
52 llvm::for_each(llvm::depth_first(node),
57 auto mods = circtOp.getOps<FModuleOp>();
58 dutModuleSet.insert(mods.begin(), mods.end());
61 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
63 if (auto mod = dyn_cast<FModuleOp>(op))
68 void runOnModule(FModuleOp mod) {
70 mod.getBodyBlock()->walk([&](MemOp memOp) {
71 LLVM_DEBUG(
llvm::dbgs() <<
"\n Memory op:" << memOp);
75 auto firMem = memOp.getSummary();
83 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
84 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
85 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
88 generateMemory(memOp, firMem);
93 Value addPipelineStages(ImplicitLocOpBuilder &b,
size_t stages, Value clock,
94 Value pipeInput, StringRef name, Value gate = {}) {
99 auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
101 b.create<WhenOp>(gate,
false,
102 [&]() { b.create<StrictConnectOp>(
reg, pipeInput); });
104 b.create<StrictConnectOp>(
reg, pipeInput);
112 Value getClock(ImplicitLocOpBuilder &
builder, Value bundle) {
113 return builder.create<SubfieldOp>(bundle,
"clk");
116 Value getAddr(ImplicitLocOpBuilder &
builder, Value bundle) {
117 return builder.create<SubfieldOp>(bundle,
"addr");
120 Value getWmode(ImplicitLocOpBuilder &
builder, Value bundle) {
121 return builder.create<SubfieldOp>(bundle,
"wmode");
124 Value getEnable(ImplicitLocOpBuilder &
builder, Value bundle) {
125 return builder.create<SubfieldOp>(bundle,
"en");
128 Value getMask(ImplicitLocOpBuilder &
builder, Value bundle) {
129 auto bType = type_cast<BundleType>(bundle.getType());
130 if (bType.getElement(
"mask"))
131 return builder.create<SubfieldOp>(bundle,
"mask");
132 return builder.create<SubfieldOp>(bundle,
"wmask");
135 Value getData(ImplicitLocOpBuilder &
builder, Value bundle,
136 bool getWdata =
false) {
137 auto bType = type_cast<BundleType>(bundle.getType());
138 if (bType.getElement(
"data"))
139 return builder.create<SubfieldOp>(bundle,
"data");
140 if (bType.getElement(
"rdata") && !getWdata)
141 return builder.create<SubfieldOp>(bundle,
"rdata");
142 return builder.create<SubfieldOp>(bundle,
"wdata");
145 void generateRead(
const FirMemory &firMem, Value clock, Value
addr,
146 Value enable, Value
data, Value regOfVec,
147 ImplicitLocOpBuilder &
builder) {
148 if (ignoreReadEnable) {
151 for (
size_t j = 0, e = firMem.
readLatency; j != e; ++j) {
152 auto enLast = enable;
154 enable = addPipelineStages(
builder, 1, clock, enable,
"en");
168 if (!ignoreReadEnable) {
170 builder.create<StrictConnectOp>(
173 builder.create<WhenOp>(enable,
false, [&]() {
182 void generateWrite(
const FirMemory &firMem, Value clock, Value
addr,
183 Value enable, Value maskBits, Value wdataIn,
184 Value regOfVec, ImplicitLocOpBuilder &
builder) {
190 enable = addPipelineStages(
builder, numStages, clock, enable,
"en");
191 wdataIn = addPipelineStages(
builder, numStages, clock, wdataIn,
"wdata");
192 maskBits = addPipelineStages(
builder, numStages, clock, maskBits,
"wmask");
201 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
216 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
218 wdataIn.getDefiningOp()->emitOpError(
219 "Cannot convert memory to bank of registers");
223 builder.create<WhenOp>(enable,
false, [&]() {
225 for (
auto regDataMask : loweredRegDataMaskFields) {
226 auto regField = std::get<0>(regDataMask);
227 auto dataField = std::get<1>(regDataMask);
228 auto maskField = std::get<2>(regDataMask);
230 builder.create<WhenOp>(maskField,
false, [&]() {
231 builder.create<StrictConnectOp>(regField, dataField);
237 void generateReadWrite(
const FirMemory &firMem, Value clock, Value
addr,
238 Value enable, Value maskBits, Value wdataIn,
239 Value rdataOut, Value
wmode, Value regOfVec,
240 ImplicitLocOpBuilder &
builder) {
246 enable = addPipelineStages(
builder, numStages, clock, enable,
"en");
247 wdataIn = addPipelineStages(
builder, numStages, clock, wdataIn,
"wdata");
248 maskBits = addPipelineStages(
builder, numStages, clock, maskBits,
"wmask");
253 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
254 if (!getFields(
rdata, wdataIn, maskBits, loweredRegDataMaskFields,
256 wdataIn.getDefiningOp()->emitOpError(
257 "Cannot convert memory to bank of registers");
261 builder.create<StrictConnectOp>(
262 rdataOut,
builder.create<InvalidValueOp>(rdataOut.getType()));
264 builder.create<WhenOp>(enable,
false, [&]() {
271 for (
auto regDataMask : loweredRegDataMaskFields) {
272 auto regField = std::get<0>(regDataMask);
273 auto dataField = std::get<1>(regDataMask);
274 auto maskField = std::get<2>(regDataMask);
277 maskField,
false, [&]() {
278 builder.create<StrictConnectOp>(regField, dataField);
283 [&]() {
builder.create<StrictConnectOp>(rdataOut,
rdata); });
294 bool getFields(Value
reg, Value input, Value
mask,
295 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
296 ImplicitLocOpBuilder &
builder) {
300 if (
auto bundle = type_dyn_cast<BundleType>(inType)) {
301 if (
auto mBundle = type_dyn_cast<BundleType>(maskType))
302 return mBundle.getNumElements() == bundle.getNumElements();
303 }
else if (
auto vec = type_dyn_cast<FVectorType>(inType)) {
304 if (
auto mVec = type_dyn_cast<FVectorType>(maskType))
305 return mVec.getNumElements() == vec.getNumElements();
311 std::function<bool(Value, Value, Value)> flatAccess =
312 [&](Value
reg, Value input, Value
mask) ->
bool {
313 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
314 if (!isValidMask(inType, type_cast<FIRRTLType>(
mask.getType()))) {
315 input.getDefiningOp()->emitOpError(
"Mask type is not valid");
319 .
Case<BundleType>([&](BundleType bundle) {
320 for (
size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
321 auto regField =
builder.create<SubfieldOp>(
reg, i);
322 auto inputField =
builder.create<SubfieldOp>(input, i);
323 auto maskField =
builder.create<SubfieldOp>(
mask, i);
324 if (!flatAccess(regField, inputField, maskField))
329 .Case<FVectorType>([&](
auto vector) {
330 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
331 auto regField =
builder.create<SubindexOp>(
reg, i);
332 auto inputField =
builder.create<SubindexOp>(input, i);
333 auto maskField =
builder.create<SubindexOp>(
mask, i);
334 if (!flatAccess(regField, inputField, maskField))
339 .Case<IntType>([&](
auto iType) {
340 results.push_back({
reg, input,
mask});
341 return iType.getWidth().has_value();
343 .Default([&](
auto) {
return false; });
345 if (flatAccess(
reg, input,
mask))
350 void scatterMemTapAnno(RegOp op, ArrayAttr attr,
351 ImplicitLocOpBuilder &
builder) {
353 SmallVector<Attribute> regAnnotations;
354 auto vecType = type_cast<FVectorType>(op.getResult().getType());
355 for (
auto anno : annos) {
357 for (
size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
360 NamedAttrList newAnno;
361 newAnno.append(
"class", anno.getMember(
"class"));
362 newAnno.append(
"circt.fieldID",
363 builder.getI64IntegerAttr(vecType.getFieldID(i)));
364 newAnno.append(
"id", anno.getMember(
"id"));
365 if (
auto nla = anno.getMember(
"circt.nonlocal"))
366 newAnno.append(
"circt.nonlocal", nla);
371 regAnnotations.push_back(
builder.getDictionaryAttr(newAnno));
374 regAnnotations.push_back(anno.getAttr());
376 op.setAnnotationsAttr(
builder.getArrayAttr(regAnnotations));
380 void generateMemory(MemOp memOp,
FirMemory &firMem) {
381 ImplicitLocOpBuilder
builder(memOp.getLoc(), memOp);
382 auto dataType = memOp.getDataType();
384 auto innerSym = memOp.getInnerSym();
385 SmallVector<Value> debugPorts;
388 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
390 auto result = memOp.getResult(index);
391 if (type_isa<RefType>(result.getType())) {
392 debugPorts.push_back(result);
397 auto wire =
builder.create<WireOp>(
399 (memOp.getName() +
"_" + memOp.getPortName(index).getValue()).str(),
400 memOp.getNameKind());
401 result.replaceAllUsesWith(wire.getResult());
402 result = wire.getResult();
404 auto adr = getAddr(
builder, result);
405 auto enb = getEnable(
builder, result);
407 auto dta = getData(
builder, result);
411 regOfVec =
builder.create<RegOp>(
415 if (!memOp.getAnnotationsAttr().empty())
416 scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(),
builder);
418 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
420 auto portKind = memOp.getPortKind(index);
421 if (portKind == MemOp::PortKind::Read) {
422 generateRead(firMem,
clk, adr, enb, dta, regOfVec.getResult(),
builder);
423 }
else if (portKind == MemOp::PortKind::Write) {
425 generateWrite(firMem,
clk, adr, enb,
mask, dta, regOfVec.getResult(),
429 auto wDta = getData(
builder, result,
true);
431 generateReadWrite(firMem,
clk, adr, enb,
mask, wDta, dta,
wmode,
432 regOfVec.getResult(),
builder);
439 for (
auto r : debugPorts)
440 r.replaceAllUsesWith(
builder.create<RefSendOp>(regOfVec.getResult()));
445 bool ignoreReadEnable;
449 std::unique_ptr<mlir::Pass>
451 return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
constexpr const char * excludeMemToRegAnnoClass
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * dutAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)