CIRCT  18.0.0git
FlattenMemory.cpp
Go to the documentation of this file.
1 //===- FlattenMemroy.cpp - Flatten Memory Pass ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the FlattenMemory pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 #include <numeric>
23 
24 #define DEBUG_TYPE "lower-memory"
25 
26 using namespace circt;
27 using namespace firrtl;
28 
29 namespace {
30 struct FlattenMemoryPass : public FlattenMemoryBase<FlattenMemoryPass> {
31  /// This pass flattens the aggregate data of memory into a UInt, and inserts
32  /// appropriate bitcasts to access the data.
33  void runOnOperation() override {
34  LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
35  << getOperation().getName());
36  SmallVector<Operation *> opsToErase;
37  auto hasSubAnno = [&](MemOp op) -> bool {
38  for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
39  for (auto attr : op.getPortAnnotation(portIdx))
40  if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
41  return true;
42 
43  return false;
44  };
45  getOperation().getBodyBlock()->walk([&](MemOp memOp) {
46  LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
47  // The vector of leaf elements type after flattening the data.
48  SmallVector<IntType> flatMemType;
49  // MaskGranularity : how many bits each mask bit controls.
50  size_t maskGran = 1;
51  // Total mask bitwidth after flattening.
52  uint32_t totalmaskWidths = 0;
53  // How many mask bits each field type requires.
54  SmallVector<unsigned> maskWidths;
55 
56  // Cannot flatten a memory if it has debug ports, because debug port
57  // implies a memtap and we cannot transform the datatype for a memory that
58  // is tapped.
59  for (auto res : memOp.getResults())
60  if (isa<RefType>(res.getType()))
61  return;
62  // If subannotations present on aggregate fields, we cannot flatten the
63  // memory. It must be split into one memory per aggregate field.
64  // Do not overwrite the pass flag!
65  if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
66  return;
67 
68  SmallVector<Operation *, 8> flatData;
69  SmallVector<int32_t> memWidths;
70  size_t memFlatWidth = 0;
71  // Get the width of individual aggregate leaf elements.
72  for (auto f : flatMemType) {
73  LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
74  auto w = *f.getWidth();
75  memWidths.push_back(w);
76  memFlatWidth += w;
77  }
78  // If all the widths are zero, ignore the memory.
79  if (!memFlatWidth)
80  return;
81  maskGran = memWidths[0];
82  // Compute the GCD of all data bitwidths.
83  for (auto w : memWidths) {
84  maskGran = std::gcd(maskGran, w);
85  }
86  for (auto w : memWidths) {
87  // How many mask bits required for each flattened field.
88  auto mWidth = w / maskGran;
89  maskWidths.push_back(mWidth);
90  totalmaskWidths += mWidth;
91  }
92  // Now create a new memory of type flattened data.
93  // ----------------------------------------------
94  SmallVector<Type, 8> ports;
95  SmallVector<Attribute, 8> portNames;
96 
97  auto *context = memOp.getContext();
98  ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
99  // Create a new memoty data type of unsigned and computed width.
100  auto flatType = UIntType::get(context, memFlatWidth);
101  auto opPorts = memOp.getPorts();
102  for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
103  auto port = opPorts[portIdx];
104  ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
105  port.second, totalmaskWidths));
106  portNames.push_back(port.first);
107  }
108 
109  auto flatMem = builder.create<MemOp>(
110  ports, memOp.getReadLatency(), memOp.getWriteLatency(),
111  memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
112  memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
113  memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
114  memOp.getInitAttr(), memOp.getPrefixAttr());
115  // Hook up the new memory to the wires the old memory was replaced with.
116  for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
117  ++index) {
118  auto result = memOp.getResult(index);
119  auto wire = builder
120  .create<WireOp>(result.getType(),
121  (memOp.getName() + "_" +
122  memOp.getPortName(index).getValue())
123  .str())
124  .getResult();
125  result.replaceAllUsesWith(wire);
126  result = wire;
127  auto newResult = flatMem.getResult(index);
128  auto rType = type_cast<BundleType>(result.getType());
129  for (size_t fieldIndex = 0, fend = rType.getNumElements();
130  fieldIndex != fend; ++fieldIndex) {
131  auto name = rType.getElement(fieldIndex).name.getValue();
132  auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
133  FIRRTLBaseValue newField =
134  builder.create<SubfieldOp>(newResult, fieldIndex);
135  // data and mask depend on the memory type which was split. They can
136  // also go both directions, depending on the port direction.
137  if (!(name == "data" || name == "mask" || name == "wdata" ||
138  name == "wmask" || name == "rdata")) {
139  emitConnect(builder, newField, oldField);
140  continue;
141  }
142  Value realOldField = oldField;
143  if (rType.getElement(fieldIndex).isFlip) {
144  // Cast the memory read data from flat type to aggregate.
145  auto castField =
146  builder.createOrFold<BitCastOp>(oldField.getType(), newField);
147  // Write the aggregate read data.
148  emitConnect(builder, realOldField, castField);
149  } else {
150  // Cast the input aggregate write data to flat type.
151  // Cast the input aggregate write data to flat type.
152  auto newFieldType = newField.getType();
153  auto oldFieldBitWidth = getBitWidth(oldField.getType());
154  // Following condition is true, if a data field is 0 bits. Then
155  // newFieldType is of smaller bits than old.
156  if (getBitWidth(newFieldType) != *oldFieldBitWidth)
157  newFieldType = UIntType::get(context, *oldFieldBitWidth);
158  realOldField = builder.create<BitCastOp>(newFieldType, oldField);
159  // Mask bits require special handling, since some of the mask bits
160  // need to be repeated, direct bitcasting wouldn't work. Depending
161  // on the mask granularity, some mask bits will be repeated.
162  if ((name == "mask" || name == "wmask") &&
163  (maskWidths.size() != totalmaskWidths)) {
164  Value catMasks;
165  for (const auto &m : llvm::enumerate(maskWidths)) {
166  // Get the mask bit.
167  auto mBit = builder.createOrFold<BitsPrimOp>(
168  realOldField, m.index(), m.index());
169  // Check how many times the mask bit needs to be prepend.
170  for (size_t repeat = 0; repeat < m.value(); repeat++)
171  if ((m.index() == 0 && repeat == 0) || !catMasks)
172  catMasks = mBit;
173  else
174  catMasks = builder.createOrFold<CatPrimOp>(mBit, catMasks);
175  }
176  realOldField = catMasks;
177  }
178  // Now set the mask or write data.
179  // Ensure that the types match.
180  emitConnect(builder, newField,
181  builder.createOrFold<BitCastOp>(newField.getType(),
182  realOldField));
183  }
184  }
185  }
186  ++numFlattenedMems;
187  memOp.erase();
188  return;
189  });
190  }
191 
192 private:
193  // Convert an aggregate type into a flat list of fields.
194  // This is used to flatten the aggregate memory datatype.
195  // Recursively populate the results with each ground type field.
196  static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
197  std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
199  .Case<BundleType>([&](auto bundle) {
200  for (auto &elt : bundle)
201  if (!flatten(elt.type))
202  return false;
203  return true;
204  })
205  .Case<FVectorType>([&](auto vector) {
206  for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
207  if (!flatten(vector.getElementType()))
208  return false;
209  return true;
210  })
211  .Case<IntType>([&](auto iType) {
212  results.push_back({iType});
213  return iType.getWidth().has_value();
214  })
215  .Default([&](auto) { return false; });
216  };
217  // Return true only if this is an aggregate with more than one element.
218  if (flatten(type) && results.size() > 1)
219  return true;
220  return false;
221  }
222 
223  Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
224  if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
225  return builder->create<SubfieldOp>(val, index);
226  if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
227  return builder->create<SubindexOp>(val, index);
228 
229  llvm_unreachable("Unknown aggregate type");
230  return nullptr;
231  }
232 };
233 } // end anonymous namespace
234 
235 std::unique_ptr<mlir::Pass> circt::firrtl::createFlattenMemoryPass() {
236  return std::make_unique<FlattenMemoryPass>();
237 }
Builder builder
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:515
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:525
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
mlir::TypedValue< FIRRTLBaseType > FIRRTLBaseValue
Definition: FIRRTLTypes.h:389
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:23
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::unique_ptr< mlir::Pass > createFlattenMemoryPass()
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28