22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/Support/Debug.h"
31 #define GEN_PASS_DEF_HANDSHAKELOWEREXTMEMTOHW
32 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
36 using namespace circt;
40 using NamedType = std::pair<StringAttr, Type>;
41 struct HandshakeMemType {
42 llvm::SmallVector<NamedType> inputTypes, outputTypes;
43 MemRefType memRefType;
44 unsigned loadPorts, storePorts;
51 static LoadName
get(MLIRContext *ctx,
unsigned idx) {
61 static StoreNames
get(MLIRContext *ctx,
unsigned idx) {
70 assert(isa<IndexType>(t) &&
"Expected index type");
71 auto shape = memRef.getShape();
72 assert(shape.size() == 1 &&
"Expected 1D memref");
73 unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
78 auto *ctx = v.getContext();
79 assert(isa<mlir::MemRefType>(v.getType()) &&
"Value is not a memref type");
80 auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
81 HandshakeMemType memType;
82 llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
85 memType.memRefType = cast<MemRefType>(v.getType());
86 memType.loadPorts = extmemOp.getLdCount();
87 memType.storePorts = extmemOp.getStCount();
90 for (
auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
92 memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
93 memType.outputTypes.push_back(
99 for (
auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
103 llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
104 storeOutFields.push_back(
107 storeOutFields.push_back(
110 memType.outputTypes.push_back({names.out, inType});
111 memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
118 struct HandshakeLowerExtmemToHWPass
119 :
public circt::handshake::impl::HandshakeLowerExtmemToHWBase<
120 HandshakeLowerExtmemToHWPass> {
122 HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
123 if (createESIWrapper)
124 this->createESIWrapper = *createESIWrapper;
127 void runOnOperation()
override {
128 auto op = getOperation();
130 if (failed(lowerExtmemToHW(func))) {
140 const std::map<unsigned, HandshakeMemType> &argReplacements);
143 LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
145 const std::map<unsigned, HandshakeMemType> &argReplacements) {
146 auto *ctx = func.getContext();
148 auto loc = func.getLoc();
152 b.setInsertionPoint(func);
154 func, func.getArgumentTypes(), func.getResultTypes());
156 loc,
StringAttr::get(ctx,
"__" + func.getName() +
"_hw"), newPortInfo);
165 auto wrapperModPortInfo = origPorts;
166 llvm::SmallVector<unsigned> argReplacementsIdxs;
167 llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
168 [](
auto &pair) {
return pair.first; });
169 for (
auto i : llvm::reverse(argReplacementsIdxs))
170 wrapperModPortInfo.eraseInput(i);
174 Value
clk = wrapperMod.getBodyBlock()->getArgument(
175 wrapperMod.getBodyBlock()->getNumArguments() - 2);
176 Value rst = wrapperMod.getBodyBlock()->getArgument(
177 wrapperMod.getBodyBlock()->getNumArguments() - 1);
178 SmallVector<Value> clkRes = {
clk, rst};
180 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
185 llvm::SmallVector<Backedge> backedges;
186 for (
auto resType : extMod.getOutputTypes())
187 backedges.push_back(bb.get(resType));
191 unsigned resIdx = origPorts.sizeOutputs();
195 llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
197 for (
auto [i, memType] : argReplacements) {
199 b.setInsertionPoint(wrapperMod);
202 auto origPortInfo = origPorts.atInput(i);
203 auto memrefShape = memType.memRefType.getShape();
204 auto dataType = memType.memRefType.getElementType();
205 assert(memrefShape.size() == 1 &&
"Only 1D memrefs are supported");
206 unsigned memrefSize = memrefShape[0];
209 b.getI64IntegerAttr(memrefSize));
210 esi::ServicePortInfo writePortInfo = memServiceDecl.writePortInfo();
211 esi::ServicePortInfo readPortInfo = memServiceDecl.readPortInfo();
213 SmallVector<Value> instanceArgsFromThisMem;
217 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
220 for (
unsigned i = 0; i < memType.loadPorts; ++i) {
222 loc, readPortInfo.type, readPortInfo.port,
224 auto reqUnpack = b.create<esi::UnpackBundleOp>(
225 loc, req.getToClient(), ValueRange{backedges[resIdx]});
226 instanceArgsFromThisMem.push_back(
227 reqUnpack.getToChannels()
228 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
233 for (
unsigned i = 0; i < memType.storePorts; ++i) {
235 loc, writePortInfo.type, writePortInfo.port,
237 auto reqUnpack = b.create<esi::UnpackBundleOp>(
238 loc, req.getToClient(), ValueRange{backedges[resIdx]});
239 instanceArgsFromThisMem.push_back(
240 reqUnpack.getToChannels()
241 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
245 instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
250 llvm::SmallVector<Value> instanceArgs;
254 unsigned wrapperArgIdx = 0;
256 for (
unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
259 if (argReplacements.count(i)) {
262 auto &memArgs = instanceArgsForMem.front();
263 instanceArgs.append(memArgs.begin(), memArgs.end());
264 instanceArgsForMem.erase(instanceArgsForMem.begin());
269 instanceArgs.push_back(
270 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
276 for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
278 instanceArgs.push_back(
279 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
283 b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
286 for (
auto [res, be] : llvm::zip(instance.getResults(), backedges))
292 cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
293 b.setInsertionPoint(outputOp);
294 b.create<hw::OutputOp>(
296 instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
306 static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
307 MemRefType memRefType) {
308 assert(isa<IndexType>(v.getType()) &&
"Expected an index-typed value");
309 auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
310 return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
313 static Value plumbLoadPort(Location loc, OpBuilder &b,
314 handshake::MemLoadInterface &ldif, Value loadData,
315 MemRefType memrefType) {
319 auto dataFork = b.create<ForkOp>(loc, loadData, 2);
321 auto dataOut = dataFork.getResult()[0];
322 llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
323 auto dataDone = b.create<JoinOp>(loc, joinArgs);
325 ldif.dataOut.replaceAllUsesWith(dataOut);
326 ldif.doneOut.replaceAllUsesWith(dataDone);
330 return truncateToMemoryWidth(loc, b, ldif.addressIn, memrefType);
333 static Value plumbStorePort(Location loc, OpBuilder &b,
334 handshake::MemStoreInterface &stif, Value done,
335 Type outType, MemRefType memrefType) {
336 stif.doneOut.replaceAllUsesWith(done);
339 llvm::SmallVector<Value> structArgs = {
340 truncateToMemoryWidth(loc, b, stif.addressIn, memrefType), stif.dataIn};
348 static void appendToStringArrayAttr(Operation *op, StringRef attrName,
350 auto *ctx = op->getContext();
351 llvm::SmallVector<Attribute> newArr;
352 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
353 std::back_inserter(newArr));
358 static void insertInStringArrayAttr(Operation *op, StringRef attrName,
359 StringRef attrVal,
unsigned idx) {
360 auto *ctx = op->getContext();
361 llvm::SmallVector<Attribute> newArr;
362 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
363 std::back_inserter(newArr));
368 static void eraseFromArrayAttr(Operation *op, StringRef attrName,
370 auto *ctx = op->getContext();
371 llvm::SmallVector<Attribute> newArr;
372 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
373 std::back_inserter(newArr));
374 newArr.erase(newArr.begin() + idx);
378 struct ArgTypeReplacement {
388 std::map<unsigned, Value> memrefArgs;
389 for (
auto [i, arg] : llvm::enumerate(func.getArguments()))
390 if (isa<MemRefType>(arg.getType()))
393 if (memrefArgs.empty())
398 std::map<unsigned, HandshakeMemType> argReplacements;
402 func, func.getArgumentTypes(), func.getResultTypes());
405 for (
auto it : memrefArgs) {
407 unsigned i = it.first;
408 auto arg = it.second;
409 auto loc = arg.getLoc();
411 auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
412 b.setInsertionPoint(extmemOp);
416 MemRefType memrefType = cast<MemRefType>(arg.getType());
419 cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
420 llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
421 unsigned addedInPorts = 0;
422 auto memName = func.getArgName(i);
423 auto addArgRes = [&](
unsigned id, NamedType &argType, NamedType &resType) {
425 unsigned newArgIdx = i + addedInPorts;
426 func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc());
427 insertInStringArrayAttr(func,
"argNames",
428 memName.str() +
"_" + argType.first.str(),
430 auto newInPort = func.getArgument(newArgIdx);
434 func.insertResult(func.getNumResults(), resType.second, {});
435 appendToStringArrayAttr(func,
"resNames",
436 memName.str() +
"_" + resType.first.str());
441 unsigned portIdx = 0;
442 for (
auto loadPort : extmemOp.getLoadPorts()) {
443 auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
444 memIOTypes.outputTypes[portIdx]);
445 newReturnOperands.push_back(
446 plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
451 for (
auto storePort : extmemOp.getStorePorts()) {
453 addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
454 memIOTypes.outputTypes[portIdx]);
455 newReturnOperands.push_back(
456 plumbStorePort(loc, b, storePort, newInPort,
457 memIOTypes.outputTypes[portIdx].second, memrefType));
463 b.setInsertionPoint(oldReturnOp);
464 b.create<ReturnOp>(arg.getLoc(), newReturnOperands);
473 func.eraseArgument(i + addedInPorts);
474 eraseFromArrayAttr(func,
"argNames", i + addedInPorts);
476 argReplacements[i] = memIOTypes;
479 if (createESIWrapper)
480 if (failed(wrapESI(func, origPortInfo, argReplacements)))
488 std::unique_ptr<mlir::Pass>
490 std::optional<bool> createESIWrapper) {
491 return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
assert(baseType &&"element must be base type")
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
FuncOp create(Union[StringAttr, str] sym_name, List[Tuple[str, Type]] args, List[Tuple[str, Type]] results, Dict[str, Attribute] attributes={}, loc=None, ip=None)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.