CIRCT  20.0.0git
HandshakeToHW.cpp
Go to the documentation of this file.
1 //===- HandshakeToHW.cpp - Translate Handshake into HW ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main Handshake to HW Conversion Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
26 #include "mlir/Dialect/Arith/IR/Arith.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Pass/PassManager.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/MathExtras.h"
34 #include <optional>
35 
36 namespace circt {
37 #define GEN_PASS_DEF_HANDSHAKETOHW
38 #include "circt/Conversion/Passes.h.inc"
39 } // namespace circt
40 
41 using namespace mlir;
42 using namespace circt;
43 using namespace circt::handshake;
44 using namespace circt::hw;
45 
46 using NameUniquer = std::function<std::string(Operation *)>;
47 
48 namespace {
49 
50 static Type tupleToStruct(TypeRange types) {
51  return toValidType(mlir::TupleType::get(types[0].getContext(), types));
52 }
53 
54 // Shared state used by various functions; captured in a struct to reduce the
55 // number of arguments that we have to pass around.
56 struct HandshakeLoweringState {
57  ModuleOp parentModule;
58  NameUniquer nameUniquer;
59 };
60 
61 // A type converter is needed to perform the in-flight materialization of "raw"
62 // (non-ESI channel) types to their ESI channel correspondents. This comes into
63 // effect when backedges exist in the input IR.
64 class ESITypeConverter : public TypeConverter {
65 public:
66  ESITypeConverter() {
67  addConversion([](Type type) -> Type { return esiWrapper(type); });
68 
69  addTargetMaterialization([&](mlir::OpBuilder &builder,
70  mlir::Type resultType, mlir::ValueRange inputs,
71  mlir::Location loc) -> mlir::Value {
72  if (inputs.size() != 1)
73  return Value();
74  return builder
75  .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
76  ->getResult(0);
77  });
78 
79  addSourceMaterialization([&](mlir::OpBuilder &builder,
80  mlir::Type resultType, mlir::ValueRange inputs,
81  mlir::Location loc) -> mlir::Value {
82  if (inputs.size() != 1)
83  return Value();
84  return builder
85  .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
86  ->getResult(0);
87  });
88  }
89 };
90 
91 } // namespace
92 
93 /// Returns a submodule name resulting from an operation, without discriminating
94 /// type information.
95 static std::string getBareSubModuleName(Operation *oldOp) {
96  // The dialect name is separated from the operation name by '.', which is not
97  // valid in SystemVerilog module names. In case this name is used in
98  // SystemVerilog output, replace '.' with '_'.
99  std::string subModuleName = oldOp->getName().getStringRef().str();
100  std::replace(subModuleName.begin(), subModuleName.end(), '.', '_');
101  return subModuleName;
102 }
103 
104 static std::string getCallName(Operation *op) {
105  auto callOp = dyn_cast<handshake::InstanceOp>(op);
106  return callOp ? callOp.getModule().str() : getBareSubModuleName(op);
107 }
108 
109 /// Extracts the type of the data-carrying type of opType. If opType is an ESI
110 /// channel, getHandshakeBundleDataType extracts the data-carrying type, else,
111 /// assume that opType itself is the data-carrying type.
112 static Type getOperandDataType(Value op) {
113  auto opType = op.getType();
114  if (auto channelType = dyn_cast<esi::ChannelType>(opType))
115  return channelType.getInner();
116  return opType;
117 }
118 
119 /// Filters NoneType's from the input.
120 static SmallVector<Type> filterNoneTypes(ArrayRef<Type> input) {
121  SmallVector<Type> filterRes;
122  llvm::copy_if(input, std::back_inserter(filterRes),
123  [](Type type) { return !isa<NoneType>(type); });
124  return filterRes;
125 }
126 
127 /// Returns a set of types which may uniquely identify the provided op. Return
128 /// value is <inputTypes, outputTypes>.
129 using DiscriminatingTypes = std::pair<SmallVector<Type>, SmallVector<Type>>;
131  return TypeSwitch<Operation *, DiscriminatingTypes>(op)
132  .Case<MemoryOp, ExternalMemoryOp>([&](auto memOp) {
133  return DiscriminatingTypes{{},
134  {memOp.getMemRefType().getElementType()}};
135  })
136  .Default([&](auto) {
137  // By default, all in- and output types which is not a control type
138  // (NoneType) are discriminating types.
139  std::vector<Type> inTypes, outTypes;
140  llvm::transform(op->getOperands(), std::back_inserter(inTypes),
142  llvm::transform(op->getResults(), std::back_inserter(outTypes),
144  return DiscriminatingTypes{filterNoneTypes(inTypes),
145  filterNoneTypes(outTypes)};
146  });
147 }
148 
149 /// Get type name. Currently we only support integer or index types.
150 /// The emitted type aligns with the getFIRRTLType() method. Thus all integers
151 /// other than signed integers will be emitted as unsigned.
152 // NOLINTNEXTLINE(misc-no-recursion)
153 static std::string getTypeName(Location loc, Type type) {
154  std::string typeName;
155  // Builtin types
156  if (type.isIntOrIndex()) {
157  if (auto indexType = dyn_cast<IndexType>(type))
158  typeName += "_ui" + std::to_string(indexType.kInternalStorageBitWidth);
159  else if (type.isSignedInteger())
160  typeName += "_si" + std::to_string(type.getIntOrFloatBitWidth());
161  else
162  typeName += "_ui" + std::to_string(type.getIntOrFloatBitWidth());
163  } else if (auto tupleType = dyn_cast<TupleType>(type)) {
164  typeName += "_tuple";
165  for (auto elementType : tupleType.getTypes())
166  typeName += getTypeName(loc, elementType);
167  } else if (auto structType = dyn_cast<hw::StructType>(type)) {
168  typeName += "_struct";
169  for (auto element : structType.getElements())
170  typeName += "_" + element.name.str() + getTypeName(loc, element.type);
171  } else
172  emitError(loc) << "unsupported data type '" << type << "'";
173 
174  return typeName;
175 }
176 
177 /// Construct a name for creating HW sub-module.
178 static std::string getSubModuleName(Operation *oldOp) {
179  if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
180  return instanceOp.getModule().str();
181 
182  std::string subModuleName = getBareSubModuleName(oldOp);
183 
184  // Add value of the constant operation.
185  if (auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
186  if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
187  auto intType = intAttr.getType();
188 
189  if (intType.isSignedInteger())
190  subModuleName += "_c" + std::to_string(intAttr.getSInt());
191  else if (intType.isUnsignedInteger())
192  subModuleName += "_c" + std::to_string(intAttr.getUInt());
193  else
194  subModuleName += "_c" + std::to_string((uint64_t)intAttr.getInt());
195  } else
196  oldOp->emitError("unsupported constant type");
197  }
198 
199  // Add discriminating in- and output types.
200  auto [inTypes, outTypes] = getHandshakeDiscriminatingTypes(oldOp);
201  if (!inTypes.empty())
202  subModuleName += "_in";
203  for (auto inType : inTypes)
204  subModuleName += getTypeName(oldOp->getLoc(), inType);
205 
206  if (!outTypes.empty())
207  subModuleName += "_out";
208  for (auto outType : outTypes)
209  subModuleName += getTypeName(oldOp->getLoc(), outType);
210 
211  // Add memory ID.
212  if (auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
213  subModuleName += "_id" + std::to_string(memOp.getId());
214 
215  // Add compare kind.
216  if (auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
217  subModuleName += "_" + stringifyEnum(comOp.getPredicate()).str();
218 
219  // Add buffer information.
220  if (auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
221  subModuleName += "_" + std::to_string(bufferOp.getNumSlots()) + "slots";
222  if (bufferOp.isSequential())
223  subModuleName += "_seq";
224  else
225  subModuleName += "_fifo";
226 
227  if (auto initValues = bufferOp.getInitValues()) {
228  subModuleName += "_init";
229  for (const Attribute e : *initValues) {
230  assert(isa<IntegerAttr>(e));
231  subModuleName +=
232  "_" + std::to_string(dyn_cast<IntegerAttr>(e).getInt());
233  }
234  }
235  }
236 
237  // Add control information.
238  if (auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
239  ctrlInterface && ctrlInterface.isControl()) {
240  // Add some additional discriminating info for non-typed operations.
241  subModuleName += "_" + std::to_string(oldOp->getNumOperands()) + "ins_" +
242  std::to_string(oldOp->getNumResults()) + "outs";
243  subModuleName += "_ctrl";
244  } else {
245  assert(
246  (!inTypes.empty() || !outTypes.empty()) &&
247  "Insufficient discriminating type info generated for the operation!");
248  }
249 
250  return subModuleName;
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // HW Sub-module Related Functions
255 //===----------------------------------------------------------------------===//
256 
257 /// Check whether a submodule with the same name has been created elsewhere in
258 /// the top level module. Return the matched module operation if true, otherwise
259 /// return nullptr.
260 static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
261  StringRef modName) {
262  if (auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
263  return mod;
264  if (auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
265  return mod;
266  return {};
267 }
268 
269 static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
270  Operation *oldOp) {
271  HWModuleLike targetModule;
272  if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
273  targetModule = checkSubModuleOp(parentModule, instanceOp.getModule());
274  else
275  targetModule = checkSubModuleOp(parentModule, getSubModuleName(oldOp));
276 
277  if (isa<handshake::InstanceOp>(oldOp))
278  assert(targetModule &&
279  "handshake.instance target modules should always have been lowered "
280  "before the modules that reference them!");
281  return targetModule;
282 }
283 
284 /// Returns a vector of PortInfo's which defines the HW interface of the
285 /// to-be-converted op.
286 static ModulePortInfo getPortInfoForOp(Operation *op) {
287  return getPortInfoForOpTypes(op, op->getOperandTypes(), op->getResultTypes());
288 }
289 
290 static llvm::SmallVector<hw::detail::FieldInfo>
291 portToFieldInfo(llvm::ArrayRef<hw::PortInfo> portInfo) {
292  llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
293  for (auto port : portInfo)
294  fieldInfo.push_back({port.name, port.type});
295 
296  return fieldInfo;
297 }
298 
299 // Convert any handshake.extmemory operations and the top-level I/O
300 // associated with these.
301 static LogicalResult convertExtMemoryOps(HWModuleOp mod) {
302  auto *ctx = mod.getContext();
303 
304  // Gather memref ports to be converted.
305  llvm::DenseMap<unsigned, Value> memrefPorts;
306  for (auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
307  auto channel = dyn_cast<esi::ChannelType>(arg.getType());
308  if (channel && isa<MemRefType>(channel.getInner()))
309  memrefPorts[i] = arg;
310  }
311 
312  if (memrefPorts.empty())
313  return success(); // nothing to do.
314 
315  OpBuilder b(mod);
316 
317  auto getMemoryIOInfo = [&](Location loc, Twine portName, unsigned argIdx,
318  ArrayRef<hw::PortInfo> info,
319  hw::ModulePort::Direction direction) {
320  auto type = hw::StructType::get(ctx, portToFieldInfo(info));
321  auto portInfo =
322  hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
323  return portInfo;
324  };
325 
326  for (auto [i, arg] : memrefPorts) {
327  // Insert ports into the module
328  auto memName = mod.getArgName(i);
329 
330  // Get the attached extmemory external module.
331  auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
332  auto extmemMod =
333  cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
334  extmemInstance, extmemInstance.getModuleNameAttr()));
335 
336  ModulePortInfo portInfo(extmemMod.getPortList());
337 
338  // The extmemory external module's interface is a direct wrapping of the
339  // original handshake.extmemory operation in- and output types. Remove the
340  // first input argument (the !esi.channel<memref> op) since that is what
341  // we're replacing with a materialized interface.
342  portInfo.eraseInput(0);
343 
344  // Add memory input - this is the output of the extmemory op.
345  SmallVector<PortInfo> outputs(portInfo.getOutputs());
346  auto inPortInfo =
347  getMemoryIOInfo(arg.getLoc(), memName.strref() + "_in", i, outputs,
349  mod.insertPorts({{i, inPortInfo}}, {});
350  auto newInPort = mod.getArgumentForInput(i);
351  // Replace the extmemory submodule outputs with the newly created inputs.
352  b.setInsertionPointToStart(mod.getBodyBlock());
353  auto newInPortExploded = b.create<hw::StructExplodeOp>(
354  arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
355  extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
356 
357  // Add memory output - this is the inputs of the extmemory op (without the
358  // first argument);
359  unsigned outArgI = mod.getNumOutputPorts();
360  SmallVector<PortInfo> inputs(portInfo.getInputs());
361  auto outPortInfo =
362  getMemoryIOInfo(arg.getLoc(), memName.strref() + "_out", outArgI,
364 
365  auto memOutputArgs = extmemInstance.getOperands().drop_front();
366  b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
367  auto memOutputStruct = b.create<hw::StructCreateOp>(
368  arg.getLoc(), outPortInfo.type, memOutputArgs);
369  mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
370 
371  // Erase the extmemory submodule instace since the i/o has now been
372  // plumbed.
373  extmemMod.erase();
374  extmemInstance.erase();
375 
376  // Erase the original memref argument of the top-level i/o now that it's use
377  // has been removed.
378  mod.modifyPorts(/*insertInputs*/ {}, /*insertOutputs*/ {},
379  /*eraseInputs*/ {i + 1}, /*eraseOutputs*/ {});
380  }
381 
382  return success();
383 }
384 
385 namespace {
386 
387 // Input handshakes contain a resolved valid and (optional )data signal, and
388 // a to-be-assigned ready signal.
389 struct InputHandshake {
390  Value valid;
391  std::shared_ptr<Backedge> ready;
392  Value data;
393 };
394 
395 // Output handshakes contain a resolved ready, and to-be-assigned valid and
396 // (optional) data signals.
397 struct OutputHandshake {
398  std::shared_ptr<Backedge> valid;
399  Value ready;
400  std::shared_ptr<Backedge> data;
401 };
402 
403 /// A helper struct that acts like a wire. Can be used to interact with the
404 /// RTLBuilder when multiple built components should be connected.
405 struct HandshakeWire {
406  HandshakeWire(BackedgeBuilder &bb, Type dataType) {
407  MLIRContext *ctx = dataType.getContext();
408  auto i1Type = IntegerType::get(ctx, 1);
409  valid = std::make_shared<Backedge>(bb.get(i1Type));
410  ready = std::make_shared<Backedge>(bb.get(i1Type));
411  data = std::make_shared<Backedge>(bb.get(dataType));
412  }
413 
414  // Functions that allow to treat a wire like an input or output port.
415  // **Careful**: Such a port will not be updated when backedges are resolved.
416  InputHandshake getAsInput() { return {*valid, ready, *data}; }
417  OutputHandshake getAsOutput() { return {valid, *ready, data}; }
418 
419  std::shared_ptr<Backedge> valid;
420  std::shared_ptr<Backedge> ready;
421  std::shared_ptr<Backedge> data;
422 };
423 
424 template <typename T, typename TInner>
425 llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
426  llvm::function_ref<T(TInner &)> extractor) {
427  llvm::SmallVector<T> result;
428  llvm::transform(container, std::back_inserter(result), extractor);
429  return result;
430 }
431 struct UnwrappedIO {
432  llvm::SmallVector<InputHandshake> inputs;
433  llvm::SmallVector<OutputHandshake> outputs;
434 
435  llvm::SmallVector<Value> getInputValids() {
436  return extractValues<Value, InputHandshake>(
437  inputs, [](auto &hs) { return hs.valid; });
438  }
439  llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
440  return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
441  inputs, [](auto &hs) { return hs.ready; });
442  }
443  llvm::SmallVector<Value> getInputDatas() {
444  return extractValues<Value, InputHandshake>(
445  inputs, [](auto &hs) { return hs.data; });
446  }
447  llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
448  return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
449  outputs, [](auto &hs) { return hs.valid; });
450  }
451  llvm::SmallVector<Value> getOutputReadys() {
452  return extractValues<Value, OutputHandshake>(
453  outputs, [](auto &hs) { return hs.ready; });
454  }
455  llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
456  return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
457  outputs, [](auto &hs) { return hs.data; });
458  }
459 };
460 
461 // A class containing a bunch of syntactic sugar to reduce builder function
462 // verbosity.
463 // @todo: should be moved to support.
464 struct RTLBuilder {
465  RTLBuilder(hw::ModulePortInfo info, OpBuilder &builder, Location loc,
466  Value clk = Value(), Value rst = Value())
467  : info(std::move(info)), b(builder), loc(loc), clk(clk), rst(rst) {}
468 
469  Value constant(const APInt &apv, std::optional<StringRef> name = {}) {
470  // Cannot use zero-width APInt's in DenseMap's, see
471  // https://github.com/llvm/llvm-project/issues/58013
472  bool isZeroWidth = apv.getBitWidth() == 0;
473  if (!isZeroWidth) {
474  auto it = constants.find(apv);
475  if (it != constants.end())
476  return it->second;
477  }
478 
479  auto cval = b.create<hw::ConstantOp>(loc, apv);
480  if (!isZeroWidth)
481  constants[apv] = cval;
482  return cval;
483  }
484 
485  Value constant(unsigned width, int64_t value,
486  std::optional<StringRef> name = {}) {
487  return constant(
488  APInt(width, value, /*isSigned=*/false, /*implicitTrunc=*/true));
489  }
490  std::pair<Value, Value> wrap(Value data, Value valid,
491  std::optional<StringRef> name = {}) {
492  auto wrapOp = b.create<esi::WrapValidReadyOp>(loc, data, valid);
493  return {wrapOp.getResult(0), wrapOp.getResult(1)};
494  }
495  std::pair<Value, Value> unwrap(Value channel, Value ready,
496  std::optional<StringRef> name = {}) {
497  auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
498  return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
499  }
500 
501  // Various syntactic sugar functions.
502  Value reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
503  Value rst = Value()) {
504  Value resolvedClk = clk ? clk : this->clk;
505  Value resolvedRst = rst ? rst : this->rst;
506  assert(resolvedClk &&
507  "No global clock provided to this RTLBuilder - a clock "
508  "signal must be provided to the reg(...) function.");
509  assert(resolvedRst &&
510  "No global reset provided to this RTLBuilder - a reset "
511  "signal must be provided to the reg(...) function.");
512 
513  return b.create<seq::CompRegOp>(loc, in, resolvedClk, resolvedRst, rstValue,
514  name);
515  }
516 
517  Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
518  std::optional<StringRef> name = {}) {
519  return b.create<comb::ICmpOp>(loc, predicate, lhs, rhs);
520  }
521 
522  Value buildNamedOp(llvm::function_ref<Value()> f,
523  std::optional<StringRef> name) {
524  Value v = f();
525  StringAttr nameAttr;
526  Operation *op = v.getDefiningOp();
527  if (name.has_value()) {
528  op->setAttr("sv.namehint", b.getStringAttr(*name));
529  nameAttr = b.getStringAttr(*name);
530  }
531  return v;
532  }
533 
534  // Bitwise 'and'.
535  Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
536  return buildNamedOp(
537  [&]() { return b.create<comb::AndOp>(loc, values, false); }, name);
538  }
539 
540  Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
541  return buildNamedOp(
542  [&]() { return b.create<comb::OrOp>(loc, values, false); }, name);
543  }
544 
545  // Bitwise 'not'.
546  Value bNot(Value value, std::optional<StringRef> name = {}) {
547  auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
548  std::string inferedName;
549  if (!name) {
550  // Try to create a name from the input value.
551  if (auto valueName =
552  value.getDefiningOp()->getAttrOfType<StringAttr>("sv.namehint")) {
553  inferedName = ("not_" + valueName.getValue()).str();
554  name = inferedName;
555  }
556  }
557 
558  return buildNamedOp(
559  [&]() { return b.create<comb::XorOp>(loc, value, allOnes); }, name);
560 
561  return b.createOrFold<comb::XorOp>(loc, value, allOnes, false);
562  }
563 
564  Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
565  return buildNamedOp(
566  [&]() { return b.create<comb::ShlOp>(loc, value, shift); }, name);
567  }
568 
569  Value concat(ValueRange values, std::optional<StringRef> name = {}) {
570  return buildNamedOp([&]() { return b.create<comb::ConcatOp>(loc, values); },
571  name);
572  }
573 
574  // Packs a list of values into a hw.struct.
575  Value pack(ValueRange values, Type structType = Type(),
576  std::optional<StringRef> name = {}) {
577  if (!structType)
578  structType = tupleToStruct(values.getTypes());
579  return buildNamedOp(
580  [&]() { return b.create<hw::StructCreateOp>(loc, structType, values); },
581  name);
582  }
583 
584  // Unpacks a hw.struct into a list of values.
585  ValueRange unpack(Value value) {
586  auto structType = cast<hw::StructType>(value.getType());
587  llvm::SmallVector<Type> innerTypes;
588  structType.getInnerTypes(innerTypes);
589  return b.create<hw::StructExplodeOp>(loc, innerTypes, value).getResults();
590  }
591 
592  llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
593  llvm::SmallVector<Value> bits;
594  for (unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
595  bits.push_back(b.create<comb::ExtractOp>(loc, v, i, /*bitWidth=*/1));
596  return bits;
597  }
598 
599  // OR-reduction of the bits in 'v'.
600  Value rOr(Value v, std::optional<StringRef> name = {}) {
601  return buildNamedOp([&]() { return bOr(toBits(v)); }, name);
602  }
603 
604  // Extract bits v[hi:lo] (inclusive).
605  Value extract(Value v, unsigned lo, unsigned hi,
606  std::optional<StringRef> name = {}) {
607  unsigned width = hi - lo + 1;
608  return buildNamedOp(
609  [&]() { return b.create<comb::ExtractOp>(loc, v, lo, width); }, name);
610  }
611 
612  // Truncates 'value' to its lower 'width' bits.
613  Value truncate(Value value, unsigned width,
614  std::optional<StringRef> name = {}) {
615  return extract(value, 0, width - 1, name);
616  }
617 
618  Value zext(Value value, unsigned outWidth,
619  std::optional<StringRef> name = {}) {
620  unsigned inWidth = value.getType().getIntOrFloatBitWidth();
621  assert(inWidth <= outWidth && "zext: input width must be <- output width.");
622  if (inWidth == outWidth)
623  return value;
624  auto c0 = constant(outWidth - inWidth, 0);
625  return concat({c0, value}, name);
626  }
627 
628  Value sext(Value value, unsigned outWidth,
629  std::optional<StringRef> name = {}) {
630  return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
631  }
632 
633  // Extracts a single bit v[bit].
634  Value bit(Value v, unsigned index, std::optional<StringRef> name = {}) {
635  return extract(v, index, index, name);
636  }
637 
638  // Creates a hw.array of the given values.
639  Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
640  return buildNamedOp(
641  [&]() { return b.create<hw::ArrayCreateOp>(loc, values); }, name);
642  }
643 
644  // Extract the 'index'th value from the input array.
645  Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
646  return buildNamedOp(
647  [&]() { return b.create<hw::ArrayGetOp>(loc, array, index); }, name);
648  }
649 
650  // Muxes a range of values.
651  // The select signal is expected to be a decimal value which selects starting
652  // from the lowest index of value.
653  Value mux(Value index, ValueRange values,
654  std::optional<StringRef> name = {}) {
655  if (values.size() == 2)
656  return b.create<comb::MuxOp>(loc, index, values[1], values[0]);
657 
658  return arrayGet(arrayCreate(values), index, name);
659  }
660 
661  // Muxes a range of values. The select signal is expected to be a 1-hot
662  // encoded value.
663  Value ohMux(Value index, ValueRange inputs) {
664  // Confirm the select input can be a one-hot encoding for the inputs.
665  unsigned numInputs = inputs.size();
666  assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
667  "one-hot select can't mux inputs");
668 
669  // Start the mux tree with zero value.
670  // Todo: clean up when handshake supports i0.
671  auto dataType = inputs[0].getType();
672  unsigned width =
673  isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
674  Value muxValue = constant(width, 0);
675 
676  // Iteratively chain together muxes from the high bit to the low bit.
677  for (size_t i = numInputs - 1; i != 0; --i) {
678  Value input = inputs[i];
679  Value selectBit = bit(index, i);
680  muxValue = mux(selectBit, {muxValue, input});
681  }
682 
683  return muxValue;
684  }
685 
686  hw::ModulePortInfo info;
687  OpBuilder &b;
688  Location loc;
689  Value clk, rst;
690  DenseMap<APInt, Value> constants;
691 };
692 
693 /// Creates a Value that has an assigned zero value. For structs, this
694 /// corresponds to assigning zero to each element recursively.
695 static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
696  return TypeSwitch<Type, Value>(type)
697  .Case<NoneType>([&](NoneType) { return s.constant(0, 0); })
698  .Case<IntType, IntegerType>([&](auto type) {
699  return s.constant(type.getIntOrFloatBitWidth(), 0);
700  })
701  .Case<hw::StructType>([&](auto structType) {
702  SmallVector<Value> zeroValues;
703  for (auto field : structType.getElements())
704  zeroValues.push_back(createZeroDataConst(s, loc, field.type));
705  return s.b.create<hw::StructCreateOp>(loc, structType, zeroValues);
706  })
707  .Default([&](Type) -> Value {
708  emitError(loc) << "unsupported type for zero value: " << type;
709  assert(false);
710  return {};
711  });
712 }
713 
714 static void
715 addSequentialIOOperandsIfNeeded(Operation *op,
716  llvm::SmallVectorImpl<Value> &operands) {
717  if (op->hasTrait<mlir::OpTrait::HasClock>()) {
718  // Parent should at this point be a hw.module and have clock and reset
719  // ports.
720  auto parent = cast<hw::HWModuleOp>(op->getParentOp());
721  operands.push_back(
722  parent.getArgumentForInput(parent.getNumInputPorts() - 2));
723  operands.push_back(
724  parent.getArgumentForInput(parent.getNumInputPorts() - 1));
725  }
726 }
727 
728 template <typename T>
729 class HandshakeConversionPattern : public OpConversionPattern<T> {
730 public:
731  HandshakeConversionPattern(ESITypeConverter &typeConverter,
732  MLIRContext *context, OpBuilder &submoduleBuilder,
733  HandshakeLoweringState &ls)
734  : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
735  submoduleBuilder(submoduleBuilder), ls(ls) {}
736 
737  using OpAdaptor = typename T::Adaptor;
738 
739  LogicalResult
740  matchAndRewrite(T op, OpAdaptor adaptor,
741  ConversionPatternRewriter &rewriter) const override {
742 
743  // Check if a submodule has already been created for the op. If so,
744  // instantiate the submodule. Else, run the pattern-defined module
745  // builder.
746  hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
747  if (!implModule) {
748  auto portInfo = ModulePortInfo(getPortInfoForOp(op));
749 
750  submoduleBuilder.setInsertionPoint(op->getParentOp());
751  implModule = submoduleBuilder.create<hw::HWModuleOp>(
752  op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
753  portInfo, [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
754  // if 'op' has clock trait, extract these and provide them to the
755  // RTL builder.
756  Value clk, rst;
757  if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
758  clk = ports.getInput("clock");
759  rst = ports.getInput("reset");
760  }
761 
762  BackedgeBuilder bb(b, op.getLoc());
763  RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
764  this->buildModule(op, bb, s, ports);
765  });
766  }
767 
768  // Instantiate the submodule.
769  llvm::SmallVector<Value> operands = adaptor.getOperands();
770  addSequentialIOOperandsIfNeeded(op, operands);
771  rewriter.replaceOpWithNewOp<hw::InstanceOp>(
772  op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
773  return success();
774  }
775 
776  virtual void buildModule(T op, BackedgeBuilder &bb, RTLBuilder &builder,
777  hw::HWModulePortAccessor &ports) const = 0;
778 
779  // Syntactic sugar functions.
780  // Unwraps an ESI-interfaced module into its constituent handshake signals.
781  // Backedges are created for the to-be-resolved signals, and output ports
782  // are assigned to their wrapped counterparts.
783  UnwrappedIO unwrapIO(RTLBuilder &s, BackedgeBuilder &bb,
784  hw::HWModulePortAccessor &ports) const {
785  UnwrappedIO unwrapped;
786  for (auto port : ports.getInputs()) {
787  if (!isa<esi::ChannelType>(port.getType()))
788  continue;
789  InputHandshake hs;
790  auto ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
791  auto [data, valid] = s.unwrap(port, *ready);
792  hs.data = data;
793  hs.valid = valid;
794  hs.ready = ready;
795  unwrapped.inputs.push_back(hs);
796  }
797  for (auto &outputInfo : ports.getPortList().getOutputs()) {
798  esi::ChannelType channelType =
799  dyn_cast<esi::ChannelType>(outputInfo.type);
800  if (!channelType)
801  continue;
802  OutputHandshake hs;
803  Type innerType = channelType.getInner();
804  auto data = std::make_shared<Backedge>(bb.get(innerType));
805  auto valid = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
806  auto [dataCh, ready] = s.wrap(*data, *valid);
807  hs.data = data;
808  hs.valid = valid;
809  hs.ready = ready;
810  ports.setOutput(outputInfo.name, dataCh);
811  unwrapped.outputs.push_back(hs);
812  }
813  return unwrapped;
814  }
815 
816  void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
817  OutputHandshake &output, Value cond) const {
818  auto validAndReady = s.bAnd({output.ready, cond});
819  for (auto &input : inputs)
820  input.ready->setValue(validAndReady);
821  }
822 
823  void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
824  OutputHandshake &output) const {
825  llvm::SmallVector<Value> valids;
826  for (auto &input : inputs)
827  valids.push_back(input.valid);
828  Value allValid = s.bAnd(valids);
829  output.valid->setValue(allValid);
830  setAllReadyWithCond(s, inputs, output, allValid);
831  }
832 
833  // Builds mux logic for the given inputs and outputs.
834  // Note: it is assumed that the caller has removed the 'select' signal from
835  // the 'unwrapped' inputs and provide it as a separate argument.
836  void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
837  InputHandshake &select) const {
838  // ============================= Control logic =============================
839  size_t numInputs = unwrapped.inputs.size();
840  size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
841  Value truncatedSelect =
842  select.data.getType().getIntOrFloatBitWidth() > selectWidth
843  ? s.truncate(select.data, selectWidth)
844  : select.data;
845 
846  // Decimal-to-1-hot decoder. 'shl' operands must be identical in size.
847  auto selectZext = s.zext(truncatedSelect, numInputs);
848  auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
849  auto &res = unwrapped.outputs[0];
850 
851  // Mux input valid signals.
852  auto selectedInputValid =
853  s.mux(truncatedSelect, unwrapped.getInputValids());
854  // Result is valid when the selected input and the select input is valid.
855  auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
856  res.valid->setValue(selAndInputValid);
857  auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
858 
859  // Select is ready when result is valid and ready (result transacting).
860  select.ready->setValue(resValidAndReady);
861 
862  // Assign each input ready signal if it is currently selected.
863  for (auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
864  // Extract the selection bit for this input.
865  auto isSelected = s.bit(select1h, inIdx);
866 
867  // '&' that with the result valid and ready, and assign to the input
868  // ready signal.
869  auto activeAndResultValidAndReady =
870  s.bAnd({isSelected, resValidAndReady});
871  in.ready->setValue(activeAndResultValidAndReady);
872  }
873 
874  // ============================== Data logic ===============================
875  res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
876  }
877 
878  // Builds fork logic between the single input and multiple outputs' control
879  // networks. Caller is expected to handle data separately.
880  void buildForkLogic(RTLBuilder &s, BackedgeBuilder &bb, InputHandshake &input,
881  ArrayRef<OutputHandshake> outputs) const {
882  auto c0I1 = s.constant(1, 0);
883  llvm::SmallVector<Value> doneWires;
884  for (auto [i, output] : llvm::enumerate(outputs)) {
885  auto doneBE = bb.get(s.b.getI1Type());
886  auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
887  auto emittedReg = s.reg("emitted_" + std::to_string(i), emitted, c0I1);
888  auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
889  output.valid->setValue(outValid);
890  auto validReady = s.bAnd({output.ready, outValid});
891  auto done = s.bOr({validReady, emittedReg}, "done" + std::to_string(i));
892  doneBE.setValue(done);
893  doneWires.push_back(done);
894  }
895  input.ready->setValue(s.bAnd(doneWires, "allDone"));
896  }
897 
898  // Builds a unit-rate actor around an inner operation. 'unitBuilder' is a
899  // function which takes the set of unwrapped data inputs, and returns a
900  // value which should be assigned to the output data value.
901  void buildUnitRateJoinLogic(
902  RTLBuilder &s, UnwrappedIO &unwrappedIO,
903  llvm::function_ref<Value(ValueRange)> unitBuilder) const {
904  assert(unwrappedIO.outputs.size() == 1 &&
905  "Expected exactly one output for unit-rate join actor");
906  // Control logic.
907  this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
908 
909  // Data logic.
910  auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
911  unwrappedIO.outputs[0].data->setValue(unitRes);
912  }
913 
914  void buildUnitRateForkLogic(
915  RTLBuilder &s, BackedgeBuilder &bb, UnwrappedIO &unwrappedIO,
916  llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder) const {
917  assert(unwrappedIO.inputs.size() == 1 &&
918  "Expected exactly one input for unit-rate fork actor");
919  // Control logic.
920  this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
921 
922  // Data logic.
923  auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
924  assert(unitResults.size() == unwrappedIO.outputs.size() &&
925  "Expected unit builder to return one result per output");
926  for (auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
927  outport.data->setValue(res);
928  }
929 
930  void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
931  bool signExtend) const {
932  size_t outWidth =
933  toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
934  .getIntOrFloatBitWidth();
935  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
936  if (signExtend)
937  return s.sext(inputs[0], outWidth);
938  return s.zext(inputs[0], outWidth);
939  });
940  }
941 
942  void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
943  unsigned targetWidth) const {
944  size_t outWidth =
945  toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
946  .getIntOrFloatBitWidth();
947  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
948  return s.truncate(inputs[0], outWidth);
949  });
950  }
951 
952  /// Return the number of bits needed to index the given number of values.
953  static size_t getNumIndexBits(uint64_t numValues) {
954  return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
955  }
956 
957  Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
958  Value defaultValue,
959  DenseMap<size_t, Value> &indexMapping) const {
960  auto numInputs = inputs.size();
961  auto priorityArb = defaultValue;
962 
963  for (size_t i = numInputs; i > 0; --i) {
964  size_t inputIndex = i - 1;
965  size_t oneHotIndex = size_t{1} << inputIndex;
966  auto constIndex = s.constant(numInputs, oneHotIndex);
967  indexMapping[inputIndex] = constIndex;
968  priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
969  }
970  return priorityArb;
971  }
972 
973 private:
974  OpBuilder &submoduleBuilder;
975  HandshakeLoweringState &ls;
976 };
977 
978 class ForkConversionPattern : public HandshakeConversionPattern<ForkOp> {
979 public:
980  using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
981  void buildModule(ForkOp op, BackedgeBuilder &bb, RTLBuilder &s,
982  hw::HWModulePortAccessor &ports) const override {
983  auto unwrapped = unwrapIO(s, bb, ports);
984  buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
985  return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
986  });
987  }
988 };
989 
990 class JoinConversionPattern : public HandshakeConversionPattern<JoinOp> {
991 public:
992  using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
993  void buildModule(JoinOp op, BackedgeBuilder &bb, RTLBuilder &s,
994  hw::HWModulePortAccessor &ports) const override {
995  auto unwrappedIO = unwrapIO(s, bb, ports);
996  buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
997  unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
998  };
999 };
1000 
1001 class SyncConversionPattern : public HandshakeConversionPattern<SyncOp> {
1002 public:
1003  using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1004  void buildModule(SyncOp op, BackedgeBuilder &bb, RTLBuilder &s,
1005  hw::HWModulePortAccessor &ports) const override {
1006  auto unwrappedIO = unwrapIO(s, bb, ports);
1007 
1008  // A helper wire that will be used to connect the two built logics
1009  HandshakeWire wire(bb, s.b.getNoneType());
1010 
1011  OutputHandshake output = wire.getAsOutput();
1012  buildJoinLogic(s, unwrappedIO.inputs, output);
1013 
1014  InputHandshake input = wire.getAsInput();
1015 
1016  // The state-keeping fork logic is required here, as the circuit isn't
1017  // allowed to wait for all the consumers to be ready. Connecting the ready
1018  // signals of the outputs to their corresponding valid signals leads to
1019  // combinatorial cycles. The paper which introduced compositional dataflow
1020  // circuits explicitly mentions this limitation:
1021  // http://arcade.cs.columbia.edu/df-memocode17.pdf
1022  buildForkLogic(s, bb, input, unwrappedIO.outputs);
1023 
1024  // Directly connect the data wires, only the control signals need to be
1025  // combined.
1026  for (auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1027  out.data->setValue(in.data);
1028  };
1029 };
1030 
1031 class MuxConversionPattern : public HandshakeConversionPattern<MuxOp> {
1032 public:
1033  using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1034  void buildModule(MuxOp op, BackedgeBuilder &bb, RTLBuilder &s,
1035  hw::HWModulePortAccessor &ports) const override {
1036  auto unwrappedIO = unwrapIO(s, bb, ports);
1037 
1038  // Extract select signal from the unwrapped IO.
1039  auto select = unwrappedIO.inputs[0];
1040  unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1041  buildMuxLogic(s, unwrappedIO, select);
1042  };
1043 };
1044 
1045 class InstanceConversionPattern
1046  : public HandshakeConversionPattern<handshake::InstanceOp> {
1047 public:
1048  using HandshakeConversionPattern<
1049  handshake::InstanceOp>::HandshakeConversionPattern;
1050  void buildModule(handshake::InstanceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1051  hw::HWModulePortAccessor &ports) const override {
1052  assert(false &&
1053  "If we indeed perform conversion in post-order, this "
1054  "should never be called. The base HandshakeConversionPattern logic "
1055  "will instantiate the external module.");
1056  }
1057 };
1058 
1059 class ESIInstanceConversionPattern
1060  : public OpConversionPattern<handshake::ESIInstanceOp> {
1061 public:
1062  ESIInstanceConversionPattern(MLIRContext *context,
1063  const HWSymbolCache &symCache)
1064  : OpConversionPattern(context), symCache(symCache) {}
1065 
1066  LogicalResult
1067  matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1068  ConversionPatternRewriter &rewriter) const override {
1069  // The operand signature of this op is very similar to the lowered
1070  // `handshake.func`s (especially since handshake uses ESI channels
1071  // internally). Whereas ESIInstance ops have 'clk' and 'rst' at the
1072  // beginning, lowered `handshake.func`s have them at the end. So we've just
1073  // got to re-arrange them.
1074  SmallVector<Value> operands;
1075  for (size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1076  i < e; ++i)
1077  operands.push_back(adaptor.getOperands()[i]);
1078  operands.push_back(adaptor.getClk());
1079  operands.push_back(adaptor.getRst());
1080  // Locate the lowered module so the instance builder can get all the
1081  // metadata.
1082  Operation *targetModule = symCache.getDefinition(op.getModuleAttr());
1083  // And replace the op with an instance of the target module.
1084  rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1085  op.getInstNameAttr(), operands);
1086  return success();
1087  }
1088 
1089 private:
1090  const HWSymbolCache &symCache;
1091 };
1092 
1093 class ReturnConversionPattern
1094  : public OpConversionPattern<handshake::ReturnOp> {
1095 public:
1096  using OpConversionPattern::OpConversionPattern;
1097  LogicalResult
1098  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1099  ConversionPatternRewriter &rewriter) const override {
1100  // Locate existing output op, Append operands to output op, and move to
1101  // the end of the block.
1102  auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1103  auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1104  outputOp->setOperands(adaptor.getOperands());
1105  outputOp->moveAfter(&parent.getBodyBlock()->back());
1106  rewriter.eraseOp(op);
1107  return success();
1108  }
1109 };
1110 
1111 // Converts an arbitrary operation into a unit rate actor. A unit rate actor
1112 // will transact once all inputs are valid and its output is ready.
1113 template <typename TIn, typename TOut = TIn>
1114 class UnitRateConversionPattern : public HandshakeConversionPattern<TIn> {
1115 public:
1116  using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1117  void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1118  hw::HWModulePortAccessor &ports) const override {
1119  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1120  this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1121  // Create TOut - it is assumed that TOut trivially
1122  // constructs from the input data signals of TIn.
1123  // To disambiguate ambiguous builders with default arguments (e.g.,
1124  // twoState UnitAttr), specify attribute array explicitly.
1125  return s.b.create<TOut>(op.getLoc(), inputs,
1126  /* attributes */ ArrayRef<NamedAttribute>{});
1127  });
1128  };
1129 };
1130 
1131 class PackConversionPattern : public HandshakeConversionPattern<PackOp> {
1132 public:
1133  using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1134  void buildModule(PackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1135  hw::HWModulePortAccessor &ports) const override {
1136  auto unwrappedIO = unwrapIO(s, bb, ports);
1137  buildUnitRateJoinLogic(s, unwrappedIO,
1138  [&](ValueRange inputs) { return s.pack(inputs); });
1139  };
1140 };
1141 
1142 class StructCreateConversionPattern
1143  : public HandshakeConversionPattern<hw::StructCreateOp> {
1144 public:
1145  using HandshakeConversionPattern<
1146  hw::StructCreateOp>::HandshakeConversionPattern;
1147  void buildModule(hw::StructCreateOp op, BackedgeBuilder &bb, RTLBuilder &s,
1148  hw::HWModulePortAccessor &ports) const override {
1149  auto unwrappedIO = unwrapIO(s, bb, ports);
1150  auto structType = op.getResult().getType();
1151  buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1152  return s.pack(inputs, structType);
1153  });
1154  };
1155 };
1156 
1157 class UnpackConversionPattern : public HandshakeConversionPattern<UnpackOp> {
1158 public:
1159  using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1160  void buildModule(UnpackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1161  hw::HWModulePortAccessor &ports) const override {
1162  auto unwrappedIO = unwrapIO(s, bb, ports);
1163  buildUnitRateForkLogic(s, bb, unwrappedIO,
1164  [&](Value input) { return s.unpack(input); });
1165  };
1166 };
1167 
1168 class ConditionalBranchConversionPattern
1169  : public HandshakeConversionPattern<ConditionalBranchOp> {
1170 public:
1171  using HandshakeConversionPattern<
1172  ConditionalBranchOp>::HandshakeConversionPattern;
1173  void buildModule(ConditionalBranchOp op, BackedgeBuilder &bb, RTLBuilder &s,
1174  hw::HWModulePortAccessor &ports) const override {
1175  auto unwrappedIO = unwrapIO(s, bb, ports);
1176  auto cond = unwrappedIO.inputs[0];
1177  auto arg = unwrappedIO.inputs[1];
1178  auto trueRes = unwrappedIO.outputs[0];
1179  auto falseRes = unwrappedIO.outputs[1];
1180 
1181  auto condArgValid = s.bAnd({cond.valid, arg.valid});
1182 
1183  // Connect valid signal of both results.
1184  trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1185  falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1186 
1187  // Connecte data signals of both results.
1188  trueRes.data->setValue(arg.data);
1189  falseRes.data->setValue(arg.data);
1190 
1191  // Connect ready signal of input and condition.
1192  auto selectedResultReady =
1193  s.mux(cond.data, {falseRes.ready, trueRes.ready});
1194  auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1195  arg.ready->setValue(condArgReady);
1196  cond.ready->setValue(condArgReady);
1197  };
1198 };
1199 
1200 template <typename TIn, bool signExtend>
1201 class ExtendConversionPattern : public HandshakeConversionPattern<TIn> {
1202 public:
1203  using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1204  void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1205  hw::HWModulePortAccessor &ports) const override {
1206  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1207  this->buildExtendLogic(s, unwrappedIO, /*signExtend=*/signExtend);
1208  };
1209 };
1210 
1211 class ComparisonConversionPattern
1212  : public HandshakeConversionPattern<arith::CmpIOp> {
1213 public:
1214  using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1215  void buildModule(arith::CmpIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1216  hw::HWModulePortAccessor &ports) const override {
1217  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1218  auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1219  return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1220  return s.b.create<comb::ICmpOp>(op.getLoc(), predicate, inputs[0],
1221  inputs[1]);
1222  });
1223  };
1224 
1225  switch (op.getPredicate()) {
1226  case arith::CmpIPredicate::eq:
1227  return buildCompareLogic(comb::ICmpPredicate::eq);
1228  case arith::CmpIPredicate::ne:
1229  return buildCompareLogic(comb::ICmpPredicate::ne);
1230  case arith::CmpIPredicate::slt:
1231  return buildCompareLogic(comb::ICmpPredicate::slt);
1232  case arith::CmpIPredicate::ult:
1233  return buildCompareLogic(comb::ICmpPredicate::ult);
1234  case arith::CmpIPredicate::sle:
1235  return buildCompareLogic(comb::ICmpPredicate::sle);
1236  case arith::CmpIPredicate::ule:
1237  return buildCompareLogic(comb::ICmpPredicate::ule);
1238  case arith::CmpIPredicate::sgt:
1239  return buildCompareLogic(comb::ICmpPredicate::sgt);
1240  case arith::CmpIPredicate::ugt:
1241  return buildCompareLogic(comb::ICmpPredicate::ugt);
1242  case arith::CmpIPredicate::sge:
1243  return buildCompareLogic(comb::ICmpPredicate::sge);
1244  case arith::CmpIPredicate::uge:
1245  return buildCompareLogic(comb::ICmpPredicate::uge);
1246  }
1247  assert(false && "invalid CmpIOp");
1248  };
1249 };
1250 
1251 class TruncateConversionPattern
1252  : public HandshakeConversionPattern<arith::TruncIOp> {
1253 public:
1254  using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1255  void buildModule(arith::TruncIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1256  hw::HWModulePortAccessor &ports) const override {
1257  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1258  unsigned targetBits =
1259  toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1260  buildTruncateLogic(s, unwrappedIO, targetBits);
1261  };
1262 };
1263 
1264 class ControlMergeConversionPattern
1265  : public HandshakeConversionPattern<ControlMergeOp> {
1266 public:
1267  using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1268  void buildModule(ControlMergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1269  hw::HWModulePortAccessor &ports) const override {
1270  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1271  auto resData = unwrappedIO.outputs[0];
1272  auto resIndex = unwrappedIO.outputs[1];
1273 
1274  // Define some common types and values that will be used.
1275  unsigned numInputs = unwrappedIO.inputs.size();
1276  auto indexType = s.b.getIntegerType(numInputs);
1277  Value noWinner = s.constant(numInputs, 0);
1278  Value c0I1 = s.constant(1, 0);
1279 
1280  // Declare register for storing arbitration winner.
1281  auto won = bb.get(indexType);
1282  Value wonReg = s.reg("won_reg", won, noWinner);
1283 
1284  // Declare wire for arbitration winner.
1285  auto win = bb.get(indexType);
1286 
1287  // Declare wire for whether the circuit just fired and emitted both
1288  // outputs.
1289  auto fired = bb.get(s.b.getI1Type());
1290 
1291  // Declare registers for storing if each output has been emitted.
1292  auto resultEmitted = bb.get(s.b.getI1Type());
1293  Value resultEmittedReg = s.reg("result_emitted_reg", resultEmitted, c0I1);
1294  auto indexEmitted = bb.get(s.b.getI1Type());
1295  Value indexEmittedReg = s.reg("index_emitted_reg", indexEmitted, c0I1);
1296 
1297  // Declare wires for if each output is done.
1298  auto resultDone = bb.get(s.b.getI1Type());
1299  auto indexDone = bb.get(s.b.getI1Type());
1300 
1301  // Create predicates to assert if the win wire or won register hold a
1302  // valid index.
1303  auto hasWinnerCondition = s.rOr({win});
1304  auto hadWinnerCondition = s.rOr({wonReg});
1305 
1306  // Create an arbiter based on a simple priority-encoding scheme to assign
1307  // an index to the win wire. If the won register is set, just use that. In
1308  // the case that won is not set and no input is valid, set a sentinel
1309  // value to indicate no winner was chosen. The constant values are
1310  // remembered in a map so they can be re-used later to assign the arg
1311  // ready outputs.
1312  DenseMap<size_t, Value> argIndexValues;
1313  Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1314  noWinner, argIndexValues);
1315  priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1316  win.setValue(priorityArb);
1317 
1318  // Create the logic to assign the result and index outputs. The result
1319  // valid output will always be assigned, and if isControl is not set, the
1320  // result data output will also be assigned. The index valid and data
1321  // outputs will always be assigned. The win wire from the arbiter is used
1322  // to index into a tree of muxes to select the chosen input's signal(s),
1323  // and is fed directly to the index output. Both the result and index
1324  // valid outputs are gated on the win wire being set to something other
1325  // than the sentinel value.
1326  auto resultNotEmitted = s.bNot(resultEmittedReg);
1327  auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1328  resData.valid->setValue(resultValid);
1329  resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1330 
1331  auto indexNotEmitted = s.bNot(indexEmittedReg);
1332  auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1333  resIndex.valid->setValue(indexValid);
1334 
1335  // Use the one-hot win wire to select the index to output in the index
1336  // data.
1337  SmallVector<Value, 8> indexOutputs;
1338  for (size_t i = 0; i < numInputs; ++i)
1339  indexOutputs.push_back(s.constant(64, i));
1340 
1341  auto indexOutput = s.ohMux(win, indexOutputs);
1342  resIndex.data->setValue(indexOutput);
1343 
1344  // Create the logic to set the won register. If the fired wire is
1345  // asserted, we have finished this round and can and reset the register to
1346  // the sentinel value that indicates there is no winner. Otherwise, we
1347  // need to hold the value of the win register until we can fire.
1348  won.setValue(s.mux(fired, {win, noWinner}));
1349 
1350  // Create the logic to set the done wires for the result and index. For
1351  // both outputs, the done wire is asserted when the output is valid and
1352  // ready, or the emitted register for that output is set.
1353  auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1354  resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1355 
1356  auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1357  indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1358 
1359  // Create the logic to set the fired wire. It is asserted when both result
1360  // and index are done.
1361  fired.setValue(s.bAnd({resultDone, indexDone}));
1362 
1363  // Create the logic to assign the emitted registers. If the fired wire is
1364  // asserted, we have finished this round and can reset the registers to 0.
1365  // Otherwise, we need to hold the values of the done registers until we
1366  // can fire.
1367  resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1368  indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1369 
1370  // Create the logic to assign the arg ready outputs. The logic is
1371  // identical for each arg. If the fired wire is asserted, and the win wire
1372  // holds an arg's index, that arg is ready.
1373  auto winnerOrDefault = s.mux(fired, {noWinner, win});
1374  for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1375  auto &indexValue = argIndexValues[i];
1376  ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1377  }
1378  };
1379 };
1380 
1381 class MergeConversionPattern : public HandshakeConversionPattern<MergeOp> {
1382 public:
1383  using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1384  void buildModule(MergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1385  hw::HWModulePortAccessor &ports) const override {
1386  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1387  auto resData = unwrappedIO.outputs[0];
1388 
1389  // Define some common types and values that will be used.
1390  unsigned numInputs = unwrappedIO.inputs.size();
1391  auto indexType = s.b.getIntegerType(numInputs);
1392  Value noWinner = s.constant(numInputs, 0);
1393 
1394  // Declare wire for arbitration winner.
1395  auto win = bb.get(indexType);
1396 
1397  // Create predicates to assert if the win wire holds a valid index.
1398  auto hasWinnerCondition = s.rOr(win);
1399 
1400  // Create an arbiter based on a simple priority-encoding scheme to assign an
1401  // index to the win wire. In the case that no input is valid, set a sentinel
1402  // value to indicate no winner was chosen. The constant values are
1403  // remembered in a map so they can be re-used later to assign the arg ready
1404  // outputs.
1405  DenseMap<size_t, Value> argIndexValues;
1406  Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1407  noWinner, argIndexValues);
1408  win.setValue(priorityArb);
1409 
1410  // Create the logic to assign the result outputs. The result valid and data
1411  // outputs will always be assigned. The win wire from the arbiter is used to
1412  // index into a tree of muxes to select the chosen input's signal(s). The
1413  // result outputs are gated on the win wire being non-zero.
1414 
1415  resData.valid->setValue(hasWinnerCondition);
1416  resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1417 
1418  // Create the logic to set the done wires for the result. The done wire is
1419  // asserted when the output is valid and ready, or the emitted register is
1420  // set.
1421  auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1422 
1423  // Create the logic to assign the arg ready outputs. The logic is
1424  // identical for each arg. If the fired wire is asserted, and the win wire
1425  // holds an arg's index, that arg is ready.
1426  auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1427  for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1428  auto &indexValue = argIndexValues[i];
1429  ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1430  }
1431  };
1432 };
1433 
1434 class LoadConversionPattern
1435  : public HandshakeConversionPattern<handshake::LoadOp> {
1436 public:
1437  using HandshakeConversionPattern<
1438  handshake::LoadOp>::HandshakeConversionPattern;
1439  void buildModule(handshake::LoadOp op, BackedgeBuilder &bb, RTLBuilder &s,
1440  hw::HWModulePortAccessor &ports) const override {
1441  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1442  auto addrFromUser = unwrappedIO.inputs[0];
1443  auto dataFromMem = unwrappedIO.inputs[1];
1444  auto controlIn = unwrappedIO.inputs[2];
1445  auto dataToUser = unwrappedIO.outputs[0];
1446  auto addrToMem = unwrappedIO.outputs[1];
1447 
1448  addrToMem.data->setValue(addrFromUser.data);
1449  dataToUser.data->setValue(dataFromMem.data);
1450 
1451  // The valid/ready logic between user address/control to memoryAddr is
1452  // join logic.
1453  buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1454 
1455  // The valid/ready logic between memoryData and outputData is a direct
1456  // connection.
1457  dataToUser.valid->setValue(dataFromMem.valid);
1458  dataFromMem.ready->setValue(dataToUser.ready);
1459  };
1460 };
1461 
1462 class StoreConversionPattern
1463  : public HandshakeConversionPattern<handshake::StoreOp> {
1464 public:
1465  using HandshakeConversionPattern<
1466  handshake::StoreOp>::HandshakeConversionPattern;
1467  void buildModule(handshake::StoreOp op, BackedgeBuilder &bb, RTLBuilder &s,
1468  hw::HWModulePortAccessor &ports) const override {
1469  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1470  auto addrFromUser = unwrappedIO.inputs[0];
1471  auto dataFromUser = unwrappedIO.inputs[1];
1472  auto controlIn = unwrappedIO.inputs[2];
1473  auto dataToMem = unwrappedIO.outputs[0];
1474  auto addrToMem = unwrappedIO.outputs[1];
1475 
1476  // Create a gate that will be asserted when all outputs are ready.
1477  auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1478 
1479  // Build the standard join logic from the inputs to the inputsValid and
1480  // outputsReady signals.
1481  HandshakeWire joinWire(bb, s.b.getNoneType());
1482  joinWire.ready->setValue(outputsReady);
1483  OutputHandshake joinOutput = joinWire.getAsOutput();
1484  buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1485 
1486  // Output address and data signals are connected directly.
1487  addrToMem.data->setValue(addrFromUser.data);
1488  dataToMem.data->setValue(dataFromUser.data);
1489 
1490  // Output valid signals are connected from the inputsValid wire.
1491  addrToMem.valid->setValue(*joinWire.valid);
1492  dataToMem.valid->setValue(*joinWire.valid);
1493  };
1494 };
1495 
1496 class MemoryConversionPattern
1497  : public HandshakeConversionPattern<handshake::MemoryOp> {
1498 public:
1499  using HandshakeConversionPattern<
1500  handshake::MemoryOp>::HandshakeConversionPattern;
1501  void buildModule(handshake::MemoryOp op, BackedgeBuilder &bb, RTLBuilder &s,
1502  hw::HWModulePortAccessor &ports) const override {
1503  auto loc = op.getLoc();
1504 
1505  // Gather up the load and store ports.
1506  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1507  struct LoadPort {
1508  InputHandshake &addr;
1509  OutputHandshake &data;
1510  OutputHandshake &done;
1511  };
1512  struct StorePort {
1513  InputHandshake &addr;
1514  InputHandshake &data;
1515  OutputHandshake &done;
1516  };
1517  SmallVector<LoadPort, 4> loadPorts;
1518  SmallVector<StorePort, 4> storePorts;
1519 
1520  unsigned stCount = op.getStCount();
1521  unsigned ldCount = op.getLdCount();
1522  for (unsigned i = 0, e = ldCount; i != e; ++i) {
1523  LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1524  unwrappedIO.outputs[i],
1525  unwrappedIO.outputs[ldCount + stCount + i]};
1526  loadPorts.push_back(port);
1527  }
1528 
1529  for (unsigned i = 0, e = stCount; i != e; ++i) {
1530  StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1531  unwrappedIO.inputs[i * 2],
1532  unwrappedIO.outputs[ldCount + i]};
1533  storePorts.push_back(port);
1534  }
1535 
1536  // used to drive the data wire of the control-only channels.
1537  auto c0I0 = s.constant(0, 0);
1538 
1539  auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1540  auto hlmem = s.b.create<seq::HLMemOp>(
1541  loc, s.clk, s.rst, "_handshake_memory_" + std::to_string(op.getId()),
1542  op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1543 
1544  // Create load ports...
1545  for (auto &ld : loadPorts) {
1546  llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1547  auto readData = s.b.create<seq::ReadPortOp>(loc, hlmem.getHandle(),
1548  addresses, ld.addr.valid,
1549  /*latency=*/0);
1550  ld.data.data->setValue(readData);
1551  ld.done.data->setValue(c0I0);
1552  // Create control fork for the load address valid and ready signals.
1553  buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1554  }
1555 
1556  // Create store ports...
1557  for (auto &st : storePorts) {
1558  // Create a register to buffer the valid path by 1 cycle, to match the
1559  // write latency of 1.
1560  auto writeValidBufferMuxBE = bb.get(s.b.getI1Type());
1561  auto writeValidBuffer =
1562  s.reg("writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1563  st.done.valid->setValue(writeValidBuffer);
1564  st.done.data->setValue(c0I0);
1565 
1566  // Create the logic for when both the buffered write valid signal and the
1567  // store complete ready signal are asserted.
1568  auto storeCompleted =
1569  s.bAnd({st.done.ready, writeValidBuffer}, "storeCompleted");
1570 
1571  // Create a signal for when the write valid buffer is empty or the output
1572  // is ready.
1573  auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1574  auto emptyOrComplete =
1575  s.bOr({notWriteValidBuffer, storeCompleted}, "emptyOrComplete");
1576 
1577  // Connect the gate to both the store address ready and store data ready
1578  st.addr.ready->setValue(emptyOrComplete);
1579  st.data.ready->setValue(emptyOrComplete);
1580 
1581  // Create a wire for when both the store address and data are valid.
1582  auto writeValid = s.bAnd({st.addr.valid, st.data.valid}, "writeValid");
1583 
1584  // Create a mux that drives the buffer input. If the emptyOrComplete
1585  // signal is asserted, the mux selects the writeValid signal. Otherwise,
1586  // it selects the buffer output, keeping the output registered until the
1587  // emptyOrComplete signal is asserted.
1588  writeValidBufferMuxBE.setValue(
1589  s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1590 
1591  // Instantiate the write port operation - truncate address width to memory
1592  // width.
1593  llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1594  s.b.create<seq::WritePortOp>(loc, hlmem.getHandle(), addresses,
1595  st.data.data, writeValid,
1596  /*latency=*/1);
1597  }
1598  }
1599 }; // namespace
1600 
1601 class SinkConversionPattern : public HandshakeConversionPattern<SinkOp> {
1602 public:
1603  using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1604  void buildModule(SinkOp op, BackedgeBuilder &bb, RTLBuilder &s,
1605  hw::HWModulePortAccessor &ports) const override {
1606  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1607  // A sink is always ready to accept a new value.
1608  unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1609  };
1610 };
1611 
1612 class SourceConversionPattern : public HandshakeConversionPattern<SourceOp> {
1613 public:
1614  using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1615  void buildModule(SourceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1616  hw::HWModulePortAccessor &ports) const override {
1617  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1618  // A source always provides a new (i0-typed) value.
1619  unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1620  unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1621  };
1622 };
1623 
1624 class ConstantConversionPattern
1625  : public HandshakeConversionPattern<handshake::ConstantOp> {
1626 public:
1627  using HandshakeConversionPattern<
1628  handshake::ConstantOp>::HandshakeConversionPattern;
1629  void buildModule(handshake::ConstantOp op, BackedgeBuilder &bb, RTLBuilder &s,
1630  hw::HWModulePortAccessor &ports) const override {
1631  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1632  unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1633  unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1634  auto constantValue = op->getAttrOfType<IntegerAttr>("value").getValue();
1635  unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1636  };
1637 };
1638 
1639 class BufferConversionPattern : public HandshakeConversionPattern<BufferOp> {
1640 public:
1641  using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1642  void buildModule(BufferOp op, BackedgeBuilder &bb, RTLBuilder &s,
1643  hw::HWModulePortAccessor &ports) const override {
1644  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1645  auto input = unwrappedIO.inputs[0];
1646  auto output = unwrappedIO.outputs[0];
1647  InputHandshake lastStage;
1648  SmallVector<int64_t> initValues;
1649 
1650  // For now, always build seq buffers.
1651  if (op.getInitValues())
1652  initValues = op.getInitValueArray();
1653 
1654  lastStage =
1655  buildSeqBufferLogic(s, bb, toValidType(op.getDataType()),
1656  op.getNumSlots(), input, output, initValues);
1657 
1658  // Connect the last stage to the output handshake.
1659  output.data->setValue(lastStage.data);
1660  output.valid->setValue(lastStage.valid);
1661  lastStage.ready->setValue(output.ready);
1662  };
1663 
1664  struct SeqBufferStage {
1665  SeqBufferStage(Type dataType, InputHandshake &preStage, BackedgeBuilder &bb,
1666  RTLBuilder &s, size_t index,
1667  std::optional<int64_t> initValue)
1668  : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1669 
1670  // Todo: Change when i0 support is added.
1671  c0s = createZeroDataConst(s, s.loc, dataType);
1672  currentStage.ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
1673 
1674  auto hasInitValue = s.constant(1, initValue.has_value());
1675  auto validBE = bb.get(s.b.getI1Type());
1676  auto validReg = s.reg(getRegName("valid"), validBE, hasInitValue);
1677  auto readyBE = bb.get(s.b.getI1Type());
1678 
1679  Value initValueCs = c0s;
1680  if (initValue.has_value())
1681  initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1682 
1683  // This could/should be revised but needs a larger rethinking to avoid
1684  // introducing new bugs.
1685  Value dataReg =
1686  buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1687  buildControlBufferLogic(validReg, readyBE, dataReg);
1688  }
1689 
1690  StringAttr getRegName(StringRef name) {
1691  return s.b.getStringAttr(name + std::to_string(index) + "_reg");
1692  }
1693 
1694  void buildControlBufferLogic(Value validReg, Backedge &readyBE,
1695  Value dataReg) {
1696  auto c0I1 = s.constant(1, 0);
1697  auto readyRegWire = bb.get(s.b.getI1Type());
1698  auto readyReg = s.reg(getRegName("ready"), readyRegWire, c0I1);
1699 
1700  // Create the logic to drive the current stage valid and potentially
1701  // data.
1702  currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1703  "controlValid" + std::to_string(index));
1704 
1705  // Create the logic to drive the current stage ready.
1706  auto notReadyReg = s.bNot(readyReg);
1707  readyBE.setValue(notReadyReg);
1708 
1709  auto succNotReady = s.bNot(*currentStage.ready);
1710  auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1711  auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1712  auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1713 
1714  // Create a mux for emptying the register when both are ready.
1715  auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1716  readyRegWire.setValue(resetSignal);
1717 
1718  // Add same logic for the data path if necessary.
1719  auto ctrlDataRegBE = bb.get(dataType);
1720  auto ctrlDataReg = s.reg(getRegName("ctrl_data"), ctrlDataRegBE, c0s);
1721  auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1722  currentStage.data = dataResult;
1723 
1724  auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1725  auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1726  ctrlDataRegBE.setValue(dataResetSignal);
1727  }
1728 
1729  Value buildDataBufferLogic(Value validReg, Value initValue,
1730  Backedge &validBE, Backedge &readyBE) {
1731  // Create a signal for when the valid register is empty or the successor
1732  // is ready to accept new token.
1733  auto notValidReg = s.bNot(validReg);
1734  auto emptyOrReady = s.bOr({notValidReg, readyBE});
1735  preStage.ready->setValue(emptyOrReady);
1736 
1737  // Create a mux that drives the register input. If the emptyOrReady
1738  // signal is asserted, the mux selects the predValid signal. Otherwise,
1739  // it selects the register output, keeping the output registered
1740  // unchanged.
1741  auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1742 
1743  // Now we can drive the valid register.
1744  validBE.setValue(validRegMux);
1745 
1746  // Create a mux that drives the date register.
1747  auto dataRegBE = bb.get(dataType);
1748  auto dataReg =
1749  s.reg(getRegName("data"),
1750  s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1751  dataRegBE.setValue(dataReg);
1752  return dataReg;
1753  }
1754 
1755  InputHandshake getOutput() { return currentStage; }
1756 
1757  Type dataType;
1758  InputHandshake &preStage;
1759  InputHandshake currentStage;
1760  RTLBuilder &s;
1761  BackedgeBuilder &bb;
1762  size_t index;
1763 
1764  // A zero-valued constant of equal type as the data type of this buffer.
1765  Value c0s;
1766  };
1767 
1768  InputHandshake buildSeqBufferLogic(RTLBuilder &s, BackedgeBuilder &bb,
1769  Type dataType, unsigned size,
1770  InputHandshake &input,
1771  OutputHandshake &output,
1772  llvm::ArrayRef<int64_t> initValues) const {
1773  // Prime the buffer building logic with an initial stage, which just
1774  // wraps the input handshake.
1775  InputHandshake currentStage = input;
1776 
1777  for (unsigned i = 0; i < size; ++i) {
1778  bool isInitialized = i < initValues.size();
1779  auto initValue =
1780  isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1781  currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1782  .getOutput();
1783  }
1784 
1785  return currentStage;
1786  };
1787 };
1788 
1789 class IndexCastConversionPattern
1790  : public HandshakeConversionPattern<arith::IndexCastOp> {
1791 public:
1792  using HandshakeConversionPattern<
1793  arith::IndexCastOp>::HandshakeConversionPattern;
1794  void buildModule(arith::IndexCastOp op, BackedgeBuilder &bb, RTLBuilder &s,
1795  hw::HWModulePortAccessor &ports) const override {
1796  auto unwrappedIO = this->unwrapIO(s, bb, ports);
1797  unsigned sourceBits =
1798  toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1799  unsigned targetBits =
1800  toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1801  if (targetBits < sourceBits)
1802  buildTruncateLogic(s, unwrappedIO, targetBits);
1803  else
1804  buildExtendLogic(s, unwrappedIO, /*signExtend=*/true);
1805  };
1806 };
1807 
1808 template <typename T>
1809 class ExtModuleConversionPattern : public OpConversionPattern<T> {
1810 public:
1811  ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1812  MLIRContext *context, OpBuilder &submoduleBuilder,
1813  HandshakeLoweringState &ls)
1814  : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
1815  submoduleBuilder(submoduleBuilder), ls(ls) {}
1816  using OpAdaptor = typename T::Adaptor;
1817 
1818  LogicalResult
1819  matchAndRewrite(T op, OpAdaptor adaptor,
1820  ConversionPatternRewriter &rewriter) const override {
1821 
1822  hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
1823  if (!implModule) {
1824  auto portInfo = ModulePortInfo(getPortInfoForOp(op));
1825  implModule = submoduleBuilder.create<hw::HWModuleExternOp>(
1826  op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
1827  portInfo);
1828  }
1829 
1830  llvm::SmallVector<Value> operands = adaptor.getOperands();
1831  addSequentialIOOperandsIfNeeded(op, operands);
1832  rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1833  op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1834  return success();
1835  }
1836 
1837 private:
1838  OpBuilder &submoduleBuilder;
1839  HandshakeLoweringState &ls;
1840 };
1841 
1842 class FuncOpConversionPattern : public OpConversionPattern<handshake::FuncOp> {
1843 public:
1844  using OpConversionPattern::OpConversionPattern;
1845 
1846  LogicalResult
1847  matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
1848  ConversionPatternRewriter &rewriter) const override {
1849  ModulePortInfo ports =
1850  getPortInfoForOpTypes(op, op.getArgumentTypes(), op.getResultTypes());
1851 
1852  HWModuleLike hwModule;
1853  if (op.isExternal()) {
1854  hwModule = rewriter.create<hw::HWModuleExternOp>(
1855  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1856  } else {
1857  auto hwModuleOp = rewriter.create<hw::HWModuleOp>(
1858  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1859  auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1860  rewriter.inlineBlockBefore(&op.getBody().front(),
1861  hwModuleOp.getBodyBlock()->getTerminator(),
1862  args);
1863  hwModule = hwModuleOp;
1864  }
1865 
1866  // Was any predeclaration associated with this func? If so, replace uses
1867  // with the newly created module and erase the predeclaration.
1868  if (auto predecl =
1869  op->getAttrOfType<FlatSymbolRefAttr>(kPredeclarationAttr)) {
1870  auto *parentOp = op->getParentOp();
1871  auto *predeclModule =
1872  SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1873  if (predeclModule) {
1874  if (failed(SymbolTable::replaceAllSymbolUses(
1875  predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1876  return failure();
1877  rewriter.eraseOp(predeclModule);
1878  }
1879  }
1880 
1881  rewriter.eraseOp(op);
1882  return success();
1883  }
1884 };
1885 
1886 } // namespace
1887 
1888 //===----------------------------------------------------------------------===//
1889 // HW Top-module Related Functions
1890 //===----------------------------------------------------------------------===//
1891 
1892 static LogicalResult convertFuncOp(ESITypeConverter &typeConverter,
1893  ConversionTarget &target,
1894  handshake::FuncOp op,
1895  OpBuilder &moduleBuilder) {
1896 
1897  std::map<std::string, unsigned> instanceNameCntr;
1898  NameUniquer instanceUniquer = [&](Operation *op) {
1899  std::string instName = getCallName(op);
1900  if (auto idAttr = op->getAttrOfType<IntegerAttr>("handshake_id"); idAttr) {
1901  // We use a special naming convention for operations which have a
1902  // 'handshake_id' attribute.
1903  instName += "_id" + std::to_string(idAttr.getValue().getZExtValue());
1904  } else {
1905  // Fallback to just prefixing with an integer.
1906  instName += std::to_string(instanceNameCntr[instName]++);
1907  }
1908  return instName;
1909  };
1910 
1911  auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1912  instanceUniquer};
1913  RewritePatternSet patterns(op.getContext());
1914  patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1915  op.getContext());
1916  patterns.insert<JoinConversionPattern, ForkConversionPattern,
1917  SyncConversionPattern>(typeConverter, op.getContext(),
1918  moduleBuilder, ls);
1919 
1920  patterns.insert<
1921  // Comb operations.
1922  UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1923  UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1924  UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1925  UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1926  UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1927  UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1928  UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1929  UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1930  UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1931  UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1932  UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1933  UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1934  UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1935  UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1936  // HW operations.
1937  StructCreateConversionPattern,
1938  // Handshake operations.
1939  ConditionalBranchConversionPattern, MuxConversionPattern,
1940  PackConversionPattern, UnpackConversionPattern,
1941  ComparisonConversionPattern, BufferConversionPattern,
1942  SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1943  MergeConversionPattern, ControlMergeConversionPattern,
1944  LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1945  InstanceConversionPattern,
1946  // Arith operations.
1947  ExtendConversionPattern<arith::ExtUIOp, /*signExtend=*/false>,
1948  ExtendConversionPattern<arith::ExtSIOp, /*signExtend=*/true>,
1949  TruncateConversionPattern, IndexCastConversionPattern>(
1950  typeConverter, op.getContext(), moduleBuilder, ls);
1951 
1952  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1953  return op->emitOpError() << "error during conversion";
1954  return success();
1955 }
1956 
1957 namespace {
1958 class HandshakeToHWPass
1959  : public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1960 public:
1961  void runOnOperation() override {
1962  mlir::ModuleOp mod = getOperation();
1963 
1964  // Lowering to HW requires that every value is used exactly once. Check
1965  // whether this precondition is met, and if not, exit.
1966  for (auto f : mod.getOps<handshake::FuncOp>()) {
1967  if (failed(verifyAllValuesHasOneUse(f))) {
1968  f.emitOpError() << "HandshakeToHW: failed to verify that all values "
1969  "are used exactly once. Remember to run the "
1970  "fork/sink materialization pass before HW lowering.";
1971  signalPassFailure();
1972  return;
1973  }
1974  }
1975 
1976  // Resolve the instance graph to get a top-level module.
1977  std::string topLevel;
1979  SmallVector<std::string> sortedFuncs;
1980  if (resolveInstanceGraph(mod, uses, topLevel, sortedFuncs).failed()) {
1981  signalPassFailure();
1982  return;
1983  }
1984 
1985  ESITypeConverter typeConverter;
1986  ConversionTarget target(getContext());
1987  // All top-level logic of a handshake module will be the interconnectivity
1988  // between instantiated modules.
1989  target.addLegalOp<hw::HWModuleOp, hw::HWModuleExternOp, hw::OutputOp,
1990  hw::InstanceOp>();
1991  target
1992  .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1993 
1994  // Convert the handshake.func operations in post-order wrt. the instance
1995  // graph. This ensures that any referenced submodules (through
1996  // handshake.instance) has already been lowered, and their HW module
1997  // equivalents are available.
1998  OpBuilder submoduleBuilder(mod.getContext());
1999  submoduleBuilder.setInsertionPointToStart(mod.getBody());
2000  for (auto &funcName : llvm::reverse(sortedFuncs)) {
2001  auto funcOp = mod.lookupSymbol<handshake::FuncOp>(funcName);
2002  assert(funcOp && "handshake.func not found in module!");
2003  if (failed(
2004  convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2005  signalPassFailure();
2006  return;
2007  }
2008  }
2009 
2010  // Second stage: Convert any handshake.extmemory operations and the
2011  // top-level I/O associated with these.
2012  for (auto hwModule : mod.getOps<hw::HWModuleOp>())
2013  if (failed(convertExtMemoryOps(hwModule)))
2014  return signalPassFailure();
2015 
2016  // Run conversions which need see everything.
2017  HWSymbolCache symbolCache;
2018  symbolCache.addDefinitions(mod);
2019  symbolCache.freeze();
2020  RewritePatternSet patterns(mod.getContext());
2021  patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2022  symbolCache);
2023  if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
2024  mod->emitOpError() << "error during conversion";
2025  signalPassFailure();
2026  }
2027  }
2028 };
2029 } // end anonymous namespace
2030 
2031 std::unique_ptr<mlir::Pass> circt::createHandshakeToHWPass() {
2032  return std::make_unique<HandshakeToHWPass>();
2033 }
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
MlirType elementType
Definition: CHIRRTL.cpp:29
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
Definition: CalyxOps.cpp:121
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:540
static Type tupleToStruct(TupleType tuple)
Definition: DCToHW.cpp:48
std::function< std::string(Operation *)> NameUniquer
Definition: DCToHW.cpp:45
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
Definition: FIRRTLOps.cpp:1013
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
static std::string getCallName(Operation *op)
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
Definition: OM.cpp:113
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition: HWSymCache.h:27
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition: HWSymCache.h:56
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
Definition: HWSymCache.h:75
Channels are the basic communication primitives.
Definition: Types.h:63
const Type * getInner() const
Definition: Types.h:66
def create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition: seq.py:157
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition: CombOps.cpp:25
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:21
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getInputs()
PortDirectionRange getOutputs()