CIRCT  20.0.0git
InferMemories.cpp
Go to the documentation of this file.
1 //===- InferMemories.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "circt/Dialect/HW/HWOps.h"
15 #include "circt/Support/SymCache.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/Pass/Pass.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "arc-infer-memories"
21 
22 namespace circt {
23 namespace arc {
24 #define GEN_PASS_DEF_INFERMEMORIES
25 #include "circt/Dialect/Arc/ArcPasses.h.inc"
26 } // namespace arc
27 } // namespace circt
28 
29 using namespace circt;
30 using namespace arc;
31 
32 namespace {
33 struct InferMemoriesPass
34  : public arc::impl::InferMemoriesBase<InferMemoriesPass> {
35  using InferMemoriesBase::InferMemoriesBase;
36 
37  void runOnOperation() override;
38 
39  SmallVector<Operation *> opsToDelete;
40  SmallPtrSet<StringAttr, 2> schemaNames;
41  DenseMap<StringAttr, DictionaryAttr> memoryParams;
42 };
43 } // namespace
44 
45 void InferMemoriesPass::runOnOperation() {
46  auto module = getOperation();
47  opsToDelete.clear();
48  schemaNames.clear();
49  memoryParams.clear();
50 
51  SymbolCache cache;
52  cache.addDefinitions(module);
53  Namespace names;
54  names.add(cache);
55 
56  // Find the matching generator schemas.
57  for (auto schemaOp : module.getOps<hw::HWGeneratorSchemaOp>()) {
58  if (schemaOp.getDescriptor() == "FIRRTL_Memory") {
59  schemaNames.insert(schemaOp.getSymNameAttr());
60  opsToDelete.push_back(schemaOp);
61  }
62  }
63  LLVM_DEBUG(llvm::dbgs() << "Found " << schemaNames.size() << " schemas\n");
64 
65  // Find generated ops using these schemas.
66  for (auto genOp : module.getOps<hw::HWModuleGeneratedOp>()) {
67  if (!schemaNames.contains(genOp.getGeneratorKindAttr().getAttr()))
68  continue;
69  memoryParams[genOp.getModuleNameAttr()] = genOp->getAttrDictionary();
70  opsToDelete.push_back(genOp);
71  }
72  LLVM_DEBUG(llvm::dbgs() << "Found " << memoryParams.size()
73  << " memory modules\n");
74 
75  // Convert instances of the generated ops into dedicated memories.
76  unsigned numReplaced = 0;
77  module.walk([&](hw::InstanceOp instOp) {
78  auto it = memoryParams.find(instOp.getModuleNameAttr().getAttr());
79  if (it == memoryParams.end())
80  return;
81  ++numReplaced;
82  DictionaryAttr params = it->second;
83  auto width = params.getAs<IntegerAttr>("width").getValue().getZExtValue();
84  auto depth = params.getAs<IntegerAttr>("depth").getValue().getZExtValue();
85  auto maskGranAttr = params.getAs<IntegerAttr>("maskGran");
86  auto maskGran =
87  maskGranAttr ? maskGranAttr.getValue().getZExtValue() : width;
88  auto maskBits = width / maskGran;
89 
90  auto writeLatency =
91  params.getAs<IntegerAttr>("writeLatency").getValue().getZExtValue();
92  auto readLatency =
93  params.getAs<IntegerAttr>("readLatency").getValue().getZExtValue();
94  if (writeLatency != 1) {
95  instOp.emitError("unsupported memory write latency ") << writeLatency;
96  return signalPassFailure();
97  }
98 
99  // FIRRTL memories are currently underspecified. They have a single read
100  // latency, but it is unclear where within this latency the read of the
101  // underlying storage happens. The `HWMemSimImpl` pass implements memories
102  // such that the storage is probed at the very end of the latency, and that
103  // the probed value becomes available immediately. We keep the latencies
104  // configurable here, in the hopes that we'll improve our memory
105  // abstractions at some point.
106  unsigned readPreLatency = readLatency; // cycles before storage is read
107  unsigned readPostLatency = 0; // cycles to move read value to output
108 
109  ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
110  auto wordType = builder.getIntegerType(width);
111  auto addressTy = dyn_cast<IntegerType>(instOp.getOperand(0).getType());
112  if (!addressTy) {
113  instOp.emitError("expected integer type for memory addressing, got ")
114  << addressTy;
115  return signalPassFailure();
116  }
117  auto memType = MemoryType::get(&getContext(), depth, wordType, addressTy);
118  auto memOp = builder.create<MemoryOp>(memType);
119  if (tapMemories && !instOp.getInstanceName().empty())
120  memOp->setAttr("name", instOp.getInstanceNameAttr());
121 
122  unsigned argIdx = 0;
123  unsigned resultIdx = 0;
124 
125  auto applyLatency = [&](Value clock, Value data, unsigned latency) {
126  for (unsigned i = 0; i < latency; ++i)
127  data = builder.create<seq::CompRegOp>(
128  data, clock, builder.getStringAttr(""), Value{}, Value{}, Value{},
129  hw::InnerSymAttr{});
130  return data;
131  };
132 
133  SmallVector<std::tuple<Value, Value, SmallVector<Value>, bool, bool>>
134  writePorts;
135 
136  // Use `<inst-name>/` as the prefix for all port taps.
137  SmallString<64> tapPrefix(instOp.getInstanceName());
138  if (!tapPrefix.empty())
139  tapPrefix.push_back('/');
140  auto tapPrefixBaseLen = tapPrefix.size();
141 
142  auto tap = [&](Value value, const Twine &name) {
143  auto prefixedName = builder.getStringAttr(tapPrefix + "_" + name);
144  builder.create<arc::TapOp>(value, prefixedName);
145  };
146 
147  // Handle read ports.
148  auto numReadPorts =
149  params.getAs<IntegerAttr>("numReadPorts").getValue().getZExtValue();
150  for (unsigned portIdx = 0; portIdx != numReadPorts; ++portIdx) {
151  auto address = instOp.getOperand(argIdx++);
152  auto enable = instOp.getOperand(argIdx++);
153  auto clock = instOp.getOperand(argIdx++);
154  auto data = instOp.getResult(resultIdx++);
155 
156  if (address.getType() != addressTy) {
157  instOp.emitOpError("expected ")
158  << addressTy << ", but got " << address.getType();
159  return signalPassFailure();
160  }
161 
162  // Add port taps.
163  if (tapPorts) {
164  tapPrefix.resize(tapPrefixBaseLen);
165  (Twine("R") + Twine(portIdx)).toVector(tapPrefix);
166  tap(address, "addr");
167  tap(enable, "en");
168  tap(data, "data");
169  }
170 
171  // Apply the latency before the underlying storage is accessed.
172  address = applyLatency(clock, address, readPreLatency);
173  enable = applyLatency(clock, enable, readPreLatency);
174 
175  // Read the underlying storage. (The result of a disabled read port is
176  // undefined, currently we define it to be zero.)
177  Value readOp = builder.create<MemoryReadPortOp>(wordType, memOp, address);
178  Value zero = builder.create<hw::ConstantOp>(wordType, 0);
179  readOp = builder.create<comb::MuxOp>(enable, readOp, zero);
180 
181  // Apply the latency after the underlying storage was accessed. (If the
182  // latency is 0, the memory read is combinatorial without any buffer.)
183  readOp = applyLatency(clock, readOp, readPostLatency);
184  data.replaceAllUsesWith(readOp);
185  }
186 
187  // Handle read-write ports.
188  auto numReadWritePorts = params.getAs<IntegerAttr>("numReadWritePorts")
189  .getValue()
190  .getZExtValue();
191  for (unsigned portIdx = 0; portIdx != numReadWritePorts; ++portIdx) {
192  auto address = instOp.getOperand(argIdx++);
193  auto enable = instOp.getOperand(argIdx++);
194  auto clock = instOp.getOperand(argIdx++);
195  auto writeMode = instOp.getOperand(argIdx++);
196  auto writeData = instOp.getOperand(argIdx++);
197  auto writeMask = maskBits > 1 ? instOp.getOperand(argIdx++) : Value{};
198  auto readData = instOp.getResult(resultIdx++);
199 
200  if (address.getType() != addressTy) {
201  instOp.emitOpError("expected ")
202  << addressTy << ", but got " << address.getType();
203  return signalPassFailure();
204  }
205 
206  // Add port taps.
207  if (tapPorts) {
208  tapPrefix.resize(tapPrefixBaseLen);
209  (Twine("RW") + Twine(portIdx)).toVector(tapPrefix);
210  tap(address, "addr");
211  tap(enable, "en");
212  tap(writeMode, "wmode");
213  tap(writeData, "wdata");
214  if (writeMask)
215  tap(writeMask, "wmask");
216  tap(readData, "rdata");
217  }
218 
219  auto c1_i1 = builder.create<hw::ConstantOp>(builder.getI1Type(), 1);
220  auto notWriteMode = builder.create<comb::XorOp>(writeMode, c1_i1);
221  Value readEnable = builder.create<comb::AndOp>(enable, notWriteMode);
222 
223  // Apply the latency before the underlying storage is accessed.
224  Value readAddress = applyLatency(clock, address, readPreLatency);
225  readEnable = applyLatency(clock, readEnable, readPreLatency);
226 
227  // Read the underlying storage. (The result of a disabled read port is
228  // undefined, currently we define it to be zero.)
229  Value readOp =
230  builder.create<MemoryReadPortOp>(wordType, memOp, readAddress);
231  Value zero = builder.create<hw::ConstantOp>(wordType, 0);
232  readOp = builder.create<comb::MuxOp>(readEnable, readOp, zero);
233 
234  if (writeMask) {
235  unsigned maskWidth = cast<IntegerType>(writeMask.getType()).getWidth();
236  SmallVector<Value> toConcat;
237  for (unsigned i = 0; i < maskWidth; ++i) {
238  Value bit = builder.create<comb::ExtractOp>(writeMask, i, 1);
239  Value replicated = builder.create<comb::ReplicateOp>(bit, maskGran);
240  toConcat.push_back(replicated);
241  }
242  std::reverse(toConcat.begin(), toConcat.end()); // I hate concat
243  writeMask =
244  builder.create<comb::ConcatOp>(writeData.getType(), toConcat);
245  }
246 
247  // Apply the latency after the underlying storage was accessed. (If the
248  // latency is 0, the memory read is combinatorial without any buffer.)
249  readOp = applyLatency(clock, readOp, readPostLatency);
250  readData.replaceAllUsesWith(readOp);
251 
252  auto writeEnable = builder.create<comb::AndOp>(enable, writeMode);
253  SmallVector<Value> inputs({address, writeData, writeEnable});
254  if (writeMask)
255  inputs.push_back(writeMask);
256  writePorts.push_back({memOp, clock, inputs, true, !!writeMask});
257  }
258 
259  // Handle write ports.
260  auto numWritePorts =
261  params.getAs<IntegerAttr>("numWritePorts").getValue().getZExtValue();
262  for (unsigned portIdx = 0; portIdx != numWritePorts; ++portIdx) {
263  auto address = instOp.getOperand(argIdx++);
264  auto enable = instOp.getOperand(argIdx++);
265  auto clock = instOp.getOperand(argIdx++);
266  auto data = instOp.getOperand(argIdx++);
267  auto mask = maskBits > 1 ? instOp.getOperand(argIdx++) : Value{};
268 
269  if (address.getType() != addressTy) {
270  instOp.emitOpError("expected ")
271  << addressTy << ", but got " << address.getType();
272  return signalPassFailure();
273  }
274 
275  // Add port taps.
276  if (tapPorts) {
277  tapPrefix.resize(tapPrefixBaseLen);
278  (Twine("W") + Twine(portIdx)).toVector(tapPrefix);
279  tap(address, "addr");
280  tap(enable, "en");
281  tap(data, "data");
282  if (mask)
283  tap(mask, "mask");
284  }
285 
286  if (mask) {
287  unsigned maskWidth = cast<IntegerType>(mask.getType()).getWidth();
288  SmallVector<Value> toConcat;
289  for (unsigned i = 0; i < maskWidth; ++i) {
290  Value bit = builder.create<comb::ExtractOp>(mask, i, 1);
291  Value replicated = builder.create<comb::ReplicateOp>(bit, maskGran);
292  toConcat.push_back(replicated);
293  }
294  std::reverse(toConcat.begin(), toConcat.end()); // I hate concat
295  mask = builder.create<comb::ConcatOp>(data.getType(), toConcat);
296  }
297  SmallVector<Value> inputs({address, data});
298  if (enable)
299  inputs.push_back(enable);
300  if (mask)
301  inputs.push_back(mask);
302  writePorts.push_back({memOp, clock, inputs, !!enable, !!mask});
303  }
304 
305  // Create the actual write ports with a dependency arc to all read
306  // ports.
307  for (auto [memOp, clock, inputs, hasEnable, hasMask] : writePorts) {
308  auto ipSave = builder.saveInsertionPoint();
309  TypeRange types = ValueRange(inputs).getTypes();
310  builder.setInsertionPointToStart(module.getBody());
311  auto defOp = builder.create<DefineOp>(
312  names.newName("mem_write"), builder.getFunctionType(types, types));
313  auto &block = defOp.getBody().emplaceBlock();
314  auto args = block.addArguments(
315  types, SmallVector<Location>(types.size(), builder.getLoc()));
316  builder.setInsertionPointToEnd(&block);
317  builder.create<arc::OutputOp>(SmallVector<Value>(args));
318  builder.restoreInsertionPoint(ipSave);
319  builder.create<MemoryWritePortOp>(memOp, defOp.getName(), inputs, clock,
320  hasEnable, hasMask);
321  }
322 
323  opsToDelete.push_back(instOp);
324  });
325  LLVM_DEBUG(llvm::dbgs() << "Inferred " << numReplaced << " memories\n");
326 
327  for (auto *op : opsToDelete)
328  op->erase();
329 }
330 
331 std::unique_ptr<Pass>
332 arc::createInferMemoriesPass(const InferMemoriesOptions &options) {
333  return std::make_unique<InferMemoriesPass>(options);
334 }
std::map< std::string, WriteChannelPort & > writePorts
int32_t width
Definition: FIRRTL.cpp:36
static std::vector< mlir::Value > toVector(mlir::ValueRange range)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
void add(mlir::ModuleOp module)
Definition: Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:85
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
def create(low_bit, result_type, input=None)
Definition: comb.py:187
def create(data_type, value)
Definition: hw.py:393
std::unique_ptr< mlir::Pass > createInferMemoriesPass(const InferMemoriesOptions &options={})
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21