21 #include "mlir/Dialect/Arith/IR/Arith.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
27 using namespace circt;
28 using namespace handshake;
31 using NamedType = std::pair<StringAttr, Type>;
32 struct HandshakeMemType {
33 llvm::SmallVector<NamedType> inputTypes, outputTypes;
34 MemRefType memRefType;
35 unsigned loadPorts, storePorts;
42 static LoadName
get(MLIRContext *ctx,
unsigned idx) {
52 static StoreNames
get(MLIRContext *ctx,
unsigned idx) {
61 assert(t.isa<IndexType>() &&
"Expected index type");
62 auto shape = memRef.getShape();
63 assert(shape.size() == 1 &&
"Expected 1D memref");
64 unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
69 auto *ctx = v.getContext();
70 assert(v.getType().isa<mlir::MemRefType>() &&
"Value is not a memref type");
71 auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
72 HandshakeMemType memType;
73 llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
76 memType.memRefType = v.getType().cast<MemRefType>();
77 memType.loadPorts = extmemOp.getLdCount();
78 memType.storePorts = extmemOp.getStCount();
81 for (
auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
83 memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
84 memType.outputTypes.push_back(
90 for (
auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
94 llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
95 storeOutFields.push_back(
98 storeOutFields.push_back(
101 memType.outputTypes.push_back({names.out, inType});
102 memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
109 struct HandshakeLowerExtmemToHWPass
110 :
public HandshakeLowerExtmemToHWBase<HandshakeLowerExtmemToHWPass> {
112 HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
113 if (createESIWrapper)
114 this->createESIWrapper = *createESIWrapper;
117 void runOnOperation()
override {
118 auto op = getOperation();
119 for (
auto func : op.getOps<handshake::FuncOp>()) {
120 if (failed(lowerExtmemToHW(func))) {
127 LogicalResult lowerExtmemToHW(handshake::FuncOp func);
129 wrapESI(handshake::FuncOp func, hw::ModulePortInfo origPorts,
130 const std::map<unsigned, HandshakeMemType> &argReplacements);
133 LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
134 handshake::FuncOp func, hw::ModulePortInfo origPorts,
135 const std::map<unsigned, HandshakeMemType> &argReplacements) {
136 auto *ctx = func.getContext();
138 auto loc = func.getLoc();
142 b.setInsertionPoint(func);
144 func, func.getArgumentTypes(), func.getResultTypes());
146 loc,
StringAttr::get(ctx,
"__" + func.getName() +
"_hw"), newPortInfo);
155 auto wrapperModPortInfo = origPorts;
156 llvm::SmallVector<unsigned> argReplacementsIdxs;
157 llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
158 [](
auto &pair) {
return pair.first; });
159 for (
auto i : llvm::reverse(argReplacementsIdxs))
160 wrapperModPortInfo.eraseInput(i);
164 Value
clk = wrapperMod.getBodyBlock()->getArgument(
165 wrapperMod.getBodyBlock()->getNumArguments() - 2);
166 Value rst = wrapperMod.getBodyBlock()->getArgument(
167 wrapperMod.getBodyBlock()->getNumArguments() - 1);
168 SmallVector<Value> clkRes = {
clk, rst};
170 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
175 llvm::SmallVector<Backedge> backedges;
176 for (
auto resType : extMod.getOutputTypes())
177 backedges.push_back(bb.get(resType));
181 unsigned resIdx = origPorts.sizeOutputs();
185 llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
187 for (
auto [i, memType] : argReplacements) {
189 b.setInsertionPoint(wrapperMod);
192 auto origPortInfo = origPorts.atInput(i);
193 auto memrefShape = memType.memRefType.getShape();
194 auto dataType = memType.memRefType.getElementType();
195 assert(memrefShape.size() == 1 &&
"Only 1D memrefs are supported");
196 unsigned memrefSize = memrefShape[0];
199 b.getI64IntegerAttr(memrefSize));
200 esi::ServicePortInfo writePortInfo = memServiceDecl.writePortInfo();
201 esi::ServicePortInfo readPortInfo = memServiceDecl.readPortInfo();
203 SmallVector<Value> instanceArgsFromThisMem;
207 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
210 for (
unsigned i = 0; i < memType.loadPorts; ++i) {
211 auto reqPack = b.create<esi::PackBundleOp>(loc, readPortInfo.type,
212 (Value)backedges[resIdx]);
214 loc, readPortInfo.port, reqPack.getBundle(),
216 instanceArgsFromThisMem.push_back(
217 reqPack.getFromChannels()
218 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
223 for (
unsigned i = 0; i < memType.storePorts; ++i) {
224 auto reqPack = b.create<esi::PackBundleOp>(loc, writePortInfo.type,
225 (Value)backedges[resIdx]);
227 loc, writePortInfo.port, reqPack.getBundle(),
229 instanceArgsFromThisMem.push_back(
230 reqPack.getFromChannels()
231 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
235 instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
240 llvm::SmallVector<Value> instanceArgs;
244 unsigned wrapperArgIdx = 0;
246 for (
unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
249 if (argReplacements.count(i)) {
252 auto &memArgs = instanceArgsForMem.front();
253 instanceArgs.append(memArgs.begin(), memArgs.end());
254 instanceArgsForMem.erase(instanceArgsForMem.begin());
259 instanceArgs.push_back(
260 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
266 for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
268 instanceArgs.push_back(
269 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
273 b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
276 for (
auto [res, be] : llvm::zip(instance.getResults(), backedges))
282 cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
283 b.setInsertionPoint(outputOp);
284 b.create<hw::OutputOp>(
286 instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
296 static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
297 MemRefType memRefType) {
298 assert(v.getType().isa<IndexType>() &&
"Expected an index-typed value");
299 auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
300 return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
303 static Value plumbLoadPort(Location loc, OpBuilder &b,
305 Value loadData, MemRefType memrefType) {
309 auto dataFork = b.create<ForkOp>(loc, loadData, 2);
311 auto dataOut = dataFork.getResult()[0];
312 llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
313 auto dataDone = b.create<JoinOp>(loc, joinArgs);
315 ldif.
dataOut.replaceAllUsesWith(dataOut);
316 ldif.
doneOut.replaceAllUsesWith(dataDone);
320 return truncateToMemoryWidth(loc, b, ldif.
addressIn, memrefType);
323 static Value plumbStorePort(Location loc, OpBuilder &b,
325 Value done, Type outType, MemRefType memrefType) {
326 stif.
doneOut.replaceAllUsesWith(done);
329 llvm::SmallVector<Value> structArgs = {
330 truncateToMemoryWidth(loc, b, stif.
addressIn, memrefType), stif.
dataIn};
338 static void appendToStringArrayAttr(Operation *op, StringRef attrName,
340 auto *ctx = op->getContext();
341 llvm::SmallVector<Attribute> newArr;
342 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
343 std::back_inserter(newArr));
348 static void insertInStringArrayAttr(Operation *op, StringRef attrName,
349 StringRef attrVal,
unsigned idx) {
350 auto *ctx = op->getContext();
351 llvm::SmallVector<Attribute> newArr;
352 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
353 std::back_inserter(newArr));
358 static void eraseFromArrayAttr(Operation *op, StringRef attrName,
360 auto *ctx = op->getContext();
361 llvm::SmallVector<Attribute> newArr;
362 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
363 std::back_inserter(newArr));
364 newArr.erase(newArr.begin() + idx);
368 struct ArgTypeReplacement {
375 HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
378 std::map<unsigned, Value> memrefArgs;
379 for (
auto [i, arg] : llvm::enumerate(func.getArguments()))
380 if (arg.getType().isa<MemRefType>())
383 if (memrefArgs.empty())
388 std::map<unsigned, HandshakeMemType> argReplacements;
392 func, func.getArgumentTypes(), func.getResultTypes());
395 for (
auto it : memrefArgs) {
397 unsigned i = it.first;
398 auto arg = it.second;
399 auto loc = arg.getLoc();
401 auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
402 b.setInsertionPoint(extmemOp);
406 MemRefType memrefType = arg.getType().cast<MemRefType>();
409 cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
410 llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
411 unsigned addedInPorts = 0;
412 auto memName = func.getArgName(i);
413 auto addArgRes = [&](
unsigned id, NamedType &argType, NamedType &resType) {
415 unsigned newArgIdx = i + addedInPorts;
416 func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc());
417 insertInStringArrayAttr(func,
"argNames",
418 memName.str() +
"_" + argType.first.str(),
420 auto newInPort = func.getArgument(newArgIdx);
424 func.insertResult(func.getNumResults(), resType.second, {});
425 appendToStringArrayAttr(func,
"resNames",
426 memName.str() +
"_" + resType.first.str());
431 unsigned portIdx = 0;
432 for (
auto loadPort : extmemOp.getLoadPorts()) {
433 auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
434 memIOTypes.outputTypes[portIdx]);
435 newReturnOperands.push_back(
436 plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
441 for (
auto storePort : extmemOp.getStorePorts()) {
443 addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
444 memIOTypes.outputTypes[portIdx]);
445 newReturnOperands.push_back(
446 plumbStorePort(loc, b, storePort, newInPort,
447 memIOTypes.outputTypes[portIdx].second, memrefType));
453 b.setInsertionPoint(oldReturnOp);
454 b.create<ReturnOp>(arg.getLoc(), newReturnOperands);
463 func.eraseArgument(i + addedInPorts);
464 eraseFromArrayAttr(func,
"argNames", i + addedInPorts);
466 argReplacements[i] = memIOTypes;
469 if (createESIWrapper)
470 if (failed(wrapESI(func, origPortInfo, argReplacements)))
478 std::unique_ptr<mlir::Pass>
480 std::optional<bool> createESIWrapper) {
481 return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
assert(baseType &&"element must be base type")
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
static constexpr const char * kPredeclarationAttr
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & outs()