19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "firrtl-infer-read-write"
27 using namespace circt;
28 using namespace firrtl;
31 struct InferReadWritePass :
public InferReadWriteBase<InferReadWritePass> {
41 void runOnOperation()
override {
42 LLVM_DEBUG(
llvm::dbgs() <<
"\n Running Infer Read Write on module:"
44 SmallVector<Operation *> opsToErase;
45 for (MemOp memOp : llvm::make_early_inc_range(
46 getOperation().getBodyBlock()->getOps<MemOp>())) {
47 inferUnmasked(memOp, opsToErase);
48 size_t nReads, nWrites, nRWs, nDbgs;
49 memOp.getNumPorts(nReads, nWrites, nRWs, nDbgs);
52 if (!(nReads == 1 && nWrites == 1 && nRWs == 0) ||
53 !(memOp.getReadLatency() == 1 && memOp.getWriteLatency() == 1))
55 SmallVector<Attribute, 4> resultNames;
56 SmallVector<Type, 4> resultTypes;
57 SmallVector<Attribute> portAtts;
58 SmallVector<Attribute, 4> portAnnotations;
61 SmallVector<Value> readTerms, writeTerms;
62 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
64 portAnno = memOp.getPortAnnotation(portIt.index());
65 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
66 resultNames.push_back(memOp.getPortName(portIt.index()));
67 resultTypes.push_back(memOp.getResult(portIt.index()).getType());
68 portAnnotations.push_back(portAnno);
72 if (!cast<ArrayAttr>(portAnno).
empty())
73 portAtts.push_back(memOp.getPortAnnotation(portIt.index()));
75 Value portVal = portIt.value();
78 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
80 for (Operation *u : portVal.getUsers())
81 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
83 auto fName = sf.getInput().getType().get().getElementName(
87 if (fName.equals(
"en"))
88 getProductTerms(sf, isReadPort ? readTerms : writeTerms);
90 else if (fName.equals(
"clk")) {
92 rClock = getConnectSrc(sf);
94 wClock = getConnectSrc(sf);
99 if (!sameDriver(rClock, wClock))
105 <<
" --- write clock:" << wClock;
106 llvm::dbgs() <<
"\n Read terms==>";
for (
auto t
110 llvm::dbgs() <<
"\n Write terms==>";
for (
auto t
118 auto complementTerm = checkComplement(readTerms, writeTerms);
125 resultTypes.push_back(MemOp::getTypeForPort(
126 memOp.getDepth(), memOp.getDataType(), MemOp::PortKind::ReadWrite,
127 memOp.getMaskBits()));
128 ImplicitLocOpBuilder
builder(memOp.getLoc(), memOp);
129 portAnnotations.push_back(
builder.getArrayAttr(portAtts));
131 auto rwMem =
builder.create<MemOp>(
132 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
133 memOp.getDepth(), RUWAttr::Undefined,
134 builder.getArrayAttr(resultNames), memOp.getNameAttr(),
135 memOp.getNameKind(), memOp.getAnnotations(),
136 builder.getArrayAttr(portAnnotations), memOp.getInnerSymAttr(),
137 memOp.getInitAttr(), memOp.getPrefixAttr());
138 ++numRWPortMemoriesInferred;
139 auto rwPort = rwMem->getResult(nDbgs);
143 auto addr =
builder.create<SubfieldOp>(rwPort,
"addr");
145 auto enb =
builder.create<SubfieldOp>(rwPort,
"en");
147 auto clk =
builder.create<SubfieldOp>(rwPort,
"clk");
148 auto readData =
builder.create<SubfieldOp>(rwPort,
"rdata");
151 auto wmode =
builder.create<SubfieldOp>(rwPort,
"wmode");
152 auto writeData =
builder.create<SubfieldOp>(rwPort,
"wdata");
153 auto mask =
builder.create<SubfieldOp>(rwPort,
"wmask");
156 builder.create<WireOp>(
addr.getType(),
"readAddr").getResult();
158 builder.create<WireOp>(
addr.getType(),
"writeAddr").getResult();
160 builder.create<WireOp>(enb.getType(),
"writeEnable").getResult();
162 builder.create<WireOp>(enb.getType(),
"readEnable").getResult();
166 builder.create<StrictConnectOp>(
167 addr,
builder.create<MuxPrimOp>(wEnWire, wAddr, rAddr));
169 builder.create<StrictConnectOp>(
170 enb,
builder.create<OrPrimOp>(rEnWire, wEnWire));
174 size_t dbgsIndex = 0;
175 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
177 Value portVal = portIt.value();
178 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
179 memOp.getResult(portIt.index())
180 .replaceAllUsesWith(rwMem.getResult(dbgsIndex));
186 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
189 for (Operation *u : portVal.getUsers())
190 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
191 StringRef fName = sf.getInput().getType().get().getElementName(
195 repl = llvm::StringSwitch<Value>(fName)
199 .Case(
"data", readData);
201 repl = llvm::StringSwitch<Value>(fName)
203 .Case(
"clk", writeClock)
205 .Case(
"data", writeData)
207 sf.replaceAllUsesWith(repl);
209 opsToErase.push_back(sf);
213 opsToErase.push_back(memOp);
215 for (
auto *o : opsToErase)
221 Value getConnectSrc(Value dst) {
222 for (
auto *c : dst.getUsers())
223 if (
auto connect = dyn_cast<FConnectLike>(c))
232 bool sameDriver(Value rClock, Value wClock) {
233 if (rClock == wClock)
235 DenseSet<Value> rClocks, wClocks;
239 rClocks.insert(rClock);
240 rClock = getConnectSrc(rClock);
243 bool sameClock =
false;
247 if (rClocks.find(wClock) != rClocks.end()) {
251 wClock = getConnectSrc(wClock);
256 void getProductTerms(Value enValue, SmallVector<Value> &terms) {
259 SmallVector<Value> worklist;
260 worklist.push_back(enValue);
261 while (!worklist.empty()) {
262 auto term = worklist.back();
264 terms.push_back(term);
265 if (isa<BlockArgument>(term))
267 TypeSwitch<Operation *>(term.getDefiningOp())
268 .Case<NodeOp>([&](
auto n) { worklist.push_back(n.getInput()); })
269 .Case<AndPrimOp>([&](AndPrimOp andOp) {
270 worklist.push_back(andOp.getOperand(0));
271 worklist.push_back(andOp.getOperand(1));
273 .Case<MuxPrimOp>([&](
auto muxOp) {
277 if (ConstantOp cLow = dyn_cast_or_null<ConstantOp>(
278 muxOp.getLow().getDefiningOp()))
279 if (cLow.getValue().isZero()) {
280 worklist.push_back(muxOp.getSel());
281 worklist.push_back(muxOp.getHigh());
285 if (
auto src = getConnectSrc(term))
286 worklist.push_back(src);
296 Value checkComplement(
const SmallVector<Value> &readTerms,
297 const SmallVector<Value> &writeTerms) {
300 for (
auto t1 : readTerms)
301 for (
auto t2 : writeTerms) {
303 if (!isa<BlockArgument>(t1) && isa<NotPrimOp>(t1.getDefiningOp()))
304 if (cast<NotPrimOp>(t1.getDefiningOp()).getInput() == t2)
307 if (!isa<BlockArgument>(t2) && isa<NotPrimOp>(t2.getDefiningOp()))
308 if (cast<NotPrimOp>(t2.getDefiningOp()).getInput() == t1)
315 void inferUnmasked(MemOp &memOp, SmallVector<Operation *> &opsToErase) {
316 bool isMasked =
true;
320 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
322 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
323 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug)
325 Value portVal = portIt.value();
327 for (Operation *u : portVal.getUsers())
328 if (
auto sf = dyn_cast<SubfieldOp>(u)) {
331 sf.getInput().getType().get().getElementName(sf.getFieldIndex());
333 if (fName.contains(
"mask")) {
335 if (sf.getResult().getType().getBitWidthOrSentinel() == 1)
339 if (
auto maskVal = getConnectSrc(sf))
340 if (
auto constVal = dyn_cast<ConstantOp>(maskVal.getDefiningOp()))
341 if (constVal.getValue().isAllOnes())
349 ImplicitLocOpBuilder
builder(memOp.getLoc(), memOp);
351 SmallVector<Type, 4> resultTypes;
352 for (
size_t i = 0, e = memOp.getNumResults(); i != e; ++i)
353 resultTypes.push_back(
354 MemOp::getTypeForPort(memOp.getDepth(), memOp.getDataType(),
355 memOp.getPortKind(i), 1));
358 auto newMem =
builder.create<MemOp>(
359 resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
360 memOp.getDepth(), memOp.getRuw(), memOp.getPortNames().getValue(),
361 memOp.getNameAttr(), memOp.getNameKind(),
362 memOp.getAnnotations().getValue(),
363 memOp.getPortAnnotations().getValue(), memOp.getInnerSymAttr());
365 for (
const auto &portIt : llvm::enumerate(memOp.getResults())) {
367 Value oldPort = portIt.value();
369 auto newPortVal = newMem->getResult(portIt.index());
371 if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
372 memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
373 oldPort.replaceAllUsesWith(newPortVal);
377 for (Operation *u : oldPort.getUsers()) {
378 auto oldRes = dyn_cast<SubfieldOp>(u);
380 builder.create<SubfieldOp>(newPortVal, oldRes.getFieldIndex());
382 sf.getInput().getType().get().getElementName(sf.getFieldIndex());
385 if (fName.contains(
"mask")) {
386 WireOp dummy =
builder.create<WireOp>(oldRes.getType());
387 oldRes->replaceAllUsesWith(dummy);
388 builder.create<StrictConnectOp>(
389 sf,
builder.create<ConstantOp>(
392 oldRes->replaceAllUsesWith(sf);
394 opsToErase.push_back(oldRes);
397 opsToErase.push_back(memOp);
405 return std::make_unique<InferReadWritePass>();
static InstancePath empty
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createInferReadWritePass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()