17 #include "mlir/Pass/Pass.h"
18 #include "llvm/ADT/APSInt.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "firrtl-infer-read-write"
27 #define GEN_PASS_DEF_INFERREADWRITE
28 #include "circt/Dialect/FIRRTL/Passes.h.inc"
32 using namespace circt;
33 using namespace firrtl;
36 struct InferReadWritePass
37 :
public circt::firrtl::impl::InferReadWriteBase<InferReadWritePass> {
47 void runOnOperation()
override {
48 LLVM_DEBUG(llvm::dbgs() <<
"\n Running Infer Read Write on module:"
50 SmallVector<Operation *> opsToErase;
52 auto result = getOperation().walk([&](Operation *op) {
60 if (isa<WhenOp>(op)) {
62 <<
"is unsupported by InferReadWrite as this pass cannot trace "
63 "signal drivers in their presence. Please run `ExpandWhens` to "
64 "remove these operations before running this pass.";
65 return WalkResult::interrupt();
68 MemOp memOp = dyn_cast<MemOp>(op);
70 return WalkResult::advance();
72 inferUnmasked(memOp, opsToErase);
74 size_t nReads, nWrites, nRWs, nDbgs;
75 memOp.getNumPorts(nReads, nWrites, nRWs, nDbgs);
78 if (!(nReads == 1 && nWrites == 1 && nRWs == 0) ||
79 !(memOp.getReadLatency() == 1 && memOp.getWriteLatency() == 1))
80 return WalkResult::skip();
81 SmallVector<Attribute, 4> resultNames;
82 SmallVector<Type, 4> resultTypes;
83 SmallVector<Attribute> portAtts;
84 SmallVector<Attribute, 4> portAnnotations;
87 SmallVector<Value> readTerms, writeTerms;
88 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
90 portAnno = memOp.getPortAnnotation(portIt.index());
91 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
92 resultNames.push_back(memOp.getPortName(portIt.index()));
93 resultTypes.push_back(memOp.getResult(portIt.index()).getType());
94 portAnnotations.push_back(portAnno);
98 if (!cast<ArrayAttr>(portAnno).
empty())
99 portAtts.push_back(memOp.getPortAnnotation(portIt.index()));
101 Value portVal = portIt.value();
104 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
106 for (Operation *u : portVal.getUsers())
107 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
109 auto fName = sf.getInput().getType().base().getElementName(
114 getProductTerms(sf, isReadPort ? readTerms : writeTerms);
116 else if (fName ==
"clk") {
118 rClock = getConnectSrc(sf);
120 wClock = getConnectSrc(sf);
125 if (!sameDriver(rClock, wClock))
126 return WalkResult::skip();
130 llvm::dbgs() <<
"\n read clock:" << rClock
131 <<
" --- write clock:" << wClock;
132 llvm::dbgs() <<
"\n Read terms==>";
for (
auto t
133 : readTerms) llvm::dbgs()
136 llvm::dbgs() <<
"\n Write terms==>";
for (
auto t
137 : writeTerms) llvm::dbgs()
144 auto complementTerm = checkComplement(readTerms, writeTerms);
146 return WalkResult::skip();
151 resultTypes.push_back(MemOp::getTypeForPort(
152 memOp.getDepth(), memOp.getDataType(), MemOp::PortKind::ReadWrite,
153 memOp.getMaskBits()));
154 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
155 portAnnotations.push_back(builder.getArrayAttr(portAtts));
157 auto rwMem = builder.create<MemOp>(
158 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
159 memOp.getDepth(), RUWAttr::Undefined,
160 builder.getArrayAttr(resultNames), memOp.getNameAttr(),
161 memOp.getNameKind(), memOp.getAnnotations(),
162 builder.getArrayAttr(portAnnotations), memOp.getInnerSymAttr(),
163 memOp.getInitAttr(), memOp.getPrefixAttr());
164 ++numRWPortMemoriesInferred;
165 auto rwPort = rwMem->getResult(nDbgs);
169 auto addr = builder.create<SubfieldOp>(rwPort,
"addr");
171 auto enb = builder.create<SubfieldOp>(rwPort,
"en");
173 auto clk = builder.create<SubfieldOp>(rwPort,
"clk");
174 auto readData = builder.create<SubfieldOp>(rwPort,
"rdata");
177 auto wmode = builder.create<SubfieldOp>(rwPort,
"wmode");
178 auto writeData = builder.create<SubfieldOp>(rwPort,
"wdata");
179 auto mask = builder.create<SubfieldOp>(rwPort,
"wmask");
182 builder.create<WireOp>(
addr.getType(),
"readAddr").getResult();
184 builder.create<WireOp>(
addr.getType(),
"writeAddr").getResult();
186 builder.create<WireOp>(enb.getType(),
"writeEnable").getResult();
188 builder.create<WireOp>(enb.getType(),
"readEnable").getResult();
190 builder.create<WireOp>(
ClockType::get(enb.getContext())).getResult();
192 builder.create<MatchingConnectOp>(
193 addr, builder.create<MuxPrimOp>(wEnWire, wAddr, rAddr));
195 builder.create<MatchingConnectOp>(
196 enb, builder.create<OrPrimOp>(rEnWire, wEnWire));
197 builder.setInsertionPointToEnd(
wmode->getBlock());
198 builder.create<MatchingConnectOp>(
wmode, complementTerm);
200 size_t dbgsIndex = 0;
201 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
203 Value portVal = portIt.value();
204 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
205 memOp.getResult(portIt.index())
206 .replaceAllUsesWith(rwMem.getResult(dbgsIndex));
212 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
215 for (Operation *u : portVal.getUsers())
216 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
217 StringRef fName = sf.getInput().getType().base().getElementName(
221 repl = llvm::StringSwitch<Value>(fName)
225 .Case(
"data", readData);
227 repl = llvm::StringSwitch<Value>(fName)
229 .Case(
"clk", writeClock)
231 .Case(
"data", writeData)
233 sf.replaceAllUsesWith(repl);
235 opsToErase.push_back(sf);
238 simplifyWmode(rwMem);
240 opsToErase.push_back(memOp);
241 return WalkResult::advance();
244 if (result.wasInterrupted())
245 return signalPassFailure();
247 for (
auto *o : opsToErase)
253 Value getConnectSrc(Value dst) {
254 for (
auto *c : dst.getUsers())
255 if (
auto connect = dyn_cast<FConnectLike>(c))
264 bool sameDriver(Value rClock, Value wClock) {
265 if (rClock == wClock)
267 DenseSet<Value> rClocks, wClocks;
271 rClocks.insert(rClock);
272 rClock = getConnectSrc(rClock);
275 bool sameClock =
false;
279 if (rClocks.find(wClock) != rClocks.end()) {
283 wClock = getConnectSrc(wClock);
288 void getProductTerms(Value enValue, SmallVector<Value> &terms) {
291 SmallVector<Value> worklist;
292 worklist.push_back(enValue);
293 while (!worklist.empty()) {
294 auto term = worklist.back();
296 terms.push_back(term);
297 if (isa<BlockArgument>(term))
299 TypeSwitch<Operation *>(term.getDefiningOp())
300 .Case<NodeOp>([&](
auto n) { worklist.push_back(n.getInput()); })
301 .Case<AndPrimOp>([&](AndPrimOp andOp) {
302 worklist.push_back(andOp.getOperand(0));
303 worklist.push_back(andOp.getOperand(1));
305 .Case<MuxPrimOp>([&](
auto muxOp) {
309 if (ConstantOp cLow = dyn_cast_or_null<ConstantOp>(
310 muxOp.getLow().getDefiningOp()))
311 if (cLow.getValue().isZero()) {
312 worklist.push_back(muxOp.getSel());
313 worklist.push_back(muxOp.getHigh());
317 if (
auto src = getConnectSrc(term))
318 worklist.push_back(src);
328 Value checkComplement(
const SmallVector<Value> &readTerms,
329 const SmallVector<Value> &writeTerms) {
332 for (
auto t1 : readTerms)
333 for (
auto t2 : writeTerms) {
335 if (!isa<BlockArgument>(t1) && isa<NotPrimOp>(t1.getDefiningOp()))
336 if (cast<NotPrimOp>(t1.getDefiningOp()).getInput() == t2)
339 if (!isa<BlockArgument>(t2) && isa<NotPrimOp>(t2.getDefiningOp()))
340 if (cast<NotPrimOp>(t2.getDefiningOp()).getInput() == t1)
347 void handleCatPrimOp(CatPrimOp defOp, SmallVectorImpl<Value> &bits) {
351 for (
auto operand : defOp->getOperands()) {
352 SmallVectorImpl<Value> &opBits = valueBitsSrc[operand];
354 getBitWidth(type_cast<FIRRTLBaseType>(operand.getType())).value();
355 assert(opBits.size() == s);
356 for (
long i = lastSize, e = lastSize + s; i != e; ++i)
357 bits[i] = opBits[i - lastSize];
362 void handleBitsPrimOp(BitsPrimOp bitsPrim, SmallVectorImpl<Value> &bits) {
364 SmallVectorImpl<Value> &opBits = valueBitsSrc[bitsPrim.getInput()];
365 for (
size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0;
366 srcIndex <= e; ++srcIndex, ++i)
367 bits[i] = opBits[srcIndex];
376 bool areBitsDrivenBySameSource(Value val) {
377 SmallVector<Value> stack;
378 stack.push_back(val);
380 while (!stack.empty()) {
381 auto val = stack.back();
382 if (valueBitsSrc.contains(val)) {
387 auto size =
getBitWidth(type_cast<FIRRTLBaseType>(val.getType()));
389 if (!size.has_value())
392 auto bitsSize = size.value();
393 if (
auto *defOp = val.getDefiningOp()) {
394 if (isa<CatPrimOp>(defOp)) {
395 bool operandsDone =
true;
398 for (
auto operand : defOp->getOperands()) {
399 if (valueBitsSrc.contains(operand))
401 stack.push_back(operand);
402 operandsDone =
false;
407 valueBitsSrc[val].resize_for_overwrite(bitsSize);
408 handleCatPrimOp(cast<CatPrimOp>(defOp), valueBitsSrc[val]);
409 }
else if (
auto bitsPrim = dyn_cast<BitsPrimOp>(defOp)) {
410 auto input = bitsPrim.getInput();
411 if (!valueBitsSrc.contains(input)) {
412 stack.push_back(input);
415 valueBitsSrc[val].resize_for_overwrite(bitsSize);
416 handleBitsPrimOp(bitsPrim, valueBitsSrc[val]);
417 }
else if (
auto constOp = dyn_cast<ConstantOp>(defOp)) {
418 auto constVal = constOp.getValue();
419 valueBitsSrc[val].resize_for_overwrite(bitsSize);
420 if (constVal.isAllOnes() || constVal.isZero()) {
421 for (
auto &b : valueBitsSrc[val])
425 }
else if (
auto wireOp = dyn_cast<WireOp>(defOp)) {
428 valueBitsSrc[val].resize_for_overwrite(bitsSize);
429 if (
auto src = getConnectSrc(wireOp.getResult())) {
430 valueBitsSrc[val][0] = src;
432 valueBitsSrc[val][0] = wireOp.getResult();
439 if (!valueBitsSrc.contains(val))
441 return llvm::all_equal(valueBitsSrc[val]);
446 void simplifyWmode(MemOp &memOp) {
450 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
451 auto portKind = memOp.getPortKind(portIt.index());
452 if (portKind != MemOp::PortKind::ReadWrite)
454 Value enableDriver, wmodeDriver;
455 Value portVal = portIt.value();
457 for (Operation *u : portVal.getUsers())
458 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
461 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
463 if (fName.contains(
"en"))
464 enableDriver = getConnectSrc(sf.getResult());
465 if (fName.contains(
"wmode"))
466 wmodeDriver = getConnectSrc(sf.getResult());
469 if (enableDriver && wmodeDriver) {
470 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
471 builder.setInsertionPointToStart(
472 memOp->getParentOfType<FModuleOp>().getBodyBlock());
473 auto constOne = builder.create<ConstantOp>(
475 setEnable(enableDriver, wmodeDriver, constOne);
482 void setEnable(Value enableDriver, Value wmodeDriver, Value constOne) {
483 auto getDriverOp = [&](Value dst) -> Operation * {
485 auto *defOp = dst.getDefiningOp();
487 if (isa<WireOp>(defOp))
488 dst = getConnectSrc(dst);
490 defOp = dst.getDefiningOp();
494 SmallVector<Value> stack;
495 llvm::SmallDenseSet<Value> visited;
496 stack.push_back(wmodeDriver);
497 while (!stack.empty()) {
498 auto driver = stack.pop_back_val();
499 if (!visited.insert(driver).second)
501 auto *defOp = getDriverOp(driver);
504 for (
auto operand : llvm::enumerate(defOp->getOperands())) {
505 if (operand.value() == enableDriver)
506 defOp->setOperand(operand.index(), constOne);
508 stack.push_back(operand.value());
513 void inferUnmasked(MemOp &memOp, SmallVector<Operation *> &opsToErase) {
514 bool isMasked =
true;
518 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
520 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
521 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug)
523 Value portVal = portIt.value();
525 for (Operation *u : portVal.getUsers())
526 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
529 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
531 if (fName.contains(
"mask")) {
533 if (sf.getResult().getType().getBitWidthOrSentinel() == 1)
538 if (
auto maskVal = getConnectSrc(sf))
539 if (areBitsDrivenBySameSource(maskVal))
547 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
549 SmallVector<Type, 4> resultTypes;
550 for (
size_t i = 0, e = memOp.getNumResults(); i != e; ++i)
551 resultTypes.push_back(
552 MemOp::getTypeForPort(memOp.getDepth(), memOp.getDataType(),
553 memOp.getPortKind(i), 1));
556 auto newMem = builder.create<MemOp>(
557 resultTypes, memOp.getReadLatencyAttr(), memOp.getWriteLatencyAttr(),
558 memOp.getDepthAttr(), memOp.getRuwAttr(), memOp.getPortNamesAttr(),
559 memOp.getNameAttr(), memOp.getNameKindAttr(),
560 memOp.getAnnotationsAttr(), memOp.getPortAnnotationsAttr(),
561 memOp.getInnerSymAttr(), memOp.getInitAttr(), memOp.getPrefixAttr());
563 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
565 Value oldPort = portIt.value();
567 auto newPortVal = newMem->getResult(portIt.index());
569 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
570 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
571 oldPort.replaceAllUsesWith(newPortVal);
575 for (Operation *u : oldPort.getUsers()) {
576 auto oldRes = dyn_cast<SubfieldOp>(u);
578 builder.create<SubfieldOp>(newPortVal, oldRes.getFieldIndex());
580 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
583 if (fName.contains(
"mask")) {
584 WireOp dummy = builder.create<WireOp>(oldRes.getType());
585 oldRes->replaceAllUsesWith(dummy);
586 builder.create<MatchingConnectOp>(
587 sf, builder.create<ConstantOp>(
590 oldRes->replaceAllUsesWith(sf);
592 opsToErase.push_back(oldRes);
595 opsToErase.push_back(memOp);
602 llvm::DenseMap<Value, SmallVector<Value>> valueBitsSrc;
607 return std::make_unique<InferReadWritePass>();
assert(baseType &&"element must be base type")
static InstancePath empty
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createInferReadWritePass()
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.