CIRCT  19.0.0git
HandshakeToHW.cpp
Go to the documentation of this file.
1 //===- HandshakeToHW.cpp - Translate Handshake into HW ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main Handshake to HW Conversion Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "../PassDetail.h"
17 #include "circt/Dialect/HW/HWOps.h"
25 #include "mlir/Dialect/Arith/IR/Arith.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace circt;
36 using namespace circt::handshake;
37 using namespace circt::hw;
38 
39 using NameUniquer = std::function<std::string(Operation *)>;
40 
41 namespace {
42 
43 static Type tupleToStruct(TypeRange types) {
44  return toValidType(mlir::TupleType::get(types[0].getContext(), types));
45 }
46 
47 // Shared state used by various functions; captured in a struct to reduce the
48 // number of arguments that we have to pass around.
49 struct HandshakeLoweringState {
50  ModuleOp parentModule;
51  NameUniquer nameUniquer;
52 };
53 
54 // A type converter is needed to perform the in-flight materialization of "raw"
55 // (non-ESI channel) types to their ESI channel correspondents. This comes into
56 // effect when backedges exist in the input IR.
57 class ESITypeConverter : public TypeConverter {
58 public:
59  ESITypeConverter() {
60  addConversion([](Type type) -> Type { return esiWrapper(type); });
61 
62  addTargetMaterialization(
63  [&](mlir::OpBuilder &builder, mlir::Type resultType,
64  mlir::ValueRange inputs,
65  mlir::Location loc) -> std::optional<mlir::Value> {
66  if (inputs.size() != 1)
67  return std::nullopt;
68  return inputs[0];
69  });
70 
71  addSourceMaterialization(
72  [&](mlir::OpBuilder &builder, mlir::Type resultType,
73  mlir::ValueRange inputs,
74  mlir::Location loc) -> std::optional<mlir::Value> {
75  if (inputs.size() != 1)
76  return std::nullopt;
77  return inputs[0];
78  });
79  }
80 };
81 
82 } // namespace
83 
84 /// Returns a submodule name resulting from an operation, without discriminating
85 /// type information.
86 static std::string getBareSubModuleName(Operation *oldOp) {
87  // The dialect name is separated from the operation name by '.', which is not
88  // valid in SystemVerilog module names. In case this name is used in
89  // SystemVerilog output, replace '.' with '_'.
90  std::string subModuleName = oldOp->getName().getStringRef().str();
91  std::replace(subModuleName.begin(), subModuleName.end(), '.', '_');
92  return subModuleName;
93 }
94 
95 static std::string getCallName(Operation *op) {
96  auto callOp = dyn_cast<handshake::InstanceOp>(op);
97  return callOp ? callOp.getModule().str() : getBareSubModuleName(op);
98 }
99 
100 /// Extracts the type of the data-carrying type of opType. If opType is an ESI
101 /// channel, getHandshakeBundleDataType extracts the data-carrying type, else,
102 /// assume that opType itself is the data-carrying type.
103 static Type getOperandDataType(Value op) {
104  auto opType = op.getType();
105  if (auto channelType = opType.dyn_cast<esi::ChannelType>())
106  return channelType.getInner();
107  return opType;
108 }
109 
110 /// Filters NoneType's from the input.
111 static SmallVector<Type> filterNoneTypes(ArrayRef<Type> input) {
112  SmallVector<Type> filterRes;
113  llvm::copy_if(input, std::back_inserter(filterRes),
114  [](Type type) { return !type.isa<NoneType>(); });
115  return filterRes;
116 }
117 
118 /// Returns a set of types which may uniquely identify the provided op. Return
119 /// value is <inputTypes, outputTypes>.
120 using DiscriminatingTypes = std::pair<SmallVector<Type>, SmallVector<Type>>;
122  return TypeSwitch<Operation *, DiscriminatingTypes>(op)
123  .Case<MemoryOp, ExternalMemoryOp>([&](auto memOp) {
124  return DiscriminatingTypes{{},
125  {memOp.getMemRefType().getElementType()}};
126  })
127  .Default([&](auto) {
128  // By default, all in- and output types which is not a control type
129  // (NoneType) are discriminating types.
130  std::vector<Type> inTypes, outTypes;
131  llvm::transform(op->getOperands(), std::back_inserter(inTypes),
133  llvm::transform(op->getResults(), std::back_inserter(outTypes),
135  return DiscriminatingTypes{filterNoneTypes(inTypes),
136  filterNoneTypes(outTypes)};
137  });
138 }
139 
140 /// Get type name. Currently we only support integer or index types.
141 /// The emitted type aligns with the getFIRRTLType() method. Thus all integers
142 /// other than signed integers will be emitted as unsigned.
143 // NOLINTNEXTLINE(misc-no-recursion)
144 static std::string getTypeName(Location loc, Type type) {
145  std::string typeName;
146  // Builtin types
147  if (type.isIntOrIndex()) {
148  if (auto indexType = type.dyn_cast<IndexType>())
149  typeName += "_ui" + std::to_string(indexType.kInternalStorageBitWidth);
150  else if (type.isSignedInteger())
151  typeName += "_si" + std::to_string(type.getIntOrFloatBitWidth());
152  else
153  typeName += "_ui" + std::to_string(type.getIntOrFloatBitWidth());
154  } else if (auto tupleType = type.dyn_cast<TupleType>()) {
155  typeName += "_tuple";
156  for (auto elementType : tupleType.getTypes())
157  typeName += getTypeName(loc, elementType);
158  } else if (auto structType = type.dyn_cast<hw::StructType>()) {
159  typeName += "_struct";
160  for (auto element : structType.getElements())
161  typeName += "_" + element.name.str() + getTypeName(loc, element.type);
162  } else
163  emitError(loc) << "unsupported data type '" << type << "'";
164 
165  return typeName;
166 }
167 
168 /// Construct a name for creating HW sub-module.
169 static std::string getSubModuleName(Operation *oldOp) {
170  if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
171  return instanceOp.getModule().str();
172 
173  std::string subModuleName = getBareSubModuleName(oldOp);
174 
175  // Add value of the constant operation.
176  if (auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
177  if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>()) {
178  auto intType = intAttr.getType();
179 
180  if (intType.isSignedInteger())
181  subModuleName += "_c" + std::to_string(intAttr.getSInt());
182  else if (intType.isUnsignedInteger())
183  subModuleName += "_c" + std::to_string(intAttr.getUInt());
184  else
185  subModuleName += "_c" + std::to_string((uint64_t)intAttr.getInt());
186  } else
187  oldOp->emitError("unsupported constant type");
188  }
189 
190  // Add discriminating in- and output types.
191  auto [inTypes, outTypes] = getHandshakeDiscriminatingTypes(oldOp);
192  if (!inTypes.empty())
193  subModuleName += "_in";
194  for (auto inType : inTypes)
195  subModuleName += getTypeName(oldOp->getLoc(), inType);
196 
197  if (!outTypes.empty())
198  subModuleName += "_out";
199  for (auto outType : outTypes)
200  subModuleName += getTypeName(oldOp->getLoc(), outType);
201 
202  // Add memory ID.
203  if (auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
204  subModuleName += "_id" + std::to_string(memOp.getId());
205 
206  // Add compare kind.
207  if (auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
208  subModuleName += "_" + stringifyEnum(comOp.getPredicate()).str();
209 
210  // Add buffer information.
211  if (auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
212  subModuleName += "_" + std::to_string(bufferOp.getNumSlots()) + "slots";
213  if (bufferOp.isSequential())
214  subModuleName += "_seq";
215  else
216  subModuleName += "_fifo";
217 
218  if (auto initValues = bufferOp.getInitValues()) {
219  subModuleName += "_init";
220  for (const Attribute e : *initValues) {
221  assert(e.isa<IntegerAttr>());
222  subModuleName +=
223  "_" + std::to_string(e.dyn_cast<IntegerAttr>().getInt());
224  }
225  }
226  }
227 
228  // Add control information.
229  if (auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
230  ctrlInterface && ctrlInterface.isControl()) {
231  // Add some additional discriminating info for non-typed operations.
232  subModuleName += "_" + std::to_string(oldOp->getNumOperands()) + "ins_" +
233  std::to_string(oldOp->getNumResults()) + "outs";
234  subModuleName += "_ctrl";
235  } else {
236  assert(
237  (!inTypes.empty() || !outTypes.empty()) &&
238  "Insufficient discriminating type info generated for the operation!");
239  }
240 
241  return subModuleName;
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // HW Sub-module Related Functions
246 //===----------------------------------------------------------------------===//
247 
248 /// Check whether a submodule with the same name has been created elsewhere in
249 /// the top level module. Return the matched module operation if true, otherwise
250 /// return nullptr.
251 static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
252  StringRef modName) {
253  if (auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
254  return mod;
255  if (auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
256  return mod;
257  return {};
258 }
259 
260 static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
261  Operation *oldOp) {
262  HWModuleLike targetModule;
263  if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
264  targetModule = checkSubModuleOp(parentModule, instanceOp.getModule());
265  else
266  targetModule = checkSubModuleOp(parentModule, getSubModuleName(oldOp));
267 
268  if (isa<handshake::InstanceOp>(oldOp))
269  assert(targetModule &&
270  "handshake.instance target modules should always have been lowered "
271  "before the modules that reference them!");
272  return targetModule;
273 }
274 
275 /// Returns a vector of PortInfo's which defines the HW interface of the
276 /// to-be-converted op.
277 static ModulePortInfo getPortInfoForOp(Operation *op) {
278  return getPortInfoForOpTypes(op, op->getOperandTypes(), op->getResultTypes());
279 }
280 
281 static llvm::SmallVector<hw::detail::FieldInfo>
282 portToFieldInfo(llvm::ArrayRef<hw::PortInfo> portInfo) {
283  llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
284  for (auto port : portInfo)
285  fieldInfo.push_back({port.name, port.type});
286 
287  return fieldInfo;
288 }
289 
290 // Convert any handshake.extmemory operations and the top-level I/O
291 // associated with these.
292 static LogicalResult convertExtMemoryOps(HWModuleOp mod) {
293  auto ports = mod.getPortList();
294  auto *ctx = mod.getContext();
295 
296  // Gather memref ports to be converted.
297  llvm::DenseMap<unsigned, Value> memrefPorts;
298  for (auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
299  auto channel = arg.getType().dyn_cast<esi::ChannelType>();
300  if (channel && channel.getInner().isa<MemRefType>())
301  memrefPorts[i] = arg;
302  }
303 
304  if (memrefPorts.empty())
305  return success(); // nothing to do.
306 
307  OpBuilder b(mod);
308 
309  auto getMemoryIOInfo = [&](Location loc, Twine portName, unsigned argIdx,
310  ArrayRef<hw::PortInfo> info,
311  hw::ModulePort::Direction direction) {
312  auto type = hw::StructType::get(ctx, portToFieldInfo(info));
313  auto portInfo =
314  hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
315  return portInfo;
316  };
317 
318  for (auto [i, arg] : memrefPorts) {
319  // Insert ports into the module
320  auto memName = mod.getArgName(i);
321 
322  // Get the attached extmemory external module.
323  auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
324  auto extmemMod =
325  cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
326  extmemInstance, extmemInstance.getModuleNameAttr()));
327 
328  ModulePortInfo portInfo(extmemMod.getPortList());
329 
330  // The extmemory external module's interface is a direct wrapping of the
331  // original handshake.extmemory operation in- and output types. Remove the
332  // first input argument (the !esi.channel<memref> op) since that is what
333  // we're replacing with a materialized interface.
334  portInfo.eraseInput(0);
335 
336  // Add memory input - this is the output of the extmemory op.
337  SmallVector<PortInfo> outputs(portInfo.getOutputs());
338  auto inPortInfo =
339  getMemoryIOInfo(arg.getLoc(), memName.strref() + "_in", i, outputs,
341  mod.insertPorts({{i, inPortInfo}}, {});
342  auto newInPort = mod.getArgumentForInput(i);
343  // Replace the extmemory submodule outputs with the newly created inputs.
344  b.setInsertionPointToStart(mod.getBodyBlock());
345  auto newInPortExploded = b.create<hw::StructExplodeOp>(
346  arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
347  extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
348 
349  // Add memory output - this is the inputs of the extmemory op (without the
350  // first argument);
351  unsigned outArgI = mod.getNumOutputPorts();
352  SmallVector<PortInfo> inputs(portInfo.getInputs());
353  auto outPortInfo =
354  getMemoryIOInfo(arg.getLoc(), memName.strref() + "_out", outArgI,
356 
357  auto memOutputArgs = extmemInstance.getOperands().drop_front();
358  b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
359  auto memOutputStruct = b.create<hw::StructCreateOp>(
360  arg.getLoc(), outPortInfo.type, memOutputArgs);
361  mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
362 
363  // Erase the extmemory submodule instace since the i/o has now been
364  // plumbed.
365  extmemMod.erase();
366  extmemInstance.erase();
367 
368  // Erase the original memref argument of the top-level i/o now that it's use
369  // has been removed.
370  mod.modifyPorts(/*insertInputs*/ {}, /*insertOutputs*/ {},
371  /*eraseInputs*/ {i + 1}, /*eraseOutputs*/ {});
372  }
373 
374  return success();
375 }
376 
377 namespace {
378 
379 // Input handshakes contain a resolved valid and (optional )data signal, and
380 // a to-be-assigned ready signal.
381 struct InputHandshake {
382  Value valid;
383  std::shared_ptr<Backedge> ready;
384  Value data;
385 };
386 
387 // Output handshakes contain a resolved ready, and to-be-assigned valid and
388 // (optional) data signals.
389 struct OutputHandshake {
390  std::shared_ptr<Backedge> valid;
391  Value ready;
392  std::shared_ptr<Backedge> data;
393 };
394 
395 /// A helper struct that acts like a wire. Can be used to interact with the
396 /// RTLBuilder when multiple built components should be connected.
397 struct HandshakeWire {
398  HandshakeWire(BackedgeBuilder &bb, Type dataType) {
399  MLIRContext *ctx = dataType.getContext();
400  auto i1Type = IntegerType::get(ctx, 1);
401  valid = std::make_shared<Backedge>(bb.get(i1Type));
402  ready = std::make_shared<Backedge>(bb.get(i1Type));
403  data = std::make_shared<Backedge>(bb.get(dataType));
404  }
405 
406  // Functions that allow to treat a wire like an input or output port.
407  // **Careful**: Such a port will not be updated when backedges are resolved.
408  InputHandshake getAsInput() { return {*valid, ready, *data}; }
409  OutputHandshake getAsOutput() { return {valid, *ready, data}; }
410 
411  std::shared_ptr<Backedge> valid;
412  std::shared_ptr<Backedge> ready;
413  std::shared_ptr<Backedge> data;
414 };
415 
416 template <typename T, typename TInner>
417 llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
418  llvm::function_ref<T(TInner &)> extractor) {
419  llvm::SmallVector<T> result;
420  llvm::transform(container, std::back_inserter(result), extractor);
421  return result;
422 }
423 struct UnwrappedIO {
424  llvm::SmallVector<InputHandshake> inputs;
425  llvm::SmallVector<OutputHandshake> outputs;
426 
427  llvm::SmallVector<Value> getInputValids() {
428  return extractValues<Value, InputHandshake>(
429  inputs, [](auto &hs) { return hs.valid; });
430  }
431  llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
432  return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
433  inputs, [](auto &hs) { return hs.ready; });
434  }
435  llvm::SmallVector<Value> getInputDatas() {
436  return extractValues<Value, InputHandshake>(
437  inputs, [](auto &hs) { return hs.data; });
438  }
439  llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
440  return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
441  outputs, [](auto &hs) { return hs.valid; });
442  }
443  llvm::SmallVector<Value> getOutputReadys() {
444  return extractValues<Value, OutputHandshake>(
445  outputs, [](auto &hs) { return hs.ready; });
446  }
447  llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
448  return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
449  outputs, [](auto &hs) { return hs.data; });
450  }
451 };
452 
453 // A class containing a bunch of syntactic sugar to reduce builder function
454 // verbosity.
455 // @todo: should be moved to support.
456 struct RTLBuilder {
457  RTLBuilder(hw::ModulePortInfo info, OpBuilder &builder, Location loc,
458  Value clk = Value(), Value rst = Value())
459  : info(std::move(info)), b(builder), loc(loc), clk(clk), rst(rst) {}
460 
461  Value constant(const APInt &apv, std::optional<StringRef> name = {}) {
462  // Cannot use zero-width APInt's in DenseMap's, see
463  // https://github.com/llvm/llvm-project/issues/58013
464  bool isZeroWidth = apv.getBitWidth() == 0;
465  if (!isZeroWidth) {
466  auto it = constants.find(apv);
467  if (it != constants.end())
468  return it->second;
469  }
470 
471  auto cval = b.create<hw::ConstantOp>(loc, apv);
472  if (!isZeroWidth)
473  constants[apv] = cval;
474  return cval;
475  }
476 
477  Value constant(unsigned width, int64_t value,
478  std::optional<StringRef> name = {}) {
479  return constant(APInt(width, value));
480  }
481  std::pair<Value, Value> wrap(Value data, Value valid,
482  std::optional<StringRef> name = {}) {
483  auto wrapOp = b.create<esi::WrapValidReadyOp>(loc, data, valid);
484  return {wrapOp.getResult(0), wrapOp.getResult(1)};
485  }
486  std::pair<Value, Value> unwrap(Value channel, Value ready,
487  std::optional<StringRef> name = {}) {
488  auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
489  return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
490  }
491 
492  // Various syntactic sugar functions.
493  Value reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
494  Value rst = Value()) {
495  Value resolvedClk = clk ? clk : this->clk;
496  Value resolvedRst = rst ? rst : this->rst;
497  assert(resolvedClk &&
498  "No global clock provided to this RTLBuilder - a clock "
499  "signal must be provided to the reg(...) function.");
500  assert(resolvedRst &&
501  "No global reset provided to this RTLBuilder - a reset "
502  "signal must be provided to the reg(...) function.");
503 
504  return b.create<seq::CompRegOp>(loc, in, resolvedClk, resolvedRst, rstValue,
505  name);
506  }
507 
508  Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
509  std::optional<StringRef> name = {}) {
510  return b.create<comb::ICmpOp>(loc, predicate, lhs, rhs);
511  }
512 
513  Value buildNamedOp(llvm::function_ref<Value()> f,
514  std::optional<StringRef> name) {
515  Value v = f();
516  StringAttr nameAttr;
517  Operation *op = v.getDefiningOp();
518  if (name.has_value()) {
519  op->setAttr("sv.namehint", b.getStringAttr(*name));
520  nameAttr = b.getStringAttr(*name);
521  }
522  return v;
523  }
524 
525  // Bitwise 'and'.
526  Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
527  return buildNamedOp(
528  [&]() { return b.create<comb::AndOp>(loc, values, false); }, name);
529  }
530 
531  Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
532  return buildNamedOp(
533  [&]() { return b.create<comb::OrOp>(loc, values, false); }, name);
534  }
535 
536  // Bitwise 'not'.
537  Value bNot(Value value, std::optional<StringRef> name = {}) {
538  auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
539  std::string inferedName;
540  if (!name) {
541  // Try to create a name from the input value.
542  if (auto valueName =
543  value.getDefiningOp()->getAttrOfType<StringAttr>("sv.namehint")) {
544  inferedName = ("not_" + valueName.getValue()).str();
545  name = inferedName;
546  }
547  }
548 
549  return buildNamedOp(
550  [&]() { return b.create<comb::XorOp>(loc, value, allOnes); }, name);
551 
552  return b.createOrFold<comb::XorOp>(loc, value, allOnes, false);
553  }
554 
555  Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
556  return buildNamedOp(
557  [&]() { return b.create<comb::ShlOp>(loc, value, shift); }, name);
558  }
559 
560  Value concat(ValueRange values, std::optional<StringRef> name = {}) {
561  return buildNamedOp([&]() { return b.create<comb::ConcatOp>(loc, values); },
562  name);
563  }
564 
565  // Packs a list of values into a hw.struct.
566  Value pack(ValueRange values, Type structType = Type(),
567  std::optional<StringRef> name = {}) {
568  if (!structType)
569  structType = tupleToStruct(values.getTypes());
570  return buildNamedOp(
571  [&]() { return b.create<hw::StructCreateOp>(loc, structType, values); },
572  name);
573  }
574 
575  // Unpacks a hw.struct into a list of values.
576  ValueRange unpack(Value value) {
577  auto structType = value.getType().cast<hw::StructType>();
578  llvm::SmallVector<Type> innerTypes;
579  structType.getInnerTypes(innerTypes);
580  return b.create<hw::StructExplodeOp>(loc, innerTypes, value).getResults();
581  }
582 
583  llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
584  llvm::SmallVector<Value> bits;
585  for (unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
586  bits.push_back(b.create<comb::ExtractOp>(loc, v, i, /*bitWidth=*/1));
587  return bits;
588  }
589 
590  // OR-reduction of the bits in 'v'.
591  Value rOr(Value v, std::optional<StringRef> name = {}) {
592  return buildNamedOp([&]() { return bOr(toBits(v)); }, name);
593  }
594 
595  // Extract bits v[hi:lo] (inclusive).
596  Value extract(Value v, unsigned lo, unsigned hi,
597  std::optional<StringRef> name = {}) {
598  unsigned width = hi - lo + 1;
599  return buildNamedOp(
600  [&]() { return b.create<comb::ExtractOp>(loc, v, lo, width); }, name);
601  }
602 
603  // Truncates 'value' to its lower 'width' bits.
604  Value truncate(Value value, unsigned width,
605  std::optional<StringRef> name = {}) {
606  return extract(value, 0, width - 1, name);
607  }
608 
609  Value zext(Value value, unsigned outWidth,
610  std::optional<StringRef> name = {}) {
611  unsigned inWidth = value.getType().getIntOrFloatBitWidth();
612  assert(inWidth <= outWidth && "zext: input width must be <- output width.");
613  if (inWidth == outWidth)
614  return value;
615  auto c0 = constant(outWidth - inWidth, 0);
616  return concat({c0, value}, name);
617  }
618 
619  Value sext(Value value, unsigned outWidth,
620  std::optional<StringRef> name = {}) {
621  return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
622  }
623 
624  // Extracts a single bit v[bit].
625  Value bit(Value v, unsigned index, std::optional<StringRef> name = {}) {
626  return extract(v, index, index, name);
627  }
628 
629  // Creates a hw.array of the given values.
630  Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
631  return buildNamedOp(
632  [&]() { return b.create<hw::ArrayCreateOp>(loc, values); }, name);
633  }
634 
635  // Extract the 'index'th value from the input array.
636  Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
637  return buildNamedOp(
638  [&]() { return b.create<hw::ArrayGetOp>(loc, array, index); }, name);
639  }
640 
641  // Muxes a range of values.
642  // The select signal is expected to be a decimal value which selects starting
643  // from the lowest index of value.
644  Value mux(Value index, ValueRange values,
645  std::optional<StringRef> name = {}) {
646  if (values.size() == 2)
647  return b.create<comb::MuxOp>(loc, index, values[1], values[0]);
648 
649  return arrayGet(arrayCreate(values), index, name);
650  }
651 
652  // Muxes a range of values. The select signal is expected to be a 1-hot
653  // encoded value.
654  Value ohMux(Value index, ValueRange inputs) {
655  // Confirm the select input can be a one-hot encoding for the inputs.
656  unsigned numInputs = inputs.size();
657  assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
658  "one-hot select can't mux inputs");
659 
660  // Start the mux tree with zero value.
661  // Todo: clean up when handshake supports i0.
662  auto dataType = inputs[0].getType();
663  unsigned width =
664  dataType.isa<NoneType>() ? 0 : dataType.getIntOrFloatBitWidth();
665  Value muxValue = constant(width, 0);
666 
667  // Iteratively chain together muxes from the high bit to the low bit.
668  for (size_t i = numInputs - 1; i != 0; --i) {
669  Value input = inputs[i];
670  Value selectBit = bit(index, i);
671  muxValue = mux(selectBit, {muxValue, input});
672  }
673 
674  return muxValue;
675  }
676 
677  hw::ModulePortInfo info;
678  OpBuilder &b;
679  Location loc;
680  Value clk, rst;
681  DenseMap<APInt, Value> constants;
682 };
683 
684 /// Creates a Value that has an assigned zero value. For structs, this
685 /// corresponds to assigning zero to each element recursively.
686 static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
687  return TypeSwitch<Type, Value>(type)
688  .Case<NoneType>([&](NoneType) { return s.constant(0, 0); })
689  .Case<IntType, IntegerType>([&](auto type) {
690  return s.constant(type.getIntOrFloatBitWidth(), 0);
691  })
692  .Case<hw::StructType>([&](auto structType) {
693  SmallVector<Value> zeroValues;
694  for (auto field : structType.getElements())
695  zeroValues.push_back(createZeroDataConst(s, loc, field.type));
696  return s.b.create<hw::StructCreateOp>(loc, structType, zeroValues);
697  })
698  .Default([&](Type) -> Value {
699  emitError(loc) << "unsupported type for zero value: " << type;
700  assert(false);
701  return {};
702  });
703 }
704 
705 static void
706 addSequentialIOOperandsIfNeeded(Operation *op,
707  llvm::SmallVectorImpl<Value> &operands) {
708  if (op->hasTrait<mlir::OpTrait::HasClock>()) {
709  // Parent should at this point be a hw.module and have clock and reset
710  // ports.
711  auto parent = cast<hw::HWModuleOp>(op->getParentOp());
712  operands.push_back(
713  parent.getArgumentForInput(parent.getNumInputPorts() - 2));
714  operands.push_back(
715  parent.getArgumentForInput(parent.getNumInputPorts() - 1));
716  }
717 }
718 
719 template <typename T>
720 class HandshakeConversionPattern : public OpConversionPattern<T> {
721 public:
722  HandshakeConversionPattern(ESITypeConverter &typeConverter,
723  MLIRContext *context, OpBuilder &submoduleBuilder,
724  HandshakeLoweringState &ls)
725  : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
726  submoduleBuilder(submoduleBuilder), ls(ls) {}
727 
728  using OpAdaptor = typename T::Adaptor;
729 
730  LogicalResult
731  matchAndRewrite(T op, OpAdaptor adaptor,
732  ConversionPatternRewriter &rewriter) const override {
733 
734  // Check if a submodule has already been created for the op. If so,
735  // instantiate the submodule. Else, run the pattern-defined module
736  // builder.
737  hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
738  if (!implModule) {
739  auto portInfo = ModulePortInfo(getPortInfoForOp(op));
740 
741  submoduleBuilder.setInsertionPoint(op->getParentOp());
742  implModule = submoduleBuilder.create<hw::HWModuleOp>(
743  op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
744  portInfo, [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
745  // if 'op' has clock trait, extract these and provide them to the
746  // RTL builder.
747  Value clk, rst;
748  if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
749  clk = ports.getInput("clock");
750  rst = ports.getInput("reset");
751  }
752 
753  BackedgeBuilder bb(b, op.getLoc());
754  RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
755  this->buildModule(op, bb, s, ports);
756  });
757  }
758 
759  // Instantiate the submodule.
760  llvm::SmallVector<Value> operands = adaptor.getOperands();
761  addSequentialIOOperandsIfNeeded(op, operands);
762  rewriter.replaceOpWithNewOp<hw::InstanceOp>(
763  op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
764  return success();
765  }
766 
767  virtual void buildModule(T op, BackedgeBuilder &bb, RTLBuilder &builder,
768  hw::HWModulePortAccessor &ports) const = 0;
769 
770  // Syntactic sugar functions.
771  // Unwraps an ESI-interfaced module into its constituent handshake signals.
772  // Backedges are created for the to-be-resolved signals, and output ports
773  // are assigned to their wrapped counterparts.
774  UnwrappedIO unwrapIO(RTLBuilder &s, BackedgeBuilder &bb,
775  hw::HWModulePortAccessor &ports) const {
776  UnwrappedIO unwrapped;
777  for (auto port : ports.getInputs()) {
778  if (!isa<esi::ChannelType>(port.getType()))
779  continue;
780  InputHandshake hs;
781  auto ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
782  auto [data, valid] = s.unwrap(port, *ready);
783  hs.data = data;
784  hs.valid = valid;
785  hs.ready = ready;
786  unwrapped.inputs.push_back(hs);
787  }
788  for (auto &outputInfo : ports.getPortList().getOutputs()) {
789  esi::ChannelType channelType =
790  dyn_cast<esi::ChannelType>(outputInfo.type);
791  if (!channelType)
792  continue;
793  OutputHandshake hs;
794  Type innerType = channelType.getInner();
795  auto data = std::make_shared<Backedge>(bb.get(innerType));
796  auto valid = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
797  auto [dataCh, ready] = s.wrap(*data, *valid);
798  hs.data = data;
799  hs.valid = valid;
800  hs.ready = ready;
801  ports.setOutput(outputInfo.name, dataCh);
802  unwrapped.outputs.push_back(hs);
803  }
804  return unwrapped;
805  }
806 
807  void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
808  OutputHandshake &output, Value cond) const {
809  auto validAndReady = s.bAnd({output.ready, cond});
810  for (auto &input : inputs)
811  input.ready->setValue(validAndReady);
812  }
813 
814  void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
815  OutputHandshake &output) const {
816  llvm::SmallVector<Value> valids;
817  for (auto &input : inputs)
818  valids.push_back(input.valid);
819  Value allValid = s.bAnd(valids);
820  output.valid->setValue(allValid);
821  setAllReadyWithCond(s, inputs, output, allValid);
822  }
823 
824  // Builds mux logic for the given inputs and outputs.
825  // Note: it is assumed that the caller has removed the 'select' signal from
826  // the 'unwrapped' inputs and provide it as a separate argument.
827  void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
828  InputHandshake &select) const {
829  // ============================= Control logic =============================
830  size_t numInputs = unwrapped.inputs.size();
831  size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
832  Value truncatedSelect =
833  select.data.getType().getIntOrFloatBitWidth() > selectWidth
834  ? s.truncate(select.data, selectWidth)
835  : select.data;
836 
837  // Decimal-to-1-hot decoder. 'shl' operands must be identical in size.
838  auto selectZext = s.zext(truncatedSelect, numInputs);
839  auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
840  auto &res = unwrapped.outputs[0];
841 
842  // Mux input valid signals.
843  auto selectedInputValid =
844  s.mux(truncatedSelect, unwrapped.getInputValids());
845  // Result is valid when the selected input and the select input is valid.
846  auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
847  res.valid->setValue(selAndInputValid);
848  auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
849 
850  // Select is ready when result is valid and ready (result transacting).
851  select.ready->setValue(resValidAndReady);
852 
853  // Assign each input ready signal if it is currently selected.
854  for (auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
855  // Extract the selection bit for this input.
856  auto isSelected = s.bit(select1h, inIdx);
857 
858  // '&' that with the result valid and ready, and assign to the input
859  // ready signal.
860  auto activeAndResultValidAndReady =
861  s.bAnd({isSelected, resValidAndReady});
862  in.ready->setValue(activeAndResultValidAndReady);
863  }
864 
865  // ============================== Data logic ===============================
866  res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
867  }
868 
869  // Builds fork logic between the single input and multiple outputs' control
870  // networks. Caller is expected to handle data separately.
871  void buildForkLogic(RTLBuilder &s, BackedgeBuilder &bb, InputHandshake &input,
872  ArrayRef<OutputHandshake> outputs) const {
873  auto c0I1 = s.constant(1, 0);
874  llvm::SmallVector<Value> doneWires;
875  for (auto [i, output] : llvm::enumerate(outputs)) {
876  auto doneBE = bb.get(s.b.getI1Type());
877  auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
878  auto emittedReg = s.reg("emitted_" + std::to_string(i), emitted, c0I1);
879  auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
880  output.valid->setValue(outValid);
881  auto validReady = s.bAnd({output.ready, outValid});
882  auto done = s.bOr({validReady, emittedReg}, "done" + std::to_string(i));
883  doneBE.setValue(done);
884  doneWires.push_back(done);
885  }
886  input.ready->setValue(s.bAnd(doneWires, "allDone"));
887  }
888 
889  // Builds a unit-rate actor around an inner operation. 'unitBuilder' is a
890  // function which takes the set of unwrapped data inputs, and returns a
891  // value which should be assigned to the output data value.
892  void buildUnitRateJoinLogic(
893  RTLBuilder &s, UnwrappedIO &unwrappedIO,
894  llvm::function_ref<Value(ValueRange)> unitBuilder) const {
895  assert(unwrappedIO.outputs.size() == 1 &&
896  "Expected exactly one output for unit-rate join actor");
897  // Control logic.
898  this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
899 
900  // Data logic.
901  auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
902  unwrappedIO.outputs[0].data->setValue(unitRes);
903  }
904 
905  void buildUnitRateForkLogic(
906  RTLBuilder &s, BackedgeBuilder &bb, UnwrappedIO &unwrappedIO,
907  llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder) const {
908  assert(unwrappedIO.inputs.size() == 1 &&
909  "Expected exactly one input for unit-rate fork actor");
910  // Control logic.
911  this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
912 
913  // Data logic.
914  auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
915  assert(unitResults.size() == unwrappedIO.outputs.size() &&
916  "Expected unit builder to return one result per output");
917  for (auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
918  outport.data->setValue(res);
919  }
920 
921  void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
922  bool signExtend) const {
923  size_t outWidth =
924  toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
925  .getIntOrFloatBitWidth();
926  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
927  if (signExtend)
928  return s.sext(inputs[0], outWidth);
929  return s.zext(inputs[0], outWidth);
930  });
931  }
932 
933  void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
934  unsigned targetWidth) const {
935  size_t outWidth =
936  toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
937  .getIntOrFloatBitWidth();
938  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
939  return s.truncate(inputs[0], outWidth);
940  });
941  }
942 
943  /// Return the number of bits needed to index the given number of values.
944  static size_t getNumIndexBits(uint64_t numValues) {
945  return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
946  }
947 
948  Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
949  Value defaultValue,
950  DenseMap<size_t, Value> &indexMapping) const {
951  auto numInputs = inputs.size();
952  auto priorityArb = defaultValue;
953 
954  for (size_t i = numInputs; i > 0; --i) {
955  size_t inputIndex = i - 1;
956  size_t oneHotIndex = size_t{1} << inputIndex;
957  auto constIndex = s.constant(numInputs, oneHotIndex);
958  indexMapping[inputIndex] = constIndex;
959  priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
960  }
961  return priorityArb;
962  }
963 
964 private:
965  OpBuilder &submoduleBuilder;
966  HandshakeLoweringState &ls;
967 };
968 
969 class ForkConversionPattern : public HandshakeConversionPattern<ForkOp> {
970 public:
971  using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
972  void buildModule(ForkOp op, BackedgeBuilder &bb, RTLBuilder &s,
973  hw::HWModulePortAccessor &ports) const override {
974  auto unwrapped = unwrapIO(s, bb, ports);
975  buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
976  return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
977  });
978  }
979 };
980 
981 class JoinConversionPattern : public HandshakeConversionPattern<JoinOp> {
982 public:
983  using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
984  void buildModule(JoinOp op, BackedgeBuilder &bb, RTLBuilder &s,
985  hw::HWModulePortAccessor &ports) const override {
986  auto unwrappedIO = unwrapIO(s, bb, ports);
987  buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
988  unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
989  };
990 };
991 
992 class SyncConversionPattern : public HandshakeConversionPattern<SyncOp> {
993 public:
994  using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
995  void buildModule(SyncOp op, BackedgeBuilder &bb, RTLBuilder &s,
996  hw::HWModulePortAccessor &ports) const override {
997  auto unwrappedIO = unwrapIO(s, bb, ports);
998 
999  // A helper wire that will be used to connect the two built logics
1000  HandshakeWire wire(bb, s.b.getNoneType());
1001 
1002  OutputHandshake output = wire.getAsOutput();
1003  buildJoinLogic(s, unwrappedIO.inputs, output);
1004 
1005  InputHandshake input = wire.getAsInput();
1006 
1007  // The state-keeping fork logic is required here, as the circuit isn't
1008  // allowed to wait for all the consumers to be ready. Connecting the ready
1009  // signals of the outputs to their corresponding valid signals leads to
1010  // combinatorial cycles. The paper which introduced compositional dataflow
1011  // circuits explicitly mentions this limitation:
1012  // http://arcade.cs.columbia.edu/df-memocode17.pdf
1013  buildForkLogic(s, bb, input, unwrappedIO.outputs);
1014 
1015  // Directly connect the data wires, only the control signals need to be
1016  // combined.
1017  for (auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1018  out.data->setValue(in.data);
1019  };
1020 };
1021 
1022 class MuxConversionPattern : public HandshakeConversionPattern<MuxOp> {
1023 public:
1024  using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1025  void buildModule(MuxOp op, BackedgeBuilder &bb, RTLBuilder &s,
1026  hw::HWModulePortAccessor &ports) const override {
1027  auto unwrappedIO = unwrapIO(s, bb, ports);
1028 
1029  // Extract select signal from the unwrapped IO.
1030  auto select = unwrappedIO.inputs[0];
1031  unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1032  buildMuxLogic(s, unwrappedIO, select);
1033  };
1034 };
1035 
1036 class InstanceConversionPattern
1037  : public HandshakeConversionPattern<handshake::InstanceOp> {
1038 public:
1039  using HandshakeConversionPattern<
1040  handshake::InstanceOp>::HandshakeConversionPattern;
1041  void buildModule(handshake::InstanceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1042  hw::HWModulePortAccessor &ports) const override {
1043  assert(false &&
1044  "If we indeed perform conversion in post-order, this "
1045  "should never be called. The base HandshakeConversionPattern logic "
1046  "will instantiate the external module.");
1047  }
1048 };
1049 
1050 class ReturnConversionPattern
1051  : public OpConversionPattern<handshake::ReturnOp> {
1052 public:
1053  using OpConversionPattern::OpConversionPattern;
1054  LogicalResult
1055  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1056  ConversionPatternRewriter &rewriter) const override {
1057  // Locate existing output op, Append operands to output op, and move to
1058  // the end of the block.
1059  auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1060  auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1061  outputOp->setOperands(adaptor.getOperands());
1062  outputOp->moveAfter(&parent.getBodyBlock()->back());
1063  rewriter.eraseOp(op);
1064  return success();
1065  }
1066 };
1067 
1068 // Converts an arbitrary operation into a unit rate actor. A unit rate actor
1069 // will transact once all inputs are valid and its output is ready.
1070 template <typename TIn, typename TOut = TIn>
1071 class UnitRateConversionPattern : public HandshakeConversionPattern<TIn> {
1072 public:
1073  using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1074  void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1075  hw::HWModulePortAccessor &ports) const override {
1076  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1077  this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1078  // Create TOut - it is assumed that TOut trivially
1079  // constructs from the input data signals of TIn.
1080  // To disambiguate ambiguous builders with default arguments (e.g.,
1081  // twoState UnitAttr), specify attribute array explicitly.
1082  return s.b.create<TOut>(op.getLoc(), inputs,
1083  /* attributes */ ArrayRef<NamedAttribute>{});
1084  });
1085  };
1086 };
1087 
1088 class PackConversionPattern : public HandshakeConversionPattern<PackOp> {
1089 public:
1090  using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1091  void buildModule(PackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1092  hw::HWModulePortAccessor &ports) const override {
1093  auto unwrappedIO = unwrapIO(s, bb, ports);
1094  buildUnitRateJoinLogic(s, unwrappedIO,
1095  [&](ValueRange inputs) { return s.pack(inputs); });
1096  };
1097 };
1098 
1099 class StructCreateConversionPattern
1100  : public HandshakeConversionPattern<hw::StructCreateOp> {
1101 public:
1102  using HandshakeConversionPattern<
1103  hw::StructCreateOp>::HandshakeConversionPattern;
1104  void buildModule(hw::StructCreateOp op, BackedgeBuilder &bb, RTLBuilder &s,
1105  hw::HWModulePortAccessor &ports) const override {
1106  auto unwrappedIO = unwrapIO(s, bb, ports);
1107  auto structType = op.getResult().getType();
1108  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1109  return s.pack(inputs, structType);
1110  });
1111  };
1112 };
1113 
1114 class UnpackConversionPattern : public HandshakeConversionPattern<UnpackOp> {
1115 public:
1116  using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1117  void buildModule(UnpackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1118  hw::HWModulePortAccessor &ports) const override {
1119  auto unwrappedIO = unwrapIO(s, bb, ports);
1120  buildUnitRateForkLogic(s, bb, unwrappedIO,
1121  [&](Value input) { return s.unpack(input); });
1122  };
1123 };
1124 
1125 class ConditionalBranchConversionPattern
1126  : public HandshakeConversionPattern<ConditionalBranchOp> {
1127 public:
1128  using HandshakeConversionPattern<
1129  ConditionalBranchOp>::HandshakeConversionPattern;
1130  void buildModule(ConditionalBranchOp op, BackedgeBuilder &bb, RTLBuilder &s,
1131  hw::HWModulePortAccessor &ports) const override {
1132  auto unwrappedIO = unwrapIO(s, bb, ports);
1133  auto cond = unwrappedIO.inputs[0];
1134  auto arg = unwrappedIO.inputs[1];
1135  auto trueRes = unwrappedIO.outputs[0];
1136  auto falseRes = unwrappedIO.outputs[1];
1137 
1138  auto condArgValid = s.bAnd({cond.valid, arg.valid});
1139 
1140  // Connect valid signal of both results.
1141  trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1142  falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1143 
1144  // Connecte data signals of both results.
1145  trueRes.data->setValue(arg.data);
1146  falseRes.data->setValue(arg.data);
1147 
1148  // Connect ready signal of input and condition.
1149  auto selectedResultReady =
1150  s.mux(cond.data, {falseRes.ready, trueRes.ready});
1151  auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1152  arg.ready->setValue(condArgReady);
1153  cond.ready->setValue(condArgReady);
1154  };
1155 };
1156 
1157 template <typename TIn, bool signExtend>
1158 class ExtendConversionPattern : public HandshakeConversionPattern<TIn> {
1159 public:
1160  using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1161  void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1162  hw::HWModulePortAccessor &ports) const override {
1163  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1164  this->buildExtendLogic(s, unwrappedIO, /*signExtend=*/signExtend);
1165  };
1166 };
1167 
1168 class ComparisonConversionPattern
1169  : public HandshakeConversionPattern<arith::CmpIOp> {
1170 public:
1171  using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1172  void buildModule(arith::CmpIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1173  hw::HWModulePortAccessor &ports) const override {
1174  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1175  auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1176  return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1177  return s.b.create<comb::ICmpOp>(op.getLoc(), predicate, inputs[0],
1178  inputs[1]);
1179  });
1180  };
1181 
1182  switch (op.getPredicate()) {
1183  case arith::CmpIPredicate::eq:
1184  return buildCompareLogic(comb::ICmpPredicate::eq);
1185  case arith::CmpIPredicate::ne:
1186  return buildCompareLogic(comb::ICmpPredicate::ne);
1187  case arith::CmpIPredicate::slt:
1188  return buildCompareLogic(comb::ICmpPredicate::slt);
1189  case arith::CmpIPredicate::ult:
1190  return buildCompareLogic(comb::ICmpPredicate::ult);
1191  case arith::CmpIPredicate::sle:
1192  return buildCompareLogic(comb::ICmpPredicate::sle);
1193  case arith::CmpIPredicate::ule:
1194  return buildCompareLogic(comb::ICmpPredicate::ule);
1195  case arith::CmpIPredicate::sgt:
1196  return buildCompareLogic(comb::ICmpPredicate::sgt);
1197  case arith::CmpIPredicate::ugt:
1198  return buildCompareLogic(comb::ICmpPredicate::ugt);
1199  case arith::CmpIPredicate::sge:
1200  return buildCompareLogic(comb::ICmpPredicate::sge);
1201  case arith::CmpIPredicate::uge:
1202  return buildCompareLogic(comb::ICmpPredicate::uge);
1203  }
1204  assert(false && "invalid CmpIOp");
1205  };
1206 };
1207 
1208 class TruncateConversionPattern
1209  : public HandshakeConversionPattern<arith::TruncIOp> {
1210 public:
1211  using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1212  void buildModule(arith::TruncIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1213  hw::HWModulePortAccessor &ports) const override {
1214  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1215  unsigned targetBits =
1216  toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1217  buildTruncateLogic(s, unwrappedIO, targetBits);
1218  };
1219 };
1220 
1221 class ControlMergeConversionPattern
1222  : public HandshakeConversionPattern<ControlMergeOp> {
1223 public:
1224  using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1225  void buildModule(ControlMergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1226  hw::HWModulePortAccessor &ports) const override {
1227  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1228  auto resData = unwrappedIO.outputs[0];
1229  auto resIndex = unwrappedIO.outputs[1];
1230 
1231  // Define some common types and values that will be used.
1232  unsigned numInputs = unwrappedIO.inputs.size();
1233  auto indexType = s.b.getIntegerType(numInputs);
1234  Value noWinner = s.constant(numInputs, 0);
1235  Value c0I1 = s.constant(1, 0);
1236 
1237  // Declare register for storing arbitration winner.
1238  auto won = bb.get(indexType);
1239  Value wonReg = s.reg("won_reg", won, noWinner);
1240 
1241  // Declare wire for arbitration winner.
1242  auto win = bb.get(indexType);
1243 
1244  // Declare wire for whether the circuit just fired and emitted both
1245  // outputs.
1246  auto fired = bb.get(s.b.getI1Type());
1247 
1248  // Declare registers for storing if each output has been emitted.
1249  auto resultEmitted = bb.get(s.b.getI1Type());
1250  Value resultEmittedReg = s.reg("result_emitted_reg", resultEmitted, c0I1);
1251  auto indexEmitted = bb.get(s.b.getI1Type());
1252  Value indexEmittedReg = s.reg("index_emitted_reg", indexEmitted, c0I1);
1253 
1254  // Declare wires for if each output is done.
1255  auto resultDone = bb.get(s.b.getI1Type());
1256  auto indexDone = bb.get(s.b.getI1Type());
1257 
1258  // Create predicates to assert if the win wire or won register hold a
1259  // valid index.
1260  auto hasWinnerCondition = s.rOr({win});
1261  auto hadWinnerCondition = s.rOr({wonReg});
1262 
1263  // Create an arbiter based on a simple priority-encoding scheme to assign
1264  // an index to the win wire. If the won register is set, just use that. In
1265  // the case that won is not set and no input is valid, set a sentinel
1266  // value to indicate no winner was chosen. The constant values are
1267  // remembered in a map so they can be re-used later to assign the arg
1268  // ready outputs.
1269  DenseMap<size_t, Value> argIndexValues;
1270  Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1271  noWinner, argIndexValues);
1272  priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1273  win.setValue(priorityArb);
1274 
1275  // Create the logic to assign the result and index outputs. The result
1276  // valid output will always be assigned, and if isControl is not set, the
1277  // result data output will also be assigned. The index valid and data
1278  // outputs will always be assigned. The win wire from the arbiter is used
1279  // to index into a tree of muxes to select the chosen input's signal(s),
1280  // and is fed directly to the index output. Both the result and index
1281  // valid outputs are gated on the win wire being set to something other
1282  // than the sentinel value.
1283  auto resultNotEmitted = s.bNot(resultEmittedReg);
1284  auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1285  resData.valid->setValue(resultValid);
1286  resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1287 
1288  auto indexNotEmitted = s.bNot(indexEmittedReg);
1289  auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1290  resIndex.valid->setValue(indexValid);
1291 
1292  // Use the one-hot win wire to select the index to output in the index
1293  // data.
1294  SmallVector<Value, 8> indexOutputs;
1295  for (size_t i = 0; i < numInputs; ++i)
1296  indexOutputs.push_back(s.constant(64, i));
1297 
1298  auto indexOutput = s.ohMux(win, indexOutputs);
1299  resIndex.data->setValue(indexOutput);
1300 
1301  // Create the logic to set the won register. If the fired wire is
1302  // asserted, we have finished this round and can and reset the register to
1303  // the sentinel value that indicates there is no winner. Otherwise, we
1304  // need to hold the value of the win register until we can fire.
1305  won.setValue(s.mux(fired, {win, noWinner}));
1306 
1307  // Create the logic to set the done wires for the result and index. For
1308  // both outputs, the done wire is asserted when the output is valid and
1309  // ready, or the emitted register for that output is set.
1310  auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1311  resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1312 
1313  auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1314  indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1315 
1316  // Create the logic to set the fired wire. It is asserted when both result
1317  // and index are done.
1318  fired.setValue(s.bAnd({resultDone, indexDone}));
1319 
1320  // Create the logic to assign the emitted registers. If the fired wire is
1321  // asserted, we have finished this round and can reset the registers to 0.
1322  // Otherwise, we need to hold the values of the done registers until we
1323  // can fire.
1324  resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1325  indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1326 
1327  // Create the logic to assign the arg ready outputs. The logic is
1328  // identical for each arg. If the fired wire is asserted, and the win wire
1329  // holds an arg's index, that arg is ready.
1330  auto winnerOrDefault = s.mux(fired, {noWinner, win});
1331  for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1332  auto &indexValue = argIndexValues[i];
1333  ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1334  }
1335  };
1336 };
1337 
1338 class MergeConversionPattern : public HandshakeConversionPattern<MergeOp> {
1339 public:
1340  using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1341  void buildModule(MergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1342  hw::HWModulePortAccessor &ports) const override {
1343  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1344  auto resData = unwrappedIO.outputs[0];
1345 
1346  // Define some common types and values that will be used.
1347  unsigned numInputs = unwrappedIO.inputs.size();
1348  auto indexType = s.b.getIntegerType(numInputs);
1349  Value noWinner = s.constant(numInputs, 0);
1350 
1351  // Declare wire for arbitration winner.
1352  auto win = bb.get(indexType);
1353 
1354  // Create predicates to assert if the win wire holds a valid index.
1355  auto hasWinnerCondition = s.rOr(win);
1356 
1357  // Create an arbiter based on a simple priority-encoding scheme to assign an
1358  // index to the win wire. In the case that no input is valid, set a sentinel
1359  // value to indicate no winner was chosen. The constant values are
1360  // remembered in a map so they can be re-used later to assign the arg ready
1361  // outputs.
1362  DenseMap<size_t, Value> argIndexValues;
1363  Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1364  noWinner, argIndexValues);
1365  win.setValue(priorityArb);
1366 
1367  // Create the logic to assign the result outputs. The result valid and data
1368  // outputs will always be assigned. The win wire from the arbiter is used to
1369  // index into a tree of muxes to select the chosen input's signal(s). The
1370  // result outputs are gated on the win wire being non-zero.
1371 
1372  resData.valid->setValue(hasWinnerCondition);
1373  resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1374 
1375  // Create the logic to set the done wires for the result. The done wire is
1376  // asserted when the output is valid and ready, or the emitted register is
1377  // set.
1378  auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1379 
1380  // Create the logic to assign the arg ready outputs. The logic is
1381  // identical for each arg. If the fired wire is asserted, and the win wire
1382  // holds an arg's index, that arg is ready.
1383  auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1384  for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1385  auto &indexValue = argIndexValues[i];
1386  ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1387  }
1388  };
1389 };
1390 
1391 class LoadConversionPattern
1392  : public HandshakeConversionPattern<handshake::LoadOp> {
1393 public:
1394  using HandshakeConversionPattern<
1395  handshake::LoadOp>::HandshakeConversionPattern;
1396  void buildModule(handshake::LoadOp op, BackedgeBuilder &bb, RTLBuilder &s,
1397  hw::HWModulePortAccessor &ports) const override {
1398  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1399  auto addrFromUser = unwrappedIO.inputs[0];
1400  auto dataFromMem = unwrappedIO.inputs[1];
1401  auto controlIn = unwrappedIO.inputs[2];
1402  auto dataToUser = unwrappedIO.outputs[0];
1403  auto addrToMem = unwrappedIO.outputs[1];
1404 
1405  addrToMem.data->setValue(addrFromUser.data);
1406  dataToUser.data->setValue(dataFromMem.data);
1407 
1408  // The valid/ready logic between user address/control to memoryAddr is
1409  // join logic.
1410  buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1411 
1412  // The valid/ready logic between memoryData and outputData is a direct
1413  // connection.
1414  dataToUser.valid->setValue(dataFromMem.valid);
1415  dataFromMem.ready->setValue(dataToUser.ready);
1416  };
1417 };
1418 
1419 class StoreConversionPattern
1420  : public HandshakeConversionPattern<handshake::StoreOp> {
1421 public:
1422  using HandshakeConversionPattern<
1423  handshake::StoreOp>::HandshakeConversionPattern;
1424  void buildModule(handshake::StoreOp op, BackedgeBuilder &bb, RTLBuilder &s,
1425  hw::HWModulePortAccessor &ports) const override {
1426  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1427  auto addrFromUser = unwrappedIO.inputs[0];
1428  auto dataFromUser = unwrappedIO.inputs[1];
1429  auto controlIn = unwrappedIO.inputs[2];
1430  auto dataToMem = unwrappedIO.outputs[0];
1431  auto addrToMem = unwrappedIO.outputs[1];
1432 
1433  // Create a gate that will be asserted when all outputs are ready.
1434  auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1435 
1436  // Build the standard join logic from the inputs to the inputsValid and
1437  // outputsReady signals.
1438  HandshakeWire joinWire(bb, s.b.getNoneType());
1439  joinWire.ready->setValue(outputsReady);
1440  OutputHandshake joinOutput = joinWire.getAsOutput();
1441  buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1442 
1443  // Output address and data signals are connected directly.
1444  addrToMem.data->setValue(addrFromUser.data);
1445  dataToMem.data->setValue(dataFromUser.data);
1446 
1447  // Output valid signals are connected from the inputsValid wire.
1448  addrToMem.valid->setValue(*joinWire.valid);
1449  dataToMem.valid->setValue(*joinWire.valid);
1450  };
1451 };
1452 
1453 class MemoryConversionPattern
1454  : public HandshakeConversionPattern<handshake::MemoryOp> {
1455 public:
1456  using HandshakeConversionPattern<
1457  handshake::MemoryOp>::HandshakeConversionPattern;
1458  void buildModule(handshake::MemoryOp op, BackedgeBuilder &bb, RTLBuilder &s,
1459  hw::HWModulePortAccessor &ports) const override {
1460  auto loc = op.getLoc();
1461 
1462  // Gather up the load and store ports.
1463  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1464  struct LoadPort {
1465  InputHandshake &addr;
1466  OutputHandshake &data;
1467  OutputHandshake &done;
1468  };
1469  struct StorePort {
1470  InputHandshake &addr;
1471  InputHandshake &data;
1472  OutputHandshake &done;
1473  };
1474  SmallVector<LoadPort, 4> loadPorts;
1475  SmallVector<StorePort, 4> storePorts;
1476 
1477  unsigned stCount = op.getStCount();
1478  unsigned ldCount = op.getLdCount();
1479  for (unsigned i = 0, e = ldCount; i != e; ++i) {
1480  LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1481  unwrappedIO.outputs[i],
1482  unwrappedIO.outputs[ldCount + stCount + i]};
1483  loadPorts.push_back(port);
1484  }
1485 
1486  for (unsigned i = 0, e = stCount; i != e; ++i) {
1487  StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1488  unwrappedIO.inputs[i * 2],
1489  unwrappedIO.outputs[ldCount + i]};
1490  storePorts.push_back(port);
1491  }
1492 
1493  // used to drive the data wire of the control-only channels.
1494  auto c0I0 = s.constant(0, 0);
1495 
1496  auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1497  auto hlmem = s.b.create<seq::HLMemOp>(
1498  loc, s.clk, s.rst, "_handshake_memory_" + std::to_string(op.getId()),
1499  op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1500 
1501  // Create load ports...
1502  for (auto &ld : loadPorts) {
1503  llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1504  auto readData = s.b.create<seq::ReadPortOp>(loc, hlmem.getHandle(),
1505  addresses, ld.addr.valid,
1506  /*latency=*/0);
1507  ld.data.data->setValue(readData);
1508  ld.done.data->setValue(c0I0);
1509  // Create control fork for the load address valid and ready signals.
1510  buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1511  }
1512 
1513  // Create store ports...
1514  for (auto &st : storePorts) {
1515  // Create a register to buffer the valid path by 1 cycle, to match the
1516  // write latency of 1.
1517  auto writeValidBufferMuxBE = bb.get(s.b.getI1Type());
1518  auto writeValidBuffer =
1519  s.reg("writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1520  st.done.valid->setValue(writeValidBuffer);
1521  st.done.data->setValue(c0I0);
1522 
1523  // Create the logic for when both the buffered write valid signal and the
1524  // store complete ready signal are asserted.
1525  auto storeCompleted =
1526  s.bAnd({st.done.ready, writeValidBuffer}, "storeCompleted");
1527 
1528  // Create a signal for when the write valid buffer is empty or the output
1529  // is ready.
1530  auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1531  auto emptyOrComplete =
1532  s.bOr({notWriteValidBuffer, storeCompleted}, "emptyOrComplete");
1533 
1534  // Connect the gate to both the store address ready and store data ready
1535  st.addr.ready->setValue(emptyOrComplete);
1536  st.data.ready->setValue(emptyOrComplete);
1537 
1538  // Create a wire for when both the store address and data are valid.
1539  auto writeValid = s.bAnd({st.addr.valid, st.data.valid}, "writeValid");
1540 
1541  // Create a mux that drives the buffer input. If the emptyOrComplete
1542  // signal is asserted, the mux selects the writeValid signal. Otherwise,
1543  // it selects the buffer output, keeping the output registered until the
1544  // emptyOrComplete signal is asserted.
1545  writeValidBufferMuxBE.setValue(
1546  s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1547 
1548  // Instantiate the write port operation - truncate address width to memory
1549  // width.
1550  llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1551  s.b.create<seq::WritePortOp>(loc, hlmem.getHandle(), addresses,
1552  st.data.data, writeValid,
1553  /*latency=*/1);
1554  }
1555  }
1556 }; // namespace
1557 
1558 class SinkConversionPattern : public HandshakeConversionPattern<SinkOp> {
1559 public:
1560  using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1561  void buildModule(SinkOp op, BackedgeBuilder &bb, RTLBuilder &s,
1562  hw::HWModulePortAccessor &ports) const override {
1563  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1564  // A sink is always ready to accept a new value.
1565  unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1566  };
1567 };
1568 
1569 class SourceConversionPattern : public HandshakeConversionPattern<SourceOp> {
1570 public:
1571  using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1572  void buildModule(SourceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1573  hw::HWModulePortAccessor &ports) const override {
1574  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1575  // A source always provides a new (i0-typed) value.
1576  unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1577  unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1578  };
1579 };
1580 
1581 class ConstantConversionPattern
1582  : public HandshakeConversionPattern<handshake::ConstantOp> {
1583 public:
1584  using HandshakeConversionPattern<
1585  handshake::ConstantOp>::HandshakeConversionPattern;
1586  void buildModule(handshake::ConstantOp op, BackedgeBuilder &bb, RTLBuilder &s,
1587  hw::HWModulePortAccessor &ports) const override {
1588  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1589  unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1590  unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1591  auto constantValue = op->getAttrOfType<IntegerAttr>("value").getValue();
1592  unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1593  };
1594 };
1595 
1596 class BufferConversionPattern : public HandshakeConversionPattern<BufferOp> {
1597 public:
1598  using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1599  void buildModule(BufferOp op, BackedgeBuilder &bb, RTLBuilder &s,
1600  hw::HWModulePortAccessor &ports) const override {
1601  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1602  auto input = unwrappedIO.inputs[0];
1603  auto output = unwrappedIO.outputs[0];
1604  InputHandshake lastStage;
1605  SmallVector<int64_t> initValues;
1606 
1607  // For now, always build seq buffers.
1608  if (op.getInitValues())
1609  initValues = op.getInitValueArray();
1610 
1611  lastStage =
1612  buildSeqBufferLogic(s, bb, toValidType(op.getDataType()),
1613  op.getNumSlots(), input, output, initValues);
1614 
1615  // Connect the last stage to the output handshake.
1616  output.data->setValue(lastStage.data);
1617  output.valid->setValue(lastStage.valid);
1618  lastStage.ready->setValue(output.ready);
1619  };
1620 
1621  struct SeqBufferStage {
1622  SeqBufferStage(Type dataType, InputHandshake &preStage, BackedgeBuilder &bb,
1623  RTLBuilder &s, size_t index,
1624  std::optional<int64_t> initValue)
1625  : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1626 
1627  // Todo: Change when i0 support is added.
1628  c0s = createZeroDataConst(s, s.loc, dataType);
1629  currentStage.ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
1630 
1631  auto hasInitValue = s.constant(1, initValue.has_value());
1632  auto validBE = bb.get(s.b.getI1Type());
1633  auto validReg = s.reg(getRegName("valid"), validBE, hasInitValue);
1634  auto readyBE = bb.get(s.b.getI1Type());
1635 
1636  Value initValueCs = c0s;
1637  if (initValue.has_value())
1638  initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1639 
1640  // This could/should be revised but needs a larger rethinking to avoid
1641  // introducing new bugs.
1642  Value dataReg =
1643  buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1644  buildControlBufferLogic(validReg, readyBE, dataReg);
1645  }
1646 
1647  StringAttr getRegName(StringRef name) {
1648  return s.b.getStringAttr(name + std::to_string(index) + "_reg");
1649  }
1650 
1651  void buildControlBufferLogic(Value validReg, Backedge &readyBE,
1652  Value dataReg) {
1653  auto c0I1 = s.constant(1, 0);
1654  auto readyRegWire = bb.get(s.b.getI1Type());
1655  auto readyReg = s.reg(getRegName("ready"), readyRegWire, c0I1);
1656 
1657  // Create the logic to drive the current stage valid and potentially
1658  // data.
1659  currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1660  "controlValid" + std::to_string(index));
1661 
1662  // Create the logic to drive the current stage ready.
1663  auto notReadyReg = s.bNot(readyReg);
1664  readyBE.setValue(notReadyReg);
1665 
1666  auto succNotReady = s.bNot(*currentStage.ready);
1667  auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1668  auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1669  auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1670 
1671  // Create a mux for emptying the register when both are ready.
1672  auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1673  readyRegWire.setValue(resetSignal);
1674 
1675  // Add same logic for the data path if necessary.
1676  auto ctrlDataRegBE = bb.get(dataType);
1677  auto ctrlDataReg = s.reg(getRegName("ctrl_data"), ctrlDataRegBE, c0s);
1678  auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1679  currentStage.data = dataResult;
1680 
1681  auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1682  auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1683  ctrlDataRegBE.setValue(dataResetSignal);
1684  }
1685 
1686  Value buildDataBufferLogic(Value validReg, Value initValue,
1687  Backedge &validBE, Backedge &readyBE) {
1688  // Create a signal for when the valid register is empty or the successor
1689  // is ready to accept new token.
1690  auto notValidReg = s.bNot(validReg);
1691  auto emptyOrReady = s.bOr({notValidReg, readyBE});
1692  preStage.ready->setValue(emptyOrReady);
1693 
1694  // Create a mux that drives the register input. If the emptyOrReady
1695  // signal is asserted, the mux selects the predValid signal. Otherwise,
1696  // it selects the register output, keeping the output registered
1697  // unchanged.
1698  auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1699 
1700  // Now we can drive the valid register.
1701  validBE.setValue(validRegMux);
1702 
1703  // Create a mux that drives the date register.
1704  auto dataRegBE = bb.get(dataType);
1705  auto dataReg =
1706  s.reg(getRegName("data"),
1707  s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1708  dataRegBE.setValue(dataReg);
1709  return dataReg;
1710  }
1711 
1712  InputHandshake getOutput() { return currentStage; }
1713 
1714  Type dataType;
1715  InputHandshake &preStage;
1716  InputHandshake currentStage;
1717  RTLBuilder &s;
1718  BackedgeBuilder &bb;
1719  size_t index;
1720 
1721  // A zero-valued constant of equal type as the data type of this buffer.
1722  Value c0s;
1723  };
1724 
1725  InputHandshake buildSeqBufferLogic(RTLBuilder &s, BackedgeBuilder &bb,
1726  Type dataType, unsigned size,
1727  InputHandshake &input,
1728  OutputHandshake &output,
1729  llvm::ArrayRef<int64_t> initValues) const {
1730  // Prime the buffer building logic with an initial stage, which just
1731  // wraps the input handshake.
1732  InputHandshake currentStage = input;
1733 
1734  for (unsigned i = 0; i < size; ++i) {
1735  bool isInitialized = i < initValues.size();
1736  auto initValue =
1737  isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1738  currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1739  .getOutput();
1740  }
1741 
1742  return currentStage;
1743  };
1744 };
1745 
1746 class IndexCastConversionPattern
1747  : public HandshakeConversionPattern<arith::IndexCastOp> {
1748 public:
1749  using HandshakeConversionPattern<
1750  arith::IndexCastOp>::HandshakeConversionPattern;
1751  void buildModule(arith::IndexCastOp op, BackedgeBuilder &bb, RTLBuilder &s,
1752  hw::HWModulePortAccessor &ports) const override {
1753  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1754  unsigned sourceBits =
1755  toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1756  unsigned targetBits =
1757  toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1758  if (targetBits < sourceBits)
1759  buildTruncateLogic(s, unwrappedIO, targetBits);
1760  else
1761  buildExtendLogic(s, unwrappedIO, /*signExtend=*/true);
1762  };
1763 };
1764 
1765 template <typename T>
1766 class ExtModuleConversionPattern : public OpConversionPattern<T> {
1767 public:
1768  ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1769  MLIRContext *context, OpBuilder &submoduleBuilder,
1770  HandshakeLoweringState &ls)
1771  : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
1772  submoduleBuilder(submoduleBuilder), ls(ls) {}
1773  using OpAdaptor = typename T::Adaptor;
1774 
1775  LogicalResult
1776  matchAndRewrite(T op, OpAdaptor adaptor,
1777  ConversionPatternRewriter &rewriter) const override {
1778 
1779  hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
1780  if (!implModule) {
1781  auto portInfo = ModulePortInfo(getPortInfoForOp(op));
1782  implModule = submoduleBuilder.create<hw::HWModuleExternOp>(
1783  op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
1784  portInfo);
1785  }
1786 
1787  llvm::SmallVector<Value> operands = adaptor.getOperands();
1788  addSequentialIOOperandsIfNeeded(op, operands);
1789  rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1790  op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1791  return success();
1792  }
1793 
1794 private:
1795  OpBuilder &submoduleBuilder;
1796  HandshakeLoweringState &ls;
1797 };
1798 
1799 class FuncOpConversionPattern : public OpConversionPattern<handshake::FuncOp> {
1800 public:
1801  using OpConversionPattern::OpConversionPattern;
1802 
1803  LogicalResult
1804  matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
1805  ConversionPatternRewriter &rewriter) const override {
1806  ModulePortInfo ports =
1807  getPortInfoForOpTypes(op, op.getArgumentTypes(), op.getResultTypes());
1808 
1809  HWModuleLike hwModule;
1810  if (op.isExternal()) {
1811  hwModule = rewriter.create<hw::HWModuleExternOp>(
1812  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1813  } else {
1814  auto hwModuleOp = rewriter.create<hw::HWModuleOp>(
1815  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1816  auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1817  rewriter.inlineBlockBefore(&op.getBody().front(),
1818  hwModuleOp.getBodyBlock()->getTerminator(),
1819  args);
1820  hwModule = hwModuleOp;
1821  }
1822 
1823  // Was any predeclaration associated with this func? If so, replace uses
1824  // with the newly created module and erase the predeclaration.
1825  if (auto predecl =
1826  op->getAttrOfType<FlatSymbolRefAttr>(kPredeclarationAttr)) {
1827  auto *parentOp = op->getParentOp();
1828  auto *predeclModule =
1829  SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1830  if (predeclModule) {
1831  if (failed(SymbolTable::replaceAllSymbolUses(
1832  predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1833  return failure();
1834  rewriter.eraseOp(predeclModule);
1835  }
1836  }
1837 
1838  rewriter.eraseOp(op);
1839  return success();
1840  }
1841 };
1842 
1843 } // namespace
1844 
1845 //===----------------------------------------------------------------------===//
1846 // HW Top-module Related Functions
1847 //===----------------------------------------------------------------------===//
1848 
1849 static LogicalResult convertFuncOp(ESITypeConverter &typeConverter,
1850  ConversionTarget &target,
1851  handshake::FuncOp op,
1852  OpBuilder &moduleBuilder) {
1853 
1854  std::map<std::string, unsigned> instanceNameCntr;
1855  NameUniquer instanceUniquer = [&](Operation *op) {
1856  std::string instName = getCallName(op);
1857  if (auto idAttr = op->getAttrOfType<IntegerAttr>("handshake_id"); idAttr) {
1858  // We use a special naming convention for operations which have a
1859  // 'handshake_id' attribute.
1860  instName += "_id" + std::to_string(idAttr.getValue().getZExtValue());
1861  } else {
1862  // Fallback to just prefixing with an integer.
1863  instName += std::to_string(instanceNameCntr[instName]++);
1864  }
1865  return instName;
1866  };
1867 
1868  auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1869  instanceUniquer};
1870  RewritePatternSet patterns(op.getContext());
1871  patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1872  op.getContext());
1873  patterns.insert<JoinConversionPattern, ForkConversionPattern,
1874  SyncConversionPattern>(typeConverter, op.getContext(),
1875  moduleBuilder, ls);
1876 
1877  patterns.insert<
1878  // Comb operations.
1879  UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1880  UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1881  UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1882  UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1883  UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1884  UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1885  UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1886  UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1887  UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1888  UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1889  UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1890  UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1891  UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1892  UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1893  // HW operations.
1894  StructCreateConversionPattern,
1895  // Handshake operations.
1896  ConditionalBranchConversionPattern, MuxConversionPattern,
1897  PackConversionPattern, UnpackConversionPattern,
1898  ComparisonConversionPattern, BufferConversionPattern,
1899  SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1900  MergeConversionPattern, ControlMergeConversionPattern,
1901  LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1902  InstanceConversionPattern,
1903  // Arith operations.
1904  ExtendConversionPattern<arith::ExtUIOp, /*signExtend=*/false>,
1905  ExtendConversionPattern<arith::ExtSIOp, /*signExtend=*/true>,
1906  TruncateConversionPattern, IndexCastConversionPattern>(
1907  typeConverter, op.getContext(), moduleBuilder, ls);
1908 
1909  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1910  return op->emitOpError() << "error during conversion";
1911  return success();
1912 }
1913 
1914 namespace {
1915 class HandshakeToHWPass : public HandshakeToHWBase<HandshakeToHWPass> {
1916 public:
1917  void runOnOperation() override {
1918  mlir::ModuleOp mod = getOperation();
1919 
1920  // Lowering to HW requires that every value is used exactly once. Check
1921  // whether this precondition is met, and if not, exit.
1922  for (auto f : mod.getOps<handshake::FuncOp>()) {
1923  if (failed(verifyAllValuesHasOneUse(f))) {
1924  f.emitOpError() << "HandshakeToHW: failed to verify that all values "
1925  "are used exactly once. Remember to run the "
1926  "fork/sink materialization pass before HW lowering.";
1927  signalPassFailure();
1928  return;
1929  }
1930  }
1931 
1932  // Resolve the instance graph to get a top-level module.
1933  std::string topLevel;
1935  SmallVector<std::string> sortedFuncs;
1936  if (resolveInstanceGraph(mod, uses, topLevel, sortedFuncs).failed()) {
1937  signalPassFailure();
1938  return;
1939  }
1940 
1941  ESITypeConverter typeConverter;
1942  ConversionTarget target(getContext());
1943  // All top-level logic of a handshake module will be the interconnectivity
1944  // between instantiated modules.
1945  target.addLegalOp<hw::HWModuleOp, hw::HWModuleExternOp, hw::OutputOp,
1946  hw::InstanceOp>();
1947  target
1948  .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1949 
1950  // Convert the handshake.func operations in post-order wrt. the instance
1951  // graph. This ensures that any referenced submodules (through
1952  // handshake.instance) has already been lowered, and their HW module
1953  // equivalents are available.
1954  OpBuilder submoduleBuilder(mod.getContext());
1955  submoduleBuilder.setInsertionPointToStart(mod.getBody());
1956  for (auto &funcName : llvm::reverse(sortedFuncs)) {
1957  auto funcOp = mod.lookupSymbol<handshake::FuncOp>(funcName);
1958  assert(funcOp && "handshake.func not found in module!");
1959  if (failed(
1960  convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
1961  signalPassFailure();
1962  return;
1963  }
1964  }
1965 
1966  // Second stage: Convert any handshake.extmemory operations and the
1967  // top-level I/O associated with these.
1968  for (auto hwModule : mod.getOps<hw::HWModuleOp>())
1969  if (failed(convertExtMemoryOps(hwModule)))
1970  return signalPassFailure();
1971  }
1972 };
1973 } // end anonymous namespace
1974 
1975 std::unique_ptr<mlir::Pass> circt::createHandshakeToHWPass() {
1976  return std::make_unique<HandshakeToHWPass>();
1977 }
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
MlirType elementType
Definition: CHIRRTL.cpp:29
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
Definition: CalyxOps.cpp:119
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:538
static Type tupleToStruct(TupleType tuple)
Definition: DCToHW.cpp:43
std::function< std::string(Operation *)> NameUniquer
Definition: DCToHW.cpp:40
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
Definition: FIRRTLOps.cpp:979
int32_t width
Definition: FIRRTL.cpp:36
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
static std::string getCallName(Operation *op)
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
Definition: OM.cpp:96
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Channels are the basic communication primitives.
Definition: Types.h:63
const Type * getInner() const
Definition: Types.h:66
def create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition: seq.py:137
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition: CombOps.cpp:25
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
Definition: PassHelpers.cpp:38
static constexpr const char * kPredeclarationAttr
Definition: HandshakeToHW.h:37
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Type toValidType(Type t)
esi::ChannelType esiWrapper(mlir::Type t)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getInputs()
PortDirectionRange getOutputs()