CIRCT  20.0.0git
FlattenMemory.cpp
Go to the documentation of this file.
1 //===- FlattenMemroy.cpp - Flatten Memory Pass ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the FlattenMemory pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/Pass/Pass.h"
19 #include "llvm/Support/Debug.h"
20 #include <numeric>
21 
22 #define DEBUG_TYPE "lower-memory"
23 
24 namespace circt {
25 namespace firrtl {
26 #define GEN_PASS_DEF_FLATTENMEMORY
27 #include "circt/Dialect/FIRRTL/Passes.h.inc"
28 } // namespace firrtl
29 } // namespace circt
30 
31 using namespace circt;
32 using namespace firrtl;
33 
34 namespace {
35 struct FlattenMemoryPass
36  : public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
37  /// This pass flattens the aggregate data of memory into a UInt, and inserts
38  /// appropriate bitcasts to access the data.
39  void runOnOperation() override {
40  LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
41  << getOperation().getName());
42  SmallVector<Operation *> opsToErase;
43  auto hasSubAnno = [&](MemOp op) -> bool {
44  for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
45  for (auto attr : op.getPortAnnotation(portIdx))
46  if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
47  return true;
48 
49  return false;
50  };
51  getOperation().getBodyBlock()->walk([&](MemOp memOp) {
52  LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
53  // The vector of leaf elements type after flattening the data.
54  SmallVector<IntType> flatMemType;
55  // MaskGranularity : how many bits each mask bit controls.
56  size_t maskGran = 1;
57  // Total mask bitwidth after flattening.
58  uint32_t totalmaskWidths = 0;
59  // How many mask bits each field type requires.
60  SmallVector<unsigned> maskWidths;
61 
62  // Cannot flatten a memory if it has debug ports, because debug port
63  // implies a memtap and we cannot transform the datatype for a memory that
64  // is tapped.
65  for (auto res : memOp.getResults())
66  if (isa<RefType>(res.getType()))
67  return;
68  // If subannotations present on aggregate fields, we cannot flatten the
69  // memory. It must be split into one memory per aggregate field.
70  // Do not overwrite the pass flag!
71  if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
72  return;
73 
74  SmallVector<Operation *, 8> flatData;
75  SmallVector<int32_t> memWidths;
76  size_t memFlatWidth = 0;
77  // Get the width of individual aggregate leaf elements.
78  for (auto f : flatMemType) {
79  LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
80  auto w = *f.getWidth();
81  memWidths.push_back(w);
82  memFlatWidth += w;
83  }
84  // If all the widths are zero, ignore the memory.
85  if (!memFlatWidth)
86  return;
87  maskGran = memWidths[0];
88  // Compute the GCD of all data bitwidths.
89  for (auto w : memWidths) {
90  maskGran = std::gcd(maskGran, w);
91  }
92  for (auto w : memWidths) {
93  // How many mask bits required for each flattened field.
94  auto mWidth = w / maskGran;
95  maskWidths.push_back(mWidth);
96  totalmaskWidths += mWidth;
97  }
98  // Now create a new memory of type flattened data.
99  // ----------------------------------------------
100  SmallVector<Type, 8> ports;
101  SmallVector<Attribute, 8> portNames;
102 
103  auto *context = memOp.getContext();
104  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
105  // Create a new memoty data type of unsigned and computed width.
106  auto flatType = UIntType::get(context, memFlatWidth);
107  auto opPorts = memOp.getPorts();
108  for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
109  auto port = opPorts[portIdx];
110  ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
111  port.second, totalmaskWidths));
112  portNames.push_back(port.first);
113  }
114 
115  auto flatMem = builder.create<MemOp>(
116  ports, memOp.getReadLatency(), memOp.getWriteLatency(),
117  memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
118  memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
119  memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
120  memOp.getInitAttr(), memOp.getPrefixAttr());
121  // Hook up the new memory to the wires the old memory was replaced with.
122  for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
123  ++index) {
124  auto result = memOp.getResult(index);
125  auto wire = builder
126  .create<WireOp>(result.getType(),
127  (memOp.getName() + "_" +
128  memOp.getPortName(index).getValue())
129  .str())
130  .getResult();
131  result.replaceAllUsesWith(wire);
132  result = wire;
133  auto newResult = flatMem.getResult(index);
134  auto rType = type_cast<BundleType>(result.getType());
135  for (size_t fieldIndex = 0, fend = rType.getNumElements();
136  fieldIndex != fend; ++fieldIndex) {
137  auto name = rType.getElement(fieldIndex).name.getValue();
138  auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
139  FIRRTLBaseValue newField =
140  builder.create<SubfieldOp>(newResult, fieldIndex);
141  // data and mask depend on the memory type which was split. They can
142  // also go both directions, depending on the port direction.
143  if (!(name == "data" || name == "mask" || name == "wdata" ||
144  name == "wmask" || name == "rdata")) {
145  emitConnect(builder, newField, oldField);
146  continue;
147  }
148  Value realOldField = oldField;
149  if (rType.getElement(fieldIndex).isFlip) {
150  // Cast the memory read data from flat type to aggregate.
151  auto castField =
152  builder.createOrFold<BitCastOp>(oldField.getType(), newField);
153  // Write the aggregate read data.
154  emitConnect(builder, realOldField, castField);
155  } else {
156  // Cast the input aggregate write data to flat type.
157  // Cast the input aggregate write data to flat type.
158  auto newFieldType = newField.getType();
159  auto oldFieldBitWidth = getBitWidth(oldField.getType());
160  // Following condition is true, if a data field is 0 bits. Then
161  // newFieldType is of smaller bits than old.
162  if (getBitWidth(newFieldType) != *oldFieldBitWidth)
163  newFieldType = UIntType::get(context, *oldFieldBitWidth);
164  realOldField = builder.create<BitCastOp>(newFieldType, oldField);
165  // Mask bits require special handling, since some of the mask bits
166  // need to be repeated, direct bitcasting wouldn't work. Depending
167  // on the mask granularity, some mask bits will be repeated.
168  if ((name == "mask" || name == "wmask") &&
169  (maskWidths.size() != totalmaskWidths)) {
170  Value catMasks;
171  for (const auto &m : llvm::enumerate(maskWidths)) {
172  // Get the mask bit.
173  auto mBit = builder.createOrFold<BitsPrimOp>(
174  realOldField, m.index(), m.index());
175  // Check how many times the mask bit needs to be prepend.
176  for (size_t repeat = 0; repeat < m.value(); repeat++)
177  if ((m.index() == 0 && repeat == 0) || !catMasks)
178  catMasks = mBit;
179  else
180  catMasks = builder.createOrFold<CatPrimOp>(mBit, catMasks);
181  }
182  realOldField = catMasks;
183  }
184  // Now set the mask or write data.
185  // Ensure that the types match.
186  emitConnect(builder, newField,
187  builder.createOrFold<BitCastOp>(newField.getType(),
188  realOldField));
189  }
190  }
191  }
192  ++numFlattenedMems;
193  memOp.erase();
194  return;
195  });
196  }
197 
198 private:
199  // Convert an aggregate type into a flat list of fields.
200  // This is used to flatten the aggregate memory datatype.
201  // Recursively populate the results with each ground type field.
202  static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
203  std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
205  .Case<BundleType>([&](auto bundle) {
206  for (auto &elt : bundle)
207  if (!flatten(elt.type))
208  return false;
209  return true;
210  })
211  .Case<FVectorType>([&](auto vector) {
212  for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
213  if (!flatten(vector.getElementType()))
214  return false;
215  return true;
216  })
217  .Case<IntType>([&](auto iType) {
218  results.push_back({iType});
219  return iType.getWidth().has_value();
220  })
221  .Default([&](auto) { return false; });
222  };
223  // Return true only if this is an aggregate with more than one element.
224  if (flatten(type) && results.size() > 1)
225  return true;
226  return false;
227  }
228 
229  Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
230  if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
231  return builder->create<SubfieldOp>(val, index);
232  if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
233  return builder->create<SubindexOp>(val, index);
234 
235  llvm_unreachable("Unknown aggregate type");
236  return nullptr;
237  }
238 };
239 } // end anonymous namespace
240 
241 std::unique_ptr<mlir::Pass> circt::firrtl::createFlattenMemoryPass() {
242  return std::make_unique<FlattenMemoryPass>();
243 }
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
Definition: FIRRTLTypes.h:394
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:25
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::unique_ptr< mlir::Pass > createFlattenMemoryPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21