22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/Support/Debug.h"
31 #define GEN_PASS_DEF_HANDSHAKELOWEREXTMEMTOHW
32 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
36 using namespace circt;
37 using namespace handshake;
40 using NamedType = std::pair<StringAttr, Type>;
41 struct HandshakeMemType {
42 llvm::SmallVector<NamedType> inputTypes, outputTypes;
43 MemRefType memRefType;
44 unsigned loadPorts, storePorts;
51 static LoadName
get(MLIRContext *ctx,
unsigned idx) {
61 static StoreNames
get(MLIRContext *ctx,
unsigned idx) {
70 assert(isa<IndexType>(t) &&
"Expected index type");
71 auto shape = memRef.getShape();
72 assert(shape.size() == 1 &&
"Expected 1D memref");
73 unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
78 auto *ctx = v.getContext();
79 assert(isa<mlir::MemRefType>(v.getType()) &&
"Value is not a memref type");
80 auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
81 HandshakeMemType memType;
82 llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
85 memType.memRefType = cast<MemRefType>(v.getType());
86 memType.loadPorts = extmemOp.getLdCount();
87 memType.storePorts = extmemOp.getStCount();
90 for (
auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
92 memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
93 memType.outputTypes.push_back(
99 for (
auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
103 llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
104 storeOutFields.push_back(
107 storeOutFields.push_back(
110 memType.outputTypes.push_back({names.out, inType});
111 memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
118 struct HandshakeLowerExtmemToHWPass
119 :
public circt::handshake::impl::HandshakeLowerExtmemToHWBase<
120 HandshakeLowerExtmemToHWPass> {
122 HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
123 if (createESIWrapper)
124 this->createESIWrapper = *createESIWrapper;
127 void runOnOperation()
override {
128 auto op = getOperation();
129 for (
auto func : op.getOps<handshake::FuncOp>()) {
130 if (failed(lowerExtmemToHW(func))) {
137 LogicalResult lowerExtmemToHW(handshake::FuncOp func);
139 wrapESI(handshake::FuncOp func, hw::ModulePortInfo origPorts,
140 const std::map<unsigned, HandshakeMemType> &argReplacements);
143 LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
144 handshake::FuncOp func, hw::ModulePortInfo origPorts,
145 const std::map<unsigned, HandshakeMemType> &argReplacements) {
146 auto *ctx = func.getContext();
148 auto loc = func.getLoc();
152 b.setInsertionPoint(func);
154 func, func.getArgumentTypes(), func.getResultTypes());
156 loc,
StringAttr::get(ctx,
"__" + func.getName() +
"_hw"), newPortInfo);
165 auto wrapperModPortInfo = origPorts;
166 llvm::SmallVector<unsigned> argReplacementsIdxs;
167 llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
168 [](
auto &pair) {
return pair.first; });
169 for (
auto i : llvm::reverse(argReplacementsIdxs))
170 wrapperModPortInfo.eraseInput(i);
174 Value
clk = wrapperMod.getBodyBlock()->getArgument(
175 wrapperMod.getBodyBlock()->getNumArguments() - 2);
176 Value rst = wrapperMod.getBodyBlock()->getArgument(
177 wrapperMod.getBodyBlock()->getNumArguments() - 1);
178 SmallVector<Value> clkRes = {
clk, rst};
180 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
185 llvm::SmallVector<Backedge> backedges;
186 for (
auto resType : extMod.getOutputTypes())
187 backedges.push_back(bb.get(resType));
191 unsigned resIdx = origPorts.sizeOutputs();
195 llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
197 for (
auto [i, memType] : argReplacements) {
199 b.setInsertionPoint(wrapperMod);
202 auto origPortInfo = origPorts.atInput(i);
203 auto memrefShape = memType.memRefType.getShape();
204 auto dataType = memType.memRefType.getElementType();
205 assert(memrefShape.size() == 1 &&
"Only 1D memrefs are supported");
206 unsigned memrefSize = memrefShape[0];
209 b.getI64IntegerAttr(memrefSize));
210 esi::ServicePortInfo writePortInfo = memServiceDecl.writePortInfo();
211 esi::ServicePortInfo readPortInfo = memServiceDecl.readPortInfo();
213 SmallVector<Value> instanceArgsFromThisMem;
217 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
220 for (
unsigned i = 0; i < memType.loadPorts; ++i) {
222 loc, readPortInfo.type, readPortInfo.port,
224 auto reqUnpack = b.create<esi::UnpackBundleOp>(
225 loc, req.getToClient(), ValueRange{backedges[resIdx]});
226 instanceArgsFromThisMem.push_back(
227 reqUnpack.getToChannels()
228 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
233 for (
unsigned i = 0; i < memType.storePorts; ++i) {
235 loc, writePortInfo.type, writePortInfo.port,
237 auto reqUnpack = b.create<esi::UnpackBundleOp>(
238 loc, req.getToClient(), ValueRange{backedges[resIdx]});
239 instanceArgsFromThisMem.push_back(
240 reqUnpack.getToChannels()
241 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
245 instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
250 llvm::SmallVector<Value> instanceArgs;
254 unsigned wrapperArgIdx = 0;
256 for (
unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
259 if (argReplacements.count(i)) {
262 auto &memArgs = instanceArgsForMem.front();
263 instanceArgs.append(memArgs.begin(), memArgs.end());
264 instanceArgsForMem.erase(instanceArgsForMem.begin());
269 instanceArgs.push_back(
270 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
276 for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
278 instanceArgs.push_back(
279 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
283 b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
286 for (
auto [res, be] : llvm::zip(instance.getResults(), backedges))
292 cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
293 b.setInsertionPoint(outputOp);
294 b.create<hw::OutputOp>(
296 instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
306 static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
307 MemRefType memRefType) {
308 assert(isa<IndexType>(v.getType()) &&
"Expected an index-typed value");
309 auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
310 return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
313 static Value plumbLoadPort(Location loc, OpBuilder &b,
315 MemRefType memrefType) {
319 auto dataFork = b.create<ForkOp>(loc, loadData, 2);
321 auto dataOut = dataFork.getResult()[0];
322 llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
323 auto dataDone = b.create<JoinOp>(loc, joinArgs);
325 ldif.
dataOut.replaceAllUsesWith(dataOut);
326 ldif.
doneOut.replaceAllUsesWith(dataDone);
330 return truncateToMemoryWidth(loc, b, ldif.
addressIn, memrefType);
333 static Value plumbStorePort(Location loc, OpBuilder &b,
335 Type outType, MemRefType memrefType) {
336 stif.
doneOut.replaceAllUsesWith(done);
339 llvm::SmallVector<Value> structArgs = {
340 truncateToMemoryWidth(loc, b, stif.
addressIn, memrefType), stif.
dataIn};
348 static void appendToStringArrayAttr(Operation *op, StringRef attrName,
350 auto *ctx = op->getContext();
351 llvm::SmallVector<Attribute> newArr;
352 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
353 std::back_inserter(newArr));
358 static void insertInStringArrayAttr(Operation *op, StringRef attrName,
359 StringRef attrVal,
unsigned idx) {
360 auto *ctx = op->getContext();
361 llvm::SmallVector<Attribute> newArr;
362 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
363 std::back_inserter(newArr));
368 static void eraseFromArrayAttr(Operation *op, StringRef attrName,
370 auto *ctx = op->getContext();
371 llvm::SmallVector<Attribute> newArr;
372 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
373 std::back_inserter(newArr));
374 newArr.erase(newArr.begin() + idx);
378 struct ArgTypeReplacement {
385 HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
388 std::map<unsigned, Value> memrefArgs;
389 for (
auto [i, arg] : llvm::enumerate(func.getArguments()))
390 if (isa<MemRefType>(arg.getType()))
393 if (memrefArgs.empty())
398 std::map<unsigned, HandshakeMemType> argReplacements;
402 func, func.getArgumentTypes(), func.getResultTypes());
405 for (
auto it : memrefArgs) {
407 unsigned i = it.first;
408 auto arg = it.second;
409 auto loc = arg.getLoc();
411 auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
412 b.setInsertionPoint(extmemOp);
416 MemRefType memrefType = cast<MemRefType>(arg.getType());
419 cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
420 llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
421 unsigned addedInPorts = 0;
422 auto memName = func.getArgName(i);
423 auto addArgRes = [&](
unsigned id, NamedType &argType, NamedType &resType) {
425 unsigned newArgIdx = i + addedInPorts;
426 func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc());
427 insertInStringArrayAttr(func,
"argNames",
428 memName.str() +
"_" + argType.first.str(),
430 auto newInPort = func.getArgument(newArgIdx);
434 func.insertResult(func.getNumResults(), resType.second, {});
435 appendToStringArrayAttr(func,
"resNames",
436 memName.str() +
"_" + resType.first.str());
441 unsigned portIdx = 0;
442 for (
auto loadPort : extmemOp.getLoadPorts()) {
443 auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
444 memIOTypes.outputTypes[portIdx]);
445 newReturnOperands.push_back(
446 plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
451 for (
auto storePort : extmemOp.getStorePorts()) {
453 addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
454 memIOTypes.outputTypes[portIdx]);
455 newReturnOperands.push_back(
456 plumbStorePort(loc, b, storePort, newInPort,
457 memIOTypes.outputTypes[portIdx].second, memrefType));
463 b.setInsertionPoint(oldReturnOp);
464 b.create<ReturnOp>(arg.getLoc(), newReturnOperands);
473 func.eraseArgument(i + addedInPorts);
474 eraseFromArrayAttr(func,
"argNames", i + addedInPorts);
476 argReplacements[i] = memIOTypes;
479 if (createESIWrapper)
480 if (failed(wrapESI(func, origPortInfo, argReplacements)))
488 std::unique_ptr<mlir::Pass>
490 std::optional<bool> createESIWrapper) {
491 return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
assert(baseType &&"element must be base type")
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.