CIRCT  19.0.0git
Go to the documentation of this file.
1 //===- LowerExtmemToHW.cpp - lock functions pass ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the lower extmem pass.
10 //
11 //===----------------------------------------------------------------------===//
13 #include "PassDetails.h"
16 #include "circt/Dialect/HW/HWOps.h"
21 #include "mlir/Dialect/Arith/IR/Arith.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
27 using namespace circt;
28 using namespace handshake;
29 using namespace mlir;
30 namespace {
31 using NamedType = std::pair<StringAttr, Type>;
32 struct HandshakeMemType {
33  llvm::SmallVector<NamedType> inputTypes, outputTypes;
34  MemRefType memRefType;
35  unsigned loadPorts, storePorts;
36 };
38 struct LoadName {
39  StringAttr dataIn;
40  StringAttr addrOut;
42  static LoadName get(MLIRContext *ctx, unsigned idx) {
43  return {StringAttr::get(ctx, "ld" + std::to_string(idx) + ".data"),
44  StringAttr::get(ctx, "ld" + std::to_string(idx) + ".addr")};
45  }
46 };
48 struct StoreNames {
49  StringAttr doneIn;
50  StringAttr out;
52  static StoreNames get(MLIRContext *ctx, unsigned idx) {
53  return {StringAttr::get(ctx, "st" + std::to_string(idx) + ".done"),
54  StringAttr::get(ctx, "st" + std::to_string(idx))};
55  }
56 };
58 } // namespace
60 static Type indexToMemAddr(Type t, MemRefType memRef) {
61  assert(t.isa<IndexType>() && "Expected index type");
62  auto shape = memRef.getShape();
63  assert(shape.size() == 1 && "Expected 1D memref");
64  unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
65  return IntegerType::get(t.getContext(), addrWidth);
66 }
68 static HandshakeMemType getMemTypeForExtmem(Value v) {
69  auto *ctx = v.getContext();
70  assert(v.getType().isa<mlir::MemRefType>() && "Value is not a memref type");
71  auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
72  HandshakeMemType memType;
73  llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
75  // Add memory type.
76  memType.memRefType = v.getType().cast<MemRefType>();
77  memType.loadPorts = extmemOp.getLdCount();
78  memType.storePorts = extmemOp.getStCount();
80  // Add load ports.
81  for (auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
82  auto names = LoadName::get(ctx, i);
83  memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
84  memType.outputTypes.push_back(
85  {names.addrOut,
86  indexToMemAddr(ldif.addressIn.getType(), memType.memRefType)});
87  }
89  // Add store ports.
90  for (auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
91  auto names = StoreNames::get(ctx, i);
93  // Incoming store data and address
94  llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
95  storeOutFields.push_back(
96  {StringAttr::get(ctx, "address"),
97  indexToMemAddr(stif.addressIn.getType(), memType.memRefType)});
98  storeOutFields.push_back(
99  {StringAttr::get(ctx, "data"), stif.dataIn.getType()});
100  auto inType = hw::StructType::get(ctx, storeOutFields);
101  memType.outputTypes.push_back({names.out, inType});
102  memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
103  }
105  return memType;
106 }
108 namespace {
109 struct HandshakeLowerExtmemToHWPass
110  : public HandshakeLowerExtmemToHWBase<HandshakeLowerExtmemToHWPass> {
112  HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
113  if (createESIWrapper)
114  this->createESIWrapper = *createESIWrapper;
115  }
117  void runOnOperation() override {
118  auto op = getOperation();
119  for (auto func : op.getOps<handshake::FuncOp>()) {
120  if (failed(lowerExtmemToHW(func))) {
121  signalPassFailure();
122  return;
123  }
124  }
125  };
127  LogicalResult lowerExtmemToHW(handshake::FuncOp func);
128  LogicalResult
129  wrapESI(handshake::FuncOp func, hw::ModulePortInfo origPorts,
130  const std::map<unsigned, HandshakeMemType> &argReplacements);
131 };
133 LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
134  handshake::FuncOp func, hw::ModulePortInfo origPorts,
135  const std::map<unsigned, HandshakeMemType> &argReplacements) {
136  auto *ctx = func.getContext();
137  OpBuilder b(func);
138  auto loc = func.getLoc();
140  // Create external module which will match the interface of 'func' after it's
141  // been lowered to HW.
142  b.setInsertionPoint(func);
143  auto newPortInfo = handshake::getPortInfoForOpTypes(
144  func, func.getArgumentTypes(), func.getResultTypes());
145  auto extMod = b.create<hw::HWModuleExternOp>(
146  loc, StringAttr::get(ctx, "__" + func.getName() + "_hw"), newPortInfo);
148  // Add an attribute to the original handshake function to indicate that it
149  // needs to resolve to extMod in a later pass.
150  func->setAttr(kPredeclarationAttr,
151  FlatSymbolRefAttr::get(ctx, extMod.getName()));
153  // Create wrapper module. This will have the same ports as the original
154  // module, sans the replaced arguments.
155  auto wrapperModPortInfo = origPorts;
156  llvm::SmallVector<unsigned> argReplacementsIdxs;
157  llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
158  [](auto &pair) { return pair.first; });
159  for (auto i : llvm::reverse(argReplacementsIdxs))
160  wrapperModPortInfo.eraseInput(i);
161  auto wrapperMod = b.create<hw::HWModuleOp>(
162  loc, StringAttr::get(ctx, func.getName() + "_esi_wrapper"),
163  wrapperModPortInfo);
164  Value clk = wrapperMod.getBodyBlock()->getArgument(
165  wrapperMod.getBodyBlock()->getNumArguments() - 2);
166  Value rst = wrapperMod.getBodyBlock()->getArgument(
167  wrapperMod.getBodyBlock()->getNumArguments() - 1);
168  SmallVector<Value> clkRes = {clk, rst};
170  b.setInsertionPointToStart(wrapperMod.getBodyBlock());
171  BackedgeBuilder bb(b, loc);
173  // Create backedges for the results of the external module. These will be
174  // replaced by the service instance requests if associated with a memory.
175  llvm::SmallVector<Backedge> backedges;
176  for (auto resType : extMod.getOutputTypes())
177  backedges.push_back(bb.get(resType));
179  // Maintain which index we're currently at in the lowered handshake module's
180  // return.
181  unsigned resIdx = origPorts.sizeOutputs();
183  // Maintain the arguments which each memory will add to the inner module
184  // instance.
185  llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
187  for (auto [i, memType] : argReplacements) {
189  b.setInsertionPoint(wrapperMod);
190  // Create a memory service declaration for each memref argument that was
191  // served.
192  auto origPortInfo = origPorts.atInput(i);
193  auto memrefShape = memType.memRefType.getShape();
194  auto dataType = memType.memRefType.getElementType();
195  assert(memrefShape.size() == 1 && "Only 1D memrefs are supported");
196  unsigned memrefSize = memrefShape[0];
197  auto memServiceDecl = b.create<esi::RandomAccessMemoryDeclOp>(
198  loc,, TypeAttr::get(dataType),
199  b.getI64IntegerAttr(memrefSize));
200  esi::ServicePortInfo writePortInfo = memServiceDecl.writePortInfo();
201  esi::ServicePortInfo readPortInfo = memServiceDecl.readPortInfo();
203  SmallVector<Value> instanceArgsFromThisMem;
205  // Create service requests. This MUST follow the order of which ports were
206  // added in other parts of this pass (load ports first, then store ports).
207  b.setInsertionPointToStart(wrapperMod.getBodyBlock());
209  // Load ports:
210  for (unsigned i = 0; i < memType.loadPorts; ++i) {
211  auto req = b.create<esi::RequestConnectionOp>(
212  loc, readPortInfo.type, readPortInfo.port,
213  esi::AppIDAttr::get(ctx, b.getStringAttr("load"), {}));
214  auto reqUnpack = b.create<esi::UnpackBundleOp>(
215  loc, req.getToClient(), ValueRange{backedges[resIdx]});
216  instanceArgsFromThisMem.push_back(
217  reqUnpack.getToChannels()
218  [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
219  ++resIdx;
220  }
222  // Store ports:
223  for (unsigned i = 0; i < memType.storePorts; ++i) {
224  auto req = b.create<esi::RequestConnectionOp>(
225  loc, writePortInfo.type, writePortInfo.port,
226  esi::AppIDAttr::get(ctx, b.getStringAttr("store"), {}));
227  auto reqUnpack = b.create<esi::UnpackBundleOp>(
228  loc, req.getToClient(), ValueRange{backedges[resIdx]});
229  instanceArgsFromThisMem.push_back(
230  reqUnpack.getToChannels()
231  [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
232  ++resIdx;
233  }
235  instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
236  }
238  // Stitch together arguments from the top-level ESI wrapper and the instance
239  // arguments generated from the service requests.
240  llvm::SmallVector<Value> instanceArgs;
242  // Iterate over the arguments of the original handshake.func and determine
243  // whether to grab operands from the arg replacements or the wrapper module.
244  unsigned wrapperArgIdx = 0;
246  for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
247  // Arg replacement indices refer to the original handshake.func argument
248  // index.
249  if (argReplacements.count(i)) {
250  // This index was originally a memref - pop the instance arguments for the
251  // next-in-line memory and add them.
252  auto &memArgs = instanceArgsForMem.front();
253  instanceArgs.append(memArgs.begin(), memArgs.end());
254  instanceArgsForMem.erase(instanceArgsForMem.begin());
255  } else {
256  // Add the argument from the wrapper mod. This is maintained by its own
257  // counter (memref arguments are removed, so if there was an argument at
258  // this point, it needs to come from the wrapper module).
259  instanceArgs.push_back(
260  wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
261  }
262  }
264  // Add any missing arguments from the wrapper module (this will be clock and
265  // reset)
266  for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
267  ++wrapperArgIdx)
268  instanceArgs.push_back(
269  wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
271  // Instantiate the inner module.
272  auto instance =
273  b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
275  // And resolve the backedges.
276  for (auto [res, be] : llvm::zip(instance.getResults(), backedges))
277  be.setValue(res);
279  // Finally, grab the (non-memory) outputs from the inner module and return
280  // them through the wrapper.
281  auto outputOp =
282  cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
283  b.setInsertionPoint(outputOp);
284  b.create<hw::OutputOp>(
285  outputOp.getLoc(),
286  instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
287  outputOp.erase();
289  return success();
290 }
292 // Truncates the index-typed 'v' into an integer-type of the same width as the
293 // 'memref' argument.
294 // Uses arith operations since these are supported in the HandshakeToHW
295 // lowering.
296 static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
297  MemRefType memRefType) {
298  assert(v.getType().isa<IndexType>() && "Expected an index-typed value");
299  auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
300  return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
301 }
303 static Value plumbLoadPort(Location loc, OpBuilder &b,
304  handshake::MemLoadInterface &ldif, Value loadData,
305  MemRefType memrefType) {
306  // We need to feed both the load data and the load done outputs.
307  // Fork the extracted load data into two, and 'join' the second one to
308  // generate a none-typed output to drive the load done.
309  auto dataFork = b.create<ForkOp>(loc, loadData, 2);
311  auto dataOut = dataFork.getResult()[0];
312  llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
313  auto dataDone = b.create<JoinOp>(loc, joinArgs);
315  ldif.dataOut.replaceAllUsesWith(dataOut);
316  ldif.doneOut.replaceAllUsesWith(dataDone);
318  // Return load address, to be fed to the top-level output, truncated to the
319  // width of the memory that is accessed.
320  return truncateToMemoryWidth(loc, b, ldif.addressIn, memrefType);
321 }
323 static Value plumbStorePort(Location loc, OpBuilder &b,
324  handshake::MemStoreInterface &stif, Value done,
325  Type outType, MemRefType memrefType) {
326  stif.doneOut.replaceAllUsesWith(done);
327  // Return the store address and data to be fed to the top-level output.
328  // Address is truncated to the width of the memory that is accessed.
329  llvm::SmallVector<Value> structArgs = {
330  truncateToMemoryWidth(loc, b, stif.addressIn, memrefType), stif.dataIn};
332  return b
333  .create<hw::StructCreateOp>(loc, outType.cast<hw::StructType>(),
334  structArgs)
335  .getResult();
336 }
338 static void appendToStringArrayAttr(Operation *op, StringRef attrName,
339  StringRef attrVal) {
340  auto *ctx = op->getContext();
341  llvm::SmallVector<Attribute> newArr;
342  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
343  std::back_inserter(newArr));
344  newArr.push_back(StringAttr::get(ctx, attrVal));
345  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
346 }
348 static void insertInStringArrayAttr(Operation *op, StringRef attrName,
349  StringRef attrVal, unsigned idx) {
350  auto *ctx = op->getContext();
351  llvm::SmallVector<Attribute> newArr;
352  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
353  std::back_inserter(newArr));
354  newArr.insert(newArr.begin() + idx, StringAttr::get(ctx, attrVal));
355  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
356 }
358 static void eraseFromArrayAttr(Operation *op, StringRef attrName,
359  unsigned idx) {
360  auto *ctx = op->getContext();
361  llvm::SmallVector<Attribute> newArr;
362  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
363  std::back_inserter(newArr));
364  newArr.erase(newArr.begin() + idx);
365  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
366 }
368 struct ArgTypeReplacement {
369  unsigned index;
370  TypeRange ins;
371  TypeRange outs;
372 };
374 LogicalResult
375 HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
376  // Gather memref ports to be converted. This is an ordered map, and will be
377  // iterated from lo to hi indices.
378  std::map<unsigned, Value> memrefArgs;
379  for (auto [i, arg] : llvm::enumerate(func.getArguments()))
380  if (arg.getType().isa<MemRefType>())
381  memrefArgs[i] = arg;
383  if (memrefArgs.empty())
384  return success(); // nothing to do.
386  // Record which arg indices were replaces with handshake memory ports.
387  // This is an ordered map, and will be iterated from lo to hi indices.
388  std::map<unsigned, HandshakeMemType> argReplacements;
390  // Record the hw.module i/o of the original func (used for ESI wrapper).
391  auto origPortInfo = handshake::getPortInfoForOpTypes(
392  func, func.getArgumentTypes(), func.getResultTypes());
394  OpBuilder b(func);
395  for (auto it : memrefArgs) {
396  // Do not use structured bindings for 'it' - cannot reference inside lambda.
397  unsigned i = it.first;
398  auto arg = it.second;
399  auto loc = arg.getLoc();
400  // Get the attached extmemory external module.
401  auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
402  b.setInsertionPoint(extmemOp);
404  // Add memory input - this is the output of the extmemory op.
405  auto memIOTypes = getMemTypeForExtmem(arg);
406  MemRefType memrefType = arg.getType().cast<MemRefType>();
408  auto oldReturnOp =
409  cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
410  llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
411  unsigned addedInPorts = 0;
412  auto memName = func.getArgName(i);
413  auto addArgRes = [&](unsigned id, NamedType &argType, NamedType &resType) {
414  // Function argument
415  unsigned newArgIdx = i + addedInPorts;
416  func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc());
417  insertInStringArrayAttr(func, "argNames",
418  memName.str() + "_" + argType.first.str(),
419  newArgIdx);
420  auto newInPort = func.getArgument(newArgIdx);
421  ++addedInPorts;
423  // Function result.
424  func.insertResult(func.getNumResults(), resType.second, {});
425  appendToStringArrayAttr(func, "resNames",
426  memName.str() + "_" + resType.first.str());
427  return newInPort;
428  };
430  // Plumb load ports.
431  unsigned portIdx = 0;
432  for (auto loadPort : extmemOp.getLoadPorts()) {
433  auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
434  memIOTypes.outputTypes[portIdx]);
435  newReturnOperands.push_back(
436  plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
437  ++portIdx;
438  }
440  // Plumb store ports.
441  for (auto storePort : extmemOp.getStorePorts()) {
442  auto newInPort =
443  addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
444  memIOTypes.outputTypes[portIdx]);
445  newReturnOperands.push_back(
446  plumbStorePort(loc, b, storePort, newInPort,
447  memIOTypes.outputTypes[portIdx].second, memrefType));
448  ++portIdx;
449  }
451  // Replace the return op of the function with a new one that returns the
452  // memory output struct.
453  b.setInsertionPoint(oldReturnOp);
454  b.create<ReturnOp>(arg.getLoc(), newReturnOperands);
455  oldReturnOp.erase();
457  // Erase the extmemory operation since I/O plumbing has replaced all of its
458  // results.
459  extmemOp.erase();
461  // Erase the original memref argument of the top-level i/o now that it's
462  // use has been removed.
463  func.eraseArgument(i + addedInPorts);
464  eraseFromArrayAttr(func, "argNames", i + addedInPorts);
466  argReplacements[i] = memIOTypes;
467  }
469  if (createESIWrapper)
470  if (failed(wrapESI(func, origPortInfo, argReplacements)))
471  return failure();
473  return success();
474 }
476 } // namespace
478 std::unique_ptr<mlir::Pass>
480  std::optional<bool> createESIWrapper) {
481  return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
482 }
assert(baseType &&"element must be base type")
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
def name(self)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
static constexpr const char * kPredeclarationAttr
Definition: HandshakeToHW.h:37
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see for more details.
Definition: DebugAnalysis.h:21