CIRCT 22.0.0git
Loading...
Searching...
No Matches
HandshakeToHW.cpp
Go to the documentation of this file.
1//===- HandshakeToHW.cpp - Translate Handshake into HW ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7//===----------------------------------------------------------------------===//
8//
9// This is the main Handshake to HW Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
26#include "mlir/Dialect/Arith/IR/Arith.h"
27#include "mlir/Dialect/MemRef/IR/MemRef.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/MathExtras.h"
34#include <optional>
35
36namespace circt {
37#define GEN_PASS_DEF_HANDSHAKETOHW
38#include "circt/Conversion/Passes.h.inc"
39} // namespace circt
40
41using namespace mlir;
42using namespace circt;
43using namespace circt::handshake;
44using namespace circt::hw;
45
46using NameUniquer = std::function<std::string(Operation *)>;
47
48namespace {
49
50static Type tupleToStruct(TypeRange types) {
51 return toValidType(mlir::TupleType::get(types[0].getContext(), types));
52}
53
54// Shared state used by various functions; captured in a struct to reduce the
55// number of arguments that we have to pass around.
56struct HandshakeLoweringState {
57 ModuleOp parentModule;
58 NameUniquer nameUniquer;
59};
60
61// A type converter is needed to perform the in-flight materialization of "raw"
62// (non-ESI channel) types to their ESI channel correspondents. This comes into
63// effect when backedges exist in the input IR.
64class ESITypeConverter : public TypeConverter {
65public:
66 ESITypeConverter() {
67 addConversion([](Type type) -> Type { return esiWrapper(type); });
68
69 addTargetMaterialization([&](mlir::OpBuilder &builder,
70 mlir::Type resultType, mlir::ValueRange inputs,
71 mlir::Location loc) -> mlir::Value {
72 if (inputs.size() != 1)
73 return Value();
74 return UnrealizedConversionCastOp::create(builder, loc, resultType,
75 inputs[0])
76 ->getResult(0);
77 });
78
79 addSourceMaterialization([&](mlir::OpBuilder &builder,
80 mlir::Type resultType, mlir::ValueRange inputs,
81 mlir::Location loc) -> mlir::Value {
82 if (inputs.size() != 1)
83 return Value();
84 return UnrealizedConversionCastOp::create(builder, loc, resultType,
85 inputs[0])
86 ->getResult(0);
87 });
88 }
89};
90
91} // namespace
92
93/// Returns a submodule name resulting from an operation, without discriminating
94/// type information.
95static std::string getBareSubModuleName(Operation *oldOp) {
96 // The dialect name is separated from the operation name by '.', which is not
97 // valid in SystemVerilog module names. In case this name is used in
98 // SystemVerilog output, replace '.' with '_'.
99 std::string subModuleName = oldOp->getName().getStringRef().str();
100 std::replace(subModuleName.begin(), subModuleName.end(), '.', '_');
101 return subModuleName;
102}
103
104static std::string getCallName(Operation *op) {
105 auto callOp = dyn_cast<handshake::InstanceOp>(op);
106 return callOp ? callOp.getModule().str() : getBareSubModuleName(op);
107}
108
109/// Extracts the type of the data-carrying type of opType. If opType is an ESI
110/// channel, getHandshakeBundleDataType extracts the data-carrying type, else,
111/// assume that opType itself is the data-carrying type.
112static Type getOperandDataType(Value op) {
113 auto opType = op.getType();
114 if (auto channelType = dyn_cast<esi::ChannelType>(opType))
115 return channelType.getInner();
116 return opType;
117}
118
119/// Filters NoneType's from the input.
120static SmallVector<Type> filterNoneTypes(ArrayRef<Type> input) {
121 SmallVector<Type> filterRes;
122 llvm::copy_if(input, std::back_inserter(filterRes),
123 [](Type type) { return !isa<NoneType>(type); });
124 return filterRes;
125}
126
127/// Returns a set of types which may uniquely identify the provided op. Return
128/// value is <inputTypes, outputTypes>.
129using DiscriminatingTypes = std::pair<SmallVector<Type>, SmallVector<Type>>;
131 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
132 .Case<MemoryOp, ExternalMemoryOp>([&](auto memOp) {
133 return DiscriminatingTypes{{},
134 {memOp.getMemRefType().getElementType()}};
135 })
136 .Default([&](auto) {
137 // By default, all in- and output types which is not a control type
138 // (NoneType) are discriminating types.
139 std::vector<Type> inTypes, outTypes;
140 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
142 llvm::transform(op->getResults(), std::back_inserter(outTypes),
144 return DiscriminatingTypes{filterNoneTypes(inTypes),
145 filterNoneTypes(outTypes)};
146 });
147}
148
149/// Get type name. Currently we only support integer or index types.
150/// The emitted type aligns with the getFIRRTLType() method. Thus all integers
151/// other than signed integers will be emitted as unsigned.
152// NOLINTNEXTLINE(misc-no-recursion)
153static std::string getTypeName(Location loc, Type type) {
154 std::string typeName;
155 // Builtin types
156 if (type.isIntOrIndex()) {
157 if (auto indexType = dyn_cast<IndexType>(type))
158 typeName += "_ui" + std::to_string(indexType.kInternalStorageBitWidth);
159 else if (type.isSignedInteger())
160 typeName += "_si" + std::to_string(type.getIntOrFloatBitWidth());
161 else
162 typeName += "_ui" + std::to_string(type.getIntOrFloatBitWidth());
163 } else if (auto tupleType = dyn_cast<TupleType>(type)) {
164 typeName += "_tuple";
165 for (auto elementType : tupleType.getTypes())
166 typeName += getTypeName(loc, elementType);
167 } else if (auto structType = dyn_cast<hw::StructType>(type)) {
168 typeName += "_struct";
169 for (auto element : structType.getElements())
170 typeName += "_" + element.name.str() + getTypeName(loc, element.type);
171 } else
172 emitError(loc) << "unsupported data type '" << type << "'";
173
174 return typeName;
175}
176
177/// Construct a name for creating HW sub-module.
178static std::string getSubModuleName(Operation *oldOp) {
179 if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
180 return instanceOp.getModule().str();
181
182 std::string subModuleName = getBareSubModuleName(oldOp);
183
184 // Add value of the constant operation.
185 if (auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
186 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
187 auto intType = intAttr.getType();
188
189 if (intType.isSignedInteger())
190 subModuleName += "_c" + std::to_string(intAttr.getSInt());
191 else if (intType.isUnsignedInteger())
192 subModuleName += "_c" + std::to_string(intAttr.getUInt());
193 else
194 subModuleName += "_c" + std::to_string((uint64_t)intAttr.getInt());
195 } else
196 oldOp->emitError("unsupported constant type");
197 }
198
199 // Add discriminating in- and output types.
200 auto [inTypes, outTypes] = getHandshakeDiscriminatingTypes(oldOp);
201 if (!inTypes.empty())
202 subModuleName += "_in";
203 for (auto inType : inTypes)
204 subModuleName += getTypeName(oldOp->getLoc(), inType);
205
206 if (!outTypes.empty())
207 subModuleName += "_out";
208 for (auto outType : outTypes)
209 subModuleName += getTypeName(oldOp->getLoc(), outType);
210
211 // Add memory ID.
212 if (auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
213 subModuleName += "_id" + std::to_string(memOp.getId());
214
215 // Add compare kind.
216 if (auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
217 subModuleName += "_" + stringifyEnum(comOp.getPredicate()).str();
218
219 // Add buffer information.
220 if (auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
221 subModuleName += "_" + std::to_string(bufferOp.getNumSlots()) + "slots";
222 if (bufferOp.isSequential())
223 subModuleName += "_seq";
224 else
225 subModuleName += "_fifo";
226
227 if (auto initValues = bufferOp.getInitValues()) {
228 subModuleName += "_init";
229 for (const Attribute e : *initValues) {
230 assert(isa<IntegerAttr>(e));
231 subModuleName +=
232 "_" + std::to_string(dyn_cast<IntegerAttr>(e).getInt());
233 }
234 }
235 }
236
237 // Add control information.
238 if (auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
239 ctrlInterface && ctrlInterface.isControl()) {
240 // Add some additional discriminating info for non-typed operations.
241 subModuleName += "_" + std::to_string(oldOp->getNumOperands()) + "ins_" +
242 std::to_string(oldOp->getNumResults()) + "outs";
243 subModuleName += "_ctrl";
244 } else {
245 assert(
246 (!inTypes.empty() || !outTypes.empty()) &&
247 "Insufficient discriminating type info generated for the operation!");
248 }
249
250 return subModuleName;
251}
252
253//===----------------------------------------------------------------------===//
254// HW Sub-module Related Functions
255//===----------------------------------------------------------------------===//
256
257/// Check whether a submodule with the same name has been created elsewhere in
258/// the top level module. Return the matched module operation if true, otherwise
259/// return nullptr.
260static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
261 StringRef modName) {
262 if (auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
263 return mod;
264 if (auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
265 return mod;
266 return {};
267}
268
269static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule,
270 Operation *oldOp) {
271 HWModuleLike targetModule;
272 if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
273 targetModule = checkSubModuleOp(parentModule, instanceOp.getModule());
274 else
275 targetModule = checkSubModuleOp(parentModule, getSubModuleName(oldOp));
276
277 if (isa<handshake::InstanceOp>(oldOp))
278 assert(targetModule &&
279 "handshake.instance target modules should always have been lowered "
280 "before the modules that reference them!");
281 return targetModule;
282}
283
284/// Returns a vector of PortInfo's which defines the HW interface of the
285/// to-be-converted op.
286static ModulePortInfo getPortInfoForOp(Operation *op) {
287 return getPortInfoForOpTypes(op, op->getOperandTypes(), op->getResultTypes());
288}
289
290static llvm::SmallVector<hw::detail::FieldInfo>
291portToFieldInfo(llvm::ArrayRef<hw::PortInfo> portInfo) {
292 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
293 for (auto port : portInfo)
294 fieldInfo.push_back({port.name, port.type});
295
296 return fieldInfo;
297}
298
299// Convert any handshake.extmemory operations and the top-level I/O
300// associated with these.
301static LogicalResult convertExtMemoryOps(HWModuleOp mod) {
302 auto *ctx = mod.getContext();
303
304 // Gather memref ports to be converted.
305 llvm::DenseMap<unsigned, Value> memrefPorts;
306 for (auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
307 auto channel = dyn_cast<esi::ChannelType>(arg.getType());
308 if (channel && isa<MemRefType>(channel.getInner()))
309 memrefPorts[i] = arg;
310 }
311
312 if (memrefPorts.empty())
313 return success(); // nothing to do.
314
315 OpBuilder b(mod);
316
317 auto getMemoryIOInfo = [&](Location loc, Twine portName, unsigned argIdx,
318 ArrayRef<hw::PortInfo> info,
319 hw::ModulePort::Direction direction) {
320 auto type = hw::StructType::get(ctx, portToFieldInfo(info));
321 auto portInfo =
322 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
323 return portInfo;
324 };
325
326 for (auto [i, arg] : memrefPorts) {
327 // Insert ports into the module
328 auto memName = mod.getArgName(i);
329
330 // Get the attached extmemory external module.
331 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
332 auto extmemMod =
333 cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
334 extmemInstance, extmemInstance.getModuleNameAttr()));
335
336 ModulePortInfo portInfo(extmemMod.getPortList());
337
338 // The extmemory external module's interface is a direct wrapping of the
339 // original handshake.extmemory operation in- and output types. Remove the
340 // first input argument (the !esi.channel<memref> op) since that is what
341 // we're replacing with a materialized interface.
342 portInfo.eraseInput(0);
343
344 // Add memory input - this is the output of the extmemory op.
345 SmallVector<PortInfo> outputs(portInfo.getOutputs());
346 auto inPortInfo =
347 getMemoryIOInfo(arg.getLoc(), memName.strref() + "_in", i, outputs,
348 hw::ModulePort::Direction::Input);
349 mod.insertPorts({{i, inPortInfo}}, {});
350 auto newInPort = mod.getArgumentForInput(i);
351 // Replace the extmemory submodule outputs with the newly created inputs.
352 b.setInsertionPointToStart(mod.getBodyBlock());
353 auto newInPortExploded = hw::StructExplodeOp::create(
354 b, arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
355 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
356
357 // Add memory output - this is the inputs of the extmemory op (without the
358 // first argument);
359 unsigned outArgI = mod.getNumOutputPorts();
360 SmallVector<PortInfo> inputs(portInfo.getInputs());
361 auto outPortInfo =
362 getMemoryIOInfo(arg.getLoc(), memName.strref() + "_out", outArgI,
363 inputs, hw::ModulePort::Direction::Output);
364
365 auto memOutputArgs = extmemInstance.getOperands().drop_front();
366 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
367 auto memOutputStruct = hw::StructCreateOp::create(
368 b, arg.getLoc(), outPortInfo.type, memOutputArgs);
369 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
370
371 // Erase the extmemory submodule instace since the i/o has now been
372 // plumbed.
373 extmemMod.erase();
374 extmemInstance.erase();
375
376 // Erase the original memref argument of the top-level i/o now that it's use
377 // has been removed.
378 mod.modifyPorts(/*insertInputs*/ {}, /*insertOutputs*/ {},
379 /*eraseInputs*/ {i + 1}, /*eraseOutputs*/ {});
380 }
381
382 return success();
383}
384
385namespace {
386
387// Input handshakes contain a resolved valid and (optional )data signal, and
388// a to-be-assigned ready signal.
389struct InputHandshake {
390 Value valid;
391 std::shared_ptr<Backedge> ready;
392 Value data;
393};
394
395// Output handshakes contain a resolved ready, and to-be-assigned valid and
396// (optional) data signals.
397struct OutputHandshake {
398 std::shared_ptr<Backedge> valid;
399 Value ready;
400 std::shared_ptr<Backedge> data;
401};
402
403/// A helper struct that acts like a wire. Can be used to interact with the
404/// RTLBuilder when multiple built components should be connected.
405struct HandshakeWire {
406 HandshakeWire(BackedgeBuilder &bb, Type dataType) {
407 MLIRContext *ctx = dataType.getContext();
408 auto i1Type = IntegerType::get(ctx, 1);
409 valid = std::make_shared<Backedge>(bb.get(i1Type));
410 ready = std::make_shared<Backedge>(bb.get(i1Type));
411 data = std::make_shared<Backedge>(bb.get(dataType));
412 }
413
414 // Functions that allow to treat a wire like an input or output port.
415 // **Careful**: Such a port will not be updated when backedges are resolved.
416 InputHandshake getAsInput() { return {*valid, ready, *data}; }
417 OutputHandshake getAsOutput() { return {valid, *ready, data}; }
418
419 std::shared_ptr<Backedge> valid;
420 std::shared_ptr<Backedge> ready;
421 std::shared_ptr<Backedge> data;
422};
423
424template <typename T, typename TInner>
425llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
426 llvm::function_ref<T(TInner &)> extractor) {
427 llvm::SmallVector<T> result;
428 llvm::transform(container, std::back_inserter(result), extractor);
429 return result;
430}
431struct UnwrappedIO {
432 llvm::SmallVector<InputHandshake> inputs;
433 llvm::SmallVector<OutputHandshake> outputs;
434
435 llvm::SmallVector<Value> getInputValids() {
436 return extractValues<Value, InputHandshake>(
437 inputs, [](auto &hs) { return hs.valid; });
438 }
439 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
440 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
441 inputs, [](auto &hs) { return hs.ready; });
442 }
443 llvm::SmallVector<Value> getInputDatas() {
444 return extractValues<Value, InputHandshake>(
445 inputs, [](auto &hs) { return hs.data; });
446 }
447 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
448 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
449 outputs, [](auto &hs) { return hs.valid; });
450 }
451 llvm::SmallVector<Value> getOutputReadys() {
452 return extractValues<Value, OutputHandshake>(
453 outputs, [](auto &hs) { return hs.ready; });
454 }
455 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
456 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
457 outputs, [](auto &hs) { return hs.data; });
458 }
459};
460
461// A class containing a bunch of syntactic sugar to reduce builder function
462// verbosity.
463// @todo: should be moved to support.
464struct RTLBuilder {
465 RTLBuilder(hw::ModulePortInfo info, OpBuilder &builder, Location loc,
466 Value clk = Value(), Value rst = Value())
467 : info(std::move(info)), b(builder), loc(loc), clk(clk), rst(rst) {}
468
469 Value constant(const APInt &apv, std::optional<StringRef> name = {}) {
470 // Cannot use zero-width APInt's in DenseMap's, see
471 // https://github.com/llvm/llvm-project/issues/58013
472 bool isZeroWidth = apv.getBitWidth() == 0;
473 if (!isZeroWidth) {
474 auto it = constants.find(apv);
475 if (it != constants.end())
476 return it->second;
477 }
478
479 auto cval = hw::ConstantOp::create(b, loc, apv);
480 if (!isZeroWidth)
481 constants[apv] = cval;
482 return cval;
483 }
484
485 Value constant(unsigned width, int64_t value,
486 std::optional<StringRef> name = {}) {
487 return constant(
488 APInt(width, value, /*isSigned=*/false, /*implicitTrunc=*/true));
489 }
490 std::pair<Value, Value> wrap(Value data, Value valid,
491 std::optional<StringRef> name = {}) {
492 auto wrapOp = esi::WrapValidReadyOp::create(b, loc, data, valid);
493 return {wrapOp.getResult(0), wrapOp.getResult(1)};
494 }
495 std::pair<Value, Value> unwrap(Value channel, Value ready,
496 std::optional<StringRef> name = {}) {
497 auto unwrapOp = esi::UnwrapValidReadyOp::create(b, loc, channel, ready);
498 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
499 }
500
501 // Various syntactic sugar functions.
502 Value reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
503 Value rst = Value()) {
504 Value resolvedClk = clk ? clk : this->clk;
505 Value resolvedRst = rst ? rst : this->rst;
506 assert(resolvedClk &&
507 "No global clock provided to this RTLBuilder - a clock "
508 "signal must be provided to the reg(...) function.");
509 assert(resolvedRst &&
510 "No global reset provided to this RTLBuilder - a reset "
511 "signal must be provided to the reg(...) function.");
512
513 return seq::CompRegOp::create(b, loc, in, resolvedClk, resolvedRst,
514 rstValue, name);
515 }
516
517 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
518 std::optional<StringRef> name = {}) {
519 return comb::ICmpOp::create(b, loc, predicate, lhs, rhs);
520 }
521
522 Value buildNamedOp(llvm::function_ref<Value()> f,
523 std::optional<StringRef> name) {
524 Value v = f();
525 StringAttr nameAttr;
526 Operation *op = v.getDefiningOp();
527 if (name.has_value()) {
528 op->setAttr("sv.namehint", b.getStringAttr(*name));
529 nameAttr = b.getStringAttr(*name);
530 }
531 return v;
532 }
533
534 // Bitwise 'and'.
535 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
536 return buildNamedOp(
537 [&]() { return comb::AndOp::create(b, loc, values, false); }, name);
538 }
539
540 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
541 return buildNamedOp(
542 [&]() { return comb::OrOp::create(b, loc, values, false); }, name);
543 }
544
545 // Bitwise 'not'.
546 Value bNot(Value value, std::optional<StringRef> name = {}) {
547 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
548 std::string inferedName;
549 if (!name) {
550 // Try to create a name from the input value.
551 if (auto valueName =
552 value.getDefiningOp()->getAttrOfType<StringAttr>("sv.namehint")) {
553 inferedName = ("not_" + valueName.getValue()).str();
554 name = inferedName;
555 }
556 }
557
558 return buildNamedOp(
559 [&]() { return comb::XorOp::create(b, loc, value, allOnes); }, name);
560
561 return b.createOrFold<comb::XorOp>(loc, value, allOnes, false);
562 }
563
564 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
565 return buildNamedOp(
566 [&]() { return comb::ShlOp::create(b, loc, value, shift); }, name);
567 }
568
569 Value concat(ValueRange values, std::optional<StringRef> name = {}) {
570 return buildNamedOp(
571 [&]() { return comb::ConcatOp::create(b, loc, values); }, name);
572 }
573
574 // Packs a list of values into a hw.struct.
575 Value pack(ValueRange values, Type structType = Type(),
576 std::optional<StringRef> name = {}) {
577 if (!structType)
578 structType = tupleToStruct(values.getTypes());
579 return buildNamedOp(
580 [&]() {
581 return hw::StructCreateOp::create(b, loc, structType, values);
582 },
583 name);
584 }
585
586 // Unpacks a hw.struct into a list of values.
587 ValueRange unpack(Value value) {
588 auto structType = cast<hw::StructType>(value.getType());
589 llvm::SmallVector<Type> innerTypes;
590 structType.getInnerTypes(innerTypes);
591 return hw::StructExplodeOp::create(b, loc, innerTypes, value).getResults();
592 }
593
594 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
595 llvm::SmallVector<Value> bits;
596 for (unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
597 bits.push_back(comb::ExtractOp::create(b, loc, v, i, /*bitWidth=*/1));
598 return bits;
599 }
600
601 // OR-reduction of the bits in 'v'.
602 Value rOr(Value v, std::optional<StringRef> name = {}) {
603 return buildNamedOp([&]() { return bOr(toBits(v)); }, name);
604 }
605
606 // Extract bits v[hi:lo] (inclusive).
607 Value extract(Value v, unsigned lo, unsigned hi,
608 std::optional<StringRef> name = {}) {
609 unsigned width = hi - lo + 1;
610 return buildNamedOp(
611 [&]() { return comb::ExtractOp::create(b, loc, v, lo, width); }, name);
612 }
613
614 // Truncates 'value' to its lower 'width' bits.
615 Value truncate(Value value, unsigned width,
616 std::optional<StringRef> name = {}) {
617 return extract(value, 0, width - 1, name);
618 }
619
620 Value zext(Value value, unsigned outWidth,
621 std::optional<StringRef> name = {}) {
622 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
623 assert(inWidth <= outWidth && "zext: input width must be <- output width.");
624 if (inWidth == outWidth)
625 return value;
626 auto c0 = constant(outWidth - inWidth, 0);
627 return concat({c0, value}, name);
628 }
629
630 Value sext(Value value, unsigned outWidth,
631 std::optional<StringRef> name = {}) {
632 return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
633 }
634
635 // Extracts a single bit v[bit].
636 Value bit(Value v, unsigned index, std::optional<StringRef> name = {}) {
637 return extract(v, index, index, name);
638 }
639
640 // Creates a hw.array of the given values.
641 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
642 return buildNamedOp(
643 [&]() { return hw::ArrayCreateOp::create(b, loc, values); }, name);
644 }
645
646 // Extract the 'index'th value from the input array.
647 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
648 return buildNamedOp(
649 [&]() { return hw::ArrayGetOp::create(b, loc, array, index); }, name);
650 }
651
652 // Muxes a range of values.
653 // The select signal is expected to be a decimal value which selects starting
654 // from the lowest index of value.
655 Value mux(Value index, ValueRange values,
656 std::optional<StringRef> name = {}) {
657 if (values.size() == 2)
658 return comb::MuxOp::create(b, loc, index, values[1], values[0]);
659
660 return arrayGet(arrayCreate(values), index, name);
661 }
662
663 // Muxes a range of values. The select signal is expected to be a 1-hot
664 // encoded value.
665 Value ohMux(Value index, ValueRange inputs) {
666 // Confirm the select input can be a one-hot encoding for the inputs.
667 unsigned numInputs = inputs.size();
668 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
669 "one-hot select can't mux inputs");
670
671 // Start the mux tree with zero value.
672 // Todo: clean up when handshake supports i0.
673 auto dataType = inputs[0].getType();
674 unsigned width =
675 isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
676 Value muxValue = constant(width, 0);
677
678 // Iteratively chain together muxes from the high bit to the low bit.
679 for (size_t i = numInputs - 1; i != 0; --i) {
680 Value input = inputs[i];
681 Value selectBit = bit(index, i);
682 muxValue = mux(selectBit, {muxValue, input});
683 }
684
685 return muxValue;
686 }
687
689 OpBuilder &b;
690 Location loc;
691 Value clk, rst;
692 DenseMap<APInt, Value> constants;
693};
694
695/// Creates a Value that has an assigned zero value. For structs, this
696/// corresponds to assigning zero to each element recursively.
697static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
698 return TypeSwitch<Type, Value>(type)
699 .Case<NoneType>([&](NoneType) { return s.constant(0, 0); })
700 .Case<IntType, IntegerType>([&](auto type) {
701 return s.constant(type.getIntOrFloatBitWidth(), 0);
702 })
703 .Case<hw::StructType>([&](auto structType) {
704 SmallVector<Value> zeroValues;
705 for (auto field : structType.getElements())
706 zeroValues.push_back(createZeroDataConst(s, loc, field.type));
707 return hw::StructCreateOp::create(s.b, loc, structType, zeroValues);
708 })
709 .Default([&](Type) -> Value {
710 emitError(loc) << "unsupported type for zero value: " << type;
711 assert(false);
712 return {};
713 });
714}
715
716static void
717addSequentialIOOperandsIfNeeded(Operation *op,
718 llvm::SmallVectorImpl<Value> &operands) {
719 if (op->hasTrait<mlir::OpTrait::HasClock>()) {
720 // Parent should at this point be a hw.module and have clock and reset
721 // ports.
722 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
723 operands.push_back(
724 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
725 operands.push_back(
726 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
727 }
728}
729
730template <typename T>
731class HandshakeConversionPattern : public OpConversionPattern<T> {
732public:
733 HandshakeConversionPattern(ESITypeConverter &typeConverter,
734 MLIRContext *context, OpBuilder &submoduleBuilder,
735 HandshakeLoweringState &ls)
736 : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
737 submoduleBuilder(submoduleBuilder), ls(ls) {}
738
739 using OpAdaptor = typename T::Adaptor;
740
741 LogicalResult
742 matchAndRewrite(T op, OpAdaptor adaptor,
743 ConversionPatternRewriter &rewriter) const override {
744
745 // Check if a submodule has already been created for the op. If so,
746 // instantiate the submodule. Else, run the pattern-defined module
747 // builder.
748 hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
749 if (!implModule) {
750 auto portInfo = ModulePortInfo(getPortInfoForOp(op));
751
752 submoduleBuilder.setInsertionPoint(op->getParentOp());
753 implModule = hw::HWModuleOp::create(
754 submoduleBuilder, op.getLoc(),
755 submoduleBuilder.getStringAttr(getSubModuleName(op)), portInfo,
756 [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
757 // if 'op' has clock trait, extract these and provide them to the
758 // RTL builder.
759 Value clk, rst;
760 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
761 clk = ports.getInput("clock");
762 rst = ports.getInput("reset");
763 }
764
765 BackedgeBuilder bb(b, op.getLoc());
766 RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
767 this->buildModule(op, bb, s, ports);
768 });
769 }
770
771 // Instantiate the submodule.
772 llvm::SmallVector<Value> operands = adaptor.getOperands();
773 addSequentialIOOperandsIfNeeded(op, operands);
774 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
775 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
776 return success();
777 }
778
779 virtual void buildModule(T op, BackedgeBuilder &bb, RTLBuilder &builder,
780 hw::HWModulePortAccessor &ports) const = 0;
781
782 // Syntactic sugar functions.
783 // Unwraps an ESI-interfaced module into its constituent handshake signals.
784 // Backedges are created for the to-be-resolved signals, and output ports
785 // are assigned to their wrapped counterparts.
786 UnwrappedIO unwrapIO(RTLBuilder &s, BackedgeBuilder &bb,
787 hw::HWModulePortAccessor &ports) const {
788 UnwrappedIO unwrapped;
789 for (auto port : ports.getInputs()) {
790 if (!isa<esi::ChannelType>(port.getType()))
791 continue;
792 InputHandshake hs;
793 auto ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
794 auto [data, valid] = s.unwrap(port, *ready);
795 hs.data = data;
796 hs.valid = valid;
797 hs.ready = ready;
798 unwrapped.inputs.push_back(hs);
799 }
800 for (auto &outputInfo : ports.getPortList().getOutputs()) {
801 esi::ChannelType channelType =
802 dyn_cast<esi::ChannelType>(outputInfo.type);
803 if (!channelType)
804 continue;
805 OutputHandshake hs;
806 Type innerType = channelType.getInner();
807 auto data = std::make_shared<Backedge>(bb.get(innerType));
808 auto valid = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
809 auto [dataCh, ready] = s.wrap(*data, *valid);
810 hs.data = data;
811 hs.valid = valid;
812 hs.ready = ready;
813 ports.setOutput(outputInfo.name, dataCh);
814 unwrapped.outputs.push_back(hs);
815 }
816 return unwrapped;
817 }
818
819 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
820 OutputHandshake &output, Value cond) const {
821 auto validAndReady = s.bAnd({output.ready, cond});
822 for (auto &input : inputs)
823 input.ready->setValue(validAndReady);
824 }
825
826 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
827 OutputHandshake &output) const {
828 llvm::SmallVector<Value> valids;
829 for (auto &input : inputs)
830 valids.push_back(input.valid);
831 Value allValid = s.bAnd(valids);
832 output.valid->setValue(allValid);
833 setAllReadyWithCond(s, inputs, output, allValid);
834 }
835
836 // Builds mux logic for the given inputs and outputs.
837 // Note: it is assumed that the caller has removed the 'select' signal from
838 // the 'unwrapped' inputs and provide it as a separate argument.
839 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
840 InputHandshake &select) const {
841 // ============================= Control logic =============================
842 size_t numInputs = unwrapped.inputs.size();
843 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
844 Value truncatedSelect =
845 select.data.getType().getIntOrFloatBitWidth() > selectWidth
846 ? s.truncate(select.data, selectWidth)
847 : select.data;
848
849 // Decimal-to-1-hot decoder. 'shl' operands must be identical in size.
850 auto selectZext = s.zext(truncatedSelect, numInputs);
851 auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
852 auto &res = unwrapped.outputs[0];
853
854 // Mux input valid signals.
855 auto selectedInputValid =
856 s.mux(truncatedSelect, unwrapped.getInputValids());
857 // Result is valid when the selected input and the select input is valid.
858 auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
859 res.valid->setValue(selAndInputValid);
860 auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
861
862 // Select is ready when result is valid and ready (result transacting).
863 select.ready->setValue(resValidAndReady);
864
865 // Assign each input ready signal if it is currently selected.
866 for (auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
867 // Extract the selection bit for this input.
868 auto isSelected = s.bit(select1h, inIdx);
869
870 // '&' that with the result valid and ready, and assign to the input
871 // ready signal.
872 auto activeAndResultValidAndReady =
873 s.bAnd({isSelected, resValidAndReady});
874 in.ready->setValue(activeAndResultValidAndReady);
875 }
876
877 // ============================== Data logic ===============================
878 res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
879 }
880
881 // Builds fork logic between the single input and multiple outputs' control
882 // networks. Caller is expected to handle data separately.
883 void buildForkLogic(RTLBuilder &s, BackedgeBuilder &bb, InputHandshake &input,
884 ArrayRef<OutputHandshake> outputs) const {
885 auto c0I1 = s.constant(1, 0);
886 llvm::SmallVector<Value> doneWires;
887 for (auto [i, output] : llvm::enumerate(outputs)) {
888 auto doneBE = bb.get(s.b.getI1Type());
889 auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
890 auto emittedReg = s.reg("emitted_" + std::to_string(i), emitted, c0I1);
891 auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
892 output.valid->setValue(outValid);
893 auto validReady = s.bAnd({output.ready, outValid});
894 auto done = s.bOr({validReady, emittedReg}, "done" + std::to_string(i));
895 doneBE.setValue(done);
896 doneWires.push_back(done);
897 }
898 input.ready->setValue(s.bAnd(doneWires, "allDone"));
899 }
900
901 // Builds a unit-rate actor around an inner operation. 'unitBuilder' is a
902 // function which takes the set of unwrapped data inputs, and returns a
903 // value which should be assigned to the output data value.
904 void buildUnitRateJoinLogic(
905 RTLBuilder &s, UnwrappedIO &unwrappedIO,
906 llvm::function_ref<Value(ValueRange)> unitBuilder) const {
907 assert(unwrappedIO.outputs.size() == 1 &&
908 "Expected exactly one output for unit-rate join actor");
909 // Control logic.
910 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
911
912 // Data logic.
913 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
914 unwrappedIO.outputs[0].data->setValue(unitRes);
915 }
916
917 void buildUnitRateForkLogic(
918 RTLBuilder &s, BackedgeBuilder &bb, UnwrappedIO &unwrappedIO,
919 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder) const {
920 assert(unwrappedIO.inputs.size() == 1 &&
921 "Expected exactly one input for unit-rate fork actor");
922 // Control logic.
923 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
924
925 // Data logic.
926 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
927 assert(unitResults.size() == unwrappedIO.outputs.size() &&
928 "Expected unit builder to return one result per output");
929 for (auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
930 outport.data->setValue(res);
931 }
932
933 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
934 bool signExtend) const {
935 size_t outWidth =
936 toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
937 .getIntOrFloatBitWidth();
938 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
939 if (signExtend)
940 return s.sext(inputs[0], outWidth);
941 return s.zext(inputs[0], outWidth);
942 });
943 }
944
945 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
946 unsigned targetWidth) const {
947 size_t outWidth =
948 toValidType(static_cast<Value>(*unwrappedIO.outputs[0].data).getType())
949 .getIntOrFloatBitWidth();
950 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
951 return s.truncate(inputs[0], outWidth);
952 });
953 }
954
955 /// Return the number of bits needed to index the given number of values.
956 static size_t getNumIndexBits(uint64_t numValues) {
957 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
958 }
959
960 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
961 Value defaultValue,
962 DenseMap<size_t, Value> &indexMapping) const {
963 auto numInputs = inputs.size();
964 auto priorityArb = defaultValue;
965
966 for (size_t i = numInputs; i > 0; --i) {
967 size_t inputIndex = i - 1;
968 size_t oneHotIndex = size_t{1} << inputIndex;
969 auto constIndex = s.constant(numInputs, oneHotIndex);
970 indexMapping[inputIndex] = constIndex;
971 priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
972 }
973 return priorityArb;
974 }
975
976private:
977 OpBuilder &submoduleBuilder;
978 HandshakeLoweringState &ls;
979};
980
981class ForkConversionPattern : public HandshakeConversionPattern<ForkOp> {
982public:
983 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
984 void buildModule(ForkOp op, BackedgeBuilder &bb, RTLBuilder &s,
985 hw::HWModulePortAccessor &ports) const override {
986 auto unwrapped = unwrapIO(s, bb, ports);
987 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
988 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
989 });
990 }
991};
992
993class JoinConversionPattern : public HandshakeConversionPattern<JoinOp> {
994public:
995 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
996 void buildModule(JoinOp op, BackedgeBuilder &bb, RTLBuilder &s,
997 hw::HWModulePortAccessor &ports) const override {
998 auto unwrappedIO = unwrapIO(s, bb, ports);
999 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
1000 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1001 };
1002};
1003
1004class SyncConversionPattern : public HandshakeConversionPattern<SyncOp> {
1005public:
1006 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1007 void buildModule(SyncOp op, BackedgeBuilder &bb, RTLBuilder &s,
1008 hw::HWModulePortAccessor &ports) const override {
1009 auto unwrappedIO = unwrapIO(s, bb, ports);
1010
1011 // A helper wire that will be used to connect the two built logics
1012 HandshakeWire wire(bb, s.b.getNoneType());
1013
1014 OutputHandshake output = wire.getAsOutput();
1015 buildJoinLogic(s, unwrappedIO.inputs, output);
1016
1017 InputHandshake input = wire.getAsInput();
1018
1019 // The state-keeping fork logic is required here, as the circuit isn't
1020 // allowed to wait for all the consumers to be ready. Connecting the ready
1021 // signals of the outputs to their corresponding valid signals leads to
1022 // combinatorial cycles. The paper which introduced compositional dataflow
1023 // circuits explicitly mentions this limitation:
1024 // http://arcade.cs.columbia.edu/df-memocode17.pdf
1025 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1026
1027 // Directly connect the data wires, only the control signals need to be
1028 // combined.
1029 for (auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1030 out.data->setValue(in.data);
1031 };
1032};
1033
1034class MuxConversionPattern : public HandshakeConversionPattern<MuxOp> {
1035public:
1036 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1037 void buildModule(MuxOp op, BackedgeBuilder &bb, RTLBuilder &s,
1038 hw::HWModulePortAccessor &ports) const override {
1039 auto unwrappedIO = unwrapIO(s, bb, ports);
1040
1041 // Extract select signal from the unwrapped IO.
1042 auto select = unwrappedIO.inputs[0];
1043 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1044 buildMuxLogic(s, unwrappedIO, select);
1045 };
1046};
1047
1048class InstanceConversionPattern
1049 : public HandshakeConversionPattern<handshake::InstanceOp> {
1050public:
1051 using HandshakeConversionPattern<
1052 handshake::InstanceOp>::HandshakeConversionPattern;
1053 void buildModule(handshake::InstanceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1054 hw::HWModulePortAccessor &ports) const override {
1055 assert(false &&
1056 "If we indeed perform conversion in post-order, this "
1057 "should never be called. The base HandshakeConversionPattern logic "
1058 "will instantiate the external module.");
1059 }
1060};
1061
1062class ESIInstanceConversionPattern
1063 : public OpConversionPattern<handshake::ESIInstanceOp> {
1064public:
1065 ESIInstanceConversionPattern(MLIRContext *context,
1066 const HWSymbolCache &symCache)
1067 : OpConversionPattern(context), symCache(symCache) {}
1068
1069 LogicalResult
1070 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1071 ConversionPatternRewriter &rewriter) const override {
1072 // The operand signature of this op is very similar to the lowered
1073 // `handshake.func`s (especially since handshake uses ESI channels
1074 // internally). Whereas ESIInstance ops have 'clk' and 'rst' at the
1075 // beginning, lowered `handshake.func`s have them at the end. So we've just
1076 // got to re-arrange them.
1077 SmallVector<Value> operands;
1078 for (size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1079 i < e; ++i)
1080 operands.push_back(adaptor.getOperands()[i]);
1081 operands.push_back(adaptor.getClk());
1082 operands.push_back(adaptor.getRst());
1083 // Locate the lowered module so the instance builder can get all the
1084 // metadata.
1085 Operation *targetModule = symCache.getDefinition(op.getModuleAttr());
1086 // And replace the op with an instance of the target module.
1087 rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1088 op.getInstNameAttr(), operands);
1089 return success();
1090 }
1091
1092private:
1093 const HWSymbolCache &symCache;
1094};
1095
1096class ReturnConversionPattern
1097 : public OpConversionPattern<handshake::ReturnOp> {
1098public:
1099 using OpConversionPattern::OpConversionPattern;
1100 LogicalResult
1101 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1102 ConversionPatternRewriter &rewriter) const override {
1103 // Locate existing output op, Append operands to output op, and move to
1104 // the end of the block.
1105 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1106 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1107 outputOp->setOperands(adaptor.getOperands());
1108 outputOp->moveAfter(&parent.getBodyBlock()->back());
1109 rewriter.eraseOp(op);
1110 return success();
1111 }
1112};
1113
1114// Converts an arbitrary operation into a unit rate actor. A unit rate actor
1115// will transact once all inputs are valid and its output is ready.
1116template <typename TIn, typename TOut = TIn>
1117class UnitRateConversionPattern : public HandshakeConversionPattern<TIn> {
1118public:
1119 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1120 void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1121 hw::HWModulePortAccessor &ports) const override {
1122 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1123 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1124 // Create TOut - it is assumed that TOut trivially
1125 // constructs from the input data signals of TIn.
1126 // To disambiguate ambiguous builders with default arguments (e.g.,
1127 // twoState UnitAttr), specify attribute array explicitly.
1128 return TOut::create(s.b, op.getLoc(), inputs,
1129 /* attributes */ ArrayRef<NamedAttribute>{});
1130 });
1131 };
1132};
1133
1134class PackConversionPattern : public HandshakeConversionPattern<PackOp> {
1135public:
1136 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1137 void buildModule(PackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1138 hw::HWModulePortAccessor &ports) const override {
1139 auto unwrappedIO = unwrapIO(s, bb, ports);
1140 buildUnitRateJoinLogic(s, unwrappedIO,
1141 [&](ValueRange inputs) { return s.pack(inputs); });
1142 };
1143};
1144
1145class StructCreateConversionPattern
1146 : public HandshakeConversionPattern<hw::StructCreateOp> {
1147public:
1148 using HandshakeConversionPattern<
1149 hw::StructCreateOp>::HandshakeConversionPattern;
1150 void buildModule(hw::StructCreateOp op, BackedgeBuilder &bb, RTLBuilder &s,
1151 hw::HWModulePortAccessor &ports) const override {
1152 auto unwrappedIO = unwrapIO(s, bb, ports);
1153 auto structType = op.getResult().getType();
1154 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1155 return s.pack(inputs, structType);
1156 });
1157 };
1158};
1159
1160class UnpackConversionPattern : public HandshakeConversionPattern<UnpackOp> {
1161public:
1162 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1163 void buildModule(UnpackOp op, BackedgeBuilder &bb, RTLBuilder &s,
1164 hw::HWModulePortAccessor &ports) const override {
1165 auto unwrappedIO = unwrapIO(s, bb, ports);
1166 buildUnitRateForkLogic(s, bb, unwrappedIO,
1167 [&](Value input) { return s.unpack(input); });
1168 };
1169};
1170
1171class ConditionalBranchConversionPattern
1172 : public HandshakeConversionPattern<ConditionalBranchOp> {
1173public:
1174 using HandshakeConversionPattern<
1175 ConditionalBranchOp>::HandshakeConversionPattern;
1176 void buildModule(ConditionalBranchOp op, BackedgeBuilder &bb, RTLBuilder &s,
1177 hw::HWModulePortAccessor &ports) const override {
1178 auto unwrappedIO = unwrapIO(s, bb, ports);
1179 auto cond = unwrappedIO.inputs[0];
1180 auto arg = unwrappedIO.inputs[1];
1181 auto trueRes = unwrappedIO.outputs[0];
1182 auto falseRes = unwrappedIO.outputs[1];
1183
1184 auto condArgValid = s.bAnd({cond.valid, arg.valid});
1185
1186 // Connect valid signal of both results.
1187 trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1188 falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1189
1190 // Connecte data signals of both results.
1191 trueRes.data->setValue(arg.data);
1192 falseRes.data->setValue(arg.data);
1193
1194 // Connect ready signal of input and condition.
1195 auto selectedResultReady =
1196 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1197 auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1198 arg.ready->setValue(condArgReady);
1199 cond.ready->setValue(condArgReady);
1200 };
1201};
1202
1203template <typename TIn, bool signExtend>
1204class ExtendConversionPattern : public HandshakeConversionPattern<TIn> {
1205public:
1206 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1207 void buildModule(TIn op, BackedgeBuilder &bb, RTLBuilder &s,
1208 hw::HWModulePortAccessor &ports) const override {
1209 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1210 this->buildExtendLogic(s, unwrappedIO, /*signExtend=*/signExtend);
1211 };
1212};
1213
1214class ComparisonConversionPattern
1215 : public HandshakeConversionPattern<arith::CmpIOp> {
1216public:
1217 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1218 void buildModule(arith::CmpIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1219 hw::HWModulePortAccessor &ports) const override {
1220 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1221 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1222 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1223 return comb::ICmpOp::create(s.b, op.getLoc(), predicate, inputs[0],
1224 inputs[1]);
1225 });
1226 };
1227
1228 switch (op.getPredicate()) {
1229 case arith::CmpIPredicate::eq:
1230 return buildCompareLogic(comb::ICmpPredicate::eq);
1231 case arith::CmpIPredicate::ne:
1232 return buildCompareLogic(comb::ICmpPredicate::ne);
1233 case arith::CmpIPredicate::slt:
1234 return buildCompareLogic(comb::ICmpPredicate::slt);
1235 case arith::CmpIPredicate::ult:
1236 return buildCompareLogic(comb::ICmpPredicate::ult);
1237 case arith::CmpIPredicate::sle:
1238 return buildCompareLogic(comb::ICmpPredicate::sle);
1239 case arith::CmpIPredicate::ule:
1240 return buildCompareLogic(comb::ICmpPredicate::ule);
1241 case arith::CmpIPredicate::sgt:
1242 return buildCompareLogic(comb::ICmpPredicate::sgt);
1243 case arith::CmpIPredicate::ugt:
1244 return buildCompareLogic(comb::ICmpPredicate::ugt);
1245 case arith::CmpIPredicate::sge:
1246 return buildCompareLogic(comb::ICmpPredicate::sge);
1247 case arith::CmpIPredicate::uge:
1248 return buildCompareLogic(comb::ICmpPredicate::uge);
1249 }
1250 assert(false && "invalid CmpIOp");
1251 };
1252};
1253
1254class TruncateConversionPattern
1255 : public HandshakeConversionPattern<arith::TruncIOp> {
1256public:
1257 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1258 void buildModule(arith::TruncIOp op, BackedgeBuilder &bb, RTLBuilder &s,
1259 hw::HWModulePortAccessor &ports) const override {
1260 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1261 unsigned targetBits =
1262 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1263 buildTruncateLogic(s, unwrappedIO, targetBits);
1264 };
1265};
1266
1267class ControlMergeConversionPattern
1268 : public HandshakeConversionPattern<ControlMergeOp> {
1269public:
1270 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1271 void buildModule(ControlMergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1272 hw::HWModulePortAccessor &ports) const override {
1273 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1274 auto resData = unwrappedIO.outputs[0];
1275 auto resIndex = unwrappedIO.outputs[1];
1276
1277 // Define some common types and values that will be used.
1278 unsigned numInputs = unwrappedIO.inputs.size();
1279 auto indexType = s.b.getIntegerType(numInputs);
1280 Value noWinner = s.constant(numInputs, 0);
1281 Value c0I1 = s.constant(1, 0);
1282
1283 // Declare register for storing arbitration winner.
1284 auto won = bb.get(indexType);
1285 Value wonReg = s.reg("won_reg", won, noWinner);
1286
1287 // Declare wire for arbitration winner.
1288 auto win = bb.get(indexType);
1289
1290 // Declare wire for whether the circuit just fired and emitted both
1291 // outputs.
1292 auto fired = bb.get(s.b.getI1Type());
1293
1294 // Declare registers for storing if each output has been emitted.
1295 auto resultEmitted = bb.get(s.b.getI1Type());
1296 Value resultEmittedReg = s.reg("result_emitted_reg", resultEmitted, c0I1);
1297 auto indexEmitted = bb.get(s.b.getI1Type());
1298 Value indexEmittedReg = s.reg("index_emitted_reg", indexEmitted, c0I1);
1299
1300 // Declare wires for if each output is done.
1301 auto resultDone = bb.get(s.b.getI1Type());
1302 auto indexDone = bb.get(s.b.getI1Type());
1303
1304 // Create predicates to assert if the win wire or won register hold a
1305 // valid index.
1306 auto hasWinnerCondition = s.rOr({win});
1307 auto hadWinnerCondition = s.rOr({wonReg});
1308
1309 // Create an arbiter based on a simple priority-encoding scheme to assign
1310 // an index to the win wire. If the won register is set, just use that. In
1311 // the case that won is not set and no input is valid, set a sentinel
1312 // value to indicate no winner was chosen. The constant values are
1313 // remembered in a map so they can be re-used later to assign the arg
1314 // ready outputs.
1315 DenseMap<size_t, Value> argIndexValues;
1316 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1317 noWinner, argIndexValues);
1318 priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1319 win.setValue(priorityArb);
1320
1321 // Create the logic to assign the result and index outputs. The result
1322 // valid output will always be assigned, and if isControl is not set, the
1323 // result data output will also be assigned. The index valid and data
1324 // outputs will always be assigned. The win wire from the arbiter is used
1325 // to index into a tree of muxes to select the chosen input's signal(s),
1326 // and is fed directly to the index output. Both the result and index
1327 // valid outputs are gated on the win wire being set to something other
1328 // than the sentinel value.
1329 auto resultNotEmitted = s.bNot(resultEmittedReg);
1330 auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1331 resData.valid->setValue(resultValid);
1332 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1333
1334 auto indexNotEmitted = s.bNot(indexEmittedReg);
1335 auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1336 resIndex.valid->setValue(indexValid);
1337
1338 // Use the one-hot win wire to select the index to output in the index
1339 // data.
1340 SmallVector<Value, 8> indexOutputs;
1341 for (size_t i = 0; i < numInputs; ++i)
1342 indexOutputs.push_back(s.constant(64, i));
1343
1344 auto indexOutput = s.ohMux(win, indexOutputs);
1345 resIndex.data->setValue(indexOutput);
1346
1347 // Create the logic to set the won register. If the fired wire is
1348 // asserted, we have finished this round and can and reset the register to
1349 // the sentinel value that indicates there is no winner. Otherwise, we
1350 // need to hold the value of the win register until we can fire.
1351 won.setValue(s.mux(fired, {win, noWinner}));
1352
1353 // Create the logic to set the done wires for the result and index. For
1354 // both outputs, the done wire is asserted when the output is valid and
1355 // ready, or the emitted register for that output is set.
1356 auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1357 resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1358
1359 auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1360 indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1361
1362 // Create the logic to set the fired wire. It is asserted when both result
1363 // and index are done.
1364 fired.setValue(s.bAnd({resultDone, indexDone}));
1365
1366 // Create the logic to assign the emitted registers. If the fired wire is
1367 // asserted, we have finished this round and can reset the registers to 0.
1368 // Otherwise, we need to hold the values of the done registers until we
1369 // can fire.
1370 resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1371 indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1372
1373 // Create the logic to assign the arg ready outputs. The logic is
1374 // identical for each arg. If the fired wire is asserted, and the win wire
1375 // holds an arg's index, that arg is ready.
1376 auto winnerOrDefault = s.mux(fired, {noWinner, win});
1377 for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1378 auto &indexValue = argIndexValues[i];
1379 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1380 }
1381 };
1382};
1383
1384class MergeConversionPattern : public HandshakeConversionPattern<MergeOp> {
1385public:
1386 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1387 void buildModule(MergeOp op, BackedgeBuilder &bb, RTLBuilder &s,
1388 hw::HWModulePortAccessor &ports) const override {
1389 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1390 auto resData = unwrappedIO.outputs[0];
1391
1392 // Define some common types and values that will be used.
1393 unsigned numInputs = unwrappedIO.inputs.size();
1394 auto indexType = s.b.getIntegerType(numInputs);
1395 Value noWinner = s.constant(numInputs, 0);
1396
1397 // Declare wire for arbitration winner.
1398 auto win = bb.get(indexType);
1399
1400 // Create predicates to assert if the win wire holds a valid index.
1401 auto hasWinnerCondition = s.rOr(win);
1402
1403 // Create an arbiter based on a simple priority-encoding scheme to assign an
1404 // index to the win wire. In the case that no input is valid, set a sentinel
1405 // value to indicate no winner was chosen. The constant values are
1406 // remembered in a map so they can be re-used later to assign the arg ready
1407 // outputs.
1408 DenseMap<size_t, Value> argIndexValues;
1409 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1410 noWinner, argIndexValues);
1411 win.setValue(priorityArb);
1412
1413 // Create the logic to assign the result outputs. The result valid and data
1414 // outputs will always be assigned. The win wire from the arbiter is used to
1415 // index into a tree of muxes to select the chosen input's signal(s). The
1416 // result outputs are gated on the win wire being non-zero.
1417
1418 resData.valid->setValue(hasWinnerCondition);
1419 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1420
1421 // Create the logic to set the done wires for the result. The done wire is
1422 // asserted when the output is valid and ready, or the emitted register is
1423 // set.
1424 auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1425
1426 // Create the logic to assign the arg ready outputs. The logic is
1427 // identical for each arg. If the fired wire is asserted, and the win wire
1428 // holds an arg's index, that arg is ready.
1429 auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1430 for (auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1431 auto &indexValue = argIndexValues[i];
1432 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1433 }
1434 };
1435};
1436
1437class LoadConversionPattern
1438 : public HandshakeConversionPattern<handshake::LoadOp> {
1439public:
1440 using HandshakeConversionPattern<
1441 handshake::LoadOp>::HandshakeConversionPattern;
1442 void buildModule(handshake::LoadOp op, BackedgeBuilder &bb, RTLBuilder &s,
1443 hw::HWModulePortAccessor &ports) const override {
1444 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1445 auto addrFromUser = unwrappedIO.inputs[0];
1446 auto dataFromMem = unwrappedIO.inputs[1];
1447 auto controlIn = unwrappedIO.inputs[2];
1448 auto dataToUser = unwrappedIO.outputs[0];
1449 auto addrToMem = unwrappedIO.outputs[1];
1450
1451 addrToMem.data->setValue(addrFromUser.data);
1452 dataToUser.data->setValue(dataFromMem.data);
1453
1454 // The valid/ready logic between user address/control to memoryAddr is
1455 // join logic.
1456 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1457
1458 // The valid/ready logic between memoryData and outputData is a direct
1459 // connection.
1460 dataToUser.valid->setValue(dataFromMem.valid);
1461 dataFromMem.ready->setValue(dataToUser.ready);
1462 };
1463};
1464
1465class StoreConversionPattern
1466 : public HandshakeConversionPattern<handshake::StoreOp> {
1467public:
1468 using HandshakeConversionPattern<
1469 handshake::StoreOp>::HandshakeConversionPattern;
1470 void buildModule(handshake::StoreOp op, BackedgeBuilder &bb, RTLBuilder &s,
1471 hw::HWModulePortAccessor &ports) const override {
1472 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1473 auto addrFromUser = unwrappedIO.inputs[0];
1474 auto dataFromUser = unwrappedIO.inputs[1];
1475 auto controlIn = unwrappedIO.inputs[2];
1476 auto dataToMem = unwrappedIO.outputs[0];
1477 auto addrToMem = unwrappedIO.outputs[1];
1478
1479 // Create a gate that will be asserted when all outputs are ready.
1480 auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1481
1482 // Build the standard join logic from the inputs to the inputsValid and
1483 // outputsReady signals.
1484 HandshakeWire joinWire(bb, s.b.getNoneType());
1485 joinWire.ready->setValue(outputsReady);
1486 OutputHandshake joinOutput = joinWire.getAsOutput();
1487 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1488
1489 // Output address and data signals are connected directly.
1490 addrToMem.data->setValue(addrFromUser.data);
1491 dataToMem.data->setValue(dataFromUser.data);
1492
1493 // Output valid signals are connected from the inputsValid wire.
1494 addrToMem.valid->setValue(*joinWire.valid);
1495 dataToMem.valid->setValue(*joinWire.valid);
1496 };
1497};
1498
1499class MemoryConversionPattern
1500 : public HandshakeConversionPattern<handshake::MemoryOp> {
1501public:
1502 using HandshakeConversionPattern<
1503 handshake::MemoryOp>::HandshakeConversionPattern;
1504 void buildModule(handshake::MemoryOp op, BackedgeBuilder &bb, RTLBuilder &s,
1505 hw::HWModulePortAccessor &ports) const override {
1506 auto loc = op.getLoc();
1507
1508 // Gather up the load and store ports.
1509 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1510 struct LoadPort {
1511 InputHandshake &addr;
1512 OutputHandshake &data;
1513 OutputHandshake &done;
1514 };
1515 struct StorePort {
1516 InputHandshake &addr;
1517 InputHandshake &data;
1518 OutputHandshake &done;
1519 };
1520 SmallVector<LoadPort, 4> loadPorts;
1521 SmallVector<StorePort, 4> storePorts;
1522
1523 unsigned stCount = op.getStCount();
1524 unsigned ldCount = op.getLdCount();
1525 for (unsigned i = 0, e = ldCount; i != e; ++i) {
1526 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1527 unwrappedIO.outputs[i],
1528 unwrappedIO.outputs[ldCount + stCount + i]};
1529 loadPorts.push_back(port);
1530 }
1531
1532 for (unsigned i = 0, e = stCount; i != e; ++i) {
1533 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1534 unwrappedIO.inputs[i * 2],
1535 unwrappedIO.outputs[ldCount + i]};
1536 storePorts.push_back(port);
1537 }
1538
1539 // used to drive the data wire of the control-only channels.
1540 auto c0I0 = s.constant(0, 0);
1541
1542 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1543 auto hlmem = seq::HLMemOp::create(
1544 s.b, loc, s.clk, s.rst,
1545 "_handshake_memory_" + std::to_string(op.getId()),
1546 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1547
1548 // Create load ports...
1549 for (auto &ld : loadPorts) {
1550 llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1551 auto readData = seq::ReadPortOp::create(s.b, loc, hlmem.getHandle(),
1552 addresses, ld.addr.valid,
1553 /*latency=*/0);
1554 ld.data.data->setValue(readData);
1555 ld.done.data->setValue(c0I0);
1556 // Create control fork for the load address valid and ready signals.
1557 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1558 }
1559
1560 // Create store ports...
1561 for (auto &st : storePorts) {
1562 // Create a register to buffer the valid path by 1 cycle, to match the
1563 // write latency of 1.
1564 auto writeValidBufferMuxBE = bb.get(s.b.getI1Type());
1565 auto writeValidBuffer =
1566 s.reg("writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1567 st.done.valid->setValue(writeValidBuffer);
1568 st.done.data->setValue(c0I0);
1569
1570 // Create the logic for when both the buffered write valid signal and the
1571 // store complete ready signal are asserted.
1572 auto storeCompleted =
1573 s.bAnd({st.done.ready, writeValidBuffer}, "storeCompleted");
1574
1575 // Create a signal for when the write valid buffer is empty or the output
1576 // is ready.
1577 auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1578 auto emptyOrComplete =
1579 s.bOr({notWriteValidBuffer, storeCompleted}, "emptyOrComplete");
1580
1581 // Connect the gate to both the store address ready and store data ready
1582 st.addr.ready->setValue(emptyOrComplete);
1583 st.data.ready->setValue(emptyOrComplete);
1584
1585 // Create a wire for when both the store address and data are valid.
1586 auto writeValid = s.bAnd({st.addr.valid, st.data.valid}, "writeValid");
1587
1588 // Create a mux that drives the buffer input. If the emptyOrComplete
1589 // signal is asserted, the mux selects the writeValid signal. Otherwise,
1590 // it selects the buffer output, keeping the output registered until the
1591 // emptyOrComplete signal is asserted.
1592 writeValidBufferMuxBE.setValue(
1593 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1594
1595 // Instantiate the write port operation - truncate address width to memory
1596 // width.
1597 llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1598 seq::WritePortOp::create(s.b, loc, hlmem.getHandle(), addresses,
1599 st.data.data, writeValid,
1600 /*latency=*/1);
1601 }
1602 }
1603}; // namespace
1604
1605class SinkConversionPattern : public HandshakeConversionPattern<SinkOp> {
1606public:
1607 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1608 void buildModule(SinkOp op, BackedgeBuilder &bb, RTLBuilder &s,
1609 hw::HWModulePortAccessor &ports) const override {
1610 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1611 // A sink is always ready to accept a new value.
1612 unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1613 };
1614};
1615
1616class SourceConversionPattern : public HandshakeConversionPattern<SourceOp> {
1617public:
1618 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1619 void buildModule(SourceOp op, BackedgeBuilder &bb, RTLBuilder &s,
1620 hw::HWModulePortAccessor &ports) const override {
1621 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1622 // A source always provides a new (i0-typed) value.
1623 unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1624 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1625 };
1626};
1627
1628class ConstantConversionPattern
1629 : public HandshakeConversionPattern<handshake::ConstantOp> {
1630public:
1631 using HandshakeConversionPattern<
1632 handshake::ConstantOp>::HandshakeConversionPattern;
1633 void buildModule(handshake::ConstantOp op, BackedgeBuilder &bb, RTLBuilder &s,
1634 hw::HWModulePortAccessor &ports) const override {
1635 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1636 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1637 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1638 auto constantValue = op->getAttrOfType<IntegerAttr>("value").getValue();
1639 unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1640 };
1641};
1642
1643class BufferConversionPattern : public HandshakeConversionPattern<BufferOp> {
1644public:
1645 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1646 void buildModule(BufferOp op, BackedgeBuilder &bb, RTLBuilder &s,
1647 hw::HWModulePortAccessor &ports) const override {
1648 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1649 auto input = unwrappedIO.inputs[0];
1650 auto output = unwrappedIO.outputs[0];
1651 InputHandshake lastStage;
1652 SmallVector<int64_t> initValues;
1653
1654 // For now, always build seq buffers.
1655 if (op.getInitValues())
1656 initValues = op.getInitValueArray();
1657
1658 lastStage =
1659 buildSeqBufferLogic(s, bb, toValidType(op.getDataType()),
1660 op.getNumSlots(), input, output, initValues);
1661
1662 // Connect the last stage to the output handshake.
1663 output.data->setValue(lastStage.data);
1664 output.valid->setValue(lastStage.valid);
1665 lastStage.ready->setValue(output.ready);
1666 };
1667
1668 struct SeqBufferStage {
1669 SeqBufferStage(Type dataType, InputHandshake &preStage, BackedgeBuilder &bb,
1670 RTLBuilder &s, size_t index,
1671 std::optional<int64_t> initValue)
1672 : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1673
1674 // Todo: Change when i0 support is added.
1675 c0s = createZeroDataConst(s, s.loc, dataType);
1676 currentStage.ready = std::make_shared<Backedge>(bb.get(s.b.getI1Type()));
1677
1678 auto hasInitValue = s.constant(1, initValue.has_value());
1679 auto validBE = bb.get(s.b.getI1Type());
1680 auto validReg = s.reg(getRegName("valid"), validBE, hasInitValue);
1681 auto readyBE = bb.get(s.b.getI1Type());
1682
1683 Value initValueCs = c0s;
1684 if (initValue.has_value())
1685 initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1686
1687 // This could/should be revised but needs a larger rethinking to avoid
1688 // introducing new bugs.
1689 Value dataReg =
1690 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1691 buildControlBufferLogic(validReg, readyBE, dataReg);
1692 }
1693
1694 StringAttr getRegName(StringRef name) {
1695 return s.b.getStringAttr(name + std::to_string(index) + "_reg");
1696 }
1697
1698 void buildControlBufferLogic(Value validReg, Backedge &readyBE,
1699 Value dataReg) {
1700 auto c0I1 = s.constant(1, 0);
1701 auto readyRegWire = bb.get(s.b.getI1Type());
1702 auto readyReg = s.reg(getRegName("ready"), readyRegWire, c0I1);
1703
1704 // Create the logic to drive the current stage valid and potentially
1705 // data.
1706 currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1707 "controlValid" + std::to_string(index));
1708
1709 // Create the logic to drive the current stage ready.
1710 auto notReadyReg = s.bNot(readyReg);
1711 readyBE.setValue(notReadyReg);
1712
1713 auto succNotReady = s.bNot(*currentStage.ready);
1714 auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1715 auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1716 auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1717
1718 // Create a mux for emptying the register when both are ready.
1719 auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1720 readyRegWire.setValue(resetSignal);
1721
1722 // Add same logic for the data path if necessary.
1723 auto ctrlDataRegBE = bb.get(dataType);
1724 auto ctrlDataReg = s.reg(getRegName("ctrl_data"), ctrlDataRegBE, c0s);
1725 auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1726 currentStage.data = dataResult;
1727
1728 auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1729 auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1730 ctrlDataRegBE.setValue(dataResetSignal);
1731 }
1732
1733 Value buildDataBufferLogic(Value validReg, Value initValue,
1734 Backedge &validBE, Backedge &readyBE) {
1735 // Create a signal for when the valid register is empty or the successor
1736 // is ready to accept new token.
1737 auto notValidReg = s.bNot(validReg);
1738 auto emptyOrReady = s.bOr({notValidReg, readyBE});
1739 preStage.ready->setValue(emptyOrReady);
1740
1741 // Create a mux that drives the register input. If the emptyOrReady
1742 // signal is asserted, the mux selects the predValid signal. Otherwise,
1743 // it selects the register output, keeping the output registered
1744 // unchanged.
1745 auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1746
1747 // Now we can drive the valid register.
1748 validBE.setValue(validRegMux);
1749
1750 // Create a mux that drives the date register.
1751 auto dataRegBE = bb.get(dataType);
1752 auto dataReg =
1753 s.reg(getRegName("data"),
1754 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1755 dataRegBE.setValue(dataReg);
1756 return dataReg;
1757 }
1758
1759 InputHandshake getOutput() { return currentStage; }
1760
1761 Type dataType;
1762 InputHandshake &preStage;
1763 InputHandshake currentStage;
1764 RTLBuilder &s;
1765 BackedgeBuilder &bb;
1766 size_t index;
1767
1768 // A zero-valued constant of equal type as the data type of this buffer.
1769 Value c0s;
1770 };
1771
1772 InputHandshake buildSeqBufferLogic(RTLBuilder &s, BackedgeBuilder &bb,
1773 Type dataType, unsigned size,
1774 InputHandshake &input,
1775 OutputHandshake &output,
1776 llvm::ArrayRef<int64_t> initValues) const {
1777 // Prime the buffer building logic with an initial stage, which just
1778 // wraps the input handshake.
1779 InputHandshake currentStage = input;
1780
1781 for (unsigned i = 0; i < size; ++i) {
1782 bool isInitialized = i < initValues.size();
1783 auto initValue =
1784 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1785 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1786 .getOutput();
1787 }
1788
1789 return currentStage;
1790 };
1791};
1792
1793class IndexCastConversionPattern
1794 : public HandshakeConversionPattern<arith::IndexCastOp> {
1795public:
1796 using HandshakeConversionPattern<
1797 arith::IndexCastOp>::HandshakeConversionPattern;
1798 void buildModule(arith::IndexCastOp op, BackedgeBuilder &bb, RTLBuilder &s,
1799 hw::HWModulePortAccessor &ports) const override {
1800 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1801 unsigned sourceBits =
1802 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1803 unsigned targetBits =
1804 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1805 if (targetBits < sourceBits)
1806 buildTruncateLogic(s, unwrappedIO, targetBits);
1807 else
1808 buildExtendLogic(s, unwrappedIO, /*signExtend=*/true);
1809 };
1810};
1811
1812template <typename T>
1813class ExtModuleConversionPattern : public OpConversionPattern<T> {
1814public:
1815 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1816 MLIRContext *context, OpBuilder &submoduleBuilder,
1817 HandshakeLoweringState &ls)
1818 : OpConversionPattern<T>::OpConversionPattern(typeConverter, context),
1819 submoduleBuilder(submoduleBuilder), ls(ls) {}
1820 using OpAdaptor = typename T::Adaptor;
1821
1822 LogicalResult
1823 matchAndRewrite(T op, OpAdaptor adaptor,
1824 ConversionPatternRewriter &rewriter) const override {
1825
1826 hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
1827 if (!implModule) {
1828 auto portInfo = ModulePortInfo(getPortInfoForOp(op));
1829 implModule = submoduleBuilder.create<hw::HWModuleExternOp>(
1830 op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
1831 portInfo);
1832 }
1833
1834 llvm::SmallVector<Value> operands = adaptor.getOperands();
1835 addSequentialIOOperandsIfNeeded(op, operands);
1836 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1837 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1838 return success();
1839 }
1840
1841private:
1842 OpBuilder &submoduleBuilder;
1843 HandshakeLoweringState &ls;
1844};
1845
1846class FuncOpConversionPattern : public OpConversionPattern<handshake::FuncOp> {
1847public:
1848 using OpConversionPattern::OpConversionPattern;
1849
1850 LogicalResult
1851 matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
1852 ConversionPatternRewriter &rewriter) const override {
1853 ModulePortInfo ports =
1854 getPortInfoForOpTypes(op, op.getArgumentTypes(), op.getResultTypes());
1855
1856 HWModuleLike hwModule;
1857 if (op.isExternal()) {
1858 hwModule = hw::HWModuleExternOp::create(
1859 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1860 } else {
1861 auto hwModuleOp = hw::HWModuleOp::create(
1862 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1863 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1864 rewriter.inlineBlockBefore(&op.getBody().front(),
1865 hwModuleOp.getBodyBlock()->getTerminator(),
1866 args);
1867 hwModule = hwModuleOp;
1868 }
1869
1870 // Was any predeclaration associated with this func? If so, replace uses
1871 // with the newly created module and erase the predeclaration.
1872 if (auto predecl =
1873 op->getAttrOfType<FlatSymbolRefAttr>(kPredeclarationAttr)) {
1874 auto *parentOp = op->getParentOp();
1875 auto *predeclModule =
1876 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1877 if (predeclModule) {
1878 if (failed(SymbolTable::replaceAllSymbolUses(
1879 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1880 return failure();
1881 rewriter.eraseOp(predeclModule);
1882 }
1883 }
1884
1885 rewriter.eraseOp(op);
1886 return success();
1887 }
1888};
1889
1890} // namespace
1891
1892//===----------------------------------------------------------------------===//
1893// HW Top-module Related Functions
1894//===----------------------------------------------------------------------===//
1895
1896static LogicalResult convertFuncOp(ESITypeConverter &typeConverter,
1897 ConversionTarget &target,
1899 OpBuilder &moduleBuilder) {
1900
1901 std::map<std::string, unsigned> instanceNameCntr;
1902 NameUniquer instanceUniquer = [&](Operation *op) {
1903 std::string instName = getCallName(op);
1904 if (auto idAttr = op->getAttrOfType<IntegerAttr>("handshake_id"); idAttr) {
1905 // We use a special naming convention for operations which have a
1906 // 'handshake_id' attribute.
1907 instName += "_id" + std::to_string(idAttr.getValue().getZExtValue());
1908 } else {
1909 // Fallback to just prefixing with an integer.
1910 instName += std::to_string(instanceNameCntr[instName]++);
1911 }
1912 return instName;
1913 };
1914
1915 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1916 instanceUniquer};
1917 RewritePatternSet patterns(op.getContext());
1918 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1919 op.getContext());
1920 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1921 SyncConversionPattern>(typeConverter, op.getContext(),
1922 moduleBuilder, ls);
1923
1924 patterns.insert<
1925 // Comb operations.
1926 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1927 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1928 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1929 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1930 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1931 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1932 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1933 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1934 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1935 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1936 UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1937 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1938 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1939 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1940 // HW operations.
1941 StructCreateConversionPattern,
1942 // Handshake operations.
1943 ConditionalBranchConversionPattern, MuxConversionPattern,
1944 PackConversionPattern, UnpackConversionPattern,
1945 ComparisonConversionPattern, BufferConversionPattern,
1946 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1947 MergeConversionPattern, ControlMergeConversionPattern,
1948 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1949 InstanceConversionPattern,
1950 // Arith operations.
1951 ExtendConversionPattern<arith::ExtUIOp, /*signExtend=*/false>,
1952 ExtendConversionPattern<arith::ExtSIOp, /*signExtend=*/true>,
1953 TruncateConversionPattern, IndexCastConversionPattern>(
1954 typeConverter, op.getContext(), moduleBuilder, ls);
1955
1956 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1957 return op->emitOpError() << "error during conversion";
1958 return success();
1959}
1960
1961namespace {
1962class HandshakeToHWPass
1963 : public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1964public:
1965 void runOnOperation() override {
1966 mlir::ModuleOp mod = getOperation();
1967
1968 // Lowering to HW requires that every value is used exactly once. Check
1969 // whether this precondition is met, and if not, exit.
1970 for (auto f : mod.getOps<handshake::FuncOp>()) {
1971 if (failed(verifyAllValuesHasOneUse(f))) {
1972 f.emitOpError() << "HandshakeToHW: failed to verify that all values "
1973 "are used exactly once. Remember to run the "
1974 "fork/sink materialization pass before HW lowering.";
1975 signalPassFailure();
1976 return;
1977 }
1978 }
1979
1980 // Resolve the instance graph to get a top-level module.
1981 std::string topLevel;
1983 SmallVector<std::string> sortedFuncs;
1984 if (resolveInstanceGraph(mod, uses, topLevel, sortedFuncs).failed()) {
1985 signalPassFailure();
1986 return;
1987 }
1988
1989 ESITypeConverter typeConverter;
1990 ConversionTarget target(getContext());
1991 // All top-level logic of a handshake module will be the interconnectivity
1992 // between instantiated modules.
1993 target.addLegalOp<hw::HWModuleOp, hw::HWModuleExternOp, hw::OutputOp,
1994 hw::InstanceOp>();
1995 target
1996 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1997
1998 // Convert the handshake.func operations in post-order wrt. the instance
1999 // graph. This ensures that any referenced submodules (through
2000 // handshake.instance) has already been lowered, and their HW module
2001 // equivalents are available.
2002 OpBuilder submoduleBuilder(mod.getContext());
2003 submoduleBuilder.setInsertionPointToStart(mod.getBody());
2004 for (auto &funcName : llvm::reverse(sortedFuncs)) {
2005 auto funcOp = mod.lookupSymbol<handshake::FuncOp>(funcName);
2006 assert(funcOp && "handshake.func not found in module!");
2007 if (failed(
2008 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2009 signalPassFailure();
2010 return;
2011 }
2012 }
2013
2014 // Second stage: Convert any handshake.extmemory operations and the
2015 // top-level I/O associated with these.
2016 for (auto hwModule : mod.getOps<hw::HWModuleOp>())
2017 if (failed(convertExtMemoryOps(hwModule)))
2018 return signalPassFailure();
2019
2020 // Run conversions which need see everything.
2021 HWSymbolCache symbolCache;
2022 symbolCache.addDefinitions(mod);
2023 symbolCache.freeze();
2024 RewritePatternSet patterns(mod.getContext());
2025 patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2026 symbolCache);
2027 if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
2028 mod->emitOpError() << "error during conversion";
2029 signalPassFailure();
2030 }
2031 }
2032};
2033} // end anonymous namespace
2034
2035std::unique_ptr<mlir::Pass> circt::createHandshakeToHWPass() {
2036 return std::make_unique<HandshakeToHWPass>();
2037}
AIGLongestPathObject wrap(llvm::PointerUnion< Object *, DataflowPath::OutputPort * > object)
Definition AIG.cpp:57
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
Definition CalyxOps.cpp:121
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static Type tupleToStruct(TupleType tuple)
Definition DCToHW.cpp:48
std::function< std::string(Operation *)> NameUniquer
Definition DCToHW.cpp:45
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition HWOps.cpp:1428
static std::string getCallName(Operation *op)
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
Definition OM.cpp:111
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
void setOutput(unsigned i, Value v)
Definition HWOps.cpp:241
const ModulePortInfo & getPortList() const
Definition HWOps.h:111
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
Definition HWSymCache.h:75
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
create(low_bit, result_type, input=None)
Definition comb.py:187
Channels are the basic communication primitives.
Definition Types.h:102
const Type * getInner() const
Definition Types.h:105
create(elements)
Definition hw.py:483
create(array_value, idx)
Definition hw.py:450
create(data_type, value)
Definition hw.py:433
create(elements, Type result_type=None)
Definition hw.py:532
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition seq.py:157
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:227
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
Definition hw.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.