18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Pass/Pass.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "firrtl-infer-read-write"
29 #define GEN_PASS_DEF_INFERREADWRITE
30 #include "circt/Dialect/FIRRTL/Passes.h.inc"
34 using namespace circt;
35 using namespace firrtl;
38 struct InferReadWritePass
39 :
public circt::firrtl::impl::InferReadWriteBase<InferReadWritePass> {
49 void runOnOperation()
override {
50 LLVM_DEBUG(llvm::dbgs() <<
"\n Running Infer Read Write on module:"
52 SmallVector<Operation *> opsToErase;
54 auto result = getOperation().walk([&](Operation *op) {
62 if (isa<WhenOp>(op)) {
64 <<
"is unsupported by InferReadWrite as this pass cannot trace "
65 "signal drivers in their presence. Please run `ExpandWhens` to "
66 "remove these operations before running this pass.";
67 return WalkResult::interrupt();
70 MemOp memOp = dyn_cast<MemOp>(op);
72 return WalkResult::advance();
74 inferUnmasked(memOp, opsToErase);
76 size_t nReads, nWrites, nRWs, nDbgs;
77 memOp.getNumPorts(nReads, nWrites, nRWs, nDbgs);
80 if (!(nReads == 1 && nWrites == 1 && nRWs == 0) ||
81 !(memOp.getReadLatency() == 1 && memOp.getWriteLatency() == 1))
82 return WalkResult::skip();
83 SmallVector<Attribute, 4> resultNames;
84 SmallVector<Type, 4> resultTypes;
85 SmallVector<Attribute> portAtts;
86 SmallVector<Attribute, 4> portAnnotations;
89 SmallVector<Value> readTerms, writeTerms;
90 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
92 portAnno = memOp.getPortAnnotation(portIt.index());
93 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
94 resultNames.push_back(memOp.getPortName(portIt.index()));
95 resultTypes.push_back(memOp.getResult(portIt.index()).getType());
96 portAnnotations.push_back(portAnno);
100 if (!cast<ArrayAttr>(portAnno).
empty())
101 portAtts.push_back(memOp.getPortAnnotation(portIt.index()));
103 Value portVal = portIt.value();
106 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
108 for (Operation *u : portVal.getUsers())
109 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
111 auto fName = sf.getInput().getType().base().getElementName(
116 getProductTerms(sf, isReadPort ? readTerms : writeTerms);
118 else if (fName ==
"clk") {
120 rClock = getConnectSrc(sf);
122 wClock = getConnectSrc(sf);
127 if (!sameDriver(rClock, wClock))
128 return WalkResult::skip();
132 llvm::dbgs() <<
"\n read clock:" << rClock
133 <<
" --- write clock:" << wClock;
134 llvm::dbgs() <<
"\n Read terms==>";
for (
auto t
135 : readTerms) llvm::dbgs()
138 llvm::dbgs() <<
"\n Write terms==>";
for (
auto t
139 : writeTerms) llvm::dbgs()
146 auto complementTerm = checkComplement(readTerms, writeTerms);
148 return WalkResult::skip();
153 resultTypes.push_back(MemOp::getTypeForPort(
154 memOp.getDepth(), memOp.getDataType(), MemOp::PortKind::ReadWrite,
155 memOp.getMaskBits()));
156 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
157 portAnnotations.push_back(builder.getArrayAttr(portAtts));
159 auto rwMem = builder.create<MemOp>(
160 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
161 memOp.getDepth(), RUWAttr::Undefined,
162 builder.getArrayAttr(resultNames), memOp.getNameAttr(),
163 memOp.getNameKind(), memOp.getAnnotations(),
164 builder.getArrayAttr(portAnnotations), memOp.getInnerSymAttr(),
165 memOp.getInitAttr(), memOp.getPrefixAttr());
166 ++numRWPortMemoriesInferred;
167 auto rwPort = rwMem->getResult(nDbgs);
171 auto addr = builder.create<SubfieldOp>(rwPort,
"addr");
173 auto enb = builder.create<SubfieldOp>(rwPort,
"en");
175 auto clk = builder.create<SubfieldOp>(rwPort,
"clk");
176 auto readData = builder.create<SubfieldOp>(rwPort,
"rdata");
179 auto wmode = builder.create<SubfieldOp>(rwPort,
"wmode");
180 auto writeData = builder.create<SubfieldOp>(rwPort,
"wdata");
181 auto mask = builder.create<SubfieldOp>(rwPort,
"wmask");
184 builder.create<WireOp>(
addr.getType(),
"readAddr").getResult();
186 builder.create<WireOp>(
addr.getType(),
"writeAddr").getResult();
188 builder.create<WireOp>(enb.getType(),
"writeEnable").getResult();
190 builder.create<WireOp>(enb.getType(),
"readEnable").getResult();
192 builder.create<WireOp>(
ClockType::get(enb.getContext())).getResult();
194 builder.create<MatchingConnectOp>(
195 addr, builder.create<MuxPrimOp>(wEnWire, wAddr, rAddr));
197 builder.create<MatchingConnectOp>(
198 enb, builder.create<OrPrimOp>(rEnWire, wEnWire));
199 builder.setInsertionPointToEnd(
wmode->getBlock());
200 builder.create<MatchingConnectOp>(
wmode, complementTerm);
202 size_t dbgsIndex = 0;
203 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
205 Value portVal = portIt.value();
206 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
207 memOp.getResult(portIt.index())
208 .replaceAllUsesWith(rwMem.getResult(dbgsIndex));
214 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
217 for (Operation *u : portVal.getUsers())
218 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
219 StringRef fName = sf.getInput().getType().base().getElementName(
223 repl = llvm::StringSwitch<Value>(fName)
227 .Case(
"data", readData);
229 repl = llvm::StringSwitch<Value>(fName)
231 .Case(
"clk", writeClock)
233 .Case(
"data", writeData)
235 sf.replaceAllUsesWith(repl);
237 opsToErase.push_back(sf);
240 simplifyWmode(rwMem);
242 opsToErase.push_back(memOp);
243 return WalkResult::advance();
246 if (result.wasInterrupted())
247 return signalPassFailure();
249 for (
auto *o : opsToErase)
255 Value getConnectSrc(Value dst) {
256 for (
auto *c : dst.getUsers())
257 if (
auto connect = dyn_cast<FConnectLike>(c))
266 bool sameDriver(Value rClock, Value wClock) {
267 if (rClock == wClock)
269 DenseSet<Value> rClocks, wClocks;
273 rClocks.insert(rClock);
274 rClock = getConnectSrc(rClock);
277 bool sameClock =
false;
281 if (rClocks.find(wClock) != rClocks.end()) {
285 wClock = getConnectSrc(wClock);
290 void getProductTerms(Value enValue, SmallVector<Value> &terms) {
293 SmallVector<Value> worklist;
294 worklist.push_back(enValue);
295 while (!worklist.empty()) {
296 auto term = worklist.back();
298 terms.push_back(term);
299 if (isa<BlockArgument>(term))
301 TypeSwitch<Operation *>(term.getDefiningOp())
302 .Case<NodeOp>([&](
auto n) { worklist.push_back(n.getInput()); })
303 .Case<AndPrimOp>([&](AndPrimOp andOp) {
304 worklist.push_back(andOp.getOperand(0));
305 worklist.push_back(andOp.getOperand(1));
307 .Case<MuxPrimOp>([&](
auto muxOp) {
311 if (ConstantOp cLow = dyn_cast_or_null<ConstantOp>(
312 muxOp.getLow().getDefiningOp()))
313 if (cLow.getValue().isZero()) {
314 worklist.push_back(muxOp.getSel());
315 worklist.push_back(muxOp.getHigh());
319 if (
auto src = getConnectSrc(term))
320 worklist.push_back(src);
330 Value checkComplement(
const SmallVector<Value> &readTerms,
331 const SmallVector<Value> &writeTerms) {
334 for (
auto t1 : readTerms)
335 for (
auto t2 : writeTerms) {
337 if (!isa<BlockArgument>(t1) && isa<NotPrimOp>(t1.getDefiningOp()))
338 if (cast<NotPrimOp>(t1.getDefiningOp()).getInput() == t2)
341 if (!isa<BlockArgument>(t2) && isa<NotPrimOp>(t2.getDefiningOp()))
342 if (cast<NotPrimOp>(t2.getDefiningOp()).getInput() == t1)
349 void handleCatPrimOp(CatPrimOp defOp, SmallVectorImpl<Value> &bits) {
353 for (
auto operand : defOp->getOperands()) {
354 SmallVectorImpl<Value> &opBits = valueBitsSrc[operand];
356 getBitWidth(type_cast<FIRRTLBaseType>(operand.getType())).value();
357 assert(opBits.size() == s);
358 for (
long i = lastSize, e = lastSize + s; i != e; ++i)
359 bits[i] = opBits[i - lastSize];
364 void handleBitsPrimOp(BitsPrimOp bitsPrim, SmallVectorImpl<Value> &bits) {
366 SmallVectorImpl<Value> &opBits = valueBitsSrc[bitsPrim.getInput()];
367 for (
size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0;
368 srcIndex <= e; ++srcIndex, ++i)
369 bits[i] = opBits[srcIndex];
378 bool areBitsDrivenBySameSource(Value val) {
379 SmallVector<Value> stack;
380 stack.push_back(val);
382 while (!stack.empty()) {
383 auto val = stack.back();
384 if (valueBitsSrc.contains(val)) {
389 auto size =
getBitWidth(type_cast<FIRRTLBaseType>(val.getType()));
391 if (!size.has_value())
394 auto bitsSize = size.value();
395 if (
auto *defOp = val.getDefiningOp()) {
396 if (isa<CatPrimOp>(defOp)) {
397 bool operandsDone =
true;
400 for (
auto operand : defOp->getOperands()) {
401 if (valueBitsSrc.contains(operand))
403 stack.push_back(operand);
404 operandsDone =
false;
409 valueBitsSrc[val].resize_for_overwrite(bitsSize);
410 handleCatPrimOp(cast<CatPrimOp>(defOp), valueBitsSrc[val]);
411 }
else if (
auto bitsPrim = dyn_cast<BitsPrimOp>(defOp)) {
412 auto input = bitsPrim.getInput();
413 if (!valueBitsSrc.contains(input)) {
414 stack.push_back(input);
417 valueBitsSrc[val].resize_for_overwrite(bitsSize);
418 handleBitsPrimOp(bitsPrim, valueBitsSrc[val]);
419 }
else if (
auto constOp = dyn_cast<ConstantOp>(defOp)) {
420 auto constVal = constOp.getValue();
421 valueBitsSrc[val].resize_for_overwrite(bitsSize);
422 if (constVal.isAllOnes() || constVal.isZero()) {
423 for (
auto &b : valueBitsSrc[val])
427 }
else if (
auto wireOp = dyn_cast<WireOp>(defOp)) {
430 valueBitsSrc[val].resize_for_overwrite(bitsSize);
431 if (
auto src = getConnectSrc(wireOp.getResult())) {
432 valueBitsSrc[val][0] = src;
434 valueBitsSrc[val][0] = wireOp.getResult();
441 if (!valueBitsSrc.contains(val))
443 return llvm::all_equal(valueBitsSrc[val]);
448 void simplifyWmode(MemOp &memOp) {
452 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
453 auto portKind = memOp.getPortKind(portIt.index());
454 if (portKind != MemOp::PortKind::ReadWrite)
456 Value enableDriver, wmodeDriver;
457 Value portVal = portIt.value();
459 for (Operation *u : portVal.getUsers())
460 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
463 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
465 if (fName.contains(
"en"))
466 enableDriver = getConnectSrc(sf.getResult());
467 if (fName.contains(
"wmode"))
468 wmodeDriver = getConnectSrc(sf.getResult());
471 if (enableDriver && wmodeDriver) {
472 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
473 builder.setInsertionPointToStart(
474 memOp->getParentOfType<FModuleOp>().getBodyBlock());
475 auto constOne = builder.create<ConstantOp>(
477 setEnable(enableDriver, wmodeDriver, constOne);
484 void setEnable(Value enableDriver, Value wmodeDriver, Value constOne) {
485 auto getDriverOp = [&](Value dst) -> Operation * {
487 auto *defOp = dst.getDefiningOp();
489 if (isa<WireOp>(defOp))
490 dst = getConnectSrc(dst);
492 defOp = dst.getDefiningOp();
496 SmallVector<Value> stack;
497 llvm::SmallDenseSet<Value> visited;
498 stack.push_back(wmodeDriver);
499 while (!stack.empty()) {
500 auto driver = stack.pop_back_val();
501 if (!visited.insert(driver).second)
503 auto *defOp = getDriverOp(driver);
506 for (
auto operand : llvm::enumerate(defOp->getOperands())) {
507 if (operand.value() == enableDriver)
508 defOp->setOperand(operand.index(), constOne);
510 stack.push_back(operand.value());
515 void inferUnmasked(MemOp &memOp, SmallVector<Operation *> &opsToErase) {
516 bool isMasked =
true;
520 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
522 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
523 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug)
525 Value portVal = portIt.value();
527 for (Operation *u : portVal.getUsers())
528 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
531 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
533 if (fName.contains(
"mask")) {
535 if (sf.getResult().getType().getBitWidthOrSentinel() == 1)
540 if (
auto maskVal = getConnectSrc(sf))
541 if (areBitsDrivenBySameSource(maskVal))
549 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
551 SmallVector<Type, 4> resultTypes;
552 for (
size_t i = 0, e = memOp.getNumResults(); i != e; ++i)
553 resultTypes.push_back(
554 MemOp::getTypeForPort(memOp.getDepth(), memOp.getDataType(),
555 memOp.getPortKind(i), 1));
558 auto newMem = builder.create<MemOp>(
559 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
560 memOp.getDepth(), memOp.getRuw(), memOp.getPortNames().getValue(),
561 memOp.getNameAttr(), memOp.getNameKind(),
562 memOp.getAnnotations().getValue(),
563 memOp.getPortAnnotations().getValue(), memOp.getInnerSymAttr());
565 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
567 Value oldPort = portIt.value();
569 auto newPortVal = newMem->getResult(portIt.index());
571 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
572 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
573 oldPort.replaceAllUsesWith(newPortVal);
577 for (Operation *u : oldPort.getUsers()) {
578 auto oldRes = dyn_cast<SubfieldOp>(u);
580 builder.create<SubfieldOp>(newPortVal, oldRes.getFieldIndex());
582 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
585 if (fName.contains(
"mask")) {
586 WireOp dummy = builder.create<WireOp>(oldRes.getType());
587 oldRes->replaceAllUsesWith(dummy);
588 builder.create<MatchingConnectOp>(
589 sf, builder.create<ConstantOp>(
592 oldRes->replaceAllUsesWith(sf);
594 opsToErase.push_back(oldRes);
597 opsToErase.push_back(memOp);
604 llvm::DenseMap<Value, SmallVector<Value>> valueBitsSrc;
609 return std::make_unique<InferReadWritePass>();
assert(baseType &&"element must be base type")
static InstancePath empty
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createInferReadWritePass()
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.