CIRCT 20.0.0git
Loading...
Searching...
No Matches
HandshakeUtils.cpp
Go to the documentation of this file.
1//===- HandshakeUtils.cpp - handshake pass helper functions -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions for various helper functions used in handshake
10// passes.
11//
12//===----------------------------------------------------------------------===//
13
20#include "circt/Support/LLVM.h"
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/OperationSupport.h"
24#include "mlir/IR/PatternMatch.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27using namespace circt;
28using namespace handshake;
29using namespace mlir;
30
31/// Iterates over the handshake::FuncOp's in the program to build an instance
32/// graph. In doing so, we detect whether there are any cycles in this graph, as
33/// well as infer a top module for the design by performing a topological sort
34/// of the instance graph. The result of this sort is placed in sortedFuncs.
36 ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel,
37 SmallVectorImpl<std::string> &sortedFuncs) {
38 // Create use graph
39 auto walkFuncOps = [&](handshake::FuncOp funcOp) {
40 auto &funcUses = instanceGraph[funcOp.getName().str()];
41 funcOp.walk([&](handshake::InstanceOp instanceOp) {
42 funcUses.insert(instanceOp.getModule().str());
43 });
44 };
45 moduleOp.walk(walkFuncOps);
46
47 // find top-level (and cycles) using a topological sort. Initialize all
48 // instances as candidate top level modules; these will be pruned whenever
49 // they are referenced by another module.
50 std::set<std::string> visited, marked, candidateTopLevel;
51 SmallVector<std::string> cycleTrace;
52 bool cyclic = false;
53 llvm::transform(instanceGraph,
54 std::inserter(candidateTopLevel, candidateTopLevel.begin()),
55 [](auto it) { return it.first; });
56 std::function<void(const std::string &, SmallVector<std::string>)> cycleUtil =
57 [&](const std::string &node, SmallVector<std::string> trace) {
58 if (cyclic || visited.count(node))
59 return;
60 trace.push_back(node);
61 if (marked.count(node)) {
62 cyclic = true;
63 cycleTrace = trace;
64 return;
65 }
66 marked.insert(node);
67 for (auto use : instanceGraph[node]) {
68 candidateTopLevel.erase(use);
69 cycleUtil(use, trace);
70 }
71 marked.erase(node);
72 visited.insert(node);
73 sortedFuncs.insert(sortedFuncs.begin(), node);
74 };
75 for (auto it : instanceGraph) {
76 if (visited.count(it.first) == 0)
77 cycleUtil(it.first, {});
78 if (cyclic)
79 break;
80 }
81
82 if (cyclic) {
83 auto err = moduleOp.emitOpError();
84 err << "cannot deduce top level function - cycle "
85 "detected in instance graph (";
86 llvm::interleave(
87 cycleTrace, err, [&](auto node) { err << node; }, "->");
88 err << ").";
89 return err;
90 }
91 assert(!candidateTopLevel.empty() &&
92 "if non-cyclic, there should be at least 1 candidate top level");
93
94 if (candidateTopLevel.size() > 1) {
95 auto err = moduleOp.emitOpError();
96 err << "multiple candidate top-level modules detected (";
97 llvm::interleaveComma(candidateTopLevel, err,
98 [&](auto topLevel) { err << topLevel; });
99 err << "). Please remove one of these from the source program.";
100 return err;
101 }
102 topLevel = *candidateTopLevel.begin();
103 return success();
104}
105
106LogicalResult
108 if (funcOp.isExternal())
109 return success();
110
111 auto checkUseFunc = [&](Operation *op, Value v, StringRef desc,
112 unsigned idx) -> LogicalResult {
113 auto numUses = std::distance(v.getUses().begin(), v.getUses().end());
114 if (numUses == 0)
115 return op->emitOpError() << desc << " " << idx << " has no uses.";
116 if (numUses > 1)
117 return op->emitOpError() << desc << " " << idx << " has multiple uses.";
118 return success();
119 };
120
121 for (auto &subOp : funcOp.getOps()) {
122 for (auto res : llvm::enumerate(subOp.getResults())) {
123 if (failed(checkUseFunc(&subOp, res.value(), "result", res.index())))
124 return failure();
125 }
126 }
127
128 Block &entryBlock = funcOp.front();
129 for (auto barg : enumerate(entryBlock.getArguments())) {
130 if (failed(checkUseFunc(funcOp.getOperation(), barg.value(), "argument",
131 barg.index())))
132 return failure();
133 }
134 return success();
135}
136
137// NOLINTNEXTLINE(misc-no-recursion)
138static Type tupleToStruct(TupleType tuple) {
139 auto *ctx = tuple.getContext();
140 mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
141 for (auto [i, innerType] : llvm::enumerate(tuple)) {
142 Type convertedInnerType = innerType;
143 if (auto tupleInnerType = dyn_cast<TupleType>(innerType))
144 convertedInnerType = tupleToStruct(tupleInnerType);
145 hwfields.push_back({StringAttr::get(ctx, "field" + std::to_string(i)),
146 convertedInnerType});
147 }
148
149 return hw::StructType::get(ctx, hwfields);
150}
151
152// Converts 't' into a valid HW type. This is strictly used for converting
153// 'index' types into a fixed-width type.
155 return TypeSwitch<Type, Type>(t)
156 .Case<IndexType>(
157 [&](IndexType it) { return IntegerType::get(it.getContext(), 64); })
158 .Case<TupleType>([&](TupleType tt) {
159 llvm::SmallVector<Type> types;
160 for (auto innerType : tt)
161 types.push_back(toValidType(innerType));
162 return tupleToStruct(
163 mlir::TupleType::get(types[0].getContext(), types));
164 })
165 .Case<hw::StructType>([&](auto st) {
166 llvm::SmallVector<hw::StructType::FieldInfo> structFields(
167 st.getElements());
168 for (auto &field : structFields)
169 field.type = toValidType(field.type);
170 return hw::StructType::get(st.getContext(), structFields);
171 })
172 .Case<NoneType>(
173 [&](NoneType nt) { return IntegerType::get(nt.getContext(), 0); })
174 .Default([&](Type t) { return t; });
175}
176
177namespace {
178
179/// A class to be used with getPortInfoForOp. Provides an opaque interface for
180/// generating the port names of an operation; handshake operations generate
181/// names by the Handshake NamedIOInterface; and other operations, such as
182/// arith ops, are assigned default names.
183class HandshakePortNameGenerator {
184public:
185 explicit HandshakePortNameGenerator(Operation *op)
186 : builder(op->getContext()) {
187 auto namedOpInterface = dyn_cast<handshake::NamedIOInterface>(op);
188 if (namedOpInterface)
189 inferFromNamedOpInterface(namedOpInterface);
190 else if (auto funcOp = dyn_cast<handshake::FuncOp>(op))
191 inferFromFuncOp(funcOp);
192 else
193 inferDefault(op);
194 }
195
196 StringAttr inputName(unsigned idx) { return inputs[idx]; }
197 StringAttr outputName(unsigned idx) { return outputs[idx]; }
198
199private:
200 using IdxToStrF = const std::function<std::string(unsigned)> &;
201 void infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF) {
202 llvm::transform(
203 llvm::enumerate(op->getOperandTypes()), std::back_inserter(inputs),
204 [&](auto it) { return builder.getStringAttr(inF(it.index())); });
205 llvm::transform(
206 llvm::enumerate(op->getResultTypes()), std::back_inserter(outputs),
207 [&](auto it) { return builder.getStringAttr(outF(it.index())); });
208 }
209
210 void inferDefault(Operation *op) {
211 infer(
212 op, [](unsigned idx) { return "in" + std::to_string(idx); },
213 [](unsigned idx) { return "out" + std::to_string(idx); });
214 }
215
216 void inferFromNamedOpInterface(handshake::NamedIOInterface op) {
217 infer(
218 op, [&](unsigned idx) { return op.getOperandName(idx); },
219 [&](unsigned idx) { return op.getResultName(idx); });
220 }
221
222 void inferFromFuncOp(handshake::FuncOp op) {
223 auto inF = [&](unsigned idx) { return op.getArgName(idx).str(); };
224 auto outF = [&](unsigned idx) { return op.getResName(idx).str(); };
225 llvm::transform(
226 llvm::enumerate(op.getArgumentTypes()), std::back_inserter(inputs),
227 [&](auto it) { return builder.getStringAttr(inF(it.index())); });
228 llvm::transform(
229 llvm::enumerate(op.getResultTypes()), std::back_inserter(outputs),
230 [&](auto it) { return builder.getStringAttr(outF(it.index())); });
231 }
232
233 Builder builder;
234 llvm::SmallVector<StringAttr> inputs;
235 llvm::SmallVector<StringAttr> outputs;
236};
237} // namespace
238
239static void replaceFirstUse(Operation *op, Value oldVal, Value newVal) {
240 for (int i = 0, e = op->getNumOperands(); i < e; ++i)
241 if (op->getOperand(i) == oldVal) {
242 op->setOperand(i, newVal);
243 break;
244 }
245}
246
247void circt::handshake::insertFork(Value result, bool isLazy,
248 OpBuilder &rewriter) {
249 // Get successor operations
250 std::vector<Operation *> opsToProcess;
251 for (auto &u : result.getUses())
252 opsToProcess.push_back(u.getOwner());
253
254 // Insert fork after op
255 rewriter.setInsertionPointAfterValue(result);
256 auto forkSize = opsToProcess.size();
257 Operation *newOp;
258 if (isLazy)
259 newOp = rewriter.create<LazyForkOp>(result.getLoc(), result, forkSize);
260 else
261 newOp = rewriter.create<ForkOp>(result.getLoc(), result, forkSize);
262
263 // Modify operands of successor
264 // opsToProcess may have multiple instances of same operand
265 // Replace uses one by one to assign different fork outputs to them
266 for (int i = 0, e = forkSize; i < e; ++i)
267 replaceFirstUse(opsToProcess[i], result, newOp->getResult(i));
268}
269
270// NOLINTNEXTLINE(misc-no-recursion)
272 return TypeSwitch<Type, esi::ChannelType>(t)
273 .Case<esi::ChannelType>([](auto t) { return t; })
274 .Case<TupleType>(
275 [&](TupleType tt) { return esiWrapper(tupleToStruct(tt)); })
276 .Case<NoneType>([](NoneType nt) {
277 // todo: change when handshake switches to i0
278 return esiWrapper(IntegerType::get(nt.getContext(), 0));
279 })
280 .Default([](auto t) {
281 return esi::ChannelType::get(t.getContext(), toValidType(t));
282 });
283}
284
286 TypeRange inputs,
287 TypeRange outputs) {
288 SmallVector<hw::PortInfo> pinputs, poutputs;
289
290 HandshakePortNameGenerator portNames(op);
291 auto *ctx = op->getContext();
292
293 Type i1Type = IntegerType::get(ctx, 1);
294 Type clkType = seq::ClockType::get(ctx);
295
296 // Add all inputs of funcOp.
297 unsigned inIdx = 0;
298 for (auto arg : llvm::enumerate(inputs)) {
299 pinputs.push_back(
300 {{portNames.inputName(arg.index()), esiWrapper(arg.value()),
301 hw::ModulePort::Direction::Input},
302 arg.index(),
303 {}});
304 inIdx++;
305 }
306
307 // Add all outputs of funcOp.
308 for (auto res : llvm::enumerate(outputs)) {
309 poutputs.push_back(
310 {{portNames.outputName(res.index()), esiWrapper(res.value()),
311 hw::ModulePort::Direction::Output},
312 res.index(),
313 {}});
314 }
315
316 // Add clock and reset signals.
317 if (op->hasTrait<mlir::OpTrait::HasClock>()) {
318 pinputs.push_back({{StringAttr::get(ctx, "clock"), clkType,
319 hw::ModulePort::Direction::Input},
320 inIdx++,
321 {}});
322 pinputs.push_back({{StringAttr::get(ctx, "reset"), i1Type,
323 hw::ModulePort::Direction::Input},
324 inIdx,
325 {}});
326 }
327
328 return hw::ModulePortInfo{pinputs, poutputs};
329}
assert(baseType &&"element must be base type")
static Type tupleToStruct(TupleType tuple)
static void replaceFirstUse(Operation *op, Value oldVal, Value newVal)
Channels are the basic communication primitives.
Definition Types.h:63
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
void insertFork(Value result, bool isLazy, OpBuilder &rewriter)
Adds fork operations to any value with multiple uses in r.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds a decoded list of input/inout and output ports for a module or instance.