Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MemToRegOfVec.cpp
Go to the documentation of this file.
1//===- MemToRegOfVec.cpp - MemToRegOfVec Pass -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MemToRegOfVec pass.
10//
11//===----------------------------------------------------------------------===//
12
19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "mem-to-reg-of-vec"
24
25namespace circt {
26namespace firrtl {
27#define GEN_PASS_DEF_MEMTOREGOFVEC
28#include "circt/Dialect/FIRRTL/Passes.h.inc"
29} // namespace firrtl
30} // namespace circt
31
32using namespace circt;
33using namespace firrtl;
34
35namespace {
36struct MemToRegOfVecPass
37 : public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
38 MemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable)
39 : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
40
41 void runOnOperation() override {
42 auto circtOp = getOperation();
43 auto &instanceInfo = getAnalysis<InstanceInfo>();
44
47 return markAllAnalysesPreserved();
48
49 DenseSet<Operation *> dutModuleSet;
50 for (auto moduleOp : circtOp.getOps<FModuleOp>())
51 if (instanceInfo.anyInstanceInEffectiveDesign(moduleOp))
52 dutModuleSet.insert(moduleOp);
53
54 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
55 [&](Operation *op) {
56 if (auto mod = dyn_cast<FModuleOp>(op))
57 runOnModule(mod);
58 });
59 }
60
61 void runOnModule(FModuleOp mod) {
62
63 mod.getBodyBlock()->walk([&](MemOp memOp) {
64 LLVM_DEBUG(llvm::dbgs() << "\n Memory op:" << memOp);
65
66 auto firMem = memOp.getSummary();
67 // Ignore if the memory is candidate for macro replacement.
68 // The requirements for macro replacement:
69 // 1. read latency and write latency of one.
70 // 2. only one readwrite port or write port.
71 // 3. zero or one read port.
72 // 4. undefined read-under-write behavior.
73 if (replSeqMem &&
74 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
75 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
76 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
77 return;
78
79 generateMemory(memOp, firMem);
80 ++numConvertedMems;
81 memOp.erase();
82 });
83 }
84 Value addPipelineStages(ImplicitLocOpBuilder &b, size_t stages, Value clock,
85 Value pipeInput, StringRef name, Value gate = {}) {
86 if (!stages)
87 return pipeInput;
88
89 while (stages--) {
90 auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
91 if (gate) {
92 b.create<WhenOp>(gate, /*withElseRegion*/ false, [&]() {
93 b.create<MatchingConnectOp>(reg, pipeInput);
94 });
95 } else
96 b.create<MatchingConnectOp>(reg, pipeInput);
97
98 pipeInput = reg;
99 }
100
101 return pipeInput;
102 }
103
104 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
105 return builder.create<SubfieldOp>(bundle, "clk");
106 }
107
108 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
109 return builder.create<SubfieldOp>(bundle, "addr");
110 }
111
112 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
113 return builder.create<SubfieldOp>(bundle, "wmode");
114 }
115
116 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
117 return builder.create<SubfieldOp>(bundle, "en");
118 }
119
120 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
121 auto bType = type_cast<BundleType>(bundle.getType());
122 if (bType.getElement("mask"))
123 return builder.create<SubfieldOp>(bundle, "mask");
124 return builder.create<SubfieldOp>(bundle, "wmask");
125 }
126
127 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
128 bool getWdata = false) {
129 auto bType = type_cast<BundleType>(bundle.getType());
130 if (bType.getElement("data"))
131 return builder.create<SubfieldOp>(bundle, "data");
132 if (bType.getElement("rdata") && !getWdata)
133 return builder.create<SubfieldOp>(bundle, "rdata");
134 return builder.create<SubfieldOp>(bundle, "wdata");
135 }
136
137 void generateRead(const FirMemory &firMem, Value clock, Value addr,
138 Value enable, Value data, Value regOfVec,
139 ImplicitLocOpBuilder &builder) {
140 if (ignoreReadEnable) {
141 // If read enable is ignored, then guard the address update with read
142 // enable.
143 for (size_t j = 0, e = firMem.readLatency; j != e; ++j) {
144 auto enLast = enable;
145 if (j < e - 1)
146 enable = addPipelineStages(builder, 1, clock, enable, "en");
147 addr = addPipelineStages(builder, 1, clock, addr, "addr", enLast);
148 }
149 } else {
150 // Add pipeline stages to respect the read latency. One register for each
151 // latency cycle.
152 enable =
153 addPipelineStages(builder, firMem.readLatency, clock, enable, "en");
154 addr =
155 addPipelineStages(builder, firMem.readLatency, clock, addr, "addr");
156 }
157
158 // Read the register[address] into a temporary.
159 Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
160 if (!ignoreReadEnable) {
161 // Initialize read data out with invalid.
162 builder.create<MatchingConnectOp>(
163 data, builder.create<InvalidValueOp>(data.getType()));
164 // If enable is true, then connect the data read from memory register.
165 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
166 builder.create<MatchingConnectOp>(data, rdata);
167 });
168 } else {
169 // Ignore read enable signal.
170 builder.create<MatchingConnectOp>(data, rdata);
171 }
172 }
173
174 void generateWrite(const FirMemory &firMem, Value clock, Value addr,
175 Value enable, Value maskBits, Value wdataIn,
176 Value regOfVec, ImplicitLocOpBuilder &builder) {
177
178 auto numStages = firMem.writeLatency - 1;
179 // Add pipeline stages to respect the write latency. Intermediate registers
180 // for each stage.
181 addr = addPipelineStages(builder, numStages, clock, addr, "addr");
182 enable = addPipelineStages(builder, numStages, clock, enable, "en");
183 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
184 maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
185 // Create the register access.
186 FIRRTLBaseValue rdata = builder.create<SubaccessOp>(regOfVec, addr);
187
188 // The tuple for the access to individual fields of an aggregate data type.
189 // Tuple::<register, data, mask>
190 // The logic:
191 // if (mask)
192 // register = data
193 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
194
195 // Write to each aggregate data field is guarded by the corresponding mask
196 // field. This means we have to generate read and write access for each
197 // individual field of the aggregate type.
198 // There are two options to handle this,
199 // 1. FlattenMemory: cast the aggregate data into a UInt and generate
200 // appropriate mask logic.
201 // 2. Create access for each individual field of the aggregate type.
202 // Here we implement the option 2 using getFields.
203 // getFields, creates an access to each individual field of the data and
204 // mask, and the corresponding field into the register. It populates
205 // the loweredRegDataMaskFields vector.
206 // This is similar to what happens in LowerTypes.
207 //
208 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
209 builder)) {
210 wdataIn.getDefiningOp()->emitOpError(
211 "Cannot convert memory to bank of registers");
212 return;
213 }
214 // If enable:
215 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
216 // For each data field. Only one field if not aggregate.
217 for (auto regDataMask : loweredRegDataMaskFields) {
218 auto regField = std::get<0>(regDataMask);
219 auto dataField = std::get<1>(regDataMask);
220 auto maskField = std::get<2>(regDataMask);
221 // If mask, then update the register field.
222 builder.create<WhenOp>(maskField, /*withElseRegion*/ false, [&]() {
223 builder.create<MatchingConnectOp>(regField, dataField);
224 });
225 }
226 });
227 }
228
229 void generateReadWrite(const FirMemory &firMem, Value clock, Value addr,
230 Value enable, Value maskBits, Value wdataIn,
231 Value rdataOut, Value wmode, Value regOfVec,
232 ImplicitLocOpBuilder &builder) {
233
234 // Add pipeline stages to respect the write latency. Intermediate registers
235 // for each stage. Number of pipeline stages, max of read/write latency.
236 auto numStages = std::max(firMem.readLatency, firMem.writeLatency) - 1;
237 addr = addPipelineStages(builder, numStages, clock, addr, "addr");
238 enable = addPipelineStages(builder, numStages, clock, enable, "en");
239 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
240 maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
241
242 // Read the register[address] into a temporary.
243 Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
244
245 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
246 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
247 builder)) {
248 wdataIn.getDefiningOp()->emitOpError(
249 "Cannot convert memory to bank of registers");
250 return;
251 }
252 // Initialize read data out with invalid.
253 builder.create<MatchingConnectOp>(
254 rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
255 // If enable:
256 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
257 // If write mode:
258 builder.create<WhenOp>(
259 wmode, true,
260 // Write block:
261 [&]() {
262 // For each data field. Only one field if not aggregate.
263 for (auto regDataMask : loweredRegDataMaskFields) {
264 auto regField = std::get<0>(regDataMask);
265 auto dataField = std::get<1>(regDataMask);
266 auto maskField = std::get<2>(regDataMask);
267 // If mask true, then set the field.
268 builder.create<WhenOp>(
269 maskField, /*withElseRegion*/ false, [&]() {
270 builder.create<MatchingConnectOp>(regField, dataField);
271 });
272 }
273 },
274 // Read block:
275 [&]() { builder.create<MatchingConnectOp>(rdataOut, rdata); });
276 });
277 }
278
279 // Generate individual field accesses for an aggregate type. Return false if
280 // it fails. Which can happen if invalid fields are present of the mask and
281 // input types donot match. The assumption is that, \p reg and \p input have
282 // exactly the same type. And \p mask has the same bundle fields, but each
283 // field is of type UInt<1> So, populate the \p results with each field
284 // access. For example, the first entry should be access to first field of \p
285 // reg, first field of \p input and first field of \p mask.
286 bool getFields(Value reg, Value input, Value mask,
287 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
288 ImplicitLocOpBuilder &builder) {
289
290 // Check if the number of fields of mask and input type match.
291 auto isValidMask = [&](FIRRTLType inType, FIRRTLType maskType) -> bool {
292 if (auto bundle = type_dyn_cast<BundleType>(inType)) {
293 if (auto mBundle = type_dyn_cast<BundleType>(maskType))
294 return mBundle.getNumElements() == bundle.getNumElements();
295 } else if (auto vec = type_dyn_cast<FVectorType>(inType)) {
296 if (auto mVec = type_dyn_cast<FVectorType>(maskType))
297 return mVec.getNumElements() == vec.getNumElements();
298 } else
299 return true;
300 return false;
301 };
302
303 std::function<bool(Value, Value, Value)> flatAccess =
304 [&](Value reg, Value input, Value mask) -> bool {
305 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
306 if (!isValidMask(inType, type_cast<FIRRTLType>(mask.getType()))) {
307 input.getDefiningOp()->emitOpError("Mask type is not valid");
308 return false;
309 }
311 .Case<BundleType>([&](BundleType bundle) {
312 for (size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
313 auto regField = builder.create<SubfieldOp>(reg, i);
314 auto inputField = builder.create<SubfieldOp>(input, i);
315 auto maskField = builder.create<SubfieldOp>(mask, i);
316 if (!flatAccess(regField, inputField, maskField))
317 return false;
318 }
319 return true;
320 })
321 .Case<FVectorType>([&](auto vector) {
322 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
323 auto regField = builder.create<SubindexOp>(reg, i);
324 auto inputField = builder.create<SubindexOp>(input, i);
325 auto maskField = builder.create<SubindexOp>(mask, i);
326 if (!flatAccess(regField, inputField, maskField))
327 return false;
328 }
329 return true;
330 })
331 .Case<IntType>([&](auto iType) {
332 results.push_back({reg, input, mask});
333 return iType.getWidth().has_value();
334 })
335 .Default([&](auto) { return false; });
336 };
337 if (flatAccess(reg, input, mask))
338 return true;
339 return false;
340 }
341
342 void scatterMemTapAnno(RegOp op, ArrayAttr attr,
343 ImplicitLocOpBuilder &builder) {
344 AnnotationSet annos(attr);
345 SmallVector<Attribute> regAnnotations;
346 auto vecType = type_cast<FVectorType>(op.getResult().getType());
347 for (auto anno : annos) {
348 if (anno.isClass(memTapSourceClass)) {
349 for (size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
350 .getNumElements();
351 i != e; ++i) {
352 NamedAttrList newAnno;
353 newAnno.append("class", anno.getMember("class"));
354 newAnno.append("circt.fieldID",
355 builder.getI64IntegerAttr(vecType.getFieldID(i)));
356 newAnno.append("id", anno.getMember("id"));
357 if (auto nla = anno.getMember("circt.nonlocal"))
358 newAnno.append("circt.nonlocal", nla);
359 newAnno.append(
360 "portID",
361 IntegerAttr::get(IntegerType::get(builder.getContext(), 64), i));
362
363 regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
364 }
365 } else
366 regAnnotations.push_back(anno.getAttr());
367 }
368 op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
369 }
370
371 /// Generate the logic for implementing the memory using Registers.
372 void generateMemory(MemOp memOp, FirMemory &firMem) {
373 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
374 auto dataType = memOp.getDataType();
375
376 auto innerSym = memOp.getInnerSym();
377 SmallVector<Value> debugPorts;
378
379 RegOp regOfVec = {};
380 for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
381 ++index) {
382 auto result = memOp.getResult(index);
383 if (type_isa<RefType>(result.getType())) {
384 debugPorts.push_back(result);
385 continue;
386 }
387 // Create a temporary wire to replace the memory port. This makes it
388 // simpler to delete the memOp.
389 auto wire = builder.create<WireOp>(
390 result.getType(),
391 (memOp.getName() + "_" + memOp.getPortName(index).getValue()).str(),
392 memOp.getNameKind());
393 result.replaceAllUsesWith(wire.getResult());
394 result = wire.getResult();
395 // Create an access to all the common subfields.
396 auto adr = getAddr(builder, result);
397 auto enb = getEnable(builder, result);
398 auto clk = getClock(builder, result);
399 auto dta = getData(builder, result);
400 // IF the register is not yet created.
401 if (!regOfVec) {
402 // Create the register corresponding to the memory.
403 regOfVec = builder.create<RegOp>(
404 FVectorType::get(dataType, firMem.depth), clk, memOp.getNameAttr());
405
406 // Copy all the memory annotations.
407 if (!memOp.getAnnotationsAttr().empty())
408 scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
409 if (innerSym)
410 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
411 }
412 auto portKind = memOp.getPortKind(index);
413 if (portKind == MemOp::PortKind::Read) {
414 generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
415 } else if (portKind == MemOp::PortKind::Write) {
416 auto mask = getMask(builder, result);
417 generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
418 builder);
419 } else {
420 auto wmode = getWmode(builder, result);
421 auto wDta = getData(builder, result, true);
422 auto mask = getMask(builder, result);
423 generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
424 regOfVec.getResult(), builder);
425 }
426 }
427 // If a valid register is created, then replace all the debug port users
428 // with a RefType of the register. The RefType is obtained by using a
429 // RefSend on the register.
430 if (regOfVec)
431 for (auto r : debugPorts)
432 r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
433 }
434
435private:
436 bool replSeqMem;
437 bool ignoreReadEnable;
438};
439} // end anonymous namespace
440
441std::unique_ptr<mlir::Pass>
442circt::firrtl::createMemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable) {
443 return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
444}
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21