18#include "mlir/Pass/Pass.h"
19#include "llvm/Support/Debug.h"
22#define DEBUG_TYPE "lower-memory"
26#define GEN_PASS_DEF_FLATTENMEMORY
27#include "circt/Dialect/FIRRTL/Passes.h.inc"
32using namespace firrtl;
35struct FlattenMemoryPass
36 :
public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
40 static bool hasSubAnno(MemOp op) {
41 for (
size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
42 for (
auto attr : op.getPortAnnotation(portIdx))
43 if (cast<DictionaryAttr>(attr).get(
"circt.fieldID"))
50 void runOnOperation()
override {
51 LLVM_DEBUG(llvm::dbgs() <<
"\n Running lower memory on module:"
53 SmallVector<Operation *> opsToErase;
54 getOperation().getBodyBlock()->walk([&](MemOp memOp) {
55 LLVM_DEBUG(llvm::dbgs() <<
"\n Memory:" << memOp);
58 if (type_isa<UIntType>(memOp.getDataType()))
64 for (
auto res : memOp.getResults())
65 if (isa<RefType>(res.getType()))
71 if (hasSubAnno(memOp))
76 SmallVector<FIRRTLBaseType> flatMemType;
77 if (!flattenType(memOp.getDataType(), flatMemType))
82 size_t memFlatWidth = 0;
83 SmallVector<int32_t> memWidths;
84 for (
auto f : flatMemType) {
85 LLVM_DEBUG(llvm::dbgs() <<
"\n field type:" << f);
86 auto w = f.getBitWidthOrSentinel();
87 memWidths.push_back(w);
97 auto maskGran = memWidths.front();
98 for (
auto w : ArrayRef(memWidths).drop_front())
99 maskGran = std::gcd(maskGran, w);
102 uint32_t totalmaskWidths = 0;
104 SmallVector<unsigned> maskWidths;
105 for (
auto w : memWidths) {
107 auto mWidth = w / maskGran;
108 maskWidths.push_back(mWidth);
109 totalmaskWidths += mWidth;
114 SmallVector<Type, 8> ports;
115 SmallVector<Attribute, 8> portNames;
117 auto *
context = memOp.getContext();
118 ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
120 auto flatType = UIntType::get(
context, memFlatWidth);
121 for (
auto port : memOp.getPorts()) {
122 ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
123 port.second, totalmaskWidths));
124 portNames.push_back(port.first);
128 auto flatMem = MemOp::create(
129 builder, ports, memOp.getReadLatency(), memOp.getWriteLatency(),
130 memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
131 memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
132 memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
133 memOp.getInitAttr(), memOp.getPrefixAttr());
136 for (
size_t index = 0, rend = memOp.getNumResults(); index < rend;
142 auto result = memOp.getResult(index);
145 builder, result.getType(),
146 (memOp.getName() +
"_" + memOp.getPortName(index)).str())
148 result.replaceAllUsesWith(wire);
150 auto newResult = flatMem.getResult(index);
151 auto rType = type_cast<BundleType>(result.getType());
152 for (
size_t fieldIndex = 0, fend = rType.getNumElements();
153 fieldIndex != fend; ++fieldIndex) {
154 auto name = rType.getElement(fieldIndex).name;
155 auto oldField = SubfieldOp::create(builder, result, fieldIndex);
157 SubfieldOp::create(builder, newResult, fieldIndex);
160 if (!(name ==
"data" || name ==
"mask" || name ==
"wdata" ||
161 name ==
"wmask" || name ==
"rdata")) {
165 Value realOldField = oldField;
166 if (rType.getElement(fieldIndex).isFlip) {
169 builder.createOrFold<BitCastOp>(oldField.getType(), newField);
174 auto newFieldType = newField.getType();
175 auto oldFieldBitWidth =
getBitWidth(oldField.getType());
178 if (
getBitWidth(newFieldType) != *oldFieldBitWidth)
179 newFieldType = UIntType::get(
context, *oldFieldBitWidth);
180 realOldField = BitCastOp::create(builder, newFieldType, oldField);
184 if ((name ==
"mask" || name ==
"wmask") &&
185 (maskWidths.size() != totalmaskWidths)) {
187 for (
const auto &m :
llvm::enumerate(maskWidths)) {
189 auto mBit = builder.createOrFold<BitsPrimOp>(
190 realOldField, m.index(), m.index());
192 for (
size_t repeat = 0; repeat < m.value(); repeat++)
193 if ((m.index() == 0 && repeat == 0) || !catMasks)
196 catMasks = builder.createOrFold<CatPrimOp>(
197 ValueRange{mBit, catMasks});
199 realOldField = catMasks;
204 builder.createOrFold<BitCastOp>(newField.getType(),
220 SmallVectorImpl<FIRRTLBaseType> &results) {
223 .
Case<BundleType>([&](
auto bundle) {
224 for (
auto &elt : bundle)
225 if (!flatten(elt.type))
229 .Case<FVectorType>([&](
auto vector) {
230 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i)
231 if (!flatten(vector.getElementType()))
235 .Case<IntType>([&](
IntType type) {
236 results.push_back(type);
239 .Case<FEnumType>([&](FEnumType type) {
240 results.emplace_back(type);
243 .Default([&](
auto) {
return false; });
245 return flatten(type);
248 Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val,
size_t index) {
249 if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
250 return SubfieldOp::create(*builder, val, index);
251 if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
252 return SubindexOp::create(*builder, val, index);
254 llvm_unreachable(
"Unknown aggregate type");
static std::unique_ptr< Context > context
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.