CIRCT 20.0.0git
Loading...
Searching...
No Matches
InferMemories.cpp
Go to the documentation of this file.
1//===- InferMemories.cpp --------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/ImplicitLocOpBuilder.h"
17#include "mlir/Pass/Pass.h"
18#include "llvm/Support/Debug.h"
19
20#define DEBUG_TYPE "arc-infer-memories"
21
22namespace circt {
23namespace arc {
24#define GEN_PASS_DEF_INFERMEMORIES
25#include "circt/Dialect/Arc/ArcPasses.h.inc"
26} // namespace arc
27} // namespace circt
28
29using namespace circt;
30using namespace arc;
31
32namespace {
33struct InferMemoriesPass
34 : public arc::impl::InferMemoriesBase<InferMemoriesPass> {
35 using InferMemoriesBase::InferMemoriesBase;
36
37 void runOnOperation() override;
38
39 SmallVector<Operation *> opsToDelete;
40 SmallPtrSet<StringAttr, 2> schemaNames;
41 DenseMap<StringAttr, DictionaryAttr> memoryParams;
42};
43} // namespace
44
45void InferMemoriesPass::runOnOperation() {
46 auto module = getOperation();
47 opsToDelete.clear();
48 schemaNames.clear();
49 memoryParams.clear();
50
51 SymbolCache cache;
52 cache.addDefinitions(module);
53 Namespace names;
54 names.add(cache);
55
56 // Find the matching generator schemas.
57 for (auto schemaOp : module.getOps<hw::HWGeneratorSchemaOp>()) {
58 if (schemaOp.getDescriptor() == "FIRRTL_Memory") {
59 schemaNames.insert(schemaOp.getSymNameAttr());
60 opsToDelete.push_back(schemaOp);
61 }
62 }
63 LLVM_DEBUG(llvm::dbgs() << "Found " << schemaNames.size() << " schemas\n");
64
65 // Find generated ops using these schemas.
66 for (auto genOp : module.getOps<hw::HWModuleGeneratedOp>()) {
67 if (!schemaNames.contains(genOp.getGeneratorKindAttr().getAttr()))
68 continue;
69 memoryParams[genOp.getModuleNameAttr()] = genOp->getAttrDictionary();
70 opsToDelete.push_back(genOp);
71 }
72 LLVM_DEBUG(llvm::dbgs() << "Found " << memoryParams.size()
73 << " memory modules\n");
74
75 // Convert instances of the generated ops into dedicated memories.
76 unsigned numReplaced = 0;
77 module.walk([&](hw::InstanceOp instOp) {
78 auto it = memoryParams.find(instOp.getModuleNameAttr().getAttr());
79 if (it == memoryParams.end())
80 return;
81 ++numReplaced;
82 DictionaryAttr params = it->second;
83 auto width = params.getAs<IntegerAttr>("width").getValue().getZExtValue();
84 auto depth = params.getAs<IntegerAttr>("depth").getValue().getZExtValue();
85 auto maskGranAttr = params.getAs<IntegerAttr>("maskGran");
86 auto maskGran =
87 maskGranAttr ? maskGranAttr.getValue().getZExtValue() : width;
88 auto maskBits = width / maskGran;
89
90 auto writeLatency =
91 params.getAs<IntegerAttr>("writeLatency").getValue().getZExtValue();
92 auto readLatency =
93 params.getAs<IntegerAttr>("readLatency").getValue().getZExtValue();
94 if (writeLatency != 1) {
95 instOp.emitError("unsupported memory write latency ") << writeLatency;
96 return signalPassFailure();
97 }
98
99 // FIRRTL memories are currently underspecified. They have a single read
100 // latency, but it is unclear where within this latency the read of the
101 // underlying storage happens. The `HWMemSimImpl` pass implements memories
102 // such that the storage is probed at the very end of the latency, and that
103 // the probed value becomes available immediately. We keep the latencies
104 // configurable here, in the hopes that we'll improve our memory
105 // abstractions at some point.
106 unsigned readPreLatency = readLatency; // cycles before storage is read
107 unsigned readPostLatency = 0; // cycles to move read value to output
108
109 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
110 auto wordType = builder.getIntegerType(width);
111 auto addressTy = dyn_cast<IntegerType>(instOp.getOperand(0).getType());
112 if (!addressTy) {
113 instOp.emitError("expected integer type for memory addressing, got ")
114 << addressTy;
115 return signalPassFailure();
116 }
117 auto memType = MemoryType::get(&getContext(), depth, wordType, addressTy);
118 auto memOp = builder.create<MemoryOp>(memType);
119 if (tapMemories && !instOp.getInstanceName().empty())
120 memOp->setAttr("name", instOp.getInstanceNameAttr());
121
122 unsigned argIdx = 0;
123 unsigned resultIdx = 0;
124
125 auto applyLatency = [&](Value clock, Value data, unsigned latency) {
126 for (unsigned i = 0; i < latency; ++i)
127 data = builder.create<seq::CompRegOp>(
128 data, clock, builder.getStringAttr(""), Value{}, Value{}, Value{},
129 hw::InnerSymAttr{});
130 return data;
131 };
132
133 SmallVector<std::tuple<Value, Value, SmallVector<Value>, bool, bool>>
135
136 // Use `<inst-name>/` as the prefix for all port taps.
137 SmallString<64> tapPrefix(instOp.getInstanceName());
138 if (!tapPrefix.empty())
139 tapPrefix.push_back('/');
140 auto tapPrefixBaseLen = tapPrefix.size();
141
142 auto tap = [&](Value value, const Twine &name) {
143 auto prefixedName = builder.getStringAttr(tapPrefix + "_" + name);
144 builder.create<arc::TapOp>(value, prefixedName);
145 };
146
147 // Handle read ports.
148 auto numReadPorts =
149 params.getAs<IntegerAttr>("numReadPorts").getValue().getZExtValue();
150 for (unsigned portIdx = 0; portIdx != numReadPorts; ++portIdx) {
151 auto address = instOp.getOperand(argIdx++);
152 auto enable = instOp.getOperand(argIdx++);
153 auto clock = instOp.getOperand(argIdx++);
154 auto data = instOp.getResult(resultIdx++);
155
156 if (address.getType() != addressTy) {
157 instOp.emitOpError("expected ")
158 << addressTy << ", but got " << address.getType();
159 return signalPassFailure();
160 }
161
162 // Add port taps.
163 if (tapPorts) {
164 tapPrefix.resize(tapPrefixBaseLen);
165 (Twine("R") + Twine(portIdx)).toVector(tapPrefix);
166 tap(address, "addr");
167 tap(enable, "en");
168 tap(data, "data");
169 }
170
171 // Apply the latency before the underlying storage is accessed.
172 address = applyLatency(clock, address, readPreLatency);
173 enable = applyLatency(clock, enable, readPreLatency);
174
175 // Read the underlying storage. (The result of a disabled read port is
176 // undefined, currently we define it to be zero.)
177 Value readOp = builder.create<MemoryReadPortOp>(wordType, memOp, address);
178 Value zero = builder.create<hw::ConstantOp>(wordType, 0);
179 readOp = builder.create<comb::MuxOp>(enable, readOp, zero);
180
181 // Apply the latency after the underlying storage was accessed. (If the
182 // latency is 0, the memory read is combinatorial without any buffer.)
183 readOp = applyLatency(clock, readOp, readPostLatency);
184 data.replaceAllUsesWith(readOp);
185 }
186
187 // Handle read-write ports.
188 auto numReadWritePorts = params.getAs<IntegerAttr>("numReadWritePorts")
189 .getValue()
190 .getZExtValue();
191 for (unsigned portIdx = 0; portIdx != numReadWritePorts; ++portIdx) {
192 auto address = instOp.getOperand(argIdx++);
193 auto enable = instOp.getOperand(argIdx++);
194 auto clock = instOp.getOperand(argIdx++);
195 auto writeMode = instOp.getOperand(argIdx++);
196 auto writeData = instOp.getOperand(argIdx++);
197 auto writeMask = maskBits > 1 ? instOp.getOperand(argIdx++) : Value{};
198 auto readData = instOp.getResult(resultIdx++);
199
200 if (address.getType() != addressTy) {
201 instOp.emitOpError("expected ")
202 << addressTy << ", but got " << address.getType();
203 return signalPassFailure();
204 }
205
206 // Add port taps.
207 if (tapPorts) {
208 tapPrefix.resize(tapPrefixBaseLen);
209 (Twine("RW") + Twine(portIdx)).toVector(tapPrefix);
210 tap(address, "addr");
211 tap(enable, "en");
212 tap(writeMode, "wmode");
213 tap(writeData, "wdata");
214 if (writeMask)
215 tap(writeMask, "wmask");
216 tap(readData, "rdata");
217 }
218
219 auto c1_i1 = builder.create<hw::ConstantOp>(builder.getI1Type(), 1);
220 auto notWriteMode = builder.create<comb::XorOp>(writeMode, c1_i1);
221 Value readEnable = builder.create<comb::AndOp>(enable, notWriteMode);
222
223 // Apply the latency before the underlying storage is accessed.
224 Value readAddress = applyLatency(clock, address, readPreLatency);
225 readEnable = applyLatency(clock, readEnable, readPreLatency);
226
227 // Read the underlying storage. (The result of a disabled read port is
228 // undefined, currently we define it to be zero.)
229 Value readOp =
230 builder.create<MemoryReadPortOp>(wordType, memOp, readAddress);
231 Value zero = builder.create<hw::ConstantOp>(wordType, 0);
232 readOp = builder.create<comb::MuxOp>(readEnable, readOp, zero);
233
234 if (writeMask) {
235 unsigned maskWidth = cast<IntegerType>(writeMask.getType()).getWidth();
236 SmallVector<Value> toConcat;
237 for (unsigned i = 0; i < maskWidth; ++i) {
238 Value bit = builder.create<comb::ExtractOp>(writeMask, i, 1);
239 Value replicated = builder.create<comb::ReplicateOp>(bit, maskGran);
240 toConcat.push_back(replicated);
241 }
242 std::reverse(toConcat.begin(), toConcat.end()); // I hate concat
243 writeMask =
244 builder.create<comb::ConcatOp>(writeData.getType(), toConcat);
245 }
246
247 // Apply the latency after the underlying storage was accessed. (If the
248 // latency is 0, the memory read is combinatorial without any buffer.)
249 readOp = applyLatency(clock, readOp, readPostLatency);
250 readData.replaceAllUsesWith(readOp);
251
252 auto writeEnable = builder.create<comb::AndOp>(enable, writeMode);
253 SmallVector<Value> inputs({address, writeData, writeEnable});
254 if (writeMask)
255 inputs.push_back(writeMask);
256 writePorts.push_back({memOp, clock, inputs, true, !!writeMask});
257 }
258
259 // Handle write ports.
260 auto numWritePorts =
261 params.getAs<IntegerAttr>("numWritePorts").getValue().getZExtValue();
262 for (unsigned portIdx = 0; portIdx != numWritePorts; ++portIdx) {
263 auto address = instOp.getOperand(argIdx++);
264 auto enable = instOp.getOperand(argIdx++);
265 auto clock = instOp.getOperand(argIdx++);
266 auto data = instOp.getOperand(argIdx++);
267 auto mask = maskBits > 1 ? instOp.getOperand(argIdx++) : Value{};
268
269 if (address.getType() != addressTy) {
270 instOp.emitOpError("expected ")
271 << addressTy << ", but got " << address.getType();
272 return signalPassFailure();
273 }
274
275 // Add port taps.
276 if (tapPorts) {
277 tapPrefix.resize(tapPrefixBaseLen);
278 (Twine("W") + Twine(portIdx)).toVector(tapPrefix);
279 tap(address, "addr");
280 tap(enable, "en");
281 tap(data, "data");
282 if (mask)
283 tap(mask, "mask");
284 }
285
286 if (mask) {
287 unsigned maskWidth = cast<IntegerType>(mask.getType()).getWidth();
288 SmallVector<Value> toConcat;
289 for (unsigned i = 0; i < maskWidth; ++i) {
290 Value bit = builder.create<comb::ExtractOp>(mask, i, 1);
291 Value replicated = builder.create<comb::ReplicateOp>(bit, maskGran);
292 toConcat.push_back(replicated);
293 }
294 std::reverse(toConcat.begin(), toConcat.end()); // I hate concat
295 mask = builder.create<comb::ConcatOp>(data.getType(), toConcat);
296 }
297 SmallVector<Value> inputs({address, data});
298 if (enable)
299 inputs.push_back(enable);
300 if (mask)
301 inputs.push_back(mask);
302 writePorts.push_back({memOp, clock, inputs, !!enable, !!mask});
303 }
304
305 // Create the actual write ports with a dependency arc to all read
306 // ports.
307 for (auto [memOp, clock, inputs, hasEnable, hasMask] : writePorts) {
308 auto ipSave = builder.saveInsertionPoint();
309 TypeRange types = ValueRange(inputs).getTypes();
310 builder.setInsertionPointToStart(module.getBody());
311 auto defOp = builder.create<DefineOp>(
312 names.newName("mem_write"), builder.getFunctionType(types, types));
313 auto &block = defOp.getBody().emplaceBlock();
314 auto args = block.addArguments(
315 types, SmallVector<Location>(types.size(), builder.getLoc()));
316 builder.setInsertionPointToEnd(&block);
317 builder.create<arc::OutputOp>(SmallVector<Value>(args));
318 builder.restoreInsertionPoint(ipSave);
319 builder.create<MemoryWritePortOp>(memOp, defOp.getName(), inputs, clock,
320 hasEnable, hasMask);
321 }
322
323 opsToDelete.push_back(instOp);
324 });
325 LLVM_DEBUG(llvm::dbgs() << "Inferred " << numReplaced << " memories\n");
326
327 for (auto *op : opsToDelete)
328 op->erase();
329}
330
331std::unique_ptr<Pass>
332arc::createInferMemoriesPass(const InferMemoriesOptions &options) {
333 return std::make_unique<InferMemoriesPass>(options);
334}
std::map< std::string, WriteChannelPort & > writePorts
static std::vector< mlir::Value > toVector(mlir::ValueRange range)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:85
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
std::unique_ptr< mlir::Pass > createInferMemoriesPass(const InferMemoriesOptions &options={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1