CIRCT 23.0.0git
Loading...
Searching...
No Matches
MemToRegOfVec.cpp
Go to the documentation of this file.
1//===- MemToRegOfVec.cpp - MemToRegOfVec Pass -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MemToRegOfVec pass.
10//
11//===----------------------------------------------------------------------===//
12
19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "mem-to-reg-of-vec"
24
25namespace circt {
26namespace firrtl {
27#define GEN_PASS_DEF_MEMTOREGOFVEC
28#include "circt/Dialect/FIRRTL/Passes.h.inc"
29} // namespace firrtl
30} // namespace circt
31
32using namespace circt;
33using namespace firrtl;
34
35namespace {
36struct MemToRegOfVecPass
37 : public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
38 using Base::Base;
39
40 void runOnOperation() override {
41 auto circtOp = getOperation();
42 auto &instanceInfo = getAnalysis<InstanceInfo>();
43
45 convertMemToRegOfVecAnnoClass))
46 return markAllAnalysesPreserved();
47
48 DenseSet<Operation *> dutModuleSet;
49 for (auto moduleOp : circtOp.getOps<FModuleOp>())
50 if (instanceInfo.anyInstanceInEffectiveDesign(moduleOp))
51 dutModuleSet.insert(moduleOp);
52
53 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
54 [&](Operation *op) {
55 if (auto mod = dyn_cast<FModuleOp>(op))
56 runOnModule(mod);
57 });
58 }
59
60 void runOnModule(FModuleOp mod) {
61
62 mod.getBodyBlock()->walk([&](MemOp memOp) {
63 LLVM_DEBUG(llvm::dbgs() << "\n Memory op:" << memOp);
64
65 auto firMem = memOp.getSummary();
66 // Ignore if the memory is a sequential memory, i.e., something that is
67 // supposed to be an SRAM. In either possible eventual lowering by later
68 // passes (blackboxing or lowering to a behavioral model) we don't want to
69 // blow this out here as it both breaks expectations about later passes
70 // that may add asynchronous resets (InferResets) or that expect metadata
71 // on SRAMs to not be split up (LowerClasses).
72 if (firMem.isSeqMem())
73 return;
74
75 generateMemory(memOp, firMem);
76 ++numConvertedMems;
77 memOp.erase();
78 });
79 }
80 Value addPipelineStages(ImplicitLocOpBuilder &b, size_t stages, Value clock,
81 Value pipeInput, StringRef name, Value gate = {}) {
82 if (!stages)
83 return pipeInput;
84
85 while (stages--) {
86 auto reg = RegOp::create(b, pipeInput.getType(), clock, name).getResult();
87 if (gate) {
88 WhenOp::create(b, gate, /*withElseRegion*/ false,
89 [&]() { MatchingConnectOp::create(b, reg, pipeInput); });
90 } else
91 MatchingConnectOp::create(b, reg, pipeInput);
92
93 pipeInput = reg;
94 }
95
96 return pipeInput;
97 }
98
99 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
100 return SubfieldOp::create(builder, bundle, "clk");
101 }
102
103 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
104 return SubfieldOp::create(builder, bundle, "addr");
105 }
106
107 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
108 return SubfieldOp::create(builder, bundle, "wmode");
109 }
110
111 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
112 return SubfieldOp::create(builder, bundle, "en");
113 }
114
115 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
116 auto bType = type_cast<BundleType>(bundle.getType());
117 if (bType.getElement("mask"))
118 return SubfieldOp::create(builder, bundle, "mask");
119 return SubfieldOp::create(builder, bundle, "wmask");
120 }
121
122 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
123 bool getWdata = false) {
124 auto bType = type_cast<BundleType>(bundle.getType());
125 if (bType.getElement("data"))
126 return SubfieldOp::create(builder, bundle, "data");
127 if (bType.getElement("rdata") && !getWdata)
128 return SubfieldOp::create(builder, bundle, "rdata");
129 return SubfieldOp::create(builder, bundle, "wdata");
130 }
131
132 void generateRead(const FirMemory &firMem, Value clock, Value addr,
133 Value enable, Value data, Value regOfVec,
134 ImplicitLocOpBuilder &builder) {
135 if (ignoreReadEnable) {
136 // If read enable is ignored, then guard the address update with read
137 // enable.
138 for (size_t j = 0, e = firMem.readLatency; j != e; ++j) {
139 auto enLast = enable;
140 if (j < e - 1)
141 enable = addPipelineStages(builder, 1, clock, enable, "en");
142 addr = addPipelineStages(builder, 1, clock, addr, "addr", enLast);
143 }
144 } else {
145 // Add pipeline stages to respect the read latency. One register for each
146 // latency cycle.
147 enable =
148 addPipelineStages(builder, firMem.readLatency, clock, enable, "en");
149 addr =
150 addPipelineStages(builder, firMem.readLatency, clock, addr, "addr");
151 }
152
153 // Read the register[address] into a temporary.
154 Value rdata = SubaccessOp::create(builder, regOfVec, addr);
155 if (!ignoreReadEnable) {
156 // Initialize read data out with invalid.
157 MatchingConnectOp::create(
158 builder, data, InvalidValueOp::create(builder, data.getType()));
159 // If enable is true, then connect the data read from memory register.
160 WhenOp::create(builder, enable, /*withElseRegion*/ false, [&]() {
161 MatchingConnectOp::create(builder, data, rdata);
162 });
163 } else {
164 // Ignore read enable signal.
165 MatchingConnectOp::create(builder, data, rdata);
166 }
167 }
168
169 void generateWrite(const FirMemory &firMem, Value clock, Value addr,
170 Value enable, Value maskBits, Value wdataIn,
171 Value regOfVec, ImplicitLocOpBuilder &builder) {
172
173 auto numStages = firMem.writeLatency - 1;
174 // Add pipeline stages to respect the write latency. Intermediate registers
175 // for each stage.
176 addr = addPipelineStages(builder, numStages, clock, addr, "addr");
177 enable = addPipelineStages(builder, numStages, clock, enable, "en");
178 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
179 maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
180 // Create the register access.
181 FIRRTLBaseValue rdata = SubaccessOp::create(builder, regOfVec, addr);
182
183 // The tuple for the access to individual fields of an aggregate data type.
184 // Tuple::<register, data, mask>
185 // The logic:
186 // if (mask)
187 // register = data
188 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
189
190 // Write to each aggregate data field is guarded by the corresponding mask
191 // field. This means we have to generate read and write access for each
192 // individual field of the aggregate type.
193 // There are two options to handle this,
194 // 1. FlattenMemory: cast the aggregate data into a UInt and generate
195 // appropriate mask logic.
196 // 2. Create access for each individual field of the aggregate type.
197 // Here we implement the option 2 using getFields.
198 // getFields, creates an access to each individual field of the data and
199 // mask, and the corresponding field into the register. It populates
200 // the loweredRegDataMaskFields vector.
201 // This is similar to what happens in LowerTypes.
202 //
203 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
204 builder)) {
205 wdataIn.getDefiningOp()->emitOpError(
206 "Cannot convert memory to bank of registers");
207 return;
208 }
209 // If enable:
210 WhenOp::create(builder, enable, /*withElseRegion*/ false, [&]() {
211 // For each data field. Only one field if not aggregate.
212 for (auto regDataMask : loweredRegDataMaskFields) {
213 auto regField = std::get<0>(regDataMask);
214 auto dataField = std::get<1>(regDataMask);
215 auto maskField = std::get<2>(regDataMask);
216 // If mask, then update the register field.
217 WhenOp::create(builder, maskField, /*withElseRegion*/ false, [&]() {
218 MatchingConnectOp::create(builder, regField, dataField);
219 });
220 }
221 });
222 }
223
224 void generateReadWrite(const FirMemory &firMem, Value clock, Value addr,
225 Value enable, Value maskBits, Value wdataIn,
226 Value rdataOut, Value wmode, Value regOfVec,
227 ImplicitLocOpBuilder &builder) {
228
229 // Add pipeline stages to respect the write latency. Intermediate registers
230 // for each stage. Number of pipeline stages, max of read/write latency.
231 auto numStages = std::max(firMem.readLatency, firMem.writeLatency) - 1;
232 addr = addPipelineStages(builder, numStages, clock, addr, "addr");
233 enable = addPipelineStages(builder, numStages, clock, enable, "en");
234 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
235 maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
236
237 // Read the register[address] into a temporary.
238 Value rdata = SubaccessOp::create(builder, regOfVec, addr);
239
240 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
241 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
242 builder)) {
243 wdataIn.getDefiningOp()->emitOpError(
244 "Cannot convert memory to bank of registers");
245 return;
246 }
247 // Initialize read data out with invalid.
248 MatchingConnectOp::create(
249 builder, rdataOut, InvalidValueOp::create(builder, rdataOut.getType()));
250 // If enable:
251 WhenOp::create(builder, enable, /*withElseRegion*/ false, [&]() {
252 // If write mode:
253 WhenOp::create(
254 builder, wmode, true,
255 // Write block:
256 [&]() {
257 // For each data field. Only one field if not aggregate.
258 for (auto regDataMask : loweredRegDataMaskFields) {
259 auto regField = std::get<0>(regDataMask);
260 auto dataField = std::get<1>(regDataMask);
261 auto maskField = std::get<2>(regDataMask);
262 // If mask true, then set the field.
263 WhenOp::create(
264 builder, maskField, /*withElseRegion*/ false, [&]() {
265 MatchingConnectOp::create(builder, regField, dataField);
266 });
267 }
268 },
269 // Read block:
270 [&]() { MatchingConnectOp::create(builder, rdataOut, rdata); });
271 });
272 }
273
274 // Generate individual field accesses for an aggregate type. Return false if
275 // it fails. Which can happen if invalid fields are present of the mask and
276 // input types donot match. The assumption is that, \p reg and \p input have
277 // exactly the same type. And \p mask has the same bundle fields, but each
278 // field is of type UInt<1> So, populate the \p results with each field
279 // access. For example, the first entry should be access to first field of \p
280 // reg, first field of \p input and first field of \p mask.
281 bool getFields(Value reg, Value input, Value mask,
282 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
283 ImplicitLocOpBuilder &builder) {
284
285 // Check if the number of fields of mask and input type match.
286 auto isValidMask = [&](FIRRTLType inType, FIRRTLType maskType) -> bool {
287 if (auto bundle = type_dyn_cast<BundleType>(inType)) {
288 if (auto mBundle = type_dyn_cast<BundleType>(maskType))
289 return mBundle.getNumElements() == bundle.getNumElements();
290 } else if (auto vec = type_dyn_cast<FVectorType>(inType)) {
291 if (auto mVec = type_dyn_cast<FVectorType>(maskType))
292 return mVec.getNumElements() == vec.getNumElements();
293 } else
294 return true;
295 return false;
296 };
297
298 std::function<bool(Value, Value, Value)> flatAccess =
299 [&](Value reg, Value input, Value mask) -> bool {
300 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
301 if (!isValidMask(inType, type_cast<FIRRTLType>(mask.getType()))) {
302 input.getDefiningOp()->emitOpError("Mask type is not valid");
303 return false;
304 }
306 .Case<BundleType>([&](BundleType bundle) {
307 for (size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
308 auto regField = SubfieldOp::create(builder, reg, i);
309 auto inputField = SubfieldOp::create(builder, input, i);
310 auto maskField = SubfieldOp::create(builder, mask, i);
311 if (!flatAccess(regField, inputField, maskField))
312 return false;
313 }
314 return true;
315 })
316 .Case<FVectorType>([&](auto vector) {
317 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
318 auto regField = SubindexOp::create(builder, reg, i);
319 auto inputField = SubindexOp::create(builder, input, i);
320 auto maskField = SubindexOp::create(builder, mask, i);
321 if (!flatAccess(regField, inputField, maskField))
322 return false;
323 }
324 return true;
325 })
326 .Case<IntType>([&](auto iType) {
327 results.push_back({reg, input, mask});
328 return iType.getWidth().has_value();
329 })
330 .Default([&](auto) { return false; });
331 };
332 if (flatAccess(reg, input, mask))
333 return true;
334 return false;
335 }
336
337 /// Generate the logic for implementing the memory using Registers.
338 void generateMemory(MemOp memOp, FirMemory &firMem) {
339 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
340 auto dataType = memOp.getDataType();
341
342 auto innerSym = memOp.getInnerSym();
343 SmallVector<Value> debugPorts;
344
345 RegOp regOfVec = {};
346 for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
347 ++index) {
348 auto result = memOp.getResult(index);
349 if (type_isa<RefType>(result.getType())) {
350 debugPorts.push_back(result);
351 continue;
352 }
353 // Create a temporary wire to replace the memory port. This makes it
354 // simpler to delete the memOp.
355 auto wire = WireOp::create(
356 builder, result.getType(),
357 (memOp.getName() + "_" + memOp.getPortName(index)).str(),
358 memOp.getNameKind());
359 result.replaceAllUsesWith(wire.getResult());
360 result = wire.getResult();
361 // Create an access to all the common subfields.
362 auto adr = getAddr(builder, result);
363 auto enb = getEnable(builder, result);
364 auto clk = getClock(builder, result);
365 auto dta = getData(builder, result);
366 // IF the register is not yet created.
367 if (!regOfVec) {
368 // Create the register corresponding to the memory.
369 regOfVec =
370 RegOp::create(builder, FVectorType::get(dataType, firMem.depth),
371 clk, memOp.getNameAttr());
372
373 // Copy all the memory annotations.
374 if (!memOp.getAnnotationsAttr().empty())
375 regOfVec.setAnnotationsAttr(memOp.getAnnotationsAttr());
376 if (innerSym)
377 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
378 }
379 auto portKind = memOp.getPortKind(index);
380 if (portKind == MemOp::PortKind::Read) {
381 generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
382 } else if (portKind == MemOp::PortKind::Write) {
383 auto mask = getMask(builder, result);
384 generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
385 builder);
386 } else {
387 auto wmode = getWmode(builder, result);
388 auto wDta = getData(builder, result, true);
389 auto mask = getMask(builder, result);
390 generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
391 regOfVec.getResult(), builder);
392 }
393 }
394 // If a valid register is created, then replace all the debug port users
395 // with a RefType of the register. The RefType is obtained by using a
396 // RefSend on the register.
397 if (regOfVec)
398 for (auto r : debugPorts)
399 r.replaceAllUsesWith(RefSendOp::create(builder, regOfVec.getResult()));
400 }
401};
402} // end anonymous namespace
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21