CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerExtmemToHW.cpp
Go to the documentation of this file.
1//===- LowerExtmemToHW.cpp - lock functions pass ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the lower extmem pass.
10//
11//===----------------------------------------------------------------------===//
12
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Rewrite/FrozenRewritePatternSet.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/Support/Debug.h"
28
29namespace circt {
30namespace handshake {
31#define GEN_PASS_DEF_HANDSHAKELOWEREXTMEMTOHW
32#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
33} // namespace handshake
34} // namespace circt
35
36using namespace circt;
37using namespace handshake;
38using namespace mlir;
39namespace {
40using NamedType = std::pair<StringAttr, Type>;
41struct HandshakeMemType {
42 llvm::SmallVector<NamedType> inputTypes, outputTypes;
43 MemRefType memRefType;
44 unsigned loadPorts, storePorts;
45};
46
47struct LoadName {
48 StringAttr dataIn;
49 StringAttr addrOut;
50
51 static LoadName get(MLIRContext *ctx, unsigned idx) {
52 return {StringAttr::get(ctx, "ld" + std::to_string(idx) + ".data"),
53 StringAttr::get(ctx, "ld" + std::to_string(idx) + ".addr")};
54 }
55};
56
57struct StoreNames {
58 StringAttr doneIn;
59 StringAttr out;
60
61 static StoreNames get(MLIRContext *ctx, unsigned idx) {
62 return {StringAttr::get(ctx, "st" + std::to_string(idx) + ".done"),
63 StringAttr::get(ctx, "st" + std::to_string(idx))};
64 }
65};
66
67} // namespace
68
69static Type indexToMemAddr(Type t, MemRefType memRef) {
70 assert(isa<IndexType>(t) && "Expected index type");
71 auto shape = memRef.getShape();
72 assert(shape.size() == 1 && "Expected 1D memref");
73 unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
74 return IntegerType::get(t.getContext(), addrWidth);
75}
76
77static HandshakeMemType getMemTypeForExtmem(Value v) {
78 auto *ctx = v.getContext();
79 assert(isa<mlir::MemRefType>(v.getType()) && "Value is not a memref type");
80 auto extmemOp = cast<handshake::ExternalMemoryOp>(*v.getUsers().begin());
81 HandshakeMemType memType;
82 llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
83
84 // Add memory type.
85 memType.memRefType = cast<MemRefType>(v.getType());
86 memType.loadPorts = extmemOp.getLdCount();
87 memType.storePorts = extmemOp.getStCount();
88
89 // Add load ports.
90 for (auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
91 auto names = LoadName::get(ctx, i);
92 memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
93 memType.outputTypes.push_back(
94 {names.addrOut,
95 indexToMemAddr(ldif.addressIn.getType(), memType.memRefType)});
96 }
97
98 // Add store ports.
99 for (auto [i, stif] : llvm::enumerate(extmemOp.getStorePorts())) {
100 auto names = StoreNames::get(ctx, i);
101
102 // Incoming store data and address
103 llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
104 storeOutFields.push_back(
105 {StringAttr::get(ctx, "address"),
106 indexToMemAddr(stif.addressIn.getType(), memType.memRefType)});
107 storeOutFields.push_back(
108 {StringAttr::get(ctx, "data"), stif.dataIn.getType()});
109 auto inType = hw::StructType::get(ctx, storeOutFields);
110 memType.outputTypes.push_back({names.out, inType});
111 memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
112 }
113
114 return memType;
115}
116
117namespace {
118struct HandshakeLowerExtmemToHWPass
119 : public circt::handshake::impl::HandshakeLowerExtmemToHWBase<
120 HandshakeLowerExtmemToHWPass> {
121
122 HandshakeLowerExtmemToHWPass(std::optional<bool> createESIWrapper) {
123 if (createESIWrapper)
124 this->createESIWrapper = *createESIWrapper;
125 }
126
127 void runOnOperation() override {
128 auto op = getOperation();
129 for (auto func : op.getOps<handshake::FuncOp>()) {
130 if (failed(lowerExtmemToHW(func))) {
131 signalPassFailure();
132 return;
133 }
134 }
135 };
136
137 LogicalResult lowerExtmemToHW(handshake::FuncOp func);
138 LogicalResult
139 wrapESI(handshake::FuncOp func, hw::ModulePortInfo origPorts,
140 const std::map<unsigned, HandshakeMemType> &argReplacements);
141};
142
143LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
145 const std::map<unsigned, HandshakeMemType> &argReplacements) {
146 auto *ctx = func.getContext();
147 OpBuilder b(func);
148 auto loc = func.getLoc();
149
150 // Create external module which will match the interface of 'func' after it's
151 // been lowered to HW.
152 b.setInsertionPoint(func);
153 auto newPortInfo = handshake::getPortInfoForOpTypes(
154 func, func.getArgumentTypes(), func.getResultTypes());
155 auto extMod = hw::HWModuleExternOp::create(
156 b, loc, StringAttr::get(ctx, "__" + func.getName() + "_hw"), newPortInfo);
157
158 // Add an attribute to the original handshake function to indicate that it
159 // needs to resolve to extMod in a later pass.
160 func->setAttr(kPredeclarationAttr,
161 FlatSymbolRefAttr::get(ctx, extMod.getName()));
162
163 // Create wrapper module. This will have the same ports as the original
164 // module, sans the replaced arguments.
165 auto wrapperModPortInfo = origPorts;
166 llvm::SmallVector<unsigned> argReplacementsIdxs;
167 llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
168 [](auto &pair) { return pair.first; });
169 for (auto i : llvm::reverse(argReplacementsIdxs))
170 wrapperModPortInfo.eraseInput(i);
171 auto wrapperMod = hw::HWModuleOp::create(
172 b, loc, StringAttr::get(ctx, func.getName() + "_esi_wrapper"),
173 wrapperModPortInfo);
174 Value clk = wrapperMod.getBodyBlock()->getArgument(
175 wrapperMod.getBodyBlock()->getNumArguments() - 2);
176 Value rst = wrapperMod.getBodyBlock()->getArgument(
177 wrapperMod.getBodyBlock()->getNumArguments() - 1);
178 SmallVector<Value> clkRes = {clk, rst};
179
180 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
181 BackedgeBuilder bb(b, loc);
182
183 // Create backedges for the results of the external module. These will be
184 // replaced by the service instance requests if associated with a memory.
185 llvm::SmallVector<Backedge> backedges;
186 for (auto resType : extMod.getOutputTypes())
187 backedges.push_back(bb.get(resType));
188
189 // Maintain which index we're currently at in the lowered handshake module's
190 // return.
191 unsigned resIdx = origPorts.sizeOutputs();
192
193 // Maintain the arguments which each memory will add to the inner module
194 // instance.
195 llvm::SmallVector<llvm::SmallVector<Value>> instanceArgsForMem;
196
197 for (auto [i, memType] : argReplacements) {
198
199 b.setInsertionPoint(wrapperMod);
200 // Create a memory service declaration for each memref argument that was
201 // served.
202 auto origPortInfo = origPorts.atInput(i);
203 auto memrefShape = memType.memRefType.getShape();
204 auto dataType = memType.memRefType.getElementType();
205 assert(memrefShape.size() == 1 && "Only 1D memrefs are supported");
206 unsigned memrefSize = memrefShape[0];
207 auto memServiceDecl = esi::RandomAccessMemoryDeclOp::create(
208 b, loc, origPortInfo.name, TypeAttr::get(dataType),
209 b.getI64IntegerAttr(memrefSize));
210 esi::ServicePortInfo writePortInfo = memServiceDecl.writePortInfo();
211 esi::ServicePortInfo readPortInfo = memServiceDecl.readPortInfo();
212
213 SmallVector<Value> instanceArgsFromThisMem;
214
215 // Create service requests. This MUST follow the order of which ports were
216 // added in other parts of this pass (load ports first, then store ports).
217 b.setInsertionPointToStart(wrapperMod.getBodyBlock());
218
219 // Load ports:
220 for (unsigned i = 0; i < memType.loadPorts; ++i) {
221 auto req = esi::RequestConnectionOp::create(
222 b, loc, readPortInfo.type, readPortInfo.port,
223 esi::AppIDAttr::get(ctx, b.getStringAttr("load"), resIdx));
224 auto reqUnpack = esi::UnpackBundleOp::create(
225 b, loc, req.getToClient(), ValueRange{backedges[resIdx]});
226 instanceArgsFromThisMem.push_back(
227 reqUnpack.getToChannels()
228 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
229 ++resIdx;
230 }
231
232 // Store ports:
233 for (unsigned i = 0; i < memType.storePorts; ++i) {
234 auto req = esi::RequestConnectionOp::create(
235 b, loc, writePortInfo.type, writePortInfo.port,
236 esi::AppIDAttr::get(ctx, b.getStringAttr("store"), resIdx));
237 auto reqUnpack = esi::UnpackBundleOp::create(
238 b, loc, req.getToClient(), ValueRange{backedges[resIdx]});
239 instanceArgsFromThisMem.push_back(
240 reqUnpack.getToChannels()
241 [esi::RandomAccessMemoryDeclOp::RespDirChannelIdx]);
242 ++resIdx;
243 }
244
245 instanceArgsForMem.emplace_back(std::move(instanceArgsFromThisMem));
246 }
247
248 // Stitch together arguments from the top-level ESI wrapper and the instance
249 // arguments generated from the service requests.
250 llvm::SmallVector<Value> instanceArgs;
251
252 // Iterate over the arguments of the original handshake.func and determine
253 // whether to grab operands from the arg replacements or the wrapper module.
254 unsigned wrapperArgIdx = 0;
255
256 for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
257 // Arg replacement indices refer to the original handshake.func argument
258 // index.
259 if (argReplacements.count(i)) {
260 // This index was originally a memref - pop the instance arguments for the
261 // next-in-line memory and add them.
262 auto &memArgs = instanceArgsForMem.front();
263 instanceArgs.append(memArgs.begin(), memArgs.end());
264 instanceArgsForMem.erase(instanceArgsForMem.begin());
265 } else {
266 // Add the argument from the wrapper mod. This is maintained by its own
267 // counter (memref arguments are removed, so if there was an argument at
268 // this point, it needs to come from the wrapper module).
269 instanceArgs.push_back(
270 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx++));
271 }
272 }
273
274 // Add any missing arguments from the wrapper module (this will be clock and
275 // reset)
276 for (; wrapperArgIdx < wrapperMod.getBodyBlock()->getNumArguments();
277 ++wrapperArgIdx)
278 instanceArgs.push_back(
279 wrapperMod.getBodyBlock()->getArgument(wrapperArgIdx));
280
281 // Instantiate the inner module.
282 auto instance =
283 hw::InstanceOp::create(b, loc, extMod, func.getName(), instanceArgs);
284
285 // And resolve the backedges.
286 for (auto [res, be] : llvm::zip(instance.getResults(), backedges))
287 be.setValue(res);
288
289 // Finally, grab the (non-memory) outputs from the inner module and return
290 // them through the wrapper.
291 auto outputOp =
292 cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
293 b.setInsertionPoint(outputOp);
294 hw::OutputOp::create(
295 b, outputOp.getLoc(),
296 instance.getResults().take_front(wrapperMod.getNumOutputPorts()));
297 outputOp.erase();
298
299 return success();
300}
301
302// Truncates the index-typed 'v' into an integer-type of the same width as the
303// 'memref' argument.
304// Uses arith operations since these are supported in the HandshakeToHW
305// lowering.
306static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
307 MemRefType memRefType) {
308 assert(isa<IndexType>(v.getType()) && "Expected an index-typed value");
309 auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
310 if (addrWidth == 0) {
311 // Arith doesn't support i0, just create a constant i0 with control
312 // dependency on the value.
313 auto ctrl = handshake::JoinOp::create(b, loc, v).getResult();
314 return handshake::ConstantOp::create(
315 b, loc, b.getIntegerType(0), b.getIntegerAttr(b.getIntegerType(0), 0),
316 ctrl);
317 }
318 return arith::IndexCastOp::create(b, loc, b.getIntegerType(addrWidth), v);
319}
320
321static Value plumbLoadPort(Location loc, OpBuilder &b,
322 handshake::MemLoadInterface &ldif, Value loadData,
323 MemRefType memrefType) {
324 // We need to feed both the load data and the load done outputs.
325 // Fork the extracted load data into two, and 'join' the second one to
326 // generate a none-typed output to drive the load done.
327 auto dataFork = ForkOp::create(b, loc, loadData, 2);
328
329 auto dataOut = dataFork.getResult()[0];
330 llvm::SmallVector<Value> joinArgs = {dataFork.getResult()[1]};
331 auto dataDone = JoinOp::create(b, loc, joinArgs);
332
333 ldif.dataOut.replaceAllUsesWith(dataOut);
334 ldif.doneOut.replaceAllUsesWith(dataDone);
335
336 // Return load address, to be fed to the top-level output, truncated to the
337 // width of the memory that is accessed.
338 return truncateToMemoryWidth(loc, b, ldif.addressIn, memrefType);
339}
340
341static Value plumbStorePort(Location loc, OpBuilder &b,
342 handshake::MemStoreInterface &stif, Value done,
343 Type outType, MemRefType memrefType) {
344 stif.doneOut.replaceAllUsesWith(done);
345 // Return the store address and data to be fed to the top-level output.
346 // Address is truncated to the width of the memory that is accessed.
347 llvm::SmallVector<Value> structArgs = {
348 truncateToMemoryWidth(loc, b, stif.addressIn, memrefType), stif.dataIn};
349
350 return hw::StructCreateOp::create(b, loc, cast<hw::StructType>(outType),
351 structArgs)
352 .getResult();
353}
354
355static void appendToStringArrayAttr(Operation *op, StringRef attrName,
356 StringRef attrVal) {
357 auto *ctx = op->getContext();
358 llvm::SmallVector<Attribute> newArr;
359 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
360 std::back_inserter(newArr));
361 newArr.push_back(StringAttr::get(ctx, attrVal));
362 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
363}
364
365static void insertInStringArrayAttr(Operation *op, StringRef attrName,
366 StringRef attrVal, unsigned idx) {
367 auto *ctx = op->getContext();
368 llvm::SmallVector<Attribute> newArr;
369 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
370 std::back_inserter(newArr));
371 newArr.insert(newArr.begin() + idx, StringAttr::get(ctx, attrVal));
372 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
373}
374
375static void eraseFromArrayAttr(Operation *op, StringRef attrName,
376 unsigned idx) {
377 auto *ctx = op->getContext();
378 llvm::SmallVector<Attribute> newArr;
379 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
380 std::back_inserter(newArr));
381 newArr.erase(newArr.begin() + idx);
382 op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
383}
384
385struct ArgTypeReplacement {
386 unsigned index;
387 TypeRange ins;
388 TypeRange outs;
389};
390
391LogicalResult
392HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
393 // Gather memref ports to be converted. This is an ordered map, and will be
394 // iterated from lo to hi indices.
395 std::map<unsigned, Value> memrefArgs;
396 for (auto [i, arg] : llvm::enumerate(func.getArguments()))
397 if (isa<MemRefType>(arg.getType()))
398 memrefArgs[i] = arg;
399
400 if (memrefArgs.empty())
401 return success(); // nothing to do.
402
403 // Record which arg indices were replaces with handshake memory ports.
404 // This is an ordered map, and will be iterated from lo to hi indices.
405 std::map<unsigned, HandshakeMemType> argReplacements;
406
407 // Record the hw.module i/o of the original func (used for ESI wrapper).
408 auto origPortInfo = handshake::getPortInfoForOpTypes(
409 func, func.getArgumentTypes(), func.getResultTypes());
410
411 OpBuilder b(func);
412 for (auto it : memrefArgs) {
413 // Do not use structured bindings for 'it' - cannot reference inside lambda.
414 unsigned i = it.first;
415 auto arg = it.second;
416 auto loc = arg.getLoc();
417 // Get the attached extmemory external module.
418 auto extmemOp = cast<handshake::ExternalMemoryOp>(*arg.getUsers().begin());
419 b.setInsertionPoint(extmemOp);
420
421 // Add memory input - this is the output of the extmemory op.
422 auto memIOTypes = getMemTypeForExtmem(arg);
423 MemRefType memrefType = cast<MemRefType>(arg.getType());
424
425 auto oldReturnOp =
426 cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
427 llvm::SmallVector<Value> newReturnOperands = oldReturnOp.getOperands();
428 unsigned addedInPorts = 0;
429 auto memName = func.getArgName(i);
430 auto addArgRes = [&](unsigned id, NamedType &argType,
431 NamedType &resType) -> FailureOr<Value> {
432 // Function argument
433 unsigned newArgIdx = i + addedInPorts;
434 if (failed(
435 func.insertArgument(newArgIdx, argType.second, {}, arg.getLoc())))
436 return failure();
437 insertInStringArrayAttr(func, "argNames",
438 memName.str() + "_" + argType.first.str(),
439 newArgIdx);
440 auto newInPort = func.getArgument(newArgIdx);
441 ++addedInPorts;
442
443 // Function result.
444 if (failed(func.insertResult(func.getNumResults(), resType.second, {})))
445 return failure();
446 appendToStringArrayAttr(func, "resNames",
447 memName.str() + "_" + resType.first.str());
448 return newInPort;
449 };
450
451 // Plumb load ports.
452 unsigned portIdx = 0;
453 for (auto loadPort : extmemOp.getLoadPorts()) {
454 auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
455 memIOTypes.outputTypes[portIdx]);
456 if (failed(newInPort))
457 return failure();
458 newReturnOperands.push_back(
459 plumbLoadPort(loc, b, loadPort, *newInPort, memrefType));
460 ++portIdx;
461 }
462
463 // Plumb store ports.
464 for (auto storePort : extmemOp.getStorePorts()) {
465 auto newInPort =
466 addArgRes(storePort.index, memIOTypes.inputTypes[portIdx],
467 memIOTypes.outputTypes[portIdx]);
468 if (failed(newInPort))
469 return failure();
470 newReturnOperands.push_back(
471 plumbStorePort(loc, b, storePort, *newInPort,
472 memIOTypes.outputTypes[portIdx].second, memrefType));
473 ++portIdx;
474 }
475
476 // Replace the return op of the function with a new one that returns the
477 // memory output struct.
478 b.setInsertionPoint(oldReturnOp);
479 ReturnOp::create(b, arg.getLoc(), newReturnOperands);
480 oldReturnOp.erase();
481
482 // Erase the extmemory operation since I/O plumbing has replaced all of its
483 // results.
484 extmemOp.erase();
485
486 // Erase the original memref argument of the top-level i/o now that it's
487 // use has been removed.
488 if (failed(func.eraseArgument(i + addedInPorts)))
489 return failure();
490 eraseFromArrayAttr(func, "argNames", i + addedInPorts);
491
492 argReplacements[i] = memIOTypes;
493 }
494
495 if (createESIWrapper)
496 if (failed(wrapESI(func, origPortInfo, argReplacements)))
497 return failure();
498
499 return success();
500}
501
502} // namespace
503
504std::unique_ptr<mlir::Pass>
506 std::optional<bool> createESIWrapper) {
507 return std::make_unique<HandshakeLowerExtmemToHWPass>(createESIWrapper);
508}
assert(baseType &&"element must be base type")
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static Type indexToMemAddr(Type t, MemRefType memRef)
static HandshakeMemType getMemTypeForExtmem(Value v)
Instantiate one of these and use it to build typed backedges.
create(elements, Type result_type=None)
Definition hw.py:544
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
std::unique_ptr< mlir::Pass > createHandshakeLowerExtmemToHWPass(std::optional< bool > createESIWrapper={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Describes a service port.
Definition ESIOps.h:31
ChannelBundleType type
Definition ESIOps.h:33
hw::InnerRefAttr port
Definition ESIOps.h:32
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & atInput(size_t idx)