Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
FlattenMemory.cpp
Go to the documentation of this file.
1//===- FlattenMemroy.cpp - Flatten Memory Pass ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the FlattenMemory pass.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/Pass/Pass.h"
19#include "llvm/Support/Debug.h"
20#include <numeric>
21
22#define DEBUG_TYPE "lower-memory"
23
24namespace circt {
25namespace firrtl {
26#define GEN_PASS_DEF_FLATTENMEMORY
27#include "circt/Dialect/FIRRTL/Passes.h.inc"
28} // namespace firrtl
29} // namespace circt
30
31using namespace circt;
32using namespace firrtl;
33
34namespace {
35struct FlattenMemoryPass
36 : public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
37
38 /// Returns true if the the memory has annotations on a subfield of any of the
39 /// ports.
40 static bool hasSubAnno(MemOp op) {
41 for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
42 for (auto attr : op.getPortAnnotation(portIdx))
43 if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
44 return true;
45 return false;
46 };
47
48 /// This pass flattens the aggregate data of memory into a UInt, and inserts
49 /// appropriate bitcasts to access the data.
50 void runOnOperation() override {
51 LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
52 << getOperation().getName());
53 SmallVector<Operation *> opsToErase;
54 getOperation().getBodyBlock()->walk([&](MemOp memOp) {
55 LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
56
57 // Cannot flatten a memory if it has debug ports, because debug port
58 // implies a memtap and we cannot transform the datatype for a memory that
59 // is tapped.
60 for (auto res : memOp.getResults())
61 if (isa<RefType>(res.getType()))
62 return;
63
64 // If subannotations present on aggregate fields, we cannot flatten the
65 // memory. It must be split into one memory per aggregate field.
66 // Do not overwrite the pass flag!
67 if (hasSubAnno(memOp))
68 return;
69
70 // The vector of leaf elements type after flattening the data. If any of
71 // the datatypes cannot be flattened, then we cannot flatten the memory.
72 SmallVector<FIRRTLBaseType> flatMemType;
73 if (!flattenType(memOp.getDataType(), flatMemType))
74 return;
75
76 // Calculate the width of the memory data type, and the width of
77 // each individual aggregate leaf elements.
78 size_t memFlatWidth = 0;
79 SmallVector<int32_t> memWidths;
80 for (auto f : flatMemType) {
81 LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
82 auto w = f.getBitWidthOrSentinel();
83 memWidths.push_back(w);
84 memFlatWidth += w;
85 }
86 // If all the widths are zero, ignore the memory.
87 if (!memFlatWidth)
88 return;
89
90 // Calculate the mask granularity of this memory, which is how many bits
91 // of the data each mask bit controls. This is the greatest common
92 // denominator of the widths of the flattened data types.
93 auto maskGran = memWidths.front();
94 for (auto w : ArrayRef(memWidths).drop_front())
95 maskGran = std::gcd(maskGran, w);
96
97 // Total mask bitwidth after flattening.
98 uint32_t totalmaskWidths = 0;
99 // How many mask bits each field type requires.
100 SmallVector<unsigned> maskWidths;
101 for (auto w : memWidths) {
102 // How many mask bits required for each flattened field.
103 auto mWidth = w / maskGran;
104 maskWidths.push_back(mWidth);
105 totalmaskWidths += mWidth;
106 }
107
108 // Now create a new memory of type flattened data.
109 // ----------------------------------------------
110 SmallVector<Type, 8> ports;
111 SmallVector<Attribute, 8> portNames;
112
113 auto *context = memOp.getContext();
114 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
115 // Create a new memory data type of unsigned and computed width.
116 auto flatType = UIntType::get(context, memFlatWidth);
117 for (auto port : memOp.getPorts()) {
118 ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
119 port.second, totalmaskWidths));
120 portNames.push_back(port.first);
121 }
122
123 // Create the new flattened memory.
124 auto flatMem = builder.create<MemOp>(
125 ports, memOp.getReadLatency(), memOp.getWriteLatency(),
126 memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
127 memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
128 memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
129 memOp.getInitAttr(), memOp.getPrefixAttr());
130
131 // Hook up the new memory to the wires the old memory was replaced with.
132 for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
133 ++index) {
134
135 // Create a wire with the original type, and replace all uses of the old
136 // memory with the wire. We will be reconstructing the original type
137 // in the wire from the bitvector of the flattened memory.
138 auto result = memOp.getResult(index);
139 auto wire = builder
140 .create<WireOp>(result.getType(),
141 (memOp.getName() + "_" +
142 memOp.getPortName(index).getValue())
143 .str())
144 .getResult();
145 result.replaceAllUsesWith(wire);
146 result = wire;
147 auto newResult = flatMem.getResult(index);
148 auto rType = type_cast<BundleType>(result.getType());
149 for (size_t fieldIndex = 0, fend = rType.getNumElements();
150 fieldIndex != fend; ++fieldIndex) {
151 auto name = rType.getElement(fieldIndex).name;
152 auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
153 FIRRTLBaseValue newField =
154 builder.create<SubfieldOp>(newResult, fieldIndex);
155 // data and mask depend on the memory type which was split. They can
156 // also go both directions, depending on the port direction.
157 if (!(name == "data" || name == "mask" || name == "wdata" ||
158 name == "wmask" || name == "rdata")) {
159 emitConnect(builder, newField, oldField);
160 continue;
161 }
162 Value realOldField = oldField;
163 if (rType.getElement(fieldIndex).isFlip) {
164 // Cast the memory read data from flat type to aggregate.
165 auto castField =
166 builder.createOrFold<BitCastOp>(oldField.getType(), newField);
167 // Write the aggregate read data.
168 emitConnect(builder, realOldField, castField);
169 } else {
170 // Cast the input aggregate write data to flat type.
171 auto newFieldType = newField.getType();
172 auto oldFieldBitWidth = getBitWidth(oldField.getType());
173 // Following condition is true, if a data field is 0 bits. Then
174 // newFieldType is of smaller bits than old.
175 if (getBitWidth(newFieldType) != *oldFieldBitWidth)
176 newFieldType = UIntType::get(context, *oldFieldBitWidth);
177 realOldField = builder.create<BitCastOp>(newFieldType, oldField);
178 // Mask bits require special handling, since some of the mask bits
179 // need to be repeated, direct bitcasting wouldn't work. Depending
180 // on the mask granularity, some mask bits will be repeated.
181 if ((name == "mask" || name == "wmask") &&
182 (maskWidths.size() != totalmaskWidths)) {
183 Value catMasks;
184 for (const auto &m : llvm::enumerate(maskWidths)) {
185 // Get the mask bit.
186 auto mBit = builder.createOrFold<BitsPrimOp>(
187 realOldField, m.index(), m.index());
188 // Check how many times the mask bit needs to be prepend.
189 for (size_t repeat = 0; repeat < m.value(); repeat++)
190 if ((m.index() == 0 && repeat == 0) || !catMasks)
191 catMasks = mBit;
192 else
193 catMasks = builder.createOrFold<CatPrimOp>(
194 ValueRange{mBit, catMasks});
195 }
196 realOldField = catMasks;
197 }
198 // Now set the mask or write data.
199 // Ensure that the types match.
200 emitConnect(builder, newField,
201 builder.createOrFold<BitCastOp>(newField.getType(),
202 realOldField));
203 }
204 }
205 }
206 ++numFlattenedMems;
207 memOp.erase();
208 return;
209 });
210 }
211
212private:
213 // Convert an aggregate type into a flat list of fields. This is used to
214 // flatten the aggregate memory datatype. Recursively populate the results
215 // with each ground type field.
216 static bool flattenType(FIRRTLType type,
217 SmallVectorImpl<FIRRTLBaseType> &results) {
218 std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
220 .Case<BundleType>([&](auto bundle) {
221 for (auto &elt : bundle)
222 if (!flatten(elt.type))
223 return false;
224 return true;
225 })
226 .Case<FVectorType>([&](auto vector) {
227 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
228 if (!flatten(vector.getElementType()))
229 return false;
230 return true;
231 })
232 .Case<IntType>([&](IntType type) {
233 results.push_back(type);
234 return type.getWidth().has_value();
235 })
236 .Case<FEnumType>([&](FEnumType type) {
237 results.emplace_back(type);
238 return true;
239 })
240 .Default([&](auto) { return false; });
241 };
242 // Return true only if this is an aggregate with more than one element.
243 return flatten(type) && results.size() > 1;
244 }
245
246 Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
247 if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
248 return builder->create<SubfieldOp>(val, index);
249 if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
250 return builder->create<SubindexOp>(val, index);
251
252 llvm_unreachable("Unknown aggregate type");
253 return nullptr;
254 }
255};
256} // end anonymous namespace
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.