CIRCT  18.0.0git
InferReadWrite.cpp
Go to the documentation of this file.
1 //===- InferReadWrite.cpp - Infer Read Write Memory -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the InferReadWrite pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "firrtl-infer-read-write"
26 
27 using namespace circt;
28 using namespace firrtl;
29 
30 namespace {
31 struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
32 
33  /// This pass performs two memory transformations:
34  /// 1. If the multi-bit enable port is connected to a constant 1,
35  /// then, replace with a single bit mask. Create a new memory with a
36  /// 1 bit mask, and replace the old memory with it. The single bit mask
37  /// memory is always lowered to an unmasked memory.
38  /// 2. If the read and write enable ports are trivially mutually exclusive,
39  /// then create a new memory with a single read/write port, and replace
40  /// the old memory with it.
41  void runOnOperation() override {
42  LLVM_DEBUG(llvm::dbgs() << "\n Running Infer Read Write on module:"
43  << getOperation().getName());
44  SmallVector<Operation *> opsToErase;
45  for (MemOp memOp : llvm::make_early_inc_range(
46  getOperation().getBodyBlock()->getOps<MemOp>())) {
47  inferUnmasked(memOp, opsToErase);
48  size_t nReads, nWrites, nRWs, nDbgs;
49  memOp.getNumPorts(nReads, nWrites, nRWs, nDbgs);
50  // Run the analysis only for Seq memories (latency=1) and a single read
51  // and write ports.
52  if (!(nReads == 1 && nWrites == 1 && nRWs == 0) ||
53  !(memOp.getReadLatency() == 1 && memOp.getWriteLatency() == 1))
54  continue;
55  SmallVector<Attribute, 4> resultNames;
56  SmallVector<Type, 4> resultTypes;
57  SmallVector<Attribute> portAtts;
58  SmallVector<Attribute, 4> portAnnotations;
59  Value rClock, wClock;
60  // The memory has exactly two ports.
61  SmallVector<Value> readTerms, writeTerms;
62  for (const auto &portIt : llvm::enumerate(memOp.getResults())) {
63  Attribute portAnno;
64  portAnno = memOp.getPortAnnotation(portIt.index());
65  if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
66  resultNames.push_back(memOp.getPortName(portIt.index()));
67  resultTypes.push_back(memOp.getResult(portIt.index()).getType());
68  portAnnotations.push_back(portAnno);
69  continue;
70  }
71  // Append the annotations from the two ports.
72  if (!cast<ArrayAttr>(portAnno).empty())
73  portAtts.push_back(memOp.getPortAnnotation(portIt.index()));
74  // Get the port value.
75  Value portVal = portIt.value();
76  // Get the port kind.
77  bool isReadPort =
78  memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
79  // Iterate over all users of the port.
80  for (Operation *u : portVal.getUsers())
81  if (auto sf = dyn_cast<SubfieldOp>(u)) {
82  // Get the field name.
83  auto fName = sf.getInput().getType().get().getElementName(
84  sf.getFieldIndex());
85  // If this is the enable field, record the product terms(the And
86  // expression tree).
87  if (fName.equals("en"))
88  getProductTerms(sf, isReadPort ? readTerms : writeTerms);
89 
90  else if (fName.equals("clk")) {
91  if (isReadPort)
92  rClock = getConnectSrc(sf);
93  else
94  wClock = getConnectSrc(sf);
95  }
96  }
97  // End of loop for getting MemOp port users.
98  }
99  if (!sameDriver(rClock, wClock))
100  continue;
101 
102  rClock = wClock;
103  LLVM_DEBUG(
104  llvm::dbgs() << "\n read clock:" << rClock
105  << " --- write clock:" << wClock;
106  llvm::dbgs() << "\n Read terms==>"; for (auto t
107  : readTerms) llvm::dbgs()
108  << "\n term::" << t;
109 
110  llvm::dbgs() << "\n Write terms==>"; for (auto t
111  : writeTerms) llvm::dbgs()
112  << "\n term::" << t;
113 
114  );
115  // If the read and write clocks are the same, and if any of the write
116  // enable product terms are a complement of the read enable, then return
117  // the write enable term.
118  auto complementTerm = checkComplement(readTerms, writeTerms);
119  if (!complementTerm)
120  continue;
121 
122  // Create the merged rw port for the new memory.
123  resultNames.push_back(StringAttr::get(memOp.getContext(), "rw"));
124  // Set the type of the rw port.
125  resultTypes.push_back(MemOp::getTypeForPort(
126  memOp.getDepth(), memOp.getDataType(), MemOp::PortKind::ReadWrite,
127  memOp.getMaskBits()));
128  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
129  portAnnotations.push_back(builder.getArrayAttr(portAtts));
130  // Create the new rw memory.
131  auto rwMem = builder.create<MemOp>(
132  resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
133  memOp.getDepth(), RUWAttr::Undefined,
134  builder.getArrayAttr(resultNames), memOp.getNameAttr(),
135  memOp.getNameKind(), memOp.getAnnotations(),
136  builder.getArrayAttr(portAnnotations), memOp.getInnerSymAttr(),
137  memOp.getInitAttr(), memOp.getPrefixAttr());
138  ++numRWPortMemoriesInferred;
139  auto rwPort = rwMem->getResult(nDbgs);
140  // Create the subfield access to all fields of the port.
141  // The addr should be connected to read/write address depending on the
142  // read/write mode.
143  auto addr = builder.create<SubfieldOp>(rwPort, "addr");
144  // Enable is high whenever the memory is written or read.
145  auto enb = builder.create<SubfieldOp>(rwPort, "en");
146  // Read/Write clock.
147  auto clk = builder.create<SubfieldOp>(rwPort, "clk");
148  auto readData = builder.create<SubfieldOp>(rwPort, "rdata");
149  // wmode is high when the port is in write mode. That is this can be
150  // connected to the write enable.
151  auto wmode = builder.create<SubfieldOp>(rwPort, "wmode");
152  auto writeData = builder.create<SubfieldOp>(rwPort, "wdata");
153  auto mask = builder.create<SubfieldOp>(rwPort, "wmask");
154  // Temp wires to replace the original memory connects.
155  auto rAddr =
156  builder.create<WireOp>(addr.getType(), "readAddr").getResult();
157  auto wAddr =
158  builder.create<WireOp>(addr.getType(), "writeAddr").getResult();
159  auto wEnWire =
160  builder.create<WireOp>(enb.getType(), "writeEnable").getResult();
161  auto rEnWire =
162  builder.create<WireOp>(enb.getType(), "readEnable").getResult();
163  auto writeClock =
164  builder.create<WireOp>(ClockType::get(enb.getContext())).getResult();
165  // addr = Mux(WriteEnable, WriteAddress, ReadAddress).
166  builder.create<StrictConnectOp>(
167  addr, builder.create<MuxPrimOp>(wEnWire, wAddr, rAddr));
168  // Enable = Or(WriteEnable, ReadEnable).
169  builder.create<StrictConnectOp>(
170  enb, builder.create<OrPrimOp>(rEnWire, wEnWire));
171  builder.setInsertionPointToEnd(wmode->getBlock());
172  builder.create<StrictConnectOp>(wmode, complementTerm);
173  // Now iterate over the original memory read and write ports.
174  size_t dbgsIndex = 0;
175  for (const auto &portIt : llvm::enumerate(memOp.getResults())) {
176  // Get the port value.
177  Value portVal = portIt.value();
178  if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
179  memOp.getResult(portIt.index())
180  .replaceAllUsesWith(rwMem.getResult(dbgsIndex));
181  dbgsIndex++;
182  continue;
183  }
184  // Get the port kind.
185  bool isReadPort =
186  memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
187  // Iterate over all users of the port, which are the subfield ops, and
188  // replace them.
189  for (Operation *u : portVal.getUsers())
190  if (auto sf = dyn_cast<SubfieldOp>(u)) {
191  StringRef fName = sf.getInput().getType().get().getElementName(
192  sf.getFieldIndex());
193  Value repl;
194  if (isReadPort)
195  repl = llvm::StringSwitch<Value>(fName)
196  .Case("en", rEnWire)
197  .Case("clk", clk)
198  .Case("addr", rAddr)
199  .Case("data", readData);
200  else
201  repl = llvm::StringSwitch<Value>(fName)
202  .Case("en", wEnWire)
203  .Case("clk", writeClock)
204  .Case("addr", wAddr)
205  .Case("data", writeData)
206  .Case("mask", mask);
207  sf.replaceAllUsesWith(repl);
208  // Once all the uses of the subfield op replaced, delete it.
209  opsToErase.push_back(sf);
210  }
211  }
212  // All uses for all results of mem removed, now erase the memOp.
213  opsToErase.push_back(memOp);
214  }
215  for (auto *o : opsToErase)
216  o->erase();
217  }
218 
219 private:
220  // Get the source value which is connected to the dst.
221  Value getConnectSrc(Value dst) {
222  for (auto *c : dst.getUsers())
223  if (auto connect = dyn_cast<FConnectLike>(c))
224  if (connect.getDest() == dst)
225  return connect.getSrc();
226 
227  return nullptr;
228  }
229 
230  /// If the ports are not directly connected to the same clock, then check
231  /// if indirectly connected to the same clock.
232  bool sameDriver(Value rClock, Value wClock) {
233  if (rClock == wClock)
234  return true;
235  DenseSet<Value> rClocks, wClocks;
236  while (rClock) {
237  // Record all the values which are indirectly connected to the clock
238  // port.
239  rClocks.insert(rClock);
240  rClock = getConnectSrc(rClock);
241  }
242 
243  bool sameClock = false;
244  // Now check all the indirect connections to the write memory clock
245  // port.
246  while (wClock) {
247  if (rClocks.find(wClock) != rClocks.end()) {
248  sameClock = true;
249  break;
250  }
251  wClock = getConnectSrc(wClock);
252  }
253  return sameClock;
254  }
255 
256  void getProductTerms(Value enValue, SmallVector<Value> &terms) {
257  if (!enValue)
258  return;
259  SmallVector<Value> worklist;
260  worklist.push_back(enValue);
261  while (!worklist.empty()) {
262  auto term = worklist.back();
263  worklist.pop_back();
264  terms.push_back(term);
265  if (isa<BlockArgument>(term))
266  continue;
267  TypeSwitch<Operation *>(term.getDefiningOp())
268  .Case<NodeOp>([&](auto n) { worklist.push_back(n.getInput()); })
269  .Case<AndPrimOp>([&](AndPrimOp andOp) {
270  worklist.push_back(andOp.getOperand(0));
271  worklist.push_back(andOp.getOperand(1));
272  })
273  .Case<MuxPrimOp>([&](auto muxOp) {
274  // Check for the pattern when low is 0, which is equivalent to (sel
275  // & high)
276  // term = mux (sel, high, 0) => term = sel & high
277  if (ConstantOp cLow = dyn_cast_or_null<ConstantOp>(
278  muxOp.getLow().getDefiningOp()))
279  if (cLow.getValue().isZero()) {
280  worklist.push_back(muxOp.getSel());
281  worklist.push_back(muxOp.getHigh());
282  }
283  })
284  .Default([&](auto) {
285  if (auto src = getConnectSrc(term))
286  worklist.push_back(src);
287  });
288  }
289  }
290 
291  /// If any of the terms in the read enable, prodTerms[0] is a complement of
292  /// any of the terms in the write enable prodTerms[1], return the
293  /// corresponding write enable term. prodTerms[0], prodTerms[1] is a vector of
294  /// Value, each of which correspond to the two product terms of read and write
295  /// enable respectively.
296  Value checkComplement(const SmallVector<Value> &readTerms,
297  const SmallVector<Value> &writeTerms) {
298  // Foreach Value in first term, check if it is the complement of any of the
299  // Value in second term.
300  for (auto t1 : readTerms)
301  for (auto t2 : writeTerms) {
302  // Return t2, t1 is a Not of t2.
303  if (!isa<BlockArgument>(t1) && isa<NotPrimOp>(t1.getDefiningOp()))
304  if (cast<NotPrimOp>(t1.getDefiningOp()).getInput() == t2)
305  return t2;
306  // Else Return t2, if t2 is a Not of t1.
307  if (!isa<BlockArgument>(t2) && isa<NotPrimOp>(t2.getDefiningOp()))
308  if (cast<NotPrimOp>(t2.getDefiningOp()).getInput() == t1)
309  return t2;
310  }
311 
312  return {};
313  }
314 
315  void inferUnmasked(MemOp &memOp, SmallVector<Operation *> &opsToErase) {
316  bool isMasked = true;
317 
318  // Iterate over all results, and check if the mask field of the result is
319  // connected to a multi-bit constant 1.
320  for (const auto &portIt : llvm::enumerate(memOp.getResults())) {
321  // Read ports donot have the mask field.
322  if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
323  memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug)
324  continue;
325  Value portVal = portIt.value();
326  // Iterate over all users of the write/rw port.
327  for (Operation *u : portVal.getUsers())
328  if (auto sf = dyn_cast<SubfieldOp>(u)) {
329  // Get the field name.
330  auto fName =
331  sf.getInput().getType().get().getElementName(sf.getFieldIndex());
332  // Check if this is the mask field.
333  if (fName.contains("mask")) {
334  // Already 1 bit, nothing to do.
335  if (sf.getResult().getType().getBitWidthOrSentinel() == 1)
336  continue;
337  // Check what is the mask field directly connected to.
338  // If, a constant 1, then we can replace with unMasked memory.
339  if (auto maskVal = getConnectSrc(sf))
340  if (auto constVal = dyn_cast<ConstantOp>(maskVal.getDefiningOp()))
341  if (constVal.getValue().isAllOnes())
342  isMasked = false;
343  }
344  }
345  }
346 
347  if (!isMasked) {
348  // Replace with a new memory of 1 bit mask.
349  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
350  // Copy the result type, except the mask bits.
351  SmallVector<Type, 4> resultTypes;
352  for (size_t i = 0, e = memOp.getNumResults(); i != e; ++i)
353  resultTypes.push_back(
354  MemOp::getTypeForPort(memOp.getDepth(), memOp.getDataType(),
355  memOp.getPortKind(i), /*maskBits=*/1));
356 
357  // Copy everything from old memory, except the result type.
358  auto newMem = builder.create<MemOp>(
359  resultTypes, memOp.getReadLatency(), memOp.getWriteLatency(),
360  memOp.getDepth(), memOp.getRuw(), memOp.getPortNames().getValue(),
361  memOp.getNameAttr(), memOp.getNameKind(),
362  memOp.getAnnotations().getValue(),
363  memOp.getPortAnnotations().getValue(), memOp.getInnerSymAttr());
364  // Now replace the result of old memory with the new one.
365  for (const auto &portIt : llvm::enumerate(memOp.getResults())) {
366  // Old result.
367  Value oldPort = portIt.value();
368  // New result.
369  auto newPortVal = newMem->getResult(portIt.index());
370  // If read port, then blindly replace.
371  if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read ||
372  memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) {
373  oldPort.replaceAllUsesWith(newPortVal);
374  continue;
375  }
376  // Otherwise, all fields can be blindly replaced, except mask field.
377  for (Operation *u : oldPort.getUsers()) {
378  auto oldRes = dyn_cast<SubfieldOp>(u);
379  auto sf =
380  builder.create<SubfieldOp>(newPortVal, oldRes.getFieldIndex());
381  auto fName =
382  sf.getInput().getType().get().getElementName(sf.getFieldIndex());
383  // Replace all mask fields with a one bit constant 1.
384  // Replace all other fields with the new port.
385  if (fName.contains("mask")) {
386  WireOp dummy = builder.create<WireOp>(oldRes.getType());
387  oldRes->replaceAllUsesWith(dummy);
388  builder.create<StrictConnectOp>(
389  sf, builder.create<ConstantOp>(
390  UIntType::get(builder.getContext(), 1), APInt(1, 1)));
391  } else
392  oldRes->replaceAllUsesWith(sf);
393 
394  opsToErase.push_back(oldRes);
395  }
396  }
397  opsToErase.push_back(memOp);
398  memOp = newMem;
399  }
400  }
401 };
402 } // end anonymous namespace
403 
404 std::unique_ptr<mlir::Pass> circt::firrtl::createInferReadWritePass() {
405  return std::make_unique<InferReadWritePass>();
406 }
static InstancePath empty
Builder builder
def connect(destination, source)
Definition: support.py:37
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
std::unique_ptr< mlir::Pass > createInferReadWritePass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28