21 #include "mlir/Dialect/Arith/IR/Arith.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
27 using namespace circt;
36 ModuleOp moduleOp,
InstanceGraph &instanceGraph, std::string &topLevel,
37 SmallVectorImpl<std::string> &sortedFuncs) {
40 auto &funcUses = instanceGraph[funcOp.getName().str()];
41 funcOp.walk([&](handshake::InstanceOp instanceOp) {
42 funcUses.insert(instanceOp.getModule().str());
45 moduleOp.walk(walkFuncOps);
50 std::set<std::string> visited, marked, candidateTopLevel;
51 SmallVector<std::string> cycleTrace;
53 llvm::transform(instanceGraph,
54 std::inserter(candidateTopLevel, candidateTopLevel.begin()),
55 [](
auto it) { return it.first; });
56 std::function<void(
const std::string &, SmallVector<std::string>)> cycleUtil =
57 [&](
const std::string &node, SmallVector<std::string> trace) {
58 if (cyclic || visited.count(node))
60 trace.push_back(node);
61 if (marked.count(node)) {
67 for (
auto use : instanceGraph[node]) {
68 candidateTopLevel.erase(use);
69 cycleUtil(use, trace);
73 sortedFuncs.insert(sortedFuncs.begin(), node);
75 for (
auto it : instanceGraph) {
76 if (visited.count(it.first) == 0)
77 cycleUtil(it.first, {});
83 auto err = moduleOp.emitOpError();
84 err <<
"cannot deduce top level function - cycle "
85 "detected in instance graph (";
87 cycleTrace, err, [&](
auto node) { err << node; },
"->");
91 assert(!candidateTopLevel.empty() &&
92 "if non-cyclic, there should be at least 1 candidate top level");
94 if (candidateTopLevel.size() > 1) {
95 auto err = moduleOp.emitOpError();
96 err <<
"multiple candidate top-level modules detected (";
97 llvm::interleaveComma(candidateTopLevel, err,
98 [&](
auto topLevel) { err << topLevel; });
99 err <<
"). Please remove one of these from the source program.";
102 topLevel = *candidateTopLevel.begin();
108 if (funcOp.isExternal())
111 auto checkUseFunc = [&](Operation *op, Value v, StringRef desc,
112 unsigned idx) -> LogicalResult {
113 auto numUses = std::distance(v.getUses().begin(), v.getUses().end());
115 return op->emitOpError() << desc <<
" " << idx <<
" has no uses.";
117 return op->emitOpError() << desc <<
" " << idx <<
" has multiple uses.";
121 for (
auto &subOp : funcOp.getOps()) {
122 for (
auto res : llvm::enumerate(subOp.getResults())) {
123 if (failed(checkUseFunc(&subOp, res.value(),
"result", res.index())))
128 Block &entryBlock = funcOp.front();
129 for (
auto barg : enumerate(entryBlock.getArguments())) {
130 if (failed(checkUseFunc(funcOp.getOperation(), barg.value(),
"argument",
139 auto *ctx = tuple.getContext();
140 mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
141 for (
auto [i,
innerType] : llvm::enumerate(tuple)) {
143 if (
auto tupleInnerType = dyn_cast<TupleType>(
innerType))
146 convertedInnerType});
155 return TypeSwitch<Type, Type>(t)
158 .Case<TupleType>([&](TupleType tt) {
159 llvm::SmallVector<Type> types;
165 .Case<hw::StructType>([&](
auto st) {
166 llvm::SmallVector<hw::StructType::FieldInfo> structFields(
168 for (
auto &field : structFields)
174 .Default([&](Type t) {
return t; });
183 class HandshakePortNameGenerator {
185 explicit HandshakePortNameGenerator(Operation *op)
186 : builder(op->getContext()) {
187 auto namedOpInterface = dyn_cast<handshake::NamedIOInterface>(op);
188 if (namedOpInterface)
189 inferFromNamedOpInterface(namedOpInterface);
190 else if (
auto funcOp = dyn_cast<handshake::FuncOp>(op))
191 inferFromFuncOp(funcOp);
196 StringAttr inputName(
unsigned idx) {
return inputs[idx]; }
197 StringAttr outputName(
unsigned idx) {
return outputs[idx]; }
200 using IdxToStrF =
const std::function<std::string(
unsigned)> &;
201 void infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF) {
203 llvm::enumerate(op->getOperandTypes()), std::back_inserter(inputs),
204 [&](
auto it) { return builder.getStringAttr(inF(it.index())); });
206 llvm::enumerate(op->getResultTypes()), std::back_inserter(outputs),
207 [&](
auto it) { return builder.getStringAttr(outF(it.index())); });
210 void inferDefault(Operation *op) {
212 op, [](
unsigned idx) {
return "in" + std::to_string(idx); },
213 [](
unsigned idx) {
return "out" + std::to_string(idx); });
216 void inferFromNamedOpInterface(handshake::NamedIOInterface op) {
218 op, [&](
unsigned idx) {
return op.getOperandName(idx); },
219 [&](
unsigned idx) {
return op.getResultName(idx); });
223 auto inF = [&](
unsigned idx) {
return op.getArgName(idx).str(); };
224 auto outF = [&](
unsigned idx) {
return op.getResName(idx).str(); };
226 llvm::enumerate(op.getArgumentTypes()), std::back_inserter(inputs),
227 [&](
auto it) { return builder.getStringAttr(inF(it.index())); });
229 llvm::enumerate(op.getResultTypes()), std::back_inserter(outputs),
230 [&](
auto it) { return builder.getStringAttr(outF(it.index())); });
234 llvm::SmallVector<StringAttr> inputs;
235 llvm::SmallVector<StringAttr> outputs;
240 for (
int i = 0, e = op->getNumOperands(); i < e; ++i)
241 if (op->getOperand(i) == oldVal) {
242 op->setOperand(i, newVal);
248 OpBuilder &rewriter) {
250 std::vector<Operation *> opsToProcess;
251 for (
auto &u : result.getUses())
252 opsToProcess.push_back(u.getOwner());
255 rewriter.setInsertionPointAfterValue(result);
256 auto forkSize = opsToProcess.size();
259 newOp = rewriter.create<LazyForkOp>(result.getLoc(), result, forkSize);
261 newOp = rewriter.create<ForkOp>(result.getLoc(), result, forkSize);
266 for (
int i = 0, e = forkSize; i < e; ++i)
272 return TypeSwitch<Type, esi::ChannelType>(t)
276 .Case<NoneType>([](NoneType nt) {
280 .Default([](
auto t) {
288 SmallVector<hw::PortInfo> pinputs, poutputs;
290 HandshakePortNameGenerator portNames(op);
291 auto *ctx = op->getContext();
298 for (
auto arg : llvm::enumerate(inputs)) {
300 {{portNames.inputName(arg.index()),
esiWrapper(arg.value()),
308 for (
auto res : llvm::enumerate(outputs)) {
310 {{portNames.outputName(res.index()),
esiWrapper(res.value()),
328 return hw::ModulePortInfo{pinputs, poutputs};
assert(baseType &&"element must be base type")
static Type tupleToStruct(TupleType tuple)
static void replaceFirstUse(Operation *op, Value oldVal, Value newVal)
Channels are the basic communication primitives.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::Type innerType(mlir::Type type)
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
void insertFork(Value result, bool isLazy, OpBuilder &rewriter)
Adds fork operations to any value with multiple uses in r.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds a decoded list of input/inout and output ports for a module or instance.