18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Pass/Pass.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "firrtl-infer-read-write"
29 #define GEN_PASS_DEF_INFERREADWRITE
30 #include "circt/Dialect/FIRRTL/Passes.h.inc"
34 using namespace circt;
35 using namespace firrtl;
38 struct InferReadWritePass
39 :
public circt::firrtl::impl::InferReadWriteBase<InferReadWritePass> {
49 void runOnOperation()
override {
50 LLVM_DEBUG(llvm::dbgs() <<
"\n Running Infer Read Write on module:"
52 SmallVector<Operation *> opsToErase;
53 for (MemOp memOp : llvm::make_early_inc_range(
55 inferUnmasked(memOp, opsToErase);
57 size_t nReads, nWrites, nRWs, nDbgs;
58 memOp.getNumPorts(nReads, nWrites, nRWs, nDbgs);
61 if (!(nReads == 1 && nWrites == 1 && nRWs == 0) ||
62 !(memOp.getReadLatency() == 1 && memOp.getWriteLatency() == 1))
64 SmallVector<Attribute, 4> resultNames;
65 SmallVector<Type, 4> resultTypes;
66 SmallVector<Attribute> portAtts;
67 SmallVector<Attribute, 4> portAnnotations;
70 SmallVector<Value> readTerms, writeTerms;
71 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
73 portAnno = memOp.getPortAnnotation(portIt.index());
74 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
75 resultNames.push_back(memOp.getPortName(portIt.index()));
76 resultTypes.push_back(memOp.getResult(portIt.index()).getType());
77 portAnnotations.push_back(portAnno);
81 if (!cast<ArrayAttr>(portAnno).
empty())
82 portAtts.push_back(memOp.getPortAnnotation(portIt.index()));
84 Value portVal = portIt.value();
87 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
89 for (Operation *u : portVal.getUsers())
90 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
92 auto fName = sf.getInput().getType().base().getElementName(
97 getProductTerms(sf, isReadPort ? readTerms : writeTerms);
99 else if (fName ==
"clk") {
101 rClock = getConnectSrc(sf);
103 wClock = getConnectSrc(sf);
108 if (!sameDriver(rClock, wClock))
113 llvm::dbgs() <<
"\n read clock:" << rClock
114 <<
" --- write clock:" << wClock;
115 llvm::dbgs() <<
"\n Read terms==>";
for (
auto t
116 : readTerms) llvm::dbgs()
119 llvm::dbgs() <<
"\n Write terms==>";
for (
auto t
120 : writeTerms) llvm::dbgs()
127 auto complementTerm = checkComplement(readTerms, writeTerms);
134 resultTypes.push_back(MemOp::getTypeForPort(
135 memOp.getDepth(), memOp.getDataType(), MemOp::PortKind::ReadWrite,
136 memOp.getMaskBits()));
137 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
138 portAnnotations.push_back(builder.getArrayAttr(portAtts));
140 auto rwMem = builder.create<MemOp>(
141 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
142 memOp.getDepth(), RUWAttr::Undefined,
143 builder.getArrayAttr(resultNames), memOp.getNameAttr(),
144 memOp.getNameKind(), memOp.getAnnotations(),
145 builder.getArrayAttr(portAnnotations), memOp.getInnerSymAttr(),
146 memOp.getInitAttr(), memOp.getPrefixAttr());
147 ++numRWPortMemoriesInferred;
148 auto rwPort = rwMem->getResult(nDbgs);
152 auto addr = builder.create<SubfieldOp>(rwPort,
"addr");
154 auto enb = builder.create<SubfieldOp>(rwPort,
"en");
156 auto clk = builder.create<SubfieldOp>(rwPort,
"clk");
157 auto readData = builder.create<SubfieldOp>(rwPort,
"rdata");
160 auto wmode = builder.create<SubfieldOp>(rwPort,
"wmode");
161 auto writeData = builder.create<SubfieldOp>(rwPort,
"wdata");
162 auto mask = builder.create<SubfieldOp>(rwPort,
"wmask");
165 builder.create<WireOp>(
addr.getType(),
"readAddr").getResult();
167 builder.create<WireOp>(
addr.getType(),
"writeAddr").getResult();
169 builder.create<WireOp>(enb.getType(),
"writeEnable").getResult();
171 builder.create<WireOp>(enb.getType(),
"readEnable").getResult();
173 builder.create<WireOp>(
ClockType::get(enb.getContext())).getResult();
175 builder.create<MatchingConnectOp>(
176 addr, builder.create<MuxPrimOp>(wEnWire, wAddr, rAddr));
178 builder.create<MatchingConnectOp>(
179 enb, builder.create<OrPrimOp>(rEnWire, wEnWire));
180 builder.setInsertionPointToEnd(
wmode->getBlock());
181 builder.create<MatchingConnectOp>(
wmode, complementTerm);
183 size_t dbgsIndex = 0;
184 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
186 Value portVal = portIt.value();
187 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
188 memOp.getResult(portIt.index())
189 .replaceAllUsesWith(rwMem.getResult(dbgsIndex));
195 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
198 for (Operation *u : portVal.getUsers())
199 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
200 StringRef fName = sf.getInput().getType().base().getElementName(
204 repl = llvm::StringSwitch<Value>(fName)
208 .Case(
"data", readData);
210 repl = llvm::StringSwitch<Value>(fName)
212 .Case(
"clk", writeClock)
214 .Case(
"data", writeData)
216 sf.replaceAllUsesWith(repl);
218 opsToErase.push_back(sf);
221 simplifyWmode(rwMem);
223 opsToErase.push_back(memOp);
225 for (
auto *o : opsToErase)
231 Value getConnectSrc(Value dst) {
232 for (
auto *c : dst.getUsers())
233 if (
auto connect = dyn_cast<FConnectLike>(c))
242 bool sameDriver(Value rClock, Value wClock) {
243 if (rClock == wClock)
245 DenseSet<Value> rClocks, wClocks;
249 rClocks.insert(rClock);
250 rClock = getConnectSrc(rClock);
253 bool sameClock =
false;
257 if (rClocks.find(wClock) != rClocks.end()) {
261 wClock = getConnectSrc(wClock);
266 void getProductTerms(Value enValue, SmallVector<Value> &terms) {
269 SmallVector<Value> worklist;
270 worklist.push_back(enValue);
271 while (!worklist.empty()) {
272 auto term = worklist.back();
274 terms.push_back(term);
275 if (isa<BlockArgument>(term))
277 TypeSwitch<Operation *>(term.getDefiningOp())
278 .Case<NodeOp>([&](
auto n) { worklist.push_back(n.getInput()); })
279 .Case<AndPrimOp>([&](AndPrimOp andOp) {
280 worklist.push_back(andOp.getOperand(0));
281 worklist.push_back(andOp.getOperand(1));
283 .Case<MuxPrimOp>([&](
auto muxOp) {
287 if (ConstantOp cLow = dyn_cast_or_null<ConstantOp>(
288 muxOp.getLow().getDefiningOp()))
289 if (cLow.getValue().isZero()) {
290 worklist.push_back(muxOp.getSel());
291 worklist.push_back(muxOp.getHigh());
295 if (
auto src = getConnectSrc(term))
296 worklist.push_back(src);
306 Value checkComplement(
const SmallVector<Value> &readTerms,
307 const SmallVector<Value> &writeTerms) {
310 for (
auto t1 : readTerms)
311 for (
auto t2 : writeTerms) {
313 if (!isa<BlockArgument>(t1) && isa<NotPrimOp>(t1.getDefiningOp()))
314 if (cast<NotPrimOp>(t1.getDefiningOp()).getInput() == t2)
317 if (!isa<BlockArgument>(t2) && isa<NotPrimOp>(t2.getDefiningOp()))
318 if (cast<NotPrimOp>(t2.getDefiningOp()).getInput() == t1)
325 void handleCatPrimOp(CatPrimOp defOp, SmallVectorImpl<Value> &bits) {
329 for (
auto operand : defOp->getOperands()) {
330 SmallVectorImpl<Value> &opBits = valueBitsSrc[operand];
332 getBitWidth(type_cast<FIRRTLBaseType>(operand.getType())).value();
333 assert(opBits.size() == s);
334 for (
long i = lastSize, e = lastSize + s; i != e; ++i)
335 bits[i] = opBits[i - lastSize];
340 void handleBitsPrimOp(BitsPrimOp bitsPrim, SmallVectorImpl<Value> &bits) {
342 SmallVectorImpl<Value> &opBits = valueBitsSrc[bitsPrim.getInput()];
343 for (
size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0;
344 srcIndex <= e; ++srcIndex, ++i)
345 bits[i] = opBits[srcIndex];
354 bool areBitsDrivenBySameSource(Value val) {
355 SmallVector<Value> stack;
356 stack.push_back(val);
358 while (!stack.empty()) {
359 auto val = stack.back();
360 if (valueBitsSrc.contains(val)) {
365 auto size =
getBitWidth(type_cast<FIRRTLBaseType>(val.getType()));
367 if (!size.has_value())
370 auto bitsSize = size.value();
371 if (
auto *defOp = val.getDefiningOp()) {
372 if (isa<CatPrimOp>(defOp)) {
373 bool operandsDone =
true;
376 for (
auto operand : defOp->getOperands()) {
377 if (valueBitsSrc.contains(operand))
379 stack.push_back(operand);
380 operandsDone =
false;
385 valueBitsSrc[val].resize_for_overwrite(bitsSize);
386 handleCatPrimOp(cast<CatPrimOp>(defOp), valueBitsSrc[val]);
387 }
else if (
auto bitsPrim = dyn_cast<BitsPrimOp>(defOp)) {
388 auto input = bitsPrim.getInput();
389 if (!valueBitsSrc.contains(input)) {
390 stack.push_back(input);
393 valueBitsSrc[val].resize_for_overwrite(bitsSize);
394 handleBitsPrimOp(bitsPrim, valueBitsSrc[val]);
395 }
else if (
auto constOp = dyn_cast<ConstantOp>(defOp)) {
396 auto constVal = constOp.getValue();
397 valueBitsSrc[val].resize_for_overwrite(bitsSize);
398 if (constVal.isAllOnes() || constVal.isZero()) {
399 for (
auto &b : valueBitsSrc[val])
403 }
else if (
auto wireOp = dyn_cast<WireOp>(defOp)) {
406 valueBitsSrc[val].resize_for_overwrite(bitsSize);
407 if (
auto src = getConnectSrc(wireOp.getResult())) {
408 valueBitsSrc[val][0] = src;
410 valueBitsSrc[val][0] = wireOp.getResult();
417 if (!valueBitsSrc.contains(val))
419 return llvm::all_equal(valueBitsSrc[val]);
424 void simplifyWmode(MemOp &memOp) {
428 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
429 auto portKind = memOp.getPortKind(portIt.index());
430 if (portKind != MemOp::PortKind::ReadWrite)
432 Value enableDriver, wmodeDriver;
433 Value portVal = portIt.value();
435 for (Operation *u : portVal.getUsers())
436 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
439 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
441 if (fName.contains(
"en"))
442 enableDriver = getConnectSrc(sf.getResult());
443 if (fName.contains(
"wmode"))
444 wmodeDriver = getConnectSrc(sf.getResult());
447 if (enableDriver && wmodeDriver) {
448 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
449 builder.setInsertionPointToStart(
450 memOp->getParentOfType<FModuleOp>().getBodyBlock());
451 auto constOne = builder.create<ConstantOp>(
453 setEnable(enableDriver, wmodeDriver, constOne);
460 void setEnable(Value enableDriver, Value wmodeDriver, Value constOne) {
461 auto getDriverOp = [&](Value dst) -> Operation * {
463 auto *defOp = dst.getDefiningOp();
465 if (isa<WireOp>(defOp))
466 dst = getConnectSrc(dst);
468 defOp = dst.getDefiningOp();
472 SmallVector<Value> stack;
473 llvm::SmallDenseSet<Value> visited;
474 stack.push_back(wmodeDriver);
475 while (!stack.empty()) {
476 auto driver = stack.pop_back_val();
477 if (!visited.insert(driver).second)
479 auto *defOp = getDriverOp(driver);
482 for (
auto operand : llvm::enumerate(defOp->getOperands())) {
483 if (operand.value() == enableDriver)
484 defOp->setOperand(operand.index(), constOne);
486 stack.push_back(operand.value());
491 void inferUnmasked(MemOp &memOp, SmallVector<Operation *> &opsToErase) {
492 bool isMasked =
true;
496 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
498 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
499 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug)
501 Value portVal = portIt.value();
503 for (Operation *u : portVal.getUsers())
504 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
507 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
509 if (fName.contains(
"mask")) {
511 if (sf.getResult().getType().getBitWidthOrSentinel() == 1)
516 if (
auto maskVal = getConnectSrc(sf))
517 if (areBitsDrivenBySameSource(maskVal))
525 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
527 SmallVector<Type, 4> resultTypes;
528 for (
size_t i = 0, e = memOp.getNumResults(); i != e; ++i)
529 resultTypes.push_back(
530 MemOp::getTypeForPort(memOp.getDepth(), memOp.getDataType(),
531 memOp.getPortKind(i), 1));
534 auto newMem = builder.create<MemOp>(
535 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
536 memOp.getDepth(), memOp.getRuw(), memOp.getPortNames().getValue(),
537 memOp.getNameAttr(), memOp.getNameKind(),
538 memOp.getAnnotations().getValue(),
539 memOp.getPortAnnotations().getValue(), memOp.getInnerSymAttr());
541 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
543 Value oldPort = portIt.value();
545 auto newPortVal = newMem->getResult(portIt.index());
547 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
548 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
549 oldPort.replaceAllUsesWith(newPortVal);
553 for (Operation *u : oldPort.getUsers()) {
554 auto oldRes = dyn_cast<SubfieldOp>(u);
556 builder.create<SubfieldOp>(newPortVal, oldRes.getFieldIndex());
558 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
561 if (fName.contains(
"mask")) {
562 WireOp dummy = builder.create<WireOp>(oldRes.getType());
563 oldRes->replaceAllUsesWith(dummy);
564 builder.create<MatchingConnectOp>(
565 sf, builder.create<ConstantOp>(
568 oldRes->replaceAllUsesWith(sf);
570 opsToErase.push_back(oldRes);
573 opsToErase.push_back(memOp);
580 llvm::DenseMap<Value, SmallVector<Value>> valueBitsSrc;
585 return std::make_unique<InferReadWritePass>();
assert(baseType &&"element must be base type")
static InstancePath empty
static Block * getBodyBlock(FModuleLike mod)
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createInferReadWritePass()
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.