22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Rewrite/FrozenRewritePatternSet.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/Support/Debug.h"
31#define GEN_PASS_DEF_HANDSHAKELOWEREXTMEMTOHW
32#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
40using NamedType = std::pair<StringAttr, Type>;
41struct HandshakeMemType {
42 llvm::SmallVector<NamedType> inputTypes, outputTypes;
43 MemRefType memRefType;
44 unsigned loadPorts, storePorts;
51 static LoadName
get(MLIRContext *ctx,
unsigned idx) {
52 return {StringAttr::get(ctx,
"ld" + std::to_string(idx) +
".data"),
53 StringAttr::get(ctx,
"ld" + std::to_string(idx) +
".addr")};
61 static StoreNames
get(MLIRContext *ctx,
unsigned idx) {
62 return {StringAttr::get(ctx,
"st" + std::to_string(idx) +
".done"),
63 StringAttr::get(ctx,
"st" + std::to_string(idx))};
70 assert(isa<IndexType>(t) &&
"Expected index type");
71 auto shape = memRef.getShape();
72 assert(shape.size() == 1 &&
"Expected 1D memref");
73 unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
74 return IntegerType::get(t.getContext(), addrWidth);
78 auto *ctx = v.getContext();
79 assert(isa<mlir::MemRefType>(v.getType()) &&
"Value is not a memref type");
80 auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
81 HandshakeMemType memType;
82 llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
85 memType.memRefType = cast<MemRefType>(v.getType());
86 memType.loadPorts = extmemOp.getLdCount();
87 memType.storePorts = extmemOp.getStCount();
90 for (
auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
91 auto names = LoadName::get(ctx, i);
92 memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
93 memType.outputTypes.push_back(
99 for (
auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
100 auto names = StoreNames::get(ctx, i);
103 llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
104 storeOutFields.push_back(
105 {StringAttr::get(ctx,
"address"),
107 storeOutFields.push_back(
108 {StringAttr::get(ctx,
"data"), stif.dataIn.getType()});
109 auto inType = hw::StructType::get(ctx, storeOutFields);
110 memType.outputTypes.push_back({names.out, inType});
111 memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
118struct HandshakeLowerExtmemToHWPass
119 :
public circt::handshake::impl::HandshakeLowerExtmemToHWBase<
120 HandshakeLowerExtmemToHWPass> {
122 HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
123 if (createESIWrapper)
124 this->createESIWrapper = *createESIWrapper;
127 void runOnOperation()
override {
128 auto op = getOperation();
130 if (failed(lowerExtmemToHW(func))) {
140 const std::map<unsigned, HandshakeMemType> &argReplacements);
143LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
145 const std::map<unsigned, HandshakeMemType> &argReplacements) {
146 auto *ctx = func.getContext();
148 auto loc = func.getLoc();
152 b.setInsertionPoint(func);
153 auto newPortInfo = handshake::getPortInfoForOpTypes(
154 func, func.getArgumentTypes(), func.getResultTypes());
155 auto extMod = hw::HWModuleExternOp::create(
156 b, loc, StringAttr::get(ctx,
"__" + func.getName() +
"_hw"), newPortInfo);
160 func->setAttr(kPredeclarationAttr,
161 FlatSymbolRefAttr::get(ctx, extMod.getName()));
165 auto wrapperModPortInfo = origPorts;
166 llvm::SmallVector<unsigned> argReplacementsIdxs;
167 llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
168 [](
auto &pair) {
return pair.first; });
169 for (
auto i :
llvm::reverse(argReplacementsIdxs))
170 wrapperModPortInfo.eraseInput(i);
171 auto wrapperMod = hw::HWModuleOp::create(
172 b, loc, StringAttr::get(ctx, func.getName() +
"_esi_wrapper"),
174 Value
clk = wrapperMod.getBodyBlock()->getArgument(
175 wrapperMod.getBodyBlock()->getNumArguments() - 2);
176 Value rst = wrapperMod.getBodyBlock()->getArgument(
177 wrapperMod.getBodyBlock()->getNumArguments() - 1);
178 SmallVector<Value> clkRes = {
clk, rst};
180 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
185 llvm::SmallVector<Backedge> backedges;
186 for (
auto resType : extMod.getOutputTypes())
187 backedges.push_back(bb.
get(resType));
195 llvm::SmallVector<llvm::SmallVector<Value>> instanceArgsForMem;
197 for (
auto [i, memType] : argReplacements) {
199 b.setInsertionPoint(wrapperMod);
202 auto origPortInfo = origPorts.
atInput(i);
203 auto memrefShape = memType.memRefType.getShape();
204 auto dataType = memType.memRefType.getElementType();
205 assert(memrefShape.size() == 1 &&
"Only 1D memrefs are supported");
206 unsigned memrefSize = memrefShape[0];
207 auto memServiceDecl = esi::RandomAccessMemoryDeclOp::create(
208 b, loc, origPortInfo.name, TypeAttr::get(dataType),
209 b.getI64IntegerAttr(memrefSize));
213 SmallVector<Value> instanceArgsFromThisMem;
217 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
220 for (
unsigned i = 0; i < memType.loadPorts; ++i) {
221 auto req = esi::RequestConnectionOp::create(
222 b, loc, readPortInfo.
type, readPortInfo.
port,
223 esi::AppIDAttr::get(ctx,
b.getStringAttr(
"load"), resIdx));
224 auto reqUnpack = esi::UnpackBundleOp::create(
225 b, loc, req.getToClient(), ValueRange{backedges[resIdx]});
226 instanceArgsFromThisMem.push_back(
227 reqUnpack.getToChannels()
228 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
233 for (
unsigned i = 0; i < memType.storePorts; ++i) {
234 auto req = esi::RequestConnectionOp::create(
235 b, loc, writePortInfo.
type, writePortInfo.
port,
236 esi::AppIDAttr::get(ctx,
b.getStringAttr(
"store"), resIdx));
237 auto reqUnpack = esi::UnpackBundleOp::create(
238 b, loc, req.getToClient(), ValueRange{backedges[resIdx]});
239 instanceArgsFromThisMem.push_back(
240 reqUnpack.getToChannels()
241 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
245 instanceArgsForMem.emplace_back(std::move(instanceArgsFromThisMem));
250 llvm::SmallVector<Value> instanceArgs;
254 unsigned wrapperArgIdx = 0;
256 for (
unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
259 if (argReplacements.count(i)) {
262 auto &memArgs = instanceArgsForMem.front();
263 instanceArgs.append(memArgs.begin(), memArgs.end());
264 instanceArgsForMem.erase(instanceArgsForMem.begin());
269 instanceArgs.push_back(
270 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
276 for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
278 instanceArgs.push_back(
279 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
283 hw::InstanceOp::create(b, loc, extMod, func.getName(), instanceArgs);
286 for (
auto [res, be] :
llvm::zip(instance.getResults(), backedges))
292 cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
293 b.setInsertionPoint(outputOp);
294 hw::OutputOp::create(
295 b, outputOp.getLoc(),
296 instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
306static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
307 MemRefType memRefType) {
308 assert(isa<IndexType>(v.getType()) &&
"Expected an index-typed value");
309 auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
310 if (addrWidth == 0) {
313 auto ctrl = handshake::JoinOp::create(b, loc, v).getResult();
314 return handshake::ConstantOp::create(
315 b, loc,
b.getIntegerType(0),
b.getIntegerAttr(
b.getIntegerType(0), 0),
318 return arith::IndexCastOp::create(b, loc,
b.getIntegerType(addrWidth), v);
321static Value plumbLoadPort(Location loc, OpBuilder &b,
323 MemRefType memrefType) {
327 auto dataFork = ForkOp::create(b, loc, loadData, 2);
329 auto dataOut = dataFork.getResult()[0];
330 llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
331 auto dataDone = JoinOp::create(b, loc, joinArgs);
333 ldif.
dataOut.replaceAllUsesWith(dataOut);
334 ldif.
doneOut.replaceAllUsesWith(dataDone);
338 return truncateToMemoryWidth(loc, b, ldif.
addressIn, memrefType);
341static Value plumbStorePort(Location loc, OpBuilder &b,
343 Type outType, MemRefType memrefType) {
344 stif.
doneOut.replaceAllUsesWith(done);
347 llvm::SmallVector<Value> structArgs = {
348 truncateToMemoryWidth(loc, b, stif.
addressIn, memrefType), stif.
dataIn};
355static void appendToStringArrayAttr(Operation *op, StringRef attrName,
357 auto *ctx = op->getContext();
358 llvm::SmallVector<Attribute> newArr;
359 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
360 std::back_inserter(newArr));
361 newArr.push_back(StringAttr::get(ctx, attrVal));
362 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
365static void insertInStringArrayAttr(Operation *op, StringRef attrName,
366 StringRef attrVal,
unsigned idx) {
367 auto *ctx = op->getContext();
368 llvm::SmallVector<Attribute> newArr;
369 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
370 std::back_inserter(newArr));
371 newArr.insert(newArr.begin() + idx, StringAttr::get(ctx, attrVal));
372 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
375static void eraseFromArrayAttr(Operation *op, StringRef attrName,
377 auto *ctx = op->getContext();
378 llvm::SmallVector<Attribute> newArr;
379 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
380 std::back_inserter(newArr));
381 newArr.erase(newArr.begin() + idx);
382 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
385struct ArgTypeReplacement {
395 std::map<unsigned, Value> memrefArgs;
396 for (
auto [i, arg] :
llvm::enumerate(func.getArguments()))
397 if (isa<MemRefType>(arg.getType()))
400 if (memrefArgs.empty())
405 std::map<unsigned, HandshakeMemType> argReplacements;
408 auto origPortInfo = handshake::getPortInfoForOpTypes(
409 func, func.getArgumentTypes(), func.getResultTypes());
412 for (
auto it : memrefArgs) {
414 unsigned i = it.first;
415 auto arg = it.second;
416 auto loc = arg.getLoc();
418 auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
419 b.setInsertionPoint(extmemOp);
423 MemRefType memrefType = cast<MemRefType>(arg.getType());
426 cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
427 llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
428 unsigned addedInPorts = 0;
429 auto memName = func.getArgName(i);
430 auto addArgRes = [&](
unsigned id, NamedType &argType,
431 NamedType &resType) -> FailureOr<Value> {
433 unsigned newArgIdx = i + addedInPorts;
435 func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc())))
437 insertInStringArrayAttr(func,
"argNames",
438 memName.str() +
"_" + argType.first.str(),
440 auto newInPort = func.getArgument(newArgIdx);
444 if (failed(func.insertResult(func.getNumResults(), resType.second, {})))
446 appendToStringArrayAttr(func,
"resNames",
447 memName.str() +
"_" + resType.first.str());
452 unsigned portIdx = 0;
454 auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
455 memIOTypes.outputTypes[portIdx]);
456 if (failed(newInPort))
458 newReturnOperands.push_back(
459 plumbLoadPort(loc, b, loadPort, *newInPort, memrefType));
466 addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
467 memIOTypes.outputTypes[portIdx]);
468 if (failed(newInPort))
470 newReturnOperands.push_back(
471 plumbStorePort(loc, b, storePort, *newInPort,
472 memIOTypes.outputTypes[portIdx].second, memrefType));
478 b.setInsertionPoint(oldReturnOp);
479 ReturnOp::create(b, arg.getLoc(), newReturnOperands);
488 if (failed(func.eraseArgument(i + addedInPorts)))
490 eraseFromArrayAttr(func,
"argNames", i + addedInPorts);
492 argReplacements[i] = memIOTypes;
495 if (createESIWrapper)
496 if (failed(wrapESI(func, origPortInfo, argReplacements)))
504std::unique_ptr<mlir::Pass>
506 std::optional<bool> createESIWrapper) {
507 return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
assert(baseType &&"element must be base type")
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
create(elements, Type result_type=None)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Describes a service port.
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & atInput(size_t idx)
size_t sizeOutputs() const