CIRCT  20.0.0git
MemToRegOfVec.cpp
Go to the documentation of this file.
1 //===- MemToRegOfVec.cpp - MemToRegOfVec Pass -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MemToRegOfVec pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/Threading.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "mem-to-reg-of-vec"
28 
29 namespace circt {
30 namespace firrtl {
31 #define GEN_PASS_DEF_MEMTOREGOFVEC
32 #include "circt/Dialect/FIRRTL/Passes.h.inc"
33 } // namespace firrtl
34 } // namespace circt
35 
36 using namespace circt;
37 using namespace firrtl;
38 
39 namespace {
40 struct MemToRegOfVecPass
41  : public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
42  MemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable)
43  : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
44 
45  void runOnOperation() override {
46  auto circtOp = getOperation();
47  DenseSet<Operation *> dutModuleSet;
50  return markAllAnalysesPreserved();
51  auto *body = circtOp.getBodyBlock();
52 
53  // Find the device under test and create a set of all modules underneath it.
54  auto it = llvm::find_if(*body, [&](Operation &op) -> bool {
56  });
57  if (it != body->end()) {
58  auto &instanceGraph = getAnalysis<InstanceGraph>();
59  auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
60  llvm::for_each(llvm::depth_first(node),
61  [&](igraph::InstanceGraphNode *node) {
62  dutModuleSet.insert(node->getModule());
63  });
64  } else {
65  auto mods = circtOp.getOps<FModuleOp>();
66  dutModuleSet.insert(mods.begin(), mods.end());
67  }
68 
69  mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
70  [&](Operation *op) {
71  if (auto mod = dyn_cast<FModuleOp>(op))
72  runOnModule(mod);
73  });
74  }
75 
76  void runOnModule(FModuleOp mod) {
77 
78  mod.getBodyBlock()->walk([&](MemOp memOp) {
79  LLVM_DEBUG(llvm::dbgs() << "\n Memory op:" << memOp);
81  return;
82 
83  auto firMem = memOp.getSummary();
84  // Ignore if the memory is candidate for macro replacement.
85  // The requirements for macro replacement:
86  // 1. read latency and write latency of one.
87  // 2. only one readwrite port or write port.
88  // 3. zero or one read port.
89  // 4. undefined read-under-write behavior.
90  if (replSeqMem &&
91  ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
92  (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
93  (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
94  return;
95 
96  generateMemory(memOp, firMem);
97  ++numConvertedMems;
98  memOp.erase();
99  });
100  }
101  Value addPipelineStages(ImplicitLocOpBuilder &b, size_t stages, Value clock,
102  Value pipeInput, StringRef name, Value gate = {}) {
103  if (!stages)
104  return pipeInput;
105 
106  while (stages--) {
107  auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
108  if (gate) {
109  b.create<WhenOp>(gate, /*withElseRegion*/ false, [&]() {
110  b.create<MatchingConnectOp>(reg, pipeInput);
111  });
112  } else
113  b.create<MatchingConnectOp>(reg, pipeInput);
114 
115  pipeInput = reg;
116  }
117 
118  return pipeInput;
119  }
120 
121  Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
122  return builder.create<SubfieldOp>(bundle, "clk");
123  }
124 
125  Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
126  return builder.create<SubfieldOp>(bundle, "addr");
127  }
128 
129  Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
130  return builder.create<SubfieldOp>(bundle, "wmode");
131  }
132 
133  Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
134  return builder.create<SubfieldOp>(bundle, "en");
135  }
136 
137  Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
138  auto bType = type_cast<BundleType>(bundle.getType());
139  if (bType.getElement("mask"))
140  return builder.create<SubfieldOp>(bundle, "mask");
141  return builder.create<SubfieldOp>(bundle, "wmask");
142  }
143 
144  Value getData(ImplicitLocOpBuilder &builder, Value bundle,
145  bool getWdata = false) {
146  auto bType = type_cast<BundleType>(bundle.getType());
147  if (bType.getElement("data"))
148  return builder.create<SubfieldOp>(bundle, "data");
149  if (bType.getElement("rdata") && !getWdata)
150  return builder.create<SubfieldOp>(bundle, "rdata");
151  return builder.create<SubfieldOp>(bundle, "wdata");
152  }
153 
154  void generateRead(const FirMemory &firMem, Value clock, Value addr,
155  Value enable, Value data, Value regOfVec,
156  ImplicitLocOpBuilder &builder) {
157  if (ignoreReadEnable) {
158  // If read enable is ignored, then guard the address update with read
159  // enable.
160  for (size_t j = 0, e = firMem.readLatency; j != e; ++j) {
161  auto enLast = enable;
162  if (j < e - 1)
163  enable = addPipelineStages(builder, 1, clock, enable, "en");
164  addr = addPipelineStages(builder, 1, clock, addr, "addr", enLast);
165  }
166  } else {
167  // Add pipeline stages to respect the read latency. One register for each
168  // latency cycle.
169  enable =
170  addPipelineStages(builder, firMem.readLatency, clock, enable, "en");
171  addr =
172  addPipelineStages(builder, firMem.readLatency, clock, addr, "addr");
173  }
174 
175  // Read the register[address] into a temporary.
176  Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
177  if (!ignoreReadEnable) {
178  // Initialize read data out with invalid.
179  builder.create<MatchingConnectOp>(
180  data, builder.create<InvalidValueOp>(data.getType()));
181  // If enable is true, then connect the data read from memory register.
182  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
183  builder.create<MatchingConnectOp>(data, rdata);
184  });
185  } else {
186  // Ignore read enable signal.
187  builder.create<MatchingConnectOp>(data, rdata);
188  }
189  }
190 
191  void generateWrite(const FirMemory &firMem, Value clock, Value addr,
192  Value enable, Value maskBits, Value wdataIn,
193  Value regOfVec, ImplicitLocOpBuilder &builder) {
194 
195  auto numStages = firMem.writeLatency - 1;
196  // Add pipeline stages to respect the write latency. Intermediate registers
197  // for each stage.
198  addr = addPipelineStages(builder, numStages, clock, addr, "addr");
199  enable = addPipelineStages(builder, numStages, clock, enable, "en");
200  wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
201  maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
202  // Create the register access.
203  FIRRTLBaseValue rdata = builder.create<SubaccessOp>(regOfVec, addr);
204 
205  // The tuple for the access to individual fields of an aggregate data type.
206  // Tuple::<register, data, mask>
207  // The logic:
208  // if (mask)
209  // register = data
210  SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
211 
212  // Write to each aggregate data field is guarded by the corresponding mask
213  // field. This means we have to generate read and write access for each
214  // individual field of the aggregate type.
215  // There are two options to handle this,
216  // 1. FlattenMemory: cast the aggregate data into a UInt and generate
217  // appropriate mask logic.
218  // 2. Create access for each individual field of the aggregate type.
219  // Here we implement the option 2 using getFields.
220  // getFields, creates an access to each individual field of the data and
221  // mask, and the corresponding field into the register. It populates
222  // the loweredRegDataMaskFields vector.
223  // This is similar to what happens in LowerTypes.
224  //
225  if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
226  builder)) {
227  wdataIn.getDefiningOp()->emitOpError(
228  "Cannot convert memory to bank of registers");
229  return;
230  }
231  // If enable:
232  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
233  // For each data field. Only one field if not aggregate.
234  for (auto regDataMask : loweredRegDataMaskFields) {
235  auto regField = std::get<0>(regDataMask);
236  auto dataField = std::get<1>(regDataMask);
237  auto maskField = std::get<2>(regDataMask);
238  // If mask, then update the register field.
239  builder.create<WhenOp>(maskField, /*withElseRegion*/ false, [&]() {
240  builder.create<MatchingConnectOp>(regField, dataField);
241  });
242  }
243  });
244  }
245 
246  void generateReadWrite(const FirMemory &firMem, Value clock, Value addr,
247  Value enable, Value maskBits, Value wdataIn,
248  Value rdataOut, Value wmode, Value regOfVec,
249  ImplicitLocOpBuilder &builder) {
250 
251  // Add pipeline stages to respect the write latency. Intermediate registers
252  // for each stage. Number of pipeline stages, max of read/write latency.
253  auto numStages = std::max(firMem.readLatency, firMem.writeLatency) - 1;
254  addr = addPipelineStages(builder, numStages, clock, addr, "addr");
255  enable = addPipelineStages(builder, numStages, clock, enable, "en");
256  wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
257  maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
258 
259  // Read the register[address] into a temporary.
260  Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
261 
262  SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
263  if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
264  builder)) {
265  wdataIn.getDefiningOp()->emitOpError(
266  "Cannot convert memory to bank of registers");
267  return;
268  }
269  // Initialize read data out with invalid.
270  builder.create<MatchingConnectOp>(
271  rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
272  // If enable:
273  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
274  // If write mode:
275  builder.create<WhenOp>(
276  wmode, true,
277  // Write block:
278  [&]() {
279  // For each data field. Only one field if not aggregate.
280  for (auto regDataMask : loweredRegDataMaskFields) {
281  auto regField = std::get<0>(regDataMask);
282  auto dataField = std::get<1>(regDataMask);
283  auto maskField = std::get<2>(regDataMask);
284  // If mask true, then set the field.
285  builder.create<WhenOp>(
286  maskField, /*withElseRegion*/ false, [&]() {
287  builder.create<MatchingConnectOp>(regField, dataField);
288  });
289  }
290  },
291  // Read block:
292  [&]() { builder.create<MatchingConnectOp>(rdataOut, rdata); });
293  });
294  }
295 
296  // Generate individual field accesses for an aggregate type. Return false if
297  // it fails. Which can happen if invalid fields are present of the mask and
298  // input types donot match. The assumption is that, \p reg and \p input have
299  // exactly the same type. And \p mask has the same bundle fields, but each
300  // field is of type UInt<1> So, populate the \p results with each field
301  // access. For example, the first entry should be access to first field of \p
302  // reg, first field of \p input and first field of \p mask.
303  bool getFields(Value reg, Value input, Value mask,
304  SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
305  ImplicitLocOpBuilder &builder) {
306 
307  // Check if the number of fields of mask and input type match.
308  auto isValidMask = [&](FIRRTLType inType, FIRRTLType maskType) -> bool {
309  if (auto bundle = type_dyn_cast<BundleType>(inType)) {
310  if (auto mBundle = type_dyn_cast<BundleType>(maskType))
311  return mBundle.getNumElements() == bundle.getNumElements();
312  } else if (auto vec = type_dyn_cast<FVectorType>(inType)) {
313  if (auto mVec = type_dyn_cast<FVectorType>(maskType))
314  return mVec.getNumElements() == vec.getNumElements();
315  } else
316  return true;
317  return false;
318  };
319 
320  std::function<bool(Value, Value, Value)> flatAccess =
321  [&](Value reg, Value input, Value mask) -> bool {
322  FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
323  if (!isValidMask(inType, type_cast<FIRRTLType>(mask.getType()))) {
324  input.getDefiningOp()->emitOpError("Mask type is not valid");
325  return false;
326  }
328  .Case<BundleType>([&](BundleType bundle) {
329  for (size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
330  auto regField = builder.create<SubfieldOp>(reg, i);
331  auto inputField = builder.create<SubfieldOp>(input, i);
332  auto maskField = builder.create<SubfieldOp>(mask, i);
333  if (!flatAccess(regField, inputField, maskField))
334  return false;
335  }
336  return true;
337  })
338  .Case<FVectorType>([&](auto vector) {
339  for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
340  auto regField = builder.create<SubindexOp>(reg, i);
341  auto inputField = builder.create<SubindexOp>(input, i);
342  auto maskField = builder.create<SubindexOp>(mask, i);
343  if (!flatAccess(regField, inputField, maskField))
344  return false;
345  }
346  return true;
347  })
348  .Case<IntType>([&](auto iType) {
349  results.push_back({reg, input, mask});
350  return iType.getWidth().has_value();
351  })
352  .Default([&](auto) { return false; });
353  };
354  if (flatAccess(reg, input, mask))
355  return true;
356  return false;
357  }
358 
359  void scatterMemTapAnno(RegOp op, ArrayAttr attr,
360  ImplicitLocOpBuilder &builder) {
361  AnnotationSet annos(attr);
362  SmallVector<Attribute> regAnnotations;
363  auto vecType = type_cast<FVectorType>(op.getResult().getType());
364  for (auto anno : annos) {
365  if (anno.isClass(memTapSourceClass)) {
366  for (size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
367  .getNumElements();
368  i != e; ++i) {
369  NamedAttrList newAnno;
370  newAnno.append("class", anno.getMember("class"));
371  newAnno.append("circt.fieldID",
372  builder.getI64IntegerAttr(vecType.getFieldID(i)));
373  newAnno.append("id", anno.getMember("id"));
374  if (auto nla = anno.getMember("circt.nonlocal"))
375  newAnno.append("circt.nonlocal", nla);
376  newAnno.append(
377  "portID",
378  IntegerAttr::get(IntegerType::get(builder.getContext(), 64), i));
379 
380  regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
381  }
382  } else
383  regAnnotations.push_back(anno.getAttr());
384  }
385  op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
386  }
387 
388  /// Generate the logic for implementing the memory using Registers.
389  void generateMemory(MemOp memOp, FirMemory &firMem) {
390  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
391  auto dataType = memOp.getDataType();
392 
393  auto innerSym = memOp.getInnerSym();
394  SmallVector<Value> debugPorts;
395 
396  RegOp regOfVec = {};
397  for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
398  ++index) {
399  auto result = memOp.getResult(index);
400  if (type_isa<RefType>(result.getType())) {
401  debugPorts.push_back(result);
402  continue;
403  }
404  // Create a temporary wire to replace the memory port. This makes it
405  // simpler to delete the memOp.
406  auto wire = builder.create<WireOp>(
407  result.getType(),
408  (memOp.getName() + "_" + memOp.getPortName(index).getValue()).str(),
409  memOp.getNameKind());
410  result.replaceAllUsesWith(wire.getResult());
411  result = wire.getResult();
412  // Create an access to all the common subfields.
413  auto adr = getAddr(builder, result);
414  auto enb = getEnable(builder, result);
415  auto clk = getClock(builder, result);
416  auto dta = getData(builder, result);
417  // IF the register is not yet created.
418  if (!regOfVec) {
419  // Create the register corresponding to the memory.
420  regOfVec = builder.create<RegOp>(
421  FVectorType::get(dataType, firMem.depth), clk, memOp.getNameAttr());
422 
423  // Copy all the memory annotations.
424  if (!memOp.getAnnotationsAttr().empty())
425  scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
426  if (innerSym)
427  regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
428  }
429  auto portKind = memOp.getPortKind(index);
430  if (portKind == MemOp::PortKind::Read) {
431  generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
432  } else if (portKind == MemOp::PortKind::Write) {
433  auto mask = getMask(builder, result);
434  generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
435  builder);
436  } else {
437  auto wmode = getWmode(builder, result);
438  auto wDta = getData(builder, result, true);
439  auto mask = getMask(builder, result);
440  generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
441  regOfVec.getResult(), builder);
442  }
443  }
444  // If a valid register is created, then replace all the debug port users
445  // with a RefType of the register. The RefType is obtained by using a
446  // RefSend on the register.
447  if (regOfVec)
448  for (auto r : debugPorts)
449  r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
450  }
451 
452 private:
453  bool replSeqMem;
454  bool ignoreReadEnable;
455 };
456 } // end anonymous namespace
457 
458 std::unique_ptr<mlir::Pass>
459 circt::firrtl::createMemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable) {
460  return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
461 }
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
constexpr const char * excludeMemToRegAnnoClass
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
Definition: FIRRTLTypes.h:394
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * dutAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:21