CIRCT 20.0.0git
Loading...
Searching...
No Matches
LowerExtmemToHW.cpp
Go to the documentation of this file.
1//===- LowerExtmemToHW.cpp - lock functions pass ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the lower extmem pass.
10//
11//===----------------------------------------------------------------------===//
12
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Rewrite/FrozenRewritePatternSet.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/Support/Debug.h"
28
29namespace circt {
30namespace handshake {
31#define GEN_PASS_DEF_HANDSHAKELOWEREXTMEMTOHW
32#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
33} // namespace handshake
34} // namespace circt
35
36using namespace circt;
37using namespace handshake;
38using namespace mlir;
39namespace {
40using NamedType = std::pair<StringAttr, Type>;
41struct HandshakeMemType {
42 llvm::SmallVector<NamedType> inputTypes, outputTypes;
43 MemRefType memRefType;
44 unsigned loadPorts, storePorts;
45};
46
47struct LoadName {
48 StringAttr dataIn;
49 StringAttr addrOut;
50
51 static LoadName get(MLIRContext *ctx, unsigned idx) {
52 return {StringAttr::get(ctx, "ld" + std::to_string(idx) + ".data"),
53 StringAttr::get(ctx, "ld" + std::to_string(idx) + ".addr")};
54 }
55};
56
57struct StoreNames {
58 StringAttr doneIn;
59 StringAttr out;
60
61 static StoreNames get(MLIRContext *ctx, unsigned idx) {
62 return {StringAttr::get(ctx, "st" + std::to_string(idx) + ".done"),
63 StringAttr::get(ctx, "st" + std::to_string(idx))};
64 }
65};
66
67} // namespace
68
69static Type indexToMemAddr(Type t, MemRefType memRef) {
70 assert(isa<IndexType>(t) && "Expected index type");
71 auto shape = memRef.getShape();
72 assert(shape.size() == 1 && "Expected 1D memref");
73 unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
74 return IntegerType::get(t.getContext(), addrWidth);
75}
76
77static HandshakeMemType getMemTypeForExtmem(Value v) {
78 auto *ctx = v.getContext();
79 assert(isa<mlir::MemRefType>(v.getType()) && "Value is not a memref type");
80 auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
81 HandshakeMemType memType;
82 llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
83
84 // Add memory type.
85 memType.memRefType = cast<MemRefType>(v.getType());
86 memType.loadPorts = extmemOp.getLdCount();
87 memType.storePorts = extmemOp.getStCount();
88
89 // Add load ports.
90 for (auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
91 auto names = LoadName::get(ctx, i);
92 memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
93 memType.outputTypes.push_back(
94 {names.addrOut,
95 indexToMemAddr(ldif.addressIn.getType(), memType.memRefType)});
96 }
97
98 // Add store ports.
99 for (auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
100 auto names = StoreNames::get(ctx, i);
101
102 // Incoming store data and address
103 llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
104 storeOutFields.push_back(
105 {StringAttr::get(ctx, "address"),
106 indexToMemAddr(stif.addressIn.getType(), memType.memRefType)});
107 storeOutFields.push_back(
108 {StringAttr::get(ctx, "data"), stif.dataIn.getType()});
109 auto inType = hw::StructType::get(ctx, storeOutFields);
110 memType.outputTypes.push_back({names.out, inType});
111 memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
112 }
113
114 return memType;
115}
116
117namespace {
118struct HandshakeLowerExtmemToHWPass
119 : public circt::handshake::impl::HandshakeLowerExtmemToHWBase<
120 HandshakeLowerExtmemToHWPass> {
121
122 HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
123 if (createESIWrapper)
124 this->createESIWrapper = *createESIWrapper;
125 }
126
127 void runOnOperation() override {
128 auto op = getOperation();
129 for (auto func : op.getOps<handshake::FuncOp>()) {
130 if (failed(lowerExtmemToHW(func))) {
131 signalPassFailure();
132 return;
133 }
134 }
135 };
136
137 LogicalResult lowerExtmemToHW(handshake::FuncOp func);
138 LogicalResult
139 wrapESI(handshake::FuncOp func, hw::ModulePortInfo origPorts,
140 const std::map<unsigned, HandshakeMemType> &argReplacements);
141};
142
143LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
145 const std::map<unsigned, HandshakeMemType> &argReplacements) {
146 auto *ctx = func.getContext();
147 OpBuilder b(func);
148 auto loc = func.getLoc();
149
150 // Create external module which will match the interface of 'func' after it's
151 // been lowered to HW.
152 b.setInsertionPoint(func);
153 auto newPortInfo = handshake::getPortInfoForOpTypes(
154 func, func.getArgumentTypes(), func.getResultTypes());
155 auto extMod = b.create<hw::HWModuleExternOp>(
156 loc, StringAttr::get(ctx, "__" + func.getName() + "_hw"), newPortInfo);
157
158 // Add an attribute to the original handshake function to indicate that it
159 // needs to resolve to extMod in a later pass.
160 func->setAttr(kPredeclarationAttr,
161 FlatSymbolRefAttr::get(ctx, extMod.getName()));
162
163 // Create wrapper module. This will have the same ports as the original
164 // module, sans the replaced arguments.
165 auto wrapperModPortInfo = origPorts;
166 llvm::SmallVector<unsigned> argReplacementsIdxs;
167 llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
168 [](auto &pair) { return pair.first; });
169 for (auto i : llvm::reverse(argReplacementsIdxs))
170 wrapperModPortInfo.eraseInput(i);
171 auto wrapperMod = b.create<hw::HWModuleOp>(
172 loc, StringAttr::get(ctx, func.getName() + "_esi_wrapper"),
173 wrapperModPortInfo);
174 Value clk = wrapperMod.getBodyBlock()->getArgument(
175 wrapperMod.getBodyBlock()->getNumArguments() - 2);
176 Value rst = wrapperMod.getBodyBlock()->getArgument(
177 wrapperMod.getBodyBlock()->getNumArguments() - 1);
178 SmallVector<Value> clkRes = {clk, rst};
179
180 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
181 BackedgeBuilder bb(b, loc);
182
183 // Create backedges for the results of the external module. These will be
184 // replaced by the service instance requests if associated with a memory.
185 llvm::SmallVector<Backedge> backedges;
186 for (auto resType : extMod.getOutputTypes())
187 backedges.push_back(bb.get(resType));
188
189 // Maintain which index we're currently at in the lowered handshake module's
190 // return.
191 unsigned resIdx = origPorts.sizeOutputs();
192
193 // Maintain the arguments which each memory will add to the inner module
194 // instance.
195 llvm::SmallVector<llvm::OwningArrayRef<Value>> instanceArgsForMem;
196
197 for (auto [i, memType] : argReplacements) {
198
199 b.setInsertionPoint(wrapperMod);
200 // Create a memory service declaration for each memref argument that was
201 // served.
202 auto origPortInfo = origPorts.atInput(i);
203 auto memrefShape = memType.memRefType.getShape();
204 auto dataType = memType.memRefType.getElementType();
205 assert(memrefShape.size() == 1 && "Only 1D memrefs are supported");
206 unsigned memrefSize = memrefShape[0];
207 auto memServiceDecl = b.create<esi::RandomAccessMemoryDeclOp>(
208 loc, origPortInfo.name, TypeAttr::get(dataType),
209 b.getI64IntegerAttr(memrefSize));
210 esi::ServicePortInfo writePortInfo = memServiceDecl.writePortInfo();
211 esi::ServicePortInfo readPortInfo = memServiceDecl.readPortInfo();
212
213 SmallVector<Value> instanceArgsFromThisMem;
214
215 // Create service requests. This MUST follow the order of which ports were
216 // added in other parts of this pass (load ports first, then store ports).
217 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
218
219 // Load ports:
220 for (unsigned i = 0; i < memType.loadPorts; ++i) {
221 auto req = b.create<esi::RequestConnectionOp>(
222 loc, readPortInfo.type, readPortInfo.port,
223 esi::AppIDAttr::get(ctx, b.getStringAttr("load"), {}));
224 auto reqUnpack = b.create<esi::UnpackBundleOp>(
225 loc, req.getToClient(), ValueRange{backedges[resIdx]});
226 instanceArgsFromThisMem.push_back(
227 reqUnpack.getToChannels()
228 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
229 ++resIdx;
230 }
231
232 // Store ports:
233 for (unsigned i = 0; i < memType.storePorts; ++i) {
234 auto req = b.create<esi::RequestConnectionOp>(
235 loc, writePortInfo.type, writePortInfo.port,
236 esi::AppIDAttr::get(ctx, b.getStringAttr("store"), {}));
237 auto reqUnpack = b.create<esi::UnpackBundleOp>(
238 loc, req.getToClient(), ValueRange{backedges[resIdx]});
239 instanceArgsFromThisMem.push_back(
240 reqUnpack.getToChannels()
241 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
242 ++resIdx;
243 }
244
245 instanceArgsForMem.emplace_back(instanceArgsFromThisMem);
246 }
247
248 // Stitch together arguments from the top-level ESI wrapper and the instance
249 // arguments generated from the service requests.
250 llvm::SmallVector<Value> instanceArgs;
251
252 // Iterate over the arguments of the original handshake.func and determine
253 // whether to grab operands from the arg replacements or the wrapper module.
254 unsigned wrapperArgIdx = 0;
255
256 for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
257 // Arg replacement indices refer to the original handshake.func argument
258 // index.
259 if (argReplacements.count(i)) {
260 // This index was originally a memref - pop the instance arguments for the
261 // next-in-line memory and add them.
262 auto &memArgs = instanceArgsForMem.front();
263 instanceArgs.append(memArgs.begin(), memArgs.end());
264 instanceArgsForMem.erase(instanceArgsForMem.begin());
265 } else {
266 // Add the argument from the wrapper mod. This is maintained by its own
267 // counter (memref arguments are removed, so if there was an argument at
268 // this point, it needs to come from the wrapper module).
269 instanceArgs.push_back(
270 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
271 }
272 }
273
274 // Add any missing arguments from the wrapper module (this will be clock and
275 // reset)
276 for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
277 ++wrapperArgIdx)
278 instanceArgs.push_back(
279 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
280
281 // Instantiate the inner module.
282 auto instance =
283 b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
284
285 // And resolve the backedges.
286 for (auto [res, be] : llvm::zip(instance.getResults(), backedges))
287 be.setValue(res);
288
289 // Finally, grab the (non-memory) outputs from the inner module and return
290 // them through the wrapper.
291 auto outputOp =
292 cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
293 b.setInsertionPoint(outputOp);
294 b.create<hw::OutputOp>(
295 outputOp.getLoc(),
296 instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
297 outputOp.erase();
298
299 return success();
300}
301
302// Truncates the index-typed 'v' into an integer-type of the same width as the
303// 'memref' argument.
304// Uses arith operations since these are supported in the HandshakeToHW
305// lowering.
306static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
307 MemRefType memRefType) {
308 assert(isa<IndexType>(v.getType()) && "Expected an index-typed value");
309 auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
310 return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
311}
312
313static Value plumbLoadPort(Location loc, OpBuilder &b,
314 handshake::MemLoadInterface &ldif, Value loadData,
315 MemRefType memrefType) {
316 // We need to feed both the load data and the load done outputs.
317 // Fork the extracted load data into two, and 'join' the second one to
318 // generate a none-typed output to drive the load done.
319 auto dataFork = b.create<ForkOp>(loc, loadData, 2);
320
321 auto dataOut = dataFork.getResult()[0];
322 llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
323 auto dataDone = b.create<JoinOp>(loc, joinArgs);
324
325 ldif.dataOut.replaceAllUsesWith(dataOut);
326 ldif.doneOut.replaceAllUsesWith(dataDone);
327
328 // Return load address, to be fed to the top-level output, truncated to the
329 // width of the memory that is accessed.
330 return truncateToMemoryWidth(loc, b, ldif.addressIn, memrefType);
331}
332
333static Value plumbStorePort(Location loc, OpBuilder &b,
334 handshake::MemStoreInterface &stif, Value done,
335 Type outType, MemRefType memrefType) {
336 stif.doneOut.replaceAllUsesWith(done);
337 // Return the store address and data to be fed to the top-level output.
338 // Address is truncated to the width of the memory that is accessed.
339 llvm::SmallVector<Value> structArgs = {
340 truncateToMemoryWidth(loc, b, stif.addressIn, memrefType), stif.dataIn};
341
342 return b
343 .create<hw::StructCreateOp>(loc, cast<hw::StructType>(outType),
344 structArgs)
345 .getResult();
346}
347
348static void appendToStringArrayAttr(Operation *op, StringRef attrName,
349 StringRef attrVal) {
350 auto *ctx = op->getContext();
351 llvm::SmallVector<Attribute> newArr;
352 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
353 std::back_inserter(newArr));
354 newArr.push_back(StringAttr::get(ctx, attrVal));
355 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
356}
357
358static void insertInStringArrayAttr(Operation *op, StringRef attrName,
359 StringRef attrVal, unsigned idx) {
360 auto *ctx = op->getContext();
361 llvm::SmallVector<Attribute> newArr;
362 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
363 std::back_inserter(newArr));
364 newArr.insert(newArr.begin() + idx, StringAttr::get(ctx, attrVal));
365 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
366}
367
368static void eraseFromArrayAttr(Operation *op, StringRef attrName,
369 unsigned idx) {
370 auto *ctx = op->getContext();
371 llvm::SmallVector<Attribute> newArr;
372 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
373 std::back_inserter(newArr));
374 newArr.erase(newArr.begin() + idx);
375 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
376}
377
378struct ArgTypeReplacement {
379 unsigned index;
380 TypeRange ins;
381 TypeRange outs;
382};
383
384LogicalResult
385HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
386 // Gather memref ports to be converted. This is an ordered map, and will be
387 // iterated from lo to hi indices.
388 std::map<unsigned, Value> memrefArgs;
389 for (auto [i, arg] : llvm::enumerate(func.getArguments()))
390 if (isa<MemRefType>(arg.getType()))
391 memrefArgs[i] = arg;
392
393 if (memrefArgs.empty())
394 return success(); // nothing to do.
395
396 // Record which arg indices were replaces with handshake memory ports.
397 // This is an ordered map, and will be iterated from lo to hi indices.
398 std::map<unsigned, HandshakeMemType> argReplacements;
399
400 // Record the hw.module i/o of the original func (used for ESI wrapper).
401 auto origPortInfo = handshake::getPortInfoForOpTypes(
402 func, func.getArgumentTypes(), func.getResultTypes());
403
404 OpBuilder b(func);
405 for (auto it : memrefArgs) {
406 // Do not use structured bindings for 'it' - cannot reference inside lambda.
407 unsigned i = it.first;
408 auto arg = it.second;
409 auto loc = arg.getLoc();
410 // Get the attached extmemory external module.
411 auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
412 b.setInsertionPoint(extmemOp);
413
414 // Add memory input - this is the output of the extmemory op.
415 auto memIOTypes = getMemTypeForExtmem(arg);
416 MemRefType memrefType = cast<MemRefType>(arg.getType());
417
418 auto oldReturnOp =
419 cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
420 llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
421 unsigned addedInPorts = 0;
422 auto memName = func.getArgName(i);
423 auto addArgRes = [&](unsigned id, NamedType &argType, NamedType &resType) {
424 // Function argument
425 unsigned newArgIdx = i + addedInPorts;
426 func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc());
427 insertInStringArrayAttr(func, "argNames",
428 memName.str() + "_" + argType.first.str(),
429 newArgIdx);
430 auto newInPort = func.getArgument(newArgIdx);
431 ++addedInPorts;
432
433 // Function result.
434 func.insertResult(func.getNumResults(), resType.second, {});
435 appendToStringArrayAttr(func, "resNames",
436 memName.str() + "_" + resType.first.str());
437 return newInPort;
438 };
439
440 // Plumb load ports.
441 unsigned portIdx = 0;
442 for (auto loadPort : extmemOp.getLoadPorts()) {
443 auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
444 memIOTypes.outputTypes[portIdx]);
445 newReturnOperands.push_back(
446 plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
447 ++portIdx;
448 }
449
450 // Plumb store ports.
451 for (auto storePort : extmemOp.getStorePorts()) {
452 auto newInPort =
453 addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
454 memIOTypes.outputTypes[portIdx]);
455 newReturnOperands.push_back(
456 plumbStorePort(loc, b, storePort, newInPort,
457 memIOTypes.outputTypes[portIdx].second, memrefType));
458 ++portIdx;
459 }
460
461 // Replace the return op of the function with a new one that returns the
462 // memory output struct.
463 b.setInsertionPoint(oldReturnOp);
464 b.create<ReturnOp>(arg.getLoc(), newReturnOperands);
465 oldReturnOp.erase();
466
467 // Erase the extmemory operation since I/O plumbing has replaced all of its
468 // results.
469 extmemOp.erase();
470
471 // Erase the original memref argument of the top-level i/o now that it's
472 // use has been removed.
473 func.eraseArgument(i + addedInPorts);
474 eraseFromArrayAttr(func, "argNames", i + addedInPorts);
475
476 argReplacements[i] = memIOTypes;
477 }
478
479 if (createESIWrapper)
480 if (failed(wrapESI(func, origPortInfo, argReplacements)))
481 return failure();
482
483 return success();
484}
485
486} // namespace
487
488std::unique_ptr<mlir::Pass>
490 std::optional<bool> createESIWrapper) {
491 return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
492}
assert(baseType &&"element must be base type")
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
name(self)
Definition hw.py:329
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Describes a service port.
Definition ESIOps.h:31
ChannelBundleType type
Definition ESIOps.h:33
hw::InnerRefAttr port
Definition ESIOps.h:32
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & atInput(size_t idx)