CIRCT  20.0.0git
FlattenMemory.cpp
Go to the documentation of this file.
1 //===- FlattenMemroy.cpp - Flatten Memory Pass ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the FlattenMemory pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Pass/Pass.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 #include <numeric>
23 
24 #define DEBUG_TYPE "lower-memory"
25 
26 namespace circt {
27 namespace firrtl {
28 #define GEN_PASS_DEF_FLATTENMEMORY
29 #include "circt/Dialect/FIRRTL/Passes.h.inc"
30 } // namespace firrtl
31 } // namespace circt
32 
33 using namespace circt;
34 using namespace firrtl;
35 
36 namespace {
37 struct FlattenMemoryPass
38  : public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
39  /// This pass flattens the aggregate data of memory into a UInt, and inserts
40  /// appropriate bitcasts to access the data.
41  void runOnOperation() override {
42  LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
43  << getOperation().getName());
44  SmallVector<Operation *> opsToErase;
45  auto hasSubAnno = [&](MemOp op) -> bool {
46  for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
47  for (auto attr : op.getPortAnnotation(portIdx))
48  if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
49  return true;
50 
51  return false;
52  };
53  getOperation().getBodyBlock()->walk([&](MemOp memOp) {
54  LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
55  // The vector of leaf elements type after flattening the data.
56  SmallVector<IntType> flatMemType;
57  // MaskGranularity : how many bits each mask bit controls.
58  size_t maskGran = 1;
59  // Total mask bitwidth after flattening.
60  uint32_t totalmaskWidths = 0;
61  // How many mask bits each field type requires.
62  SmallVector<unsigned> maskWidths;
63 
64  // Cannot flatten a memory if it has debug ports, because debug port
65  // implies a memtap and we cannot transform the datatype for a memory that
66  // is tapped.
67  for (auto res : memOp.getResults())
68  if (isa<RefType>(res.getType()))
69  return;
70  // If subannotations present on aggregate fields, we cannot flatten the
71  // memory. It must be split into one memory per aggregate field.
72  // Do not overwrite the pass flag!
73  if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
74  return;
75 
76  SmallVector<Operation *, 8> flatData;
77  SmallVector<int32_t> memWidths;
78  size_t memFlatWidth = 0;
79  // Get the width of individual aggregate leaf elements.
80  for (auto f : flatMemType) {
81  LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
82  auto w = *f.getWidth();
83  memWidths.push_back(w);
84  memFlatWidth += w;
85  }
86  // If all the widths are zero, ignore the memory.
87  if (!memFlatWidth)
88  return;
89  maskGran = memWidths[0];
90  // Compute the GCD of all data bitwidths.
91  for (auto w : memWidths) {
92  maskGran = std::gcd(maskGran, w);
93  }
94  for (auto w : memWidths) {
95  // How many mask bits required for each flattened field.
96  auto mWidth = w / maskGran;
97  maskWidths.push_back(mWidth);
98  totalmaskWidths += mWidth;
99  }
100  // Now create a new memory of type flattened data.
101  // ----------------------------------------------
102  SmallVector<Type, 8> ports;
103  SmallVector<Attribute, 8> portNames;
104 
105  auto *context = memOp.getContext();
106  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
107  // Create a new memoty data type of unsigned and computed width.
108  auto flatType = UIntType::get(context, memFlatWidth);
109  auto opPorts = memOp.getPorts();
110  for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
111  auto port = opPorts[portIdx];
112  ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
113  port.second, totalmaskWidths));
114  portNames.push_back(port.first);
115  }
116 
117  auto flatMem = builder.create<MemOp>(
118  ports, memOp.getReadLatency(), memOp.getWriteLatency(),
119  memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
120  memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
121  memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
122  memOp.getInitAttr(), memOp.getPrefixAttr());
123  // Hook up the new memory to the wires the old memory was replaced with.
124  for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
125  ++index) {
126  auto result = memOp.getResult(index);
127  auto wire = builder
128  .create<WireOp>(result.getType(),
129  (memOp.getName() + "_" +
130  memOp.getPortName(index).getValue())
131  .str())
132  .getResult();
133  result.replaceAllUsesWith(wire);
134  result = wire;
135  auto newResult = flatMem.getResult(index);
136  auto rType = type_cast<BundleType>(result.getType());
137  for (size_t fieldIndex = 0, fend = rType.getNumElements();
138  fieldIndex != fend; ++fieldIndex) {
139  auto name = rType.getElement(fieldIndex).name.getValue();
140  auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
141  FIRRTLBaseValue newField =
142  builder.create<SubfieldOp>(newResult, fieldIndex);
143  // data and mask depend on the memory type which was split. They can
144  // also go both directions, depending on the port direction.
145  if (!(name == "data" || name == "mask" || name == "wdata" ||
146  name == "wmask" || name == "rdata")) {
147  emitConnect(builder, newField, oldField);
148  continue;
149  }
150  Value realOldField = oldField;
151  if (rType.getElement(fieldIndex).isFlip) {
152  // Cast the memory read data from flat type to aggregate.
153  auto castField =
154  builder.createOrFold<BitCastOp>(oldField.getType(), newField);
155  // Write the aggregate read data.
156  emitConnect(builder, realOldField, castField);
157  } else {
158  // Cast the input aggregate write data to flat type.
159  // Cast the input aggregate write data to flat type.
160  auto newFieldType = newField.getType();
161  auto oldFieldBitWidth = getBitWidth(oldField.getType());
162  // Following condition is true, if a data field is 0 bits. Then
163  // newFieldType is of smaller bits than old.
164  if (getBitWidth(newFieldType) != *oldFieldBitWidth)
165  newFieldType = UIntType::get(context, *oldFieldBitWidth);
166  realOldField = builder.create<BitCastOp>(newFieldType, oldField);
167  // Mask bits require special handling, since some of the mask bits
168  // need to be repeated, direct bitcasting wouldn't work. Depending
169  // on the mask granularity, some mask bits will be repeated.
170  if ((name == "mask" || name == "wmask") &&
171  (maskWidths.size() != totalmaskWidths)) {
172  Value catMasks;
173  for (const auto &m : llvm::enumerate(maskWidths)) {
174  // Get the mask bit.
175  auto mBit = builder.createOrFold<BitsPrimOp>(
176  realOldField, m.index(), m.index());
177  // Check how many times the mask bit needs to be prepend.
178  for (size_t repeat = 0; repeat < m.value(); repeat++)
179  if ((m.index() == 0 && repeat == 0) || !catMasks)
180  catMasks = mBit;
181  else
182  catMasks = builder.createOrFold<CatPrimOp>(mBit, catMasks);
183  }
184  realOldField = catMasks;
185  }
186  // Now set the mask or write data.
187  // Ensure that the types match.
188  emitConnect(builder, newField,
189  builder.createOrFold<BitCastOp>(newField.getType(),
190  realOldField));
191  }
192  }
193  }
194  ++numFlattenedMems;
195  memOp.erase();
196  return;
197  });
198  }
199 
200 private:
201  // Convert an aggregate type into a flat list of fields.
202  // This is used to flatten the aggregate memory datatype.
203  // Recursively populate the results with each ground type field.
204  static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
205  std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
207  .Case<BundleType>([&](auto bundle) {
208  for (auto &elt : bundle)
209  if (!flatten(elt.type))
210  return false;
211  return true;
212  })
213  .Case<FVectorType>([&](auto vector) {
214  for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
215  if (!flatten(vector.getElementType()))
216  return false;
217  return true;
218  })
219  .Case<IntType>([&](auto iType) {
220  results.push_back({iType});
221  return iType.getWidth().has_value();
222  })
223  .Default([&](auto) { return false; });
224  };
225  // Return true only if this is an aggregate with more than one element.
226  if (flatten(type) && results.size() > 1)
227  return true;
228  return false;
229  }
230 
231  Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
232  if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
233  return builder->create<SubfieldOp>(val, index);
234  if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
235  return builder->create<SubindexOp>(val, index);
236 
237  llvm_unreachable("Unknown aggregate type");
238  return nullptr;
239  }
240 };
241 } // end anonymous namespace
242 
243 std::unique_ptr<mlir::Pass> circt::firrtl::createFlattenMemoryPass() {
244  return std::make_unique<FlattenMemoryPass>();
245 }
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
Definition: FIRRTLTypes.h:394
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:25
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::unique_ptr< mlir::Pass > createFlattenMemoryPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21