CIRCT  19.0.0git
MemToRegOfVec.cpp
Go to the documentation of this file.
1 //===- MemToRegOfVec.cpp - MemToRegOfVec Pass -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MemToRegOfVec pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/Threading.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "mem-to-reg-of-vec"
28 
29 using namespace circt;
30 using namespace firrtl;
31 
32 namespace {
33 struct MemToRegOfVecPass : public MemToRegOfVecBase<MemToRegOfVecPass> {
34  MemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable)
35  : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
36 
37  void runOnOperation() override {
38  auto circtOp = getOperation();
39  DenseSet<Operation *> dutModuleSet;
42  return;
43  auto *body = circtOp.getBodyBlock();
44 
45  // Find the device under test and create a set of all modules underneath it.
46  auto it = llvm::find_if(*body, [&](Operation &op) -> bool {
48  });
49  if (it != body->end()) {
50  auto &instanceGraph = getAnalysis<InstanceGraph>();
51  auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*it));
52  llvm::for_each(llvm::depth_first(node),
53  [&](igraph::InstanceGraphNode *node) {
54  dutModuleSet.insert(node->getModule());
55  });
56  } else {
57  auto mods = circtOp.getOps<FModuleOp>();
58  dutModuleSet.insert(mods.begin(), mods.end());
59  }
60 
61  mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
62  [&](Operation *op) {
63  if (auto mod = dyn_cast<FModuleOp>(op))
64  runOnModule(mod);
65  });
66  }
67 
68  void runOnModule(FModuleOp mod) {
69 
70  mod.getBodyBlock()->walk([&](MemOp memOp) {
71  LLVM_DEBUG(llvm::dbgs() << "\n Memory op:" << memOp);
73  return;
74 
75  auto firMem = memOp.getSummary();
76  // Ignore if the memory is candidate for macro replacement.
77  // The requirements for macro replacement:
78  // 1. read latency and write latency of one.
79  // 2. only one readwrite port or write port.
80  // 3. zero or one read port.
81  // 4. undefined read-under-write behavior.
82  if (replSeqMem &&
83  ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
84  (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
85  (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
86  return;
87 
88  generateMemory(memOp, firMem);
89  ++numConvertedMems;
90  memOp.erase();
91  });
92  }
93  Value addPipelineStages(ImplicitLocOpBuilder &b, size_t stages, Value clock,
94  Value pipeInput, StringRef name, Value gate = {}) {
95  if (!stages)
96  return pipeInput;
97 
98  while (stages--) {
99  auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
100  if (gate) {
101  b.create<WhenOp>(gate, /*withElseRegion*/ false,
102  [&]() { b.create<StrictConnectOp>(reg, pipeInput); });
103  } else
104  b.create<StrictConnectOp>(reg, pipeInput);
105 
106  pipeInput = reg;
107  }
108 
109  return pipeInput;
110  }
111 
112  Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
113  return builder.create<SubfieldOp>(bundle, "clk");
114  }
115 
116  Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
117  return builder.create<SubfieldOp>(bundle, "addr");
118  }
119 
120  Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
121  return builder.create<SubfieldOp>(bundle, "wmode");
122  }
123 
124  Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
125  return builder.create<SubfieldOp>(bundle, "en");
126  }
127 
128  Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
129  auto bType = type_cast<BundleType>(bundle.getType());
130  if (bType.getElement("mask"))
131  return builder.create<SubfieldOp>(bundle, "mask");
132  return builder.create<SubfieldOp>(bundle, "wmask");
133  }
134 
135  Value getData(ImplicitLocOpBuilder &builder, Value bundle,
136  bool getWdata = false) {
137  auto bType = type_cast<BundleType>(bundle.getType());
138  if (bType.getElement("data"))
139  return builder.create<SubfieldOp>(bundle, "data");
140  if (bType.getElement("rdata") && !getWdata)
141  return builder.create<SubfieldOp>(bundle, "rdata");
142  return builder.create<SubfieldOp>(bundle, "wdata");
143  }
144 
145  void generateRead(const FirMemory &firMem, Value clock, Value addr,
146  Value enable, Value data, Value regOfVec,
147  ImplicitLocOpBuilder &builder) {
148  if (ignoreReadEnable) {
149  // If read enable is ignored, then guard the address update with read
150  // enable.
151  for (size_t j = 0, e = firMem.readLatency; j != e; ++j) {
152  auto enLast = enable;
153  if (j < e - 1)
154  enable = addPipelineStages(builder, 1, clock, enable, "en");
155  addr = addPipelineStages(builder, 1, clock, addr, "addr", enLast);
156  }
157  } else {
158  // Add pipeline stages to respect the read latency. One register for each
159  // latency cycle.
160  enable =
161  addPipelineStages(builder, firMem.readLatency, clock, enable, "en");
162  addr =
163  addPipelineStages(builder, firMem.readLatency, clock, addr, "addr");
164  }
165 
166  // Read the register[address] into a temporary.
167  Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
168  if (!ignoreReadEnable) {
169  // Initialize read data out with invalid.
170  builder.create<StrictConnectOp>(
171  data, builder.create<InvalidValueOp>(data.getType()));
172  // If enable is true, then connect the data read from memory register.
173  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
174  builder.create<StrictConnectOp>(data, rdata);
175  });
176  } else {
177  // Ignore read enable signal.
178  builder.create<StrictConnectOp>(data, rdata);
179  }
180  }
181 
182  void generateWrite(const FirMemory &firMem, Value clock, Value addr,
183  Value enable, Value maskBits, Value wdataIn,
184  Value regOfVec, ImplicitLocOpBuilder &builder) {
185 
186  auto numStages = firMem.writeLatency - 1;
187  // Add pipeline stages to respect the write latency. Intermediate registers
188  // for each stage.
189  addr = addPipelineStages(builder, numStages, clock, addr, "addr");
190  enable = addPipelineStages(builder, numStages, clock, enable, "en");
191  wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
192  maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
193  // Create the register access.
194  FIRRTLBaseValue rdata = builder.create<SubaccessOp>(regOfVec, addr);
195 
196  // The tuple for the access to individual fields of an aggregate data type.
197  // Tuple::<register, data, mask>
198  // The logic:
199  // if (mask)
200  // register = data
201  SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
202 
203  // Write to each aggregate data field is guarded by the corresponding mask
204  // field. This means we have to generate read and write access for each
205  // individual field of the aggregate type.
206  // There are two options to handle this,
207  // 1. FlattenMemory: cast the aggregate data into a UInt and generate
208  // appropriate mask logic.
209  // 2. Create access for each individual field of the aggregate type.
210  // Here we implement the option 2 using getFields.
211  // getFields, creates an access to each individual field of the data and
212  // mask, and the corresponding field into the register. It populates
213  // the loweredRegDataMaskFields vector.
214  // This is similar to what happens in LowerTypes.
215  //
216  if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
217  builder)) {
218  wdataIn.getDefiningOp()->emitOpError(
219  "Cannot convert memory to bank of registers");
220  return;
221  }
222  // If enable:
223  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
224  // For each data field. Only one field if not aggregate.
225  for (auto regDataMask : loweredRegDataMaskFields) {
226  auto regField = std::get<0>(regDataMask);
227  auto dataField = std::get<1>(regDataMask);
228  auto maskField = std::get<2>(regDataMask);
229  // If mask, then update the register field.
230  builder.create<WhenOp>(maskField, /*withElseRegion*/ false, [&]() {
231  builder.create<StrictConnectOp>(regField, dataField);
232  });
233  }
234  });
235  }
236 
237  void generateReadWrite(const FirMemory &firMem, Value clock, Value addr,
238  Value enable, Value maskBits, Value wdataIn,
239  Value rdataOut, Value wmode, Value regOfVec,
240  ImplicitLocOpBuilder &builder) {
241 
242  // Add pipeline stages to respect the write latency. Intermediate registers
243  // for each stage. Number of pipeline stages, max of read/write latency.
244  auto numStages = std::max(firMem.readLatency, firMem.writeLatency) - 1;
245  addr = addPipelineStages(builder, numStages, clock, addr, "addr");
246  enable = addPipelineStages(builder, numStages, clock, enable, "en");
247  wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
248  maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
249 
250  // Read the register[address] into a temporary.
251  Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
252 
253  SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
254  if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
255  builder)) {
256  wdataIn.getDefiningOp()->emitOpError(
257  "Cannot convert memory to bank of registers");
258  return;
259  }
260  // Initialize read data out with invalid.
261  builder.create<StrictConnectOp>(
262  rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
263  // If enable:
264  builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
265  // If write mode:
266  builder.create<WhenOp>(
267  wmode, true,
268  // Write block:
269  [&]() {
270  // For each data field. Only one field if not aggregate.
271  for (auto regDataMask : loweredRegDataMaskFields) {
272  auto regField = std::get<0>(regDataMask);
273  auto dataField = std::get<1>(regDataMask);
274  auto maskField = std::get<2>(regDataMask);
275  // If mask true, then set the field.
276  builder.create<WhenOp>(
277  maskField, /*withElseRegion*/ false, [&]() {
278  builder.create<StrictConnectOp>(regField, dataField);
279  });
280  }
281  },
282  // Read block:
283  [&]() { builder.create<StrictConnectOp>(rdataOut, rdata); });
284  });
285  }
286 
287  // Generate individual field accesses for an aggregate type. Return false if
288  // it fails. Which can happen if invalid fields are present of the mask and
289  // input types donot match. The assumption is that, \p reg and \p input have
290  // exactly the same type. And \p mask has the same bundle fields, but each
291  // field is of type UInt<1> So, populate the \p results with each field
292  // access. For example, the first entry should be access to first field of \p
293  // reg, first field of \p input and first field of \p mask.
294  bool getFields(Value reg, Value input, Value mask,
295  SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
296  ImplicitLocOpBuilder &builder) {
297 
298  // Check if the number of fields of mask and input type match.
299  auto isValidMask = [&](FIRRTLType inType, FIRRTLType maskType) -> bool {
300  if (auto bundle = type_dyn_cast<BundleType>(inType)) {
301  if (auto mBundle = type_dyn_cast<BundleType>(maskType))
302  return mBundle.getNumElements() == bundle.getNumElements();
303  } else if (auto vec = type_dyn_cast<FVectorType>(inType)) {
304  if (auto mVec = type_dyn_cast<FVectorType>(maskType))
305  return mVec.getNumElements() == vec.getNumElements();
306  } else
307  return true;
308  return false;
309  };
310 
311  std::function<bool(Value, Value, Value)> flatAccess =
312  [&](Value reg, Value input, Value mask) -> bool {
313  FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
314  if (!isValidMask(inType, type_cast<FIRRTLType>(mask.getType()))) {
315  input.getDefiningOp()->emitOpError("Mask type is not valid");
316  return false;
317  }
319  .Case<BundleType>([&](BundleType bundle) {
320  for (size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
321  auto regField = builder.create<SubfieldOp>(reg, i);
322  auto inputField = builder.create<SubfieldOp>(input, i);
323  auto maskField = builder.create<SubfieldOp>(mask, i);
324  if (!flatAccess(regField, inputField, maskField))
325  return false;
326  }
327  return true;
328  })
329  .Case<FVectorType>([&](auto vector) {
330  for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
331  auto regField = builder.create<SubindexOp>(reg, i);
332  auto inputField = builder.create<SubindexOp>(input, i);
333  auto maskField = builder.create<SubindexOp>(mask, i);
334  if (!flatAccess(regField, inputField, maskField))
335  return false;
336  }
337  return true;
338  })
339  .Case<IntType>([&](auto iType) {
340  results.push_back({reg, input, mask});
341  return iType.getWidth().has_value();
342  })
343  .Default([&](auto) { return false; });
344  };
345  if (flatAccess(reg, input, mask))
346  return true;
347  return false;
348  }
349 
350  void scatterMemTapAnno(RegOp op, ArrayAttr attr,
351  ImplicitLocOpBuilder &builder) {
352  AnnotationSet annos(attr);
353  SmallVector<Attribute> regAnnotations;
354  auto vecType = type_cast<FVectorType>(op.getResult().getType());
355  for (auto anno : annos) {
356  if (anno.isClass(memTapSourceClass)) {
357  for (size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
358  .getNumElements();
359  i != e; ++i) {
360  NamedAttrList newAnno;
361  newAnno.append("class", anno.getMember("class"));
362  newAnno.append("circt.fieldID",
363  builder.getI64IntegerAttr(vecType.getFieldID(i)));
364  newAnno.append("id", anno.getMember("id"));
365  if (auto nla = anno.getMember("circt.nonlocal"))
366  newAnno.append("circt.nonlocal", nla);
367  newAnno.append(
368  "portID",
369  IntegerAttr::get(IntegerType::get(builder.getContext(), 64), i));
370 
371  regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
372  }
373  } else
374  regAnnotations.push_back(anno.getAttr());
375  }
376  op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
377  }
378 
379  /// Generate the logic for implementing the memory using Registers.
380  void generateMemory(MemOp memOp, FirMemory &firMem) {
381  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
382  auto dataType = memOp.getDataType();
383 
384  auto innerSym = memOp.getInnerSym();
385  SmallVector<Value> debugPorts;
386 
387  RegOp regOfVec = {};
388  for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
389  ++index) {
390  auto result = memOp.getResult(index);
391  if (type_isa<RefType>(result.getType())) {
392  debugPorts.push_back(result);
393  continue;
394  }
395  // Create a temporary wire to replace the memory port. This makes it
396  // simpler to delete the memOp.
397  auto wire = builder.create<WireOp>(
398  result.getType(),
399  (memOp.getName() + "_" + memOp.getPortName(index).getValue()).str(),
400  memOp.getNameKind());
401  result.replaceAllUsesWith(wire.getResult());
402  result = wire.getResult();
403  // Create an access to all the common subfields.
404  auto adr = getAddr(builder, result);
405  auto enb = getEnable(builder, result);
406  auto clk = getClock(builder, result);
407  auto dta = getData(builder, result);
408  // IF the register is not yet created.
409  if (!regOfVec) {
410  // Create the register corresponding to the memory.
411  regOfVec = builder.create<RegOp>(
412  FVectorType::get(dataType, firMem.depth), clk, memOp.getNameAttr());
413 
414  // Copy all the memory annotations.
415  if (!memOp.getAnnotationsAttr().empty())
416  scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
417  if (innerSym)
418  regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
419  }
420  auto portKind = memOp.getPortKind(index);
421  if (portKind == MemOp::PortKind::Read) {
422  generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
423  } else if (portKind == MemOp::PortKind::Write) {
424  auto mask = getMask(builder, result);
425  generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
426  builder);
427  } else {
428  auto wmode = getWmode(builder, result);
429  auto wDta = getData(builder, result, true);
430  auto mask = getMask(builder, result);
431  generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
432  regOfVec.getResult(), builder);
433  }
434  }
435  // If a valid register is created, then replace all the debug port users
436  // with a RefType of the register. The RefType is obtained by using a
437  // RefSend on the register.
438  if (regOfVec)
439  for (auto r : debugPorts)
440  r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
441  }
442 
443 private:
444  bool replSeqMem;
445  bool ignoreReadEnable;
446 };
447 } // end anonymous namespace
448 
449 std::unique_ptr<mlir::Pass>
450 circt::firrtl::createMemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable) {
451  return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
452 }
Builder builder
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:518
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:528
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
constexpr const char * excludeMemToRegAnnoClass
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
Definition: FIRRTLTypes.h:392
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * dutAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20