CIRCT 20.0.0git
Loading...
Searching...
No Matches
MemToRegOfVec.cpp
Go to the documentation of this file.
1//===- MemToRegOfVec.cpp - MemToRegOfVec Pass -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MemToRegOfVec pass.
10//
11//===----------------------------------------------------------------------===//
12
19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "mem-to-reg-of-vec"
24
25namespace circt {
26namespace firrtl {
27#define GEN_PASS_DEF_MEMTOREGOFVEC
28#include "circt/Dialect/FIRRTL/Passes.h.inc"
29} // namespace firrtl
30} // namespace circt
31
32using namespace circt;
33using namespace firrtl;
34
35namespace {
36struct MemToRegOfVecPass
37 : public circt::firrtl::impl::MemToRegOfVecBase<MemToRegOfVecPass> {
38 MemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable)
39 : replSeqMem(replSeqMem), ignoreReadEnable(ignoreReadEnable){};
40
41 void runOnOperation() override {
42 auto circtOp = getOperation();
43 auto &instanceInfo = getAnalysis<InstanceInfo>();
44
47 return markAllAnalysesPreserved();
48
49 DenseSet<Operation *> dutModuleSet;
50 for (auto moduleOp : circtOp.getOps<FModuleOp>())
51 if (instanceInfo.anyInstanceInEffectiveDesign(moduleOp))
52 dutModuleSet.insert(moduleOp);
53
54 mlir::parallelForEach(circtOp.getContext(), dutModuleSet,
55 [&](Operation *op) {
56 if (auto mod = dyn_cast<FModuleOp>(op))
57 runOnModule(mod);
58 });
59 }
60
61 void runOnModule(FModuleOp mod) {
62
63 mod.getBodyBlock()->walk([&](MemOp memOp) {
64 LLVM_DEBUG(llvm::dbgs() << "\n Memory op:" << memOp);
66 return;
67
68 auto firMem = memOp.getSummary();
69 // Ignore if the memory is candidate for macro replacement.
70 // The requirements for macro replacement:
71 // 1. read latency and write latency of one.
72 // 2. only one readwrite port or write port.
73 // 3. zero or one read port.
74 // 4. undefined read-under-write behavior.
75 if (replSeqMem &&
76 ((firMem.readLatency == 1 && firMem.writeLatency == 1) &&
77 (firMem.numWritePorts + firMem.numReadWritePorts == 1) &&
78 (firMem.numReadPorts <= 1) && firMem.dataWidth > 0))
79 return;
80
81 generateMemory(memOp, firMem);
82 ++numConvertedMems;
83 memOp.erase();
84 });
85 }
86 Value addPipelineStages(ImplicitLocOpBuilder &b, size_t stages, Value clock,
87 Value pipeInput, StringRef name, Value gate = {}) {
88 if (!stages)
89 return pipeInput;
90
91 while (stages--) {
92 auto reg = b.create<RegOp>(pipeInput.getType(), clock, name).getResult();
93 if (gate) {
94 b.create<WhenOp>(gate, /*withElseRegion*/ false, [&]() {
95 b.create<MatchingConnectOp>(reg, pipeInput);
96 });
97 } else
98 b.create<MatchingConnectOp>(reg, pipeInput);
99
100 pipeInput = reg;
101 }
102
103 return pipeInput;
104 }
105
106 Value getClock(ImplicitLocOpBuilder &builder, Value bundle) {
107 return builder.create<SubfieldOp>(bundle, "clk");
108 }
109
110 Value getAddr(ImplicitLocOpBuilder &builder, Value bundle) {
111 return builder.create<SubfieldOp>(bundle, "addr");
112 }
113
114 Value getWmode(ImplicitLocOpBuilder &builder, Value bundle) {
115 return builder.create<SubfieldOp>(bundle, "wmode");
116 }
117
118 Value getEnable(ImplicitLocOpBuilder &builder, Value bundle) {
119 return builder.create<SubfieldOp>(bundle, "en");
120 }
121
122 Value getMask(ImplicitLocOpBuilder &builder, Value bundle) {
123 auto bType = type_cast<BundleType>(bundle.getType());
124 if (bType.getElement("mask"))
125 return builder.create<SubfieldOp>(bundle, "mask");
126 return builder.create<SubfieldOp>(bundle, "wmask");
127 }
128
129 Value getData(ImplicitLocOpBuilder &builder, Value bundle,
130 bool getWdata = false) {
131 auto bType = type_cast<BundleType>(bundle.getType());
132 if (bType.getElement("data"))
133 return builder.create<SubfieldOp>(bundle, "data");
134 if (bType.getElement("rdata") && !getWdata)
135 return builder.create<SubfieldOp>(bundle, "rdata");
136 return builder.create<SubfieldOp>(bundle, "wdata");
137 }
138
139 void generateRead(const FirMemory &firMem, Value clock, Value addr,
140 Value enable, Value data, Value regOfVec,
141 ImplicitLocOpBuilder &builder) {
142 if (ignoreReadEnable) {
143 // If read enable is ignored, then guard the address update with read
144 // enable.
145 for (size_t j = 0, e = firMem.readLatency; j != e; ++j) {
146 auto enLast = enable;
147 if (j < e - 1)
148 enable = addPipelineStages(builder, 1, clock, enable, "en");
149 addr = addPipelineStages(builder, 1, clock, addr, "addr", enLast);
150 }
151 } else {
152 // Add pipeline stages to respect the read latency. One register for each
153 // latency cycle.
154 enable =
155 addPipelineStages(builder, firMem.readLatency, clock, enable, "en");
156 addr =
157 addPipelineStages(builder, firMem.readLatency, clock, addr, "addr");
158 }
159
160 // Read the register[address] into a temporary.
161 Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
162 if (!ignoreReadEnable) {
163 // Initialize read data out with invalid.
164 builder.create<MatchingConnectOp>(
165 data, builder.create<InvalidValueOp>(data.getType()));
166 // If enable is true, then connect the data read from memory register.
167 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
168 builder.create<MatchingConnectOp>(data, rdata);
169 });
170 } else {
171 // Ignore read enable signal.
172 builder.create<MatchingConnectOp>(data, rdata);
173 }
174 }
175
176 void generateWrite(const FirMemory &firMem, Value clock, Value addr,
177 Value enable, Value maskBits, Value wdataIn,
178 Value regOfVec, ImplicitLocOpBuilder &builder) {
179
180 auto numStages = firMem.writeLatency - 1;
181 // Add pipeline stages to respect the write latency. Intermediate registers
182 // for each stage.
183 addr = addPipelineStages(builder, numStages, clock, addr, "addr");
184 enable = addPipelineStages(builder, numStages, clock, enable, "en");
185 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
186 maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
187 // Create the register access.
188 FIRRTLBaseValue rdata = builder.create<SubaccessOp>(regOfVec, addr);
189
190 // The tuple for the access to individual fields of an aggregate data type.
191 // Tuple::<register, data, mask>
192 // The logic:
193 // if (mask)
194 // register = data
195 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
196
197 // Write to each aggregate data field is guarded by the corresponding mask
198 // field. This means we have to generate read and write access for each
199 // individual field of the aggregate type.
200 // There are two options to handle this,
201 // 1. FlattenMemory: cast the aggregate data into a UInt and generate
202 // appropriate mask logic.
203 // 2. Create access for each individual field of the aggregate type.
204 // Here we implement the option 2 using getFields.
205 // getFields, creates an access to each individual field of the data and
206 // mask, and the corresponding field into the register. It populates
207 // the loweredRegDataMaskFields vector.
208 // This is similar to what happens in LowerTypes.
209 //
210 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
211 builder)) {
212 wdataIn.getDefiningOp()->emitOpError(
213 "Cannot convert memory to bank of registers");
214 return;
215 }
216 // If enable:
217 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
218 // For each data field. Only one field if not aggregate.
219 for (auto regDataMask : loweredRegDataMaskFields) {
220 auto regField = std::get<0>(regDataMask);
221 auto dataField = std::get<1>(regDataMask);
222 auto maskField = std::get<2>(regDataMask);
223 // If mask, then update the register field.
224 builder.create<WhenOp>(maskField, /*withElseRegion*/ false, [&]() {
225 builder.create<MatchingConnectOp>(regField, dataField);
226 });
227 }
228 });
229 }
230
231 void generateReadWrite(const FirMemory &firMem, Value clock, Value addr,
232 Value enable, Value maskBits, Value wdataIn,
233 Value rdataOut, Value wmode, Value regOfVec,
234 ImplicitLocOpBuilder &builder) {
235
236 // Add pipeline stages to respect the write latency. Intermediate registers
237 // for each stage. Number of pipeline stages, max of read/write latency.
238 auto numStages = std::max(firMem.readLatency, firMem.writeLatency) - 1;
239 addr = addPipelineStages(builder, numStages, clock, addr, "addr");
240 enable = addPipelineStages(builder, numStages, clock, enable, "en");
241 wdataIn = addPipelineStages(builder, numStages, clock, wdataIn, "wdata");
242 maskBits = addPipelineStages(builder, numStages, clock, maskBits, "wmask");
243
244 // Read the register[address] into a temporary.
245 Value rdata = builder.create<SubaccessOp>(regOfVec, addr);
246
247 SmallVector<std::tuple<Value, Value, Value>, 8> loweredRegDataMaskFields;
248 if (!getFields(rdata, wdataIn, maskBits, loweredRegDataMaskFields,
249 builder)) {
250 wdataIn.getDefiningOp()->emitOpError(
251 "Cannot convert memory to bank of registers");
252 return;
253 }
254 // Initialize read data out with invalid.
255 builder.create<MatchingConnectOp>(
256 rdataOut, builder.create<InvalidValueOp>(rdataOut.getType()));
257 // If enable:
258 builder.create<WhenOp>(enable, /*withElseRegion*/ false, [&]() {
259 // If write mode:
260 builder.create<WhenOp>(
261 wmode, true,
262 // Write block:
263 [&]() {
264 // For each data field. Only one field if not aggregate.
265 for (auto regDataMask : loweredRegDataMaskFields) {
266 auto regField = std::get<0>(regDataMask);
267 auto dataField = std::get<1>(regDataMask);
268 auto maskField = std::get<2>(regDataMask);
269 // If mask true, then set the field.
270 builder.create<WhenOp>(
271 maskField, /*withElseRegion*/ false, [&]() {
272 builder.create<MatchingConnectOp>(regField, dataField);
273 });
274 }
275 },
276 // Read block:
277 [&]() { builder.create<MatchingConnectOp>(rdataOut, rdata); });
278 });
279 }
280
281 // Generate individual field accesses for an aggregate type. Return false if
282 // it fails. Which can happen if invalid fields are present of the mask and
283 // input types donot match. The assumption is that, \p reg and \p input have
284 // exactly the same type. And \p mask has the same bundle fields, but each
285 // field is of type UInt<1> So, populate the \p results with each field
286 // access. For example, the first entry should be access to first field of \p
287 // reg, first field of \p input and first field of \p mask.
288 bool getFields(Value reg, Value input, Value mask,
289 SmallVectorImpl<std::tuple<Value, Value, Value>> &results,
290 ImplicitLocOpBuilder &builder) {
291
292 // Check if the number of fields of mask and input type match.
293 auto isValidMask = [&](FIRRTLType inType, FIRRTLType maskType) -> bool {
294 if (auto bundle = type_dyn_cast<BundleType>(inType)) {
295 if (auto mBundle = type_dyn_cast<BundleType>(maskType))
296 return mBundle.getNumElements() == bundle.getNumElements();
297 } else if (auto vec = type_dyn_cast<FVectorType>(inType)) {
298 if (auto mVec = type_dyn_cast<FVectorType>(maskType))
299 return mVec.getNumElements() == vec.getNumElements();
300 } else
301 return true;
302 return false;
303 };
304
305 std::function<bool(Value, Value, Value)> flatAccess =
306 [&](Value reg, Value input, Value mask) -> bool {
307 FIRRTLType inType = type_cast<FIRRTLType>(input.getType());
308 if (!isValidMask(inType, type_cast<FIRRTLType>(mask.getType()))) {
309 input.getDefiningOp()->emitOpError("Mask type is not valid");
310 return false;
311 }
313 .Case<BundleType>([&](BundleType bundle) {
314 for (size_t i = 0, e = bundle.getNumElements(); i != e; ++i) {
315 auto regField = builder.create<SubfieldOp>(reg, i);
316 auto inputField = builder.create<SubfieldOp>(input, i);
317 auto maskField = builder.create<SubfieldOp>(mask, i);
318 if (!flatAccess(regField, inputField, maskField))
319 return false;
320 }
321 return true;
322 })
323 .Case<FVectorType>([&](auto vector) {
324 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
325 auto regField = builder.create<SubindexOp>(reg, i);
326 auto inputField = builder.create<SubindexOp>(input, i);
327 auto maskField = builder.create<SubindexOp>(mask, i);
328 if (!flatAccess(regField, inputField, maskField))
329 return false;
330 }
331 return true;
332 })
333 .Case<IntType>([&](auto iType) {
334 results.push_back({reg, input, mask});
335 return iType.getWidth().has_value();
336 })
337 .Default([&](auto) { return false; });
338 };
339 if (flatAccess(reg, input, mask))
340 return true;
341 return false;
342 }
343
344 void scatterMemTapAnno(RegOp op, ArrayAttr attr,
345 ImplicitLocOpBuilder &builder) {
346 AnnotationSet annos(attr);
347 SmallVector<Attribute> regAnnotations;
348 auto vecType = type_cast<FVectorType>(op.getResult().getType());
349 for (auto anno : annos) {
350 if (anno.isClass(memTapSourceClass)) {
351 for (size_t i = 0, e = type_cast<FVectorType>(op.getResult().getType())
352 .getNumElements();
353 i != e; ++i) {
354 NamedAttrList newAnno;
355 newAnno.append("class", anno.getMember("class"));
356 newAnno.append("circt.fieldID",
357 builder.getI64IntegerAttr(vecType.getFieldID(i)));
358 newAnno.append("id", anno.getMember("id"));
359 if (auto nla = anno.getMember("circt.nonlocal"))
360 newAnno.append("circt.nonlocal", nla);
361 newAnno.append(
362 "portID",
363 IntegerAttr::get(IntegerType::get(builder.getContext(), 64), i));
364
365 regAnnotations.push_back(builder.getDictionaryAttr(newAnno));
366 }
367 } else
368 regAnnotations.push_back(anno.getAttr());
369 }
370 op.setAnnotationsAttr(builder.getArrayAttr(regAnnotations));
371 }
372
373 /// Generate the logic for implementing the memory using Registers.
374 void generateMemory(MemOp memOp, FirMemory &firMem) {
375 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
376 auto dataType = memOp.getDataType();
377
378 auto innerSym = memOp.getInnerSym();
379 SmallVector<Value> debugPorts;
380
381 RegOp regOfVec = {};
382 for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
383 ++index) {
384 auto result = memOp.getResult(index);
385 if (type_isa<RefType>(result.getType())) {
386 debugPorts.push_back(result);
387 continue;
388 }
389 // Create a temporary wire to replace the memory port. This makes it
390 // simpler to delete the memOp.
391 auto wire = builder.create<WireOp>(
392 result.getType(),
393 (memOp.getName() + "_" + memOp.getPortName(index).getValue()).str(),
394 memOp.getNameKind());
395 result.replaceAllUsesWith(wire.getResult());
396 result = wire.getResult();
397 // Create an access to all the common subfields.
398 auto adr = getAddr(builder, result);
399 auto enb = getEnable(builder, result);
400 auto clk = getClock(builder, result);
401 auto dta = getData(builder, result);
402 // IF the register is not yet created.
403 if (!regOfVec) {
404 // Create the register corresponding to the memory.
405 regOfVec = builder.create<RegOp>(
406 FVectorType::get(dataType, firMem.depth), clk, memOp.getNameAttr());
407
408 // Copy all the memory annotations.
409 if (!memOp.getAnnotationsAttr().empty())
410 scatterMemTapAnno(regOfVec, memOp.getAnnotationsAttr(), builder);
411 if (innerSym)
412 regOfVec.setInnerSymAttr(memOp.getInnerSymAttr());
413 }
414 auto portKind = memOp.getPortKind(index);
415 if (portKind == MemOp::PortKind::Read) {
416 generateRead(firMem, clk, adr, enb, dta, regOfVec.getResult(), builder);
417 } else if (portKind == MemOp::PortKind::Write) {
418 auto mask = getMask(builder, result);
419 generateWrite(firMem, clk, adr, enb, mask, dta, regOfVec.getResult(),
420 builder);
421 } else {
422 auto wmode = getWmode(builder, result);
423 auto wDta = getData(builder, result, true);
424 auto mask = getMask(builder, result);
425 generateReadWrite(firMem, clk, adr, enb, mask, wDta, dta, wmode,
426 regOfVec.getResult(), builder);
427 }
428 }
429 // If a valid register is created, then replace all the debug port users
430 // with a RefType of the register. The RefType is obtained by using a
431 // RefSend on the register.
432 if (regOfVec)
433 for (auto r : debugPorts)
434 r.replaceAllUsesWith(builder.create<RefSendOp>(regOfVec.getResult()));
435 }
436
437private:
438 bool replSeqMem;
439 bool ignoreReadEnable;
440};
441} // end anonymous namespace
442
443std::unique_ptr<mlir::Pass>
444circt::firrtl::createMemToRegOfVecPass(bool replSeqMem, bool ignoreReadEnable) {
445 return std::make_unique<MemToRegOfVecPass>(replSeqMem, ignoreReadEnable);
446}
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
constexpr const char * excludeMemToRegAnnoClass
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
constexpr const char * convertMemToRegOfVecAnnoClass
constexpr const char * memTapSourceClass
std::unique_ptr< mlir::Pass > createMemToRegOfVecPass(bool replSeqMem=false, bool ignoreReadEnable=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21