18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Pass/Pass.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "lower-memory"
28 #define GEN_PASS_DEF_FLATTENMEMORY
29 #include "circt/Dialect/FIRRTL/Passes.h.inc"
33 using namespace circt;
34 using namespace firrtl;
37 struct FlattenMemoryPass
38 :
public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
41 void runOnOperation()
override {
42 LLVM_DEBUG(llvm::dbgs() <<
"\n Running lower memory on module:"
44 SmallVector<Operation *> opsToErase;
45 auto hasSubAnno = [&](MemOp op) ->
bool {
46 for (
size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
47 for (
auto attr : op.getPortAnnotation(portIdx))
48 if (cast<DictionaryAttr>(attr).get(
"circt.fieldID"))
53 getOperation().getBodyBlock()->walk([&](MemOp memOp) {
54 LLVM_DEBUG(llvm::dbgs() <<
"\n Memory:" << memOp);
56 SmallVector<IntType> flatMemType;
60 uint32_t totalmaskWidths = 0;
62 SmallVector<unsigned> maskWidths;
67 for (
auto res : memOp.getResults())
68 if (isa<RefType>(res.getType()))
73 if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
76 SmallVector<Operation *, 8> flatData;
77 SmallVector<int32_t> memWidths;
78 size_t memFlatWidth = 0;
80 for (
auto f : flatMemType) {
81 LLVM_DEBUG(llvm::dbgs() <<
"\n field type:" << f);
82 auto w = *f.getWidth();
83 memWidths.push_back(w);
89 maskGran = memWidths[0];
91 for (
auto w : memWidths) {
92 maskGran = std::gcd(maskGran, w);
94 for (
auto w : memWidths) {
96 auto mWidth = w / maskGran;
97 maskWidths.push_back(mWidth);
98 totalmaskWidths += mWidth;
102 SmallVector<Type, 8> ports;
103 SmallVector<Attribute, 8> portNames;
105 auto *context = memOp.getContext();
106 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
109 auto opPorts = memOp.getPorts();
110 for (
size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
111 auto port = opPorts[portIdx];
112 ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
113 port.second, totalmaskWidths));
114 portNames.push_back(port.first);
117 auto flatMem = builder.create<MemOp>(
118 ports, memOp.getReadLatency(), memOp.getWriteLatency(),
119 memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
120 memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
121 memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
122 memOp.getInitAttr(), memOp.getPrefixAttr());
124 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
126 auto result = memOp.getResult(index);
128 .create<WireOp>(result.getType(),
129 (memOp.getName() +
"_" +
130 memOp.getPortName(index).getValue())
133 result.replaceAllUsesWith(wire);
135 auto newResult = flatMem.getResult(index);
136 auto rType = type_cast<BundleType>(result.getType());
137 for (
size_t fieldIndex = 0, fend = rType.getNumElements();
138 fieldIndex != fend; ++fieldIndex) {
139 auto name = rType.getElement(fieldIndex).name.getValue();
140 auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
142 builder.create<SubfieldOp>(newResult, fieldIndex);
145 if (!(name ==
"data" || name ==
"mask" || name ==
"wdata" ||
146 name ==
"wmask" || name ==
"rdata")) {
150 Value realOldField = oldField;
151 if (rType.getElement(fieldIndex).isFlip) {
154 builder.createOrFold<BitCastOp>(oldField.getType(), newField);
160 auto newFieldType = newField.getType();
161 auto oldFieldBitWidth =
getBitWidth(oldField.getType());
164 if (
getBitWidth(newFieldType) != *oldFieldBitWidth)
166 realOldField = builder.create<BitCastOp>(newFieldType, oldField);
170 if ((name ==
"mask" || name ==
"wmask") &&
171 (maskWidths.size() != totalmaskWidths)) {
173 for (
const auto &m : llvm::enumerate(maskWidths)) {
175 auto mBit = builder.createOrFold<BitsPrimOp>(
176 realOldField, m.index(), m.index());
178 for (
size_t repeat = 0; repeat < m.value(); repeat++)
179 if ((m.index() == 0 && repeat == 0) || !catMasks)
182 catMasks = builder.createOrFold<CatPrimOp>(mBit, catMasks);
184 realOldField = catMasks;
189 builder.createOrFold<BitCastOp>(newField.getType(),
204 static bool flattenType(
FIRRTLType type, SmallVectorImpl<IntType> &results) {
207 .
Case<BundleType>([&](
auto bundle) {
208 for (
auto &elt : bundle)
209 if (!flatten(elt.type))
213 .Case<FVectorType>([&](
auto vector) {
214 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i)
215 if (!flatten(vector.getElementType()))
219 .Case<IntType>([&](
auto iType) {
220 results.push_back({iType});
221 return iType.getWidth().has_value();
223 .Default([&](
auto) {
return false; });
226 if (flatten(type) && results.size() > 1)
231 Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val,
size_t index) {
232 if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
233 return builder->create<SubfieldOp>(val, index);
234 if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
235 return builder->create<SubindexOp>(val, index);
237 llvm_unreachable(
"Unknown aggregate type");
244 return std::make_unique<FlattenMemoryPass>();
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::unique_ptr< mlir::Pass > createFlattenMemoryPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.