22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Rewrite/FrozenRewritePatternSet.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/Support/Debug.h"
31#define GEN_PASS_DEF_HANDSHAKELOWEREXTMEMTOHW
32#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
40using NamedType = std::pair<StringAttr, Type>;
41struct HandshakeMemType {
42 llvm::SmallVector<NamedType> inputTypes, outputTypes;
43 MemRefType memRefType;
44 unsigned loadPorts, storePorts;
51 static LoadName
get(MLIRContext *ctx,
unsigned idx) {
52 return {StringAttr::get(ctx,
"ld" + std::to_string(idx) +
".data"),
53 StringAttr::get(ctx,
"ld" + std::to_string(idx) +
".addr")};
61 static StoreNames
get(MLIRContext *ctx,
unsigned idx) {
62 return {StringAttr::get(ctx,
"st" + std::to_string(idx) +
".done"),
63 StringAttr::get(ctx,
"st" + std::to_string(idx))};
70 assert(isa<IndexType>(t) &&
"Expected index type");
71 auto shape = memRef.getShape();
72 assert(shape.size() == 1 &&
"Expected 1D memref");
73 unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
74 return IntegerType::get(t.getContext(), addrWidth);
78 auto *ctx = v.getContext();
79 assert(isa<mlir::MemRefType>(v.getType()) &&
"Value is not a memref type");
80 auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
81 HandshakeMemType memType;
82 llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
85 memType.memRefType = cast<MemRefType>(v.getType());
86 memType.loadPorts = extmemOp.getLdCount();
87 memType.storePorts = extmemOp.getStCount();
90 for (
auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
91 auto names = LoadName::get(ctx, i);
92 memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
93 memType.outputTypes.push_back(
99 for (
auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
100 auto names = StoreNames::get(ctx, i);
103 llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
104 storeOutFields.push_back(
105 {StringAttr::get(ctx,
"address"),
107 storeOutFields.push_back(
108 {StringAttr::get(ctx,
"data"), stif.dataIn.getType()});
109 auto inType = hw::StructType::get(ctx, storeOutFields);
110 memType.outputTypes.push_back({names.out, inType});
111 memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
118struct HandshakeLowerExtmemToHWPass
119 :
public circt::handshake::impl::HandshakeLowerExtmemToHWBase<
120 HandshakeLowerExtmemToHWPass> {
122 HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
123 if (createESIWrapper)
124 this->createESIWrapper = *createESIWrapper;
127 void runOnOperation()
override {
128 auto op = getOperation();
130 if (failed(lowerExtmemToHW(func))) {
140 const std::map<unsigned, HandshakeMemType> &argReplacements);
143LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
145 const std::map<unsigned, HandshakeMemType> &argReplacements) {
146 auto *ctx = func.getContext();
148 auto loc = func.getLoc();
152 b.setInsertionPoint(func);
153 auto newPortInfo = handshake::getPortInfoForOpTypes(
154 func, func.getArgumentTypes(), func.getResultTypes());
156 loc, StringAttr::get(ctx,
"__" + func.getName() +
"_hw"), newPortInfo);
160 func->setAttr(kPredeclarationAttr,
161 FlatSymbolRefAttr::get(ctx, extMod.getName()));
165 auto wrapperModPortInfo = origPorts;
166 llvm::SmallVector<unsigned> argReplacementsIdxs;
167 llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
168 [](
auto &pair) {
return pair.first; });
169 for (
auto i :
llvm::reverse(argReplacementsIdxs))
170 wrapperModPortInfo.eraseInput(i);
172 loc, StringAttr::get(ctx, func.getName() +
"_esi_wrapper"),
174 Value
clk = wrapperMod.getBodyBlock()->getArgument(
175 wrapperMod.getBodyBlock()->getNumArguments() - 2);
176 Value rst = wrapperMod.getBodyBlock()->getArgument(
177 wrapperMod.getBodyBlock()->getNumArguments() - 1);
178 SmallVector<Value> clkRes = {
clk, rst};
180 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
185 llvm::SmallVector<Backedge> backedges;
186 for (
auto resType : extMod.getOutputTypes())
187 backedges.push_back(bb.
get(resType));
195 llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
197 for (
auto [i, memType] : argReplacements) {
199 b.setInsertionPoint(wrapperMod);
202 auto origPortInfo = origPorts.
atInput(i);
203 auto memrefShape = memType.memRefType.getShape();
204 auto dataType = memType.memRefType.getElementType();
205 assert(memrefShape.size() == 1 &&
"Only 1D memrefs are supported");
206 unsigned memrefSize = memrefShape[0];
208 loc, origPortInfo.
name, TypeAttr::get(dataType),
209 b.getI64IntegerAttr(memrefSize));
213 SmallVector<Value> instanceArgsFromThisMem;
217 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
220 for (
unsigned i = 0; i < memType.loadPorts; ++i) {
222 loc, readPortInfo.
type, readPortInfo.
port,
223 esi::AppIDAttr::get(ctx, b.getStringAttr(
"load"), {}));
224 auto reqUnpack = b.create<esi::UnpackBundleOp>(
225 loc, req.getToClient(), ValueRange{backedges[resIdx]});
226 instanceArgsFromThisMem.push_back(
227 reqUnpack.getToChannels()
228 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
233 for (
unsigned i = 0; i < memType.storePorts; ++i) {
235 loc, writePortInfo.type, writePortInfo.port,
236 esi::AppIDAttr::get(ctx, b.getStringAttr(
"store"), {}));
237 auto reqUnpack = b.create<esi::UnpackBundleOp>(
238 loc, req.getToClient(), ValueRange{backedges[resIdx]});
239 instanceArgsFromThisMem.push_back(
240 reqUnpack.getToChannels()
241 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
245 instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
250 llvm::SmallVector<Value> instanceArgs;
254 unsigned wrapperArgIdx = 0;
256 for (
unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
259 if (argReplacements.count(i)) {
262 auto &memArgs = instanceArgsForMem.front();
263 instanceArgs.append(memArgs.begin(), memArgs.end());
264 instanceArgsForMem.erase(instanceArgsForMem.begin());
269 instanceArgs.push_back(
270 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
276 for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
278 instanceArgs.push_back(
279 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
283 b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
286 for (
auto [res, be] :
llvm::zip(instance.getResults(), backedges))
292 cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
293 b.setInsertionPoint(outputOp);
294 b.create<hw::OutputOp>(
296 instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
306static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
307 MemRefType memRefType) {
308 assert(isa<IndexType>(v.getType()) &&
"Expected an index-typed value");
309 auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
310 return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
313static Value plumbLoadPort(Location loc, OpBuilder &b,
315 MemRefType memrefType) {
319 auto dataFork = b.create<ForkOp>(loc, loadData, 2);
321 auto dataOut = dataFork.getResult()[0];
322 llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
323 auto dataDone = b.create<JoinOp>(loc, joinArgs);
325 ldif.
dataOut.replaceAllUsesWith(dataOut);
326 ldif.
doneOut.replaceAllUsesWith(dataDone);
330 return truncateToMemoryWidth(loc, b, ldif.
addressIn, memrefType);
333static Value plumbStorePort(Location loc, OpBuilder &b,
335 Type outType, MemRefType memrefType) {
336 stif.
doneOut.replaceAllUsesWith(done);
339 llvm::SmallVector<Value> structArgs = {
340 truncateToMemoryWidth(loc, b, stif.
addressIn, memrefType), stif.
dataIn};
348static void appendToStringArrayAttr(Operation *op, StringRef attrName,
350 auto *ctx = op->getContext();
351 llvm::SmallVector<Attribute> newArr;
352 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
353 std::back_inserter(newArr));
354 newArr.push_back(StringAttr::get(ctx, attrVal));
355 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
358static void insertInStringArrayAttr(Operation *op, StringRef attrName,
359 StringRef attrVal,
unsigned idx) {
360 auto *ctx = op->getContext();
361 llvm::SmallVector<Attribute> newArr;
362 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
363 std::back_inserter(newArr));
364 newArr.insert(newArr.begin() + idx, StringAttr::get(ctx, attrVal));
365 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
368static void eraseFromArrayAttr(Operation *op, StringRef attrName,
370 auto *ctx = op->getContext();
371 llvm::SmallVector<Attribute> newArr;
372 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
373 std::back_inserter(newArr));
374 newArr.erase(newArr.begin() + idx);
375 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
378struct ArgTypeReplacement {
388 std::map<unsigned, Value> memrefArgs;
389 for (
auto [i, arg] :
llvm::enumerate(func.getArguments()))
390 if (isa<MemRefType>(arg.getType()))
393 if (memrefArgs.empty())
398 std::map<unsigned, HandshakeMemType> argReplacements;
401 auto origPortInfo = handshake::getPortInfoForOpTypes(
402 func, func.getArgumentTypes(), func.getResultTypes());
405 for (
auto it : memrefArgs) {
407 unsigned i = it.first;
408 auto arg = it.second;
409 auto loc = arg.getLoc();
411 auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
412 b.setInsertionPoint(extmemOp);
416 MemRefType memrefType = cast<MemRefType>(arg.getType());
419 cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
420 llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
421 unsigned addedInPorts = 0;
422 auto memName = func.getArgName(i);
423 auto addArgRes = [&](
unsigned id, NamedType &argType, NamedType &resType) {
425 unsigned newArgIdx = i + addedInPorts;
426 func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc());
427 insertInStringArrayAttr(func,
"argNames",
428 memName.str() +
"_" + argType.first.str(),
430 auto newInPort = func.getArgument(newArgIdx);
434 func.insertResult(func.getNumResults(), resType.second, {});
435 appendToStringArrayAttr(func,
"resNames",
436 memName.str() +
"_" + resType.first.str());
441 unsigned portIdx = 0;
443 auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
444 memIOTypes.outputTypes[portIdx]);
445 newReturnOperands.push_back(
446 plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
453 addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
454 memIOTypes.outputTypes[portIdx]);
455 newReturnOperands.push_back(
456 plumbStorePort(loc, b, storePort, newInPort,
457 memIOTypes.outputTypes[portIdx].second, memrefType));
463 b.setInsertionPoint(oldReturnOp);
464 b.create<ReturnOp>(arg.getLoc(), newReturnOperands);
473 func.eraseArgument(i + addedInPorts);
474 eraseFromArrayAttr(func,
"argNames", i + addedInPorts);
476 argReplacements[i] = memIOTypes;
479 if (createESIWrapper)
480 if (failed(wrapESI(func, origPortInfo, argReplacements)))
488std::unique_ptr<mlir::Pass>
490 std::optional<bool> createESIWrapper) {
491 return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
assert(baseType &&"element must be base type")
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Describes a service port.
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & atInput(size_t idx)
size_t sizeOutputs() const