Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
HandshakeUtils.cpp
Go to the documentation of this file.
1//===- HandshakeUtils.cpp - handshake pass helper functions -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions for various helper functions used in handshake
10// passes.
11//
12//===----------------------------------------------------------------------===//
13
19#include "circt/Support/LLVM.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/OperationSupport.h"
23#include "mlir/IR/PatternMatch.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26using namespace circt;
27using namespace handshake;
28using namespace mlir;
29
30/// Iterates over the handshake::FuncOp's in the program to build an instance
31/// graph. In doing so, we detect whether there are any cycles in this graph, as
32/// well as infer a top module for the design by performing a topological sort
33/// of the instance graph. The result of this sort is placed in sortedFuncs.
35 ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel,
36 SmallVectorImpl<std::string> &sortedFuncs) {
37 // Create use graph
38 auto walkFuncOps = [&](handshake::FuncOp funcOp) {
39 auto &funcUses = instanceGraph[funcOp.getName().str()];
40 funcOp.walk([&](handshake::InstanceOp instanceOp) {
41 funcUses.insert(instanceOp.getModule().str());
42 });
43 };
44 moduleOp.walk(walkFuncOps);
45
46 // find top-level (and cycles) using a topological sort. Initialize all
47 // instances as candidate top level modules; these will be pruned whenever
48 // they are referenced by another module.
49 std::set<std::string> visited, marked, candidateTopLevel;
50 SmallVector<std::string> cycleTrace;
51 bool cyclic = false;
52 llvm::transform(instanceGraph,
53 std::inserter(candidateTopLevel, candidateTopLevel.begin()),
54 [](auto it) { return it.first; });
55 std::function<void(const std::string &, SmallVector<std::string>)> cycleUtil =
56 [&](const std::string &node, SmallVector<std::string> trace) {
57 if (cyclic || visited.count(node))
58 return;
59 trace.push_back(node);
60 if (marked.count(node)) {
61 cyclic = true;
62 cycleTrace = trace;
63 return;
64 }
65 marked.insert(node);
66 for (auto use : instanceGraph[node]) {
67 candidateTopLevel.erase(use);
68 cycleUtil(use, trace);
69 }
70 marked.erase(node);
71 visited.insert(node);
72 sortedFuncs.insert(sortedFuncs.begin(), node);
73 };
74 for (auto it : instanceGraph) {
75 if (visited.count(it.first) == 0)
76 cycleUtil(it.first, {});
77 if (cyclic)
78 break;
79 }
80
81 if (cyclic) {
82 auto err = moduleOp.emitOpError();
83 err << "cannot deduce top level function - cycle "
84 "detected in instance graph (";
85 llvm::interleave(
86 cycleTrace, err, [&](auto node) { err << node; }, "->");
87 err << ").";
88 return err;
89 }
90 assert(!candidateTopLevel.empty() &&
91 "if non-cyclic, there should be at least 1 candidate top level");
92
93 if (candidateTopLevel.size() > 1) {
94 auto err = moduleOp.emitOpError();
95 err << "multiple candidate top-level modules detected (";
96 llvm::interleaveComma(candidateTopLevel, err,
97 [&](auto topLevel) { err << topLevel; });
98 err << "). Please remove one of these from the source program.";
99 return err;
100 }
101 topLevel = *candidateTopLevel.begin();
102 return success();
103}
104
105LogicalResult
107 if (funcOp.isExternal())
108 return success();
109
110 auto checkUseFunc = [&](Operation *op, Value v, StringRef desc,
111 unsigned idx) -> LogicalResult {
112 auto numUses = std::distance(v.getUses().begin(), v.getUses().end());
113 if (numUses == 0)
114 return op->emitOpError() << desc << " " << idx << " has no uses.";
115 if (numUses > 1)
116 return op->emitOpError() << desc << " " << idx << " has multiple uses.";
117 return success();
118 };
119
120 for (auto &subOp : funcOp.getOps()) {
121 for (auto res : llvm::enumerate(subOp.getResults())) {
122 if (failed(checkUseFunc(&subOp, res.value(), "result", res.index())))
123 return failure();
124 }
125 }
126
127 Block &entryBlock = funcOp.front();
128 for (auto barg : enumerate(entryBlock.getArguments())) {
129 if (failed(checkUseFunc(funcOp.getOperation(), barg.value(), "argument",
130 barg.index())))
131 return failure();
132 }
133 return success();
134}
135
136// NOLINTNEXTLINE(misc-no-recursion)
137static Type tupleToStruct(TupleType tuple) {
138 auto *ctx = tuple.getContext();
139 mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
140 for (auto [i, innerType] : llvm::enumerate(tuple)) {
141 Type convertedInnerType = innerType;
142 if (auto tupleInnerType = dyn_cast<TupleType>(innerType))
143 convertedInnerType = tupleToStruct(tupleInnerType);
144 hwfields.push_back({StringAttr::get(ctx, "field" + std::to_string(i)),
145 convertedInnerType});
146 }
147
148 return hw::StructType::get(ctx, hwfields);
149}
150
151// Converts 't' into a valid HW type. This is strictly used for converting
152// 'index' types into a fixed-width type.
154 return TypeSwitch<Type, Type>(t)
155 .Case<IndexType>(
156 [&](IndexType it) { return IntegerType::get(it.getContext(), 64); })
157 .Case<TupleType>([&](TupleType tt) {
158 llvm::SmallVector<Type> types;
159 for (auto innerType : tt)
160 types.push_back(toValidType(innerType));
161 return tupleToStruct(
162 mlir::TupleType::get(types[0].getContext(), types));
163 })
164 .Case<hw::StructType>([&](auto st) {
165 llvm::SmallVector<hw::StructType::FieldInfo> structFields(
166 st.getElements());
167 for (auto &field : structFields)
168 field.type = toValidType(field.type);
169 return hw::StructType::get(st.getContext(), structFields);
170 })
171 .Case<NoneType>(
172 [&](NoneType nt) { return IntegerType::get(nt.getContext(), 0); })
173 .Default([&](Type t) { return t; });
174}
175
176namespace {
177
178/// A class to be used with getPortInfoForOp. Provides an opaque interface for
179/// generating the port names of an operation; handshake operations generate
180/// names by the Handshake NamedIOInterface; and other operations, such as
181/// arith ops, are assigned default names.
182class HandshakePortNameGenerator {
183public:
184 explicit HandshakePortNameGenerator(Operation *op)
185 : builder(op->getContext()) {
186 auto namedOpInterface = dyn_cast<handshake::NamedIOInterface>(op);
187 if (namedOpInterface)
188 inferFromNamedOpInterface(namedOpInterface);
189 else if (auto funcOp = dyn_cast<handshake::FuncOp>(op))
190 inferFromFuncOp(funcOp);
191 else
192 inferDefault(op);
193 }
194
195 StringAttr inputName(unsigned idx) { return inputs[idx]; }
196 StringAttr outputName(unsigned idx) { return outputs[idx]; }
197
198private:
199 using IdxToStrF = const std::function<std::string(unsigned)> &;
200 void infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF) {
201 llvm::transform(
202 llvm::enumerate(op->getOperandTypes()), std::back_inserter(inputs),
203 [&](auto it) { return builder.getStringAttr(inF(it.index())); });
204 llvm::transform(
205 llvm::enumerate(op->getResultTypes()), std::back_inserter(outputs),
206 [&](auto it) { return builder.getStringAttr(outF(it.index())); });
207 }
208
209 void inferDefault(Operation *op) {
210 infer(
211 op, [](unsigned idx) { return "in" + std::to_string(idx); },
212 [](unsigned idx) { return "out" + std::to_string(idx); });
213 }
214
215 void inferFromNamedOpInterface(handshake::NamedIOInterface op) {
216 infer(
217 op, [&](unsigned idx) { return op.getOperandName(idx); },
218 [&](unsigned idx) { return op.getResultName(idx); });
219 }
220
221 void inferFromFuncOp(handshake::FuncOp op) {
222 auto inF = [&](unsigned idx) { return op.getArgName(idx).str(); };
223 auto outF = [&](unsigned idx) { return op.getResName(idx).str(); };
224 llvm::transform(
225 llvm::enumerate(op.getArgumentTypes()), std::back_inserter(inputs),
226 [&](auto it) { return builder.getStringAttr(inF(it.index())); });
227 llvm::transform(
228 llvm::enumerate(op.getResultTypes()), std::back_inserter(outputs),
229 [&](auto it) { return builder.getStringAttr(outF(it.index())); });
230 }
231
232 Builder builder;
233 llvm::SmallVector<StringAttr> inputs;
234 llvm::SmallVector<StringAttr> outputs;
235};
236} // namespace
237
238static void replaceFirstUse(Operation *op, Value oldVal, Value newVal) {
239 for (int i = 0, e = op->getNumOperands(); i < e; ++i)
240 if (op->getOperand(i) == oldVal) {
241 op->setOperand(i, newVal);
242 break;
243 }
244}
245
246void circt::handshake::insertFork(Value result, bool isLazy,
247 OpBuilder &rewriter) {
248 // Get successor operations
249 std::vector<Operation *> opsToProcess;
250 for (auto &u : result.getUses())
251 opsToProcess.push_back(u.getOwner());
252
253 // Insert fork after op
254 rewriter.setInsertionPointAfterValue(result);
255 auto forkSize = opsToProcess.size();
256 Operation *newOp;
257 if (isLazy)
258 newOp = LazyForkOp::create(rewriter, result.getLoc(), result, forkSize);
259 else
260 newOp = ForkOp::create(rewriter, result.getLoc(), result, forkSize);
261
262 // Modify operands of successor
263 // opsToProcess may have multiple instances of same operand
264 // Replace uses one by one to assign different fork outputs to them
265 for (int i = 0, e = forkSize; i < e; ++i)
266 replaceFirstUse(opsToProcess[i], result, newOp->getResult(i));
267}
268
269// NOLINTNEXTLINE(misc-no-recursion)
271 return TypeSwitch<Type, esi::ChannelType>(t)
272 .Case<esi::ChannelType>([](auto t) { return t; })
273 .Case<TupleType>(
274 [&](TupleType tt) { return esiWrapper(tupleToStruct(tt)); })
275 .Case<NoneType>([](NoneType nt) {
276 // todo: change when handshake switches to i0
277 return esiWrapper(IntegerType::get(nt.getContext(), 0));
278 })
279 .Default([](auto t) {
280 return esi::ChannelType::get(t.getContext(), toValidType(t));
281 });
282}
283
285 TypeRange inputs,
286 TypeRange outputs) {
287 SmallVector<hw::PortInfo> pinputs, poutputs;
288
289 HandshakePortNameGenerator portNames(op);
290 auto *ctx = op->getContext();
291
292 Type i1Type = IntegerType::get(ctx, 1);
293 Type clkType = seq::ClockType::get(ctx);
294
295 // Add all inputs of funcOp.
296 unsigned inIdx = 0;
297 for (auto arg : llvm::enumerate(inputs)) {
298 pinputs.push_back(
299 {{portNames.inputName(arg.index()), esiWrapper(arg.value()),
300 hw::ModulePort::Direction::Input},
301 arg.index(),
302 {}});
303 inIdx++;
304 }
305
306 // Add all outputs of funcOp.
307 for (auto res : llvm::enumerate(outputs)) {
308 poutputs.push_back(
309 {{portNames.outputName(res.index()), esiWrapper(res.value()),
310 hw::ModulePort::Direction::Output},
311 res.index(),
312 {}});
313 }
314
315 // Add clock and reset signals.
316 if (op->hasTrait<mlir::OpTrait::HasClock>()) {
317 pinputs.push_back({{StringAttr::get(ctx, "clock"), clkType,
318 hw::ModulePort::Direction::Input},
319 inIdx++,
320 {}});
321 pinputs.push_back({{StringAttr::get(ctx, "reset"), i1Type,
322 hw::ModulePort::Direction::Input},
323 inIdx,
324 {}});
325 }
326
327 return hw::ModulePortInfo{pinputs, poutputs};
328}
assert(baseType &&"element must be base type")
static Type tupleToStruct(TupleType tuple)
static void replaceFirstUse(Operation *op, Value oldVal, Value newVal)
Channels are the basic communication primitives.
Definition Types.h:70
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
void insertFork(Value result, bool isLazy, OpBuilder &rewriter)
Adds fork operations to any value with multiple uses in r.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds a decoded list of input/inout and output ports for a module or instance.