Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MemToRegOfVec.cpp
Go to the documentation of this file.
1//===- MemToRegOfVec.cpp - MemToRegOfVec Pass -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MemToRegOfVec pass.
10//
11//===----------------------------------------------------------------------===//
12
19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "mem-to-reg-of-vec"
24
25namespace circt {
26namespace firrtl {
27#define GEN_PASS_DEF_MEMTOREGOFVEC
28#include "circt/Dialect/FIRRTL/Passes.h.inc"
29} // namespace firrtl
30} // namespace circt
31
32using namespace circt;
33using namespace firrtl;
34
35namespace {
36struct MemToRegOfVecPass
37 : public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
38 using Base::Base;
39
40 void runOnOperation() override {
41 auto circtOp = getOperation();
42 auto &instanceInfo = getAnalysis<InstanceInfo>();
43
46 return markAllAnalysesPreserved();
47
48 DenseSet<Operation *> dutModuleSet;
49 for (auto moduleOp : circtOp.getOps<FModuleOp>())
50 if (instanceInfo.anyInstanceInEffectiveDesign(moduleOp))
51 dutModuleSet.insert(moduleOp);
52
53 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
54 [&](Operation *op) {
55 if (auto mod = dyn_cast<FModuleOp>(op))
56 runOnModule(mod);
57 });
58 }
59
60 void runOnModule(FModuleOp mod) {
61
62 mod.getBodyBlock()->walk([&](MemOp memOp) {
63 LLVM_DEBUG(llvm::dbgs() << "\n Memory op:" << memOp);
64
65 auto firMem = memOp.getSummary();
66 // Ignore if the memory is candidate for macro replacement.
67 // The requirements for macro replacement:
68 // 1. read latency and write latency of one.
69 // 2. only one readwrite port or write port.
70 // 3. zero or one read port.
71 // 4. undefined read-under-write behavior.
72 if (replSeqMem &&
73 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
74 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
75 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
76 return;
77
78 generateMemory(memOp, firMem);
79 ++numConvertedMems;
80 memOp.erase();
81 });
82 }
83 Value addPipelineStages(ImplicitLocOpBuilder &b, size_t stages, Value clock,
84 Value pipeInput, StringRef name, Value gate = {}) {
85 if (!stages)
86 return pipeInput;
87
88 while (stages--) {
89 auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
90 if (gate) {
91 b.create<WhenOp>(gate, /*withElseRegion*/ false, [&]() {
92 b.create<MatchingConnectOp>(reg, pipeInput);
93 });
94 } else
95 b.create<MatchingConnectOp>(reg, pipeInput);
96
97 pipeInput = reg;
98 }
99
100 return pipeInput;
101 }
102
103 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
104 return builder.create<SubfieldOp>(bundle, "clk");
105 }
106
107 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
108 return builder.create<SubfieldOp>(bundle, "addr");
109 }
110
111 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
112 return builder.create<SubfieldOp>(bundle, "wmode");
113 }
114
115 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
116 return builder.create<SubfieldOp>(bundle, "en");
117 }
118
119 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
120 auto bType = type_cast<BundleType>(bundle.getType());
121 if (bType.getElement("mask"))
122 return builder.create<SubfieldOp>(bundle, "mask");
123 return builder.create<SubfieldOp>(bundle, "wmask");
124 }
125
126 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
127 bool getWdata = false) {
128 auto bType = type_cast<BundleType>(bundle.getType());
129 if (bType.getElement("data"))
130 return builder.create<SubfieldOp>(bundle, "data");
131 if (bType.getElement("rdata") && !getWdata)
132 return builder.create<SubfieldOp>(bundle, "rdata");
133 return builder.create<SubfieldOp>(bundle, "wdata");
134 }
135
136 void generateRead(const FirMemory &firMem, Value clock, Value addr,
137 Value enable, Value data, Value regOfVec,
138 ImplicitLocOpBuilder &builder) {
139 if (ignoreReadEnable) {
140 // If read enable is ignored, then guard the address update with read
141 // enable.
142 for (size_t j = 0, e = firMem.readLatency; j != e; ++j) {
143 auto enLast = enable;
144 if (j < e - 1)
145 enable = addPipelineStages(builder, 1, clock, enable, "en");
146 addr = addPipelineStages(builder, 1, clock, addr, "addr", enLast);
147 }
148 } else {
149 // Add pipeline stages to respect the read latency. One register for each
150 // latency cycle.
151 enable =
152 addPipelineStages(builder, firMem.readLatency, clock, enable, "en");
153 addr =
154 addPipelineStages(builder, firMem.readLatency, clock, addr, "addr");
155 }
156
157 // Read the register[address] into a temporary.
158 Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
159 if (!ignoreReadEnable) {
160 // Initialize read data out with invalid.
161 builder.create<MatchingConnectOp>(
162 data, builder.create<InvalidValueOp>(data.getType()));
163 // If enable is true, then connect the data read from memory register.
164 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
165 builder.create<MatchingConnectOp>(data, rdata);
166 });
167 } else {
168 // Ignore read enable signal.
169 builder.create<MatchingConnectOp>(data, rdata);
170 }
171 }
172
173 void generateWrite(const FirMemory &firMem, Value clock, Value addr,
174 Value enable, Value maskBits, Value wdataIn,
175 Value regOfVec, ImplicitLocOpBuilder &builder) {
176
177 auto numStages = firMem.writeLatency - 1;
178 // Add pipeline stages to respect the write latency. Intermediate registers
179 // for each stage.
180 addr = addPipelineStages(builder, numStages, clock, addr, "addr");
181 enable = addPipelineStages(builder, numStages, clock, enable, "en");
182 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
183 maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
184 // Create the register access.
185 FIRRTLBaseValue rdata = builder.create<SubaccessOp>(regOfVec, addr);
186
187 // The tuple for the access to individual fields of an aggregate data type.
188 // Tuple::<register, data, mask>
189 // The logic:
190 // if (mask)
191 // register = data
192 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
193
194 // Write to each aggregate data field is guarded by the corresponding mask
195 // field. This means we have to generate read and write access for each
196 // individual field of the aggregate type.
197 // There are two options to handle this,
198 // 1. FlattenMemory: cast the aggregate data into a UInt and generate
199 // appropriate mask logic.
200 // 2. Create access for each individual field of the aggregate type.
201 // Here we implement the option 2 using getFields.
202 // getFields, creates an access to each individual field of the data and
203 // mask, and the corresponding field into the register. It populates
204 // the loweredRegDataMaskFields vector.
205 // This is similar to what happens in LowerTypes.
206 //
207 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
208 builder)) {
209 wdataIn.getDefiningOp()->emitOpError(
210 "Cannot convert memory to bank of registers");
211 return;
212 }
213 // If enable:
214 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
215 // For each data field. Only one field if not aggregate.
216 for (auto regDataMask : loweredRegDataMaskFields) {
217 auto regField = std::get<0>(regDataMask);
218 auto dataField = std::get<1>(regDataMask);
219 auto maskField = std::get<2>(regDataMask);
220 // If mask, then update the register field.
221 builder.create<WhenOp>(maskField, /*withElseRegion*/ false, [&]() {
222 builder.create<MatchingConnectOp>(regField, dataField);
223 });
224 }
225 });
226 }
227
228 void generateReadWrite(const FirMemory &firMem, Value clock, Value addr,
229 Value enable, Value maskBits, Value wdataIn,
230 Value rdataOut, Value wmode, Value regOfVec,
231 ImplicitLocOpBuilder &builder) {
232
233 // Add pipeline stages to respect the write latency. Intermediate registers
234 // for each stage. Number of pipeline stages, max of read/write latency.
235 auto numStages = std::max(firMem.readLatency, firMem.writeLatency) - 1;
236 addr = addPipelineStages(builder, numStages, clock, addr, "addr");
237 enable = addPipelineStages(builder, numStages, clock, enable, "en");
238 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
239 maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
240
241 // Read the register[address] into a temporary.
242 Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
243
244 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
245 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
246 builder)) {
247 wdataIn.getDefiningOp()->emitOpError(
248 "Cannot convert memory to bank of registers");
249 return;
250 }
251 // Initialize read data out with invalid.
252 builder.create<MatchingConnectOp>(
253 rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
254 // If enable:
255 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
256 // If write mode:
257 builder.create<WhenOp>(
258 wmode, true,
259 // Write block:
260 [&]() {
261 // For each data field. Only one field if not aggregate.
262 for (auto regDataMask : loweredRegDataMaskFields) {
263 auto regField = std::get<0>(regDataMask);
264 auto dataField = std::get<1>(regDataMask);
265 auto maskField = std::get<2>(regDataMask);
266 // If mask true, then set the field.
267 builder.create<WhenOp>(
268 maskField, /*withElseRegion*/ false, [&]() {
269 builder.create<MatchingConnectOp>(regField, dataField);
270 });
271 }
272 },
273 // Read block:
274 [&]() { builder.create<MatchingConnectOp>(rdataOut, rdata); });
275 });
276 }
277
278 // Generate individual field accesses for an aggregate type. Return false if
279 // it fails. Which can happen if invalid fields are present of the mask and
280 // input types donot match. The assumption is that, \p reg and \p input have
281 // exactly the same type. And \p mask has the same bundle fields, but each
282 // field is of type UInt<1> So, populate the \p results with each field
283 // access. For example, the first entry should be access to first field of \p
284 // reg, first field of \p input and first field of \p mask.
285 bool getFields(Value reg, Value input, Value mask,
286 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
287 ImplicitLocOpBuilder &builder) {
288
289 // Check if the number of fields of mask and input type match.
290 auto isValidMask = [&](FIRRTLType inType, FIRRTLType maskType) -> bool {
291 if (auto bundle = type_dyn_cast<BundleType>(inType)) {
292 if (auto mBundle = type_dyn_cast<BundleType>(maskType))
293 return mBundle.getNumElements() == bundle.getNumElements();
294 } else if (auto vec = type_dyn_cast<FVectorType>(inType)) {
295 if (auto mVec = type_dyn_cast<FVectorType>(maskType))
296 return mVec.getNumElements() == vec.getNumElements();
297 } else
298 return true;
299 return false;
300 };
301
302 std::function<bool(Value, Value, Value)> flatAccess =
303 [&](Value reg, Value input, Value mask) -> bool {
304 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
305 if (!isValidMask(inType, type_cast<FIRRTLType>(mask.getType()))) {
306 input.getDefiningOp()->emitOpError("Mask type is not valid");
307 return false;
308 }
310 .Case<BundleType>([&](BundleType bundle) {
311 for (size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
312 auto regField = builder.create<SubfieldOp>(reg, i);
313 auto inputField = builder.create<SubfieldOp>(input, i);
314 auto maskField = builder.create<SubfieldOp>(mask, i);
315 if (!flatAccess(regField, inputField, maskField))
316 return false;
317 }
318 return true;
319 })
320 .Case<FVectorType>([&](auto vector) {
321 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
322 auto regField = builder.create<SubindexOp>(reg, i);
323 auto inputField = builder.create<SubindexOp>(input, i);
324 auto maskField = builder.create<SubindexOp>(mask, i);
325 if (!flatAccess(regField, inputField, maskField))
326 return false;
327 }
328 return true;
329 })
330 .Case<IntType>([&](auto iType) {
331 results.push_back({reg, input, mask});
332 return iType.getWidth().has_value();
333 })
334 .Default([&](auto) { return false; });
335 };
336 if (flatAccess(reg, input, mask))
337 return true;
338 return false;
339 }
340
341 void scatterMemTapAnno(RegOp op, ArrayAttr attr,
342 ImplicitLocOpBuilder &builder) {
343 AnnotationSet annos(attr);
344 SmallVector<Attribute> regAnnotations;
345 auto vecType = type_cast<FVectorType>(op.getResult().getType());
346 for (auto anno : annos) {
347 if (anno.isClass(memTapSourceClass)) {
348 for (size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
349 .getNumElements();
350 i != e; ++i) {
351 NamedAttrList newAnno;
352 newAnno.append("class", anno.getMember("class"));
353 newAnno.append("circt.fieldID",
354 builder.getI64IntegerAttr(vecType.getFieldID(i)));
355 newAnno.append("id", anno.getMember("id"));
356 if (auto nla = anno.getMember("circt.nonlocal"))
357 newAnno.append("circt.nonlocal", nla);
358 newAnno.append(
359 "portID",
360 IntegerAttr::get(IntegerType::get(builder.getContext(), 64), i));
361
362 regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
363 }
364 } else
365 regAnnotations.push_back(anno.getAttr());
366 }
367 op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
368 }
369
370 /// Generate the logic for implementing the memory using Registers.
371 void generateMemory(MemOp memOp, FirMemory &firMem) {
372 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
373 auto dataType = memOp.getDataType();
374
375 auto innerSym = memOp.getInnerSym();
376 SmallVector<Value> debugPorts;
377
378 RegOp regOfVec = {};
379 for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
380 ++index) {
381 auto result = memOp.getResult(index);
382 if (type_isa<RefType>(result.getType())) {
383 debugPorts.push_back(result);
384 continue;
385 }
386 // Create a temporary wire to replace the memory port. This makes it
387 // simpler to delete the memOp.
388 auto wire = builder.create<WireOp>(
389 result.getType(),
390 (memOp.getName() + "_" + memOp.getPortName(index).getValue()).str(),
391 memOp.getNameKind());
392 result.replaceAllUsesWith(wire.getResult());
393 result = wire.getResult();
394 // Create an access to all the common subfields.
395 auto adr = getAddr(builder, result);
396 auto enb = getEnable(builder, result);
397 auto clk = getClock(builder, result);
398 auto dta = getData(builder, result);
399 // IF the register is not yet created.
400 if (!regOfVec) {
401 // Create the register corresponding to the memory.
402 regOfVec = builder.create<RegOp>(
403 FVectorType::get(dataType, firMem.depth), clk, memOp.getNameAttr());
404
405 // Copy all the memory annotations.
406 if (!memOp.getAnnotationsAttr().empty())
407 scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
408 if (innerSym)
409 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
410 }
411 auto portKind = memOp.getPortKind(index);
412 if (portKind == MemOp::PortKind::Read) {
413 generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
414 } else if (portKind == MemOp::PortKind::Write) {
415 auto mask = getMask(builder, result);
416 generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
417 builder);
418 } else {
419 auto wmode = getWmode(builder, result);
420 auto wDta = getData(builder, result, true);
421 auto mask = getMask(builder, result);
422 generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
423 regOfVec.getResult(), builder);
424 }
425 }
426 // If a valid register is created, then replace all the debug port users
427 // with a RefType of the register. The RefType is obtained by using a
428 // RefSend on the register.
429 if (regOfVec)
430 for (auto r : debugPorts)
431 r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
432 }
433};
434} // end anonymous namespace
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * memTapSourceClass
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21