19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "firrtl-infer-read-write"
27 using namespace circt;
28 using namespace firrtl;
31 struct InferReadWritePass :
public InferReadWriteBase<InferReadWritePass> {
41 void runOnOperation()
override {
42 LLVM_DEBUG(llvm::dbgs() <<
"\n Running Infer Read Write on module:"
44 SmallVector<Operation *> opsToErase;
45 for (MemOp memOp : llvm::make_early_inc_range(
46 getOperation().getBodyBlock()->getOps<MemOp>())) {
47 inferUnmasked(memOp, opsToErase);
49 size_t nReads, nWrites, nRWs, nDbgs;
50 memOp.getNumPorts(nReads, nWrites, nRWs, nDbgs);
53 if (!(nReads == 1 && nWrites == 1 && nRWs == 0) ||
54 !(memOp.getReadLatency() == 1 && memOp.getWriteLatency() == 1))
56 SmallVector<Attribute, 4> resultNames;
57 SmallVector<Type, 4> resultTypes;
58 SmallVector<Attribute> portAtts;
59 SmallVector<Attribute, 4> portAnnotations;
62 SmallVector<Value> readTerms, writeTerms;
63 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
65 portAnno = memOp.getPortAnnotation(portIt.index());
66 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
67 resultNames.push_back(memOp.getPortName(portIt.index()));
68 resultTypes.push_back(memOp.getResult(portIt.index()).getType());
69 portAnnotations.push_back(portAnno);
73 if (!cast<ArrayAttr>(portAnno).
empty())
74 portAtts.push_back(memOp.getPortAnnotation(portIt.index()));
76 Value portVal = portIt.value();
79 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
81 for (Operation *u : portVal.getUsers())
82 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
84 auto fName = sf.getInput().getType().base().getElementName(
88 if (fName.equals(
"en"))
89 getProductTerms(sf, isReadPort ? readTerms : writeTerms);
91 else if (fName.equals(
"clk")) {
93 rClock = getConnectSrc(sf);
95 wClock = getConnectSrc(sf);
100 if (!sameDriver(rClock, wClock))
105 llvm::dbgs() <<
"\n read clock:" << rClock
106 <<
" --- write clock:" << wClock;
107 llvm::dbgs() <<
"\n Read terms==>";
for (
auto t
108 : readTerms) llvm::dbgs()
111 llvm::dbgs() <<
"\n Write terms==>";
for (
auto t
112 : writeTerms) llvm::dbgs()
119 auto complementTerm = checkComplement(readTerms, writeTerms);
126 resultTypes.push_back(MemOp::getTypeForPort(
127 memOp.getDepth(), memOp.getDataType(), MemOp::PortKind::ReadWrite,
128 memOp.getMaskBits()));
129 ImplicitLocOpBuilder
builder(memOp.getLoc(), memOp);
130 portAnnotations.push_back(
builder.getArrayAttr(portAtts));
132 auto rwMem =
builder.create<MemOp>(
133 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
134 memOp.getDepth(), RUWAttr::Undefined,
135 builder.getArrayAttr(resultNames), memOp.getNameAttr(),
136 memOp.getNameKind(), memOp.getAnnotations(),
137 builder.getArrayAttr(portAnnotations), memOp.getInnerSymAttr(),
138 memOp.getInitAttr(), memOp.getPrefixAttr());
139 ++numRWPortMemoriesInferred;
140 auto rwPort = rwMem->getResult(nDbgs);
144 auto addr =
builder.create<SubfieldOp>(rwPort,
"addr");
146 auto enb =
builder.create<SubfieldOp>(rwPort,
"en");
148 auto clk =
builder.create<SubfieldOp>(rwPort,
"clk");
149 auto readData =
builder.create<SubfieldOp>(rwPort,
"rdata");
152 auto wmode =
builder.create<SubfieldOp>(rwPort,
"wmode");
153 auto writeData =
builder.create<SubfieldOp>(rwPort,
"wdata");
154 auto mask =
builder.create<SubfieldOp>(rwPort,
"wmask");
157 builder.create<WireOp>(
addr.getType(),
"readAddr").getResult();
159 builder.create<WireOp>(
addr.getType(),
"writeAddr").getResult();
161 builder.create<WireOp>(enb.getType(),
"writeEnable").getResult();
163 builder.create<WireOp>(enb.getType(),
"readEnable").getResult();
167 builder.create<StrictConnectOp>(
168 addr,
builder.create<MuxPrimOp>(wEnWire, wAddr, rAddr));
170 builder.create<StrictConnectOp>(
171 enb,
builder.create<OrPrimOp>(rEnWire, wEnWire));
175 size_t dbgsIndex = 0;
176 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
178 Value portVal = portIt.value();
179 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
180 memOp.getResult(portIt.index())
181 .replaceAllUsesWith(rwMem.getResult(dbgsIndex));
187 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
190 for (Operation *u : portVal.getUsers())
191 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
192 StringRef fName = sf.getInput().getType().base().getElementName(
196 repl = llvm::StringSwitch<Value>(fName)
200 .Case(
"data", readData);
202 repl = llvm::StringSwitch<Value>(fName)
204 .Case(
"clk", writeClock)
206 .Case(
"data", writeData)
208 sf.replaceAllUsesWith(repl);
210 opsToErase.push_back(sf);
213 simplifyWmode(rwMem);
215 opsToErase.push_back(memOp);
217 for (
auto *o : opsToErase)
223 Value getConnectSrc(Value dst) {
224 for (
auto *c : dst.getUsers())
225 if (
auto connect = dyn_cast<FConnectLike>(c))
234 bool sameDriver(Value rClock, Value wClock) {
235 if (rClock == wClock)
237 DenseSet<Value> rClocks, wClocks;
241 rClocks.insert(rClock);
242 rClock = getConnectSrc(rClock);
245 bool sameClock =
false;
249 if (rClocks.find(wClock) != rClocks.end()) {
253 wClock = getConnectSrc(wClock);
258 void getProductTerms(Value enValue, SmallVector<Value> &terms) {
261 SmallVector<Value> worklist;
262 worklist.push_back(enValue);
263 while (!worklist.empty()) {
264 auto term = worklist.back();
266 terms.push_back(term);
267 if (isa<BlockArgument>(term))
269 TypeSwitch<Operation *>(term.getDefiningOp())
270 .Case<NodeOp>([&](
auto n) { worklist.push_back(n.getInput()); })
271 .Case<AndPrimOp>([&](AndPrimOp andOp) {
272 worklist.push_back(andOp.getOperand(0));
273 worklist.push_back(andOp.getOperand(1));
275 .Case<MuxPrimOp>([&](
auto muxOp) {
279 if (ConstantOp cLow = dyn_cast_or_null<ConstantOp>(
280 muxOp.getLow().getDefiningOp()))
281 if (cLow.getValue().isZero()) {
282 worklist.push_back(muxOp.getSel());
283 worklist.push_back(muxOp.getHigh());
287 if (
auto src = getConnectSrc(term))
288 worklist.push_back(src);
298 Value checkComplement(
const SmallVector<Value> &readTerms,
299 const SmallVector<Value> &writeTerms) {
302 for (
auto t1 : readTerms)
303 for (
auto t2 : writeTerms) {
305 if (!isa<BlockArgument>(t1) && isa<NotPrimOp>(t1.getDefiningOp()))
306 if (cast<NotPrimOp>(t1.getDefiningOp()).getInput() == t2)
309 if (!isa<BlockArgument>(t2) && isa<NotPrimOp>(t2.getDefiningOp()))
310 if (cast<NotPrimOp>(t2.getDefiningOp()).getInput() == t1)
317 void handleCatPrimOp(CatPrimOp defOp, SmallVectorImpl<Value> &bits) {
321 for (
auto operand : defOp->getOperands()) {
322 SmallVectorImpl<Value> &opBits = valueBitsSrc[operand];
324 getBitWidth(type_cast<FIRRTLBaseType>(operand.getType())).value();
325 assert(opBits.size() == s);
326 for (
long i = lastSize, e = lastSize + s; i != e; ++i)
327 bits[i] = opBits[i - lastSize];
332 void handleBitsPrimOp(BitsPrimOp bitsPrim, SmallVectorImpl<Value> &bits) {
334 SmallVectorImpl<Value> &opBits = valueBitsSrc[bitsPrim.getInput()];
335 for (
size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0;
336 srcIndex <= e; ++srcIndex, ++i)
337 bits[i] = opBits[srcIndex];
346 bool areBitsDrivenBySameSource(Value val) {
347 SmallVector<Value> stack;
348 stack.push_back(val);
350 while (!stack.empty()) {
351 auto val = stack.back();
352 if (valueBitsSrc.contains(val)) {
357 auto size =
getBitWidth(type_cast<FIRRTLBaseType>(val.getType()));
359 if (!size.has_value())
362 auto bitsSize = size.value();
363 if (
auto *defOp = val.getDefiningOp()) {
364 if (isa<CatPrimOp>(defOp)) {
365 bool operandsDone =
true;
368 for (
auto operand : defOp->getOperands()) {
369 if (valueBitsSrc.contains(operand))
371 stack.push_back(operand);
372 operandsDone =
false;
377 valueBitsSrc[val].resize_for_overwrite(bitsSize);
378 handleCatPrimOp(cast<CatPrimOp>(defOp), valueBitsSrc[val]);
379 }
else if (
auto bitsPrim = dyn_cast<BitsPrimOp>(defOp)) {
380 auto input = bitsPrim.getInput();
381 if (!valueBitsSrc.contains(input)) {
382 stack.push_back(input);
385 valueBitsSrc[val].resize_for_overwrite(bitsSize);
386 handleBitsPrimOp(bitsPrim, valueBitsSrc[val]);
387 }
else if (
auto constOp = dyn_cast<ConstantOp>(defOp)) {
388 auto constVal = constOp.getValue();
389 valueBitsSrc[val].resize_for_overwrite(bitsSize);
390 if (constVal.isAllOnes() || constVal.isZero()) {
391 for (
auto &b : valueBitsSrc[val])
395 }
else if (
auto wireOp = dyn_cast<WireOp>(defOp)) {
398 valueBitsSrc[val].resize_for_overwrite(bitsSize);
399 if (
auto src = getConnectSrc(wireOp.getResult())) {
400 valueBitsSrc[val][0] = src;
402 valueBitsSrc[val][0] = wireOp.getResult();
409 if (!valueBitsSrc.contains(val))
411 return llvm::all_equal(valueBitsSrc[val]);
416 void simplifyWmode(MemOp &memOp) {
420 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
421 auto portKind = memOp.getPortKind(portIt.index());
422 if (portKind != MemOp::PortKind::ReadWrite)
424 Value enableDriver, wmodeDriver;
425 Value portVal = portIt.value();
427 for (Operation *u : portVal.getUsers())
428 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
431 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
433 if (fName.contains(
"en"))
434 enableDriver = getConnectSrc(sf.getResult());
435 if (fName.contains(
"wmode"))
436 wmodeDriver = getConnectSrc(sf.getResult());
439 if (enableDriver && wmodeDriver) {
440 ImplicitLocOpBuilder
builder(memOp.getLoc(), memOp);
441 builder.setInsertionPointToStart(
442 memOp->getParentOfType<FModuleOp>().getBodyBlock());
443 auto constOne =
builder.create<ConstantOp>(
445 setEnable(enableDriver, wmodeDriver, constOne);
452 void setEnable(Value enableDriver, Value wmodeDriver, Value constOne) {
453 auto getDriverOp = [&](Value dst) -> Operation * {
455 auto *defOp = dst.getDefiningOp();
457 if (isa<WireOp>(defOp))
458 dst = getConnectSrc(dst);
460 defOp = dst.getDefiningOp();
464 SmallVector<Value> stack;
465 llvm::SmallDenseSet<Value> visited;
466 stack.push_back(wmodeDriver);
467 while (!stack.empty()) {
468 auto driver = stack.pop_back_val();
469 if (!visited.insert(driver).second)
471 auto *defOp = getDriverOp(driver);
474 for (
auto operand : llvm::enumerate(defOp->getOperands())) {
475 if (operand.value() == enableDriver)
476 defOp->setOperand(operand.index(), constOne);
478 stack.push_back(operand.value());
483 void inferUnmasked(MemOp &memOp, SmallVector<Operation *> &opsToErase) {
484 bool isMasked =
true;
488 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
490 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
491 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug)
493 Value portVal = portIt.value();
495 for (Operation *u : portVal.getUsers())
496 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
499 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
501 if (fName.contains(
"mask")) {
503 if (sf.getResult().getType().getBitWidthOrSentinel() == 1)
508 if (
auto maskVal = getConnectSrc(sf))
509 if (areBitsDrivenBySameSource(maskVal))
517 ImplicitLocOpBuilder
builder(memOp.getLoc(), memOp);
519 SmallVector<Type, 4> resultTypes;
520 for (
size_t i = 0, e = memOp.getNumResults(); i != e; ++i)
521 resultTypes.push_back(
522 MemOp::getTypeForPort(memOp.getDepth(), memOp.getDataType(),
523 memOp.getPortKind(i), 1));
526 auto newMem =
builder.create<MemOp>(
527 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
528 memOp.getDepth(), memOp.getRuw(), memOp.getPortNames().getValue(),
529 memOp.getNameAttr(), memOp.getNameKind(),
530 memOp.getAnnotations().getValue(),
531 memOp.getPortAnnotations().getValue(), memOp.getInnerSymAttr());
533 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
535 Value oldPort = portIt.value();
537 auto newPortVal = newMem->getResult(portIt.index());
539 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
540 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
541 oldPort.replaceAllUsesWith(newPortVal);
545 for (Operation *u : oldPort.getUsers()) {
546 auto oldRes = dyn_cast<SubfieldOp>(u);
548 builder.create<SubfieldOp>(newPortVal, oldRes.getFieldIndex());
550 sf.getInput().getType().base().getElementName(sf.getFieldIndex());
553 if (fName.contains(
"mask")) {
554 WireOp dummy =
builder.create<WireOp>(oldRes.getType());
555 oldRes->replaceAllUsesWith(dummy);
556 builder.create<StrictConnectOp>(
557 sf,
builder.create<ConstantOp>(
560 oldRes->replaceAllUsesWith(sf);
562 opsToErase.push_back(oldRes);
565 opsToErase.push_back(memOp);
572 llvm::DenseMap<Value, SmallVector<Value>> valueBitsSrc;
577 return std::make_unique<InferReadWritePass>();
assert(baseType &&"element must be base type")
static InstancePath empty
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createInferReadWritePass()
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.