CIRCT 20.0.0git
Loading...
Searching...
No Matches
FlattenMemory.cpp
Go to the documentation of this file.
1//===- FlattenMemroy.cpp - Flatten Memory Pass ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the FlattenMemory pass.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/Pass/Pass.h"
19#include "llvm/Support/Debug.h"
20#include <numeric>
21
22#define DEBUG_TYPE "lower-memory"
23
24namespace circt {
25namespace firrtl {
26#define GEN_PASS_DEF_FLATTENMEMORY
27#include "circt/Dialect/FIRRTL/Passes.h.inc"
28} // namespace firrtl
29} // namespace circt
30
31using namespace circt;
32using namespace firrtl;
33
34namespace {
35struct FlattenMemoryPass
36 : public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
37 /// This pass flattens the aggregate data of memory into a UInt, and inserts
38 /// appropriate bitcasts to access the data.
39 void runOnOperation() override {
40 LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
41 << getOperation().getName());
42 SmallVector<Operation *> opsToErase;
43 auto hasSubAnno = [&](MemOp op) -> bool {
44 for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
45 for (auto attr : op.getPortAnnotation(portIdx))
46 if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
47 return true;
48
49 return false;
50 };
51 getOperation().getBodyBlock()->walk([&](MemOp memOp) {
52 LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
53 // The vector of leaf elements type after flattening the data.
54 SmallVector<IntType> flatMemType;
55 // MaskGranularity : how many bits each mask bit controls.
56 size_t maskGran = 1;
57 // Total mask bitwidth after flattening.
58 uint32_t totalmaskWidths = 0;
59 // How many mask bits each field type requires.
60 SmallVector<unsigned> maskWidths;
61
62 // Cannot flatten a memory if it has debug ports, because debug port
63 // implies a memtap and we cannot transform the datatype for a memory that
64 // is tapped.
65 for (auto res : memOp.getResults())
66 if (isa<RefType>(res.getType()))
67 return;
68 // If subannotations present on aggregate fields, we cannot flatten the
69 // memory. It must be split into one memory per aggregate field.
70 // Do not overwrite the pass flag!
71 if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
72 return;
73
74 SmallVector<Operation *, 8> flatData;
75 SmallVector<int32_t> memWidths;
76 size_t memFlatWidth = 0;
77 // Get the width of individual aggregate leaf elements.
78 for (auto f : flatMemType) {
79 LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
80 auto w = *f.getWidth();
81 memWidths.push_back(w);
82 memFlatWidth += w;
83 }
84 // If all the widths are zero, ignore the memory.
85 if (!memFlatWidth)
86 return;
87 maskGran = memWidths[0];
88 // Compute the GCD of all data bitwidths.
89 for (auto w : memWidths) {
90 maskGran = std::gcd(maskGran, w);
91 }
92 for (auto w : memWidths) {
93 // How many mask bits required for each flattened field.
94 auto mWidth = w / maskGran;
95 maskWidths.push_back(mWidth);
96 totalmaskWidths += mWidth;
97 }
98 // Now create a new memory of type flattened data.
99 // ----------------------------------------------
100 SmallVector<Type, 8> ports;
101 SmallVector<Attribute, 8> portNames;
102
103 auto *context = memOp.getContext();
104 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
105 // Create a new memoty data type of unsigned and computed width.
106 auto flatType = UIntType::get(context, memFlatWidth);
107 auto opPorts = memOp.getPorts();
108 for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
109 auto port = opPorts[portIdx];
110 ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
111 port.second, totalmaskWidths));
112 portNames.push_back(port.first);
113 }
114
115 auto flatMem = builder.create<MemOp>(
116 ports, memOp.getReadLatency(), memOp.getWriteLatency(),
117 memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
118 memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
119 memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
120 memOp.getInitAttr(), memOp.getPrefixAttr());
121 // Hook up the new memory to the wires the old memory was replaced with.
122 for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
123 ++index) {
124 auto result = memOp.getResult(index);
125 auto wire = builder
126 .create<WireOp>(result.getType(),
127 (memOp.getName() + "_" +
128 memOp.getPortName(index).getValue())
129 .str())
130 .getResult();
131 result.replaceAllUsesWith(wire);
132 result = wire;
133 auto newResult = flatMem.getResult(index);
134 auto rType = type_cast<BundleType>(result.getType());
135 for (size_t fieldIndex = 0, fend = rType.getNumElements();
136 fieldIndex != fend; ++fieldIndex) {
137 auto name = rType.getElement(fieldIndex).name.getValue();
138 auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
139 FIRRTLBaseValue newField =
140 builder.create<SubfieldOp>(newResult, fieldIndex);
141 // data and mask depend on the memory type which was split. They can
142 // also go both directions, depending on the port direction.
143 if (!(name == "data" || name == "mask" || name == "wdata" ||
144 name == "wmask" || name == "rdata")) {
145 emitConnect(builder, newField, oldField);
146 continue;
147 }
148 Value realOldField = oldField;
149 if (rType.getElement(fieldIndex).isFlip) {
150 // Cast the memory read data from flat type to aggregate.
151 auto castField =
152 builder.createOrFold<BitCastOp>(oldField.getType(), newField);
153 // Write the aggregate read data.
154 emitConnect(builder, realOldField, castField);
155 } else {
156 // Cast the input aggregate write data to flat type.
157 // Cast the input aggregate write data to flat type.
158 auto newFieldType = newField.getType();
159 auto oldFieldBitWidth = getBitWidth(oldField.getType());
160 // Following condition is true, if a data field is 0 bits. Then
161 // newFieldType is of smaller bits than old.
162 if (getBitWidth(newFieldType) != *oldFieldBitWidth)
163 newFieldType = UIntType::get(context, *oldFieldBitWidth);
164 realOldField = builder.create<BitCastOp>(newFieldType, oldField);
165 // Mask bits require special handling, since some of the mask bits
166 // need to be repeated, direct bitcasting wouldn't work. Depending
167 // on the mask granularity, some mask bits will be repeated.
168 if ((name == "mask" || name == "wmask") &&
169 (maskWidths.size() != totalmaskWidths)) {
170 Value catMasks;
171 for (const auto &m : llvm::enumerate(maskWidths)) {
172 // Get the mask bit.
173 auto mBit = builder.createOrFold<BitsPrimOp>(
174 realOldField, m.index(), m.index());
175 // Check how many times the mask bit needs to be prepend.
176 for (size_t repeat = 0; repeat < m.value(); repeat++)
177 if ((m.index() == 0 && repeat == 0) || !catMasks)
178 catMasks = mBit;
179 else
180 catMasks = builder.createOrFold<CatPrimOp>(mBit, catMasks);
181 }
182 realOldField = catMasks;
183 }
184 // Now set the mask or write data.
185 // Ensure that the types match.
186 emitConnect(builder, newField,
187 builder.createOrFold<BitCastOp>(newField.getType(),
188 realOldField));
189 }
190 }
191 }
192 ++numFlattenedMems;
193 memOp.erase();
194 return;
195 });
196 }
197
198private:
199 // Convert an aggregate type into a flat list of fields.
200 // This is used to flatten the aggregate memory datatype.
201 // Recursively populate the results with each ground type field.
202 static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
203 std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
205 .Case<BundleType>([&](auto bundle) {
206 for (auto &elt : bundle)
207 if (!flatten(elt.type))
208 return false;
209 return true;
210 })
211 .Case<FVectorType>([&](auto vector) {
212 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
213 if (!flatten(vector.getElementType()))
214 return false;
215 return true;
216 })
217 .Case<IntType>([&](auto iType) {
218 results.push_back({iType});
219 return iType.getWidth().has_value();
220 })
221 .Default([&](auto) { return false; });
222 };
223 // Return true only if this is an aggregate with more than one element.
224 if (flatten(type) && results.size() > 1)
225 return true;
226 return false;
227 }
228
229 Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
230 if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
231 return builder->create<SubfieldOp>(val, index);
232 if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
233 return builder->create<SubindexOp>(val, index);
234
235 llvm_unreachable("Unknown aggregate type");
236 return nullptr;
237 }
238};
239} // end anonymous namespace
240
241std::unique_ptr<mlir::Pass> circt::firrtl::createFlattenMemoryPass() {
242 return std::make_unique<FlattenMemoryPass>();
243}
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::unique_ptr< mlir::Pass > createFlattenMemoryPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.