CIRCT  19.0.0git
FlattenMemory.cpp
Go to the documentation of this file.
1 //===- FlattenMemroy.cpp - Flatten Memory Pass ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the FlattenMemory pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 #include <numeric>
23 
24 #define DEBUG_TYPE "lower-memory"
25 
26 using namespace circt;
27 using namespace firrtl;
28 
29 namespace {
30 struct FlattenMemoryPass : public FlattenMemoryBase<FlattenMemoryPass> {
31  /// This pass flattens the aggregate data of memory into a UInt, and inserts
32  /// appropriate bitcasts to access the data.
33  void runOnOperation() override {
34  LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
35  << getOperation().getName());
36  SmallVector<Operation *> opsToErase;
37  auto hasSubAnno = [&](MemOp op) -> bool {
38  for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
39  for (auto attr : op.getPortAnnotation(portIdx))
40  if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
41  return true;
42 
43  return false;
44  };
45  getOperation().getBodyBlock()->walk([&](MemOp memOp) {
46  LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
47  // The vector of leaf elements type after flattening the data.
48  SmallVector<IntType> flatMemType;
49  // MaskGranularity : how many bits each mask bit controls.
50  size_t maskGran = 1;
51  // Total mask bitwidth after flattening.
52  uint32_t totalmaskWidths = 0;
53  // How many mask bits each field type requires.
54  SmallVector<unsigned> maskWidths;
55 
56  // Cannot flatten a memory if it has debug ports, because debug port
57  // implies a memtap and we cannot transform the datatype for a memory that
58  // is tapped.
59  for (auto res : memOp.getResults())
60  if (isa<RefType>(res.getType()))
61  return;
62  // If subannotations present on aggregate fields, we cannot flatten the
63  // memory. It must be split into one memory per aggregate field.
64  // Do not overwrite the pass flag!
65  if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
66  return;
67 
68  SmallVector<Operation *, 8> flatData;
69  SmallVector<int32_t> memWidths;
70  size_t memFlatWidth = 0;
71  // Get the width of individual aggregate leaf elements.
72  for (auto f : flatMemType) {
73  LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
74  auto w = *f.getWidth();
75  memWidths.push_back(w);
76  memFlatWidth += w;
77  }
78  // If all the widths are zero, ignore the memory.
79  if (!memFlatWidth)
80  return;
81  maskGran = memWidths[0];
82  // Compute the GCD of all data bitwidths.
83  for (auto w : memWidths) {
84  maskGran = std::gcd(maskGran, w);
85  }
86  for (auto w : memWidths) {
87  // How many mask bits required for each flattened field.
88  auto mWidth = w / maskGran;
89  maskWidths.push_back(mWidth);
90  totalmaskWidths += mWidth;
91  }
92  // Now create a new memory of type flattened data.
93  // ----------------------------------------------
94  SmallVector<Type, 8> ports;
95  SmallVector<Attribute, 8> portNames;
96 
97  auto *context = memOp.getContext();
98  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
99  // Create a new memoty data type of unsigned and computed width.
100  auto flatType = UIntType::get(context, memFlatWidth);
101  auto opPorts = memOp.getPorts();
102  for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
103  auto port = opPorts[portIdx];
104  ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
105  port.second, totalmaskWidths));
106  portNames.push_back(port.first);
107  }
108 
109  auto flatMem = builder.create<MemOp>(
110  ports, memOp.getReadLatency(), memOp.getWriteLatency(),
111  memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
112  memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
113  memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
114  memOp.getInitAttr(), memOp.getPrefixAttr());
115  // Hook up the new memory to the wires the old memory was replaced with.
116  for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
117  ++index) {
118  auto result = memOp.getResult(index);
119  auto wire = builder
120  .create<WireOp>(result.getType(),
121  (memOp.getName() + "_" +
122  memOp.getPortName(index).getValue())
123  .str())
124  .getResult();
125  result.replaceAllUsesWith(wire);
126  result = wire;
127  auto newResult = flatMem.getResult(index);
128  auto rType = type_cast<BundleType>(result.getType());
129  for (size_t fieldIndex = 0, fend = rType.getNumElements();
130  fieldIndex != fend; ++fieldIndex) {
131  auto name = rType.getElement(fieldIndex).name.getValue();
132  auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
133  FIRRTLBaseValue newField =
134  builder.create<SubfieldOp>(newResult, fieldIndex);
135  // data and mask depend on the memory type which was split. They can
136  // also go both directions, depending on the port direction.
137  if (!(name == "data" || name == "mask" || name == "wdata" ||
138  name == "wmask" || name == "rdata")) {
139  emitConnect(builder, newField, oldField);
140  continue;
141  }
142  Value realOldField = oldField;
143  if (rType.getElement(fieldIndex).isFlip) {
144  // Cast the memory read data from flat type to aggregate.
145  auto castField =
146  builder.createOrFold<BitCastOp>(oldField.getType(), newField);
147  // Write the aggregate read data.
148  emitConnect(builder, realOldField, castField);
149  } else {
150  // Cast the input aggregate write data to flat type.
151  // Cast the input aggregate write data to flat type.
152  auto newFieldType = newField.getType();
153  auto oldFieldBitWidth = getBitWidth(oldField.getType());
154  // Following condition is true, if a data field is 0 bits. Then
155  // newFieldType is of smaller bits than old.
156  if (getBitWidth(newFieldType) != *oldFieldBitWidth)
157  newFieldType = UIntType::get(context, *oldFieldBitWidth);
158  realOldField = builder.create<BitCastOp>(newFieldType, oldField);
159  // Mask bits require special handling, since some of the mask bits
160  // need to be repeated, direct bitcasting wouldn't work. Depending
161  // on the mask granularity, some mask bits will be repeated.
162  if ((name == "mask" || name == "wmask") &&
163  (maskWidths.size() != totalmaskWidths)) {
164  Value catMasks;
165  for (const auto &m : llvm::enumerate(maskWidths)) {
166  // Get the mask bit.
167  auto mBit = builder.createOrFold<BitsPrimOp>(
168  realOldField, m.index(), m.index());
169  // Check how many times the mask bit needs to be prepend.
170  for (size_t repeat = 0; repeat < m.value(); repeat++)
171  if ((m.index() == 0 && repeat == 0) || !catMasks)
172  catMasks = mBit;
173  else
174  catMasks = builder.createOrFold<CatPrimOp>(mBit, catMasks);
175  }
176  realOldField = catMasks;
177  }
178  // Now set the mask or write data.
179  // Ensure that the types match.
180  emitConnect(builder, newField,
181  builder.createOrFold<BitCastOp>(newField.getType(),
182  realOldField));
183  }
184  }
185  }
186  ++numFlattenedMems;
187  memOp.erase();
188  return;
189  });
190  }
191 
192 private:
193  // Convert an aggregate type into a flat list of fields.
194  // This is used to flatten the aggregate memory datatype.
195  // Recursively populate the results with each ground type field.
196  static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
197  std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
199  .Case<BundleType>([&](auto bundle) {
200  for (auto &elt : bundle)
201  if (!flatten(elt.type))
202  return false;
203  return true;
204  })
205  .Case<FVectorType>([&](auto vector) {
206  for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
207  if (!flatten(vector.getElementType()))
208  return false;
209  return true;
210  })
211  .Case<IntType>([&](auto iType) {
212  results.push_back({iType});
213  return iType.getWidth().has_value();
214  })
215  .Default([&](auto) { return false; });
216  };
217  // Return true only if this is an aggregate with more than one element.
218  if (flatten(type) && results.size() > 1)
219  return true;
220  return false;
221  }
222 
223  Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
224  if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
225  return builder->create<SubfieldOp>(val, index);
226  if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
227  return builder->create<SubindexOp>(val, index);
228 
229  llvm_unreachable("Unknown aggregate type");
230  return nullptr;
231  }
232 };
233 } // end anonymous namespace
234 
235 std::unique_ptr<mlir::Pass> circt::firrtl::createFlattenMemoryPass() {
236  return std::make_unique<FlattenMemoryPass>();
237 }
Builder builder
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:518
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:528
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
Definition: FIRRTLTypes.h:392
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:24
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::unique_ptr< mlir::Pass > createFlattenMemoryPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21