CIRCT 20.0.0git
Loading...
Searching...
No Matches
HandshakeToHW.cpp
Go to the documentation of this file.
1//===- HandshakeToHW.cpp - Translate Handshake into HW ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7//===----------------------------------------------------------------------===//
8//
9// This is the main Handshake to HW Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
26#include "mlir/Dialect/Arith/IR/Arith.h"
27#include "mlir/Dialect/MemRef/IR/MemRef.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/MathExtras.h"
34#include <optional>
35
36namespace circt {
37#define GEN_PASS_DEF_HANDSHAKETOHW
38#include "circt/Conversion/Passes.h.inc"
39} // namespace circt
40
41using namespace mlir;
42using namespace circt;
43using namespace circt::handshake;
44using namespace circt::hw;
45
46using NameUniquer = std::function<std::string(Operation *)>;
47
48namespace {
49
50static Type tupleToStruct(TypeRange types) {
51 return toValidType(mlir::TupleType::get(types[0].getContext(), types));
52}
53
54// Shared state used by various functions; captured in a struct to reduce the
55// number of arguments that we have to pass around.
56struct HandshakeLoweringState {
57 ModuleOp parentModule;
58 NameUniquer nameUniquer;
59};
60
61// A type converter is needed to perform the in-flight materialization of "raw"
62// (non-ESI channel) types to their ESI channel correspondents. This comes into
63// effect when backedges exist in the input IR.
64class ESITypeConverter : public TypeConverter {
65public:
66 ESITypeConverter() {
67 addConversion([](Type type) -> Type { return esiWrapper(type); });
68
69 addTargetMaterialization([&](mlir::OpBuilder &builder,
70 mlir::Type resultType, mlir::ValueRange inputs,
71 mlir::Location loc) -> mlir::Value {
72 if (inputs.size() != 1)
73 return Value();
74 return builder
75 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
76 ->getResult(0);
77 });
78
79 addSourceMaterialization([&](mlir::OpBuilder &builder,
80 mlir::Type resultType, mlir::ValueRange inputs,
81 mlir::Location loc) -> mlir::Value {
82 if (inputs.size() != 1)
83 return Value();
84 return builder
85 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
86 ->getResult(0);
87 });
88 }
89};
90
91} // namespace
92
93/// Returns a submodule name resulting from an operation, without discriminating
94/// type information.
95static std::string getBareSubModuleName(Operation *oldOp) {
96 // The dialect name is separated from the operation name by '.', which is not
97 // valid in SystemVerilog module names. In case this name is used in
98 // SystemVerilog output, replace '.' with '_'.
99 std::string subModuleName = oldOp->getName().getStringRef().str();
100 std::replace(subModuleName.begin(), subModuleName.end(), '.', '_');
101 return subModuleName;
102}
103
104static std::string getCallName(Operation *op) {
105 auto callOp = dyn_cast<handshake::InstanceOp>(op);
106 return callOp ? callOp.getModule().str() : getBareSubModuleName(op);
107}
108
109/// Extracts the type of the data-carrying type of opType. If opType is an ESI
110/// channel, getHandshakeBundleDataType extracts the data-carrying type, else,
111/// assume that opType itself is the data-carrying type.
112static Type getOperandDataType(Value op) {
113 auto opType = op.getType();
114 if (auto channelType = dyn_cast<esi::ChannelType>(opType))
115 return channelType.getInner();
116 return opType;
117}
118
119/// Filters NoneType's from the input.
120static SmallVector<Type> filterNoneTypes(ArrayRef<Type> input) {
121 SmallVector<Type> filterRes;
122 llvm::copy_if(input, std::back_inserter(filterRes),
123 [](Type type) { return !isa<NoneType>(type); });
124 return filterRes;
125}
126
127/// Returns a set of types which may uniquely identify the provided op. Return
128/// value is <inputTypes, outputTypes>.
129using DiscriminatingTypes = std::pair<SmallVector<Type>, SmallVector<Type>>;
131 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
132 .Case<MemoryOp, ExternalMemoryOp>([&](auto memOp) {
133 return DiscriminatingTypes{{},
134 {memOp.getMemRefType().getElementType()}};
135 })
136 .Default([&](auto) {
137 // By default, all in- and output types which is not a control type
138 // (NoneType) are discriminating types.
139 std::vector<Type> inTypes, outTypes;
140 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
142 llvm::transform(op->getResults(), std::back_inserter(outTypes),
144 return DiscriminatingTypes{filterNoneTypes(inTypes),
145 filterNoneTypes(outTypes)};
146 });
147}
148
149/// Get type name. Currently we only support integer or index types.
150/// The emitted type aligns with the getFIRRTLType() method. Thus all integers
151/// other than signed integers will be emitted as unsigned.
152// NOLINTNEXTLINE(misc-no-recursion)
153static std::string getTypeName(Location loc, Type type) {
154 std::string typeName;
155 // Builtin types
156 if (type.isIntOrIndex()) {
157 if (auto indexType = dyn_cast<IndexType>(type))
158 typeName += "_ui" + std::to_string(indexType.kInternalStorageBitWidth);
159 else if (type.isSignedInteger())
160 typeName += "_si" + std::to_string(type.getIntOrFloatBitWidth());
161 else
162 typeName += "_ui" + std::to_string(type.getIntOrFloatBitWidth());
163 } else if (auto tupleType = dyn_cast<TupleType>(type)) {
164 typeName += "_tuple";
165 for (auto elementType : tupleType.getTypes())
166 typeName += getTypeName(loc, elementType);
167 } else if (auto structType = dyn_cast<hw::StructType>(type)) {
168 typeName += "_struct";
169 for (auto element : structType.getElements())
170 typeName += "_" + element.name.str() + getTypeName(loc, element.type);
171 } else
172 emitError(loc) << "unsupported data type '" << type << "'";
173
174 return typeName;
175}
176
177/// Construct a name for creating HW sub-module.
178static std::string getSubModuleName(Operation *oldOp) {
179 if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
180 return instanceOp.getModule().str();
181
182 std::string subModuleName = getBareSubModuleName(oldOp);
183
184 // Add value of the constant operation.
185 if (auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
186 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
187 auto intType = intAttr.getType();
188
189 if (intType.isSignedInteger())
190 subModuleName += "_c" + std::to_string(intAttr.getSInt());
191 else if (intType.isUnsignedInteger())
192 subModuleName += "_c" + std::to_string(intAttr.getUInt());
193 else
194 subModuleName += "_c" + std::to_string((uint64_t)intAttr.getInt());
195 } else
196 oldOp->emitError("unsupported constant type");
197 }
198
199 // Add discriminating in- and output types.
200 auto [inTypes, outTypes] = getHandshakeDiscriminatingTypes(oldOp);
201 if (!inTypes.empty())
202 subModuleName += "_in";
203 for (auto inType : inTypes)
204 subModuleName += getTypeName(oldOp->getLoc(), inType);
205
206 if (!outTypes.empty())
207 subModuleName += "_out";
208 for (auto outType : outTypes)
209 subModuleName += getTypeName(oldOp->getLoc(), outType);
210
211 // Add memory ID.
212 if (auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
213 subModuleName += "_id" + std::to_string(memOp.getId());
214
215 // Add compare kind.
216 if (auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
217 subModuleName += "_" + stringifyEnum(comOp.getPredicate()).str();
218
219 // Add buffer information.
220 if (auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
221 subModuleName += "_" + std::to_string(bufferOp.getNumSlots()) + "slots";
222 if (bufferOp.isSequential())
223 subModuleName += "_seq";
224 else
225 subModuleName += "_fifo";
226
227 if (auto initValues = bufferOp.getInitValues()) {
228 subModuleName += "_init";
229 for (const Attribute e : *initValues) {
230 assert(isa<IntegerAttr>(e));
231 subModuleName +=
232 "_" + std::to_string(dyn_cast<IntegerAttr>(e).getInt());
233 }
234 }
235 }
236
237 // Add control information.
238 if (auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
239 ctrlInterface && ctrlInterface.isControl()) {
240 // Add some additional discriminating info for non-typed operations.
241 subModuleName += "_" + std::to_string(oldOp->getNumOperands()) + "ins_" +
242 std::to_string(oldOp->getNumResults()) + "outs";
243 subModuleName += "_ctrl";
244 } else {
245 assert(
246 (!inTypes.empty() || !outTypes.empty()) &&
247 "Insufficient discriminating type info generated for the operation!");
248 }
249
250 return subModuleName;
251}
252
253//===----------------------------------------------------------------------===//
254// HW Sub-module Related Functions
255//===----------------------------------------------------------------------===//
256
257/// Check whether a submodule with the same name has been created elsewhere in
258/// the top level module. Return the matched module operation if true, otherwise
259/// return nullptr.
260static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
261 StringRef modName) {
262 if (auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
263 return mod;
264 if (auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
265 return mod;
266 return {};
267}
268
269static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
270 Operation *oldOp) {
271 HWModuleLike targetModule;
272 if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
273 targetModule = checkSubModuleOp(parentModule, instanceOp.getModule());
274 else
275 targetModule = checkSubModuleOp(parentModule, getSubModuleName(oldOp));
276
277 if (isa<handshake::InstanceOp>(oldOp))
278 assert(targetModule &&
279 "handshake.instance target modules should always have been lowered "
280 "before the modules that reference them!");
281 return targetModule;
282}
283
284/// Returns a vector of PortInfo's which defines the HW interface of the
285/// to-be-converted op.
286static ModulePortInfo getPortInfoForOp(Operation *op) {
287 return getPortInfoForOpTypes(op, op->getOperandTypes(), op->getResultTypes());
288}
289
290static llvm::SmallVector<hw::detail::FieldInfo>
291portToFieldInfo(llvm::ArrayRef<hw::PortInfo> portInfo) {
292 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
293 for (auto port : portInfo)
294 fieldInfo.push_back({port.name, port.type});
295
296 return fieldInfo;
297}
298
299// Convert any handshake.extmemory operations and the top-level I/O
300// associated with these.
301static LogicalResult convertExtMemoryOps(HWModuleOp mod) {
302 auto *ctx = mod.getContext();
303
304 // Gather memref ports to be converted.
305 llvm::DenseMap<unsigned, Value> memrefPorts;
306 for (auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
307 auto channel = dyn_cast<esi::ChannelType>(arg.getType());
308 if (channel && isa<MemRefType>(channel.getInner()))
309 memrefPorts[i] = arg;
310 }
311
312 if (memrefPorts.empty())
313 return success(); // nothing to do.
314
315 OpBuilder b(mod);
316
317 auto getMemoryIOInfo = [&](Location loc, Twine portName, unsigned argIdx,
318 ArrayRef<hw::PortInfo> info,
319 hw::ModulePort::Direction direction) {
320 auto type = hw::StructType::get(ctx, portToFieldInfo(info));
321 auto portInfo =
322 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
323 return portInfo;
324 };
325
326 for (auto [i, arg] : memrefPorts) {
327 // Insert ports into the module
328 auto memName = mod.getArgName(i);
329
330 // Get the attached extmemory external module.
331 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
332 auto extmemMod =
333 cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
334 extmemInstance, extmemInstance.getModuleNameAttr()));
335
336 ModulePortInfo portInfo(extmemMod.getPortList());
337
338 // The extmemory external module's interface is a direct wrapping of the
339 // original handshake.extmemory operation in- and output types. Remove the
340 // first input argument (the !esi.channel<memref> op) since that is what
341 // we're replacing with a materialized interface.
342 portInfo.eraseInput(0);
343
344 // Add memory input - this is the output of the extmemory op.
345 SmallVector<PortInfo> outputs(portInfo.getOutputs());
346 auto inPortInfo =
347 getMemoryIOInfo(arg.getLoc(), memName.strref() + "_in", i, outputs,
348 hw::ModulePort::Direction::Input);
349 mod.insertPorts({{i, inPortInfo}}, {});
350 auto newInPort = mod.getArgumentForInput(i);
351 // Replace the extmemory submodule outputs with the newly created inputs.
352 b.setInsertionPointToStart(mod.getBodyBlock());
353 auto newInPortExploded = b.create<hw::StructExplodeOp>(
354 arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
355 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
356
357 // Add memory output - this is the inputs of the extmemory op (without the
358 // first argument);
359 unsigned outArgI = mod.getNumOutputPorts();
360 SmallVector<PortInfo> inputs(portInfo.getInputs());
361 auto outPortInfo =
362 getMemoryIOInfo(arg.getLoc(), memName.strref() + "_out", outArgI,
363 inputs, hw::ModulePort::Direction::Output);
364
365 auto memOutputArgs = extmemInstance.getOperands().drop_front();
366 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
367 auto memOutputStruct = b.create<hw::StructCreateOp>(
368 arg.getLoc(), outPortInfo.type, memOutputArgs);
369 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
370
371 // Erase the extmemory submodule instace since the i/o has now been
372 // plumbed.
373 extmemMod.erase();
374 extmemInstance.erase();
375
376 // Erase the original memref argument of the top-level i/o now that it's use
377 // has been removed.
378 mod.modifyPorts(/*insertInputs*/ {}, /*insertOutputs*/ {},
379 /*eraseInputs*/ {i + 1}, /*eraseOutputs*/ {});
380 }
381
382 return success();
383}
384
385namespace {
386
387// Input handshakes contain a resolved valid and (optional )data signal, and
388// a to-be-assigned ready signal.
389struct InputHandshake {
390 Value valid;
391 std::shared_ptr<Backedge> ready;
392 Value data;
393};
394
395// Output handshakes contain a resolved ready, and to-be-assigned valid and
396// (optional) data signals.
397struct OutputHandshake {
398 std::shared_ptr<Backedge> valid;
399 Value ready;
400 std::shared_ptr<Backedge> data;
401};
402
403/// A helper struct that acts like a wire. Can be used to interact with the
404/// RTLBuilder when multiple built components should be connected.
405struct HandshakeWire {
406 HandshakeWire(BackedgeBuilder &bb, Type dataType) {
407 MLIRContext *ctx = dataType.getContext();
408 auto i1Type = IntegerType::get(ctx, 1);
409 valid = std::make_shared<Backedge>(bb.get(i1Type));
410 ready = std::make_shared<Backedge>(bb.get(i1Type));
411 data = std::make_shared<Backedge>(bb.get(dataType));
412 }
413
414 // Functions that allow to treat a wire like an input or output port.
415 // **Careful**: Such a port will not be updated when backedges are resolved.
416 InputHandshake getAsInput() { return {*valid, ready, *data}; }
417 OutputHandshake getAsOutput() { return {valid, *ready, data}; }
418
419 std::shared_ptr<Backedge> valid;
420 std::shared_ptr<Backedge> ready;
421 std::shared_ptr<Backedge> data;
422};
423
424template <typename T, typename TInner>
425llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
426 llvm::function_ref<T(TInner &)> extractor) {
427 llvm::SmallVector<T> result;
428 llvm::transform(container, std::back_inserter(result), extractor);
429 return result;
430}
431struct UnwrappedIO {
432 llvm::SmallVector<InputHandshake> inputs;
433 llvm::SmallVector<OutputHandshake> outputs;
434
435 llvm::SmallVector<Value> getInputValids() {
436 return extractValues<Value, InputHandshake>(
437 inputs, [](auto &hs) { return hs.valid; });
438 }
439 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
440 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
441 inputs, [](auto &hs) { return hs.ready; });
442 }
443 llvm::SmallVector<Value> getInputDatas() {
444 return extractValues<Value, InputHandshake>(
445 inputs, [](auto &hs) { return hs.data; });
446 }
447 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
448 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
449 outputs, [](auto &hs) { return hs.valid; });
450 }
451 llvm::SmallVector<Value> getOutputReadys() {
452 return extractValues<Value, OutputHandshake>(
453 outputs, [](auto &hs) { return hs.ready; });
454 }
455 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
456 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
457 outputs, [](auto &hs) { return hs.data; });
458 }
459};
460
461// A class containing a bunch of syntactic sugar to reduce builder function
462// verbosity.
463// @todo: should be moved to support.
464struct RTLBuilder {
465 RTLBuilder(hw::ModulePortInfo info, OpBuilder &builder, Location loc,
466 Value clk = Value(), Value rst = Value())
467 : info(std::move(info)), b(builder), loc(loc), clk(clk), rst(rst) {}
468
469 Value constant(const APInt &apv, std::optional<StringRef> name = {}) {
470 // Cannot use zero-width APInt's in DenseMap's, see
471 // https://github.com/llvm/llvm-project/issues/58013
472 bool isZeroWidth = apv.getBitWidth() == 0;
473 if (!isZeroWidth) {
474 auto it = constants.find(apv);
475 if (it != constants.end())
476 return it->second;
477 }
478
479 auto cval = b.create<hw::ConstantOp>(loc, apv);
480 if (!isZeroWidth)
481 constants[apv] = cval;
482 return cval;
483 }
484
485 Value constant(unsigned width, int64_t value,
486 std::optional<StringRef> name = {}) {
487 return constant(
488 APInt(width, value, /*isSigned=*/false, /*implicitTrunc=*/true));
489 }
490 std::pair<Value, Value> wrap(Value data, Value valid,
491 std::optional<StringRef> name = {}) {
492 auto wrapOp = b.create<esi::WrapValidReadyOp>(loc, data, valid);
493 return {wrapOp.getResult(0), wrapOp.getResult(1)};
494 }
495 std::pair<Value, Value> unwrap(Value channel, Value ready,
496 std::optional<StringRef> name = {}) {
497 auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
498 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
499 }
500
501 // Various syntactic sugar functions.
502 Value reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
503 Value rst = Value()) {
504 Value resolvedClk = clk ? clk : this->clk;
505 Value resolvedRst = rst ? rst : this->rst;
506 assert(resolvedClk &&
507 "No global clock provided to this RTLBuilder - a clock "
508 "signal must be provided to the reg(...) function.");
509 assert(resolvedRst &&
510 "No global reset provided to this RTLBuilder - a reset "
511 "signal must be provided to the reg(...) function.");
512
513 return b.create<seq::CompRegOp>(loc, in, resolvedClk, resolvedRst, rstValue,
514 name);
515 }
516
517 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
518 std::optional<StringRef> name = {}) {
519 return b.create<comb::ICmpOp>(loc, predicate, lhs, rhs);
520 }
521
522 Value buildNamedOp(llvm::function_ref<Value()> f,
523 std::optional<StringRef> name) {
524 Value v = f();
525 StringAttr nameAttr;
526 Operation *op = v.getDefiningOp();
527 if (name.has_value()) {
528 op->setAttr("sv.namehint", b.getStringAttr(*name));
529 nameAttr = b.getStringAttr(*name);
530 }
531 return v;
532 }
533
534 // Bitwise 'and'.
535 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
536 return buildNamedOp(
537 [&]() { return b.create<comb::AndOp>(loc, values, false); }, name);
538 }
539
540 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
541 return buildNamedOp(
542 [&]() { return b.create<comb::OrOp>(loc, values, false); }, name);
543 }
544
545 // Bitwise 'not'.
546 Value bNot(Value value, std::optional<StringRef> name = {}) {
547 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
548 std::string inferedName;
549 if (!name) {
550 // Try to create a name from the input value.
551 if (auto valueName =
552 value.getDefiningOp()->getAttrOfType<StringAttr>("sv.namehint")) {
553 inferedName = ("not_" + valueName.getValue()).str();
554 name = inferedName;
555 }
556 }
557
558 return buildNamedOp(
559 [&]() { return b.create<comb::XorOp>(loc, value, allOnes); }, name);
560
561 return b.createOrFold<comb::XorOp>(loc, value, allOnes, false);
562 }
563
564 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
565 return buildNamedOp(
566 [&]() { return b.create<comb::ShlOp>(loc, value, shift); }, name);
567 }
568
569 Value concat(ValueRange values, std::optional<StringRef> name = {}) {
570 return buildNamedOp([&]() { return b.create<comb::ConcatOp>(loc, values); },
571 name);
572 }
573
574 // Packs a list of values into a hw.struct.
575 Value pack(ValueRange values, Type structType = Type(),
576 std::optional<StringRef> name = {}) {
577 if (!structType)
578 structType = tupleToStruct(values.getTypes());
579 return buildNamedOp(
580 [&]() { return b.create<hw::StructCreateOp>(loc, structType, values); },
581 name);
582 }
583
584 // Unpacks a hw.struct into a list of values.
585 ValueRange unpack(Value value) {
586 auto structType = cast<hw::StructType>(value.getType());
587 llvm::SmallVector<Type> innerTypes;
588 structType.getInnerTypes(innerTypes);
589 return b.create<hw::StructExplodeOp>(loc, innerTypes, value).getResults();
590 }
591
592 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
593 llvm::SmallVector<Value> bits;
594 for (unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
595 bits.push_back(b.create<comb::ExtractOp>(loc, v, i, /*bitWidth=*/1));
596 return bits;
597 }
598
599 // OR-reduction of the bits in 'v'.
600 Value rOr(Value v, std::optional<StringRef> name = {}) {
601 return buildNamedOp([&]() { return bOr(toBits(v)); }, name);
602 }
603
604 // Extract bits v[hi:lo] (inclusive).
605 Value extract(Value v, unsigned lo, unsigned hi,
606 std::optional<StringRef> name = {}) {
607 unsigned width = hi - lo + 1;
608 return buildNamedOp(
609 [&]() { return b.create<comb::ExtractOp>(loc, v, lo, width); }, name);
610 }
611
612 // Truncates 'value' to its lower 'width' bits.
613 Value truncate(Value value, unsigned width,
614 std::optional<StringRef> name = {}) {
615 return extract(value, 0, width - 1, name);
616 }
617
618 Value zext(Value value, unsigned outWidth,
619 std::optional<StringRef> name = {}) {
620 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
621 assert(inWidth <= outWidth && "zext: input width must be <- output width.");
622 if (inWidth == outWidth)
623 return value;
624 auto c0 = constant(outWidth - inWidth, 0);
625 return concat({c0, value}, name);
626 }
627
628 Value sext(Value value, unsigned outWidth,
629 std::optional<StringRef> name = {}) {
630 return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
631 }
632
633 // Extracts a single bit v[bit].
634 Value bit(Value v, unsigned index, std::optional<StringRef> name = {}) {
635 return extract(v, index, index, name);
636 }
637
638 // Creates a hw.array of the given values.
639 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
640 return buildNamedOp(
641 [&]() { return b.create<hw::ArrayCreateOp>(loc, values); }, name);
642 }
643
644 // Extract the 'index'th value from the input array.
645 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
646 return buildNamedOp(
647 [&]() { return b.create<hw::ArrayGetOp>(loc, array, index); }, name);
648 }
649
650 // Muxes a range of values.
651 // The select signal is expected to be a decimal value which selects starting
652 // from the lowest index of value.
653 Value mux(Value index, ValueRange values,
654 std::optional<StringRef> name = {}) {
655 if (values.size() == 2)
656 return b.create<comb::MuxOp>(loc, index, values[1], values[0]);
657
658 return arrayGet(arrayCreate(values), index, name);
659 }
660
661 // Muxes a range of values. The select signal is expected to be a 1-hot
662 // encoded value.
663 Value ohMux(Value index, ValueRange inputs) {
664 // Confirm the select input can be a one-hot encoding for the inputs.
665 unsigned numInputs = inputs.size();
666 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
667 "one-hot select can't mux inputs");
668
669 // Start the mux tree with zero value.
670 // Todo: clean up when handshake supports i0.
671 auto dataType = inputs[0].getType();
672 unsigned width =
673 isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
674 Value muxValue = constant(width, 0);
675
676 // Iteratively chain together muxes from the high bit to the low bit.
677 for (size_t i = numInputs - 1; i != 0; --i) {
678 Value input = inputs[i];
679 Value selectBit = bit(index, i);
680 muxValue = mux(selectBit, {muxValue, input});
681 }
682
683 return muxValue;
684 }
685
687 OpBuilder &b;
688 Location loc;
689 Value clk, rst;
690 DenseMap<APInt, Value> constants;
691};
692
693/// Creates a Value that has an assigned zero value. For structs, this
694/// corresponds to assigning zero to each element recursively.
695static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
696 return TypeSwitch<Type, Value>(type)
697 .Case<NoneType>([&](NoneType) { return s.constant(0, 0); })
698 .Case<IntType, IntegerType>([&](auto type) {
699 return s.constant(type.getIntOrFloatBitWidth(), 0);
700 })
701 .Case<hw::StructType>([&](auto structType) {
702 SmallVector<Value> zeroValues;
703 for (auto field : structType.getElements())
704 zeroValues.push_back(createZeroDataConst(s, loc, field.type));
705 return s.b.create<hw::StructCreateOp>(loc, structType, zeroValues);
706 })
707 .Default([&](Type) -> Value {
708 emitError(loc) << "unsupported type for zero value: " << type;
709 assert(false);
710 return {};
711 });
712}
713
714static void
715addSequentialIOOperandsIfNeeded(Operation *op,
716 llvm::SmallVectorImpl<Value> &operands) {
717 if (op->hasTrait<mlir::OpTrait::HasClock>()) {
718 // Parent should at this point be a hw.module and have clock and reset
719 // ports.
720 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
721 operands.push_back(
722 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
723 operands.push_back(
724 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
725 }
726}
727
728template <typename T>
729class HandshakeConversionPattern : public OpConversionPattern<T> {
730public:
731 HandshakeConversionPattern(ESITypeConverter &typeConverter,
732 MLIRContext *context, OpBuilder &submoduleBuilder,
733 HandshakeLoweringState &ls)
734 : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
735 submoduleBuilder(submoduleBuilder), ls(ls) {}
736
737 using OpAdaptor = typename T::Adaptor;
738
739 LogicalResult
740 matchAndRewrite(T op, OpAdaptor adaptor,
741 ConversionPatternRewriter &rewriter) const override {
742
743 // Check if a submodule has already been created for the op. If so,
744 // instantiate the submodule. Else, run the pattern-defined module
745 // builder.
746 hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
747 if (!implModule) {
748 auto portInfo = ModulePortInfo(getPortInfoForOp(op));
749
750 submoduleBuilder.setInsertionPoint(op->getParentOp());
751 implModule = submoduleBuilder.create<hw::HWModuleOp>(
752 op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
753 portInfo, [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
754 // if 'op' has clock trait, extract these and provide them to the
755 // RTL builder.
756 Value clk, rst;
757 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
758 clk = ports.getInput("clock");
759 rst = ports.getInput("reset");
760 }
761
762 BackedgeBuilder bb(b, op.getLoc());
763 RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
764 this->buildModule(op, bb, s, ports);
765 });
766 }
767
768 // Instantiate the submodule.
769 llvm::SmallVector<Value> operands = adaptor.getOperands();
770 addSequentialIOOperandsIfNeeded(op, operands);
771 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
772 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
773 return success();
774 }
775
776 virtual void buildModule(T op, BackedgeBuilder &bb, RTLBuilder &builder,
777 hw::HWModulePortAccessor &ports) const = 0;
778
779 // Syntactic sugar functions.
780 // Unwraps an ESI-interfaced module into its constituent handshake signals.
781 // Backedges are created for the to-be-resolved signals, and output ports
782 // are assigned to their wrapped counterparts.
783 UnwrappedIO unwrapIO(RTLBuilder &s, BackedgeBuilder &bb,
784 hw::HWModulePortAccessor &ports) const {
785 UnwrappedIO unwrapped;
786 for (auto port : ports.getInputs()) {
787 if (!isa<esi::ChannelType>(port.getType()))
788 continue;
789 InputHandshake hs;
790 auto ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
791 auto [data, valid] = s.unwrap(port, *ready);
792 hs.data = data;
793 hs.valid = valid;
794 hs.ready = ready;
795 unwrapped.inputs.push_back(hs);
796 }
797 for (auto &outputInfo : ports.getPortList().getOutputs()) {
798 esi::ChannelType channelType =
799 dyn_cast<esi::ChannelType>(outputInfo.type);
800 if (!channelType)
801 continue;
802 OutputHandshake hs;
803 Type innerType = channelType.getInner();
804 auto data = std::make_shared<Backedge>(bb.get(innerType));
805 auto valid = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
806 auto [dataCh, ready] = s.wrap(*data, *valid);
807 hs.data = data;
808 hs.valid = valid;
809 hs.ready = ready;
810 ports.setOutput(outputInfo.name, dataCh);
811 unwrapped.outputs.push_back(hs);
812 }
813 return unwrapped;
814 }
815
816 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
817 OutputHandshake &output, Value cond) const {
818 auto validAndReady = s.bAnd({output.ready, cond});
819 for (auto &input : inputs)
820 input.ready->setValue(validAndReady);
821 }
822
823 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
824 OutputHandshake &output) const {
825 llvm::SmallVector<Value> valids;
826 for (auto &input : inputs)
827 valids.push_back(input.valid);
828 Value allValid = s.bAnd(valids);
829 output.valid->setValue(allValid);
830 setAllReadyWithCond(s, inputs, output, allValid);
831 }
832
833 // Builds mux logic for the given inputs and outputs.
834 // Note: it is assumed that the caller has removed the 'select' signal from
835 // the 'unwrapped' inputs and provide it as a separate argument.
836 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
837 InputHandshake &select) const {
838 // ============================= Control logic =============================
839 size_t numInputs = unwrapped.inputs.size();
840 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
841 Value truncatedSelect =
842 select.data.getType().getIntOrFloatBitWidth() > selectWidth
843 ? s.truncate(select.data, selectWidth)
844 : select.data;
845
846 // Decimal-to-1-hot decoder. 'shl' operands must be identical in size.
847 auto selectZext = s.zext(truncatedSelect, numInputs);
848 auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
849 auto &res = unwrapped.outputs[0];
850
851 // Mux input valid signals.
852 auto selectedInputValid =
853 s.mux(truncatedSelect, unwrapped.getInputValids());
854 // Result is valid when the selected input and the select input is valid.
855 auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
856 res.valid->setValue(selAndInputValid);
857 auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
858
859 // Select is ready when result is valid and ready (result transacting).
860 select.ready->setValue(resValidAndReady);
861
862 // Assign each input ready signal if it is currently selected.
863 for (auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
864 // Extract the selection bit for this input.
865 auto isSelected = s.bit(select1h, inIdx);
866
867 // '&' that with the result valid and ready, and assign to the input
868 // ready signal.
869 auto activeAndResultValidAndReady =
870 s.bAnd({isSelected, resValidAndReady});
871 in.ready->setValue(activeAndResultValidAndReady);
872 }
873
874 // ============================== Data logic ===============================
875 res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
876 }
877
878 // Builds fork logic between the single input and multiple outputs' control
879 // networks. Caller is expected to handle data separately.
880 void buildForkLogic(RTLBuilder &s, BackedgeBuilder &bb, InputHandshake &input,
881 ArrayRef<OutputHandshake> outputs) const {
882 auto c0I1 = s.constant(1, 0);
883 llvm::SmallVector<Value> doneWires;
884 for (auto [i, output] : llvm::enumerate(outputs)) {
885 auto doneBE = bb.get(s.b.getI1Type());
886 auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
887 auto emittedReg = s.reg("emitted_" + std::to_string(i), emitted, c0I1);
888 auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
889 output.valid->setValue(outValid);
890 auto validReady = s.bAnd({output.ready, outValid});
891 auto done = s.bOr({validReady, emittedReg}, "done" + std::to_string(i));
892 doneBE.setValue(done);
893 doneWires.push_back(done);
894 }
895 input.ready->setValue(s.bAnd(doneWires, "allDone"));
896 }
897
898 // Builds a unit-rate actor around an inner operation. 'unitBuilder' is a
899 // function which takes the set of unwrapped data inputs, and returns a
900 // value which should be assigned to the output data value.
901 void buildUnitRateJoinLogic(
902 RTLBuilder &s, UnwrappedIO &unwrappedIO,
903 llvm::function_ref<Value(ValueRange)> unitBuilder) const {
904 assert(unwrappedIO.outputs.size() == 1 &&
905 "Expected exactly one output for unit-rate join actor");
906 // Control logic.
907 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
908
909 // Data logic.
910 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
911 unwrappedIO.outputs[0].data->setValue(unitRes);
912 }
913
914 void buildUnitRateForkLogic(
915 RTLBuilder &s, BackedgeBuilder &bb, UnwrappedIO &unwrappedIO,
916 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder) const {
917 assert(unwrappedIO.inputs.size() == 1 &&
918 "Expected exactly one input for unit-rate fork actor");
919 // Control logic.
920 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
921
922 // Data logic.
923 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
924 assert(unitResults.size() == unwrappedIO.outputs.size() &&
925 "Expected unit builder to return one result per output");
926 for (auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
927 outport.data->setValue(res);
928 }
929
930 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
931 bool signExtend) const {
932 size_t outWidth =
933 toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
934 .getIntOrFloatBitWidth();
935 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
936 if (signExtend)
937 return s.sext(inputs[0], outWidth);
938 return s.zext(inputs[0], outWidth);
939 });
940 }
941
942 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
943 unsigned targetWidth) const {
944 size_t outWidth =
945 toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
946 .getIntOrFloatBitWidth();
947 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
948 return s.truncate(inputs[0], outWidth);
949 });
950 }
951
952 /// Return the number of bits needed to index the given number of values.
953 static size_t getNumIndexBits(uint64_t numValues) {
954 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
955 }
956
957 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
958 Value defaultValue,
959 DenseMap<size_t, Value> &indexMapping) const {
960 auto numInputs = inputs.size();
961 auto priorityArb = defaultValue;
962
963 for (size_t i = numInputs; i > 0; --i) {
964 size_t inputIndex = i - 1;
965 size_t oneHotIndex = size_t{1} << inputIndex;
966 auto constIndex = s.constant(numInputs, oneHotIndex);
967 indexMapping[inputIndex] = constIndex;
968 priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
969 }
970 return priorityArb;
971 }
972
973private:
974 OpBuilder &submoduleBuilder;
975 HandshakeLoweringState &ls;
976};
977
978class ForkConversionPattern : public HandshakeConversionPattern<ForkOp> {
979public:
980 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
981 void buildModule(ForkOp op, BackedgeBuilder &bb, RTLBuilder &s,
982 hw::HWModulePortAccessor &ports) const override {
983 auto unwrapped = unwrapIO(s, bb, ports);
984 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
985 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
986 });
987 }
988};
989
990class JoinConversionPattern : public HandshakeConversionPattern<JoinOp> {
991public:
992 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
993 void buildModule(JoinOp op, BackedgeBuilder &bb, RTLBuilder &s,
994 hw::HWModulePortAccessor &ports) const override {
995 auto unwrappedIO = unwrapIO(s, bb, ports);
996 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
997 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
998 };
999};
1000
1001class SyncConversionPattern : public HandshakeConversionPattern<SyncOp> {
1002public:
1003 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1004 void buildModule(SyncOp op, BackedgeBuilder &bb, RTLBuilder &s,
1005 hw::HWModulePortAccessor &ports) const override {
1006 auto unwrappedIO = unwrapIO(s, bb, ports);
1007
1008 // A helper wire that will be used to connect the two built logics
1009 HandshakeWire wire(bb, s.b.getNoneType());
1010
1011 OutputHandshake output = wire.getAsOutput();
1012 buildJoinLogic(s, unwrappedIO.inputs, output);
1013
1014 InputHandshake input = wire.getAsInput();
1015
1016 // The state-keeping fork logic is required here, as the circuit isn't
1017 // allowed to wait for all the consumers to be ready. Connecting the ready
1018 // signals of the outputs to their corresponding valid signals leads to
1019 // combinatorial cycles. The paper which introduced compositional dataflow
1020 // circuits explicitly mentions this limitation:
1021 // http://arcade.cs.columbia.edu/df-memocode17.pdf
1022 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1023
1024 // Directly connect the data wires, only the control signals need to be
1025 // combined.
1026 for (auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1027 out.data->setValue(in.data);
1028 };
1029};
1030
1031class MuxConversionPattern : public HandshakeConversionPattern<MuxOp> {
1032public:
1033 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1034 void buildModule(MuxOp op, BackedgeBuilder &bb, RTLBuilder &s,
1035 hw::HWModulePortAccessor &ports) const override {
1036 auto unwrappedIO = unwrapIO(s, bb, ports);
1037
1038 // Extract select signal from the unwrapped IO.
1039 auto select = unwrappedIO.inputs[0];
1040 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1041 buildMuxLogic(s, unwrappedIO, select);
1042 };
1043};
1044
1045class InstanceConversionPattern
1046 : public HandshakeConversionPattern<handshake::InstanceOp> {
1047public:
1048 using HandshakeConversionPattern<
1049 handshake::InstanceOp>::HandshakeConversionPattern;
1050 void buildModule(handshake::InstanceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1051 hw::HWModulePortAccessor &ports) const override {
1052 assert(false &&
1053 "If we indeed perform conversion in post-order, this "
1054 "should never be called. The base HandshakeConversionPattern logic "
1055 "will instantiate the external module.");
1056 }
1057};
1058
1059class ESIInstanceConversionPattern
1060 : public OpConversionPattern<handshake::ESIInstanceOp> {
1061public:
1062 ESIInstanceConversionPattern(MLIRContext *context,
1063 const HWSymbolCache &symCache)
1064 : OpConversionPattern(context), symCache(symCache) {}
1065
1066 LogicalResult
1067 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1068 ConversionPatternRewriter &rewriter) const override {
1069 // The operand signature of this op is very similar to the lowered
1070 // `handshake.func`s (especially since handshake uses ESI channels
1071 // internally). Whereas ESIInstance ops have 'clk' and 'rst' at the
1072 // beginning, lowered `handshake.func`s have them at the end. So we've just
1073 // got to re-arrange them.
1074 SmallVector<Value> operands;
1075 for (size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1076 i < e; ++i)
1077 operands.push_back(adaptor.getOperands()[i]);
1078 operands.push_back(adaptor.getClk());
1079 operands.push_back(adaptor.getRst());
1080 // Locate the lowered module so the instance builder can get all the
1081 // metadata.
1082 Operation *targetModule = symCache.getDefinition(op.getModuleAttr());
1083 // And replace the op with an instance of the target module.
1084 rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1085 op.getInstNameAttr(), operands);
1086 return success();
1087 }
1088
1089private:
1090 const HWSymbolCache &symCache;
1091};
1092
1093class ReturnConversionPattern
1094 : public OpConversionPattern<handshake::ReturnOp> {
1095public:
1096 using OpConversionPattern::OpConversionPattern;
1097 LogicalResult
1098 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter) const override {
1100 // Locate existing output op, Append operands to output op, and move to
1101 // the end of the block.
1102 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1103 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1104 outputOp->setOperands(adaptor.getOperands());
1105 outputOp->moveAfter(&parent.getBodyBlock()->back());
1106 rewriter.eraseOp(op);
1107 return success();
1108 }
1109};
1110
1111// Converts an arbitrary operation into a unit rate actor. A unit rate actor
1112// will transact once all inputs are valid and its output is ready.
1113template <typename TIn, typename TOut = TIn>
1114class UnitRateConversionPattern : public HandshakeConversionPattern<TIn> {
1115public:
1116 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1117 void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1118 hw::HWModulePortAccessor &ports) const override {
1119 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1120 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1121 // Create TOut - it is assumed that TOut trivially
1122 // constructs from the input data signals of TIn.
1123 // To disambiguate ambiguous builders with default arguments (e.g.,
1124 // twoState UnitAttr), specify attribute array explicitly.
1125 return s.b.create<TOut>(op.getLoc(), inputs,
1126 /* attributes */ ArrayRef<NamedAttribute>{});
1127 });
1128 };
1129};
1130
1131class PackConversionPattern : public HandshakeConversionPattern<PackOp> {
1132public:
1133 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1134 void buildModule(PackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1135 hw::HWModulePortAccessor &ports) const override {
1136 auto unwrappedIO = unwrapIO(s, bb, ports);
1137 buildUnitRateJoinLogic(s, unwrappedIO,
1138 [&](ValueRange inputs) { return s.pack(inputs); });
1139 };
1140};
1141
1142class StructCreateConversionPattern
1143 : public HandshakeConversionPattern<hw::StructCreateOp> {
1144public:
1145 using HandshakeConversionPattern<
1146 hw::StructCreateOp>::HandshakeConversionPattern;
1147 void buildModule(hw::StructCreateOp op, BackedgeBuilder &bb, RTLBuilder &s,
1148 hw::HWModulePortAccessor &ports) const override {
1149 auto unwrappedIO = unwrapIO(s, bb, ports);
1150 auto structType = op.getResult().getType();
1151 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1152 return s.pack(inputs, structType);
1153 });
1154 };
1155};
1156
1157class UnpackConversionPattern : public HandshakeConversionPattern<UnpackOp> {
1158public:
1159 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1160 void buildModule(UnpackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1161 hw::HWModulePortAccessor &ports) const override {
1162 auto unwrappedIO = unwrapIO(s, bb, ports);
1163 buildUnitRateForkLogic(s, bb, unwrappedIO,
1164 [&](Value input) { return s.unpack(input); });
1165 };
1166};
1167
1168class ConditionalBranchConversionPattern
1169 : public HandshakeConversionPattern<ConditionalBranchOp> {
1170public:
1171 using HandshakeConversionPattern<
1172 ConditionalBranchOp>::HandshakeConversionPattern;
1173 void buildModule(ConditionalBranchOp op, BackedgeBuilder &bb, RTLBuilder &s,
1174 hw::HWModulePortAccessor &ports) const override {
1175 auto unwrappedIO = unwrapIO(s, bb, ports);
1176 auto cond = unwrappedIO.inputs[0];
1177 auto arg = unwrappedIO.inputs[1];
1178 auto trueRes = unwrappedIO.outputs[0];
1179 auto falseRes = unwrappedIO.outputs[1];
1180
1181 auto condArgValid = s.bAnd({cond.valid, arg.valid});
1182
1183 // Connect valid signal of both results.
1184 trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1185 falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1186
1187 // Connecte data signals of both results.
1188 trueRes.data->setValue(arg.data);
1189 falseRes.data->setValue(arg.data);
1190
1191 // Connect ready signal of input and condition.
1192 auto selectedResultReady =
1193 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1194 auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1195 arg.ready->setValue(condArgReady);
1196 cond.ready->setValue(condArgReady);
1197 };
1198};
1199
1200template <typename TIn, bool signExtend>
1201class ExtendConversionPattern : public HandshakeConversionPattern<TIn> {
1202public:
1203 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1204 void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1205 hw::HWModulePortAccessor &ports) const override {
1206 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1207 this->buildExtendLogic(s, unwrappedIO, /*signExtend=*/signExtend);
1208 };
1209};
1210
1211class ComparisonConversionPattern
1212 : public HandshakeConversionPattern<arith::CmpIOp> {
1213public:
1214 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1215 void buildModule(arith::CmpIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1216 hw::HWModulePortAccessor &ports) const override {
1217 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1218 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1219 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1220 return s.b.create<comb::ICmpOp>(op.getLoc(), predicate, inputs[0],
1221 inputs[1]);
1222 });
1223 };
1224
1225 switch (op.getPredicate()) {
1226 case arith::CmpIPredicate::eq:
1227 return buildCompareLogic(comb::ICmpPredicate::eq);
1228 case arith::CmpIPredicate::ne:
1229 return buildCompareLogic(comb::ICmpPredicate::ne);
1230 case arith::CmpIPredicate::slt:
1231 return buildCompareLogic(comb::ICmpPredicate::slt);
1232 case arith::CmpIPredicate::ult:
1233 return buildCompareLogic(comb::ICmpPredicate::ult);
1234 case arith::CmpIPredicate::sle:
1235 return buildCompareLogic(comb::ICmpPredicate::sle);
1236 case arith::CmpIPredicate::ule:
1237 return buildCompareLogic(comb::ICmpPredicate::ule);
1238 case arith::CmpIPredicate::sgt:
1239 return buildCompareLogic(comb::ICmpPredicate::sgt);
1240 case arith::CmpIPredicate::ugt:
1241 return buildCompareLogic(comb::ICmpPredicate::ugt);
1242 case arith::CmpIPredicate::sge:
1243 return buildCompareLogic(comb::ICmpPredicate::sge);
1244 case arith::CmpIPredicate::uge:
1245 return buildCompareLogic(comb::ICmpPredicate::uge);
1246 }
1247 assert(false && "invalid CmpIOp");
1248 };
1249};
1250
1251class TruncateConversionPattern
1252 : public HandshakeConversionPattern<arith::TruncIOp> {
1253public:
1254 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1255 void buildModule(arith::TruncIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1256 hw::HWModulePortAccessor &ports) const override {
1257 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1258 unsigned targetBits =
1259 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1260 buildTruncateLogic(s, unwrappedIO, targetBits);
1261 };
1262};
1263
1264class ControlMergeConversionPattern
1265 : public HandshakeConversionPattern<ControlMergeOp> {
1266public:
1267 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1268 void buildModule(ControlMergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1269 hw::HWModulePortAccessor &ports) const override {
1270 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1271 auto resData = unwrappedIO.outputs[0];
1272 auto resIndex = unwrappedIO.outputs[1];
1273
1274 // Define some common types and values that will be used.
1275 unsigned numInputs = unwrappedIO.inputs.size();
1276 auto indexType = s.b.getIntegerType(numInputs);
1277 Value noWinner = s.constant(numInputs, 0);
1278 Value c0I1 = s.constant(1, 0);
1279
1280 // Declare register for storing arbitration winner.
1281 auto won = bb.get(indexType);
1282 Value wonReg = s.reg("won_reg", won, noWinner);
1283
1284 // Declare wire for arbitration winner.
1285 auto win = bb.get(indexType);
1286
1287 // Declare wire for whether the circuit just fired and emitted both
1288 // outputs.
1289 auto fired = bb.get(s.b.getI1Type());
1290
1291 // Declare registers for storing if each output has been emitted.
1292 auto resultEmitted = bb.get(s.b.getI1Type());
1293 Value resultEmittedReg = s.reg("result_emitted_reg", resultEmitted, c0I1);
1294 auto indexEmitted = bb.get(s.b.getI1Type());
1295 Value indexEmittedReg = s.reg("index_emitted_reg", indexEmitted, c0I1);
1296
1297 // Declare wires for if each output is done.
1298 auto resultDone = bb.get(s.b.getI1Type());
1299 auto indexDone = bb.get(s.b.getI1Type());
1300
1301 // Create predicates to assert if the win wire or won register hold a
1302 // valid index.
1303 auto hasWinnerCondition = s.rOr({win});
1304 auto hadWinnerCondition = s.rOr({wonReg});
1305
1306 // Create an arbiter based on a simple priority-encoding scheme to assign
1307 // an index to the win wire. If the won register is set, just use that. In
1308 // the case that won is not set and no input is valid, set a sentinel
1309 // value to indicate no winner was chosen. The constant values are
1310 // remembered in a map so they can be re-used later to assign the arg
1311 // ready outputs.
1312 DenseMap<size_t, Value> argIndexValues;
1313 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1314 noWinner, argIndexValues);
1315 priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1316 win.setValue(priorityArb);
1317
1318 // Create the logic to assign the result and index outputs. The result
1319 // valid output will always be assigned, and if isControl is not set, the
1320 // result data output will also be assigned. The index valid and data
1321 // outputs will always be assigned. The win wire from the arbiter is used
1322 // to index into a tree of muxes to select the chosen input's signal(s),
1323 // and is fed directly to the index output. Both the result and index
1324 // valid outputs are gated on the win wire being set to something other
1325 // than the sentinel value.
1326 auto resultNotEmitted = s.bNot(resultEmittedReg);
1327 auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1328 resData.valid->setValue(resultValid);
1329 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1330
1331 auto indexNotEmitted = s.bNot(indexEmittedReg);
1332 auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1333 resIndex.valid->setValue(indexValid);
1334
1335 // Use the one-hot win wire to select the index to output in the index
1336 // data.
1337 SmallVector<Value, 8> indexOutputs;
1338 for (size_t i = 0; i < numInputs; ++i)
1339 indexOutputs.push_back(s.constant(64, i));
1340
1341 auto indexOutput = s.ohMux(win, indexOutputs);
1342 resIndex.data->setValue(indexOutput);
1343
1344 // Create the logic to set the won register. If the fired wire is
1345 // asserted, we have finished this round and can and reset the register to
1346 // the sentinel value that indicates there is no winner. Otherwise, we
1347 // need to hold the value of the win register until we can fire.
1348 won.setValue(s.mux(fired, {win, noWinner}));
1349
1350 // Create the logic to set the done wires for the result and index. For
1351 // both outputs, the done wire is asserted when the output is valid and
1352 // ready, or the emitted register for that output is set.
1353 auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1354 resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1355
1356 auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1357 indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1358
1359 // Create the logic to set the fired wire. It is asserted when both result
1360 // and index are done.
1361 fired.setValue(s.bAnd({resultDone, indexDone}));
1362
1363 // Create the logic to assign the emitted registers. If the fired wire is
1364 // asserted, we have finished this round and can reset the registers to 0.
1365 // Otherwise, we need to hold the values of the done registers until we
1366 // can fire.
1367 resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1368 indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1369
1370 // Create the logic to assign the arg ready outputs. The logic is
1371 // identical for each arg. If the fired wire is asserted, and the win wire
1372 // holds an arg's index, that arg is ready.
1373 auto winnerOrDefault = s.mux(fired, {noWinner, win});
1374 for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1375 auto &indexValue = argIndexValues[i];
1376 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1377 }
1378 };
1379};
1380
1381class MergeConversionPattern : public HandshakeConversionPattern<MergeOp> {
1382public:
1383 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1384 void buildModule(MergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1385 hw::HWModulePortAccessor &ports) const override {
1386 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1387 auto resData = unwrappedIO.outputs[0];
1388
1389 // Define some common types and values that will be used.
1390 unsigned numInputs = unwrappedIO.inputs.size();
1391 auto indexType = s.b.getIntegerType(numInputs);
1392 Value noWinner = s.constant(numInputs, 0);
1393
1394 // Declare wire for arbitration winner.
1395 auto win = bb.get(indexType);
1396
1397 // Create predicates to assert if the win wire holds a valid index.
1398 auto hasWinnerCondition = s.rOr(win);
1399
1400 // Create an arbiter based on a simple priority-encoding scheme to assign an
1401 // index to the win wire. In the case that no input is valid, set a sentinel
1402 // value to indicate no winner was chosen. The constant values are
1403 // remembered in a map so they can be re-used later to assign the arg ready
1404 // outputs.
1405 DenseMap<size_t, Value> argIndexValues;
1406 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1407 noWinner, argIndexValues);
1408 win.setValue(priorityArb);
1409
1410 // Create the logic to assign the result outputs. The result valid and data
1411 // outputs will always be assigned. The win wire from the arbiter is used to
1412 // index into a tree of muxes to select the chosen input's signal(s). The
1413 // result outputs are gated on the win wire being non-zero.
1414
1415 resData.valid->setValue(hasWinnerCondition);
1416 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1417
1418 // Create the logic to set the done wires for the result. The done wire is
1419 // asserted when the output is valid and ready, or the emitted register is
1420 // set.
1421 auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1422
1423 // Create the logic to assign the arg ready outputs. The logic is
1424 // identical for each arg. If the fired wire is asserted, and the win wire
1425 // holds an arg's index, that arg is ready.
1426 auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1427 for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1428 auto &indexValue = argIndexValues[i];
1429 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1430 }
1431 };
1432};
1433
1434class LoadConversionPattern
1435 : public HandshakeConversionPattern<handshake::LoadOp> {
1436public:
1437 using HandshakeConversionPattern<
1438 handshake::LoadOp>::HandshakeConversionPattern;
1439 void buildModule(handshake::LoadOp op, BackedgeBuilder &bb, RTLBuilder &s,
1440 hw::HWModulePortAccessor &ports) const override {
1441 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1442 auto addrFromUser = unwrappedIO.inputs[0];
1443 auto dataFromMem = unwrappedIO.inputs[1];
1444 auto controlIn = unwrappedIO.inputs[2];
1445 auto dataToUser = unwrappedIO.outputs[0];
1446 auto addrToMem = unwrappedIO.outputs[1];
1447
1448 addrToMem.data->setValue(addrFromUser.data);
1449 dataToUser.data->setValue(dataFromMem.data);
1450
1451 // The valid/ready logic between user address/control to memoryAddr is
1452 // join logic.
1453 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1454
1455 // The valid/ready logic between memoryData and outputData is a direct
1456 // connection.
1457 dataToUser.valid->setValue(dataFromMem.valid);
1458 dataFromMem.ready->setValue(dataToUser.ready);
1459 };
1460};
1461
1462class StoreConversionPattern
1463 : public HandshakeConversionPattern<handshake::StoreOp> {
1464public:
1465 using HandshakeConversionPattern<
1466 handshake::StoreOp>::HandshakeConversionPattern;
1467 void buildModule(handshake::StoreOp op, BackedgeBuilder &bb, RTLBuilder &s,
1468 hw::HWModulePortAccessor &ports) const override {
1469 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1470 auto addrFromUser = unwrappedIO.inputs[0];
1471 auto dataFromUser = unwrappedIO.inputs[1];
1472 auto controlIn = unwrappedIO.inputs[2];
1473 auto dataToMem = unwrappedIO.outputs[0];
1474 auto addrToMem = unwrappedIO.outputs[1];
1475
1476 // Create a gate that will be asserted when all outputs are ready.
1477 auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1478
1479 // Build the standard join logic from the inputs to the inputsValid and
1480 // outputsReady signals.
1481 HandshakeWire joinWire(bb, s.b.getNoneType());
1482 joinWire.ready->setValue(outputsReady);
1483 OutputHandshake joinOutput = joinWire.getAsOutput();
1484 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1485
1486 // Output address and data signals are connected directly.
1487 addrToMem.data->setValue(addrFromUser.data);
1488 dataToMem.data->setValue(dataFromUser.data);
1489
1490 // Output valid signals are connected from the inputsValid wire.
1491 addrToMem.valid->setValue(*joinWire.valid);
1492 dataToMem.valid->setValue(*joinWire.valid);
1493 };
1494};
1495
1496class MemoryConversionPattern
1497 : public HandshakeConversionPattern<handshake::MemoryOp> {
1498public:
1499 using HandshakeConversionPattern<
1500 handshake::MemoryOp>::HandshakeConversionPattern;
1501 void buildModule(handshake::MemoryOp op, BackedgeBuilder &bb, RTLBuilder &s,
1502 hw::HWModulePortAccessor &ports) const override {
1503 auto loc = op.getLoc();
1504
1505 // Gather up the load and store ports.
1506 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1507 struct LoadPort {
1508 InputHandshake &addr;
1509 OutputHandshake &data;
1510 OutputHandshake &done;
1511 };
1512 struct StorePort {
1513 InputHandshake &addr;
1514 InputHandshake &data;
1515 OutputHandshake &done;
1516 };
1517 SmallVector<LoadPort, 4> loadPorts;
1518 SmallVector<StorePort, 4> storePorts;
1519
1520 unsigned stCount = op.getStCount();
1521 unsigned ldCount = op.getLdCount();
1522 for (unsigned i = 0, e = ldCount; i != e; ++i) {
1523 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1524 unwrappedIO.outputs[i],
1525 unwrappedIO.outputs[ldCount + stCount + i]};
1526 loadPorts.push_back(port);
1527 }
1528
1529 for (unsigned i = 0, e = stCount; i != e; ++i) {
1530 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1531 unwrappedIO.inputs[i * 2],
1532 unwrappedIO.outputs[ldCount + i]};
1533 storePorts.push_back(port);
1534 }
1535
1536 // used to drive the data wire of the control-only channels.
1537 auto c0I0 = s.constant(0, 0);
1538
1539 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1540 auto hlmem = s.b.create<seq::HLMemOp>(
1541 loc, s.clk, s.rst, "_handshake_memory_" + std::to_string(op.getId()),
1542 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1543
1544 // Create load ports...
1545 for (auto &ld : loadPorts) {
1546 llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1547 auto readData = s.b.create<seq::ReadPortOp>(loc, hlmem.getHandle(),
1548 addresses, ld.addr.valid,
1549 /*latency=*/0);
1550 ld.data.data->setValue(readData);
1551 ld.done.data->setValue(c0I0);
1552 // Create control fork for the load address valid and ready signals.
1553 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1554 }
1555
1556 // Create store ports...
1557 for (auto &st : storePorts) {
1558 // Create a register to buffer the valid path by 1 cycle, to match the
1559 // write latency of 1.
1560 auto writeValidBufferMuxBE = bb.get(s.b.getI1Type());
1561 auto writeValidBuffer =
1562 s.reg("writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1563 st.done.valid->setValue(writeValidBuffer);
1564 st.done.data->setValue(c0I0);
1565
1566 // Create the logic for when both the buffered write valid signal and the
1567 // store complete ready signal are asserted.
1568 auto storeCompleted =
1569 s.bAnd({st.done.ready, writeValidBuffer}, "storeCompleted");
1570
1571 // Create a signal for when the write valid buffer is empty or the output
1572 // is ready.
1573 auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1574 auto emptyOrComplete =
1575 s.bOr({notWriteValidBuffer, storeCompleted}, "emptyOrComplete");
1576
1577 // Connect the gate to both the store address ready and store data ready
1578 st.addr.ready->setValue(emptyOrComplete);
1579 st.data.ready->setValue(emptyOrComplete);
1580
1581 // Create a wire for when both the store address and data are valid.
1582 auto writeValid = s.bAnd({st.addr.valid, st.data.valid}, "writeValid");
1583
1584 // Create a mux that drives the buffer input. If the emptyOrComplete
1585 // signal is asserted, the mux selects the writeValid signal. Otherwise,
1586 // it selects the buffer output, keeping the output registered until the
1587 // emptyOrComplete signal is asserted.
1588 writeValidBufferMuxBE.setValue(
1589 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1590
1591 // Instantiate the write port operation - truncate address width to memory
1592 // width.
1593 llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1594 s.b.create<seq::WritePortOp>(loc, hlmem.getHandle(), addresses,
1595 st.data.data, writeValid,
1596 /*latency=*/1);
1597 }
1598 }
1599}; // namespace
1600
1601class SinkConversionPattern : public HandshakeConversionPattern<SinkOp> {
1602public:
1603 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1604 void buildModule(SinkOp op, BackedgeBuilder &bb, RTLBuilder &s,
1605 hw::HWModulePortAccessor &ports) const override {
1606 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1607 // A sink is always ready to accept a new value.
1608 unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1609 };
1610};
1611
1612class SourceConversionPattern : public HandshakeConversionPattern<SourceOp> {
1613public:
1614 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1615 void buildModule(SourceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1616 hw::HWModulePortAccessor &ports) const override {
1617 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1618 // A source always provides a new (i0-typed) value.
1619 unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1620 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1621 };
1622};
1623
1624class ConstantConversionPattern
1625 : public HandshakeConversionPattern<handshake::ConstantOp> {
1626public:
1627 using HandshakeConversionPattern<
1628 handshake::ConstantOp>::HandshakeConversionPattern;
1629 void buildModule(handshake::ConstantOp op, BackedgeBuilder &bb, RTLBuilder &s,
1630 hw::HWModulePortAccessor &ports) const override {
1631 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1632 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1633 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1634 auto constantValue = op->getAttrOfType<IntegerAttr>("value").getValue();
1635 unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1636 };
1637};
1638
1639class BufferConversionPattern : public HandshakeConversionPattern<BufferOp> {
1640public:
1641 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1642 void buildModule(BufferOp op, BackedgeBuilder &bb, RTLBuilder &s,
1643 hw::HWModulePortAccessor &ports) const override {
1644 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1645 auto input = unwrappedIO.inputs[0];
1646 auto output = unwrappedIO.outputs[0];
1647 InputHandshake lastStage;
1648 SmallVector<int64_t> initValues;
1649
1650 // For now, always build seq buffers.
1651 if (op.getInitValues())
1652 initValues = op.getInitValueArray();
1653
1654 lastStage =
1655 buildSeqBufferLogic(s, bb, toValidType(op.getDataType()),
1656 op.getNumSlots(), input, output, initValues);
1657
1658 // Connect the last stage to the output handshake.
1659 output.data->setValue(lastStage.data);
1660 output.valid->setValue(lastStage.valid);
1661 lastStage.ready->setValue(output.ready);
1662 };
1663
1664 struct SeqBufferStage {
1665 SeqBufferStage(Type dataType, InputHandshake &preStage, BackedgeBuilder &bb,
1666 RTLBuilder &s, size_t index,
1667 std::optional<int64_t> initValue)
1668 : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1669
1670 // Todo: Change when i0 support is added.
1671 c0s = createZeroDataConst(s, s.loc, dataType);
1672 currentStage.ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
1673
1674 auto hasInitValue = s.constant(1, initValue.has_value());
1675 auto validBE = bb.get(s.b.getI1Type());
1676 auto validReg = s.reg(getRegName("valid"), validBE, hasInitValue);
1677 auto readyBE = bb.get(s.b.getI1Type());
1678
1679 Value initValueCs = c0s;
1680 if (initValue.has_value())
1681 initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1682
1683 // This could/should be revised but needs a larger rethinking to avoid
1684 // introducing new bugs.
1685 Value dataReg =
1686 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1687 buildControlBufferLogic(validReg, readyBE, dataReg);
1688 }
1689
1690 StringAttr getRegName(StringRef name) {
1691 return s.b.getStringAttr(name + std::to_string(index) + "_reg");
1692 }
1693
1694 void buildControlBufferLogic(Value validReg, Backedge &readyBE,
1695 Value dataReg) {
1696 auto c0I1 = s.constant(1, 0);
1697 auto readyRegWire = bb.get(s.b.getI1Type());
1698 auto readyReg = s.reg(getRegName("ready"), readyRegWire, c0I1);
1699
1700 // Create the logic to drive the current stage valid and potentially
1701 // data.
1702 currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1703 "controlValid" + std::to_string(index));
1704
1705 // Create the logic to drive the current stage ready.
1706 auto notReadyReg = s.bNot(readyReg);
1707 readyBE.setValue(notReadyReg);
1708
1709 auto succNotReady = s.bNot(*currentStage.ready);
1710 auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1711 auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1712 auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1713
1714 // Create a mux for emptying the register when both are ready.
1715 auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1716 readyRegWire.setValue(resetSignal);
1717
1718 // Add same logic for the data path if necessary.
1719 auto ctrlDataRegBE = bb.get(dataType);
1720 auto ctrlDataReg = s.reg(getRegName("ctrl_data"), ctrlDataRegBE, c0s);
1721 auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1722 currentStage.data = dataResult;
1723
1724 auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1725 auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1726 ctrlDataRegBE.setValue(dataResetSignal);
1727 }
1728
1729 Value buildDataBufferLogic(Value validReg, Value initValue,
1730 Backedge &validBE, Backedge &readyBE) {
1731 // Create a signal for when the valid register is empty or the successor
1732 // is ready to accept new token.
1733 auto notValidReg = s.bNot(validReg);
1734 auto emptyOrReady = s.bOr({notValidReg, readyBE});
1735 preStage.ready->setValue(emptyOrReady);
1736
1737 // Create a mux that drives the register input. If the emptyOrReady
1738 // signal is asserted, the mux selects the predValid signal. Otherwise,
1739 // it selects the register output, keeping the output registered
1740 // unchanged.
1741 auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1742
1743 // Now we can drive the valid register.
1744 validBE.setValue(validRegMux);
1745
1746 // Create a mux that drives the date register.
1747 auto dataRegBE = bb.get(dataType);
1748 auto dataReg =
1749 s.reg(getRegName("data"),
1750 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1751 dataRegBE.setValue(dataReg);
1752 return dataReg;
1753 }
1754
1755 InputHandshake getOutput() { return currentStage; }
1756
1757 Type dataType;
1758 InputHandshake &preStage;
1759 InputHandshake currentStage;
1760 RTLBuilder &s;
1761 BackedgeBuilder &bb;
1762 size_t index;
1763
1764 // A zero-valued constant of equal type as the data type of this buffer.
1765 Value c0s;
1766 };
1767
1768 InputHandshake buildSeqBufferLogic(RTLBuilder &s, BackedgeBuilder &bb,
1769 Type dataType, unsigned size,
1770 InputHandshake &input,
1771 OutputHandshake &output,
1772 llvm::ArrayRef<int64_t> initValues) const {
1773 // Prime the buffer building logic with an initial stage, which just
1774 // wraps the input handshake.
1775 InputHandshake currentStage = input;
1776
1777 for (unsigned i = 0; i < size; ++i) {
1778 bool isInitialized = i < initValues.size();
1779 auto initValue =
1780 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1781 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1782 .getOutput();
1783 }
1784
1785 return currentStage;
1786 };
1787};
1788
1789class IndexCastConversionPattern
1790 : public HandshakeConversionPattern<arith::IndexCastOp> {
1791public:
1792 using HandshakeConversionPattern<
1793 arith::IndexCastOp>::HandshakeConversionPattern;
1794 void buildModule(arith::IndexCastOp op, BackedgeBuilder &bb, RTLBuilder &s,
1795 hw::HWModulePortAccessor &ports) const override {
1796 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1797 unsigned sourceBits =
1798 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1799 unsigned targetBits =
1800 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1801 if (targetBits < sourceBits)
1802 buildTruncateLogic(s, unwrappedIO, targetBits);
1803 else
1804 buildExtendLogic(s, unwrappedIO, /*signExtend=*/true);
1805 };
1806};
1807
1808template <typename T>
1809class ExtModuleConversionPattern : public OpConversionPattern<T> {
1810public:
1811 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1812 MLIRContext *context, OpBuilder &submoduleBuilder,
1813 HandshakeLoweringState &ls)
1814 : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
1815 submoduleBuilder(submoduleBuilder), ls(ls) {}
1816 using OpAdaptor = typename T::Adaptor;
1817
1818 LogicalResult
1819 matchAndRewrite(T op, OpAdaptor adaptor,
1820 ConversionPatternRewriter &rewriter) const override {
1821
1822 hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
1823 if (!implModule) {
1824 auto portInfo = ModulePortInfo(getPortInfoForOp(op));
1825 implModule = submoduleBuilder.create<hw::HWModuleExternOp>(
1826 op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
1827 portInfo);
1828 }
1829
1830 llvm::SmallVector<Value> operands = adaptor.getOperands();
1831 addSequentialIOOperandsIfNeeded(op, operands);
1832 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1833 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1834 return success();
1835 }
1836
1837private:
1838 OpBuilder &submoduleBuilder;
1839 HandshakeLoweringState &ls;
1840};
1841
1842class FuncOpConversionPattern : public OpConversionPattern<handshake::FuncOp> {
1843public:
1844 using OpConversionPattern::OpConversionPattern;
1845
1846 LogicalResult
1847 matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
1848 ConversionPatternRewriter &rewriter) const override {
1849 ModulePortInfo ports =
1850 getPortInfoForOpTypes(op, op.getArgumentTypes(), op.getResultTypes());
1851
1852 HWModuleLike hwModule;
1853 if (op.isExternal()) {
1854 hwModule = rewriter.create<hw::HWModuleExternOp>(
1855 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1856 } else {
1857 auto hwModuleOp = rewriter.create<hw::HWModuleOp>(
1858 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1859 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1860 rewriter.inlineBlockBefore(&op.getBody().front(),
1861 hwModuleOp.getBodyBlock()->getTerminator(),
1862 args);
1863 hwModule = hwModuleOp;
1864 }
1865
1866 // Was any predeclaration associated with this func? If so, replace uses
1867 // with the newly created module and erase the predeclaration.
1868 if (auto predecl =
1869 op->getAttrOfType<FlatSymbolRefAttr>(kPredeclarationAttr)) {
1870 auto *parentOp = op->getParentOp();
1871 auto *predeclModule =
1872 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1873 if (predeclModule) {
1874 if (failed(SymbolTable::replaceAllSymbolUses(
1875 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1876 return failure();
1877 rewriter.eraseOp(predeclModule);
1878 }
1879 }
1880
1881 rewriter.eraseOp(op);
1882 return success();
1883 }
1884};
1885
1886} // namespace
1887
1888//===----------------------------------------------------------------------===//
1889// HW Top-module Related Functions
1890//===----------------------------------------------------------------------===//
1891
1892static LogicalResult convertFuncOp(ESITypeConverter &typeConverter,
1893 ConversionTarget &target,
1895 OpBuilder &moduleBuilder) {
1896
1897 std::map<std::string, unsigned> instanceNameCntr;
1898 NameUniquer instanceUniquer = [&](Operation *op) {
1899 std::string instName = getCallName(op);
1900 if (auto idAttr = op->getAttrOfType<IntegerAttr>("handshake_id"); idAttr) {
1901 // We use a special naming convention for operations which have a
1902 // 'handshake_id' attribute.
1903 instName += "_id" + std::to_string(idAttr.getValue().getZExtValue());
1904 } else {
1905 // Fallback to just prefixing with an integer.
1906 instName += std::to_string(instanceNameCntr[instName]++);
1907 }
1908 return instName;
1909 };
1910
1911 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1912 instanceUniquer};
1913 RewritePatternSet patterns(op.getContext());
1914 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1915 op.getContext());
1916 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1917 SyncConversionPattern>(typeConverter, op.getContext(),
1918 moduleBuilder, ls);
1919
1920 patterns.insert<
1921 // Comb operations.
1922 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1923 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1924 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1925 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1926 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1927 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1928 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1929 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1930 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1931 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1932 UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1933 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1934 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1935 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1936 // HW operations.
1937 StructCreateConversionPattern,
1938 // Handshake operations.
1939 ConditionalBranchConversionPattern, MuxConversionPattern,
1940 PackConversionPattern, UnpackConversionPattern,
1941 ComparisonConversionPattern, BufferConversionPattern,
1942 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1943 MergeConversionPattern, ControlMergeConversionPattern,
1944 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1945 InstanceConversionPattern,
1946 // Arith operations.
1947 ExtendConversionPattern<arith::ExtUIOp, /*signExtend=*/false>,
1948 ExtendConversionPattern<arith::ExtSIOp, /*signExtend=*/true>,
1949 TruncateConversionPattern, IndexCastConversionPattern>(
1950 typeConverter, op.getContext(), moduleBuilder, ls);
1951
1952 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1953 return op->emitOpError() << "error during conversion";
1954 return success();
1955}
1956
1957namespace {
1958class HandshakeToHWPass
1959 : public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1960public:
1961 void runOnOperation() override {
1962 mlir::ModuleOp mod = getOperation();
1963
1964 // Lowering to HW requires that every value is used exactly once. Check
1965 // whether this precondition is met, and if not, exit.
1966 for (auto f : mod.getOps<handshake::FuncOp>()) {
1967 if (failed(verifyAllValuesHasOneUse(f))) {
1968 f.emitOpError() << "HandshakeToHW: failed to verify that all values "
1969 "are used exactly once. Remember to run the "
1970 "fork/sink materialization pass before HW lowering.";
1971 signalPassFailure();
1972 return;
1973 }
1974 }
1975
1976 // Resolve the instance graph to get a top-level module.
1977 std::string topLevel;
1979 SmallVector<std::string> sortedFuncs;
1980 if (resolveInstanceGraph(mod, uses, topLevel, sortedFuncs).failed()) {
1981 signalPassFailure();
1982 return;
1983 }
1984
1985 ESITypeConverter typeConverter;
1986 ConversionTarget target(getContext());
1987 // All top-level logic of a handshake module will be the interconnectivity
1988 // between instantiated modules.
1989 target.addLegalOp<hw::HWModuleOp, hw::HWModuleExternOp, hw::OutputOp,
1990 hw::InstanceOp>();
1991 target
1992 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1993
1994 // Convert the handshake.func operations in post-order wrt. the instance
1995 // graph. This ensures that any referenced submodules (through
1996 // handshake.instance) has already been lowered, and their HW module
1997 // equivalents are available.
1998 OpBuilder submoduleBuilder(mod.getContext());
1999 submoduleBuilder.setInsertionPointToStart(mod.getBody());
2000 for (auto &funcName : llvm::reverse(sortedFuncs)) {
2001 auto funcOp = mod.lookupSymbol<handshake::FuncOp>(funcName);
2002 assert(funcOp && "handshake.func not found in module!");
2003 if (failed(
2004 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2005 signalPassFailure();
2006 return;
2007 }
2008 }
2009
2010 // Second stage: Convert any handshake.extmemory operations and the
2011 // top-level I/O associated with these.
2012 for (auto hwModule : mod.getOps<hw::HWModuleOp>())
2013 if (failed(convertExtMemoryOps(hwModule)))
2014 return signalPassFailure();
2015
2016 // Run conversions which need see everything.
2017 HWSymbolCache symbolCache;
2018 symbolCache.addDefinitions(mod);
2019 symbolCache.freeze();
2020 RewritePatternSet patterns(mod.getContext());
2021 patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2022 symbolCache);
2023 if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
2024 mod->emitOpError() << "error during conversion";
2025 signalPassFailure();
2026 }
2027 }
2028};
2029} // end anonymous namespace
2030
2031std::unique_ptr<mlir::Pass> circt::createHandshakeToHWPass() {
2032 return std::make_unique<HandshakeToHWPass>();
2033}
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
MlirType elementType
Definition CHIRRTL.cpp:29
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
Definition CalyxOps.cpp:121
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static Type tupleToStruct(TupleType tuple)
Definition DCToHW.cpp:48
std::function< std::string(Operation *)> NameUniquer
Definition DCToHW.cpp:45
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition HWOps.cpp:1420
static std::string getCallName(Operation *op)
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
Definition OM.cpp:113
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
void setOutput(unsigned i, Value v)
Definition HWOps.cpp:241
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
Definition HWSymCache.h:75
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
Channels are the basic communication primitives.
Definition Types.h:63
const Type * getInner() const
Definition Types.h:66
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition seq.py:157
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:227
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
Definition hw.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.