22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Rewrite/FrozenRewritePatternSet.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/Support/Debug.h"
31#define GEN_PASS_DEF_HANDSHAKELOWEREXTMEMTOHW
32#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
40using NamedType = std::pair<StringAttr, Type>;
41struct HandshakeMemType {
42 llvm::SmallVector<NamedType> inputTypes, outputTypes;
43 MemRefType memRefType;
44 unsigned loadPorts, storePorts;
51 static LoadName
get(MLIRContext *ctx,
unsigned idx) {
52 return {StringAttr::get(ctx,
"ld" + std::to_string(idx) +
".data"),
53 StringAttr::get(ctx,
"ld" + std::to_string(idx) +
".addr")};
61 static StoreNames
get(MLIRContext *ctx,
unsigned idx) {
62 return {StringAttr::get(ctx,
"st" + std::to_string(idx) +
".done"),
63 StringAttr::get(ctx,
"st" + std::to_string(idx))};
70 assert(isa<IndexType>(t) &&
"Expected index type");
71 auto shape = memRef.getShape();
72 assert(shape.size() == 1 &&
"Expected 1D memref");
73 unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
74 return IntegerType::get(t.getContext(), addrWidth);
78 auto *ctx = v.getContext();
79 assert(isa<mlir::MemRefType>(v.getType()) &&
"Value is not a memref type");
80 auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
81 HandshakeMemType memType;
82 llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
85 memType.memRefType = cast<MemRefType>(v.getType());
86 memType.loadPorts = extmemOp.getLdCount();
87 memType.storePorts = extmemOp.getStCount();
90 for (
auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
91 auto names = LoadName::get(ctx, i);
92 memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
93 memType.outputTypes.push_back(
99 for (
auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
100 auto names = StoreNames::get(ctx, i);
103 llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
104 storeOutFields.push_back(
105 {StringAttr::get(ctx,
"address"),
107 storeOutFields.push_back(
108 {StringAttr::get(ctx,
"data"), stif.dataIn.getType()});
109 auto inType = hw::StructType::get(ctx, storeOutFields);
110 memType.outputTypes.push_back({names.out, inType});
111 memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
118struct HandshakeLowerExtmemToHWPass
119 :
public circt::handshake::impl::HandshakeLowerExtmemToHWBase<
120 HandshakeLowerExtmemToHWPass> {
122 HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
123 if (createESIWrapper)
124 this->createESIWrapper = *createESIWrapper;
127 void runOnOperation()
override {
128 auto op = getOperation();
130 if (failed(lowerExtmemToHW(func))) {
140 const std::map<unsigned, HandshakeMemType> &argReplacements);
143LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
145 const std::map<unsigned, HandshakeMemType> &argReplacements) {
146 auto *ctx = func.getContext();
148 auto loc = func.getLoc();
152 b.setInsertionPoint(func);
153 auto newPortInfo = handshake::getPortInfoForOpTypes(
154 func, func.getArgumentTypes(), func.getResultTypes());
155 auto extMod = hw::HWModuleExternOp::create(
156 b, loc, StringAttr::get(ctx,
"__" + func.getName() +
"_hw"), newPortInfo);
160 func->setAttr(kPredeclarationAttr,
161 FlatSymbolRefAttr::get(ctx, extMod.getName()));
165 auto wrapperModPortInfo = origPorts;
166 llvm::SmallVector<unsigned> argReplacementsIdxs;
167 llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
168 [](
auto &pair) {
return pair.first; });
169 for (
auto i :
llvm::reverse(argReplacementsIdxs))
170 wrapperModPortInfo.eraseInput(i);
171 auto wrapperMod = hw::HWModuleOp::create(
172 b, loc, StringAttr::get(ctx, func.getName() +
"_esi_wrapper"),
174 Value
clk = wrapperMod.getBodyBlock()->getArgument(
175 wrapperMod.getBodyBlock()->getNumArguments() - 2);
176 Value rst = wrapperMod.getBodyBlock()->getArgument(
177 wrapperMod.getBodyBlock()->getNumArguments() - 1);
178 SmallVector<Value> clkRes = {
clk, rst};
180 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
185 llvm::SmallVector<Backedge> backedges;
186 for (
auto resType : extMod.getOutputTypes())
187 backedges.push_back(bb.
get(resType));
195 llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
197 for (
auto [i, memType] : argReplacements) {
199 b.setInsertionPoint(wrapperMod);
202 auto origPortInfo = origPorts.
atInput(i);
203 auto memrefShape = memType.memRefType.getShape();
204 auto dataType = memType.memRefType.getElementType();
205 assert(memrefShape.size() == 1 &&
"Only 1D memrefs are supported");
206 unsigned memrefSize = memrefShape[0];
207 auto memServiceDecl = esi::RandomAccessMemoryDeclOp::create(
208 b, loc, origPortInfo.name, TypeAttr::get(dataType),
209 b.getI64IntegerAttr(memrefSize));
213 SmallVector<Value> instanceArgsFromThisMem;
217 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
220 for (
unsigned i = 0; i < memType.loadPorts; ++i) {
221 auto req = esi::RequestConnectionOp::create(
222 b, loc, readPortInfo.
type, readPortInfo.
port,
223 esi::AppIDAttr::get(ctx, b.getStringAttr(
"load"), resIdx));
224 auto reqUnpack = esi::UnpackBundleOp::create(
225 b, loc, req.getToClient(), ValueRange{backedges[resIdx]});
226 instanceArgsFromThisMem.push_back(
227 reqUnpack.getToChannels()
228 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
233 for (
unsigned i = 0; i < memType.storePorts; ++i) {
234 auto req = esi::RequestConnectionOp::create(
235 b, loc, writePortInfo.
type, writePortInfo.
port,
236 esi::AppIDAttr::get(ctx, b.getStringAttr(
"store"), resIdx));
237 auto reqUnpack = esi::UnpackBundleOp::create(
238 b, loc, req.getToClient(), ValueRange{backedges[resIdx]});
239 instanceArgsFromThisMem.push_back(
240 reqUnpack.getToChannels()
241 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
245 instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
250 llvm::SmallVector<Value> instanceArgs;
254 unsigned wrapperArgIdx = 0;
256 for (
unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
259 if (argReplacements.count(i)) {
262 auto &memArgs = instanceArgsForMem.front();
263 instanceArgs.append(memArgs.begin(), memArgs.end());
264 instanceArgsForMem.erase(instanceArgsForMem.begin());
269 instanceArgs.push_back(
270 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
276 for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
278 instanceArgs.push_back(
279 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
283 hw::InstanceOp::create(b, loc, extMod, func.getName(), instanceArgs);
286 for (
auto [res, be] :
llvm::zip(instance.getResults(), backedges))
292 cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
293 b.setInsertionPoint(outputOp);
294 hw::OutputOp::create(
295 b, outputOp.getLoc(),
296 instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
306static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
307 MemRefType memRefType) {
308 assert(isa<IndexType>(v.getType()) &&
"Expected an index-typed value");
309 auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
310 return arith::IndexCastOp::create(b, loc, b.getIntegerType(addrWidth), v);
313static Value plumbLoadPort(Location loc, OpBuilder &b,
315 MemRefType memrefType) {
319 auto dataFork = ForkOp::create(b, loc, loadData, 2);
321 auto dataOut = dataFork.getResult()[0];
322 llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
323 auto dataDone = JoinOp::create(b, loc, joinArgs);
325 ldif.
dataOut.replaceAllUsesWith(dataOut);
326 ldif.
doneOut.replaceAllUsesWith(dataDone);
330 return truncateToMemoryWidth(loc, b, ldif.
addressIn, memrefType);
333static Value plumbStorePort(Location loc, OpBuilder &b,
335 Type outType, MemRefType memrefType) {
336 stif.
doneOut.replaceAllUsesWith(done);
339 llvm::SmallVector<Value> structArgs = {
340 truncateToMemoryWidth(loc, b, stif.
addressIn, memrefType), stif.
dataIn};
347static void appendToStringArrayAttr(Operation *op, StringRef attrName,
349 auto *ctx = op->getContext();
350 llvm::SmallVector<Attribute> newArr;
351 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
352 std::back_inserter(newArr));
353 newArr.push_back(StringAttr::get(ctx, attrVal));
354 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
357static void insertInStringArrayAttr(Operation *op, StringRef attrName,
358 StringRef attrVal,
unsigned idx) {
359 auto *ctx = op->getContext();
360 llvm::SmallVector<Attribute> newArr;
361 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
362 std::back_inserter(newArr));
363 newArr.insert(newArr.begin() + idx, StringAttr::get(ctx, attrVal));
364 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
367static void eraseFromArrayAttr(Operation *op, StringRef attrName,
369 auto *ctx = op->getContext();
370 llvm::SmallVector<Attribute> newArr;
371 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
372 std::back_inserter(newArr));
373 newArr.erase(newArr.begin() + idx);
374 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
377struct ArgTypeReplacement {
387 std::map<unsigned, Value> memrefArgs;
388 for (
auto [i, arg] :
llvm::enumerate(func.getArguments()))
389 if (isa<MemRefType>(arg.getType()))
392 if (memrefArgs.empty())
397 std::map<unsigned, HandshakeMemType> argReplacements;
400 auto origPortInfo = handshake::getPortInfoForOpTypes(
401 func, func.getArgumentTypes(), func.getResultTypes());
404 for (
auto it : memrefArgs) {
406 unsigned i = it.first;
407 auto arg = it.second;
408 auto loc = arg.getLoc();
410 auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
411 b.setInsertionPoint(extmemOp);
415 MemRefType memrefType = cast<MemRefType>(arg.getType());
418 cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
419 llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
420 unsigned addedInPorts = 0;
421 auto memName = func.getArgName(i);
422 auto addArgRes = [&](
unsigned id, NamedType &argType,
423 NamedType &resType) -> FailureOr<Value> {
425 unsigned newArgIdx = i + addedInPorts;
427 func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc())))
429 insertInStringArrayAttr(func,
"argNames",
430 memName.str() +
"_" + argType.first.str(),
432 auto newInPort = func.getArgument(newArgIdx);
436 if (failed(func.insertResult(func.getNumResults(), resType.second, {})))
438 appendToStringArrayAttr(func,
"resNames",
439 memName.str() +
"_" + resType.first.str());
444 unsigned portIdx = 0;
446 auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
447 memIOTypes.outputTypes[portIdx]);
448 if (failed(newInPort))
450 newReturnOperands.push_back(
451 plumbLoadPort(loc, b, loadPort, *newInPort, memrefType));
458 addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
459 memIOTypes.outputTypes[portIdx]);
460 if (failed(newInPort))
462 newReturnOperands.push_back(
463 plumbStorePort(loc, b, storePort, *newInPort,
464 memIOTypes.outputTypes[portIdx].second, memrefType));
470 b.setInsertionPoint(oldReturnOp);
471 ReturnOp::create(b, arg.getLoc(), newReturnOperands);
480 if (failed(func.eraseArgument(i + addedInPorts)))
482 eraseFromArrayAttr(func,
"argNames", i + addedInPorts);
484 argReplacements[i] = memIOTypes;
487 if (createESIWrapper)
488 if (failed(wrapESI(func, origPortInfo, argReplacements)))
496std::unique_ptr<mlir::Pass>
498 std::optional<bool> createESIWrapper) {
499 return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
assert(baseType &&"element must be base type")
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
create(elements, Type result_type=None)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Describes a service port.
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & atInput(size_t idx)
size_t sizeOutputs() const