CIRCT  19.0.0git
LowerExtmemToHW.cpp
Go to the documentation of this file.
1 //===- LowerExtmemToHW.cpp - lock functions pass ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the lower extmem pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
16 #include "circt/Dialect/HW/HWOps.h"
21 #include "mlir/Dialect/Arith/IR/Arith.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
26 
27 using namespace circt;
28 using namespace handshake;
29 using namespace mlir;
30 namespace {
31 using NamedType = std::pair<StringAttr, Type>;
32 struct HandshakeMemType {
33  llvm::SmallVector<NamedType> inputTypes, outputTypes;
34  MemRefType memRefType;
35  unsigned loadPorts, storePorts;
36 };
37 
38 struct LoadName {
39  StringAttr dataIn;
40  StringAttr addrOut;
41 
42  static LoadName get(MLIRContext *ctx, unsigned idx) {
43  return {StringAttr::get(ctx, "ld" + std::to_string(idx) + ".data"),
44  StringAttr::get(ctx, "ld" + std::to_string(idx) + ".addr")};
45  }
46 };
47 
48 struct StoreNames {
49  StringAttr doneIn;
50  StringAttr out;
51 
52  static StoreNames get(MLIRContext *ctx, unsigned idx) {
53  return {StringAttr::get(ctx, "st" + std::to_string(idx) + ".done"),
54  StringAttr::get(ctx, "st" + std::to_string(idx))};
55  }
56 };
57 
58 } // namespace
59 
60 static Type indexToMemAddr(Type t, MemRefType memRef) {
61  assert(isa<IndexType>(t) && "Expected index type");
62  auto shape = memRef.getShape();
63  assert(shape.size() == 1 && "Expected 1D memref");
64  unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
65  return IntegerType::get(t.getContext(), addrWidth);
66 }
67 
68 static HandshakeMemType getMemTypeForExtmem(Value v) {
69  auto *ctx = v.getContext();
70  assert(isa<mlir::MemRefType>(v.getType()) && "Value is not a memref type");
71  auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
72  HandshakeMemType memType;
73  llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
74 
75  // Add memory type.
76  memType.memRefType = cast<MemRefType>(v.getType());
77  memType.loadPorts = extmemOp.getLdCount();
78  memType.storePorts = extmemOp.getStCount();
79 
80  // Add load ports.
81  for (auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
82  auto names = LoadName::get(ctx, i);
83  memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
84  memType.outputTypes.push_back(
85  {names.addrOut,
86  indexToMemAddr(ldif.addressIn.getType(), memType.memRefType)});
87  }
88 
89  // Add store ports.
90  for (auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
91  auto names = StoreNames::get(ctx, i);
92 
93  // Incoming store data and address
94  llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
95  storeOutFields.push_back(
96  {StringAttr::get(ctx, "address"),
97  indexToMemAddr(stif.addressIn.getType(), memType.memRefType)});
98  storeOutFields.push_back(
99  {StringAttr::get(ctx, "data"), stif.dataIn.getType()});
100  auto inType = hw::StructType::get(ctx, storeOutFields);
101  memType.outputTypes.push_back({names.out, inType});
102  memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
103  }
104 
105  return memType;
106 }
107 
108 namespace {
109 struct HandshakeLowerExtmemToHWPass
110  : public HandshakeLowerExtmemToHWBase<HandshakeLowerExtmemToHWPass> {
111 
112  HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
113  if (createESIWrapper)
114  this->createESIWrapper = *createESIWrapper;
115  }
116 
117  void runOnOperation() override {
118  auto op = getOperation();
119  for (auto func : op.getOps<handshake::FuncOp>()) {
120  if (failed(lowerExtmemToHW(func))) {
121  signalPassFailure();
122  return;
123  }
124  }
125  };
126 
127  LogicalResult lowerExtmemToHW(handshake::FuncOp func);
128  LogicalResult
129  wrapESI(handshake::FuncOp func, hw::ModulePortInfo origPorts,
130  const std::map<unsigned, HandshakeMemType> &argReplacements);
131 };
132 
133 LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
134  handshake::FuncOp func, hw::ModulePortInfo origPorts,
135  const std::map<unsigned, HandshakeMemType> &argReplacements) {
136  auto *ctx = func.getContext();
137  OpBuilder b(func);
138  auto loc = func.getLoc();
139 
140  // Create external module which will match the interface of 'func' after it's
141  // been lowered to HW.
142  b.setInsertionPoint(func);
143  auto newPortInfo = handshake::getPortInfoForOpTypes(
144  func, func.getArgumentTypes(), func.getResultTypes());
145  auto extMod = b.create<hw::HWModuleExternOp>(
146  loc, StringAttr::get(ctx, "__" + func.getName() + "_hw"), newPortInfo);
147 
148  // Add an attribute to the original handshake function to indicate that it
149  // needs to resolve to extMod in a later pass.
150  func->setAttr(kPredeclarationAttr,
151  FlatSymbolRefAttr::get(ctx, extMod.getName()));
152 
153  // Create wrapper module. This will have the same ports as the original
154  // module, sans the replaced arguments.
155  auto wrapperModPortInfo = origPorts;
156  llvm::SmallVector<unsigned> argReplacementsIdxs;
157  llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
158  [](auto &pair) { return pair.first; });
159  for (auto i : llvm::reverse(argReplacementsIdxs))
160  wrapperModPortInfo.eraseInput(i);
161  auto wrapperMod = b.create<hw::HWModuleOp>(
162  loc, StringAttr::get(ctx, func.getName() + "_esi_wrapper"),
163  wrapperModPortInfo);
164  Value clk = wrapperMod.getBodyBlock()->getArgument(
165  wrapperMod.getBodyBlock()->getNumArguments() - 2);
166  Value rst = wrapperMod.getBodyBlock()->getArgument(
167  wrapperMod.getBodyBlock()->getNumArguments() - 1);
168  SmallVector<Value> clkRes = {clk, rst};
169 
170  b.setInsertionPointToStart(wrapperMod.getBodyBlock());
171  BackedgeBuilder bb(b, loc);
172 
173  // Create backedges for the results of the external module. These will be
174  // replaced by the service instance requests if associated with a memory.
175  llvm::SmallVector<Backedge> backedges;
176  for (auto resType : extMod.getOutputTypes())
177  backedges.push_back(bb.get(resType));
178 
179  // Maintain which index we're currently at in the lowered handshake module's
180  // return.
181  unsigned resIdx = origPorts.sizeOutputs();
182 
183  // Maintain the arguments which each memory will add to the inner module
184  // instance.
185  llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
186 
187  for (auto [i, memType] : argReplacements) {
188 
189  b.setInsertionPoint(wrapperMod);
190  // Create a memory service declaration for each memref argument that was
191  // served.
192  auto origPortInfo = origPorts.atInput(i);
193  auto memrefShape = memType.memRefType.getShape();
194  auto dataType = memType.memRefType.getElementType();
195  assert(memrefShape.size() == 1 && "Only 1D memrefs are supported");
196  unsigned memrefSize = memrefShape[0];
197  auto memServiceDecl = b.create<esi::RandomAccessMemoryDeclOp>(
198  loc, origPortInfo.name, TypeAttr::get(dataType),
199  b.getI64IntegerAttr(memrefSize));
200  esi::ServicePortInfo writePortInfo = memServiceDecl.writePortInfo();
201  esi::ServicePortInfo readPortInfo = memServiceDecl.readPortInfo();
202 
203  SmallVector<Value> instanceArgsFromThisMem;
204 
205  // Create service requests. This MUST follow the order of which ports were
206  // added in other parts of this pass (load ports first, then store ports).
207  b.setInsertionPointToStart(wrapperMod.getBodyBlock());
208 
209  // Load ports:
210  for (unsigned i = 0; i < memType.loadPorts; ++i) {
211  auto req = b.create<esi::RequestConnectionOp>(
212  loc, readPortInfo.type, readPortInfo.port,
213  esi::AppIDAttr::get(ctx, b.getStringAttr("load"), {}));
214  auto reqUnpack = b.create<esi::UnpackBundleOp>(
215  loc, req.getToClient(), ValueRange{backedges[resIdx]});
216  instanceArgsFromThisMem.push_back(
217  reqUnpack.getToChannels()
218  [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
219  ++resIdx;
220  }
221 
222  // Store ports:
223  for (unsigned i = 0; i < memType.storePorts; ++i) {
224  auto req = b.create<esi::RequestConnectionOp>(
225  loc, writePortInfo.type, writePortInfo.port,
226  esi::AppIDAttr::get(ctx, b.getStringAttr("store"), {}));
227  auto reqUnpack = b.create<esi::UnpackBundleOp>(
228  loc, req.getToClient(), ValueRange{backedges[resIdx]});
229  instanceArgsFromThisMem.push_back(
230  reqUnpack.getToChannels()
231  [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
232  ++resIdx;
233  }
234 
235  instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
236  }
237 
238  // Stitch together arguments from the top-level ESI wrapper and the instance
239  // arguments generated from the service requests.
240  llvm::SmallVector<Value> instanceArgs;
241 
242  // Iterate over the arguments of the original handshake.func and determine
243  // whether to grab operands from the arg replacements or the wrapper module.
244  unsigned wrapperArgIdx = 0;
245 
246  for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
247  // Arg replacement indices refer to the original handshake.func argument
248  // index.
249  if (argReplacements.count(i)) {
250  // This index was originally a memref - pop the instance arguments for the
251  // next-in-line memory and add them.
252  auto &memArgs = instanceArgsForMem.front();
253  instanceArgs.append(memArgs.begin(), memArgs.end());
254  instanceArgsForMem.erase(instanceArgsForMem.begin());
255  } else {
256  // Add the argument from the wrapper mod. This is maintained by its own
257  // counter (memref arguments are removed, so if there was an argument at
258  // this point, it needs to come from the wrapper module).
259  instanceArgs.push_back(
260  wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
261  }
262  }
263 
264  // Add any missing arguments from the wrapper module (this will be clock and
265  // reset)
266  for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
267  ++wrapperArgIdx)
268  instanceArgs.push_back(
269  wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
270 
271  // Instantiate the inner module.
272  auto instance =
273  b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
274 
275  // And resolve the backedges.
276  for (auto [res, be] : llvm::zip(instance.getResults(), backedges))
277  be.setValue(res);
278 
279  // Finally, grab the (non-memory) outputs from the inner module and return
280  // them through the wrapper.
281  auto outputOp =
282  cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
283  b.setInsertionPoint(outputOp);
284  b.create<hw::OutputOp>(
285  outputOp.getLoc(),
286  instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
287  outputOp.erase();
288 
289  return success();
290 }
291 
292 // Truncates the index-typed 'v' into an integer-type of the same width as the
293 // 'memref' argument.
294 // Uses arith operations since these are supported in the HandshakeToHW
295 // lowering.
296 static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
297  MemRefType memRefType) {
298  assert(isa<IndexType>(v.getType()) && "Expected an index-typed value");
299  auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
300  return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
301 }
302 
303 static Value plumbLoadPort(Location loc, OpBuilder &b,
304  handshake::MemLoadInterface &ldif, Value loadData,
305  MemRefType memrefType) {
306  // We need to feed both the load data and the load done outputs.
307  // Fork the extracted load data into two, and 'join' the second one to
308  // generate a none-typed output to drive the load done.
309  auto dataFork = b.create<ForkOp>(loc, loadData, 2);
310 
311  auto dataOut = dataFork.getResult()[0];
312  llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
313  auto dataDone = b.create<JoinOp>(loc, joinArgs);
314 
315  ldif.dataOut.replaceAllUsesWith(dataOut);
316  ldif.doneOut.replaceAllUsesWith(dataDone);
317 
318  // Return load address, to be fed to the top-level output, truncated to the
319  // width of the memory that is accessed.
320  return truncateToMemoryWidth(loc, b, ldif.addressIn, memrefType);
321 }
322 
323 static Value plumbStorePort(Location loc, OpBuilder &b,
324  handshake::MemStoreInterface &stif, Value done,
325  Type outType, MemRefType memrefType) {
326  stif.doneOut.replaceAllUsesWith(done);
327  // Return the store address and data to be fed to the top-level output.
328  // Address is truncated to the width of the memory that is accessed.
329  llvm::SmallVector<Value> structArgs = {
330  truncateToMemoryWidth(loc, b, stif.addressIn, memrefType), stif.dataIn};
331 
332  return b
333  .create<hw::StructCreateOp>(loc, cast<hw::StructType>(outType),
334  structArgs)
335  .getResult();
336 }
337 
338 static void appendToStringArrayAttr(Operation *op, StringRef attrName,
339  StringRef attrVal) {
340  auto *ctx = op->getContext();
341  llvm::SmallVector<Attribute> newArr;
342  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
343  std::back_inserter(newArr));
344  newArr.push_back(StringAttr::get(ctx, attrVal));
345  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
346 }
347 
348 static void insertInStringArrayAttr(Operation *op, StringRef attrName,
349  StringRef attrVal, unsigned idx) {
350  auto *ctx = op->getContext();
351  llvm::SmallVector<Attribute> newArr;
352  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
353  std::back_inserter(newArr));
354  newArr.insert(newArr.begin() + idx, StringAttr::get(ctx, attrVal));
355  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
356 }
357 
358 static void eraseFromArrayAttr(Operation *op, StringRef attrName,
359  unsigned idx) {
360  auto *ctx = op->getContext();
361  llvm::SmallVector<Attribute> newArr;
362  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
363  std::back_inserter(newArr));
364  newArr.erase(newArr.begin() + idx);
365  op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
366 }
367 
368 struct ArgTypeReplacement {
369  unsigned index;
370  TypeRange ins;
371  TypeRange outs;
372 };
373 
374 LogicalResult
375 HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
376  // Gather memref ports to be converted. This is an ordered map, and will be
377  // iterated from lo to hi indices.
378  std::map<unsigned, Value> memrefArgs;
379  for (auto [i, arg] : llvm::enumerate(func.getArguments()))
380  if (isa<MemRefType>(arg.getType()))
381  memrefArgs[i] = arg;
382 
383  if (memrefArgs.empty())
384  return success(); // nothing to do.
385 
386  // Record which arg indices were replaces with handshake memory ports.
387  // This is an ordered map, and will be iterated from lo to hi indices.
388  std::map<unsigned, HandshakeMemType> argReplacements;
389 
390  // Record the hw.module i/o of the original func (used for ESI wrapper).
391  auto origPortInfo = handshake::getPortInfoForOpTypes(
392  func, func.getArgumentTypes(), func.getResultTypes());
393 
394  OpBuilder b(func);
395  for (auto it : memrefArgs) {
396  // Do not use structured bindings for 'it' - cannot reference inside lambda.
397  unsigned i = it.first;
398  auto arg = it.second;
399  auto loc = arg.getLoc();
400  // Get the attached extmemory external module.
401  auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
402  b.setInsertionPoint(extmemOp);
403 
404  // Add memory input - this is the output of the extmemory op.
405  auto memIOTypes = getMemTypeForExtmem(arg);
406  MemRefType memrefType = cast<MemRefType>(arg.getType());
407 
408  auto oldReturnOp =
409  cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
410  llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
411  unsigned addedInPorts = 0;
412  auto memName = func.getArgName(i);
413  auto addArgRes = [&](unsigned id, NamedType &argType, NamedType &resType) {
414  // Function argument
415  unsigned newArgIdx = i + addedInPorts;
416  func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc());
417  insertInStringArrayAttr(func, "argNames",
418  memName.str() + "_" + argType.first.str(),
419  newArgIdx);
420  auto newInPort = func.getArgument(newArgIdx);
421  ++addedInPorts;
422 
423  // Function result.
424  func.insertResult(func.getNumResults(), resType.second, {});
425  appendToStringArrayAttr(func, "resNames",
426  memName.str() + "_" + resType.first.str());
427  return newInPort;
428  };
429 
430  // Plumb load ports.
431  unsigned portIdx = 0;
432  for (auto loadPort : extmemOp.getLoadPorts()) {
433  auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
434  memIOTypes.outputTypes[portIdx]);
435  newReturnOperands.push_back(
436  plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
437  ++portIdx;
438  }
439 
440  // Plumb store ports.
441  for (auto storePort : extmemOp.getStorePorts()) {
442  auto newInPort =
443  addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
444  memIOTypes.outputTypes[portIdx]);
445  newReturnOperands.push_back(
446  plumbStorePort(loc, b, storePort, newInPort,
447  memIOTypes.outputTypes[portIdx].second, memrefType));
448  ++portIdx;
449  }
450 
451  // Replace the return op of the function with a new one that returns the
452  // memory output struct.
453  b.setInsertionPoint(oldReturnOp);
454  b.create<ReturnOp>(arg.getLoc(), newReturnOperands);
455  oldReturnOp.erase();
456 
457  // Erase the extmemory operation since I/O plumbing has replaced all of its
458  // results.
459  extmemOp.erase();
460 
461  // Erase the original memref argument of the top-level i/o now that it's
462  // use has been removed.
463  func.eraseArgument(i + addedInPorts);
464  eraseFromArrayAttr(func, "argNames", i + addedInPorts);
465 
466  argReplacements[i] = memIOTypes;
467  }
468 
469  if (createESIWrapper)
470  if (failed(wrapESI(func, origPortInfo, argReplacements)))
471  return failure();
472 
473  return success();
474 }
475 
476 } // namespace
477 
478 std::unique_ptr<mlir::Pass>
480  std::optional<bool> createESIWrapper) {
481  return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
482 }
assert(baseType &&"element must be base type")
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
def name(self)
Definition: hw.py:195
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
static constexpr const char * kPredeclarationAttr
Definition: HandshakeToHW.h:37
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21