CIRCT  20.0.0git
LowerExtmemToHW.cpp
Go to the documentation of this file.
1 //===- LowerExtmemToHW.cpp - lock functions pass ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the lower extmem pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "circt/Dialect/HW/HWOps.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/Support/Debug.h"
28 
29 namespace circt {
30 namespace handshake {
31 #define GEN_PASS_DEF_HANDSHAKELOWEREXTMEMTOHW
32 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
33 } // namespace handshake
34 } // namespace circt
35 
36 using namespace circt;
37 using namespace handshake;
38 using namespace mlir;
39 namespace {
40 using NamedType = std::pair<StringAttr, Type>;
41 struct HandshakeMemType {
42  llvm::SmallVector<NamedType> inputTypes, outputTypes;
43  MemRefType memRefType;
44  unsigned loadPorts, storePorts;
45 };
46 
47 struct LoadName {
48  StringAttr dataIn;
49  StringAttr addrOut;
50 
51  static LoadName get(MLIRContext *ctx, unsigned idx) {
52  return {StringAttr::get(ctx, "ld" + std::to_string(idx) + ".data"),
53  StringAttr::get(ctx, "ld" + std::to_string(idx) + ".addr")};
54  }
55 };
56 
57 struct StoreNames {
58  StringAttr doneIn;
59  StringAttr out;
60 
61  static StoreNames get(MLIRContext *ctx, unsigned idx) {
62  return {StringAttr::get(ctx, "st" + std::to_string(idx) + ".done"),
63  StringAttr::get(ctx, "st" + std::to_string(idx))};
64  }
65 };
66 
67 } // namespace
68 
69 static Type indexToMemAddr(Type t, MemRefType memRef) {
70  assert(isa<IndexType>(t) && "Expected index type");
71  auto shape = memRef.getShape();
72  assert(shape.size() == 1 && "Expected 1D memref");
73  unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
74  return IntegerType::get(t.getContext(), addrWidth);
75 }
76 
77 static HandshakeMemType getMemTypeForExtmem(Value v) {
78  auto *ctx = v.getContext();
79  assert(isa<mlir::MemRefType>(v.getType()) && "Value is not a memref type");
80  auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
81  HandshakeMemType memType;
82  llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
83 
84  // Add memory type.
85  memType.memRefType = cast<MemRefType>(v.getType());
86  memType.loadPorts = extmemOp.getLdCount();
87  memType.storePorts = extmemOp.getStCount();
88 
89  // Add load ports.
90  for (auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
91  auto names = LoadName::get(ctx, i);
92  memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
93  memType.outputTypes.push_back(
94  {names.addrOut,
95  indexToMemAddr(ldif.addressIn.getType(), memType.memRefType)});
96  }
97 
98  // Add store ports.
99  for (auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
100  auto names = StoreNames::get(ctx, i);
101 
102  // Incoming store data and address
103  llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
104  storeOutFields.push_back(
105  {StringAttr::get(ctx, "address"),
106  indexToMemAddr(stif.addressIn.getType(), memType.memRefType)});
107  storeOutFields.push_back(
108  {StringAttr::get(ctx, "data"), stif.dataIn.getType()});
109  auto inType = hw::StructType::get(ctx, storeOutFields);
110  memType.outputTypes.push_back({names.out, inType});
111  memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
112  }
113 
114  return memType;
115 }
116 
117 namespace {
118 struct HandshakeLowerExtmemToHWPass
119  : public circt::handshake::impl::HandshakeLowerExtmemToHWBase<
120  HandshakeLowerExtmemToHWPass> {
121 
122  HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
123  if (createESIWrapper)
124  this->createESIWrapper = *createESIWrapper;
125  }
126 
127  void runOnOperation() override {
128  auto op = getOperation();
129  for (auto func : op.getOps<handshake::FuncOp>()) {
130  if (failed(lowerExtmemToHW(func))) {
131  signalPassFailure();
132  return;
133  }
134  }
135  };
136 
137  LogicalResult lowerExtmemToHW(handshake::FuncOp func);
138  LogicalResult
139  wrapESI(handshake::FuncOp func, hw::ModulePortInfo origPorts,
140  const std::map<unsigned, HandshakeMemType> &argReplacements);
141 };
142 
143 LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
144  handshake::FuncOp func, hw::ModulePortInfo origPorts,
145  const std::map<unsigned, HandshakeMemType> &argReplacements) {
146  auto *ctx = func.getContext();
147  OpBuilder b(func);
148  auto loc = func.getLoc();
149 
150  // Create external module which will match the interface of 'func' after it's
151  // been lowered to HW.
152  b.setInsertionPoint(func);
153  auto newPortInfo = handshake::getPortInfoForOpTypes(
154  func, func.getArgumentTypes(), func.getResultTypes());
155  auto extMod = b.create<hw::HWModuleExternOp>(
156  loc, StringAttr::get(ctx, "__" + func.getName() + "_hw"), newPortInfo);
157 
158  // Add an attribute to the original handshake function to indicate that it
159  // needs to resolve to extMod in a later pass.
160  func->setAttr(kPredeclarationAttr,
161  FlatSymbolRefAttr::get(ctx, extMod.getName()));
162 
163  // Create wrapper module. This will have the same ports as the original
164  // module, sans the replaced arguments.
165  auto wrapperModPortInfo = origPorts;
166  llvm::SmallVector<unsigned> argReplacementsIdxs;
167  llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
168  [](auto &pair) { return pair.first; });
169  for (auto i : llvm::reverse(argReplacementsIdxs))
170  wrapperModPortInfo.eraseInput(i);
171  auto wrapperMod = b.create<hw::HWModuleOp>(
172  loc, StringAttr::get(ctx, func.getName() + "_esi_wrapper"),
173  wrapperModPortInfo);
174  Value clk = wrapperMod.getBodyBlock()->getArgument(
175  wrapperMod.getBodyBlock()->getNumArguments() - 2);
176  Value rst = wrapperMod.getBodyBlock()->getArgument(
177  wrapperMod.getBodyBlock()->getNumArguments() - 1);
178  SmallVector<Value> clkRes = {clk, rst};
179 
180  b.setInsertionPointToStart(wrapperMod.getBodyBlock());
181  BackedgeBuilder bb(b, loc);
182 
183  // Create backedges for the results of the external module. These will be
184  // replaced by the service instance requests if associated with a memory.
185  llvm::SmallVector<Backedge> backedges;
186  for (auto resType : extMod.getOutputTypes())
187  backedges.push_back(bb.get(resType));
188 
189  // Maintain which index we're currently at in the lowered handshake module's
190  // return.
191  unsigned resIdx = origPorts.sizeOutputs();
192 
193  // Maintain the arguments which each memory will add to the inner module
194  // instance.
195  llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
196 
197  for (auto [i, memType] : argReplacements) {
198 
199  b.setInsertionPoint(wrapperMod);
200  // Create a memory service declaration for each memref argument that was
201  // served.
202  auto origPortInfo = origPorts.atInput(i);
203  auto memrefShape = memType.memRefType.getShape();
204  auto dataType = memType.memRefType.getElementType();
205  assert(memrefShape.size() == 1 && "Only 1D memrefs are supported");
206  unsigned memrefSize = memrefShape[0];
207  auto memServiceDecl = b.create<esi::RandomAccessMemoryDeclOp>(
208  loc, origPortInfo.name, TypeAttr::get(dataType),
209  b.getI64IntegerAttr(memrefSize));
210  esi::ServicePortInfo writePortInfo = memServiceDecl.writePortInfo();
211  esi::ServicePortInfo readPortInfo = memServiceDecl.readPortInfo();
212 
213  SmallVector<Value> instanceArgsFromThisMem;
214 
215  // Create service requests. This MUST follow the order of which ports were
216  // added in other parts of this pass (load ports first, then store ports).
217  b.setInsertionPointToStart(wrapperMod.getBodyBlock());
218 
219  // Load ports:
220  for (unsigned i = 0; i < memType.loadPorts; ++i) {
221  auto req = b.create<esi::RequestConnectionOp>(
222  loc, readPortInfo.type, readPortInfo.port,
223  esi::AppIDAttr::get(ctx, b.getStringAttr("load"), {}));
224  auto reqUnpack = b.create<esi::UnpackBundleOp>(
225  loc, req.getToClient(), ValueRange{backedges[resIdx]});
226  instanceArgsFromThisMem.push_back(
227  reqUnpack.getToChannels()
228  [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
229  ++resIdx;
230  }
231 
232  // Store ports:
233  for (unsigned i = 0; i < memType.storePorts; ++i) {
234  auto req = b.create<esi::RequestConnectionOp>(
235  loc, writePortInfo.type, writePortInfo.port,
236  esi::AppIDAttr::get(ctx, b.getStringAttr("store"), {}));
237  auto reqUnpack = b.create<esi::UnpackBundleOp>(
238  loc, req.getToClient(), ValueRange{backedges[resIdx]});
239  instanceArgsFromThisMem.push_back(
240  reqUnpack.getToChannels()
241  [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
242  ++resIdx;
243  }
244 
245  instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
246  }
247 
248  // Stitch together arguments from the top-level ESI wrapper and the instance
249  // arguments generated from the service requests.
250  llvm::SmallVector<Value> instanceArgs;
251 
252  // Iterate over the arguments of the original handshake.func and determine
253  // whether to grab operands from the arg replacements or the wrapper module.
254  unsigned wrapperArgIdx = 0;
255 
256  for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
257  // Arg replacement indices refer to the original handshake.func argument
258  // index.
259  if (argReplacements.count(i)) {
260  // This index was originally a memref - pop the instance arguments for the
261  // next-in-line memory and add them.
262  auto &memArgs = instanceArgsForMem.front();
263  instanceArgs.append(memArgs.begin(), memArgs.end());
264  instanceArgsForMem.erase(instanceArgsForMem.begin());
265  } else {
266  // Add the argument from the wrapper mod. This is maintained by its own
267  // counter (memref arguments are removed, so if there was an argument at
268  // this point, it needs to come from the wrapper module).
269  instanceArgs.push_back(
270  wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
271  }
272  }
273 
274  // Add any missing arguments from the wrapper module (this will be clock and
275  // reset)
276  for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
277  ++wrapperArgIdx)
278  instanceArgs.push_back(
279  wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
280 
281  // Instantiate the inner module.
282  auto instance =
283  b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
284 
285  // And resolve the backedges.
286  for (auto [res, be] : llvm::zip(instance.getResults(), backedges))
287  be.setValue(res);
288 
289  // Finally, grab the (non-memory) outputs from the inner module and return
290  // them through the wrapper.
291  auto outputOp =
292  cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
293  b.setInsertionPoint(outputOp);
294  b.create<hw::OutputOp>(
295  outputOp.getLoc(),
296  instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
297  outputOp.erase();
298 
299  return success();
300 }
301 
302 // Truncates the index-typed 'v' into an integer-type of the same width as the
303 // 'memref' argument.
304 // Uses arith operations since these are supported in the HandshakeToHW
305 // lowering.
306 static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
307  MemRefType memRefType) {
308  assert(isa<IndexType>(v.getType()) && "Expected an index-typed value");
309  auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
310  return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
311 }
312 
313 static Value plumbLoadPort(Location loc, OpBuilder &b,
314  handshake::MemLoadInterface &ldif, Value loadData,
315  MemRefType memrefType) {
316  // We need to feed both the load data and the load done outputs.
317  // Fork the extracted load data into two, and 'join' the second one to
318  // generate a none-typed output to drive the load done.
319  auto dataFork = b.create<ForkOp>(loc, loadData, 2);
320 
321  auto dataOut = dataFork.getResult()[0];
322  llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
323  auto dataDone = b.create<JoinOp>(loc, joinArgs);
324 
325  ldif.dataOut.replaceAllUsesWith(dataOut);
326  ldif.doneOut.replaceAllUsesWith(dataDone);
327 
328  // Return load address, to be fed to the top-level output, truncated to the
329  // width of the memory that is accessed.
330  return truncateToMemoryWidth(loc, b, ldif.addressIn, memrefType);
331 }
332 
333 static Value plumbStorePort(Location loc, OpBuilder &b,
334  handshake::MemStoreInterface &stif, Value done,
335  Type outType, MemRefType memrefType) {
336  stif.doneOut.replaceAllUsesWith(done);
337  // Return the store address and data to be fed to the top-level output.
338  // Address is truncated to the width of the memory that is accessed.
339  llvm::SmallVector<Value> structArgs = {
340  truncateToMemoryWidth(loc, b, stif.addressIn, memrefType), stif.dataIn};
341 
342  return b
343  .create<hw::StructCreateOp>(loc, cast<hw::StructType>(outType),
344  structArgs)
345  .getResult();
346 }
347 
348 static void appendToStringArrayAttr(Operation *op, StringRef attrName,
349  StringRef attrVal) {
350  auto *ctx = op->getContext();
351  llvm::SmallVector<Attribute> newArr;
352  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
353  std::back_inserter(newArr));
354  newArr.push_back(StringAttr::get(ctx, attrVal));
355  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
356 }
357 
358 static void insertInStringArrayAttr(Operation *op, StringRef attrName,
359  StringRef attrVal, unsigned idx) {
360  auto *ctx = op->getContext();
361  llvm::SmallVector<Attribute> newArr;
362  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
363  std::back_inserter(newArr));
364  newArr.insert(newArr.begin() + idx, StringAttr::get(ctx, attrVal));
365  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
366 }
367 
368 static void eraseFromArrayAttr(Operation *op, StringRef attrName,
369  unsigned idx) {
370  auto *ctx = op->getContext();
371  llvm::SmallVector<Attribute> newArr;
372  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
373  std::back_inserter(newArr));
374  newArr.erase(newArr.begin() + idx);
375  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
376 }
377 
378 struct ArgTypeReplacement {
379  unsigned index;
380  TypeRange ins;
381  TypeRange outs;
382 };
383 
384 LogicalResult
385 HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
386  // Gather memref ports to be converted. This is an ordered map, and will be
387  // iterated from lo to hi indices.
388  std::map<unsigned, Value> memrefArgs;
389  for (auto [i, arg] : llvm::enumerate(func.getArguments()))
390  if (isa<MemRefType>(arg.getType()))
391  memrefArgs[i] = arg;
392 
393  if (memrefArgs.empty())
394  return success(); // nothing to do.
395 
396  // Record which arg indices were replaces with handshake memory ports.
397  // This is an ordered map, and will be iterated from lo to hi indices.
398  std::map<unsigned, HandshakeMemType> argReplacements;
399 
400  // Record the hw.module i/o of the original func (used for ESI wrapper).
401  auto origPortInfo = handshake::getPortInfoForOpTypes(
402  func, func.getArgumentTypes(), func.getResultTypes());
403 
404  OpBuilder b(func);
405  for (auto it : memrefArgs) {
406  // Do not use structured bindings for 'it' - cannot reference inside lambda.
407  unsigned i = it.first;
408  auto arg = it.second;
409  auto loc = arg.getLoc();
410  // Get the attached extmemory external module.
411  auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
412  b.setInsertionPoint(extmemOp);
413 
414  // Add memory input - this is the output of the extmemory op.
415  auto memIOTypes = getMemTypeForExtmem(arg);
416  MemRefType memrefType = cast<MemRefType>(arg.getType());
417 
418  auto oldReturnOp =
419  cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
420  llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
421  unsigned addedInPorts = 0;
422  auto memName = func.getArgName(i);
423  auto addArgRes = [&](unsigned id, NamedType &argType, NamedType &resType) {
424  // Function argument
425  unsigned newArgIdx = i + addedInPorts;
426  func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc());
427  insertInStringArrayAttr(func, "argNames",
428  memName.str() + "_" + argType.first.str(),
429  newArgIdx);
430  auto newInPort = func.getArgument(newArgIdx);
431  ++addedInPorts;
432 
433  // Function result.
434  func.insertResult(func.getNumResults(), resType.second, {});
435  appendToStringArrayAttr(func, "resNames",
436  memName.str() + "_" + resType.first.str());
437  return newInPort;
438  };
439 
440  // Plumb load ports.
441  unsigned portIdx = 0;
442  for (auto loadPort : extmemOp.getLoadPorts()) {
443  auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
444  memIOTypes.outputTypes[portIdx]);
445  newReturnOperands.push_back(
446  plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
447  ++portIdx;
448  }
449 
450  // Plumb store ports.
451  for (auto storePort : extmemOp.getStorePorts()) {
452  auto newInPort =
453  addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
454  memIOTypes.outputTypes[portIdx]);
455  newReturnOperands.push_back(
456  plumbStorePort(loc, b, storePort, newInPort,
457  memIOTypes.outputTypes[portIdx].second, memrefType));
458  ++portIdx;
459  }
460 
461  // Replace the return op of the function with a new one that returns the
462  // memory output struct.
463  b.setInsertionPoint(oldReturnOp);
464  b.create<ReturnOp>(arg.getLoc(), newReturnOperands);
465  oldReturnOp.erase();
466 
467  // Erase the extmemory operation since I/O plumbing has replaced all of its
468  // results.
469  extmemOp.erase();
470 
471  // Erase the original memref argument of the top-level i/o now that it's
472  // use has been removed.
473  func.eraseArgument(i + addedInPorts);
474  eraseFromArrayAttr(func, "argNames", i + addedInPorts);
475 
476  argReplacements[i] = memIOTypes;
477  }
478 
479  if (createESIWrapper)
480  if (failed(wrapESI(func, origPortInfo, argReplacements)))
481  return failure();
482 
483  return success();
484 }
485 
486 } // namespace
487 
488 std::unique_ptr<mlir::Pass>
490  std::optional<bool> createESIWrapper) {
491  return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
492 }
assert(baseType &&"element must be base type")
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
FuncOp create(Union[StringAttr, str] sym_name, List[Tuple[str, Type]] args, List[Tuple[str, Type]] results, Dict[str, Attribute] attributes={}, loc=None, ip=None)
Definition: handshake.py:36
def name(self)
Definition: hw.py:329
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21