20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/OperationSupport.h"
23#include "mlir/IR/PatternMatch.h"
24#include "llvm/ADT/TypeSwitch.h"
35 ModuleOp moduleOp,
InstanceGraph &instanceGraph, std::string &topLevel,
36 SmallVectorImpl<std::string> &sortedFuncs) {
39 auto &funcUses = instanceGraph[funcOp.getName().str()];
40 funcOp.walk([&](handshake::InstanceOp instanceOp) {
41 funcUses.insert(instanceOp.getModule().str());
44 moduleOp.walk(walkFuncOps);
49 std::set<std::string> visited, marked, candidateTopLevel;
50 SmallVector<std::string> cycleTrace;
52 llvm::transform(instanceGraph,
53 std::inserter(candidateTopLevel, candidateTopLevel.begin()),
54 [](
auto it) { return it.first; });
55 std::function<void(
const std::string &, SmallVector<std::string>)> cycleUtil =
56 [&](
const std::string &node, SmallVector<std::string> trace) {
57 if (cyclic || visited.count(node))
59 trace.push_back(node);
60 if (marked.count(node)) {
66 for (
auto use : instanceGraph[node]) {
67 candidateTopLevel.erase(use);
68 cycleUtil(use, trace);
72 sortedFuncs.insert(sortedFuncs.begin(), node);
74 for (
auto it : instanceGraph) {
75 if (visited.count(it.first) == 0)
76 cycleUtil(it.first, {});
82 auto err = moduleOp.emitOpError();
83 err <<
"cannot deduce top level function - cycle "
84 "detected in instance graph (";
86 cycleTrace, err, [&](
auto node) { err << node; },
"->");
90 assert(!candidateTopLevel.empty() &&
91 "if non-cyclic, there should be at least 1 candidate top level");
93 if (candidateTopLevel.size() > 1) {
94 auto err = moduleOp.emitOpError();
95 err <<
"multiple candidate top-level modules detected (";
96 llvm::interleaveComma(candidateTopLevel, err,
97 [&](
auto topLevel) { err << topLevel; });
98 err <<
"). Please remove one of these from the source program.";
101 topLevel = *candidateTopLevel.begin();
107 if (funcOp.isExternal())
110 auto checkUseFunc = [&](Operation *op, Value v, StringRef desc,
111 unsigned idx) -> LogicalResult {
112 auto numUses = std::distance(v.getUses().begin(), v.getUses().end());
114 return op->emitOpError() << desc <<
" " << idx <<
" has no uses.";
116 return op->emitOpError() << desc <<
" " << idx <<
" has multiple uses.";
120 for (
auto &subOp : funcOp.getOps()) {
121 for (
auto res : llvm::enumerate(subOp.getResults())) {
122 if (failed(checkUseFunc(&subOp, res.value(),
"result", res.index())))
127 Block &entryBlock = funcOp.front();
128 for (
auto barg : enumerate(entryBlock.getArguments())) {
129 if (failed(checkUseFunc(funcOp.getOperation(), barg.value(),
"argument",
138 auto *ctx = tuple.getContext();
139 mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
140 for (
auto [i, innerType] : llvm::enumerate(tuple)) {
141 Type convertedInnerType = innerType;
142 if (
auto tupleInnerType = dyn_cast<TupleType>(innerType))
144 hwfields.push_back({StringAttr::get(ctx,
"field" + std::to_string(i)),
145 convertedInnerType});
148 return hw::StructType::get(ctx, hwfields);
154 return TypeSwitch<Type, Type>(t)
156 [&](IndexType it) {
return IntegerType::get(it.getContext(), 64); })
157 .Case<TupleType>([&](TupleType tt) {
158 llvm::SmallVector<Type> types;
159 for (
auto innerType : tt)
162 mlir::TupleType::get(types[0].getContext(), types));
164 .Case<hw::StructType>([&](
auto st) {
165 llvm::SmallVector<hw::StructType::FieldInfo> structFields(
167 for (
auto &field : structFields)
169 return hw::StructType::get(st.getContext(), structFields);
172 [&](NoneType nt) {
return IntegerType::get(nt.getContext(), 0); })
173 .Default([&](Type t) {
return t; });
182class HandshakePortNameGenerator {
184 explicit HandshakePortNameGenerator(Operation *op)
185 : builder(op->getContext()) {
186 auto namedOpInterface = dyn_cast<handshake::NamedIOInterface>(op);
187 if (namedOpInterface)
188 inferFromNamedOpInterface(namedOpInterface);
189 else if (
auto funcOp = dyn_cast<handshake::FuncOp>(op))
190 inferFromFuncOp(funcOp);
195 StringAttr inputName(
unsigned idx) {
return inputs[idx]; }
196 StringAttr outputName(
unsigned idx) {
return outputs[idx]; }
199 using IdxToStrF =
const std::function<std::string(
unsigned)> &;
200 void infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF) {
202 llvm::enumerate(op->getOperandTypes()), std::back_inserter(inputs),
203 [&](
auto it) { return builder.getStringAttr(inF(it.index())); });
205 llvm::enumerate(op->getResultTypes()), std::back_inserter(outputs),
206 [&](
auto it) { return builder.getStringAttr(outF(it.index())); });
209 void inferDefault(Operation *op) {
211 op, [](
unsigned idx) {
return "in" + std::to_string(idx); },
212 [](
unsigned idx) {
return "out" + std::to_string(idx); });
215 void inferFromNamedOpInterface(handshake::NamedIOInterface op) {
217 op, [&](
unsigned idx) {
return op.getOperandName(idx); },
218 [&](
unsigned idx) {
return op.getResultName(idx); });
222 auto inF = [&](
unsigned idx) {
return op.getArgName(idx).str(); };
223 auto outF = [&](
unsigned idx) {
return op.getResName(idx).str(); };
225 llvm::enumerate(op.getArgumentTypes()), std::back_inserter(inputs),
226 [&](
auto it) { return builder.getStringAttr(inF(it.index())); });
228 llvm::enumerate(op.getResultTypes()), std::back_inserter(outputs),
229 [&](
auto it) { return builder.getStringAttr(outF(it.index())); });
233 llvm::SmallVector<StringAttr> inputs;
234 llvm::SmallVector<StringAttr> outputs;
239 for (
int i = 0, e = op->getNumOperands(); i < e; ++i)
240 if (op->getOperand(i) == oldVal) {
241 op->setOperand(i, newVal);
247 OpBuilder &rewriter) {
249 std::vector<Operation *> opsToProcess;
250 for (
auto &u : result.getUses())
251 opsToProcess.push_back(u.getOwner());
254 rewriter.setInsertionPointAfterValue(result);
255 auto forkSize = opsToProcess.size();
258 newOp = LazyForkOp::create(rewriter, result.getLoc(), result, forkSize);
260 newOp = ForkOp::create(rewriter, result.getLoc(), result, forkSize);
265 for (
int i = 0, e = forkSize; i < e; ++i)
271 return TypeSwitch<Type, esi::ChannelType>(t)
275 .Case<NoneType>([](NoneType nt) {
277 return esiWrapper(IntegerType::get(nt.getContext(), 0));
279 .Default([](
auto t) {
280 return esi::ChannelType::get(t.getContext(),
toValidType(t));
287 SmallVector<hw::PortInfo> pinputs, poutputs;
289 HandshakePortNameGenerator portNames(op);
290 auto *ctx = op->getContext();
292 Type i1Type = IntegerType::get(ctx, 1);
293 Type clkType = seq::ClockType::get(ctx);
297 for (
auto arg :
llvm::enumerate(inputs)) {
299 {{portNames.inputName(arg.index()), esiWrapper(arg.value()),
300 hw::ModulePort::Direction::Input},
307 for (
auto res :
llvm::enumerate(outputs)) {
309 {{portNames.outputName(res.index()),
esiWrapper(res.value()),
310 hw::ModulePort::Direction::Output},
317 pinputs.push_back({{StringAttr::get(ctx,
"clock"), clkType,
318 hw::ModulePort::Direction::Input},
321 pinputs.push_back({{StringAttr::get(ctx,
"reset"), i1Type,
322 hw::ModulePort::Direction::Input},
assert(baseType &&"element must be base type")
static Type tupleToStruct(TupleType tuple)
static void replaceFirstUse(Operation *op, Value oldVal, Value newVal)
Channels are the basic communication primitives.
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
void insertFork(Value result, bool isLazy, OpBuilder &rewriter)
Adds fork operations to any value with multiple uses in r.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds a decoded list of input/inout and output ports for a module or instance.