CIRCT  19.0.0git
HandshakeToHW.cpp
Go to the documentation of this file.
1 //===- HandshakeToHW.cpp - Translate Handshake into HW ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main Handshake to HW Conversion Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
25 #include "mlir/Dialect/Arith/IR/Arith.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Pass/PassManager.h"
30 #include "mlir/Transforms/DialectConversion.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/MathExtras.h"
33 #include <optional>
34 
35 namespace circt {
36 #define GEN_PASS_DEF_HANDSHAKETOHW
37 #include "circt/Conversion/Passes.h.inc"
38 } // namespace circt
39 
40 using namespace mlir;
41 using namespace circt;
42 using namespace circt::handshake;
43 using namespace circt::hw;
44 
45 using NameUniquer = std::function<std::string(Operation *)>;
46 
47 namespace {
48 
49 static Type tupleToStruct(TypeRange types) {
50  return toValidType(mlir::TupleType::get(types[0].getContext(), types));
51 }
52 
53 // Shared state used by various functions; captured in a struct to reduce the
54 // number of arguments that we have to pass around.
55 struct HandshakeLoweringState {
56  ModuleOp parentModule;
57  NameUniquer nameUniquer;
58 };
59 
60 // A type converter is needed to perform the in-flight materialization of "raw"
61 // (non-ESI channel) types to their ESI channel correspondents. This comes into
62 // effect when backedges exist in the input IR.
63 class ESITypeConverter : public TypeConverter {
64 public:
65  ESITypeConverter() {
66  addConversion([](Type type) -> Type { return esiWrapper(type); });
67 
68  addTargetMaterialization(
69  [&](mlir::OpBuilder &builder, mlir::Type resultType,
70  mlir::ValueRange inputs,
71  mlir::Location loc) -> std::optional<mlir::Value> {
72  if (inputs.size() != 1)
73  return std::nullopt;
74  return inputs[0];
75  });
76 
77  addSourceMaterialization(
78  [&](mlir::OpBuilder &builder, mlir::Type resultType,
79  mlir::ValueRange inputs,
80  mlir::Location loc) -> std::optional<mlir::Value> {
81  if (inputs.size() != 1)
82  return std::nullopt;
83  return inputs[0];
84  });
85  }
86 };
87 
88 } // namespace
89 
90 /// Returns a submodule name resulting from an operation, without discriminating
91 /// type information.
92 static std::string getBareSubModuleName(Operation *oldOp) {
93  // The dialect name is separated from the operation name by '.', which is not
94  // valid in SystemVerilog module names. In case this name is used in
95  // SystemVerilog output, replace '.' with '_'.
96  std::string subModuleName = oldOp->getName().getStringRef().str();
97  std::replace(subModuleName.begin(), subModuleName.end(), '.', '_');
98  return subModuleName;
99 }
100 
101 static std::string getCallName(Operation *op) {
102  auto callOp = dyn_cast<handshake::InstanceOp>(op);
103  return callOp ? callOp.getModule().str() : getBareSubModuleName(op);
104 }
105 
106 /// Extracts the type of the data-carrying type of opType. If opType is an ESI
107 /// channel, getHandshakeBundleDataType extracts the data-carrying type, else,
108 /// assume that opType itself is the data-carrying type.
109 static Type getOperandDataType(Value op) {
110  auto opType = op.getType();
111  if (auto channelType = dyn_cast<esi::ChannelType>(opType))
112  return channelType.getInner();
113  return opType;
114 }
115 
116 /// Filters NoneType's from the input.
117 static SmallVector<Type> filterNoneTypes(ArrayRef<Type> input) {
118  SmallVector<Type> filterRes;
119  llvm::copy_if(input, std::back_inserter(filterRes),
120  [](Type type) { return !isa<NoneType>(type); });
121  return filterRes;
122 }
123 
124 /// Returns a set of types which may uniquely identify the provided op. Return
125 /// value is <inputTypes, outputTypes>.
126 using DiscriminatingTypes = std::pair<SmallVector<Type>, SmallVector<Type>>;
128  return TypeSwitch<Operation *, DiscriminatingTypes>(op)
129  .Case<MemoryOp, ExternalMemoryOp>([&](auto memOp) {
130  return DiscriminatingTypes{{},
131  {memOp.getMemRefType().getElementType()}};
132  })
133  .Default([&](auto) {
134  // By default, all in- and output types which is not a control type
135  // (NoneType) are discriminating types.
136  std::vector<Type> inTypes, outTypes;
137  llvm::transform(op->getOperands(), std::back_inserter(inTypes),
139  llvm::transform(op->getResults(), std::back_inserter(outTypes),
141  return DiscriminatingTypes{filterNoneTypes(inTypes),
142  filterNoneTypes(outTypes)};
143  });
144 }
145 
146 /// Get type name. Currently we only support integer or index types.
147 /// The emitted type aligns with the getFIRRTLType() method. Thus all integers
148 /// other than signed integers will be emitted as unsigned.
149 // NOLINTNEXTLINE(misc-no-recursion)
150 static std::string getTypeName(Location loc, Type type) {
151  std::string typeName;
152  // Builtin types
153  if (type.isIntOrIndex()) {
154  if (auto indexType = dyn_cast<IndexType>(type))
155  typeName += "_ui" + std::to_string(indexType.kInternalStorageBitWidth);
156  else if (type.isSignedInteger())
157  typeName += "_si" + std::to_string(type.getIntOrFloatBitWidth());
158  else
159  typeName += "_ui" + std::to_string(type.getIntOrFloatBitWidth());
160  } else if (auto tupleType = dyn_cast<TupleType>(type)) {
161  typeName += "_tuple";
162  for (auto elementType : tupleType.getTypes())
163  typeName += getTypeName(loc, elementType);
164  } else if (auto structType = dyn_cast<hw::StructType>(type)) {
165  typeName += "_struct";
166  for (auto element : structType.getElements())
167  typeName += "_" + element.name.str() + getTypeName(loc, element.type);
168  } else
169  emitError(loc) << "unsupported data type '" << type << "'";
170 
171  return typeName;
172 }
173 
174 /// Construct a name for creating HW sub-module.
175 static std::string getSubModuleName(Operation *oldOp) {
176  if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
177  return instanceOp.getModule().str();
178 
179  std::string subModuleName = getBareSubModuleName(oldOp);
180 
181  // Add value of the constant operation.
182  if (auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
183  if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
184  auto intType = intAttr.getType();
185 
186  if (intType.isSignedInteger())
187  subModuleName += "_c" + std::to_string(intAttr.getSInt());
188  else if (intType.isUnsignedInteger())
189  subModuleName += "_c" + std::to_string(intAttr.getUInt());
190  else
191  subModuleName += "_c" + std::to_string((uint64_t)intAttr.getInt());
192  } else
193  oldOp->emitError("unsupported constant type");
194  }
195 
196  // Add discriminating in- and output types.
197  auto [inTypes, outTypes] = getHandshakeDiscriminatingTypes(oldOp);
198  if (!inTypes.empty())
199  subModuleName += "_in";
200  for (auto inType : inTypes)
201  subModuleName += getTypeName(oldOp->getLoc(), inType);
202 
203  if (!outTypes.empty())
204  subModuleName += "_out";
205  for (auto outType : outTypes)
206  subModuleName += getTypeName(oldOp->getLoc(), outType);
207 
208  // Add memory ID.
209  if (auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
210  subModuleName += "_id" + std::to_string(memOp.getId());
211 
212  // Add compare kind.
213  if (auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
214  subModuleName += "_" + stringifyEnum(comOp.getPredicate()).str();
215 
216  // Add buffer information.
217  if (auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
218  subModuleName += "_" + std::to_string(bufferOp.getNumSlots()) + "slots";
219  if (bufferOp.isSequential())
220  subModuleName += "_seq";
221  else
222  subModuleName += "_fifo";
223 
224  if (auto initValues = bufferOp.getInitValues()) {
225  subModuleName += "_init";
226  for (const Attribute e : *initValues) {
227  assert(isa<IntegerAttr>(e));
228  subModuleName +=
229  "_" + std::to_string(dyn_cast<IntegerAttr>(e).getInt());
230  }
231  }
232  }
233 
234  // Add control information.
235  if (auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
236  ctrlInterface && ctrlInterface.isControl()) {
237  // Add some additional discriminating info for non-typed operations.
238  subModuleName += "_" + std::to_string(oldOp->getNumOperands()) + "ins_" +
239  std::to_string(oldOp->getNumResults()) + "outs";
240  subModuleName += "_ctrl";
241  } else {
242  assert(
243  (!inTypes.empty() || !outTypes.empty()) &&
244  "Insufficient discriminating type info generated for the operation!");
245  }
246 
247  return subModuleName;
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // HW Sub-module Related Functions
252 //===----------------------------------------------------------------------===//
253 
254 /// Check whether a submodule with the same name has been created elsewhere in
255 /// the top level module. Return the matched module operation if true, otherwise
256 /// return nullptr.
257 static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
258  StringRef modName) {
259  if (auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
260  return mod;
261  if (auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
262  return mod;
263  return {};
264 }
265 
266 static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
267  Operation *oldOp) {
268  HWModuleLike targetModule;
269  if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
270  targetModule = checkSubModuleOp(parentModule, instanceOp.getModule());
271  else
272  targetModule = checkSubModuleOp(parentModule, getSubModuleName(oldOp));
273 
274  if (isa<handshake::InstanceOp>(oldOp))
275  assert(targetModule &&
276  "handshake.instance target modules should always have been lowered "
277  "before the modules that reference them!");
278  return targetModule;
279 }
280 
281 /// Returns a vector of PortInfo's which defines the HW interface of the
282 /// to-be-converted op.
283 static ModulePortInfo getPortInfoForOp(Operation *op) {
284  return getPortInfoForOpTypes(op, op->getOperandTypes(), op->getResultTypes());
285 }
286 
287 static llvm::SmallVector<hw::detail::FieldInfo>
288 portToFieldInfo(llvm::ArrayRef<hw::PortInfo> portInfo) {
289  llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
290  for (auto port : portInfo)
291  fieldInfo.push_back({port.name, port.type});
292 
293  return fieldInfo;
294 }
295 
296 // Convert any handshake.extmemory operations and the top-level I/O
297 // associated with these.
298 static LogicalResult convertExtMemoryOps(HWModuleOp mod) {
299  auto *ctx = mod.getContext();
300 
301  // Gather memref ports to be converted.
302  llvm::DenseMap<unsigned, Value> memrefPorts;
303  for (auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
304  auto channel = dyn_cast<esi::ChannelType>(arg.getType());
305  if (channel && isa<MemRefType>(channel.getInner()))
306  memrefPorts[i] = arg;
307  }
308 
309  if (memrefPorts.empty())
310  return success(); // nothing to do.
311 
312  OpBuilder b(mod);
313 
314  auto getMemoryIOInfo = [&](Location loc, Twine portName, unsigned argIdx,
315  ArrayRef<hw::PortInfo> info,
316  hw::ModulePort::Direction direction) {
317  auto type = hw::StructType::get(ctx, portToFieldInfo(info));
318  auto portInfo =
319  hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
320  return portInfo;
321  };
322 
323  for (auto [i, arg] : memrefPorts) {
324  // Insert ports into the module
325  auto memName = mod.getArgName(i);
326 
327  // Get the attached extmemory external module.
328  auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
329  auto extmemMod =
330  cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
331  extmemInstance, extmemInstance.getModuleNameAttr()));
332 
333  ModulePortInfo portInfo(extmemMod.getPortList());
334 
335  // The extmemory external module's interface is a direct wrapping of the
336  // original handshake.extmemory operation in- and output types. Remove the
337  // first input argument (the !esi.channel<memref> op) since that is what
338  // we're replacing with a materialized interface.
339  portInfo.eraseInput(0);
340 
341  // Add memory input - this is the output of the extmemory op.
342  SmallVector<PortInfo> outputs(portInfo.getOutputs());
343  auto inPortInfo =
344  getMemoryIOInfo(arg.getLoc(), memName.strref() + "_in", i, outputs,
346  mod.insertPorts({{i, inPortInfo}}, {});
347  auto newInPort = mod.getArgumentForInput(i);
348  // Replace the extmemory submodule outputs with the newly created inputs.
349  b.setInsertionPointToStart(mod.getBodyBlock());
350  auto newInPortExploded = b.create<hw::StructExplodeOp>(
351  arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
352  extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
353 
354  // Add memory output - this is the inputs of the extmemory op (without the
355  // first argument);
356  unsigned outArgI = mod.getNumOutputPorts();
357  SmallVector<PortInfo> inputs(portInfo.getInputs());
358  auto outPortInfo =
359  getMemoryIOInfo(arg.getLoc(), memName.strref() + "_out", outArgI,
361 
362  auto memOutputArgs = extmemInstance.getOperands().drop_front();
363  b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
364  auto memOutputStruct = b.create<hw::StructCreateOp>(
365  arg.getLoc(), outPortInfo.type, memOutputArgs);
366  mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
367 
368  // Erase the extmemory submodule instace since the i/o has now been
369  // plumbed.
370  extmemMod.erase();
371  extmemInstance.erase();
372 
373  // Erase the original memref argument of the top-level i/o now that it's use
374  // has been removed.
375  mod.modifyPorts(/*insertInputs*/ {}, /*insertOutputs*/ {},
376  /*eraseInputs*/ {i + 1}, /*eraseOutputs*/ {});
377  }
378 
379  return success();
380 }
381 
382 namespace {
383 
384 // Input handshakes contain a resolved valid and (optional )data signal, and
385 // a to-be-assigned ready signal.
386 struct InputHandshake {
387  Value valid;
388  std::shared_ptr<Backedge> ready;
389  Value data;
390 };
391 
392 // Output handshakes contain a resolved ready, and to-be-assigned valid and
393 // (optional) data signals.
394 struct OutputHandshake {
395  std::shared_ptr<Backedge> valid;
396  Value ready;
397  std::shared_ptr<Backedge> data;
398 };
399 
400 /// A helper struct that acts like a wire. Can be used to interact with the
401 /// RTLBuilder when multiple built components should be connected.
402 struct HandshakeWire {
403  HandshakeWire(BackedgeBuilder &bb, Type dataType) {
404  MLIRContext *ctx = dataType.getContext();
405  auto i1Type = IntegerType::get(ctx, 1);
406  valid = std::make_shared<Backedge>(bb.get(i1Type));
407  ready = std::make_shared<Backedge>(bb.get(i1Type));
408  data = std::make_shared<Backedge>(bb.get(dataType));
409  }
410 
411  // Functions that allow to treat a wire like an input or output port.
412  // **Careful**: Such a port will not be updated when backedges are resolved.
413  InputHandshake getAsInput() { return {*valid, ready, *data}; }
414  OutputHandshake getAsOutput() { return {valid, *ready, data}; }
415 
416  std::shared_ptr<Backedge> valid;
417  std::shared_ptr<Backedge> ready;
418  std::shared_ptr<Backedge> data;
419 };
420 
421 template <typename T, typename TInner>
422 llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
423  llvm::function_ref<T(TInner &)> extractor) {
424  llvm::SmallVector<T> result;
425  llvm::transform(container, std::back_inserter(result), extractor);
426  return result;
427 }
428 struct UnwrappedIO {
429  llvm::SmallVector<InputHandshake> inputs;
430  llvm::SmallVector<OutputHandshake> outputs;
431 
432  llvm::SmallVector<Value> getInputValids() {
433  return extractValues<Value, InputHandshake>(
434  inputs, [](auto &hs) { return hs.valid; });
435  }
436  llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
437  return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
438  inputs, [](auto &hs) { return hs.ready; });
439  }
440  llvm::SmallVector<Value> getInputDatas() {
441  return extractValues<Value, InputHandshake>(
442  inputs, [](auto &hs) { return hs.data; });
443  }
444  llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
445  return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
446  outputs, [](auto &hs) { return hs.valid; });
447  }
448  llvm::SmallVector<Value> getOutputReadys() {
449  return extractValues<Value, OutputHandshake>(
450  outputs, [](auto &hs) { return hs.ready; });
451  }
452  llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
453  return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
454  outputs, [](auto &hs) { return hs.data; });
455  }
456 };
457 
458 // A class containing a bunch of syntactic sugar to reduce builder function
459 // verbosity.
460 // @todo: should be moved to support.
461 struct RTLBuilder {
462  RTLBuilder(hw::ModulePortInfo info, OpBuilder &builder, Location loc,
463  Value clk = Value(), Value rst = Value())
464  : info(std::move(info)), b(builder), loc(loc), clk(clk), rst(rst) {}
465 
466  Value constant(const APInt &apv, std::optional<StringRef> name = {}) {
467  // Cannot use zero-width APInt's in DenseMap's, see
468  // https://github.com/llvm/llvm-project/issues/58013
469  bool isZeroWidth = apv.getBitWidth() == 0;
470  if (!isZeroWidth) {
471  auto it = constants.find(apv);
472  if (it != constants.end())
473  return it->second;
474  }
475 
476  auto cval = b.create<hw::ConstantOp>(loc, apv);
477  if (!isZeroWidth)
478  constants[apv] = cval;
479  return cval;
480  }
481 
482  Value constant(unsigned width, int64_t value,
483  std::optional<StringRef> name = {}) {
484  return constant(APInt(width, value));
485  }
486  std::pair<Value, Value> wrap(Value data, Value valid,
487  std::optional<StringRef> name = {}) {
488  auto wrapOp = b.create<esi::WrapValidReadyOp>(loc, data, valid);
489  return {wrapOp.getResult(0), wrapOp.getResult(1)};
490  }
491  std::pair<Value, Value> unwrap(Value channel, Value ready,
492  std::optional<StringRef> name = {}) {
493  auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
494  return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
495  }
496 
497  // Various syntactic sugar functions.
498  Value reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
499  Value rst = Value()) {
500  Value resolvedClk = clk ? clk : this->clk;
501  Value resolvedRst = rst ? rst : this->rst;
502  assert(resolvedClk &&
503  "No global clock provided to this RTLBuilder - a clock "
504  "signal must be provided to the reg(...) function.");
505  assert(resolvedRst &&
506  "No global reset provided to this RTLBuilder - a reset "
507  "signal must be provided to the reg(...) function.");
508 
509  return b.create<seq::CompRegOp>(loc, in, resolvedClk, resolvedRst, rstValue,
510  name);
511  }
512 
513  Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
514  std::optional<StringRef> name = {}) {
515  return b.create<comb::ICmpOp>(loc, predicate, lhs, rhs);
516  }
517 
518  Value buildNamedOp(llvm::function_ref<Value()> f,
519  std::optional<StringRef> name) {
520  Value v = f();
521  StringAttr nameAttr;
522  Operation *op = v.getDefiningOp();
523  if (name.has_value()) {
524  op->setAttr("sv.namehint", b.getStringAttr(*name));
525  nameAttr = b.getStringAttr(*name);
526  }
527  return v;
528  }
529 
530  // Bitwise 'and'.
531  Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
532  return buildNamedOp(
533  [&]() { return b.create<comb::AndOp>(loc, values, false); }, name);
534  }
535 
536  Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
537  return buildNamedOp(
538  [&]() { return b.create<comb::OrOp>(loc, values, false); }, name);
539  }
540 
541  // Bitwise 'not'.
542  Value bNot(Value value, std::optional<StringRef> name = {}) {
543  auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
544  std::string inferedName;
545  if (!name) {
546  // Try to create a name from the input value.
547  if (auto valueName =
548  value.getDefiningOp()->getAttrOfType<StringAttr>("sv.namehint")) {
549  inferedName = ("not_" + valueName.getValue()).str();
550  name = inferedName;
551  }
552  }
553 
554  return buildNamedOp(
555  [&]() { return b.create<comb::XorOp>(loc, value, allOnes); }, name);
556 
557  return b.createOrFold<comb::XorOp>(loc, value, allOnes, false);
558  }
559 
560  Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
561  return buildNamedOp(
562  [&]() { return b.create<comb::ShlOp>(loc, value, shift); }, name);
563  }
564 
565  Value concat(ValueRange values, std::optional<StringRef> name = {}) {
566  return buildNamedOp([&]() { return b.create<comb::ConcatOp>(loc, values); },
567  name);
568  }
569 
570  // Packs a list of values into a hw.struct.
571  Value pack(ValueRange values, Type structType = Type(),
572  std::optional<StringRef> name = {}) {
573  if (!structType)
574  structType = tupleToStruct(values.getTypes());
575  return buildNamedOp(
576  [&]() { return b.create<hw::StructCreateOp>(loc, structType, values); },
577  name);
578  }
579 
580  // Unpacks a hw.struct into a list of values.
581  ValueRange unpack(Value value) {
582  auto structType = cast<hw::StructType>(value.getType());
583  llvm::SmallVector<Type> innerTypes;
584  structType.getInnerTypes(innerTypes);
585  return b.create<hw::StructExplodeOp>(loc, innerTypes, value).getResults();
586  }
587 
588  llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
589  llvm::SmallVector<Value> bits;
590  for (unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
591  bits.push_back(b.create<comb::ExtractOp>(loc, v, i, /*bitWidth=*/1));
592  return bits;
593  }
594 
595  // OR-reduction of the bits in 'v'.
596  Value rOr(Value v, std::optional<StringRef> name = {}) {
597  return buildNamedOp([&]() { return bOr(toBits(v)); }, name);
598  }
599 
600  // Extract bits v[hi:lo] (inclusive).
601  Value extract(Value v, unsigned lo, unsigned hi,
602  std::optional<StringRef> name = {}) {
603  unsigned width = hi - lo + 1;
604  return buildNamedOp(
605  [&]() { return b.create<comb::ExtractOp>(loc, v, lo, width); }, name);
606  }
607 
608  // Truncates 'value' to its lower 'width' bits.
609  Value truncate(Value value, unsigned width,
610  std::optional<StringRef> name = {}) {
611  return extract(value, 0, width - 1, name);
612  }
613 
614  Value zext(Value value, unsigned outWidth,
615  std::optional<StringRef> name = {}) {
616  unsigned inWidth = value.getType().getIntOrFloatBitWidth();
617  assert(inWidth <= outWidth && "zext: input width must be <- output width.");
618  if (inWidth == outWidth)
619  return value;
620  auto c0 = constant(outWidth - inWidth, 0);
621  return concat({c0, value}, name);
622  }
623 
624  Value sext(Value value, unsigned outWidth,
625  std::optional<StringRef> name = {}) {
626  return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
627  }
628 
629  // Extracts a single bit v[bit].
630  Value bit(Value v, unsigned index, std::optional<StringRef> name = {}) {
631  return extract(v, index, index, name);
632  }
633 
634  // Creates a hw.array of the given values.
635  Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
636  return buildNamedOp(
637  [&]() { return b.create<hw::ArrayCreateOp>(loc, values); }, name);
638  }
639 
640  // Extract the 'index'th value from the input array.
641  Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
642  return buildNamedOp(
643  [&]() { return b.create<hw::ArrayGetOp>(loc, array, index); }, name);
644  }
645 
646  // Muxes a range of values.
647  // The select signal is expected to be a decimal value which selects starting
648  // from the lowest index of value.
649  Value mux(Value index, ValueRange values,
650  std::optional<StringRef> name = {}) {
651  if (values.size() == 2)
652  return b.create<comb::MuxOp>(loc, index, values[1], values[0]);
653 
654  return arrayGet(arrayCreate(values), index, name);
655  }
656 
657  // Muxes a range of values. The select signal is expected to be a 1-hot
658  // encoded value.
659  Value ohMux(Value index, ValueRange inputs) {
660  // Confirm the select input can be a one-hot encoding for the inputs.
661  unsigned numInputs = inputs.size();
662  assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
663  "one-hot select can't mux inputs");
664 
665  // Start the mux tree with zero value.
666  // Todo: clean up when handshake supports i0.
667  auto dataType = inputs[0].getType();
668  unsigned width =
669  isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
670  Value muxValue = constant(width, 0);
671 
672  // Iteratively chain together muxes from the high bit to the low bit.
673  for (size_t i = numInputs - 1; i != 0; --i) {
674  Value input = inputs[i];
675  Value selectBit = bit(index, i);
676  muxValue = mux(selectBit, {muxValue, input});
677  }
678 
679  return muxValue;
680  }
681 
682  hw::ModulePortInfo info;
683  OpBuilder &b;
684  Location loc;
685  Value clk, rst;
686  DenseMap<APInt, Value> constants;
687 };
688 
689 /// Creates a Value that has an assigned zero value. For structs, this
690 /// corresponds to assigning zero to each element recursively.
691 static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
692  return TypeSwitch<Type, Value>(type)
693  .Case<NoneType>([&](NoneType) { return s.constant(0, 0); })
694  .Case<IntType, IntegerType>([&](auto type) {
695  return s.constant(type.getIntOrFloatBitWidth(), 0);
696  })
697  .Case<hw::StructType>([&](auto structType) {
698  SmallVector<Value> zeroValues;
699  for (auto field : structType.getElements())
700  zeroValues.push_back(createZeroDataConst(s, loc, field.type));
701  return s.b.create<hw::StructCreateOp>(loc, structType, zeroValues);
702  })
703  .Default([&](Type) -> Value {
704  emitError(loc) << "unsupported type for zero value: " << type;
705  assert(false);
706  return {};
707  });
708 }
709 
710 static void
711 addSequentialIOOperandsIfNeeded(Operation *op,
712  llvm::SmallVectorImpl<Value> &operands) {
713  if (op->hasTrait<mlir::OpTrait::HasClock>()) {
714  // Parent should at this point be a hw.module and have clock and reset
715  // ports.
716  auto parent = cast<hw::HWModuleOp>(op->getParentOp());
717  operands.push_back(
718  parent.getArgumentForInput(parent.getNumInputPorts() - 2));
719  operands.push_back(
720  parent.getArgumentForInput(parent.getNumInputPorts() - 1));
721  }
722 }
723 
724 template <typename T>
725 class HandshakeConversionPattern : public OpConversionPattern<T> {
726 public:
727  HandshakeConversionPattern(ESITypeConverter &typeConverter,
728  MLIRContext *context, OpBuilder &submoduleBuilder,
729  HandshakeLoweringState &ls)
730  : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
731  submoduleBuilder(submoduleBuilder), ls(ls) {}
732 
733  using OpAdaptor = typename T::Adaptor;
734 
735  LogicalResult
736  matchAndRewrite(T op, OpAdaptor adaptor,
737  ConversionPatternRewriter &rewriter) const override {
738 
739  // Check if a submodule has already been created for the op. If so,
740  // instantiate the submodule. Else, run the pattern-defined module
741  // builder.
742  hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
743  if (!implModule) {
744  auto portInfo = ModulePortInfo(getPortInfoForOp(op));
745 
746  submoduleBuilder.setInsertionPoint(op->getParentOp());
747  implModule = submoduleBuilder.create<hw::HWModuleOp>(
748  op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
749  portInfo, [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
750  // if 'op' has clock trait, extract these and provide them to the
751  // RTL builder.
752  Value clk, rst;
753  if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
754  clk = ports.getInput("clock");
755  rst = ports.getInput("reset");
756  }
757 
758  BackedgeBuilder bb(b, op.getLoc());
759  RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
760  this->buildModule(op, bb, s, ports);
761  });
762  }
763 
764  // Instantiate the submodule.
765  llvm::SmallVector<Value> operands = adaptor.getOperands();
766  addSequentialIOOperandsIfNeeded(op, operands);
767  rewriter.replaceOpWithNewOp<hw::InstanceOp>(
768  op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
769  return success();
770  }
771 
772  virtual void buildModule(T op, BackedgeBuilder &bb, RTLBuilder &builder,
773  hw::HWModulePortAccessor &ports) const = 0;
774 
775  // Syntactic sugar functions.
776  // Unwraps an ESI-interfaced module into its constituent handshake signals.
777  // Backedges are created for the to-be-resolved signals, and output ports
778  // are assigned to their wrapped counterparts.
779  UnwrappedIO unwrapIO(RTLBuilder &s, BackedgeBuilder &bb,
780  hw::HWModulePortAccessor &ports) const {
781  UnwrappedIO unwrapped;
782  for (auto port : ports.getInputs()) {
783  if (!isa<esi::ChannelType>(port.getType()))
784  continue;
785  InputHandshake hs;
786  auto ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
787  auto [data, valid] = s.unwrap(port, *ready);
788  hs.data = data;
789  hs.valid = valid;
790  hs.ready = ready;
791  unwrapped.inputs.push_back(hs);
792  }
793  for (auto &outputInfo : ports.getPortList().getOutputs()) {
794  esi::ChannelType channelType =
795  dyn_cast<esi::ChannelType>(outputInfo.type);
796  if (!channelType)
797  continue;
798  OutputHandshake hs;
799  Type innerType = channelType.getInner();
800  auto data = std::make_shared<Backedge>(bb.get(innerType));
801  auto valid = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
802  auto [dataCh, ready] = s.wrap(*data, *valid);
803  hs.data = data;
804  hs.valid = valid;
805  hs.ready = ready;
806  ports.setOutput(outputInfo.name, dataCh);
807  unwrapped.outputs.push_back(hs);
808  }
809  return unwrapped;
810  }
811 
812  void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
813  OutputHandshake &output, Value cond) const {
814  auto validAndReady = s.bAnd({output.ready, cond});
815  for (auto &input : inputs)
816  input.ready->setValue(validAndReady);
817  }
818 
819  void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
820  OutputHandshake &output) const {
821  llvm::SmallVector<Value> valids;
822  for (auto &input : inputs)
823  valids.push_back(input.valid);
824  Value allValid = s.bAnd(valids);
825  output.valid->setValue(allValid);
826  setAllReadyWithCond(s, inputs, output, allValid);
827  }
828 
829  // Builds mux logic for the given inputs and outputs.
830  // Note: it is assumed that the caller has removed the 'select' signal from
831  // the 'unwrapped' inputs and provide it as a separate argument.
832  void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
833  InputHandshake &select) const {
834  // ============================= Control logic =============================
835  size_t numInputs = unwrapped.inputs.size();
836  size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
837  Value truncatedSelect =
838  select.data.getType().getIntOrFloatBitWidth() > selectWidth
839  ? s.truncate(select.data, selectWidth)
840  : select.data;
841 
842  // Decimal-to-1-hot decoder. 'shl' operands must be identical in size.
843  auto selectZext = s.zext(truncatedSelect, numInputs);
844  auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
845  auto &res = unwrapped.outputs[0];
846 
847  // Mux input valid signals.
848  auto selectedInputValid =
849  s.mux(truncatedSelect, unwrapped.getInputValids());
850  // Result is valid when the selected input and the select input is valid.
851  auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
852  res.valid->setValue(selAndInputValid);
853  auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
854 
855  // Select is ready when result is valid and ready (result transacting).
856  select.ready->setValue(resValidAndReady);
857 
858  // Assign each input ready signal if it is currently selected.
859  for (auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
860  // Extract the selection bit for this input.
861  auto isSelected = s.bit(select1h, inIdx);
862 
863  // '&' that with the result valid and ready, and assign to the input
864  // ready signal.
865  auto activeAndResultValidAndReady =
866  s.bAnd({isSelected, resValidAndReady});
867  in.ready->setValue(activeAndResultValidAndReady);
868  }
869 
870  // ============================== Data logic ===============================
871  res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
872  }
873 
874  // Builds fork logic between the single input and multiple outputs' control
875  // networks. Caller is expected to handle data separately.
876  void buildForkLogic(RTLBuilder &s, BackedgeBuilder &bb, InputHandshake &input,
877  ArrayRef<OutputHandshake> outputs) const {
878  auto c0I1 = s.constant(1, 0);
879  llvm::SmallVector<Value> doneWires;
880  for (auto [i, output] : llvm::enumerate(outputs)) {
881  auto doneBE = bb.get(s.b.getI1Type());
882  auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
883  auto emittedReg = s.reg("emitted_" + std::to_string(i), emitted, c0I1);
884  auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
885  output.valid->setValue(outValid);
886  auto validReady = s.bAnd({output.ready, outValid});
887  auto done = s.bOr({validReady, emittedReg}, "done" + std::to_string(i));
888  doneBE.setValue(done);
889  doneWires.push_back(done);
890  }
891  input.ready->setValue(s.bAnd(doneWires, "allDone"));
892  }
893 
894  // Builds a unit-rate actor around an inner operation. 'unitBuilder' is a
895  // function which takes the set of unwrapped data inputs, and returns a
896  // value which should be assigned to the output data value.
897  void buildUnitRateJoinLogic(
898  RTLBuilder &s, UnwrappedIO &unwrappedIO,
899  llvm::function_ref<Value(ValueRange)> unitBuilder) const {
900  assert(unwrappedIO.outputs.size() == 1 &&
901  "Expected exactly one output for unit-rate join actor");
902  // Control logic.
903  this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
904 
905  // Data logic.
906  auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
907  unwrappedIO.outputs[0].data->setValue(unitRes);
908  }
909 
910  void buildUnitRateForkLogic(
911  RTLBuilder &s, BackedgeBuilder &bb, UnwrappedIO &unwrappedIO,
912  llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder) const {
913  assert(unwrappedIO.inputs.size() == 1 &&
914  "Expected exactly one input for unit-rate fork actor");
915  // Control logic.
916  this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
917 
918  // Data logic.
919  auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
920  assert(unitResults.size() == unwrappedIO.outputs.size() &&
921  "Expected unit builder to return one result per output");
922  for (auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
923  outport.data->setValue(res);
924  }
925 
926  void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
927  bool signExtend) const {
928  size_t outWidth =
929  toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
930  .getIntOrFloatBitWidth();
931  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
932  if (signExtend)
933  return s.sext(inputs[0], outWidth);
934  return s.zext(inputs[0], outWidth);
935  });
936  }
937 
938  void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
939  unsigned targetWidth) const {
940  size_t outWidth =
941  toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
942  .getIntOrFloatBitWidth();
943  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
944  return s.truncate(inputs[0], outWidth);
945  });
946  }
947 
948  /// Return the number of bits needed to index the given number of values.
949  static size_t getNumIndexBits(uint64_t numValues) {
950  return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
951  }
952 
953  Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
954  Value defaultValue,
955  DenseMap<size_t, Value> &indexMapping) const {
956  auto numInputs = inputs.size();
957  auto priorityArb = defaultValue;
958 
959  for (size_t i = numInputs; i > 0; --i) {
960  size_t inputIndex = i - 1;
961  size_t oneHotIndex = size_t{1} << inputIndex;
962  auto constIndex = s.constant(numInputs, oneHotIndex);
963  indexMapping[inputIndex] = constIndex;
964  priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
965  }
966  return priorityArb;
967  }
968 
969 private:
970  OpBuilder &submoduleBuilder;
971  HandshakeLoweringState &ls;
972 };
973 
974 class ForkConversionPattern : public HandshakeConversionPattern<ForkOp> {
975 public:
976  using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
977  void buildModule(ForkOp op, BackedgeBuilder &bb, RTLBuilder &s,
978  hw::HWModulePortAccessor &ports) const override {
979  auto unwrapped = unwrapIO(s, bb, ports);
980  buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
981  return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
982  });
983  }
984 };
985 
986 class JoinConversionPattern : public HandshakeConversionPattern<JoinOp> {
987 public:
988  using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
989  void buildModule(JoinOp op, BackedgeBuilder &bb, RTLBuilder &s,
990  hw::HWModulePortAccessor &ports) const override {
991  auto unwrappedIO = unwrapIO(s, bb, ports);
992  buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
993  unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
994  };
995 };
996 
997 class SyncConversionPattern : public HandshakeConversionPattern<SyncOp> {
998 public:
999  using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1000  void buildModule(SyncOp op, BackedgeBuilder &bb, RTLBuilder &s,
1001  hw::HWModulePortAccessor &ports) const override {
1002  auto unwrappedIO = unwrapIO(s, bb, ports);
1003 
1004  // A helper wire that will be used to connect the two built logics
1005  HandshakeWire wire(bb, s.b.getNoneType());
1006 
1007  OutputHandshake output = wire.getAsOutput();
1008  buildJoinLogic(s, unwrappedIO.inputs, output);
1009 
1010  InputHandshake input = wire.getAsInput();
1011 
1012  // The state-keeping fork logic is required here, as the circuit isn't
1013  // allowed to wait for all the consumers to be ready. Connecting the ready
1014  // signals of the outputs to their corresponding valid signals leads to
1015  // combinatorial cycles. The paper which introduced compositional dataflow
1016  // circuits explicitly mentions this limitation:
1017  // http://arcade.cs.columbia.edu/df-memocode17.pdf
1018  buildForkLogic(s, bb, input, unwrappedIO.outputs);
1019 
1020  // Directly connect the data wires, only the control signals need to be
1021  // combined.
1022  for (auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1023  out.data->setValue(in.data);
1024  };
1025 };
1026 
1027 class MuxConversionPattern : public HandshakeConversionPattern<MuxOp> {
1028 public:
1029  using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1030  void buildModule(MuxOp op, BackedgeBuilder &bb, RTLBuilder &s,
1031  hw::HWModulePortAccessor &ports) const override {
1032  auto unwrappedIO = unwrapIO(s, bb, ports);
1033 
1034  // Extract select signal from the unwrapped IO.
1035  auto select = unwrappedIO.inputs[0];
1036  unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1037  buildMuxLogic(s, unwrappedIO, select);
1038  };
1039 };
1040 
1041 class InstanceConversionPattern
1042  : public HandshakeConversionPattern<handshake::InstanceOp> {
1043 public:
1044  using HandshakeConversionPattern<
1045  handshake::InstanceOp>::HandshakeConversionPattern;
1046  void buildModule(handshake::InstanceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1047  hw::HWModulePortAccessor &ports) const override {
1048  assert(false &&
1049  "If we indeed perform conversion in post-order, this "
1050  "should never be called. The base HandshakeConversionPattern logic "
1051  "will instantiate the external module.");
1052  }
1053 };
1054 
1055 class ReturnConversionPattern
1056  : public OpConversionPattern<handshake::ReturnOp> {
1057 public:
1058  using OpConversionPattern::OpConversionPattern;
1059  LogicalResult
1060  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1061  ConversionPatternRewriter &rewriter) const override {
1062  // Locate existing output op, Append operands to output op, and move to
1063  // the end of the block.
1064  auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1065  auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1066  outputOp->setOperands(adaptor.getOperands());
1067  outputOp->moveAfter(&parent.getBodyBlock()->back());
1068  rewriter.eraseOp(op);
1069  return success();
1070  }
1071 };
1072 
1073 // Converts an arbitrary operation into a unit rate actor. A unit rate actor
1074 // will transact once all inputs are valid and its output is ready.
1075 template <typename TIn, typename TOut = TIn>
1076 class UnitRateConversionPattern : public HandshakeConversionPattern<TIn> {
1077 public:
1078  using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1079  void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1080  hw::HWModulePortAccessor &ports) const override {
1081  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1082  this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1083  // Create TOut - it is assumed that TOut trivially
1084  // constructs from the input data signals of TIn.
1085  // To disambiguate ambiguous builders with default arguments (e.g.,
1086  // twoState UnitAttr), specify attribute array explicitly.
1087  return s.b.create<TOut>(op.getLoc(), inputs,
1088  /* attributes */ ArrayRef<NamedAttribute>{});
1089  });
1090  };
1091 };
1092 
1093 class PackConversionPattern : public HandshakeConversionPattern<PackOp> {
1094 public:
1095  using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1096  void buildModule(PackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1097  hw::HWModulePortAccessor &ports) const override {
1098  auto unwrappedIO = unwrapIO(s, bb, ports);
1099  buildUnitRateJoinLogic(s, unwrappedIO,
1100  [&](ValueRange inputs) { return s.pack(inputs); });
1101  };
1102 };
1103 
1104 class StructCreateConversionPattern
1105  : public HandshakeConversionPattern<hw::StructCreateOp> {
1106 public:
1107  using HandshakeConversionPattern<
1108  hw::StructCreateOp>::HandshakeConversionPattern;
1109  void buildModule(hw::StructCreateOp op, BackedgeBuilder &bb, RTLBuilder &s,
1110  hw::HWModulePortAccessor &ports) const override {
1111  auto unwrappedIO = unwrapIO(s, bb, ports);
1112  auto structType = op.getResult().getType();
1113  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1114  return s.pack(inputs, structType);
1115  });
1116  };
1117 };
1118 
1119 class UnpackConversionPattern : public HandshakeConversionPattern<UnpackOp> {
1120 public:
1121  using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1122  void buildModule(UnpackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1123  hw::HWModulePortAccessor &ports) const override {
1124  auto unwrappedIO = unwrapIO(s, bb, ports);
1125  buildUnitRateForkLogic(s, bb, unwrappedIO,
1126  [&](Value input) { return s.unpack(input); });
1127  };
1128 };
1129 
1130 class ConditionalBranchConversionPattern
1131  : public HandshakeConversionPattern<ConditionalBranchOp> {
1132 public:
1133  using HandshakeConversionPattern<
1134  ConditionalBranchOp>::HandshakeConversionPattern;
1135  void buildModule(ConditionalBranchOp op, BackedgeBuilder &bb, RTLBuilder &s,
1136  hw::HWModulePortAccessor &ports) const override {
1137  auto unwrappedIO = unwrapIO(s, bb, ports);
1138  auto cond = unwrappedIO.inputs[0];
1139  auto arg = unwrappedIO.inputs[1];
1140  auto trueRes = unwrappedIO.outputs[0];
1141  auto falseRes = unwrappedIO.outputs[1];
1142 
1143  auto condArgValid = s.bAnd({cond.valid, arg.valid});
1144 
1145  // Connect valid signal of both results.
1146  trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1147  falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1148 
1149  // Connecte data signals of both results.
1150  trueRes.data->setValue(arg.data);
1151  falseRes.data->setValue(arg.data);
1152 
1153  // Connect ready signal of input and condition.
1154  auto selectedResultReady =
1155  s.mux(cond.data, {falseRes.ready, trueRes.ready});
1156  auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1157  arg.ready->setValue(condArgReady);
1158  cond.ready->setValue(condArgReady);
1159  };
1160 };
1161 
1162 template <typename TIn, bool signExtend>
1163 class ExtendConversionPattern : public HandshakeConversionPattern<TIn> {
1164 public:
1165  using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1166  void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1167  hw::HWModulePortAccessor &ports) const override {
1168  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1169  this->buildExtendLogic(s, unwrappedIO, /*signExtend=*/signExtend);
1170  };
1171 };
1172 
1173 class ComparisonConversionPattern
1174  : public HandshakeConversionPattern<arith::CmpIOp> {
1175 public:
1176  using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1177  void buildModule(arith::CmpIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1178  hw::HWModulePortAccessor &ports) const override {
1179  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1180  auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1181  return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1182  return s.b.create<comb::ICmpOp>(op.getLoc(), predicate, inputs[0],
1183  inputs[1]);
1184  });
1185  };
1186 
1187  switch (op.getPredicate()) {
1188  case arith::CmpIPredicate::eq:
1189  return buildCompareLogic(comb::ICmpPredicate::eq);
1190  case arith::CmpIPredicate::ne:
1191  return buildCompareLogic(comb::ICmpPredicate::ne);
1192  case arith::CmpIPredicate::slt:
1193  return buildCompareLogic(comb::ICmpPredicate::slt);
1194  case arith::CmpIPredicate::ult:
1195  return buildCompareLogic(comb::ICmpPredicate::ult);
1196  case arith::CmpIPredicate::sle:
1197  return buildCompareLogic(comb::ICmpPredicate::sle);
1198  case arith::CmpIPredicate::ule:
1199  return buildCompareLogic(comb::ICmpPredicate::ule);
1200  case arith::CmpIPredicate::sgt:
1201  return buildCompareLogic(comb::ICmpPredicate::sgt);
1202  case arith::CmpIPredicate::ugt:
1203  return buildCompareLogic(comb::ICmpPredicate::ugt);
1204  case arith::CmpIPredicate::sge:
1205  return buildCompareLogic(comb::ICmpPredicate::sge);
1206  case arith::CmpIPredicate::uge:
1207  return buildCompareLogic(comb::ICmpPredicate::uge);
1208  }
1209  assert(false && "invalid CmpIOp");
1210  };
1211 };
1212 
1213 class TruncateConversionPattern
1214  : public HandshakeConversionPattern<arith::TruncIOp> {
1215 public:
1216  using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1217  void buildModule(arith::TruncIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1218  hw::HWModulePortAccessor &ports) const override {
1219  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1220  unsigned targetBits =
1221  toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1222  buildTruncateLogic(s, unwrappedIO, targetBits);
1223  };
1224 };
1225 
1226 class ControlMergeConversionPattern
1227  : public HandshakeConversionPattern<ControlMergeOp> {
1228 public:
1229  using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1230  void buildModule(ControlMergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1231  hw::HWModulePortAccessor &ports) const override {
1232  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1233  auto resData = unwrappedIO.outputs[0];
1234  auto resIndex = unwrappedIO.outputs[1];
1235 
1236  // Define some common types and values that will be used.
1237  unsigned numInputs = unwrappedIO.inputs.size();
1238  auto indexType = s.b.getIntegerType(numInputs);
1239  Value noWinner = s.constant(numInputs, 0);
1240  Value c0I1 = s.constant(1, 0);
1241 
1242  // Declare register for storing arbitration winner.
1243  auto won = bb.get(indexType);
1244  Value wonReg = s.reg("won_reg", won, noWinner);
1245 
1246  // Declare wire for arbitration winner.
1247  auto win = bb.get(indexType);
1248 
1249  // Declare wire for whether the circuit just fired and emitted both
1250  // outputs.
1251  auto fired = bb.get(s.b.getI1Type());
1252 
1253  // Declare registers for storing if each output has been emitted.
1254  auto resultEmitted = bb.get(s.b.getI1Type());
1255  Value resultEmittedReg = s.reg("result_emitted_reg", resultEmitted, c0I1);
1256  auto indexEmitted = bb.get(s.b.getI1Type());
1257  Value indexEmittedReg = s.reg("index_emitted_reg", indexEmitted, c0I1);
1258 
1259  // Declare wires for if each output is done.
1260  auto resultDone = bb.get(s.b.getI1Type());
1261  auto indexDone = bb.get(s.b.getI1Type());
1262 
1263  // Create predicates to assert if the win wire or won register hold a
1264  // valid index.
1265  auto hasWinnerCondition = s.rOr({win});
1266  auto hadWinnerCondition = s.rOr({wonReg});
1267 
1268  // Create an arbiter based on a simple priority-encoding scheme to assign
1269  // an index to the win wire. If the won register is set, just use that. In
1270  // the case that won is not set and no input is valid, set a sentinel
1271  // value to indicate no winner was chosen. The constant values are
1272  // remembered in a map so they can be re-used later to assign the arg
1273  // ready outputs.
1274  DenseMap<size_t, Value> argIndexValues;
1275  Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1276  noWinner, argIndexValues);
1277  priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1278  win.setValue(priorityArb);
1279 
1280  // Create the logic to assign the result and index outputs. The result
1281  // valid output will always be assigned, and if isControl is not set, the
1282  // result data output will also be assigned. The index valid and data
1283  // outputs will always be assigned. The win wire from the arbiter is used
1284  // to index into a tree of muxes to select the chosen input's signal(s),
1285  // and is fed directly to the index output. Both the result and index
1286  // valid outputs are gated on the win wire being set to something other
1287  // than the sentinel value.
1288  auto resultNotEmitted = s.bNot(resultEmittedReg);
1289  auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1290  resData.valid->setValue(resultValid);
1291  resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1292 
1293  auto indexNotEmitted = s.bNot(indexEmittedReg);
1294  auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1295  resIndex.valid->setValue(indexValid);
1296 
1297  // Use the one-hot win wire to select the index to output in the index
1298  // data.
1299  SmallVector<Value, 8> indexOutputs;
1300  for (size_t i = 0; i < numInputs; ++i)
1301  indexOutputs.push_back(s.constant(64, i));
1302 
1303  auto indexOutput = s.ohMux(win, indexOutputs);
1304  resIndex.data->setValue(indexOutput);
1305 
1306  // Create the logic to set the won register. If the fired wire is
1307  // asserted, we have finished this round and can and reset the register to
1308  // the sentinel value that indicates there is no winner. Otherwise, we
1309  // need to hold the value of the win register until we can fire.
1310  won.setValue(s.mux(fired, {win, noWinner}));
1311 
1312  // Create the logic to set the done wires for the result and index. For
1313  // both outputs, the done wire is asserted when the output is valid and
1314  // ready, or the emitted register for that output is set.
1315  auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1316  resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1317 
1318  auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1319  indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1320 
1321  // Create the logic to set the fired wire. It is asserted when both result
1322  // and index are done.
1323  fired.setValue(s.bAnd({resultDone, indexDone}));
1324 
1325  // Create the logic to assign the emitted registers. If the fired wire is
1326  // asserted, we have finished this round and can reset the registers to 0.
1327  // Otherwise, we need to hold the values of the done registers until we
1328  // can fire.
1329  resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1330  indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1331 
1332  // Create the logic to assign the arg ready outputs. The logic is
1333  // identical for each arg. If the fired wire is asserted, and the win wire
1334  // holds an arg's index, that arg is ready.
1335  auto winnerOrDefault = s.mux(fired, {noWinner, win});
1336  for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1337  auto &indexValue = argIndexValues[i];
1338  ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1339  }
1340  };
1341 };
1342 
1343 class MergeConversionPattern : public HandshakeConversionPattern<MergeOp> {
1344 public:
1345  using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1346  void buildModule(MergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1347  hw::HWModulePortAccessor &ports) const override {
1348  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1349  auto resData = unwrappedIO.outputs[0];
1350 
1351  // Define some common types and values that will be used.
1352  unsigned numInputs = unwrappedIO.inputs.size();
1353  auto indexType = s.b.getIntegerType(numInputs);
1354  Value noWinner = s.constant(numInputs, 0);
1355 
1356  // Declare wire for arbitration winner.
1357  auto win = bb.get(indexType);
1358 
1359  // Create predicates to assert if the win wire holds a valid index.
1360  auto hasWinnerCondition = s.rOr(win);
1361 
1362  // Create an arbiter based on a simple priority-encoding scheme to assign an
1363  // index to the win wire. In the case that no input is valid, set a sentinel
1364  // value to indicate no winner was chosen. The constant values are
1365  // remembered in a map so they can be re-used later to assign the arg ready
1366  // outputs.
1367  DenseMap<size_t, Value> argIndexValues;
1368  Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1369  noWinner, argIndexValues);
1370  win.setValue(priorityArb);
1371 
1372  // Create the logic to assign the result outputs. The result valid and data
1373  // outputs will always be assigned. The win wire from the arbiter is used to
1374  // index into a tree of muxes to select the chosen input's signal(s). The
1375  // result outputs are gated on the win wire being non-zero.
1376 
1377  resData.valid->setValue(hasWinnerCondition);
1378  resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1379 
1380  // Create the logic to set the done wires for the result. The done wire is
1381  // asserted when the output is valid and ready, or the emitted register is
1382  // set.
1383  auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1384 
1385  // Create the logic to assign the arg ready outputs. The logic is
1386  // identical for each arg. If the fired wire is asserted, and the win wire
1387  // holds an arg's index, that arg is ready.
1388  auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1389  for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1390  auto &indexValue = argIndexValues[i];
1391  ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1392  }
1393  };
1394 };
1395 
1396 class LoadConversionPattern
1397  : public HandshakeConversionPattern<handshake::LoadOp> {
1398 public:
1399  using HandshakeConversionPattern<
1400  handshake::LoadOp>::HandshakeConversionPattern;
1401  void buildModule(handshake::LoadOp op, BackedgeBuilder &bb, RTLBuilder &s,
1402  hw::HWModulePortAccessor &ports) const override {
1403  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1404  auto addrFromUser = unwrappedIO.inputs[0];
1405  auto dataFromMem = unwrappedIO.inputs[1];
1406  auto controlIn = unwrappedIO.inputs[2];
1407  auto dataToUser = unwrappedIO.outputs[0];
1408  auto addrToMem = unwrappedIO.outputs[1];
1409 
1410  addrToMem.data->setValue(addrFromUser.data);
1411  dataToUser.data->setValue(dataFromMem.data);
1412 
1413  // The valid/ready logic between user address/control to memoryAddr is
1414  // join logic.
1415  buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1416 
1417  // The valid/ready logic between memoryData and outputData is a direct
1418  // connection.
1419  dataToUser.valid->setValue(dataFromMem.valid);
1420  dataFromMem.ready->setValue(dataToUser.ready);
1421  };
1422 };
1423 
1424 class StoreConversionPattern
1425  : public HandshakeConversionPattern<handshake::StoreOp> {
1426 public:
1427  using HandshakeConversionPattern<
1428  handshake::StoreOp>::HandshakeConversionPattern;
1429  void buildModule(handshake::StoreOp op, BackedgeBuilder &bb, RTLBuilder &s,
1430  hw::HWModulePortAccessor &ports) const override {
1431  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1432  auto addrFromUser = unwrappedIO.inputs[0];
1433  auto dataFromUser = unwrappedIO.inputs[1];
1434  auto controlIn = unwrappedIO.inputs[2];
1435  auto dataToMem = unwrappedIO.outputs[0];
1436  auto addrToMem = unwrappedIO.outputs[1];
1437 
1438  // Create a gate that will be asserted when all outputs are ready.
1439  auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1440 
1441  // Build the standard join logic from the inputs to the inputsValid and
1442  // outputsReady signals.
1443  HandshakeWire joinWire(bb, s.b.getNoneType());
1444  joinWire.ready->setValue(outputsReady);
1445  OutputHandshake joinOutput = joinWire.getAsOutput();
1446  buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1447 
1448  // Output address and data signals are connected directly.
1449  addrToMem.data->setValue(addrFromUser.data);
1450  dataToMem.data->setValue(dataFromUser.data);
1451 
1452  // Output valid signals are connected from the inputsValid wire.
1453  addrToMem.valid->setValue(*joinWire.valid);
1454  dataToMem.valid->setValue(*joinWire.valid);
1455  };
1456 };
1457 
1458 class MemoryConversionPattern
1459  : public HandshakeConversionPattern<handshake::MemoryOp> {
1460 public:
1461  using HandshakeConversionPattern<
1462  handshake::MemoryOp>::HandshakeConversionPattern;
1463  void buildModule(handshake::MemoryOp op, BackedgeBuilder &bb, RTLBuilder &s,
1464  hw::HWModulePortAccessor &ports) const override {
1465  auto loc = op.getLoc();
1466 
1467  // Gather up the load and store ports.
1468  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1469  struct LoadPort {
1470  InputHandshake &addr;
1471  OutputHandshake &data;
1472  OutputHandshake &done;
1473  };
1474  struct StorePort {
1475  InputHandshake &addr;
1476  InputHandshake &data;
1477  OutputHandshake &done;
1478  };
1479  SmallVector<LoadPort, 4> loadPorts;
1480  SmallVector<StorePort, 4> storePorts;
1481 
1482  unsigned stCount = op.getStCount();
1483  unsigned ldCount = op.getLdCount();
1484  for (unsigned i = 0, e = ldCount; i != e; ++i) {
1485  LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1486  unwrappedIO.outputs[i],
1487  unwrappedIO.outputs[ldCount + stCount + i]};
1488  loadPorts.push_back(port);
1489  }
1490 
1491  for (unsigned i = 0, e = stCount; i != e; ++i) {
1492  StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1493  unwrappedIO.inputs[i * 2],
1494  unwrappedIO.outputs[ldCount + i]};
1495  storePorts.push_back(port);
1496  }
1497 
1498  // used to drive the data wire of the control-only channels.
1499  auto c0I0 = s.constant(0, 0);
1500 
1501  auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1502  auto hlmem = s.b.create<seq::HLMemOp>(
1503  loc, s.clk, s.rst, "_handshake_memory_" + std::to_string(op.getId()),
1504  op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1505 
1506  // Create load ports...
1507  for (auto &ld : loadPorts) {
1508  llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1509  auto readData = s.b.create<seq::ReadPortOp>(loc, hlmem.getHandle(),
1510  addresses, ld.addr.valid,
1511  /*latency=*/0);
1512  ld.data.data->setValue(readData);
1513  ld.done.data->setValue(c0I0);
1514  // Create control fork for the load address valid and ready signals.
1515  buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1516  }
1517 
1518  // Create store ports...
1519  for (auto &st : storePorts) {
1520  // Create a register to buffer the valid path by 1 cycle, to match the
1521  // write latency of 1.
1522  auto writeValidBufferMuxBE = bb.get(s.b.getI1Type());
1523  auto writeValidBuffer =
1524  s.reg("writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1525  st.done.valid->setValue(writeValidBuffer);
1526  st.done.data->setValue(c0I0);
1527 
1528  // Create the logic for when both the buffered write valid signal and the
1529  // store complete ready signal are asserted.
1530  auto storeCompleted =
1531  s.bAnd({st.done.ready, writeValidBuffer}, "storeCompleted");
1532 
1533  // Create a signal for when the write valid buffer is empty or the output
1534  // is ready.
1535  auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1536  auto emptyOrComplete =
1537  s.bOr({notWriteValidBuffer, storeCompleted}, "emptyOrComplete");
1538 
1539  // Connect the gate to both the store address ready and store data ready
1540  st.addr.ready->setValue(emptyOrComplete);
1541  st.data.ready->setValue(emptyOrComplete);
1542 
1543  // Create a wire for when both the store address and data are valid.
1544  auto writeValid = s.bAnd({st.addr.valid, st.data.valid}, "writeValid");
1545 
1546  // Create a mux that drives the buffer input. If the emptyOrComplete
1547  // signal is asserted, the mux selects the writeValid signal. Otherwise,
1548  // it selects the buffer output, keeping the output registered until the
1549  // emptyOrComplete signal is asserted.
1550  writeValidBufferMuxBE.setValue(
1551  s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1552 
1553  // Instantiate the write port operation - truncate address width to memory
1554  // width.
1555  llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1556  s.b.create<seq::WritePortOp>(loc, hlmem.getHandle(), addresses,
1557  st.data.data, writeValid,
1558  /*latency=*/1);
1559  }
1560  }
1561 }; // namespace
1562 
1563 class SinkConversionPattern : public HandshakeConversionPattern<SinkOp> {
1564 public:
1565  using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1566  void buildModule(SinkOp op, BackedgeBuilder &bb, RTLBuilder &s,
1567  hw::HWModulePortAccessor &ports) const override {
1568  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1569  // A sink is always ready to accept a new value.
1570  unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1571  };
1572 };
1573 
1574 class SourceConversionPattern : public HandshakeConversionPattern<SourceOp> {
1575 public:
1576  using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1577  void buildModule(SourceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1578  hw::HWModulePortAccessor &ports) const override {
1579  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1580  // A source always provides a new (i0-typed) value.
1581  unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1582  unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1583  };
1584 };
1585 
1586 class ConstantConversionPattern
1587  : public HandshakeConversionPattern<handshake::ConstantOp> {
1588 public:
1589  using HandshakeConversionPattern<
1590  handshake::ConstantOp>::HandshakeConversionPattern;
1591  void buildModule(handshake::ConstantOp op, BackedgeBuilder &bb, RTLBuilder &s,
1592  hw::HWModulePortAccessor &ports) const override {
1593  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1594  unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1595  unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1596  auto constantValue = op->getAttrOfType<IntegerAttr>("value").getValue();
1597  unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1598  };
1599 };
1600 
1601 class BufferConversionPattern : public HandshakeConversionPattern<BufferOp> {
1602 public:
1603  using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1604  void buildModule(BufferOp op, BackedgeBuilder &bb, RTLBuilder &s,
1605  hw::HWModulePortAccessor &ports) const override {
1606  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1607  auto input = unwrappedIO.inputs[0];
1608  auto output = unwrappedIO.outputs[0];
1609  InputHandshake lastStage;
1610  SmallVector<int64_t> initValues;
1611 
1612  // For now, always build seq buffers.
1613  if (op.getInitValues())
1614  initValues = op.getInitValueArray();
1615 
1616  lastStage =
1617  buildSeqBufferLogic(s, bb, toValidType(op.getDataType()),
1618  op.getNumSlots(), input, output, initValues);
1619 
1620  // Connect the last stage to the output handshake.
1621  output.data->setValue(lastStage.data);
1622  output.valid->setValue(lastStage.valid);
1623  lastStage.ready->setValue(output.ready);
1624  };
1625 
1626  struct SeqBufferStage {
1627  SeqBufferStage(Type dataType, InputHandshake &preStage, BackedgeBuilder &bb,
1628  RTLBuilder &s, size_t index,
1629  std::optional<int64_t> initValue)
1630  : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1631 
1632  // Todo: Change when i0 support is added.
1633  c0s = createZeroDataConst(s, s.loc, dataType);
1634  currentStage.ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
1635 
1636  auto hasInitValue = s.constant(1, initValue.has_value());
1637  auto validBE = bb.get(s.b.getI1Type());
1638  auto validReg = s.reg(getRegName("valid"), validBE, hasInitValue);
1639  auto readyBE = bb.get(s.b.getI1Type());
1640 
1641  Value initValueCs = c0s;
1642  if (initValue.has_value())
1643  initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1644 
1645  // This could/should be revised but needs a larger rethinking to avoid
1646  // introducing new bugs.
1647  Value dataReg =
1648  buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1649  buildControlBufferLogic(validReg, readyBE, dataReg);
1650  }
1651 
1652  StringAttr getRegName(StringRef name) {
1653  return s.b.getStringAttr(name + std::to_string(index) + "_reg");
1654  }
1655 
1656  void buildControlBufferLogic(Value validReg, Backedge &readyBE,
1657  Value dataReg) {
1658  auto c0I1 = s.constant(1, 0);
1659  auto readyRegWire = bb.get(s.b.getI1Type());
1660  auto readyReg = s.reg(getRegName("ready"), readyRegWire, c0I1);
1661 
1662  // Create the logic to drive the current stage valid and potentially
1663  // data.
1664  currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1665  "controlValid" + std::to_string(index));
1666 
1667  // Create the logic to drive the current stage ready.
1668  auto notReadyReg = s.bNot(readyReg);
1669  readyBE.setValue(notReadyReg);
1670 
1671  auto succNotReady = s.bNot(*currentStage.ready);
1672  auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1673  auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1674  auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1675 
1676  // Create a mux for emptying the register when both are ready.
1677  auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1678  readyRegWire.setValue(resetSignal);
1679 
1680  // Add same logic for the data path if necessary.
1681  auto ctrlDataRegBE = bb.get(dataType);
1682  auto ctrlDataReg = s.reg(getRegName("ctrl_data"), ctrlDataRegBE, c0s);
1683  auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1684  currentStage.data = dataResult;
1685 
1686  auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1687  auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1688  ctrlDataRegBE.setValue(dataResetSignal);
1689  }
1690 
1691  Value buildDataBufferLogic(Value validReg, Value initValue,
1692  Backedge &validBE, Backedge &readyBE) {
1693  // Create a signal for when the valid register is empty or the successor
1694  // is ready to accept new token.
1695  auto notValidReg = s.bNot(validReg);
1696  auto emptyOrReady = s.bOr({notValidReg, readyBE});
1697  preStage.ready->setValue(emptyOrReady);
1698 
1699  // Create a mux that drives the register input. If the emptyOrReady
1700  // signal is asserted, the mux selects the predValid signal. Otherwise,
1701  // it selects the register output, keeping the output registered
1702  // unchanged.
1703  auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1704 
1705  // Now we can drive the valid register.
1706  validBE.setValue(validRegMux);
1707 
1708  // Create a mux that drives the date register.
1709  auto dataRegBE = bb.get(dataType);
1710  auto dataReg =
1711  s.reg(getRegName("data"),
1712  s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1713  dataRegBE.setValue(dataReg);
1714  return dataReg;
1715  }
1716 
1717  InputHandshake getOutput() { return currentStage; }
1718 
1719  Type dataType;
1720  InputHandshake &preStage;
1721  InputHandshake currentStage;
1722  RTLBuilder &s;
1723  BackedgeBuilder &bb;
1724  size_t index;
1725 
1726  // A zero-valued constant of equal type as the data type of this buffer.
1727  Value c0s;
1728  };
1729 
1730  InputHandshake buildSeqBufferLogic(RTLBuilder &s, BackedgeBuilder &bb,
1731  Type dataType, unsigned size,
1732  InputHandshake &input,
1733  OutputHandshake &output,
1734  llvm::ArrayRef<int64_t> initValues) const {
1735  // Prime the buffer building logic with an initial stage, which just
1736  // wraps the input handshake.
1737  InputHandshake currentStage = input;
1738 
1739  for (unsigned i = 0; i < size; ++i) {
1740  bool isInitialized = i < initValues.size();
1741  auto initValue =
1742  isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1743  currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1744  .getOutput();
1745  }
1746 
1747  return currentStage;
1748  };
1749 };
1750 
1751 class IndexCastConversionPattern
1752  : public HandshakeConversionPattern<arith::IndexCastOp> {
1753 public:
1754  using HandshakeConversionPattern<
1755  arith::IndexCastOp>::HandshakeConversionPattern;
1756  void buildModule(arith::IndexCastOp op, BackedgeBuilder &bb, RTLBuilder &s,
1757  hw::HWModulePortAccessor &ports) const override {
1758  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1759  unsigned sourceBits =
1760  toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1761  unsigned targetBits =
1762  toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1763  if (targetBits < sourceBits)
1764  buildTruncateLogic(s, unwrappedIO, targetBits);
1765  else
1766  buildExtendLogic(s, unwrappedIO, /*signExtend=*/true);
1767  };
1768 };
1769 
1770 template <typename T>
1771 class ExtModuleConversionPattern : public OpConversionPattern<T> {
1772 public:
1773  ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1774  MLIRContext *context, OpBuilder &submoduleBuilder,
1775  HandshakeLoweringState &ls)
1776  : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
1777  submoduleBuilder(submoduleBuilder), ls(ls) {}
1778  using OpAdaptor = typename T::Adaptor;
1779 
1780  LogicalResult
1781  matchAndRewrite(T op, OpAdaptor adaptor,
1782  ConversionPatternRewriter &rewriter) const override {
1783 
1784  hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
1785  if (!implModule) {
1786  auto portInfo = ModulePortInfo(getPortInfoForOp(op));
1787  implModule = submoduleBuilder.create<hw::HWModuleExternOp>(
1788  op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
1789  portInfo);
1790  }
1791 
1792  llvm::SmallVector<Value> operands = adaptor.getOperands();
1793  addSequentialIOOperandsIfNeeded(op, operands);
1794  rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1795  op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1796  return success();
1797  }
1798 
1799 private:
1800  OpBuilder &submoduleBuilder;
1801  HandshakeLoweringState &ls;
1802 };
1803 
1804 class FuncOpConversionPattern : public OpConversionPattern<handshake::FuncOp> {
1805 public:
1806  using OpConversionPattern::OpConversionPattern;
1807 
1808  LogicalResult
1809  matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
1810  ConversionPatternRewriter &rewriter) const override {
1811  ModulePortInfo ports =
1812  getPortInfoForOpTypes(op, op.getArgumentTypes(), op.getResultTypes());
1813 
1814  HWModuleLike hwModule;
1815  if (op.isExternal()) {
1816  hwModule = rewriter.create<hw::HWModuleExternOp>(
1817  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1818  } else {
1819  auto hwModuleOp = rewriter.create<hw::HWModuleOp>(
1820  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1821  auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1822  rewriter.inlineBlockBefore(&op.getBody().front(),
1823  hwModuleOp.getBodyBlock()->getTerminator(),
1824  args);
1825  hwModule = hwModuleOp;
1826  }
1827 
1828  // Was any predeclaration associated with this func? If so, replace uses
1829  // with the newly created module and erase the predeclaration.
1830  if (auto predecl =
1831  op->getAttrOfType<FlatSymbolRefAttr>(kPredeclarationAttr)) {
1832  auto *parentOp = op->getParentOp();
1833  auto *predeclModule =
1834  SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1835  if (predeclModule) {
1836  if (failed(SymbolTable::replaceAllSymbolUses(
1837  predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1838  return failure();
1839  rewriter.eraseOp(predeclModule);
1840  }
1841  }
1842 
1843  rewriter.eraseOp(op);
1844  return success();
1845  }
1846 };
1847 
1848 } // namespace
1849 
1850 //===----------------------------------------------------------------------===//
1851 // HW Top-module Related Functions
1852 //===----------------------------------------------------------------------===//
1853 
1854 static LogicalResult convertFuncOp(ESITypeConverter &typeConverter,
1855  ConversionTarget &target,
1856  handshake::FuncOp op,
1857  OpBuilder &moduleBuilder) {
1858 
1859  std::map<std::string, unsigned> instanceNameCntr;
1860  NameUniquer instanceUniquer = [&](Operation *op) {
1861  std::string instName = getCallName(op);
1862  if (auto idAttr = op->getAttrOfType<IntegerAttr>("handshake_id"); idAttr) {
1863  // We use a special naming convention for operations which have a
1864  // 'handshake_id' attribute.
1865  instName += "_id" + std::to_string(idAttr.getValue().getZExtValue());
1866  } else {
1867  // Fallback to just prefixing with an integer.
1868  instName += std::to_string(instanceNameCntr[instName]++);
1869  }
1870  return instName;
1871  };
1872 
1873  auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1874  instanceUniquer};
1875  RewritePatternSet patterns(op.getContext());
1876  patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1877  op.getContext());
1878  patterns.insert<JoinConversionPattern, ForkConversionPattern,
1879  SyncConversionPattern>(typeConverter, op.getContext(),
1880  moduleBuilder, ls);
1881 
1882  patterns.insert<
1883  // Comb operations.
1884  UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1885  UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1886  UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1887  UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1888  UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1889  UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1890  UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1891  UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1892  UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1893  UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1894  UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1895  UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1896  UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1897  UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1898  // HW operations.
1899  StructCreateConversionPattern,
1900  // Handshake operations.
1901  ConditionalBranchConversionPattern, MuxConversionPattern,
1902  PackConversionPattern, UnpackConversionPattern,
1903  ComparisonConversionPattern, BufferConversionPattern,
1904  SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1905  MergeConversionPattern, ControlMergeConversionPattern,
1906  LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1907  InstanceConversionPattern,
1908  // Arith operations.
1909  ExtendConversionPattern<arith::ExtUIOp, /*signExtend=*/false>,
1910  ExtendConversionPattern<arith::ExtSIOp, /*signExtend=*/true>,
1911  TruncateConversionPattern, IndexCastConversionPattern>(
1912  typeConverter, op.getContext(), moduleBuilder, ls);
1913 
1914  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1915  return op->emitOpError() << "error during conversion";
1916  return success();
1917 }
1918 
1919 namespace {
1920 class HandshakeToHWPass
1921  : public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1922 public:
1923  void runOnOperation() override {
1924  mlir::ModuleOp mod = getOperation();
1925 
1926  // Lowering to HW requires that every value is used exactly once. Check
1927  // whether this precondition is met, and if not, exit.
1928  for (auto f : mod.getOps<handshake::FuncOp>()) {
1929  if (failed(verifyAllValuesHasOneUse(f))) {
1930  f.emitOpError() << "HandshakeToHW: failed to verify that all values "
1931  "are used exactly once. Remember to run the "
1932  "fork/sink materialization pass before HW lowering.";
1933  signalPassFailure();
1934  return;
1935  }
1936  }
1937 
1938  // Resolve the instance graph to get a top-level module.
1939  std::string topLevel;
1941  SmallVector<std::string> sortedFuncs;
1942  if (resolveInstanceGraph(mod, uses, topLevel, sortedFuncs).failed()) {
1943  signalPassFailure();
1944  return;
1945  }
1946 
1947  ESITypeConverter typeConverter;
1948  ConversionTarget target(getContext());
1949  // All top-level logic of a handshake module will be the interconnectivity
1950  // between instantiated modules.
1951  target.addLegalOp<hw::HWModuleOp, hw::HWModuleExternOp, hw::OutputOp,
1952  hw::InstanceOp>();
1953  target
1954  .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1955 
1956  // Convert the handshake.func operations in post-order wrt. the instance
1957  // graph. This ensures that any referenced submodules (through
1958  // handshake.instance) has already been lowered, and their HW module
1959  // equivalents are available.
1960  OpBuilder submoduleBuilder(mod.getContext());
1961  submoduleBuilder.setInsertionPointToStart(mod.getBody());
1962  for (auto &funcName : llvm::reverse(sortedFuncs)) {
1963  auto funcOp = mod.lookupSymbol<handshake::FuncOp>(funcName);
1964  assert(funcOp && "handshake.func not found in module!");
1965  if (failed(
1966  convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
1967  signalPassFailure();
1968  return;
1969  }
1970  }
1971 
1972  // Second stage: Convert any handshake.extmemory operations and the
1973  // top-level I/O associated with these.
1974  for (auto hwModule : mod.getOps<hw::HWModuleOp>())
1975  if (failed(convertExtMemoryOps(hwModule)))
1976  return signalPassFailure();
1977  }
1978 };
1979 } // end anonymous namespace
1980 
1981 std::unique_ptr<mlir::Pass> circt::createHandshakeToHWPass() {
1982  return std::make_unique<HandshakeToHWPass>();
1983 }
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
MlirType elementType
Definition: CHIRRTL.cpp:29
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
Definition: CalyxOps.cpp:120
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:539
static Type tupleToStruct(TupleType tuple)
Definition: DCToHW.cpp:48
std::function< std::string(Operation *)> NameUniquer
Definition: DCToHW.cpp:45
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
Definition: FIRRTLOps.cpp:1005
int32_t width
Definition: FIRRTL.cpp:36
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
static std::string getCallName(Operation *op)
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
Definition: OM.cpp:96
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Channels are the basic communication primitives.
Definition: Types.h:63
const Type * getInner() const
Definition: Types.h:66
def create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition: seq.py:137
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition: CombOps.cpp:25
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getInputs()
PortDirectionRange getOutputs()