CIRCT 23.0.0git
Loading...
Searching...
No Matches
HandshakeToHW.cpp
Go to the documentation of this file.
1//===- HandshakeToHW.cpp - Translate Handshake into HW ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7//===----------------------------------------------------------------------===//
8//
9// This is the main Handshake to HW Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
26#include "mlir/Dialect/Arith/IR/Arith.h"
27#include "mlir/Dialect/MemRef/IR/MemRef.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/MathExtras.h"
34#include <optional>
35
36namespace circt {
37#define GEN_PASS_DEF_HANDSHAKETOHW
38#include "circt/Conversion/Passes.h.inc"
39} // namespace circt
40
41using namespace mlir;
42using namespace circt;
43using namespace circt::handshake;
44using namespace circt::hw;
45
46using NameUniquer = std::function<std::string(Operation *)>;
47
48namespace {
49
50static Type tupleToStruct(TypeRange types) {
51 return toValidType(mlir::TupleType::get(types[0].getContext(), types));
52}
53
54// Shared state used by various functions; captured in a struct to reduce the
55// number of arguments that we have to pass around.
56struct HandshakeLoweringState {
57 ModuleOp parentModule;
58 NameUniquer nameUniquer;
59};
60
61// A type converter is needed to perform the in-flight materialization of "raw"
62// (non-ESI channel) types to their ESI channel correspondents. This comes into
63// effect when backedges exist in the input IR.
64class ESITypeConverter : public TypeConverter {
65public:
66 ESITypeConverter() {
67 addConversion([](Type type) -> Type { return esiWrapper(type); });
68
69 addTargetMaterialization([&](mlir::OpBuilder &builder,
70 mlir::Type resultType, mlir::ValueRange inputs,
71 mlir::Location loc) -> mlir::Value {
72 if (inputs.size() != 1)
73 return Value();
74 return UnrealizedConversionCastOp::create(builder, loc, resultType,
75 inputs[0])
76 ->getResult(0);
77 });
78
79 addSourceMaterialization([&](mlir::OpBuilder &builder,
80 mlir::Type resultType, mlir::ValueRange inputs,
81 mlir::Location loc) -> mlir::Value {
82 if (inputs.size() != 1)
83 return Value();
84 return UnrealizedConversionCastOp::create(builder, loc, resultType,
85 inputs[0])
86 ->getResult(0);
87 });
88 }
89};
90
91} // namespace
92
93/// Returns a submodule name resulting from an operation, without discriminating
94/// type information.
95static std::string getBareSubModuleName(Operation *oldOp) {
96 // The dialect name is separated from the operation name by '.', which is not
97 // valid in SystemVerilog module names. In case this name is used in
98 // SystemVerilog output, replace '.' with '_'.
99 std::string subModuleName = oldOp->getName().getStringRef().str();
100 std::replace(subModuleName.begin(), subModuleName.end(), '.', '_');
101 return subModuleName;
102}
103
104static std::string getCallName(Operation *op) {
105 auto callOp = dyn_cast<handshake::InstanceOp>(op);
106 return callOp ? callOp.getModule().str() : getBareSubModuleName(op);
107}
108
109/// Extracts the type of the data-carrying type of opType. If opType is an ESI
110/// channel, getHandshakeBundleDataType extracts the data-carrying type, else,
111/// assume that opType itself is the data-carrying type.
112static Type getOperandDataType(Value op) {
113 auto opType = op.getType();
114 if (auto channelType = dyn_cast<esi::ChannelType>(opType))
115 return channelType.getInner();
116 return opType;
117}
118
119/// Returns a set of types which may uniquely identify the provided op. Return
120/// value is <inputTypes, outputTypes>.
121using DiscriminatingTypes = std::pair<SmallVector<Type>, SmallVector<Type>>;
123 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
124 .Case<MemoryOp, ExternalMemoryOp>([&](auto memOp) {
125 return DiscriminatingTypes{{},
126 {memOp.getMemRefType().getElementType()}};
127 })
128 .Default([&](auto) {
129 // By default, all in- and output types which is not a control type
130 // (NoneType) are discriminating types.
131 SmallVector<Type> inTypes, outTypes;
132 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
134 llvm::transform(op->getResults(), std::back_inserter(outTypes),
136 return DiscriminatingTypes{inTypes, outTypes};
137 });
138}
139
140/// Get type name. Currently we only support integer or index types.
141/// The emitted type aligns with the getFIRRTLType() method. Thus all integers
142/// other than signed integers will be emitted as unsigned.
143// NOLINTNEXTLINE(misc-no-recursion)
144static std::string getTypeName(Location loc, Type type) {
145 std::string typeName;
146 // Builtin types
147 if (type.isIntOrIndex()) {
148 if (auto indexType = dyn_cast<IndexType>(type))
149 typeName += "_ui" + std::to_string(indexType.kInternalStorageBitWidth);
150 else if (type.isSignedInteger())
151 typeName += "_si" + std::to_string(type.getIntOrFloatBitWidth());
152 else
153 typeName += "_ui" + std::to_string(type.getIntOrFloatBitWidth());
154 } else if (auto tupleType = dyn_cast<TupleType>(type)) {
155 typeName += "_tuple";
156 for (auto elementType : tupleType.getTypes())
157 typeName += getTypeName(loc, elementType);
158 } else if (auto structType = dyn_cast<hw::StructType>(type)) {
159 typeName += "_struct";
160 for (auto element : structType.getElements())
161 typeName += "_" + element.name.str() + getTypeName(loc, element.type);
162 } else if (isa<NoneType>(type)) {
163 typeName += "_none";
164 } else
165 emitError(loc) << "unsupported data type '" << type << "'";
166
167 return typeName;
168}
169
170/// Construct a name for creating HW sub-module.
171static std::string getSubModuleName(Operation *oldOp) {
172 if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
173 return instanceOp.getModule().str();
174
175 std::string subModuleName = getBareSubModuleName(oldOp);
176
177 // Add value of the constant operation.
178 if (auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
179 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
180 auto intType = intAttr.getType();
181
182 if (intType.isSignedInteger())
183 subModuleName += "_c" + std::to_string(intAttr.getSInt());
184 else if (intType.isUnsignedInteger())
185 subModuleName += "_c" + std::to_string(intAttr.getUInt());
186 else
187 subModuleName += "_c" + std::to_string((uint64_t)intAttr.getInt());
188 } else
189 oldOp->emitError("unsupported constant type");
190 }
191
192 // Add discriminating in- and output types.
193 auto [inTypes, outTypes] = getHandshakeDiscriminatingTypes(oldOp);
194 if (!inTypes.empty())
195 subModuleName += "_in";
196 for (auto inType : inTypes)
197 subModuleName += getTypeName(oldOp->getLoc(), inType);
198
199 if (!outTypes.empty())
200 subModuleName += "_out";
201 for (auto outType : outTypes)
202 subModuleName += getTypeName(oldOp->getLoc(), outType);
203
204 // Add memory ID.
205 if (auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
206 subModuleName += "_id" + std::to_string(memOp.getId());
207
208 // Add compare kind.
209 if (auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
210 subModuleName += "_" + stringifyEnum(comOp.getPredicate()).str();
211
212 // Add buffer information.
213 if (auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
214 subModuleName += "_" + std::to_string(bufferOp.getNumSlots()) + "slots";
215 if (bufferOp.isSequential())
216 subModuleName += "_seq";
217 else
218 subModuleName += "_fifo";
219
220 if (auto initValues = bufferOp.getInitValues()) {
221 subModuleName += "_init";
222 for (const Attribute e : *initValues) {
223 assert(isa<IntegerAttr>(e));
224 subModuleName +=
225 "_" + std::to_string(dyn_cast<IntegerAttr>(e).getInt());
226 }
227 }
228 }
229
230 // Add control information.
231 if (auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
232 ctrlInterface && ctrlInterface.isControl()) {
233 // Add some additional discriminating info for non-typed operations.
234 subModuleName += "_" + std::to_string(oldOp->getNumOperands()) + "ins_" +
235 std::to_string(oldOp->getNumResults()) + "outs";
236 subModuleName += "_ctrl";
237 } else {
238 assert(
239 (!inTypes.empty() || !outTypes.empty()) &&
240 "Insufficient discriminating type info generated for the operation!");
241 }
242
243 return subModuleName;
244}
245
246//===----------------------------------------------------------------------===//
247// HW Sub-module Related Functions
248//===----------------------------------------------------------------------===//
249
250/// Check whether a submodule with the same name has been created elsewhere in
251/// the top level module. Return the matched module operation if true, otherwise
252/// return nullptr.
253static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
254 StringRef modName) {
255 if (auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
256 return mod;
257 if (auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
258 return mod;
259 return {};
260}
261
262static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
263 Operation *oldOp) {
264 HWModuleLike targetModule;
265 if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
266 targetModule = checkSubModuleOp(parentModule, instanceOp.getModule());
267 else
268 targetModule = checkSubModuleOp(parentModule, getSubModuleName(oldOp));
269
270 if (isa<handshake::InstanceOp>(oldOp))
271 assert(targetModule &&
272 "handshake.instance target modules should always have been lowered "
273 "before the modules that reference them!");
274 return targetModule;
275}
276
277/// Returns a vector of PortInfo's which defines the HW interface of the
278/// to-be-converted op.
279static ModulePortInfo getPortInfoForOp(Operation *op) {
280 return getPortInfoForOpTypes(op, op->getOperandTypes(), op->getResultTypes());
281}
282
283static llvm::SmallVector<hw::detail::FieldInfo>
284portToFieldInfo(llvm::ArrayRef<hw::PortInfo> portInfo) {
285 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
286 for (auto port : portInfo)
287 fieldInfo.push_back({port.name, port.type});
288
289 return fieldInfo;
290}
291
292// Convert any handshake.extmemory operations and the top-level I/O
293// associated with these.
294static LogicalResult convertExtMemoryOps(HWModuleOp mod) {
295 auto *ctx = mod.getContext();
296
297 // Gather memref ports to be converted.
298 llvm::DenseMap<unsigned, Value> memrefPorts;
299 for (auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
300 auto channel = dyn_cast<esi::ChannelType>(arg.getType());
301 if (channel && isa<MemRefType>(channel.getInner()))
302 memrefPorts[i] = arg;
303 }
304
305 if (memrefPorts.empty())
306 return success(); // nothing to do.
307
308 OpBuilder b(mod);
309
310 auto getMemoryIOInfo = [&](Location loc, Twine portName, unsigned argIdx,
311 ArrayRef<hw::PortInfo> info,
312 hw::ModulePort::Direction direction) {
313 auto type = hw::StructType::get(ctx, portToFieldInfo(info));
314 auto portInfo =
315 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
316 return portInfo;
317 };
318
319 for (auto [i, arg] : memrefPorts) {
320 // Insert ports into the module
321 auto memName = mod.getArgName(i);
322
323 // Get the attached extmemory external module.
324 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
325 auto extmemMod =
326 cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
327 extmemInstance, extmemInstance.getModuleNameAttr()));
328
329 ModulePortInfo portInfo(extmemMod.getPortList());
330
331 // The extmemory external module's interface is a direct wrapping of the
332 // original handshake.extmemory operation in- and output types. Remove the
333 // first input argument (the !esi.channel<memref> op) since that is what
334 // we're replacing with a materialized interface.
335 portInfo.eraseInput(0);
336
337 // Add memory input - this is the output of the extmemory op.
338 SmallVector<PortInfo> outputs(portInfo.getOutputs());
339 auto inPortInfo =
340 getMemoryIOInfo(arg.getLoc(), memName.strref() + "_in", i, outputs,
341 hw::ModulePort::Direction::Input);
342 mod.insertPorts({{i, inPortInfo}}, {});
343 auto newInPort = mod.getArgumentForInput(i);
344 // Replace the extmemory submodule outputs with the newly created inputs.
345 b.setInsertionPointToStart(mod.getBodyBlock());
346 auto newInPortExploded = hw::StructExplodeOp::create(
347 b, arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
348 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
349
350 // Add memory output - this is the inputs of the extmemory op (without the
351 // first argument);
352 unsigned outArgI = mod.getNumOutputPorts();
353 SmallVector<PortInfo> inputs(portInfo.getInputs());
354 auto outPortInfo =
355 getMemoryIOInfo(arg.getLoc(), memName.strref() + "_out", outArgI,
356 inputs, hw::ModulePort::Direction::Output);
357
358 auto memOutputArgs = extmemInstance.getOperands().drop_front();
359 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
360 auto memOutputStruct = hw::StructCreateOp::create(
361 b, arg.getLoc(), outPortInfo.type, memOutputArgs);
362 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
363
364 // Erase the extmemory submodule instace since the i/o has now been
365 // plumbed.
366 extmemMod.erase();
367 extmemInstance.erase();
368
369 // Erase the original memref argument of the top-level i/o now that it's use
370 // has been removed.
371 mod.modifyPorts(/*insertInputs*/ {}, /*insertOutputs*/ {},
372 /*eraseInputs*/ {i + 1}, /*eraseOutputs*/ {});
373 }
374
375 return success();
376}
377
378namespace {
379
380// Input handshakes contain a resolved valid and (optional )data signal, and
381// a to-be-assigned ready signal.
382struct InputHandshake {
383 Value valid;
384 std::shared_ptr<Backedge> ready;
385 Value data;
386};
387
388// Output handshakes contain a resolved ready, and to-be-assigned valid and
389// (optional) data signals.
390struct OutputHandshake {
391 std::shared_ptr<Backedge> valid;
392 Value ready;
393 std::shared_ptr<Backedge> data;
394};
395
396/// A helper struct that acts like a wire. Can be used to interact with the
397/// RTLBuilder when multiple built components should be connected.
398struct HandshakeWire {
399 HandshakeWire(BackedgeBuilder &bb, Type dataType) {
400 MLIRContext *ctx = dataType.getContext();
401 auto i1Type = IntegerType::get(ctx, 1);
402 valid = std::make_shared<Backedge>(bb.get(i1Type));
403 ready = std::make_shared<Backedge>(bb.get(i1Type));
404 data = std::make_shared<Backedge>(bb.get(dataType));
405 }
406
407 // Functions that allow to treat a wire like an input or output port.
408 // **Careful**: Such a port will not be updated when backedges are resolved.
409 InputHandshake getAsInput() { return {*valid, ready, *data}; }
410 OutputHandshake getAsOutput() { return {valid, *ready, data}; }
411
412 std::shared_ptr<Backedge> valid;
413 std::shared_ptr<Backedge> ready;
414 std::shared_ptr<Backedge> data;
415};
416
417template <typename T, typename TInner>
418llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
419 llvm::function_ref<T(TInner &)> extractor) {
420 llvm::SmallVector<T> result;
421 llvm::transform(container, std::back_inserter(result), extractor);
422 return result;
423}
424struct UnwrappedIO {
425 llvm::SmallVector<InputHandshake> inputs;
426 llvm::SmallVector<OutputHandshake> outputs;
427
428 llvm::SmallVector<Value> getInputValids() {
429 return extractValues<Value, InputHandshake>(
430 inputs, [](auto &hs) { return hs.valid; });
431 }
432 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
433 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
434 inputs, [](auto &hs) { return hs.ready; });
435 }
436 llvm::SmallVector<Value> getInputDatas() {
437 return extractValues<Value, InputHandshake>(
438 inputs, [](auto &hs) { return hs.data; });
439 }
440 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
441 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
442 outputs, [](auto &hs) { return hs.valid; });
443 }
444 llvm::SmallVector<Value> getOutputReadys() {
445 return extractValues<Value, OutputHandshake>(
446 outputs, [](auto &hs) { return hs.ready; });
447 }
448 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
449 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
450 outputs, [](auto &hs) { return hs.data; });
451 }
452};
453
454// A class containing a bunch of syntactic sugar to reduce builder function
455// verbosity.
456// @todo: should be moved to support.
457struct RTLBuilder {
458 RTLBuilder(hw::ModulePortInfo info, OpBuilder &builder, Location loc,
459 Value clk = Value(), Value rst = Value())
460 : info(std::move(info)), b(builder), loc(loc), clk(clk), rst(rst) {}
461
462 Value constant(const APInt &apv, std::optional<StringRef> name = {}) {
463 // Cannot use zero-width APInt's in DenseMap's, see
464 // https://github.com/llvm/llvm-project/issues/58013
465 bool isZeroWidth = apv.getBitWidth() == 0;
466 if (!isZeroWidth) {
467 auto it = constants.find(apv);
468 if (it != constants.end())
469 return it->second;
470 }
471
472 auto cval = hw::ConstantOp::create(b, loc, apv);
473 if (!isZeroWidth)
474 constants[apv] = cval;
475 return cval;
476 }
477
478 Value constant(unsigned width, int64_t value,
479 std::optional<StringRef> name = {}) {
480 return constant(
481 APInt(width, value, /*isSigned=*/false, /*implicitTrunc=*/true));
482 }
483 std::pair<Value, Value> wrap(Value data, Value valid,
484 std::optional<StringRef> name = {}) {
485 auto wrapOp = esi::WrapValidReadyOp::create(b, loc, data, valid);
486 return {wrapOp.getResult(0), wrapOp.getResult(1)};
487 }
488 std::pair<Value, Value> unwrap(Value channel, Value ready,
489 std::optional<StringRef> name = {}) {
490 auto unwrapOp = esi::UnwrapValidReadyOp::create(b, loc, channel, ready);
491 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
492 }
493
494 // Various syntactic sugar functions.
495 Value reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
496 Value rst = Value()) {
497 Value resolvedClk = clk ? clk : this->clk;
498 Value resolvedRst = rst ? rst : this->rst;
499 assert(resolvedClk &&
500 "No global clock provided to this RTLBuilder - a clock "
501 "signal must be provided to the reg(...) function.");
502 assert(resolvedRst &&
503 "No global reset provided to this RTLBuilder - a reset "
504 "signal must be provided to the reg(...) function.");
505
506 return seq::CompRegOp::create(b, loc, in, resolvedClk, resolvedRst,
507 rstValue, name);
508 }
509
510 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
511 std::optional<StringRef> name = {}) {
512 return comb::ICmpOp::create(b, loc, predicate, lhs, rhs);
513 }
514
515 Value buildNamedOp(llvm::function_ref<Value()> f,
516 std::optional<StringRef> name) {
517 Value v = f();
518 StringAttr nameAttr;
519 Operation *op = v.getDefiningOp();
520 if (name.has_value()) {
521 op->setAttr("sv.namehint", b.getStringAttr(*name));
522 nameAttr = b.getStringAttr(*name);
523 }
524 return v;
525 }
526
527 // Bitwise 'and'.
528 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
529 return buildNamedOp(
530 [&]() { return comb::AndOp::create(b, loc, values, false); }, name);
531 }
532
533 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
534 return buildNamedOp(
535 [&]() { return comb::OrOp::create(b, loc, values, false); }, name);
536 }
537
538 // Bitwise 'not'.
539 Value bNot(Value value, std::optional<StringRef> name = {}) {
540 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
541 std::string inferedName;
542 if (!name) {
543 // Try to create a name from the input value.
544 if (auto valueName =
545 value.getDefiningOp()->getAttrOfType<StringAttr>("sv.namehint")) {
546 inferedName = ("not_" + valueName.getValue()).str();
547 name = inferedName;
548 }
549 }
550
551 return buildNamedOp(
552 [&]() { return comb::XorOp::create(b, loc, value, allOnes); }, name);
553
554 return b.createOrFold<comb::XorOp>(loc, value, allOnes, false);
555 }
556
557 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
558 return buildNamedOp(
559 [&]() { return comb::ShlOp::create(b, loc, value, shift); }, name);
560 }
561
562 Value concat(ValueRange values, std::optional<StringRef> name = {}) {
563 return buildNamedOp(
564 [&]() { return comb::ConcatOp::create(b, loc, values); }, name);
565 }
566
567 // Packs a list of values into a hw.struct.
568 Value pack(ValueRange values, Type structType = Type(),
569 std::optional<StringRef> name = {}) {
570 if (!structType)
571 structType = tupleToStruct(values.getTypes());
572 return buildNamedOp(
573 [&]() {
574 return hw::StructCreateOp::create(b, loc, structType, values);
575 },
576 name);
577 }
578
579 // Unpacks a hw.struct into a list of values.
580 ValueRange unpack(Value value) {
581 auto structType = cast<hw::StructType>(value.getType());
582 llvm::SmallVector<Type> innerTypes;
583 structType.getInnerTypes(innerTypes);
584 return hw::StructExplodeOp::create(b, loc, innerTypes, value).getResults();
585 }
586
587 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
588 llvm::SmallVector<Value> bits;
589 for (unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
590 bits.push_back(comb::ExtractOp::create(b, loc, v, i, /*bitWidth=*/1));
591 return bits;
592 }
593
594 // OR-reduction of the bits in 'v'.
595 Value rOr(Value v, std::optional<StringRef> name = {}) {
596 return buildNamedOp([&]() { return bOr(toBits(v)); }, name);
597 }
598
599 // Extract bits v[hi:lo] (inclusive).
600 Value extract(Value v, unsigned lo, unsigned hi,
601 std::optional<StringRef> name = {}) {
602 unsigned width = hi - lo + 1;
603 return buildNamedOp(
604 [&]() { return comb::ExtractOp::create(b, loc, v, lo, width); }, name);
605 }
606
607 // Truncates 'value' to its lower 'width' bits.
608 Value truncate(Value value, unsigned width,
609 std::optional<StringRef> name = {}) {
610 return extract(value, 0, width - 1, name);
611 }
612
613 Value zext(Value value, unsigned outWidth,
614 std::optional<StringRef> name = {}) {
615 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
616 assert(inWidth <= outWidth && "zext: input width must be <- output width.");
617 if (inWidth == outWidth)
618 return value;
619 auto c0 = constant(outWidth - inWidth, 0);
620 return concat({c0, value}, name);
621 }
622
623 Value sext(Value value, unsigned outWidth,
624 std::optional<StringRef> name = {}) {
625 return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
626 }
627
628 // Extracts a single bit v[bit].
629 Value bit(Value v, unsigned index, std::optional<StringRef> name = {}) {
630 return extract(v, index, index, name);
631 }
632
633 // Creates a hw.array of the given values.
634 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
635 return buildNamedOp(
636 [&]() { return hw::ArrayCreateOp::create(b, loc, values); }, name);
637 }
638
639 // Extract the 'index'th value from the input array.
640 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
641 return buildNamedOp(
642 [&]() { return hw::ArrayGetOp::create(b, loc, array, index); }, name);
643 }
644
645 // Muxes a range of values.
646 // The select signal is expected to be a decimal value which selects starting
647 // from the lowest index of value.
648 Value mux(Value index, ValueRange values,
649 std::optional<StringRef> name = {}) {
650 if (values.size() == 2)
651 return comb::MuxOp::create(b, loc, index, values[1], values[0]);
652
653 return arrayGet(arrayCreate(values), index, name);
654 }
655
656 // Muxes a range of values. The select signal is expected to be a 1-hot
657 // encoded value.
658 Value ohMux(Value index, ValueRange inputs) {
659 // Confirm the select input can be a one-hot encoding for the inputs.
660 unsigned numInputs = inputs.size();
661 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
662 "one-hot select can't mux inputs");
663
664 // Start the mux tree with zero value.
665 // Todo: clean up when handshake supports i0.
666 auto dataType = inputs[0].getType();
667 unsigned width =
668 isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
669 Value muxValue = constant(width, 0);
670
671 // Iteratively chain together muxes from the high bit to the low bit.
672 for (size_t i = numInputs - 1; i != 0; --i) {
673 Value input = inputs[i];
674 Value selectBit = bit(index, i);
675 muxValue = mux(selectBit, {muxValue, input});
676 }
677
678 return muxValue;
679 }
680
682 OpBuilder &b;
683 Location loc;
684 Value clk, rst;
685 DenseMap<APInt, Value> constants;
686};
687
688/// Creates a Value that has an assigned zero value. For structs, this
689/// corresponds to assigning zero to each element recursively.
690static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
691 return TypeSwitch<Type, Value>(type)
692 .Case<NoneType>([&](NoneType) { return s.constant(0, 0); })
693 .Case<IntType, IntegerType>([&](auto type) {
694 return s.constant(type.getIntOrFloatBitWidth(), 0);
695 })
696 .Case<hw::StructType>([&](auto structType) {
697 SmallVector<Value> zeroValues;
698 for (auto field : structType.getElements())
699 zeroValues.push_back(createZeroDataConst(s, loc, field.type));
700 return hw::StructCreateOp::create(s.b, loc, structType, zeroValues);
701 })
702 .Default([&](Type) -> Value {
703 emitError(loc) << "unsupported type for zero value: " << type;
704 assert(false);
705 return {};
706 });
707}
708
709static void
710addSequentialIOOperandsIfNeeded(Operation *op,
711 llvm::SmallVectorImpl<Value> &operands) {
712 if (op->hasTrait<mlir::OpTrait::HasClock>()) {
713 // Parent should at this point be a hw.module and have clock and reset
714 // ports.
715 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
716 operands.push_back(
717 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
718 operands.push_back(
719 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
720 }
721}
722
723template <typename T>
724class HandshakeConversionPattern : public OpConversionPattern<T> {
725public:
726 HandshakeConversionPattern(ESITypeConverter &typeConverter,
727 MLIRContext *context, OpBuilder &submoduleBuilder,
728 HandshakeLoweringState &ls)
729 : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
730 submoduleBuilder(submoduleBuilder), ls(ls) {}
731
732 using OpAdaptor = typename T::Adaptor;
733
734 LogicalResult
735 matchAndRewrite(T op, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter) const override {
737
738 // Check if a submodule has already been created for the op. If so,
739 // instantiate the submodule. Else, run the pattern-defined module
740 // builder.
741 hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
742 if (!implModule) {
743 auto portInfo = ModulePortInfo(getPortInfoForOp(op));
744
745 submoduleBuilder.setInsertionPoint(op->getParentOp());
746 implModule = hw::HWModuleOp::create(
747 submoduleBuilder, op.getLoc(),
748 submoduleBuilder.getStringAttr(getSubModuleName(op)), portInfo,
749 [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
750 // if 'op' has clock trait, extract these and provide them to the
751 // RTL builder.
752 Value clk, rst;
753 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
754 clk = ports.getInput("clock");
755 rst = ports.getInput("reset");
756 }
757
758 BackedgeBuilder bb(b, op.getLoc());
759 RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
760 this->buildModule(op, bb, s, ports);
761 });
762 }
763
764 // Instantiate the submodule.
765 llvm::SmallVector<Value> operands = adaptor.getOperands();
766 addSequentialIOOperandsIfNeeded(op, operands);
767 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
768 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
769 return success();
770 }
771
772 virtual void buildModule(T op, BackedgeBuilder &bb, RTLBuilder &builder,
773 hw::HWModulePortAccessor &ports) const = 0;
774
775 // Syntactic sugar functions.
776 // Unwraps an ESI-interfaced module into its constituent handshake signals.
777 // Backedges are created for the to-be-resolved signals, and output ports
778 // are assigned to their wrapped counterparts.
779 UnwrappedIO unwrapIO(RTLBuilder &s, BackedgeBuilder &bb,
780 hw::HWModulePortAccessor &ports) const {
781 UnwrappedIO unwrapped;
782 for (auto port : ports.getInputs()) {
783 if (!isa<esi::ChannelType>(port.getType()))
784 continue;
785 InputHandshake hs;
786 auto ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
787 auto [data, valid] = s.unwrap(port, *ready);
788 hs.data = data;
789 hs.valid = valid;
790 hs.ready = ready;
791 unwrapped.inputs.push_back(hs);
792 }
793 for (auto &outputInfo : ports.getPortList().getOutputs()) {
794 esi::ChannelType channelType =
795 dyn_cast<esi::ChannelType>(outputInfo.type);
796 if (!channelType)
797 continue;
798 OutputHandshake hs;
799 Type innerType = channelType.getInner();
800 auto data = std::make_shared<Backedge>(bb.get(innerType));
801 auto valid = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
802 auto [dataCh, ready] = s.wrap(*data, *valid);
803 hs.data = data;
804 hs.valid = valid;
805 hs.ready = ready;
806 ports.setOutput(outputInfo.name, dataCh);
807 unwrapped.outputs.push_back(hs);
808 }
809 return unwrapped;
810 }
811
812 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
813 OutputHandshake &output, Value cond) const {
814 auto validAndReady = s.bAnd({output.ready, cond});
815 for (auto &input : inputs)
816 input.ready->setValue(validAndReady);
817 }
818
819 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
820 OutputHandshake &output) const {
821 llvm::SmallVector<Value> valids;
822 for (auto &input : inputs)
823 valids.push_back(input.valid);
824 Value allValid = s.bAnd(valids);
825 output.valid->setValue(allValid);
826 setAllReadyWithCond(s, inputs, output, allValid);
827 }
828
829 // Builds mux logic for the given inputs and outputs.
830 // Note: it is assumed that the caller has removed the 'select' signal from
831 // the 'unwrapped' inputs and provide it as a separate argument.
832 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
833 InputHandshake &select) const {
834 // ============================= Control logic =============================
835 size_t numInputs = unwrapped.inputs.size();
836 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
837 Value truncatedSelect =
838 select.data.getType().getIntOrFloatBitWidth() > selectWidth
839 ? s.truncate(select.data, selectWidth)
840 : select.data;
841
842 // Decimal-to-1-hot decoder. 'shl' operands must be identical in size.
843 auto selectZext = s.zext(truncatedSelect, numInputs);
844 auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
845 auto &res = unwrapped.outputs[0];
846
847 // Mux input valid signals.
848 auto selectedInputValid =
849 s.mux(truncatedSelect, unwrapped.getInputValids());
850 // Result is valid when the selected input and the select input is valid.
851 auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
852 res.valid->setValue(selAndInputValid);
853 auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
854
855 // Select is ready when result is valid and ready (result transacting).
856 select.ready->setValue(resValidAndReady);
857
858 // Assign each input ready signal if it is currently selected.
859 for (auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
860 // Extract the selection bit for this input.
861 auto isSelected = s.bit(select1h, inIdx);
862
863 // '&' that with the result valid and ready, and assign to the input
864 // ready signal.
865 auto activeAndResultValidAndReady =
866 s.bAnd({isSelected, resValidAndReady});
867 in.ready->setValue(activeAndResultValidAndReady);
868 }
869
870 // ============================== Data logic ===============================
871 res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
872 }
873
874 // Builds fork logic between the single input and multiple outputs' control
875 // networks. Caller is expected to handle data separately.
876 void buildForkLogic(RTLBuilder &s, BackedgeBuilder &bb, InputHandshake &input,
877 ArrayRef<OutputHandshake> outputs) const {
878 auto c0I1 = s.constant(1, 0);
879 llvm::SmallVector<Value> doneWires;
880 for (auto [i, output] : llvm::enumerate(outputs)) {
881 auto doneBE = bb.get(s.b.getI1Type());
882 auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
883 auto emittedReg = s.reg("emitted_" + std::to_string(i), emitted, c0I1);
884 auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
885 output.valid->setValue(outValid);
886 auto validReady = s.bAnd({output.ready, outValid});
887 auto done = s.bOr({validReady, emittedReg}, "done" + std::to_string(i));
888 doneBE.setValue(done);
889 doneWires.push_back(done);
890 }
891 input.ready->setValue(s.bAnd(doneWires, "allDone"));
892 }
893
894 // Builds a unit-rate actor around an inner operation. 'unitBuilder' is a
895 // function which takes the set of unwrapped data inputs, and returns a
896 // value which should be assigned to the output data value.
897 void buildUnitRateJoinLogic(
898 RTLBuilder &s, UnwrappedIO &unwrappedIO,
899 llvm::function_ref<Value(ValueRange)> unitBuilder) const {
900 assert(unwrappedIO.outputs.size() == 1 &&
901 "Expected exactly one output for unit-rate join actor");
902 // Control logic.
903 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
904
905 // Data logic.
906 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
907 unwrappedIO.outputs[0].data->setValue(unitRes);
908 }
909
910 void buildUnitRateForkLogic(
911 RTLBuilder &s, BackedgeBuilder &bb, UnwrappedIO &unwrappedIO,
912 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder) const {
913 assert(unwrappedIO.inputs.size() == 1 &&
914 "Expected exactly one input for unit-rate fork actor");
915 // Control logic.
916 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
917
918 // Data logic.
919 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
920 assert(unitResults.size() == unwrappedIO.outputs.size() &&
921 "Expected unit builder to return one result per output");
922 for (auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
923 outport.data->setValue(res);
924 }
925
926 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
927 bool signExtend) const {
928 size_t outWidth =
929 toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
930 .getIntOrFloatBitWidth();
931 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
932 if (signExtend)
933 return s.sext(inputs[0], outWidth);
934 return s.zext(inputs[0], outWidth);
935 });
936 }
937
938 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
939 unsigned targetWidth) const {
940 size_t outWidth =
941 toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
942 .getIntOrFloatBitWidth();
943 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
944 return s.truncate(inputs[0], outWidth);
945 });
946 }
947
948 /// Return the number of bits needed to index the given number of values.
949 static size_t getNumIndexBits(uint64_t numValues) {
950 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
951 }
952
953 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
954 Value defaultValue,
955 DenseMap<size_t, Value> &indexMapping) const {
956 auto numInputs = inputs.size();
957 auto priorityArb = defaultValue;
958
959 for (size_t i = numInputs; i > 0; --i) {
960 size_t inputIndex = i - 1;
961 size_t oneHotIndex = size_t{1} << inputIndex;
962 auto constIndex = s.constant(numInputs, oneHotIndex);
963 indexMapping[inputIndex] = constIndex;
964 priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
965 }
966 return priorityArb;
967 }
968
969private:
970 OpBuilder &submoduleBuilder;
971 HandshakeLoweringState &ls;
972};
973
974class ForkConversionPattern : public HandshakeConversionPattern<ForkOp> {
975public:
976 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
977 void buildModule(ForkOp op, BackedgeBuilder &bb, RTLBuilder &s,
978 hw::HWModulePortAccessor &ports) const override {
979 auto unwrapped = unwrapIO(s, bb, ports);
980 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
981 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
982 });
983 }
984};
985
986class JoinConversionPattern : public HandshakeConversionPattern<JoinOp> {
987public:
988 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
989 void buildModule(JoinOp op, BackedgeBuilder &bb, RTLBuilder &s,
990 hw::HWModulePortAccessor &ports) const override {
991 auto unwrappedIO = unwrapIO(s, bb, ports);
992 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
993 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
994 };
995};
996
997class SyncConversionPattern : public HandshakeConversionPattern<SyncOp> {
998public:
999 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1000 void buildModule(SyncOp op, BackedgeBuilder &bb, RTLBuilder &s,
1001 hw::HWModulePortAccessor &ports) const override {
1002 auto unwrappedIO = unwrapIO(s, bb, ports);
1003
1004 // A helper wire that will be used to connect the two built logics
1005 HandshakeWire wire(bb, s.b.getNoneType());
1006
1007 OutputHandshake output = wire.getAsOutput();
1008 buildJoinLogic(s, unwrappedIO.inputs, output);
1009
1010 InputHandshake input = wire.getAsInput();
1011
1012 // The state-keeping fork logic is required here, as the circuit isn't
1013 // allowed to wait for all the consumers to be ready. Connecting the ready
1014 // signals of the outputs to their corresponding valid signals leads to
1015 // combinatorial cycles. The paper which introduced compositional dataflow
1016 // circuits explicitly mentions this limitation:
1017 // http://arcade.cs.columbia.edu/df-memocode17.pdf
1018 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1019
1020 // Directly connect the data wires, only the control signals need to be
1021 // combined.
1022 for (auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1023 out.data->setValue(in.data);
1024 };
1025};
1026
1027class MuxConversionPattern : public HandshakeConversionPattern<MuxOp> {
1028public:
1029 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1030 void buildModule(MuxOp op, BackedgeBuilder &bb, RTLBuilder &s,
1031 hw::HWModulePortAccessor &ports) const override {
1032 auto unwrappedIO = unwrapIO(s, bb, ports);
1033
1034 // Extract select signal from the unwrapped IO.
1035 auto select = unwrappedIO.inputs[0];
1036 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1037 buildMuxLogic(s, unwrappedIO, select);
1038 };
1039};
1040
1041class InstanceConversionPattern
1042 : public HandshakeConversionPattern<handshake::InstanceOp> {
1043public:
1044 using HandshakeConversionPattern<
1045 handshake::InstanceOp>::HandshakeConversionPattern;
1046 void buildModule(handshake::InstanceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1047 hw::HWModulePortAccessor &ports) const override {
1048 assert(false &&
1049 "If we indeed perform conversion in post-order, this "
1050 "should never be called. The base HandshakeConversionPattern logic "
1051 "will instantiate the external module.");
1052 }
1053};
1054
1055class ESIInstanceConversionPattern
1056 : public OpConversionPattern<handshake::ESIInstanceOp> {
1057public:
1058 ESIInstanceConversionPattern(MLIRContext *context,
1059 const HWSymbolCache &symCache)
1060 : OpConversionPattern(context), symCache(symCache) {}
1061
1062 LogicalResult
1063 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1064 ConversionPatternRewriter &rewriter) const override {
1065 // The operand signature of this op is very similar to the lowered
1066 // `handshake.func`s (especially since handshake uses ESI channels
1067 // internally). Whereas ESIInstance ops have 'clk' and 'rst' at the
1068 // beginning, lowered `handshake.func`s have them at the end. So we've just
1069 // got to re-arrange them.
1070 SmallVector<Value> operands;
1071 for (size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1072 i < e; ++i)
1073 operands.push_back(adaptor.getOperands()[i]);
1074 operands.push_back(adaptor.getClk());
1075 operands.push_back(adaptor.getRst());
1076 // Locate the lowered module so the instance builder can get all the
1077 // metadata.
1078 Operation *targetModule = symCache.getDefinition(op.getModuleAttr());
1079 // And replace the op with an instance of the target module.
1080 rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1081 op.getInstNameAttr(), operands);
1082 return success();
1083 }
1084
1085private:
1086 const HWSymbolCache &symCache;
1087};
1088
1089class ReturnConversionPattern
1090 : public OpConversionPattern<handshake::ReturnOp> {
1091public:
1092 using OpConversionPattern::OpConversionPattern;
1093 LogicalResult
1094 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1095 ConversionPatternRewriter &rewriter) const override {
1096 // Locate existing output op, Append operands to output op, and move to
1097 // the end of the block.
1098 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1099 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1100 outputOp->setOperands(adaptor.getOperands());
1101 outputOp->moveAfter(&parent.getBodyBlock()->back());
1102 rewriter.eraseOp(op);
1103 return success();
1104 }
1105};
1106
1107// Converts an arbitrary operation into a unit rate actor. A unit rate actor
1108// will transact once all inputs are valid and its output is ready.
1109template <typename TIn, typename TOut = TIn>
1110class UnitRateConversionPattern : public HandshakeConversionPattern<TIn> {
1111public:
1112 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1113 void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1114 hw::HWModulePortAccessor &ports) const override {
1115 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1116 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1117 // Create TOut - it is assumed that TOut trivially
1118 // constructs from the input data signals of TIn.
1119 // To disambiguate ambiguous builders with default arguments (e.g.,
1120 // twoState UnitAttr), specify attribute array explicitly.
1121 return TOut::create(s.b, op.getLoc(), inputs,
1122 /* attributes */ ArrayRef<NamedAttribute>{});
1123 });
1124 };
1125};
1126
1127class PackConversionPattern : public HandshakeConversionPattern<PackOp> {
1128public:
1129 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1130 void buildModule(PackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1131 hw::HWModulePortAccessor &ports) const override {
1132 auto unwrappedIO = unwrapIO(s, bb, ports);
1133 buildUnitRateJoinLogic(s, unwrappedIO,
1134 [&](ValueRange inputs) { return s.pack(inputs); });
1135 };
1136};
1137
1138class StructCreateConversionPattern
1139 : public HandshakeConversionPattern<hw::StructCreateOp> {
1140public:
1141 using HandshakeConversionPattern<
1142 hw::StructCreateOp>::HandshakeConversionPattern;
1143 void buildModule(hw::StructCreateOp op, BackedgeBuilder &bb, RTLBuilder &s,
1144 hw::HWModulePortAccessor &ports) const override {
1145 auto unwrappedIO = unwrapIO(s, bb, ports);
1146 auto structType = op.getResult().getType();
1147 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1148 return s.pack(inputs, structType);
1149 });
1150 };
1151};
1152
1153class UnpackConversionPattern : public HandshakeConversionPattern<UnpackOp> {
1154public:
1155 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1156 void buildModule(UnpackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1157 hw::HWModulePortAccessor &ports) const override {
1158 auto unwrappedIO = unwrapIO(s, bb, ports);
1159 buildUnitRateForkLogic(s, bb, unwrappedIO,
1160 [&](Value input) { return s.unpack(input); });
1161 };
1162};
1163
1164class ConditionalBranchConversionPattern
1165 : public HandshakeConversionPattern<ConditionalBranchOp> {
1166public:
1167 using HandshakeConversionPattern<
1168 ConditionalBranchOp>::HandshakeConversionPattern;
1169 void buildModule(ConditionalBranchOp op, BackedgeBuilder &bb, RTLBuilder &s,
1170 hw::HWModulePortAccessor &ports) const override {
1171 auto unwrappedIO = unwrapIO(s, bb, ports);
1172 auto cond = unwrappedIO.inputs[0];
1173 auto arg = unwrappedIO.inputs[1];
1174 auto trueRes = unwrappedIO.outputs[0];
1175 auto falseRes = unwrappedIO.outputs[1];
1176
1177 auto condArgValid = s.bAnd({cond.valid, arg.valid});
1178
1179 // Connect valid signal of both results.
1180 trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1181 falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1182
1183 // Connecte data signals of both results.
1184 trueRes.data->setValue(arg.data);
1185 falseRes.data->setValue(arg.data);
1186
1187 // Connect ready signal of input and condition.
1188 auto selectedResultReady =
1189 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1190 auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1191 arg.ready->setValue(condArgReady);
1192 cond.ready->setValue(condArgReady);
1193 };
1194};
1195
1196template <typename TIn, bool signExtend>
1197class ExtendConversionPattern : public HandshakeConversionPattern<TIn> {
1198public:
1199 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1200 void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1201 hw::HWModulePortAccessor &ports) const override {
1202 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1203 this->buildExtendLogic(s, unwrappedIO, /*signExtend=*/signExtend);
1204 };
1205};
1206
1207class ComparisonConversionPattern
1208 : public HandshakeConversionPattern<arith::CmpIOp> {
1209public:
1210 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1211 void buildModule(arith::CmpIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1212 hw::HWModulePortAccessor &ports) const override {
1213 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1214 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1215 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1216 return comb::ICmpOp::create(s.b, op.getLoc(), predicate, inputs[0],
1217 inputs[1]);
1218 });
1219 };
1220
1221 switch (op.getPredicate()) {
1222 case arith::CmpIPredicate::eq:
1223 return buildCompareLogic(comb::ICmpPredicate::eq);
1224 case arith::CmpIPredicate::ne:
1225 return buildCompareLogic(comb::ICmpPredicate::ne);
1226 case arith::CmpIPredicate::slt:
1227 return buildCompareLogic(comb::ICmpPredicate::slt);
1228 case arith::CmpIPredicate::ult:
1229 return buildCompareLogic(comb::ICmpPredicate::ult);
1230 case arith::CmpIPredicate::sle:
1231 return buildCompareLogic(comb::ICmpPredicate::sle);
1232 case arith::CmpIPredicate::ule:
1233 return buildCompareLogic(comb::ICmpPredicate::ule);
1234 case arith::CmpIPredicate::sgt:
1235 return buildCompareLogic(comb::ICmpPredicate::sgt);
1236 case arith::CmpIPredicate::ugt:
1237 return buildCompareLogic(comb::ICmpPredicate::ugt);
1238 case arith::CmpIPredicate::sge:
1239 return buildCompareLogic(comb::ICmpPredicate::sge);
1240 case arith::CmpIPredicate::uge:
1241 return buildCompareLogic(comb::ICmpPredicate::uge);
1242 }
1243 assert(false && "invalid CmpIOp");
1244 };
1245};
1246
1247class TruncateConversionPattern
1248 : public HandshakeConversionPattern<arith::TruncIOp> {
1249public:
1250 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1251 void buildModule(arith::TruncIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1252 hw::HWModulePortAccessor &ports) const override {
1253 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1254 unsigned targetBits =
1255 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1256 buildTruncateLogic(s, unwrappedIO, targetBits);
1257 };
1258};
1259
1260class ControlMergeConversionPattern
1261 : public HandshakeConversionPattern<ControlMergeOp> {
1262public:
1263 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1264 void buildModule(ControlMergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1265 hw::HWModulePortAccessor &ports) const override {
1266 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1267 auto resData = unwrappedIO.outputs[0];
1268 auto resIndex = unwrappedIO.outputs[1];
1269
1270 // Define some common types and values that will be used.
1271 unsigned numInputs = unwrappedIO.inputs.size();
1272 auto indexType = s.b.getIntegerType(numInputs);
1273 Value noWinner = s.constant(numInputs, 0);
1274 Value c0I1 = s.constant(1, 0);
1275
1276 // Declare register for storing arbitration winner.
1277 auto won = bb.get(indexType);
1278 Value wonReg = s.reg("won_reg", won, noWinner);
1279
1280 // Declare wire for arbitration winner.
1281 auto win = bb.get(indexType);
1282
1283 // Declare wire for whether the circuit just fired and emitted both
1284 // outputs.
1285 auto fired = bb.get(s.b.getI1Type());
1286
1287 // Declare registers for storing if each output has been emitted.
1288 auto resultEmitted = bb.get(s.b.getI1Type());
1289 Value resultEmittedReg = s.reg("result_emitted_reg", resultEmitted, c0I1);
1290 auto indexEmitted = bb.get(s.b.getI1Type());
1291 Value indexEmittedReg = s.reg("index_emitted_reg", indexEmitted, c0I1);
1292
1293 // Declare wires for if each output is done.
1294 auto resultDone = bb.get(s.b.getI1Type());
1295 auto indexDone = bb.get(s.b.getI1Type());
1296
1297 // Create predicates to assert if the win wire or won register hold a
1298 // valid index.
1299 auto hasWinnerCondition = s.rOr({win});
1300 auto hadWinnerCondition = s.rOr({wonReg});
1301
1302 // Create an arbiter based on a simple priority-encoding scheme to assign
1303 // an index to the win wire. If the won register is set, just use that. In
1304 // the case that won is not set and no input is valid, set a sentinel
1305 // value to indicate no winner was chosen. The constant values are
1306 // remembered in a map so they can be re-used later to assign the arg
1307 // ready outputs.
1308 DenseMap<size_t, Value> argIndexValues;
1309 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1310 noWinner, argIndexValues);
1311 priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1312 win.setValue(priorityArb);
1313
1314 // Create the logic to assign the result and index outputs. The result
1315 // valid output will always be assigned, and if isControl is not set, the
1316 // result data output will also be assigned. The index valid and data
1317 // outputs will always be assigned. The win wire from the arbiter is used
1318 // to index into a tree of muxes to select the chosen input's signal(s),
1319 // and is fed directly to the index output. Both the result and index
1320 // valid outputs are gated on the win wire being set to something other
1321 // than the sentinel value.
1322 auto resultNotEmitted = s.bNot(resultEmittedReg);
1323 auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1324 resData.valid->setValue(resultValid);
1325 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1326
1327 auto indexNotEmitted = s.bNot(indexEmittedReg);
1328 auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1329 resIndex.valid->setValue(indexValid);
1330
1331 // Use the one-hot win wire to select the index to output in the index
1332 // data.
1333 SmallVector<Value, 8> indexOutputs;
1334 for (size_t i = 0; i < numInputs; ++i)
1335 indexOutputs.push_back(s.constant(64, i));
1336
1337 auto indexOutput = s.ohMux(win, indexOutputs);
1338 resIndex.data->setValue(indexOutput);
1339
1340 // Create the logic to set the won register. If the fired wire is
1341 // asserted, we have finished this round and can and reset the register to
1342 // the sentinel value that indicates there is no winner. Otherwise, we
1343 // need to hold the value of the win register until we can fire.
1344 won.setValue(s.mux(fired, {win, noWinner}));
1345
1346 // Create the logic to set the done wires for the result and index. For
1347 // both outputs, the done wire is asserted when the output is valid and
1348 // ready, or the emitted register for that output is set.
1349 auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1350 resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1351
1352 auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1353 indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1354
1355 // Create the logic to set the fired wire. It is asserted when both result
1356 // and index are done.
1357 fired.setValue(s.bAnd({resultDone, indexDone}));
1358
1359 // Create the logic to assign the emitted registers. If the fired wire is
1360 // asserted, we have finished this round and can reset the registers to 0.
1361 // Otherwise, we need to hold the values of the done registers until we
1362 // can fire.
1363 resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1364 indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1365
1366 // Create the logic to assign the arg ready outputs. The logic is
1367 // identical for each arg. If the fired wire is asserted, and the win wire
1368 // holds an arg's index, that arg is ready.
1369 auto winnerOrDefault = s.mux(fired, {noWinner, win});
1370 for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1371 auto &indexValue = argIndexValues[i];
1372 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1373 }
1374 };
1375};
1376
1377class MergeConversionPattern : public HandshakeConversionPattern<MergeOp> {
1378public:
1379 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1380 void buildModule(MergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1381 hw::HWModulePortAccessor &ports) const override {
1382 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1383 auto resData = unwrappedIO.outputs[0];
1384
1385 // Define some common types and values that will be used.
1386 unsigned numInputs = unwrappedIO.inputs.size();
1387 auto indexType = s.b.getIntegerType(numInputs);
1388 Value noWinner = s.constant(numInputs, 0);
1389
1390 // Declare wire for arbitration winner.
1391 auto win = bb.get(indexType);
1392
1393 // Create predicates to assert if the win wire holds a valid index.
1394 auto hasWinnerCondition = s.rOr(win);
1395
1396 // Create an arbiter based on a simple priority-encoding scheme to assign an
1397 // index to the win wire. In the case that no input is valid, set a sentinel
1398 // value to indicate no winner was chosen. The constant values are
1399 // remembered in a map so they can be re-used later to assign the arg ready
1400 // outputs.
1401 DenseMap<size_t, Value> argIndexValues;
1402 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1403 noWinner, argIndexValues);
1404 win.setValue(priorityArb);
1405
1406 // Create the logic to assign the result outputs. The result valid and data
1407 // outputs will always be assigned. The win wire from the arbiter is used to
1408 // index into a tree of muxes to select the chosen input's signal(s). The
1409 // result outputs are gated on the win wire being non-zero.
1410
1411 resData.valid->setValue(hasWinnerCondition);
1412 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1413
1414 // Create the logic to set the done wires for the result. The done wire is
1415 // asserted when the output is valid and ready, or the emitted register is
1416 // set.
1417 auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1418
1419 // Create the logic to assign the arg ready outputs. The logic is
1420 // identical for each arg. If the fired wire is asserted, and the win wire
1421 // holds an arg's index, that arg is ready.
1422 auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1423 for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1424 auto &indexValue = argIndexValues[i];
1425 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1426 }
1427 };
1428};
1429
1430class LoadConversionPattern
1431 : public HandshakeConversionPattern<handshake::LoadOp> {
1432public:
1433 using HandshakeConversionPattern<
1434 handshake::LoadOp>::HandshakeConversionPattern;
1435 void buildModule(handshake::LoadOp op, BackedgeBuilder &bb, RTLBuilder &s,
1436 hw::HWModulePortAccessor &ports) const override {
1437 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1438 auto addrFromUser = unwrappedIO.inputs[0];
1439 auto dataFromMem = unwrappedIO.inputs[1];
1440 auto controlIn = unwrappedIO.inputs[2];
1441 auto dataToUser = unwrappedIO.outputs[0];
1442 auto addrToMem = unwrappedIO.outputs[1];
1443
1444 addrToMem.data->setValue(addrFromUser.data);
1445 dataToUser.data->setValue(dataFromMem.data);
1446
1447 // The valid/ready logic between user address/control to memoryAddr is
1448 // join logic.
1449 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1450
1451 // The valid/ready logic between memoryData and outputData is a direct
1452 // connection.
1453 dataToUser.valid->setValue(dataFromMem.valid);
1454 dataFromMem.ready->setValue(dataToUser.ready);
1455 };
1456};
1457
1458class StoreConversionPattern
1459 : public HandshakeConversionPattern<handshake::StoreOp> {
1460public:
1461 using HandshakeConversionPattern<
1462 handshake::StoreOp>::HandshakeConversionPattern;
1463 void buildModule(handshake::StoreOp op, BackedgeBuilder &bb, RTLBuilder &s,
1464 hw::HWModulePortAccessor &ports) const override {
1465 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1466 auto addrFromUser = unwrappedIO.inputs[0];
1467 auto dataFromUser = unwrappedIO.inputs[1];
1468 auto controlIn = unwrappedIO.inputs[2];
1469 auto dataToMem = unwrappedIO.outputs[0];
1470 auto addrToMem = unwrappedIO.outputs[1];
1471
1472 // Create a gate that will be asserted when all outputs are ready.
1473 auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1474
1475 // Build the standard join logic from the inputs to the inputsValid and
1476 // outputsReady signals.
1477 HandshakeWire joinWire(bb, s.b.getNoneType());
1478 joinWire.ready->setValue(outputsReady);
1479 OutputHandshake joinOutput = joinWire.getAsOutput();
1480 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1481
1482 // Output address and data signals are connected directly.
1483 addrToMem.data->setValue(addrFromUser.data);
1484 dataToMem.data->setValue(dataFromUser.data);
1485
1486 // Output valid signals are connected from the inputsValid wire.
1487 addrToMem.valid->setValue(*joinWire.valid);
1488 dataToMem.valid->setValue(*joinWire.valid);
1489 };
1490};
1491
1492class MemoryConversionPattern
1493 : public HandshakeConversionPattern<handshake::MemoryOp> {
1494public:
1495 using HandshakeConversionPattern<
1496 handshake::MemoryOp>::HandshakeConversionPattern;
1497 void buildModule(handshake::MemoryOp op, BackedgeBuilder &bb, RTLBuilder &s,
1498 hw::HWModulePortAccessor &ports) const override {
1499 auto loc = op.getLoc();
1500
1501 // Gather up the load and store ports.
1502 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1503 struct LoadPort {
1504 InputHandshake &addr;
1505 OutputHandshake &data;
1506 OutputHandshake &done;
1507 };
1508 struct StorePort {
1509 InputHandshake &addr;
1510 InputHandshake &data;
1511 OutputHandshake &done;
1512 };
1513 SmallVector<LoadPort, 4> loadPorts;
1514 SmallVector<StorePort, 4> storePorts;
1515
1516 unsigned stCount = op.getStCount();
1517 unsigned ldCount = op.getLdCount();
1518 for (unsigned i = 0, e = ldCount; i != e; ++i) {
1519 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1520 unwrappedIO.outputs[i],
1521 unwrappedIO.outputs[ldCount + stCount + i]};
1522 loadPorts.push_back(port);
1523 }
1524
1525 for (unsigned i = 0, e = stCount; i != e; ++i) {
1526 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1527 unwrappedIO.inputs[i * 2],
1528 unwrappedIO.outputs[ldCount + i]};
1529 storePorts.push_back(port);
1530 }
1531
1532 // used to drive the data wire of the control-only channels.
1533 auto c0I0 = s.constant(0, 0);
1534
1535 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1536 auto hlmem = seq::HLMemOp::create(
1537 s.b, loc, s.clk, s.rst,
1538 "_handshake_memory_" + std::to_string(op.getId()),
1539 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1540
1541 // Create load ports...
1542 for (auto &ld : loadPorts) {
1543 llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1544 auto readData = seq::ReadPortOp::create(s.b, loc, hlmem.getHandle(),
1545 addresses, ld.addr.valid,
1546 /*latency=*/0);
1547 ld.data.data->setValue(readData);
1548 ld.done.data->setValue(c0I0);
1549 // Create control fork for the load address valid and ready signals.
1550 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1551 }
1552
1553 // Create store ports...
1554 for (auto &st : storePorts) {
1555 // Create a register to buffer the valid path by 1 cycle, to match the
1556 // write latency of 1.
1557 auto writeValidBufferMuxBE = bb.get(s.b.getI1Type());
1558 auto writeValidBuffer =
1559 s.reg("writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1560 st.done.valid->setValue(writeValidBuffer);
1561 st.done.data->setValue(c0I0);
1562
1563 // Create the logic for when both the buffered write valid signal and the
1564 // store complete ready signal are asserted.
1565 auto storeCompleted =
1566 s.bAnd({st.done.ready, writeValidBuffer}, "storeCompleted");
1567
1568 // Create a signal for when the write valid buffer is empty or the output
1569 // is ready.
1570 auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1571 auto emptyOrComplete =
1572 s.bOr({notWriteValidBuffer, storeCompleted}, "emptyOrComplete");
1573
1574 // Connect the gate to both the store address ready and store data ready
1575 st.addr.ready->setValue(emptyOrComplete);
1576 st.data.ready->setValue(emptyOrComplete);
1577
1578 // Create a wire for when both the store address and data are valid.
1579 auto writeValid = s.bAnd({st.addr.valid, st.data.valid}, "writeValid");
1580
1581 // Create a mux that drives the buffer input. If the emptyOrComplete
1582 // signal is asserted, the mux selects the writeValid signal. Otherwise,
1583 // it selects the buffer output, keeping the output registered until the
1584 // emptyOrComplete signal is asserted.
1585 writeValidBufferMuxBE.setValue(
1586 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1587
1588 // Instantiate the write port operation - truncate address width to memory
1589 // width.
1590 llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1591 seq::WritePortOp::create(s.b, loc, hlmem.getHandle(), addresses,
1592 st.data.data, writeValid,
1593 /*latency=*/1);
1594 }
1595 }
1596};
1597
1598class SinkConversionPattern : public HandshakeConversionPattern<SinkOp> {
1599public:
1600 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1601 void buildModule(SinkOp op, BackedgeBuilder &bb, RTLBuilder &s,
1602 hw::HWModulePortAccessor &ports) const override {
1603 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1604 // A sink is always ready to accept a new value.
1605 unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1606 };
1607};
1608
1609class SourceConversionPattern : public HandshakeConversionPattern<SourceOp> {
1610public:
1611 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1612 void buildModule(SourceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1613 hw::HWModulePortAccessor &ports) const override {
1614 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1615 // A source always provides a new (i0-typed) value.
1616 unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1617 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1618 };
1619};
1620
1621class ConstantConversionPattern
1622 : public HandshakeConversionPattern<handshake::ConstantOp> {
1623public:
1624 using HandshakeConversionPattern<
1625 handshake::ConstantOp>::HandshakeConversionPattern;
1626 void buildModule(handshake::ConstantOp op, BackedgeBuilder &bb, RTLBuilder &s,
1627 hw::HWModulePortAccessor &ports) const override {
1628 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1629 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1630 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1631 auto constantValue = op->getAttrOfType<IntegerAttr>("value").getValue();
1632 unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1633 };
1634};
1635
1636class BufferConversionPattern : public HandshakeConversionPattern<BufferOp> {
1637public:
1638 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1639 void buildModule(BufferOp op, BackedgeBuilder &bb, RTLBuilder &s,
1640 hw::HWModulePortAccessor &ports) const override {
1641 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1642 auto input = unwrappedIO.inputs[0];
1643 auto output = unwrappedIO.outputs[0];
1644 InputHandshake lastStage;
1645 SmallVector<int64_t> initValues;
1646
1647 // For now, always build seq buffers.
1648 if (op.getInitValues())
1649 initValues = op.getInitValueArray();
1650
1651 lastStage =
1652 buildSeqBufferLogic(s, bb, toValidType(op.getDataType()),
1653 op.getNumSlots(), input, output, initValues);
1654
1655 // Connect the last stage to the output handshake.
1656 output.data->setValue(lastStage.data);
1657 output.valid->setValue(lastStage.valid);
1658 lastStage.ready->setValue(output.ready);
1659 };
1660
1661 struct SeqBufferStage {
1662 SeqBufferStage(Type dataType, InputHandshake &preStage, BackedgeBuilder &bb,
1663 RTLBuilder &s, size_t index,
1664 std::optional<int64_t> initValue)
1665 : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1666
1667 // Todo: Change when i0 support is added.
1668 c0s = createZeroDataConst(s, s.loc, dataType);
1669 currentStage.ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
1670
1671 auto hasInitValue = s.constant(1, initValue.has_value());
1672 auto validBE = bb.get(s.b.getI1Type());
1673 auto validReg = s.reg(getRegName("valid"), validBE, hasInitValue);
1674 auto readyBE = bb.get(s.b.getI1Type());
1675
1676 Value initValueCs = c0s;
1677 if (initValue.has_value())
1678 initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1679
1680 // This could/should be revised but needs a larger rethinking to avoid
1681 // introducing new bugs.
1682 Value dataReg =
1683 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1684 buildControlBufferLogic(validReg, readyBE, dataReg);
1685 }
1686
1687 StringAttr getRegName(StringRef name) {
1688 return s.b.getStringAttr(name + std::to_string(index) + "_reg");
1689 }
1690
1691 void buildControlBufferLogic(Value validReg, Backedge &readyBE,
1692 Value dataReg) {
1693 auto c0I1 = s.constant(1, 0);
1694 auto readyRegWire = bb.get(s.b.getI1Type());
1695 auto readyReg = s.reg(getRegName("ready"), readyRegWire, c0I1);
1696
1697 // Create the logic to drive the current stage valid and potentially
1698 // data.
1699 currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1700 "controlValid" + std::to_string(index));
1701
1702 // Create the logic to drive the current stage ready.
1703 auto notReadyReg = s.bNot(readyReg);
1704 readyBE.setValue(notReadyReg);
1705
1706 auto succNotReady = s.bNot(*currentStage.ready);
1707 auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1708 auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1709 auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1710
1711 // Create a mux for emptying the register when both are ready.
1712 auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1713 readyRegWire.setValue(resetSignal);
1714
1715 // Add same logic for the data path if necessary.
1716 auto ctrlDataRegBE = bb.get(dataType);
1717 auto ctrlDataReg = s.reg(getRegName("ctrl_data"), ctrlDataRegBE, c0s);
1718 auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1719 currentStage.data = dataResult;
1720
1721 auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1722 auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1723 ctrlDataRegBE.setValue(dataResetSignal);
1724 }
1725
1726 Value buildDataBufferLogic(Value validReg, Value initValue,
1727 Backedge &validBE, Backedge &readyBE) {
1728 // Create a signal for when the valid register is empty or the successor
1729 // is ready to accept new token.
1730 auto notValidReg = s.bNot(validReg);
1731 auto emptyOrReady = s.bOr({notValidReg, readyBE});
1732 preStage.ready->setValue(emptyOrReady);
1733
1734 // Create a mux that drives the register input. If the emptyOrReady
1735 // signal is asserted, the mux selects the predValid signal. Otherwise,
1736 // it selects the register output, keeping the output registered
1737 // unchanged.
1738 auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1739
1740 // Now we can drive the valid register.
1741 validBE.setValue(validRegMux);
1742
1743 // Create a mux that drives the date register.
1744 auto dataRegBE = bb.get(dataType);
1745 auto dataReg =
1746 s.reg(getRegName("data"),
1747 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1748 dataRegBE.setValue(dataReg);
1749 return dataReg;
1750 }
1751
1752 InputHandshake getOutput() { return currentStage; }
1753
1754 Type dataType;
1755 InputHandshake &preStage;
1756 InputHandshake currentStage;
1757 RTLBuilder &s;
1758 BackedgeBuilder &bb;
1759 size_t index;
1760
1761 // A zero-valued constant of equal type as the data type of this buffer.
1762 Value c0s;
1763 };
1764
1765 InputHandshake buildSeqBufferLogic(RTLBuilder &s, BackedgeBuilder &bb,
1766 Type dataType, unsigned size,
1767 InputHandshake &input,
1768 OutputHandshake &output,
1769 llvm::ArrayRef<int64_t> initValues) const {
1770 // Prime the buffer building logic with an initial stage, which just
1771 // wraps the input handshake.
1772 InputHandshake currentStage = input;
1773
1774 for (unsigned i = 0; i < size; ++i) {
1775 bool isInitialized = i < initValues.size();
1776 auto initValue =
1777 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1778 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1779 .getOutput();
1780 }
1781
1782 return currentStage;
1783 };
1784};
1785
1786class IndexCastConversionPattern
1787 : public HandshakeConversionPattern<arith::IndexCastOp> {
1788public:
1789 using HandshakeConversionPattern<
1790 arith::IndexCastOp>::HandshakeConversionPattern;
1791 void buildModule(arith::IndexCastOp op, BackedgeBuilder &bb, RTLBuilder &s,
1792 hw::HWModulePortAccessor &ports) const override {
1793 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1794 unsigned sourceBits =
1795 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1796 unsigned targetBits =
1797 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1798 if (targetBits < sourceBits)
1799 buildTruncateLogic(s, unwrappedIO, targetBits);
1800 else
1801 buildExtendLogic(s, unwrappedIO, /*signExtend=*/true);
1802 };
1803};
1804
1805template <typename T>
1806class ExtModuleConversionPattern : public OpConversionPattern<T> {
1807public:
1808 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1809 MLIRContext *context, OpBuilder &submoduleBuilder,
1810 HandshakeLoweringState &ls)
1811 : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
1812 submoduleBuilder(submoduleBuilder), ls(ls) {}
1813 using OpAdaptor = typename T::Adaptor;
1814
1815 LogicalResult
1816 matchAndRewrite(T op, OpAdaptor adaptor,
1817 ConversionPatternRewriter &rewriter) const override {
1818
1819 hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
1820 if (!implModule) {
1821 auto portInfo = ModulePortInfo(getPortInfoForOp(op));
1822 implModule = hw::HWModuleExternOp::create(
1823 submoduleBuilder, op.getLoc(),
1824 submoduleBuilder.getStringAttr(getSubModuleName(op)), portInfo);
1825 }
1826
1827 llvm::SmallVector<Value> operands = adaptor.getOperands();
1828 addSequentialIOOperandsIfNeeded(op, operands);
1829 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1830 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1831 return success();
1832 }
1833
1834private:
1835 OpBuilder &submoduleBuilder;
1836 HandshakeLoweringState &ls;
1837};
1838
1839class FuncOpConversionPattern : public OpConversionPattern<handshake::FuncOp> {
1840public:
1841 using OpConversionPattern::OpConversionPattern;
1842
1843 LogicalResult
1844 matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
1845 ConversionPatternRewriter &rewriter) const override {
1846 ModulePortInfo ports =
1847 getPortInfoForOpTypes(op, op.getArgumentTypes(), op.getResultTypes());
1848
1849 HWModuleLike hwModule;
1850 if (op.isExternal()) {
1851 hwModule = hw::HWModuleExternOp::create(
1852 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1853 } else {
1854 auto hwModuleOp = hw::HWModuleOp::create(
1855 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1856 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1857 rewriter.inlineBlockBefore(&op.getBody().front(),
1858 hwModuleOp.getBodyBlock()->getTerminator(),
1859 args);
1860 hwModule = hwModuleOp;
1861 }
1862
1863 // Was any predeclaration associated with this func? If so, replace uses
1864 // with the newly created module and erase the predeclaration.
1865 if (auto predecl =
1866 op->getAttrOfType<FlatSymbolRefAttr>(kPredeclarationAttr)) {
1867 auto *parentOp = op->getParentOp();
1868 auto *predeclModule =
1869 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1870 if (predeclModule) {
1871 if (failed(SymbolTable::replaceAllSymbolUses(
1872 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1873 return failure();
1874 rewriter.eraseOp(predeclModule);
1875 }
1876 }
1877
1878 rewriter.eraseOp(op);
1879 return success();
1880 }
1881};
1882
1883} // namespace
1884
1885//===----------------------------------------------------------------------===//
1886// HW Top-module Related Functions
1887//===----------------------------------------------------------------------===//
1888
1889static LogicalResult convertFuncOp(ESITypeConverter &typeConverter,
1890 ConversionTarget &target,
1892 OpBuilder &moduleBuilder) {
1893
1894 std::map<std::string, unsigned> instanceNameCntr;
1895 NameUniquer instanceUniquer = [&](Operation *op) {
1896 std::string instName = getCallName(op);
1897 if (auto idAttr = op->getAttrOfType<IntegerAttr>("handshake_id"); idAttr) {
1898 // We use a special naming convention for operations which have a
1899 // 'handshake_id' attribute.
1900 instName += "_id" + std::to_string(idAttr.getValue().getZExtValue());
1901 } else {
1902 // Fallback to just prefixing with an integer.
1903 instName += std::to_string(instanceNameCntr[instName]++);
1904 }
1905 return instName;
1906 };
1907
1908 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1909 instanceUniquer};
1910 RewritePatternSet patterns(op.getContext());
1911 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1912 op.getContext());
1913 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1914 SyncConversionPattern>(typeConverter, op.getContext(),
1915 moduleBuilder, ls);
1916
1917 patterns.insert<
1918 // Comb operations.
1919 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1920 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1921 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1922 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1923 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1924 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1925 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1926 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1927 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1928 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1929 UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1930 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1931 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1932 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1933 // HW operations.
1934 StructCreateConversionPattern,
1935 // Handshake operations.
1936 ConditionalBranchConversionPattern, MuxConversionPattern,
1937 PackConversionPattern, UnpackConversionPattern,
1938 ComparisonConversionPattern, BufferConversionPattern,
1939 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1940 MergeConversionPattern, ControlMergeConversionPattern,
1941 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1942 InstanceConversionPattern,
1943 // Arith operations.
1944 ExtendConversionPattern<arith::ExtUIOp, /*signExtend=*/false>,
1945 ExtendConversionPattern<arith::ExtSIOp, /*signExtend=*/true>,
1946 TruncateConversionPattern, IndexCastConversionPattern>(
1947 typeConverter, op.getContext(), moduleBuilder, ls);
1948
1949 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1950 return op->emitOpError() << "error during conversion";
1951 return success();
1952}
1953
1954namespace {
1955class HandshakeToHWPass
1956 : public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1957public:
1958 void runOnOperation() override {
1959 mlir::ModuleOp mod = getOperation();
1960
1961 // Lowering to HW requires that every value is used exactly once. Check
1962 // whether this precondition is met, and if not, exit.
1963 for (auto f : mod.getOps<handshake::FuncOp>()) {
1964 if (failed(verifyAllValuesHasOneUse(f))) {
1965 f.emitOpError() << "HandshakeToHW: failed to verify that all values "
1966 "are used exactly once. Remember to run the "
1967 "fork/sink materialization pass before HW lowering.";
1968 signalPassFailure();
1969 return;
1970 }
1971 }
1972
1973 // Resolve the instance graph to get a top-level module.
1974 std::string topLevel;
1976 SmallVector<std::string> sortedFuncs;
1977 if (resolveInstanceGraph(mod, uses, topLevel, sortedFuncs).failed()) {
1978 signalPassFailure();
1979 return;
1980 }
1981
1982 ESITypeConverter typeConverter;
1983 ConversionTarget target(getContext());
1984 // All top-level logic of a handshake module will be the interconnectivity
1985 // between instantiated modules.
1986 target.addLegalOp<hw::HWModuleOp, hw::HWModuleExternOp, hw::OutputOp,
1987 hw::InstanceOp>();
1988 target
1989 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1990
1991 // Convert the handshake.func operations in post-order wrt. the instance
1992 // graph. This ensures that any referenced submodules (through
1993 // handshake.instance) has already been lowered, and their HW module
1994 // equivalents are available.
1995 OpBuilder submoduleBuilder(mod.getContext());
1996 submoduleBuilder.setInsertionPointToStart(mod.getBody());
1997 for (auto &funcName : llvm::reverse(sortedFuncs)) {
1998 auto funcOp = mod.lookupSymbol<handshake::FuncOp>(funcName);
1999 assert(funcOp && "handshake.func not found in module!");
2000 if (failed(
2001 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2002 signalPassFailure();
2003 return;
2004 }
2005 }
2006
2007 // Second stage: Convert any handshake.extmemory operations and the
2008 // top-level I/O associated with these.
2009 for (auto hwModule : mod.getOps<hw::HWModuleOp>())
2010 if (failed(convertExtMemoryOps(hwModule)))
2011 return signalPassFailure();
2012
2013 // Run conversions which need see everything.
2014 HWSymbolCache symbolCache;
2015 symbolCache.addDefinitions(mod);
2016 symbolCache.freeze();
2017 RewritePatternSet patterns(mod.getContext());
2018 patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2019 symbolCache);
2020 if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
2021 mod->emitOpError() << "error during conversion";
2022 signalPassFailure();
2023 }
2024 }
2025};
2026} // end anonymous namespace
2027
2028std::unique_ptr<mlir::Pass> circt::createHandshakeToHWPass() {
2029 return std::make_unique<HandshakeToHWPass>();
2030}
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
MlirType elementType
Definition CHIRRTL.cpp:29
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
Definition CalyxOps.cpp:121
static Type tupleToStruct(TupleType tuple)
Definition DCToHW.cpp:48
std::function< std::string(Operation *)> NameUniquer
Definition DCToHW.cpp:45
static std::unique_ptr< Context > context
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition HWOps.cpp:1428
static std::string getCallName(Operation *op)
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
Definition OM.cpp:111
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
void setOutput(unsigned i, Value v)
Definition HWOps.cpp:241
const ModulePortInfo & getPortList() const
Definition HWOps.h:111
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
Definition HWSymCache.h:75
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
create(low_bit, result_type, input=None)
Definition comb.py:187
Channels are the basic communication primitives.
Definition Types.h:120
const Type * getInner() const
Definition Types.h:124
create(elements, Type result_type=None)
Definition hw.py:483
create(array_value, idx)
Definition hw.py:450
create(data_type, value)
Definition hw.py:433
create(elements, Type result_type=None)
Definition hw.py:544
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition seq.py:157
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:420
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
Definition hw.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.