CIRCT  20.0.0git
MemToRegOfVec.cpp
Go to the documentation of this file.
1 //===- MemToRegOfVec.cpp - MemToRegOfVec Pass -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MemToRegOfVec pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/DepthFirstIterator.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "mem-to-reg-of-vec"
25 
26 namespace circt {
27 namespace firrtl {
28 #define GEN_PASS_DEF_MEMTOREGOFVEC
29 #include "circt/Dialect/FIRRTL/Passes.h.inc"
30 } // namespace firrtl
31 } // namespace circt
32 
33 using namespace circt;
34 using namespace firrtl;
35 
36 namespace {
37 struct MemToRegOfVecPass
38  : public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
39  MemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable)
40  : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
41 
42  void runOnOperation() override {
43  auto circtOp = getOperation();
44  DenseSet<Operation *> dutModuleSet;
47  return markAllAnalysesPreserved();
48  auto *body = circtOp.getBodyBlock();
49 
50  // Find the device under test and create a set of all modules underneath it.
51  auto it = llvm::find_if(*body, [&](Operation &op) -> bool {
53  });
54  if (it != body->end()) {
55  auto &instanceGraph = getAnalysis<InstanceGraph>();
56  auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
57  llvm::for_each(llvm::depth_first(node),
58  [&](igraph::InstanceGraphNode *node) {
59  dutModuleSet.insert(node->getModule());
60  });
61  } else {
62  auto mods = circtOp.getOps<FModuleOp>();
63  dutModuleSet.insert(mods.begin(), mods.end());
64  }
65 
66  mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
67  [&](Operation *op) {
68  if (auto mod = dyn_cast<FModuleOp>(op))
69  runOnModule(mod);
70  });
71  }
72 
73  void runOnModule(FModuleOp mod) {
74 
75  mod.getBodyBlock()->walk([&](MemOp memOp) {
76  LLVM_DEBUG(llvm::dbgs() << "\n Memory op:" << memOp);
78  return;
79 
80  auto firMem = memOp.getSummary();
81  // Ignore if the memory is candidate for macro replacement.
82  // The requirements for macro replacement:
83  // 1. read latency and write latency of one.
84  // 2. only one readwrite port or write port.
85  // 3. zero or one read port.
86  // 4. undefined read-under-write behavior.
87  if (replSeqMem &&
88  ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
89  (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
90  (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
91  return;
92 
93  generateMemory(memOp, firMem);
94  ++numConvertedMems;
95  memOp.erase();
96  });
97  }
98  Value addPipelineStages(ImplicitLocOpBuilder &b, size_t stages, Value clock,
99  Value pipeInput, StringRef name, Value gate = {}) {
100  if (!stages)
101  return pipeInput;
102 
103  while (stages--) {
104  auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
105  if (gate) {
106  b.create<WhenOp>(gate, /*withElseRegion*/ false, [&]() {
107  b.create<MatchingConnectOp>(reg, pipeInput);
108  });
109  } else
110  b.create<MatchingConnectOp>(reg, pipeInput);
111 
112  pipeInput = reg;
113  }
114 
115  return pipeInput;
116  }
117 
118  Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
119  return builder.create<SubfieldOp>(bundle, "clk");
120  }
121 
122  Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
123  return builder.create<SubfieldOp>(bundle, "addr");
124  }
125 
126  Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
127  return builder.create<SubfieldOp>(bundle, "wmode");
128  }
129 
130  Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
131  return builder.create<SubfieldOp>(bundle, "en");
132  }
133 
134  Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
135  auto bType = type_cast<BundleType>(bundle.getType());
136  if (bType.getElement("mask"))
137  return builder.create<SubfieldOp>(bundle, "mask");
138  return builder.create<SubfieldOp>(bundle, "wmask");
139  }
140 
141  Value getData(ImplicitLocOpBuilder &builder, Value bundle,
142  bool getWdata = false) {
143  auto bType = type_cast<BundleType>(bundle.getType());
144  if (bType.getElement("data"))
145  return builder.create<SubfieldOp>(bundle, "data");
146  if (bType.getElement("rdata") && !getWdata)
147  return builder.create<SubfieldOp>(bundle, "rdata");
148  return builder.create<SubfieldOp>(bundle, "wdata");
149  }
150 
151  void generateRead(const FirMemory &firMem, Value clock, Value addr,
152  Value enable, Value data, Value regOfVec,
153  ImplicitLocOpBuilder &builder) {
154  if (ignoreReadEnable) {
155  // If read enable is ignored, then guard the address update with read
156  // enable.
157  for (size_t j = 0, e = firMem.readLatency; j != e; ++j) {
158  auto enLast = enable;
159  if (j < e - 1)
160  enable = addPipelineStages(builder, 1, clock, enable, "en");
161  addr = addPipelineStages(builder, 1, clock, addr, "addr", enLast);
162  }
163  } else {
164  // Add pipeline stages to respect the read latency. One register for each
165  // latency cycle.
166  enable =
167  addPipelineStages(builder, firMem.readLatency, clock, enable, "en");
168  addr =
169  addPipelineStages(builder, firMem.readLatency, clock, addr, "addr");
170  }
171 
172  // Read the register[address] into a temporary.
173  Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
174  if (!ignoreReadEnable) {
175  // Initialize read data out with invalid.
176  builder.create<MatchingConnectOp>(
177  data, builder.create<InvalidValueOp>(data.getType()));
178  // If enable is true, then connect the data read from memory register.
179  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
180  builder.create<MatchingConnectOp>(data, rdata);
181  });
182  } else {
183  // Ignore read enable signal.
184  builder.create<MatchingConnectOp>(data, rdata);
185  }
186  }
187 
188  void generateWrite(const FirMemory &firMem, Value clock, Value addr,
189  Value enable, Value maskBits, Value wdataIn,
190  Value regOfVec, ImplicitLocOpBuilder &builder) {
191 
192  auto numStages = firMem.writeLatency - 1;
193  // Add pipeline stages to respect the write latency. Intermediate registers
194  // for each stage.
195  addr = addPipelineStages(builder, numStages, clock, addr, "addr");
196  enable = addPipelineStages(builder, numStages, clock, enable, "en");
197  wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
198  maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
199  // Create the register access.
200  FIRRTLBaseValue rdata = builder.create<SubaccessOp>(regOfVec, addr);
201 
202  // The tuple for the access to individual fields of an aggregate data type.
203  // Tuple::<register, data, mask>
204  // The logic:
205  // if (mask)
206  // register = data
207  SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
208 
209  // Write to each aggregate data field is guarded by the corresponding mask
210  // field. This means we have to generate read and write access for each
211  // individual field of the aggregate type.
212  // There are two options to handle this,
213  // 1. FlattenMemory: cast the aggregate data into a UInt and generate
214  // appropriate mask logic.
215  // 2. Create access for each individual field of the aggregate type.
216  // Here we implement the option 2 using getFields.
217  // getFields, creates an access to each individual field of the data and
218  // mask, and the corresponding field into the register. It populates
219  // the loweredRegDataMaskFields vector.
220  // This is similar to what happens in LowerTypes.
221  //
222  if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
223  builder)) {
224  wdataIn.getDefiningOp()->emitOpError(
225  "Cannot convert memory to bank of registers");
226  return;
227  }
228  // If enable:
229  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
230  // For each data field. Only one field if not aggregate.
231  for (auto regDataMask : loweredRegDataMaskFields) {
232  auto regField = std::get<0>(regDataMask);
233  auto dataField = std::get<1>(regDataMask);
234  auto maskField = std::get<2>(regDataMask);
235  // If mask, then update the register field.
236  builder.create<WhenOp>(maskField, /*withElseRegion*/ false, [&]() {
237  builder.create<MatchingConnectOp>(regField, dataField);
238  });
239  }
240  });
241  }
242 
243  void generateReadWrite(const FirMemory &firMem, Value clock, Value addr,
244  Value enable, Value maskBits, Value wdataIn,
245  Value rdataOut, Value wmode, Value regOfVec,
246  ImplicitLocOpBuilder &builder) {
247 
248  // Add pipeline stages to respect the write latency. Intermediate registers
249  // for each stage. Number of pipeline stages, max of read/write latency.
250  auto numStages = std::max(firMem.readLatency, firMem.writeLatency) - 1;
251  addr = addPipelineStages(builder, numStages, clock, addr, "addr");
252  enable = addPipelineStages(builder, numStages, clock, enable, "en");
253  wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
254  maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
255 
256  // Read the register[address] into a temporary.
257  Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
258 
259  SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
260  if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
261  builder)) {
262  wdataIn.getDefiningOp()->emitOpError(
263  "Cannot convert memory to bank of registers");
264  return;
265  }
266  // Initialize read data out with invalid.
267  builder.create<MatchingConnectOp>(
268  rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
269  // If enable:
270  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
271  // If write mode:
272  builder.create<WhenOp>(
273  wmode, true,
274  // Write block:
275  [&]() {
276  // For each data field. Only one field if not aggregate.
277  for (auto regDataMask : loweredRegDataMaskFields) {
278  auto regField = std::get<0>(regDataMask);
279  auto dataField = std::get<1>(regDataMask);
280  auto maskField = std::get<2>(regDataMask);
281  // If mask true, then set the field.
282  builder.create<WhenOp>(
283  maskField, /*withElseRegion*/ false, [&]() {
284  builder.create<MatchingConnectOp>(regField, dataField);
285  });
286  }
287  },
288  // Read block:
289  [&]() { builder.create<MatchingConnectOp>(rdataOut, rdata); });
290  });
291  }
292 
293  // Generate individual field accesses for an aggregate type. Return false if
294  // it fails. Which can happen if invalid fields are present of the mask and
295  // input types donot match. The assumption is that, \p reg and \p input have
296  // exactly the same type. And \p mask has the same bundle fields, but each
297  // field is of type UInt<1> So, populate the \p results with each field
298  // access. For example, the first entry should be access to first field of \p
299  // reg, first field of \p input and first field of \p mask.
300  bool getFields(Value reg, Value input, Value mask,
301  SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
302  ImplicitLocOpBuilder &builder) {
303 
304  // Check if the number of fields of mask and input type match.
305  auto isValidMask = [&](FIRRTLType inType, FIRRTLType maskType) -> bool {
306  if (auto bundle = type_dyn_cast<BundleType>(inType)) {
307  if (auto mBundle = type_dyn_cast<BundleType>(maskType))
308  return mBundle.getNumElements() == bundle.getNumElements();
309  } else if (auto vec = type_dyn_cast<FVectorType>(inType)) {
310  if (auto mVec = type_dyn_cast<FVectorType>(maskType))
311  return mVec.getNumElements() == vec.getNumElements();
312  } else
313  return true;
314  return false;
315  };
316 
317  std::function<bool(Value, Value, Value)> flatAccess =
318  [&](Value reg, Value input, Value mask) -> bool {
319  FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
320  if (!isValidMask(inType, type_cast<FIRRTLType>(mask.getType()))) {
321  input.getDefiningOp()->emitOpError("Mask type is not valid");
322  return false;
323  }
325  .Case<BundleType>([&](BundleType bundle) {
326  for (size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
327  auto regField = builder.create<SubfieldOp>(reg, i);
328  auto inputField = builder.create<SubfieldOp>(input, i);
329  auto maskField = builder.create<SubfieldOp>(mask, i);
330  if (!flatAccess(regField, inputField, maskField))
331  return false;
332  }
333  return true;
334  })
335  .Case<FVectorType>([&](auto vector) {
336  for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
337  auto regField = builder.create<SubindexOp>(reg, i);
338  auto inputField = builder.create<SubindexOp>(input, i);
339  auto maskField = builder.create<SubindexOp>(mask, i);
340  if (!flatAccess(regField, inputField, maskField))
341  return false;
342  }
343  return true;
344  })
345  .Case<IntType>([&](auto iType) {
346  results.push_back({reg, input, mask});
347  return iType.getWidth().has_value();
348  })
349  .Default([&](auto) { return false; });
350  };
351  if (flatAccess(reg, input, mask))
352  return true;
353  return false;
354  }
355 
356  void scatterMemTapAnno(RegOp op, ArrayAttr attr,
357  ImplicitLocOpBuilder &builder) {
358  AnnotationSet annos(attr);
359  SmallVector<Attribute> regAnnotations;
360  auto vecType = type_cast<FVectorType>(op.getResult().getType());
361  for (auto anno : annos) {
362  if (anno.isClass(memTapSourceClass)) {
363  for (size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
364  .getNumElements();
365  i != e; ++i) {
366  NamedAttrList newAnno;
367  newAnno.append("class", anno.getMember("class"));
368  newAnno.append("circt.fieldID",
369  builder.getI64IntegerAttr(vecType.getFieldID(i)));
370  newAnno.append("id", anno.getMember("id"));
371  if (auto nla = anno.getMember("circt.nonlocal"))
372  newAnno.append("circt.nonlocal", nla);
373  newAnno.append(
374  "portID",
375  IntegerAttr::get(IntegerType::get(builder.getContext(), 64), i));
376 
377  regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
378  }
379  } else
380  regAnnotations.push_back(anno.getAttr());
381  }
382  op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
383  }
384 
385  /// Generate the logic for implementing the memory using Registers.
386  void generateMemory(MemOp memOp, FirMemory &firMem) {
387  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
388  auto dataType = memOp.getDataType();
389 
390  auto innerSym = memOp.getInnerSym();
391  SmallVector<Value> debugPorts;
392 
393  RegOp regOfVec = {};
394  for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
395  ++index) {
396  auto result = memOp.getResult(index);
397  if (type_isa<RefType>(result.getType())) {
398  debugPorts.push_back(result);
399  continue;
400  }
401  // Create a temporary wire to replace the memory port. This makes it
402  // simpler to delete the memOp.
403  auto wire = builder.create<WireOp>(
404  result.getType(),
405  (memOp.getName() + "_" + memOp.getPortName(index).getValue()).str(),
406  memOp.getNameKind());
407  result.replaceAllUsesWith(wire.getResult());
408  result = wire.getResult();
409  // Create an access to all the common subfields.
410  auto adr = getAddr(builder, result);
411  auto enb = getEnable(builder, result);
412  auto clk = getClock(builder, result);
413  auto dta = getData(builder, result);
414  // IF the register is not yet created.
415  if (!regOfVec) {
416  // Create the register corresponding to the memory.
417  regOfVec = builder.create<RegOp>(
418  FVectorType::get(dataType, firMem.depth), clk, memOp.getNameAttr());
419 
420  // Copy all the memory annotations.
421  if (!memOp.getAnnotationsAttr().empty())
422  scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
423  if (innerSym)
424  regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
425  }
426  auto portKind = memOp.getPortKind(index);
427  if (portKind == MemOp::PortKind::Read) {
428  generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
429  } else if (portKind == MemOp::PortKind::Write) {
430  auto mask = getMask(builder, result);
431  generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
432  builder);
433  } else {
434  auto wmode = getWmode(builder, result);
435  auto wDta = getData(builder, result, true);
436  auto mask = getMask(builder, result);
437  generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
438  regOfVec.getResult(), builder);
439  }
440  }
441  // If a valid register is created, then replace all the debug port users
442  // with a RefType of the register. The RefType is obtained by using a
443  // RefSend on the register.
444  if (regOfVec)
445  for (auto r : debugPorts)
446  r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
447  }
448 
449 private:
450  bool replSeqMem;
451  bool ignoreReadEnable;
452 };
453 } // end anonymous namespace
454 
455 std::unique_ptr<mlir::Pass>
456 circt::firrtl::createMemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable) {
457  return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
458 }
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
constexpr const char * excludeMemToRegAnnoClass
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
Definition: FIRRTLTypes.h:394
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * dutAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:21