19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "lower-memory"
26 using namespace circt;
27 using namespace firrtl;
30 struct FlattenMemoryPass :
public FlattenMemoryBase<FlattenMemoryPass> {
33 void runOnOperation()
override {
34 LLVM_DEBUG(
llvm::dbgs() <<
"\n Running lower memory on module:"
36 SmallVector<Operation *> opsToErase;
37 auto hasSubAnno = [&](MemOp op) ->
bool {
38 for (
size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
39 for (
auto attr : op.getPortAnnotation(portIdx))
40 if (cast<DictionaryAttr>(attr).get(
"circt.fieldID"))
45 getOperation().getBodyBlock()->walk([&](MemOp memOp) {
46 LLVM_DEBUG(
llvm::dbgs() <<
"\n Memory:" << memOp);
48 SmallVector<IntType> flatMemType;
52 uint32_t totalmaskWidths = 0;
54 SmallVector<unsigned> maskWidths;
59 for (
auto res : memOp.getResults())
60 if (isa<RefType>(res.getType()))
65 if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
68 SmallVector<Operation *, 8> flatData;
69 SmallVector<int32_t> memWidths;
70 size_t memFlatWidth = 0;
72 for (
auto f : flatMemType) {
73 LLVM_DEBUG(
llvm::dbgs() <<
"\n field type:" << f);
74 auto w = *f.getWidth();
75 memWidths.push_back(w);
81 maskGran = memWidths[0];
83 for (
auto w : memWidths) {
84 maskGran = std::gcd(maskGran, w);
86 for (
auto w : memWidths) {
88 auto mWidth = w / maskGran;
89 maskWidths.push_back(mWidth);
90 totalmaskWidths += mWidth;
94 SmallVector<Type, 8> ports;
95 SmallVector<Attribute, 8> portNames;
97 auto *context = memOp.getContext();
98 ImplicitLocOpBuilder
builder(memOp.getLoc(), memOp);
101 auto opPorts = memOp.getPorts();
102 for (
size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
103 auto port = opPorts[portIdx];
104 ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
105 port.second, totalmaskWidths));
106 portNames.push_back(port.first);
109 auto flatMem =
builder.create<MemOp>(
110 ports, memOp.getReadLatency(), memOp.getWriteLatency(),
111 memOp.getDepth(), memOp.getRuw(),
builder.getArrayAttr(portNames),
112 memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
113 memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
114 memOp.getInitAttr(), memOp.getPrefixAttr());
116 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
118 auto result = memOp.getResult(index);
120 .create<WireOp>(result.getType(),
121 (memOp.getName() +
"_" +
122 memOp.getPortName(index).getValue())
125 result.replaceAllUsesWith(wire);
127 auto newResult = flatMem.getResult(index);
128 auto rType = type_cast<BundleType>(result.getType());
129 for (
size_t fieldIndex = 0, fend = rType.getNumElements();
130 fieldIndex != fend; ++fieldIndex) {
131 auto name = rType.getElement(fieldIndex).name.getValue();
132 auto oldField =
builder.create<SubfieldOp>(result, fieldIndex);
134 builder.create<SubfieldOp>(newResult, fieldIndex);
137 if (!(name ==
"data" || name ==
"mask" || name ==
"wdata" ||
138 name ==
"wmask" || name ==
"rdata")) {
142 Value realOldField = oldField;
143 if (rType.getElement(fieldIndex).isFlip) {
146 builder.createOrFold<BitCastOp>(oldField.getType(), newField);
152 auto newFieldType = newField.getType();
153 auto oldFieldBitWidth =
getBitWidth(oldField.getType());
156 if (
getBitWidth(newFieldType) != *oldFieldBitWidth)
158 realOldField =
builder.create<BitCastOp>(newFieldType, oldField);
162 if ((name ==
"mask" || name ==
"wmask") &&
163 (maskWidths.size() != totalmaskWidths)) {
165 for (
const auto &m : llvm::enumerate(maskWidths)) {
167 auto mBit =
builder.createOrFold<BitsPrimOp>(
168 realOldField, m.index(), m.index());
170 for (
size_t repeat = 0; repeat < m.value(); repeat++)
171 if ((m.index() == 0 && repeat == 0) || !catMasks)
174 catMasks =
builder.createOrFold<CatPrimOp>(mBit, catMasks);
176 realOldField = catMasks;
181 builder.createOrFold<BitCastOp>(newField.getType(),
196 static bool flattenType(
FIRRTLType type, SmallVectorImpl<IntType> &results) {
199 .
Case<BundleType>([&](
auto bundle) {
200 for (
auto &elt : bundle)
201 if (!flatten(elt.type))
205 .Case<FVectorType>([&](
auto vector) {
206 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i)
207 if (!flatten(vector.getElementType()))
211 .Case<IntType>([&](
auto iType) {
212 results.push_back({iType});
213 return iType.getWidth().has_value();
215 .Default([&](
auto) {
return false; });
218 if (flatten(type) && results.size() > 1)
223 Value getSubWhatever(ImplicitLocOpBuilder *
builder, Value val,
size_t index) {
224 if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
225 return builder->create<SubfieldOp>(val, index);
226 if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
227 return builder->create<SubindexOp>(val, index);
229 llvm_unreachable(
"Unknown aggregate type");
236 return std::make_unique<FlattenMemoryPass>();
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::unique_ptr< mlir::Pass > createFlattenMemoryPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()