CIRCT  19.0.0git
HandshakeToHW.cpp
Go to the documentation of this file.
1 //===- HandshakeToHW.cpp - Translate Handshake into HW ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main Handshake to HW Conversion Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
24 #include "mlir/Dialect/Arith/IR/Arith.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <optional>
33 
34 namespace circt {
35 #define GEN_PASS_DEF_HANDSHAKETOHW
36 #include "circt/Conversion/Passes.h.inc"
37 } // namespace circt
38 
39 using namespace mlir;
40 using namespace circt;
41 using namespace circt::handshake;
42 using namespace circt::hw;
43 
44 using NameUniquer = std::function<std::string(Operation *)>;
45 
46 namespace {
47 
48 static Type tupleToStruct(TypeRange types) {
49  return toValidType(mlir::TupleType::get(types[0].getContext(), types));
50 }
51 
52 // Shared state used by various functions; captured in a struct to reduce the
53 // number of arguments that we have to pass around.
54 struct HandshakeLoweringState {
55  ModuleOp parentModule;
56  NameUniquer nameUniquer;
57 };
58 
59 // A type converter is needed to perform the in-flight materialization of "raw"
60 // (non-ESI channel) types to their ESI channel correspondents. This comes into
61 // effect when backedges exist in the input IR.
62 class ESITypeConverter : public TypeConverter {
63 public:
64  ESITypeConverter() {
65  addConversion([](Type type) -> Type { return esiWrapper(type); });
66 
67  addTargetMaterialization(
68  [&](mlir::OpBuilder &builder, mlir::Type resultType,
69  mlir::ValueRange inputs,
70  mlir::Location loc) -> std::optional<mlir::Value> {
71  if (inputs.size() != 1)
72  return std::nullopt;
73  return inputs[0];
74  });
75 
76  addSourceMaterialization(
77  [&](mlir::OpBuilder &builder, mlir::Type resultType,
78  mlir::ValueRange inputs,
79  mlir::Location loc) -> std::optional<mlir::Value> {
80  if (inputs.size() != 1)
81  return std::nullopt;
82  return inputs[0];
83  });
84  }
85 };
86 
87 } // namespace
88 
89 /// Returns a submodule name resulting from an operation, without discriminating
90 /// type information.
91 static std::string getBareSubModuleName(Operation *oldOp) {
92  // The dialect name is separated from the operation name by '.', which is not
93  // valid in SystemVerilog module names. In case this name is used in
94  // SystemVerilog output, replace '.' with '_'.
95  std::string subModuleName = oldOp->getName().getStringRef().str();
96  std::replace(subModuleName.begin(), subModuleName.end(), '.', '_');
97  return subModuleName;
98 }
99 
100 static std::string getCallName(Operation *op) {
101  auto callOp = dyn_cast<handshake::InstanceOp>(op);
102  return callOp ? callOp.getModule().str() : getBareSubModuleName(op);
103 }
104 
105 /// Extracts the type of the data-carrying type of opType. If opType is an ESI
106 /// channel, getHandshakeBundleDataType extracts the data-carrying type, else,
107 /// assume that opType itself is the data-carrying type.
108 static Type getOperandDataType(Value op) {
109  auto opType = op.getType();
110  if (auto channelType = dyn_cast<esi::ChannelType>(opType))
111  return channelType.getInner();
112  return opType;
113 }
114 
115 /// Filters NoneType's from the input.
116 static SmallVector<Type> filterNoneTypes(ArrayRef<Type> input) {
117  SmallVector<Type> filterRes;
118  llvm::copy_if(input, std::back_inserter(filterRes),
119  [](Type type) { return !isa<NoneType>(type); });
120  return filterRes;
121 }
122 
123 /// Returns a set of types which may uniquely identify the provided op. Return
124 /// value is <inputTypes, outputTypes>.
125 using DiscriminatingTypes = std::pair<SmallVector<Type>, SmallVector<Type>>;
127  return TypeSwitch<Operation *, DiscriminatingTypes>(op)
128  .Case<MemoryOp, ExternalMemoryOp>([&](auto memOp) {
129  return DiscriminatingTypes{{},
130  {memOp.getMemRefType().getElementType()}};
131  })
132  .Default([&](auto) {
133  // By default, all in- and output types which is not a control type
134  // (NoneType) are discriminating types.
135  std::vector<Type> inTypes, outTypes;
136  llvm::transform(op->getOperands(), std::back_inserter(inTypes),
138  llvm::transform(op->getResults(), std::back_inserter(outTypes),
140  return DiscriminatingTypes{filterNoneTypes(inTypes),
141  filterNoneTypes(outTypes)};
142  });
143 }
144 
145 /// Get type name. Currently we only support integer or index types.
146 /// The emitted type aligns with the getFIRRTLType() method. Thus all integers
147 /// other than signed integers will be emitted as unsigned.
148 // NOLINTNEXTLINE(misc-no-recursion)
149 static std::string getTypeName(Location loc, Type type) {
150  std::string typeName;
151  // Builtin types
152  if (type.isIntOrIndex()) {
153  if (auto indexType = dyn_cast<IndexType>(type))
154  typeName += "_ui" + std::to_string(indexType.kInternalStorageBitWidth);
155  else if (type.isSignedInteger())
156  typeName += "_si" + std::to_string(type.getIntOrFloatBitWidth());
157  else
158  typeName += "_ui" + std::to_string(type.getIntOrFloatBitWidth());
159  } else if (auto tupleType = dyn_cast<TupleType>(type)) {
160  typeName += "_tuple";
161  for (auto elementType : tupleType.getTypes())
162  typeName += getTypeName(loc, elementType);
163  } else if (auto structType = dyn_cast<hw::StructType>(type)) {
164  typeName += "_struct";
165  for (auto element : structType.getElements())
166  typeName += "_" + element.name.str() + getTypeName(loc, element.type);
167  } else
168  emitError(loc) << "unsupported data type '" << type << "'";
169 
170  return typeName;
171 }
172 
173 /// Construct a name for creating HW sub-module.
174 static std::string getSubModuleName(Operation *oldOp) {
175  if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
176  return instanceOp.getModule().str();
177 
178  std::string subModuleName = getBareSubModuleName(oldOp);
179 
180  // Add value of the constant operation.
181  if (auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
182  if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
183  auto intType = intAttr.getType();
184 
185  if (intType.isSignedInteger())
186  subModuleName += "_c" + std::to_string(intAttr.getSInt());
187  else if (intType.isUnsignedInteger())
188  subModuleName += "_c" + std::to_string(intAttr.getUInt());
189  else
190  subModuleName += "_c" + std::to_string((uint64_t)intAttr.getInt());
191  } else
192  oldOp->emitError("unsupported constant type");
193  }
194 
195  // Add discriminating in- and output types.
196  auto [inTypes, outTypes] = getHandshakeDiscriminatingTypes(oldOp);
197  if (!inTypes.empty())
198  subModuleName += "_in";
199  for (auto inType : inTypes)
200  subModuleName += getTypeName(oldOp->getLoc(), inType);
201 
202  if (!outTypes.empty())
203  subModuleName += "_out";
204  for (auto outType : outTypes)
205  subModuleName += getTypeName(oldOp->getLoc(), outType);
206 
207  // Add memory ID.
208  if (auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
209  subModuleName += "_id" + std::to_string(memOp.getId());
210 
211  // Add compare kind.
212  if (auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
213  subModuleName += "_" + stringifyEnum(comOp.getPredicate()).str();
214 
215  // Add buffer information.
216  if (auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
217  subModuleName += "_" + std::to_string(bufferOp.getNumSlots()) + "slots";
218  if (bufferOp.isSequential())
219  subModuleName += "_seq";
220  else
221  subModuleName += "_fifo";
222 
223  if (auto initValues = bufferOp.getInitValues()) {
224  subModuleName += "_init";
225  for (const Attribute e : *initValues) {
226  assert(isa<IntegerAttr>(e));
227  subModuleName +=
228  "_" + std::to_string(dyn_cast<IntegerAttr>(e).getInt());
229  }
230  }
231  }
232 
233  // Add control information.
234  if (auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
235  ctrlInterface && ctrlInterface.isControl()) {
236  // Add some additional discriminating info for non-typed operations.
237  subModuleName += "_" + std::to_string(oldOp->getNumOperands()) + "ins_" +
238  std::to_string(oldOp->getNumResults()) + "outs";
239  subModuleName += "_ctrl";
240  } else {
241  assert(
242  (!inTypes.empty() || !outTypes.empty()) &&
243  "Insufficient discriminating type info generated for the operation!");
244  }
245 
246  return subModuleName;
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // HW Sub-module Related Functions
251 //===----------------------------------------------------------------------===//
252 
253 /// Check whether a submodule with the same name has been created elsewhere in
254 /// the top level module. Return the matched module operation if true, otherwise
255 /// return nullptr.
256 static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
257  StringRef modName) {
258  if (auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
259  return mod;
260  if (auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
261  return mod;
262  return {};
263 }
264 
265 static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
266  Operation *oldOp) {
267  HWModuleLike targetModule;
268  if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
269  targetModule = checkSubModuleOp(parentModule, instanceOp.getModule());
270  else
271  targetModule = checkSubModuleOp(parentModule, getSubModuleName(oldOp));
272 
273  if (isa<handshake::InstanceOp>(oldOp))
274  assert(targetModule &&
275  "handshake.instance target modules should always have been lowered "
276  "before the modules that reference them!");
277  return targetModule;
278 }
279 
280 /// Returns a vector of PortInfo's which defines the HW interface of the
281 /// to-be-converted op.
282 static ModulePortInfo getPortInfoForOp(Operation *op) {
283  return getPortInfoForOpTypes(op, op->getOperandTypes(), op->getResultTypes());
284 }
285 
286 static llvm::SmallVector<hw::detail::FieldInfo>
287 portToFieldInfo(llvm::ArrayRef<hw::PortInfo> portInfo) {
288  llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
289  for (auto port : portInfo)
290  fieldInfo.push_back({port.name, port.type});
291 
292  return fieldInfo;
293 }
294 
295 // Convert any handshake.extmemory operations and the top-level I/O
296 // associated with these.
297 static LogicalResult convertExtMemoryOps(HWModuleOp mod) {
298  auto *ctx = mod.getContext();
299 
300  // Gather memref ports to be converted.
301  llvm::DenseMap<unsigned, Value> memrefPorts;
302  for (auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
303  auto channel = dyn_cast<esi::ChannelType>(arg.getType());
304  if (channel && isa<MemRefType>(channel.getInner()))
305  memrefPorts[i] = arg;
306  }
307 
308  if (memrefPorts.empty())
309  return success(); // nothing to do.
310 
311  OpBuilder b(mod);
312 
313  auto getMemoryIOInfo = [&](Location loc, Twine portName, unsigned argIdx,
314  ArrayRef<hw::PortInfo> info,
315  hw::ModulePort::Direction direction) {
316  auto type = hw::StructType::get(ctx, portToFieldInfo(info));
317  auto portInfo =
318  hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
319  return portInfo;
320  };
321 
322  for (auto [i, arg] : memrefPorts) {
323  // Insert ports into the module
324  auto memName = mod.getArgName(i);
325 
326  // Get the attached extmemory external module.
327  auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
328  auto extmemMod =
329  cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
330  extmemInstance, extmemInstance.getModuleNameAttr()));
331 
332  ModulePortInfo portInfo(extmemMod.getPortList());
333 
334  // The extmemory external module's interface is a direct wrapping of the
335  // original handshake.extmemory operation in- and output types. Remove the
336  // first input argument (the !esi.channel<memref> op) since that is what
337  // we're replacing with a materialized interface.
338  portInfo.eraseInput(0);
339 
340  // Add memory input - this is the output of the extmemory op.
341  SmallVector<PortInfo> outputs(portInfo.getOutputs());
342  auto inPortInfo =
343  getMemoryIOInfo(arg.getLoc(), memName.strref() + "_in", i, outputs,
345  mod.insertPorts({{i, inPortInfo}}, {});
346  auto newInPort = mod.getArgumentForInput(i);
347  // Replace the extmemory submodule outputs with the newly created inputs.
348  b.setInsertionPointToStart(mod.getBodyBlock());
349  auto newInPortExploded = b.create<hw::StructExplodeOp>(
350  arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
351  extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
352 
353  // Add memory output - this is the inputs of the extmemory op (without the
354  // first argument);
355  unsigned outArgI = mod.getNumOutputPorts();
356  SmallVector<PortInfo> inputs(portInfo.getInputs());
357  auto outPortInfo =
358  getMemoryIOInfo(arg.getLoc(), memName.strref() + "_out", outArgI,
360 
361  auto memOutputArgs = extmemInstance.getOperands().drop_front();
362  b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
363  auto memOutputStruct = b.create<hw::StructCreateOp>(
364  arg.getLoc(), outPortInfo.type, memOutputArgs);
365  mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
366 
367  // Erase the extmemory submodule instace since the i/o has now been
368  // plumbed.
369  extmemMod.erase();
370  extmemInstance.erase();
371 
372  // Erase the original memref argument of the top-level i/o now that it's use
373  // has been removed.
374  mod.modifyPorts(/*insertInputs*/ {}, /*insertOutputs*/ {},
375  /*eraseInputs*/ {i + 1}, /*eraseOutputs*/ {});
376  }
377 
378  return success();
379 }
380 
381 namespace {
382 
383 // Input handshakes contain a resolved valid and (optional )data signal, and
384 // a to-be-assigned ready signal.
385 struct InputHandshake {
386  Value valid;
387  std::shared_ptr<Backedge> ready;
388  Value data;
389 };
390 
391 // Output handshakes contain a resolved ready, and to-be-assigned valid and
392 // (optional) data signals.
393 struct OutputHandshake {
394  std::shared_ptr<Backedge> valid;
395  Value ready;
396  std::shared_ptr<Backedge> data;
397 };
398 
399 /// A helper struct that acts like a wire. Can be used to interact with the
400 /// RTLBuilder when multiple built components should be connected.
401 struct HandshakeWire {
402  HandshakeWire(BackedgeBuilder &bb, Type dataType) {
403  MLIRContext *ctx = dataType.getContext();
404  auto i1Type = IntegerType::get(ctx, 1);
405  valid = std::make_shared<Backedge>(bb.get(i1Type));
406  ready = std::make_shared<Backedge>(bb.get(i1Type));
407  data = std::make_shared<Backedge>(bb.get(dataType));
408  }
409 
410  // Functions that allow to treat a wire like an input or output port.
411  // **Careful**: Such a port will not be updated when backedges are resolved.
412  InputHandshake getAsInput() { return {*valid, ready, *data}; }
413  OutputHandshake getAsOutput() { return {valid, *ready, data}; }
414 
415  std::shared_ptr<Backedge> valid;
416  std::shared_ptr<Backedge> ready;
417  std::shared_ptr<Backedge> data;
418 };
419 
420 template <typename T, typename TInner>
421 llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
422  llvm::function_ref<T(TInner &)> extractor) {
423  llvm::SmallVector<T> result;
424  llvm::transform(container, std::back_inserter(result), extractor);
425  return result;
426 }
427 struct UnwrappedIO {
428  llvm::SmallVector<InputHandshake> inputs;
429  llvm::SmallVector<OutputHandshake> outputs;
430 
431  llvm::SmallVector<Value> getInputValids() {
432  return extractValues<Value, InputHandshake>(
433  inputs, [](auto &hs) { return hs.valid; });
434  }
435  llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
436  return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
437  inputs, [](auto &hs) { return hs.ready; });
438  }
439  llvm::SmallVector<Value> getInputDatas() {
440  return extractValues<Value, InputHandshake>(
441  inputs, [](auto &hs) { return hs.data; });
442  }
443  llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
444  return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
445  outputs, [](auto &hs) { return hs.valid; });
446  }
447  llvm::SmallVector<Value> getOutputReadys() {
448  return extractValues<Value, OutputHandshake>(
449  outputs, [](auto &hs) { return hs.ready; });
450  }
451  llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
452  return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
453  outputs, [](auto &hs) { return hs.data; });
454  }
455 };
456 
457 // A class containing a bunch of syntactic sugar to reduce builder function
458 // verbosity.
459 // @todo: should be moved to support.
460 struct RTLBuilder {
461  RTLBuilder(hw::ModulePortInfo info, OpBuilder &builder, Location loc,
462  Value clk = Value(), Value rst = Value())
463  : info(std::move(info)), b(builder), loc(loc), clk(clk), rst(rst) {}
464 
465  Value constant(const APInt &apv, std::optional<StringRef> name = {}) {
466  // Cannot use zero-width APInt's in DenseMap's, see
467  // https://github.com/llvm/llvm-project/issues/58013
468  bool isZeroWidth = apv.getBitWidth() == 0;
469  if (!isZeroWidth) {
470  auto it = constants.find(apv);
471  if (it != constants.end())
472  return it->second;
473  }
474 
475  auto cval = b.create<hw::ConstantOp>(loc, apv);
476  if (!isZeroWidth)
477  constants[apv] = cval;
478  return cval;
479  }
480 
481  Value constant(unsigned width, int64_t value,
482  std::optional<StringRef> name = {}) {
483  return constant(APInt(width, value));
484  }
485  std::pair<Value, Value> wrap(Value data, Value valid,
486  std::optional<StringRef> name = {}) {
487  auto wrapOp = b.create<esi::WrapValidReadyOp>(loc, data, valid);
488  return {wrapOp.getResult(0), wrapOp.getResult(1)};
489  }
490  std::pair<Value, Value> unwrap(Value channel, Value ready,
491  std::optional<StringRef> name = {}) {
492  auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
493  return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
494  }
495 
496  // Various syntactic sugar functions.
497  Value reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
498  Value rst = Value()) {
499  Value resolvedClk = clk ? clk : this->clk;
500  Value resolvedRst = rst ? rst : this->rst;
501  assert(resolvedClk &&
502  "No global clock provided to this RTLBuilder - a clock "
503  "signal must be provided to the reg(...) function.");
504  assert(resolvedRst &&
505  "No global reset provided to this RTLBuilder - a reset "
506  "signal must be provided to the reg(...) function.");
507 
508  return b.create<seq::CompRegOp>(loc, in, resolvedClk, resolvedRst, rstValue,
509  name);
510  }
511 
512  Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
513  std::optional<StringRef> name = {}) {
514  return b.create<comb::ICmpOp>(loc, predicate, lhs, rhs);
515  }
516 
517  Value buildNamedOp(llvm::function_ref<Value()> f,
518  std::optional<StringRef> name) {
519  Value v = f();
520  StringAttr nameAttr;
521  Operation *op = v.getDefiningOp();
522  if (name.has_value()) {
523  op->setAttr("sv.namehint", b.getStringAttr(*name));
524  nameAttr = b.getStringAttr(*name);
525  }
526  return v;
527  }
528 
529  // Bitwise 'and'.
530  Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
531  return buildNamedOp(
532  [&]() { return b.create<comb::AndOp>(loc, values, false); }, name);
533  }
534 
535  Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
536  return buildNamedOp(
537  [&]() { return b.create<comb::OrOp>(loc, values, false); }, name);
538  }
539 
540  // Bitwise 'not'.
541  Value bNot(Value value, std::optional<StringRef> name = {}) {
542  auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
543  std::string inferedName;
544  if (!name) {
545  // Try to create a name from the input value.
546  if (auto valueName =
547  value.getDefiningOp()->getAttrOfType<StringAttr>("sv.namehint")) {
548  inferedName = ("not_" + valueName.getValue()).str();
549  name = inferedName;
550  }
551  }
552 
553  return buildNamedOp(
554  [&]() { return b.create<comb::XorOp>(loc, value, allOnes); }, name);
555 
556  return b.createOrFold<comb::XorOp>(loc, value, allOnes, false);
557  }
558 
559  Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
560  return buildNamedOp(
561  [&]() { return b.create<comb::ShlOp>(loc, value, shift); }, name);
562  }
563 
564  Value concat(ValueRange values, std::optional<StringRef> name = {}) {
565  return buildNamedOp([&]() { return b.create<comb::ConcatOp>(loc, values); },
566  name);
567  }
568 
569  // Packs a list of values into a hw.struct.
570  Value pack(ValueRange values, Type structType = Type(),
571  std::optional<StringRef> name = {}) {
572  if (!structType)
573  structType = tupleToStruct(values.getTypes());
574  return buildNamedOp(
575  [&]() { return b.create<hw::StructCreateOp>(loc, structType, values); },
576  name);
577  }
578 
579  // Unpacks a hw.struct into a list of values.
580  ValueRange unpack(Value value) {
581  auto structType = cast<hw::StructType>(value.getType());
582  llvm::SmallVector<Type> innerTypes;
583  structType.getInnerTypes(innerTypes);
584  return b.create<hw::StructExplodeOp>(loc, innerTypes, value).getResults();
585  }
586 
587  llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
588  llvm::SmallVector<Value> bits;
589  for (unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
590  bits.push_back(b.create<comb::ExtractOp>(loc, v, i, /*bitWidth=*/1));
591  return bits;
592  }
593 
594  // OR-reduction of the bits in 'v'.
595  Value rOr(Value v, std::optional<StringRef> name = {}) {
596  return buildNamedOp([&]() { return bOr(toBits(v)); }, name);
597  }
598 
599  // Extract bits v[hi:lo] (inclusive).
600  Value extract(Value v, unsigned lo, unsigned hi,
601  std::optional<StringRef> name = {}) {
602  unsigned width = hi - lo + 1;
603  return buildNamedOp(
604  [&]() { return b.create<comb::ExtractOp>(loc, v, lo, width); }, name);
605  }
606 
607  // Truncates 'value' to its lower 'width' bits.
608  Value truncate(Value value, unsigned width,
609  std::optional<StringRef> name = {}) {
610  return extract(value, 0, width - 1, name);
611  }
612 
613  Value zext(Value value, unsigned outWidth,
614  std::optional<StringRef> name = {}) {
615  unsigned inWidth = value.getType().getIntOrFloatBitWidth();
616  assert(inWidth <= outWidth && "zext: input width must be <- output width.");
617  if (inWidth == outWidth)
618  return value;
619  auto c0 = constant(outWidth - inWidth, 0);
620  return concat({c0, value}, name);
621  }
622 
623  Value sext(Value value, unsigned outWidth,
624  std::optional<StringRef> name = {}) {
625  return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
626  }
627 
628  // Extracts a single bit v[bit].
629  Value bit(Value v, unsigned index, std::optional<StringRef> name = {}) {
630  return extract(v, index, index, name);
631  }
632 
633  // Creates a hw.array of the given values.
634  Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
635  return buildNamedOp(
636  [&]() { return b.create<hw::ArrayCreateOp>(loc, values); }, name);
637  }
638 
639  // Extract the 'index'th value from the input array.
640  Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
641  return buildNamedOp(
642  [&]() { return b.create<hw::ArrayGetOp>(loc, array, index); }, name);
643  }
644 
645  // Muxes a range of values.
646  // The select signal is expected to be a decimal value which selects starting
647  // from the lowest index of value.
648  Value mux(Value index, ValueRange values,
649  std::optional<StringRef> name = {}) {
650  if (values.size() == 2)
651  return b.create<comb::MuxOp>(loc, index, values[1], values[0]);
652 
653  return arrayGet(arrayCreate(values), index, name);
654  }
655 
656  // Muxes a range of values. The select signal is expected to be a 1-hot
657  // encoded value.
658  Value ohMux(Value index, ValueRange inputs) {
659  // Confirm the select input can be a one-hot encoding for the inputs.
660  unsigned numInputs = inputs.size();
661  assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
662  "one-hot select can't mux inputs");
663 
664  // Start the mux tree with zero value.
665  // Todo: clean up when handshake supports i0.
666  auto dataType = inputs[0].getType();
667  unsigned width =
668  isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
669  Value muxValue = constant(width, 0);
670 
671  // Iteratively chain together muxes from the high bit to the low bit.
672  for (size_t i = numInputs - 1; i != 0; --i) {
673  Value input = inputs[i];
674  Value selectBit = bit(index, i);
675  muxValue = mux(selectBit, {muxValue, input});
676  }
677 
678  return muxValue;
679  }
680 
681  hw::ModulePortInfo info;
682  OpBuilder &b;
683  Location loc;
684  Value clk, rst;
685  DenseMap<APInt, Value> constants;
686 };
687 
688 /// Creates a Value that has an assigned zero value. For structs, this
689 /// corresponds to assigning zero to each element recursively.
690 static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
691  return TypeSwitch<Type, Value>(type)
692  .Case<NoneType>([&](NoneType) { return s.constant(0, 0); })
693  .Case<IntType, IntegerType>([&](auto type) {
694  return s.constant(type.getIntOrFloatBitWidth(), 0);
695  })
696  .Case<hw::StructType>([&](auto structType) {
697  SmallVector<Value> zeroValues;
698  for (auto field : structType.getElements())
699  zeroValues.push_back(createZeroDataConst(s, loc, field.type));
700  return s.b.create<hw::StructCreateOp>(loc, structType, zeroValues);
701  })
702  .Default([&](Type) -> Value {
703  emitError(loc) << "unsupported type for zero value: " << type;
704  assert(false);
705  return {};
706  });
707 }
708 
709 static void
710 addSequentialIOOperandsIfNeeded(Operation *op,
711  llvm::SmallVectorImpl<Value> &operands) {
712  if (op->hasTrait<mlir::OpTrait::HasClock>()) {
713  // Parent should at this point be a hw.module and have clock and reset
714  // ports.
715  auto parent = cast<hw::HWModuleOp>(op->getParentOp());
716  operands.push_back(
717  parent.getArgumentForInput(parent.getNumInputPorts() - 2));
718  operands.push_back(
719  parent.getArgumentForInput(parent.getNumInputPorts() - 1));
720  }
721 }
722 
723 template <typename T>
724 class HandshakeConversionPattern : public OpConversionPattern<T> {
725 public:
726  HandshakeConversionPattern(ESITypeConverter &typeConverter,
727  MLIRContext *context, OpBuilder &submoduleBuilder,
728  HandshakeLoweringState &ls)
729  : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
730  submoduleBuilder(submoduleBuilder), ls(ls) {}
731 
732  using OpAdaptor = typename T::Adaptor;
733 
734  LogicalResult
735  matchAndRewrite(T op, OpAdaptor adaptor,
736  ConversionPatternRewriter &rewriter) const override {
737 
738  // Check if a submodule has already been created for the op. If so,
739  // instantiate the submodule. Else, run the pattern-defined module
740  // builder.
741  hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
742  if (!implModule) {
743  auto portInfo = ModulePortInfo(getPortInfoForOp(op));
744 
745  submoduleBuilder.setInsertionPoint(op->getParentOp());
746  implModule = submoduleBuilder.create<hw::HWModuleOp>(
747  op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
748  portInfo, [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
749  // if 'op' has clock trait, extract these and provide them to the
750  // RTL builder.
751  Value clk, rst;
752  if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
753  clk = ports.getInput("clock");
754  rst = ports.getInput("reset");
755  }
756 
757  BackedgeBuilder bb(b, op.getLoc());
758  RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
759  this->buildModule(op, bb, s, ports);
760  });
761  }
762 
763  // Instantiate the submodule.
764  llvm::SmallVector<Value> operands = adaptor.getOperands();
765  addSequentialIOOperandsIfNeeded(op, operands);
766  rewriter.replaceOpWithNewOp<hw::InstanceOp>(
767  op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
768  return success();
769  }
770 
771  virtual void buildModule(T op, BackedgeBuilder &bb, RTLBuilder &builder,
772  hw::HWModulePortAccessor &ports) const = 0;
773 
774  // Syntactic sugar functions.
775  // Unwraps an ESI-interfaced module into its constituent handshake signals.
776  // Backedges are created for the to-be-resolved signals, and output ports
777  // are assigned to their wrapped counterparts.
778  UnwrappedIO unwrapIO(RTLBuilder &s, BackedgeBuilder &bb,
779  hw::HWModulePortAccessor &ports) const {
780  UnwrappedIO unwrapped;
781  for (auto port : ports.getInputs()) {
782  if (!isa<esi::ChannelType>(port.getType()))
783  continue;
784  InputHandshake hs;
785  auto ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
786  auto [data, valid] = s.unwrap(port, *ready);
787  hs.data = data;
788  hs.valid = valid;
789  hs.ready = ready;
790  unwrapped.inputs.push_back(hs);
791  }
792  for (auto &outputInfo : ports.getPortList().getOutputs()) {
793  esi::ChannelType channelType =
794  dyn_cast<esi::ChannelType>(outputInfo.type);
795  if (!channelType)
796  continue;
797  OutputHandshake hs;
798  Type innerType = channelType.getInner();
799  auto data = std::make_shared<Backedge>(bb.get(innerType));
800  auto valid = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
801  auto [dataCh, ready] = s.wrap(*data, *valid);
802  hs.data = data;
803  hs.valid = valid;
804  hs.ready = ready;
805  ports.setOutput(outputInfo.name, dataCh);
806  unwrapped.outputs.push_back(hs);
807  }
808  return unwrapped;
809  }
810 
811  void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
812  OutputHandshake &output, Value cond) const {
813  auto validAndReady = s.bAnd({output.ready, cond});
814  for (auto &input : inputs)
815  input.ready->setValue(validAndReady);
816  }
817 
818  void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
819  OutputHandshake &output) const {
820  llvm::SmallVector<Value> valids;
821  for (auto &input : inputs)
822  valids.push_back(input.valid);
823  Value allValid = s.bAnd(valids);
824  output.valid->setValue(allValid);
825  setAllReadyWithCond(s, inputs, output, allValid);
826  }
827 
828  // Builds mux logic for the given inputs and outputs.
829  // Note: it is assumed that the caller has removed the 'select' signal from
830  // the 'unwrapped' inputs and provide it as a separate argument.
831  void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
832  InputHandshake &select) const {
833  // ============================= Control logic =============================
834  size_t numInputs = unwrapped.inputs.size();
835  size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
836  Value truncatedSelect =
837  select.data.getType().getIntOrFloatBitWidth() > selectWidth
838  ? s.truncate(select.data, selectWidth)
839  : select.data;
840 
841  // Decimal-to-1-hot decoder. 'shl' operands must be identical in size.
842  auto selectZext = s.zext(truncatedSelect, numInputs);
843  auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
844  auto &res = unwrapped.outputs[0];
845 
846  // Mux input valid signals.
847  auto selectedInputValid =
848  s.mux(truncatedSelect, unwrapped.getInputValids());
849  // Result is valid when the selected input and the select input is valid.
850  auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
851  res.valid->setValue(selAndInputValid);
852  auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
853 
854  // Select is ready when result is valid and ready (result transacting).
855  select.ready->setValue(resValidAndReady);
856 
857  // Assign each input ready signal if it is currently selected.
858  for (auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
859  // Extract the selection bit for this input.
860  auto isSelected = s.bit(select1h, inIdx);
861 
862  // '&' that with the result valid and ready, and assign to the input
863  // ready signal.
864  auto activeAndResultValidAndReady =
865  s.bAnd({isSelected, resValidAndReady});
866  in.ready->setValue(activeAndResultValidAndReady);
867  }
868 
869  // ============================== Data logic ===============================
870  res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
871  }
872 
873  // Builds fork logic between the single input and multiple outputs' control
874  // networks. Caller is expected to handle data separately.
875  void buildForkLogic(RTLBuilder &s, BackedgeBuilder &bb, InputHandshake &input,
876  ArrayRef<OutputHandshake> outputs) const {
877  auto c0I1 = s.constant(1, 0);
878  llvm::SmallVector<Value> doneWires;
879  for (auto [i, output] : llvm::enumerate(outputs)) {
880  auto doneBE = bb.get(s.b.getI1Type());
881  auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
882  auto emittedReg = s.reg("emitted_" + std::to_string(i), emitted, c0I1);
883  auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
884  output.valid->setValue(outValid);
885  auto validReady = s.bAnd({output.ready, outValid});
886  auto done = s.bOr({validReady, emittedReg}, "done" + std::to_string(i));
887  doneBE.setValue(done);
888  doneWires.push_back(done);
889  }
890  input.ready->setValue(s.bAnd(doneWires, "allDone"));
891  }
892 
893  // Builds a unit-rate actor around an inner operation. 'unitBuilder' is a
894  // function which takes the set of unwrapped data inputs, and returns a
895  // value which should be assigned to the output data value.
896  void buildUnitRateJoinLogic(
897  RTLBuilder &s, UnwrappedIO &unwrappedIO,
898  llvm::function_ref<Value(ValueRange)> unitBuilder) const {
899  assert(unwrappedIO.outputs.size() == 1 &&
900  "Expected exactly one output for unit-rate join actor");
901  // Control logic.
902  this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
903 
904  // Data logic.
905  auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
906  unwrappedIO.outputs[0].data->setValue(unitRes);
907  }
908 
909  void buildUnitRateForkLogic(
910  RTLBuilder &s, BackedgeBuilder &bb, UnwrappedIO &unwrappedIO,
911  llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder) const {
912  assert(unwrappedIO.inputs.size() == 1 &&
913  "Expected exactly one input for unit-rate fork actor");
914  // Control logic.
915  this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
916 
917  // Data logic.
918  auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
919  assert(unitResults.size() == unwrappedIO.outputs.size() &&
920  "Expected unit builder to return one result per output");
921  for (auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
922  outport.data->setValue(res);
923  }
924 
925  void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
926  bool signExtend) const {
927  size_t outWidth =
928  toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
929  .getIntOrFloatBitWidth();
930  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
931  if (signExtend)
932  return s.sext(inputs[0], outWidth);
933  return s.zext(inputs[0], outWidth);
934  });
935  }
936 
937  void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
938  unsigned targetWidth) const {
939  size_t outWidth =
940  toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
941  .getIntOrFloatBitWidth();
942  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
943  return s.truncate(inputs[0], outWidth);
944  });
945  }
946 
947  /// Return the number of bits needed to index the given number of values.
948  static size_t getNumIndexBits(uint64_t numValues) {
949  return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
950  }
951 
952  Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
953  Value defaultValue,
954  DenseMap<size_t, Value> &indexMapping) const {
955  auto numInputs = inputs.size();
956  auto priorityArb = defaultValue;
957 
958  for (size_t i = numInputs; i > 0; --i) {
959  size_t inputIndex = i - 1;
960  size_t oneHotIndex = size_t{1} << inputIndex;
961  auto constIndex = s.constant(numInputs, oneHotIndex);
962  indexMapping[inputIndex] = constIndex;
963  priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
964  }
965  return priorityArb;
966  }
967 
968 private:
969  OpBuilder &submoduleBuilder;
970  HandshakeLoweringState &ls;
971 };
972 
973 class ForkConversionPattern : public HandshakeConversionPattern<ForkOp> {
974 public:
975  using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
976  void buildModule(ForkOp op, BackedgeBuilder &bb, RTLBuilder &s,
977  hw::HWModulePortAccessor &ports) const override {
978  auto unwrapped = unwrapIO(s, bb, ports);
979  buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
980  return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
981  });
982  }
983 };
984 
985 class JoinConversionPattern : public HandshakeConversionPattern<JoinOp> {
986 public:
987  using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
988  void buildModule(JoinOp op, BackedgeBuilder &bb, RTLBuilder &s,
989  hw::HWModulePortAccessor &ports) const override {
990  auto unwrappedIO = unwrapIO(s, bb, ports);
991  buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
992  unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
993  };
994 };
995 
996 class SyncConversionPattern : public HandshakeConversionPattern<SyncOp> {
997 public:
998  using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
999  void buildModule(SyncOp op, BackedgeBuilder &bb, RTLBuilder &s,
1000  hw::HWModulePortAccessor &ports) const override {
1001  auto unwrappedIO = unwrapIO(s, bb, ports);
1002 
1003  // A helper wire that will be used to connect the two built logics
1004  HandshakeWire wire(bb, s.b.getNoneType());
1005 
1006  OutputHandshake output = wire.getAsOutput();
1007  buildJoinLogic(s, unwrappedIO.inputs, output);
1008 
1009  InputHandshake input = wire.getAsInput();
1010 
1011  // The state-keeping fork logic is required here, as the circuit isn't
1012  // allowed to wait for all the consumers to be ready. Connecting the ready
1013  // signals of the outputs to their corresponding valid signals leads to
1014  // combinatorial cycles. The paper which introduced compositional dataflow
1015  // circuits explicitly mentions this limitation:
1016  // http://arcade.cs.columbia.edu/df-memocode17.pdf
1017  buildForkLogic(s, bb, input, unwrappedIO.outputs);
1018 
1019  // Directly connect the data wires, only the control signals need to be
1020  // combined.
1021  for (auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1022  out.data->setValue(in.data);
1023  };
1024 };
1025 
1026 class MuxConversionPattern : public HandshakeConversionPattern<MuxOp> {
1027 public:
1028  using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1029  void buildModule(MuxOp op, BackedgeBuilder &bb, RTLBuilder &s,
1030  hw::HWModulePortAccessor &ports) const override {
1031  auto unwrappedIO = unwrapIO(s, bb, ports);
1032 
1033  // Extract select signal from the unwrapped IO.
1034  auto select = unwrappedIO.inputs[0];
1035  unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1036  buildMuxLogic(s, unwrappedIO, select);
1037  };
1038 };
1039 
1040 class InstanceConversionPattern
1041  : public HandshakeConversionPattern<handshake::InstanceOp> {
1042 public:
1043  using HandshakeConversionPattern<
1044  handshake::InstanceOp>::HandshakeConversionPattern;
1045  void buildModule(handshake::InstanceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1046  hw::HWModulePortAccessor &ports) const override {
1047  assert(false &&
1048  "If we indeed perform conversion in post-order, this "
1049  "should never be called. The base HandshakeConversionPattern logic "
1050  "will instantiate the external module.");
1051  }
1052 };
1053 
1054 class ReturnConversionPattern
1055  : public OpConversionPattern<handshake::ReturnOp> {
1056 public:
1057  using OpConversionPattern::OpConversionPattern;
1058  LogicalResult
1059  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1060  ConversionPatternRewriter &rewriter) const override {
1061  // Locate existing output op, Append operands to output op, and move to
1062  // the end of the block.
1063  auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1064  auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1065  outputOp->setOperands(adaptor.getOperands());
1066  outputOp->moveAfter(&parent.getBodyBlock()->back());
1067  rewriter.eraseOp(op);
1068  return success();
1069  }
1070 };
1071 
1072 // Converts an arbitrary operation into a unit rate actor. A unit rate actor
1073 // will transact once all inputs are valid and its output is ready.
1074 template <typename TIn, typename TOut = TIn>
1075 class UnitRateConversionPattern : public HandshakeConversionPattern<TIn> {
1076 public:
1077  using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1078  void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1079  hw::HWModulePortAccessor &ports) const override {
1080  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1081  this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1082  // Create TOut - it is assumed that TOut trivially
1083  // constructs from the input data signals of TIn.
1084  // To disambiguate ambiguous builders with default arguments (e.g.,
1085  // twoState UnitAttr), specify attribute array explicitly.
1086  return s.b.create<TOut>(op.getLoc(), inputs,
1087  /* attributes */ ArrayRef<NamedAttribute>{});
1088  });
1089  };
1090 };
1091 
1092 class PackConversionPattern : public HandshakeConversionPattern<PackOp> {
1093 public:
1094  using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1095  void buildModule(PackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1096  hw::HWModulePortAccessor &ports) const override {
1097  auto unwrappedIO = unwrapIO(s, bb, ports);
1098  buildUnitRateJoinLogic(s, unwrappedIO,
1099  [&](ValueRange inputs) { return s.pack(inputs); });
1100  };
1101 };
1102 
1103 class StructCreateConversionPattern
1104  : public HandshakeConversionPattern<hw::StructCreateOp> {
1105 public:
1106  using HandshakeConversionPattern<
1107  hw::StructCreateOp>::HandshakeConversionPattern;
1108  void buildModule(hw::StructCreateOp op, BackedgeBuilder &bb, RTLBuilder &s,
1109  hw::HWModulePortAccessor &ports) const override {
1110  auto unwrappedIO = unwrapIO(s, bb, ports);
1111  auto structType = op.getResult().getType();
1112  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1113  return s.pack(inputs, structType);
1114  });
1115  };
1116 };
1117 
1118 class UnpackConversionPattern : public HandshakeConversionPattern<UnpackOp> {
1119 public:
1120  using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1121  void buildModule(UnpackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1122  hw::HWModulePortAccessor &ports) const override {
1123  auto unwrappedIO = unwrapIO(s, bb, ports);
1124  buildUnitRateForkLogic(s, bb, unwrappedIO,
1125  [&](Value input) { return s.unpack(input); });
1126  };
1127 };
1128 
1129 class ConditionalBranchConversionPattern
1130  : public HandshakeConversionPattern<ConditionalBranchOp> {
1131 public:
1132  using HandshakeConversionPattern<
1133  ConditionalBranchOp>::HandshakeConversionPattern;
1134  void buildModule(ConditionalBranchOp op, BackedgeBuilder &bb, RTLBuilder &s,
1135  hw::HWModulePortAccessor &ports) const override {
1136  auto unwrappedIO = unwrapIO(s, bb, ports);
1137  auto cond = unwrappedIO.inputs[0];
1138  auto arg = unwrappedIO.inputs[1];
1139  auto trueRes = unwrappedIO.outputs[0];
1140  auto falseRes = unwrappedIO.outputs[1];
1141 
1142  auto condArgValid = s.bAnd({cond.valid, arg.valid});
1143 
1144  // Connect valid signal of both results.
1145  trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1146  falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1147 
1148  // Connecte data signals of both results.
1149  trueRes.data->setValue(arg.data);
1150  falseRes.data->setValue(arg.data);
1151 
1152  // Connect ready signal of input and condition.
1153  auto selectedResultReady =
1154  s.mux(cond.data, {falseRes.ready, trueRes.ready});
1155  auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1156  arg.ready->setValue(condArgReady);
1157  cond.ready->setValue(condArgReady);
1158  };
1159 };
1160 
1161 template <typename TIn, bool signExtend>
1162 class ExtendConversionPattern : public HandshakeConversionPattern<TIn> {
1163 public:
1164  using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1165  void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1166  hw::HWModulePortAccessor &ports) const override {
1167  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1168  this->buildExtendLogic(s, unwrappedIO, /*signExtend=*/signExtend);
1169  };
1170 };
1171 
1172 class ComparisonConversionPattern
1173  : public HandshakeConversionPattern<arith::CmpIOp> {
1174 public:
1175  using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1176  void buildModule(arith::CmpIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1177  hw::HWModulePortAccessor &ports) const override {
1178  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1179  auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1180  return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1181  return s.b.create<comb::ICmpOp>(op.getLoc(), predicate, inputs[0],
1182  inputs[1]);
1183  });
1184  };
1185 
1186  switch (op.getPredicate()) {
1187  case arith::CmpIPredicate::eq:
1188  return buildCompareLogic(comb::ICmpPredicate::eq);
1189  case arith::CmpIPredicate::ne:
1190  return buildCompareLogic(comb::ICmpPredicate::ne);
1191  case arith::CmpIPredicate::slt:
1192  return buildCompareLogic(comb::ICmpPredicate::slt);
1193  case arith::CmpIPredicate::ult:
1194  return buildCompareLogic(comb::ICmpPredicate::ult);
1195  case arith::CmpIPredicate::sle:
1196  return buildCompareLogic(comb::ICmpPredicate::sle);
1197  case arith::CmpIPredicate::ule:
1198  return buildCompareLogic(comb::ICmpPredicate::ule);
1199  case arith::CmpIPredicate::sgt:
1200  return buildCompareLogic(comb::ICmpPredicate::sgt);
1201  case arith::CmpIPredicate::ugt:
1202  return buildCompareLogic(comb::ICmpPredicate::ugt);
1203  case arith::CmpIPredicate::sge:
1204  return buildCompareLogic(comb::ICmpPredicate::sge);
1205  case arith::CmpIPredicate::uge:
1206  return buildCompareLogic(comb::ICmpPredicate::uge);
1207  }
1208  assert(false && "invalid CmpIOp");
1209  };
1210 };
1211 
1212 class TruncateConversionPattern
1213  : public HandshakeConversionPattern<arith::TruncIOp> {
1214 public:
1215  using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1216  void buildModule(arith::TruncIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1217  hw::HWModulePortAccessor &ports) const override {
1218  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1219  unsigned targetBits =
1220  toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1221  buildTruncateLogic(s, unwrappedIO, targetBits);
1222  };
1223 };
1224 
1225 class ControlMergeConversionPattern
1226  : public HandshakeConversionPattern<ControlMergeOp> {
1227 public:
1228  using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1229  void buildModule(ControlMergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1230  hw::HWModulePortAccessor &ports) const override {
1231  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1232  auto resData = unwrappedIO.outputs[0];
1233  auto resIndex = unwrappedIO.outputs[1];
1234 
1235  // Define some common types and values that will be used.
1236  unsigned numInputs = unwrappedIO.inputs.size();
1237  auto indexType = s.b.getIntegerType(numInputs);
1238  Value noWinner = s.constant(numInputs, 0);
1239  Value c0I1 = s.constant(1, 0);
1240 
1241  // Declare register for storing arbitration winner.
1242  auto won = bb.get(indexType);
1243  Value wonReg = s.reg("won_reg", won, noWinner);
1244 
1245  // Declare wire for arbitration winner.
1246  auto win = bb.get(indexType);
1247 
1248  // Declare wire for whether the circuit just fired and emitted both
1249  // outputs.
1250  auto fired = bb.get(s.b.getI1Type());
1251 
1252  // Declare registers for storing if each output has been emitted.
1253  auto resultEmitted = bb.get(s.b.getI1Type());
1254  Value resultEmittedReg = s.reg("result_emitted_reg", resultEmitted, c0I1);
1255  auto indexEmitted = bb.get(s.b.getI1Type());
1256  Value indexEmittedReg = s.reg("index_emitted_reg", indexEmitted, c0I1);
1257 
1258  // Declare wires for if each output is done.
1259  auto resultDone = bb.get(s.b.getI1Type());
1260  auto indexDone = bb.get(s.b.getI1Type());
1261 
1262  // Create predicates to assert if the win wire or won register hold a
1263  // valid index.
1264  auto hasWinnerCondition = s.rOr({win});
1265  auto hadWinnerCondition = s.rOr({wonReg});
1266 
1267  // Create an arbiter based on a simple priority-encoding scheme to assign
1268  // an index to the win wire. If the won register is set, just use that. In
1269  // the case that won is not set and no input is valid, set a sentinel
1270  // value to indicate no winner was chosen. The constant values are
1271  // remembered in a map so they can be re-used later to assign the arg
1272  // ready outputs.
1273  DenseMap<size_t, Value> argIndexValues;
1274  Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1275  noWinner, argIndexValues);
1276  priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1277  win.setValue(priorityArb);
1278 
1279  // Create the logic to assign the result and index outputs. The result
1280  // valid output will always be assigned, and if isControl is not set, the
1281  // result data output will also be assigned. The index valid and data
1282  // outputs will always be assigned. The win wire from the arbiter is used
1283  // to index into a tree of muxes to select the chosen input's signal(s),
1284  // and is fed directly to the index output. Both the result and index
1285  // valid outputs are gated on the win wire being set to something other
1286  // than the sentinel value.
1287  auto resultNotEmitted = s.bNot(resultEmittedReg);
1288  auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1289  resData.valid->setValue(resultValid);
1290  resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1291 
1292  auto indexNotEmitted = s.bNot(indexEmittedReg);
1293  auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1294  resIndex.valid->setValue(indexValid);
1295 
1296  // Use the one-hot win wire to select the index to output in the index
1297  // data.
1298  SmallVector<Value, 8> indexOutputs;
1299  for (size_t i = 0; i < numInputs; ++i)
1300  indexOutputs.push_back(s.constant(64, i));
1301 
1302  auto indexOutput = s.ohMux(win, indexOutputs);
1303  resIndex.data->setValue(indexOutput);
1304 
1305  // Create the logic to set the won register. If the fired wire is
1306  // asserted, we have finished this round and can and reset the register to
1307  // the sentinel value that indicates there is no winner. Otherwise, we
1308  // need to hold the value of the win register until we can fire.
1309  won.setValue(s.mux(fired, {win, noWinner}));
1310 
1311  // Create the logic to set the done wires for the result and index. For
1312  // both outputs, the done wire is asserted when the output is valid and
1313  // ready, or the emitted register for that output is set.
1314  auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1315  resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1316 
1317  auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1318  indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1319 
1320  // Create the logic to set the fired wire. It is asserted when both result
1321  // and index are done.
1322  fired.setValue(s.bAnd({resultDone, indexDone}));
1323 
1324  // Create the logic to assign the emitted registers. If the fired wire is
1325  // asserted, we have finished this round and can reset the registers to 0.
1326  // Otherwise, we need to hold the values of the done registers until we
1327  // can fire.
1328  resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1329  indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1330 
1331  // Create the logic to assign the arg ready outputs. The logic is
1332  // identical for each arg. If the fired wire is asserted, and the win wire
1333  // holds an arg's index, that arg is ready.
1334  auto winnerOrDefault = s.mux(fired, {noWinner, win});
1335  for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1336  auto &indexValue = argIndexValues[i];
1337  ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1338  }
1339  };
1340 };
1341 
1342 class MergeConversionPattern : public HandshakeConversionPattern<MergeOp> {
1343 public:
1344  using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1345  void buildModule(MergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1346  hw::HWModulePortAccessor &ports) const override {
1347  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1348  auto resData = unwrappedIO.outputs[0];
1349 
1350  // Define some common types and values that will be used.
1351  unsigned numInputs = unwrappedIO.inputs.size();
1352  auto indexType = s.b.getIntegerType(numInputs);
1353  Value noWinner = s.constant(numInputs, 0);
1354 
1355  // Declare wire for arbitration winner.
1356  auto win = bb.get(indexType);
1357 
1358  // Create predicates to assert if the win wire holds a valid index.
1359  auto hasWinnerCondition = s.rOr(win);
1360 
1361  // Create an arbiter based on a simple priority-encoding scheme to assign an
1362  // index to the win wire. In the case that no input is valid, set a sentinel
1363  // value to indicate no winner was chosen. The constant values are
1364  // remembered in a map so they can be re-used later to assign the arg ready
1365  // outputs.
1366  DenseMap<size_t, Value> argIndexValues;
1367  Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1368  noWinner, argIndexValues);
1369  win.setValue(priorityArb);
1370 
1371  // Create the logic to assign the result outputs. The result valid and data
1372  // outputs will always be assigned. The win wire from the arbiter is used to
1373  // index into a tree of muxes to select the chosen input's signal(s). The
1374  // result outputs are gated on the win wire being non-zero.
1375 
1376  resData.valid->setValue(hasWinnerCondition);
1377  resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1378 
1379  // Create the logic to set the done wires for the result. The done wire is
1380  // asserted when the output is valid and ready, or the emitted register is
1381  // set.
1382  auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1383 
1384  // Create the logic to assign the arg ready outputs. The logic is
1385  // identical for each arg. If the fired wire is asserted, and the win wire
1386  // holds an arg's index, that arg is ready.
1387  auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1388  for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1389  auto &indexValue = argIndexValues[i];
1390  ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1391  }
1392  };
1393 };
1394 
1395 class LoadConversionPattern
1396  : public HandshakeConversionPattern<handshake::LoadOp> {
1397 public:
1398  using HandshakeConversionPattern<
1399  handshake::LoadOp>::HandshakeConversionPattern;
1400  void buildModule(handshake::LoadOp op, BackedgeBuilder &bb, RTLBuilder &s,
1401  hw::HWModulePortAccessor &ports) const override {
1402  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1403  auto addrFromUser = unwrappedIO.inputs[0];
1404  auto dataFromMem = unwrappedIO.inputs[1];
1405  auto controlIn = unwrappedIO.inputs[2];
1406  auto dataToUser = unwrappedIO.outputs[0];
1407  auto addrToMem = unwrappedIO.outputs[1];
1408 
1409  addrToMem.data->setValue(addrFromUser.data);
1410  dataToUser.data->setValue(dataFromMem.data);
1411 
1412  // The valid/ready logic between user address/control to memoryAddr is
1413  // join logic.
1414  buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1415 
1416  // The valid/ready logic between memoryData and outputData is a direct
1417  // connection.
1418  dataToUser.valid->setValue(dataFromMem.valid);
1419  dataFromMem.ready->setValue(dataToUser.ready);
1420  };
1421 };
1422 
1423 class StoreConversionPattern
1424  : public HandshakeConversionPattern<handshake::StoreOp> {
1425 public:
1426  using HandshakeConversionPattern<
1427  handshake::StoreOp>::HandshakeConversionPattern;
1428  void buildModule(handshake::StoreOp op, BackedgeBuilder &bb, RTLBuilder &s,
1429  hw::HWModulePortAccessor &ports) const override {
1430  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1431  auto addrFromUser = unwrappedIO.inputs[0];
1432  auto dataFromUser = unwrappedIO.inputs[1];
1433  auto controlIn = unwrappedIO.inputs[2];
1434  auto dataToMem = unwrappedIO.outputs[0];
1435  auto addrToMem = unwrappedIO.outputs[1];
1436 
1437  // Create a gate that will be asserted when all outputs are ready.
1438  auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1439 
1440  // Build the standard join logic from the inputs to the inputsValid and
1441  // outputsReady signals.
1442  HandshakeWire joinWire(bb, s.b.getNoneType());
1443  joinWire.ready->setValue(outputsReady);
1444  OutputHandshake joinOutput = joinWire.getAsOutput();
1445  buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1446 
1447  // Output address and data signals are connected directly.
1448  addrToMem.data->setValue(addrFromUser.data);
1449  dataToMem.data->setValue(dataFromUser.data);
1450 
1451  // Output valid signals are connected from the inputsValid wire.
1452  addrToMem.valid->setValue(*joinWire.valid);
1453  dataToMem.valid->setValue(*joinWire.valid);
1454  };
1455 };
1456 
1457 class MemoryConversionPattern
1458  : public HandshakeConversionPattern<handshake::MemoryOp> {
1459 public:
1460  using HandshakeConversionPattern<
1461  handshake::MemoryOp>::HandshakeConversionPattern;
1462  void buildModule(handshake::MemoryOp op, BackedgeBuilder &bb, RTLBuilder &s,
1463  hw::HWModulePortAccessor &ports) const override {
1464  auto loc = op.getLoc();
1465 
1466  // Gather up the load and store ports.
1467  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1468  struct LoadPort {
1469  InputHandshake &addr;
1470  OutputHandshake &data;
1471  OutputHandshake &done;
1472  };
1473  struct StorePort {
1474  InputHandshake &addr;
1475  InputHandshake &data;
1476  OutputHandshake &done;
1477  };
1478  SmallVector<LoadPort, 4> loadPorts;
1479  SmallVector<StorePort, 4> storePorts;
1480 
1481  unsigned stCount = op.getStCount();
1482  unsigned ldCount = op.getLdCount();
1483  for (unsigned i = 0, e = ldCount; i != e; ++i) {
1484  LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1485  unwrappedIO.outputs[i],
1486  unwrappedIO.outputs[ldCount + stCount + i]};
1487  loadPorts.push_back(port);
1488  }
1489 
1490  for (unsigned i = 0, e = stCount; i != e; ++i) {
1491  StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1492  unwrappedIO.inputs[i * 2],
1493  unwrappedIO.outputs[ldCount + i]};
1494  storePorts.push_back(port);
1495  }
1496 
1497  // used to drive the data wire of the control-only channels.
1498  auto c0I0 = s.constant(0, 0);
1499 
1500  auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1501  auto hlmem = s.b.create<seq::HLMemOp>(
1502  loc, s.clk, s.rst, "_handshake_memory_" + std::to_string(op.getId()),
1503  op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1504 
1505  // Create load ports...
1506  for (auto &ld : loadPorts) {
1507  llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1508  auto readData = s.b.create<seq::ReadPortOp>(loc, hlmem.getHandle(),
1509  addresses, ld.addr.valid,
1510  /*latency=*/0);
1511  ld.data.data->setValue(readData);
1512  ld.done.data->setValue(c0I0);
1513  // Create control fork for the load address valid and ready signals.
1514  buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1515  }
1516 
1517  // Create store ports...
1518  for (auto &st : storePorts) {
1519  // Create a register to buffer the valid path by 1 cycle, to match the
1520  // write latency of 1.
1521  auto writeValidBufferMuxBE = bb.get(s.b.getI1Type());
1522  auto writeValidBuffer =
1523  s.reg("writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1524  st.done.valid->setValue(writeValidBuffer);
1525  st.done.data->setValue(c0I0);
1526 
1527  // Create the logic for when both the buffered write valid signal and the
1528  // store complete ready signal are asserted.
1529  auto storeCompleted =
1530  s.bAnd({st.done.ready, writeValidBuffer}, "storeCompleted");
1531 
1532  // Create a signal for when the write valid buffer is empty or the output
1533  // is ready.
1534  auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1535  auto emptyOrComplete =
1536  s.bOr({notWriteValidBuffer, storeCompleted}, "emptyOrComplete");
1537 
1538  // Connect the gate to both the store address ready and store data ready
1539  st.addr.ready->setValue(emptyOrComplete);
1540  st.data.ready->setValue(emptyOrComplete);
1541 
1542  // Create a wire for when both the store address and data are valid.
1543  auto writeValid = s.bAnd({st.addr.valid, st.data.valid}, "writeValid");
1544 
1545  // Create a mux that drives the buffer input. If the emptyOrComplete
1546  // signal is asserted, the mux selects the writeValid signal. Otherwise,
1547  // it selects the buffer output, keeping the output registered until the
1548  // emptyOrComplete signal is asserted.
1549  writeValidBufferMuxBE.setValue(
1550  s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1551 
1552  // Instantiate the write port operation - truncate address width to memory
1553  // width.
1554  llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1555  s.b.create<seq::WritePortOp>(loc, hlmem.getHandle(), addresses,
1556  st.data.data, writeValid,
1557  /*latency=*/1);
1558  }
1559  }
1560 }; // namespace
1561 
1562 class SinkConversionPattern : public HandshakeConversionPattern<SinkOp> {
1563 public:
1564  using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1565  void buildModule(SinkOp op, BackedgeBuilder &bb, RTLBuilder &s,
1566  hw::HWModulePortAccessor &ports) const override {
1567  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1568  // A sink is always ready to accept a new value.
1569  unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1570  };
1571 };
1572 
1573 class SourceConversionPattern : public HandshakeConversionPattern<SourceOp> {
1574 public:
1575  using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1576  void buildModule(SourceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1577  hw::HWModulePortAccessor &ports) const override {
1578  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1579  // A source always provides a new (i0-typed) value.
1580  unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1581  unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1582  };
1583 };
1584 
1585 class ConstantConversionPattern
1586  : public HandshakeConversionPattern<handshake::ConstantOp> {
1587 public:
1588  using HandshakeConversionPattern<
1589  handshake::ConstantOp>::HandshakeConversionPattern;
1590  void buildModule(handshake::ConstantOp op, BackedgeBuilder &bb, RTLBuilder &s,
1591  hw::HWModulePortAccessor &ports) const override {
1592  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1593  unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1594  unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1595  auto constantValue = op->getAttrOfType<IntegerAttr>("value").getValue();
1596  unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1597  };
1598 };
1599 
1600 class BufferConversionPattern : public HandshakeConversionPattern<BufferOp> {
1601 public:
1602  using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1603  void buildModule(BufferOp op, BackedgeBuilder &bb, RTLBuilder &s,
1604  hw::HWModulePortAccessor &ports) const override {
1605  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1606  auto input = unwrappedIO.inputs[0];
1607  auto output = unwrappedIO.outputs[0];
1608  InputHandshake lastStage;
1609  SmallVector<int64_t> initValues;
1610 
1611  // For now, always build seq buffers.
1612  if (op.getInitValues())
1613  initValues = op.getInitValueArray();
1614 
1615  lastStage =
1616  buildSeqBufferLogic(s, bb, toValidType(op.getDataType()),
1617  op.getNumSlots(), input, output, initValues);
1618 
1619  // Connect the last stage to the output handshake.
1620  output.data->setValue(lastStage.data);
1621  output.valid->setValue(lastStage.valid);
1622  lastStage.ready->setValue(output.ready);
1623  };
1624 
1625  struct SeqBufferStage {
1626  SeqBufferStage(Type dataType, InputHandshake &preStage, BackedgeBuilder &bb,
1627  RTLBuilder &s, size_t index,
1628  std::optional<int64_t> initValue)
1629  : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1630 
1631  // Todo: Change when i0 support is added.
1632  c0s = createZeroDataConst(s, s.loc, dataType);
1633  currentStage.ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
1634 
1635  auto hasInitValue = s.constant(1, initValue.has_value());
1636  auto validBE = bb.get(s.b.getI1Type());
1637  auto validReg = s.reg(getRegName("valid"), validBE, hasInitValue);
1638  auto readyBE = bb.get(s.b.getI1Type());
1639 
1640  Value initValueCs = c0s;
1641  if (initValue.has_value())
1642  initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1643 
1644  // This could/should be revised but needs a larger rethinking to avoid
1645  // introducing new bugs.
1646  Value dataReg =
1647  buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1648  buildControlBufferLogic(validReg, readyBE, dataReg);
1649  }
1650 
1651  StringAttr getRegName(StringRef name) {
1652  return s.b.getStringAttr(name + std::to_string(index) + "_reg");
1653  }
1654 
1655  void buildControlBufferLogic(Value validReg, Backedge &readyBE,
1656  Value dataReg) {
1657  auto c0I1 = s.constant(1, 0);
1658  auto readyRegWire = bb.get(s.b.getI1Type());
1659  auto readyReg = s.reg(getRegName("ready"), readyRegWire, c0I1);
1660 
1661  // Create the logic to drive the current stage valid and potentially
1662  // data.
1663  currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1664  "controlValid" + std::to_string(index));
1665 
1666  // Create the logic to drive the current stage ready.
1667  auto notReadyReg = s.bNot(readyReg);
1668  readyBE.setValue(notReadyReg);
1669 
1670  auto succNotReady = s.bNot(*currentStage.ready);
1671  auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1672  auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1673  auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1674 
1675  // Create a mux for emptying the register when both are ready.
1676  auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1677  readyRegWire.setValue(resetSignal);
1678 
1679  // Add same logic for the data path if necessary.
1680  auto ctrlDataRegBE = bb.get(dataType);
1681  auto ctrlDataReg = s.reg(getRegName("ctrl_data"), ctrlDataRegBE, c0s);
1682  auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1683  currentStage.data = dataResult;
1684 
1685  auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1686  auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1687  ctrlDataRegBE.setValue(dataResetSignal);
1688  }
1689 
1690  Value buildDataBufferLogic(Value validReg, Value initValue,
1691  Backedge &validBE, Backedge &readyBE) {
1692  // Create a signal for when the valid register is empty or the successor
1693  // is ready to accept new token.
1694  auto notValidReg = s.bNot(validReg);
1695  auto emptyOrReady = s.bOr({notValidReg, readyBE});
1696  preStage.ready->setValue(emptyOrReady);
1697 
1698  // Create a mux that drives the register input. If the emptyOrReady
1699  // signal is asserted, the mux selects the predValid signal. Otherwise,
1700  // it selects the register output, keeping the output registered
1701  // unchanged.
1702  auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1703 
1704  // Now we can drive the valid register.
1705  validBE.setValue(validRegMux);
1706 
1707  // Create a mux that drives the date register.
1708  auto dataRegBE = bb.get(dataType);
1709  auto dataReg =
1710  s.reg(getRegName("data"),
1711  s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1712  dataRegBE.setValue(dataReg);
1713  return dataReg;
1714  }
1715 
1716  InputHandshake getOutput() { return currentStage; }
1717 
1718  Type dataType;
1719  InputHandshake &preStage;
1720  InputHandshake currentStage;
1721  RTLBuilder &s;
1722  BackedgeBuilder &bb;
1723  size_t index;
1724 
1725  // A zero-valued constant of equal type as the data type of this buffer.
1726  Value c0s;
1727  };
1728 
1729  InputHandshake buildSeqBufferLogic(RTLBuilder &s, BackedgeBuilder &bb,
1730  Type dataType, unsigned size,
1731  InputHandshake &input,
1732  OutputHandshake &output,
1733  llvm::ArrayRef<int64_t> initValues) const {
1734  // Prime the buffer building logic with an initial stage, which just
1735  // wraps the input handshake.
1736  InputHandshake currentStage = input;
1737 
1738  for (unsigned i = 0; i < size; ++i) {
1739  bool isInitialized = i < initValues.size();
1740  auto initValue =
1741  isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1742  currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1743  .getOutput();
1744  }
1745 
1746  return currentStage;
1747  };
1748 };
1749 
1750 class IndexCastConversionPattern
1751  : public HandshakeConversionPattern<arith::IndexCastOp> {
1752 public:
1753  using HandshakeConversionPattern<
1754  arith::IndexCastOp>::HandshakeConversionPattern;
1755  void buildModule(arith::IndexCastOp op, BackedgeBuilder &bb, RTLBuilder &s,
1756  hw::HWModulePortAccessor &ports) const override {
1757  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1758  unsigned sourceBits =
1759  toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1760  unsigned targetBits =
1761  toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1762  if (targetBits < sourceBits)
1763  buildTruncateLogic(s, unwrappedIO, targetBits);
1764  else
1765  buildExtendLogic(s, unwrappedIO, /*signExtend=*/true);
1766  };
1767 };
1768 
1769 template <typename T>
1770 class ExtModuleConversionPattern : public OpConversionPattern<T> {
1771 public:
1772  ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1773  MLIRContext *context, OpBuilder &submoduleBuilder,
1774  HandshakeLoweringState &ls)
1775  : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
1776  submoduleBuilder(submoduleBuilder), ls(ls) {}
1777  using OpAdaptor = typename T::Adaptor;
1778 
1779  LogicalResult
1780  matchAndRewrite(T op, OpAdaptor adaptor,
1781  ConversionPatternRewriter &rewriter) const override {
1782 
1783  hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
1784  if (!implModule) {
1785  auto portInfo = ModulePortInfo(getPortInfoForOp(op));
1786  implModule = submoduleBuilder.create<hw::HWModuleExternOp>(
1787  op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
1788  portInfo);
1789  }
1790 
1791  llvm::SmallVector<Value> operands = adaptor.getOperands();
1792  addSequentialIOOperandsIfNeeded(op, operands);
1793  rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1794  op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1795  return success();
1796  }
1797 
1798 private:
1799  OpBuilder &submoduleBuilder;
1800  HandshakeLoweringState &ls;
1801 };
1802 
1803 class FuncOpConversionPattern : public OpConversionPattern<handshake::FuncOp> {
1804 public:
1805  using OpConversionPattern::OpConversionPattern;
1806 
1807  LogicalResult
1808  matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
1809  ConversionPatternRewriter &rewriter) const override {
1810  ModulePortInfo ports =
1811  getPortInfoForOpTypes(op, op.getArgumentTypes(), op.getResultTypes());
1812 
1813  HWModuleLike hwModule;
1814  if (op.isExternal()) {
1815  hwModule = rewriter.create<hw::HWModuleExternOp>(
1816  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1817  } else {
1818  auto hwModuleOp = rewriter.create<hw::HWModuleOp>(
1819  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1820  auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1821  rewriter.inlineBlockBefore(&op.getBody().front(),
1822  hwModuleOp.getBodyBlock()->getTerminator(),
1823  args);
1824  hwModule = hwModuleOp;
1825  }
1826 
1827  // Was any predeclaration associated with this func? If so, replace uses
1828  // with the newly created module and erase the predeclaration.
1829  if (auto predecl =
1830  op->getAttrOfType<FlatSymbolRefAttr>(kPredeclarationAttr)) {
1831  auto *parentOp = op->getParentOp();
1832  auto *predeclModule =
1833  SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1834  if (predeclModule) {
1835  if (failed(SymbolTable::replaceAllSymbolUses(
1836  predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1837  return failure();
1838  rewriter.eraseOp(predeclModule);
1839  }
1840  }
1841 
1842  rewriter.eraseOp(op);
1843  return success();
1844  }
1845 };
1846 
1847 } // namespace
1848 
1849 //===----------------------------------------------------------------------===//
1850 // HW Top-module Related Functions
1851 //===----------------------------------------------------------------------===//
1852 
1853 static LogicalResult convertFuncOp(ESITypeConverter &typeConverter,
1854  ConversionTarget &target,
1855  handshake::FuncOp op,
1856  OpBuilder &moduleBuilder) {
1857 
1858  std::map<std::string, unsigned> instanceNameCntr;
1859  NameUniquer instanceUniquer = [&](Operation *op) {
1860  std::string instName = getCallName(op);
1861  if (auto idAttr = op->getAttrOfType<IntegerAttr>("handshake_id"); idAttr) {
1862  // We use a special naming convention for operations which have a
1863  // 'handshake_id' attribute.
1864  instName += "_id" + std::to_string(idAttr.getValue().getZExtValue());
1865  } else {
1866  // Fallback to just prefixing with an integer.
1867  instName += std::to_string(instanceNameCntr[instName]++);
1868  }
1869  return instName;
1870  };
1871 
1872  auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1873  instanceUniquer};
1874  RewritePatternSet patterns(op.getContext());
1875  patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1876  op.getContext());
1877  patterns.insert<JoinConversionPattern, ForkConversionPattern,
1878  SyncConversionPattern>(typeConverter, op.getContext(),
1879  moduleBuilder, ls);
1880 
1881  patterns.insert<
1882  // Comb operations.
1883  UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1884  UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1885  UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1886  UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1887  UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1888  UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1889  UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1890  UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1891  UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1892  UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1893  UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1894  UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1895  UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1896  UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1897  // HW operations.
1898  StructCreateConversionPattern,
1899  // Handshake operations.
1900  ConditionalBranchConversionPattern, MuxConversionPattern,
1901  PackConversionPattern, UnpackConversionPattern,
1902  ComparisonConversionPattern, BufferConversionPattern,
1903  SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1904  MergeConversionPattern, ControlMergeConversionPattern,
1905  LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1906  InstanceConversionPattern,
1907  // Arith operations.
1908  ExtendConversionPattern<arith::ExtUIOp, /*signExtend=*/false>,
1909  ExtendConversionPattern<arith::ExtSIOp, /*signExtend=*/true>,
1910  TruncateConversionPattern, IndexCastConversionPattern>(
1911  typeConverter, op.getContext(), moduleBuilder, ls);
1912 
1913  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1914  return op->emitOpError() << "error during conversion";
1915  return success();
1916 }
1917 
1918 namespace {
1919 class HandshakeToHWPass
1920  : public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1921 public:
1922  void runOnOperation() override {
1923  mlir::ModuleOp mod = getOperation();
1924 
1925  // Lowering to HW requires that every value is used exactly once. Check
1926  // whether this precondition is met, and if not, exit.
1927  for (auto f : mod.getOps<handshake::FuncOp>()) {
1928  if (failed(verifyAllValuesHasOneUse(f))) {
1929  f.emitOpError() << "HandshakeToHW: failed to verify that all values "
1930  "are used exactly once. Remember to run the "
1931  "fork/sink materialization pass before HW lowering.";
1932  signalPassFailure();
1933  return;
1934  }
1935  }
1936 
1937  // Resolve the instance graph to get a top-level module.
1938  std::string topLevel;
1940  SmallVector<std::string> sortedFuncs;
1941  if (resolveInstanceGraph(mod, uses, topLevel, sortedFuncs).failed()) {
1942  signalPassFailure();
1943  return;
1944  }
1945 
1946  ESITypeConverter typeConverter;
1947  ConversionTarget target(getContext());
1948  // All top-level logic of a handshake module will be the interconnectivity
1949  // between instantiated modules.
1950  target.addLegalOp<hw::HWModuleOp, hw::HWModuleExternOp, hw::OutputOp,
1951  hw::InstanceOp>();
1952  target
1953  .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1954 
1955  // Convert the handshake.func operations in post-order wrt. the instance
1956  // graph. This ensures that any referenced submodules (through
1957  // handshake.instance) has already been lowered, and their HW module
1958  // equivalents are available.
1959  OpBuilder submoduleBuilder(mod.getContext());
1960  submoduleBuilder.setInsertionPointToStart(mod.getBody());
1961  for (auto &funcName : llvm::reverse(sortedFuncs)) {
1962  auto funcOp = mod.lookupSymbol<handshake::FuncOp>(funcName);
1963  assert(funcOp && "handshake.func not found in module!");
1964  if (failed(
1965  convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
1966  signalPassFailure();
1967  return;
1968  }
1969  }
1970 
1971  // Second stage: Convert any handshake.extmemory operations and the
1972  // top-level I/O associated with these.
1973  for (auto hwModule : mod.getOps<hw::HWModuleOp>())
1974  if (failed(convertExtMemoryOps(hwModule)))
1975  return signalPassFailure();
1976  }
1977 };
1978 } // end anonymous namespace
1979 
1980 std::unique_ptr<mlir::Pass> circt::createHandshakeToHWPass() {
1981  return std::make_unique<HandshakeToHWPass>();
1982 }
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
MlirType elementType
Definition: CHIRRTL.cpp:29
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
Definition: CalyxOps.cpp:120
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:539
static Type tupleToStruct(TupleType tuple)
Definition: DCToHW.cpp:48
std::function< std::string(Operation *)> NameUniquer
Definition: DCToHW.cpp:45
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
Definition: FIRRTLOps.cpp:1004
int32_t width
Definition: FIRRTL.cpp:36
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
static std::string getCallName(Operation *op)
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
Definition: OM.cpp:96
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Channels are the basic communication primitives.
Definition: Types.h:63
const Type * getInner() const
Definition: Types.h:66
def create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition: seq.py:137
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition: CombOps.cpp:25
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
Definition: PassHelpers.cpp:38
static constexpr const char * kPredeclarationAttr
Definition: HandshakeToHW.h:40
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Type toValidType(Type t)
esi::ChannelType esiWrapper(mlir::Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getInputs()
PortDirectionRange getOutputs()