CIRCT  19.0.0git
PassHelpers.cpp
Go to the documentation of this file.
1 //===- PassHelpers.cpp - handshake pass helper functions --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions for various helper functions used in handshake
10 // passes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetails.h"
19 #include "circt/Support/LLVM.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/OperationSupport.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace circt;
27 using namespace handshake;
28 using namespace mlir;
29 
30 namespace circt {
31 
32 namespace handshake {
33 
34 /// Iterates over the handshake::FuncOp's in the program to build an instance
35 /// graph. In doing so, we detect whether there are any cycles in this graph, as
36 /// well as infer a top module for the design by performing a topological sort
37 /// of the instance graph. The result of this sort is placed in sortedFuncs.
38 LogicalResult resolveInstanceGraph(ModuleOp moduleOp,
39  InstanceGraph &instanceGraph,
40  std::string &topLevel,
41  SmallVectorImpl<std::string> &sortedFuncs) {
42  // Create use graph
43  auto walkFuncOps = [&](handshake::FuncOp funcOp) {
44  auto &funcUses = instanceGraph[funcOp.getName().str()];
45  funcOp.walk([&](handshake::InstanceOp instanceOp) {
46  funcUses.insert(instanceOp.getModule().str());
47  });
48  };
49  moduleOp.walk(walkFuncOps);
50 
51  // find top-level (and cycles) using a topological sort. Initialize all
52  // instances as candidate top level modules; these will be pruned whenever
53  // they are referenced by another module.
54  std::set<std::string> visited, marked, candidateTopLevel;
55  SmallVector<std::string> cycleTrace;
56  bool cyclic = false;
57  llvm::transform(instanceGraph,
58  std::inserter(candidateTopLevel, candidateTopLevel.begin()),
59  [](auto it) { return it.first; });
60  std::function<void(const std::string &, SmallVector<std::string>)> cycleUtil =
61  [&](const std::string &node, SmallVector<std::string> trace) {
62  if (cyclic || visited.count(node))
63  return;
64  trace.push_back(node);
65  if (marked.count(node)) {
66  cyclic = true;
67  cycleTrace = trace;
68  return;
69  }
70  marked.insert(node);
71  for (auto use : instanceGraph[node]) {
72  candidateTopLevel.erase(use);
73  cycleUtil(use, trace);
74  }
75  marked.erase(node);
76  visited.insert(node);
77  sortedFuncs.insert(sortedFuncs.begin(), node);
78  };
79  for (auto it : instanceGraph) {
80  if (visited.count(it.first) == 0)
81  cycleUtil(it.first, {});
82  if (cyclic)
83  break;
84  }
85 
86  if (cyclic) {
87  auto err = moduleOp.emitOpError();
88  err << "cannot deduce top level function - cycle "
89  "detected in instance graph (";
90  llvm::interleave(
91  cycleTrace, err, [&](auto node) { err << node; }, "->");
92  err << ").";
93  return err;
94  }
95  assert(!candidateTopLevel.empty() &&
96  "if non-cyclic, there should be at least 1 candidate top level");
97 
98  if (candidateTopLevel.size() > 1) {
99  auto err = moduleOp.emitOpError();
100  err << "multiple candidate top-level modules detected (";
101  llvm::interleaveComma(candidateTopLevel, err,
102  [&](auto topLevel) { err << topLevel; });
103  err << "). Please remove one of these from the source program.";
104  return err;
105  }
106  topLevel = *candidateTopLevel.begin();
107  return success();
108 }
109 
110 LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp funcOp) {
111  if (funcOp.isExternal())
112  return success();
113 
114  auto checkUseFunc = [&](Operation *op, Value v, StringRef desc,
115  unsigned idx) -> LogicalResult {
116  auto numUses = std::distance(v.getUses().begin(), v.getUses().end());
117  if (numUses == 0)
118  return op->emitOpError() << desc << " " << idx << " has no uses.";
119  if (numUses > 1)
120  return op->emitOpError() << desc << " " << idx << " has multiple uses.";
121  return success();
122  };
123 
124  for (auto &subOp : funcOp.getOps()) {
125  for (auto res : llvm::enumerate(subOp.getResults())) {
126  if (failed(checkUseFunc(&subOp, res.value(), "result", res.index())))
127  return failure();
128  }
129  }
130 
131  Block &entryBlock = funcOp.front();
132  for (auto barg : enumerate(entryBlock.getArguments())) {
133  if (failed(checkUseFunc(funcOp.getOperation(), barg.value(), "argument",
134  barg.index())))
135  return failure();
136  }
137  return success();
138 }
139 
140 // NOLINTNEXTLINE(misc-no-recursion)
141 static Type tupleToStruct(TupleType tuple) {
142  auto *ctx = tuple.getContext();
143  mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
144  for (auto [i, innerType] : llvm::enumerate(tuple)) {
145  Type convertedInnerType = innerType;
146  if (auto tupleInnerType = innerType.dyn_cast<TupleType>())
147  convertedInnerType = tupleToStruct(tupleInnerType);
148  hwfields.push_back({StringAttr::get(ctx, "field" + std::to_string(i)),
149  convertedInnerType});
150  }
151 
152  return hw::StructType::get(ctx, hwfields);
153 }
154 
155 // Converts 't' into a valid HW type. This is strictly used for converting
156 // 'index' types into a fixed-width type.
157 Type toValidType(Type t) {
158  return TypeSwitch<Type, Type>(t)
159  .Case<IndexType>(
160  [&](IndexType it) { return IntegerType::get(it.getContext(), 64); })
161  .Case<TupleType>([&](TupleType tt) {
162  llvm::SmallVector<Type> types;
163  for (auto innerType : tt)
164  types.push_back(toValidType(innerType));
165  return tupleToStruct(
166  mlir::TupleType::get(types[0].getContext(), types));
167  })
168  .Case<hw::StructType>([&](auto st) {
169  llvm::SmallVector<hw::StructType::FieldInfo> structFields(
170  st.getElements());
171  for (auto &field : structFields)
172  field.type = toValidType(field.type);
173  return hw::StructType::get(st.getContext(), structFields);
174  })
175  .Case<NoneType>(
176  [&](NoneType nt) { return IntegerType::get(nt.getContext(), 0); })
177  .Default([&](Type t) { return t; });
178 }
179 
180 // Wraps a type into an ESI ChannelType type. The inner type is converted to
181 // ensure comprehensability by the RTL dialects.
183  return TypeSwitch<Type, esi::ChannelType>(t)
184  .Case<esi::ChannelType>([](auto t) { return t; })
185  .Case<TupleType>(
186  [&](TupleType tt) { return esiWrapper(tupleToStruct(tt)); })
187  .Case<NoneType>([](NoneType nt) {
188  // todo: change when handshake switches to i0
189  return esiWrapper(IntegerType::get(nt.getContext(), 0));
190  })
191  .Default([](auto t) {
192  return esi::ChannelType::get(t.getContext(), toValidType(t));
193  });
194 }
195 
196 namespace {
197 
198 /// A class to be used with getPortInfoForOp. Provides an opaque interface for
199 /// generating the port names of an operation; handshake operations generate
200 /// names by the Handshake NamedIOInterface; and other operations, such as
201 /// arith ops, are assigned default names.
202 class HandshakePortNameGenerator {
203 public:
204  explicit HandshakePortNameGenerator(Operation *op)
205  : builder(op->getContext()) {
206  auto namedOpInterface = dyn_cast<handshake::NamedIOInterface>(op);
207  if (namedOpInterface)
208  inferFromNamedOpInterface(namedOpInterface);
209  else if (auto funcOp = dyn_cast<handshake::FuncOp>(op))
210  inferFromFuncOp(funcOp);
211  else
212  inferDefault(op);
213  }
214 
215  StringAttr inputName(unsigned idx) { return inputs[idx]; }
216  StringAttr outputName(unsigned idx) { return outputs[idx]; }
217 
218 private:
219  using IdxToStrF = const std::function<std::string(unsigned)> &;
220  void infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF) {
221  llvm::transform(
222  llvm::enumerate(op->getOperandTypes()), std::back_inserter(inputs),
223  [&](auto it) { return builder.getStringAttr(inF(it.index())); });
224  llvm::transform(
225  llvm::enumerate(op->getResultTypes()), std::back_inserter(outputs),
226  [&](auto it) { return builder.getStringAttr(outF(it.index())); });
227  }
228 
229  void inferDefault(Operation *op) {
230  infer(
231  op, [](unsigned idx) { return "in" + std::to_string(idx); },
232  [](unsigned idx) { return "out" + std::to_string(idx); });
233  }
234 
235  void inferFromNamedOpInterface(handshake::NamedIOInterface op) {
236  infer(
237  op, [&](unsigned idx) { return op.getOperandName(idx); },
238  [&](unsigned idx) { return op.getResultName(idx); });
239  }
240 
241  void inferFromFuncOp(handshake::FuncOp op) {
242  auto inF = [&](unsigned idx) { return op.getArgName(idx).str(); };
243  auto outF = [&](unsigned idx) { return op.getResName(idx).str(); };
244  llvm::transform(
245  llvm::enumerate(op.getArgumentTypes()), std::back_inserter(inputs),
246  [&](auto it) { return builder.getStringAttr(inF(it.index())); });
247  llvm::transform(
248  llvm::enumerate(op.getResultTypes()), std::back_inserter(outputs),
249  [&](auto it) { return builder.getStringAttr(outF(it.index())); });
250  }
251 
252  Builder builder;
253  llvm::SmallVector<StringAttr> inputs;
254  llvm::SmallVector<StringAttr> outputs;
255 };
256 } // namespace
257 
259  TypeRange outputs) {
260  SmallVector<hw::PortInfo> pinputs, poutputs;
261 
262  HandshakePortNameGenerator portNames(op);
263  auto *ctx = op->getContext();
264 
265  Type i1Type = IntegerType::get(ctx, 1);
266  Type clkType = seq::ClockType::get(ctx);
267 
268  // Add all inputs of funcOp.
269  unsigned inIdx = 0;
270  for (auto arg : llvm::enumerate(inputs)) {
271  pinputs.push_back(
272  {{portNames.inputName(arg.index()), esiWrapper(arg.value()),
274  arg.index(),
275  {}});
276  inIdx++;
277  }
278 
279  // Add all outputs of funcOp.
280  for (auto res : llvm::enumerate(outputs)) {
281  poutputs.push_back(
282  {{portNames.outputName(res.index()), esiWrapper(res.value()),
284  res.index(),
285  {}});
286  }
287 
288  // Add clock and reset signals.
289  if (op->hasTrait<mlir::OpTrait::HasClock>()) {
290  pinputs.push_back({{StringAttr::get(ctx, "clock"), clkType,
292  inIdx++,
293  {}});
294  pinputs.push_back({{StringAttr::get(ctx, "reset"), i1Type,
296  inIdx,
297  {}});
298  }
299 
300  return hw::ModulePortInfo{pinputs, poutputs};
301 }
302 
303 } // namespace handshake
304 } // namespace circt
assert(baseType &&"element must be base type")
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
Channels are the basic communication primitives.
Definition: Types.h:63
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
Definition: PassHelpers.cpp:38
static Type tupleToStruct(TupleType tuple)
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Type toValidType(Type t)
esi::ChannelType esiWrapper(mlir::Type t)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
This holds a decoded list of input/inout and output ports for a module or instance.