26#include "mlir/Dialect/Arith/IR/Arith.h"
27#include "mlir/Dialect/MemRef/IR/MemRef.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/MathExtras.h"
37#define GEN_PASS_DEF_HANDSHAKETOHW
38#include "circt/Conversion/Passes.h.inc"
51 return toValidType(mlir::TupleType::get(types[0].getContext(), types));
56struct HandshakeLoweringState {
57 ModuleOp parentModule;
64class ESITypeConverter :
public TypeConverter {
67 addConversion([](Type type) -> Type {
return esiWrapper(type); });
69 addTargetMaterialization([&](mlir::OpBuilder &builder,
70 mlir::Type resultType, mlir::ValueRange inputs,
71 mlir::Location loc) -> mlir::Value {
72 if (inputs.size() != 1)
74 return UnrealizedConversionCastOp::create(builder, loc, resultType,
79 addSourceMaterialization([&](mlir::OpBuilder &builder,
80 mlir::Type resultType, mlir::ValueRange inputs,
81 mlir::Location loc) -> mlir::Value {
82 if (inputs.size() != 1)
84 return UnrealizedConversionCastOp::create(builder, loc, resultType,
99 std::string subModuleName = oldOp->getName().getStringRef().str();
100 std::replace(subModuleName.begin(), subModuleName.end(),
'.',
'_');
101 return subModuleName;
105 auto callOp = dyn_cast<handshake::InstanceOp>(op);
113 auto opType = op.getType();
114 if (
auto channelType = dyn_cast<esi::ChannelType>(opType))
115 return channelType.getInner();
123 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
124 .Case<MemoryOp, ExternalMemoryOp>([&](
auto memOp) {
126 {memOp.getMemRefType().getElementType()}};
131 SmallVector<Type> inTypes, outTypes;
132 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
134 llvm::transform(op->getResults(), std::back_inserter(outTypes),
145 std::string typeName;
147 if (type.isIntOrIndex()) {
148 if (
auto indexType = dyn_cast<IndexType>(type))
149 typeName +=
"_ui" + std::to_string(indexType.kInternalStorageBitWidth);
150 else if (type.isSignedInteger())
151 typeName +=
"_si" + std::to_string(type.getIntOrFloatBitWidth());
153 typeName +=
"_ui" + std::to_string(type.getIntOrFloatBitWidth());
154 }
else if (
auto tupleType = dyn_cast<TupleType>(type)) {
155 typeName +=
"_tuple";
158 }
else if (
auto structType = dyn_cast<hw::StructType>(type)) {
159 typeName +=
"_struct";
160 for (
auto element : structType.getElements())
161 typeName +=
"_" + element.name.str() +
getTypeName(loc, element.type);
162 }
else if (isa<NoneType>(type)) {
165 emitError(loc) <<
"unsupported data type '" << type <<
"'";
172 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
173 return instanceOp.getModule().str();
178 if (
auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
179 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
180 auto intType = intAttr.getType();
182 if (intType.isSignedInteger())
183 subModuleName +=
"_c" + std::to_string(intAttr.getSInt());
184 else if (intType.isUnsignedInteger())
185 subModuleName +=
"_c" + std::to_string(intAttr.getUInt());
187 subModuleName +=
"_c" + std::to_string((uint64_t)intAttr.getInt());
189 oldOp->emitError(
"unsupported constant type");
194 if (!inTypes.empty())
195 subModuleName +=
"_in";
196 for (
auto inType : inTypes)
197 subModuleName +=
getTypeName(oldOp->getLoc(), inType);
199 if (!outTypes.empty())
200 subModuleName +=
"_out";
201 for (
auto outType : outTypes)
202 subModuleName +=
getTypeName(oldOp->getLoc(), outType);
205 if (
auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
206 subModuleName +=
"_id" + std::to_string(memOp.getId());
209 if (
auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
210 subModuleName +=
"_" + stringifyEnum(comOp.getPredicate()).str();
213 if (
auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
214 subModuleName +=
"_" + std::to_string(bufferOp.getNumSlots()) +
"slots";
215 if (bufferOp.isSequential())
216 subModuleName +=
"_seq";
218 subModuleName +=
"_fifo";
220 if (
auto initValues = bufferOp.getInitValues()) {
221 subModuleName +=
"_init";
222 for (
const Attribute e : *initValues) {
223 assert(isa<IntegerAttr>(e));
225 "_" + std::to_string(dyn_cast<IntegerAttr>(e).
getInt());
231 if (
auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
232 ctrlInterface && ctrlInterface.isControl()) {
234 subModuleName +=
"_" + std::to_string(oldOp->getNumOperands()) +
"ins_" +
235 std::to_string(oldOp->getNumResults()) +
"outs";
236 subModuleName +=
"_ctrl";
239 (!inTypes.empty() || !outTypes.empty()) &&
240 "Insufficient discriminating type info generated for the operation!");
243 return subModuleName;
255 if (
auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
257 if (
auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
264 HWModuleLike targetModule;
265 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
270 if (isa<handshake::InstanceOp>(oldOp))
272 "handshake.instance target modules should always have been lowered "
273 "before the modules that reference them!");
283static llvm::SmallVector<hw::detail::FieldInfo>
285 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
286 for (
auto port : portInfo)
287 fieldInfo.push_back({port.name, port.type});
295 auto *ctx = mod.getContext();
298 llvm::DenseMap<unsigned, Value> memrefPorts;
299 for (
auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
300 auto channel = dyn_cast<esi::ChannelType>(arg.getType());
301 if (channel && isa<MemRefType>(channel.getInner()))
302 memrefPorts[i] = arg;
305 if (memrefPorts.empty())
310 auto getMemoryIOInfo = [&](Location loc, Twine portName,
unsigned argIdx,
311 ArrayRef<hw::PortInfo> info,
315 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
319 for (
auto [i, arg] : memrefPorts) {
321 auto memName = mod.getArgName(i);
324 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
326 cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
327 extmemInstance, extmemInstance.getModuleNameAttr()));
338 SmallVector<PortInfo> outputs(portInfo.
getOutputs());
340 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_in", i, outputs,
341 hw::ModulePort::Direction::Input);
342 mod.insertPorts({{i, inPortInfo}}, {});
343 auto newInPort = mod.getArgumentForInput(i);
345 b.setInsertionPointToStart(mod.getBodyBlock());
346 auto newInPortExploded = hw::StructExplodeOp::create(
347 b, arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
348 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
352 unsigned outArgI = mod.getNumOutputPorts();
353 SmallVector<PortInfo> inputs(portInfo.
getInputs());
355 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_out", outArgI,
356 inputs, hw::ModulePort::Direction::Output);
358 auto memOutputArgs = extmemInstance.getOperands().drop_front();
359 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
361 b, arg.getLoc(), outPortInfo.type, memOutputArgs);
362 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
367 extmemInstance.erase();
371 mod.modifyPorts( {}, {},
382struct InputHandshake {
384 std::shared_ptr<Backedge> ready;
390struct OutputHandshake {
391 std::shared_ptr<Backedge> valid;
393 std::shared_ptr<Backedge>
data;
398struct HandshakeWire {
400 MLIRContext *ctx = dataType.getContext();
401 auto i1Type = IntegerType::get(ctx, 1);
402 valid = std::make_shared<Backedge>(bb.
get(i1Type));
403 ready = std::make_shared<Backedge>(bb.
get(i1Type));
404 data = std::make_shared<Backedge>(bb.
get(dataType));
409 InputHandshake getAsInput() {
return {*valid, ready, *
data}; }
410 OutputHandshake getAsOutput() {
return {valid, *ready,
data}; }
412 std::shared_ptr<Backedge> valid;
413 std::shared_ptr<Backedge> ready;
414 std::shared_ptr<Backedge>
data;
417template <
typename T,
typename TInner>
418llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
419 llvm::function_ref<T(TInner &)> extractor) {
420 llvm::SmallVector<T> result;
421 llvm::transform(container, std::back_inserter(result), extractor);
425 llvm::SmallVector<InputHandshake> inputs;
426 llvm::SmallVector<OutputHandshake> outputs;
428 llvm::SmallVector<Value> getInputValids() {
429 return extractValues<Value, InputHandshake>(
430 inputs, [](
auto &hs) {
return hs.valid; });
432 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
433 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
434 inputs, [](
auto &hs) {
return hs.ready; });
436 llvm::SmallVector<Value> getInputDatas() {
437 return extractValues<Value, InputHandshake>(
438 inputs, [](
auto &hs) {
return hs.data; });
440 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
441 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
442 outputs, [](
auto &hs) {
return hs.valid; });
444 llvm::SmallVector<Value> getOutputReadys() {
445 return extractValues<Value, OutputHandshake>(
446 outputs, [](
auto &hs) {
return hs.ready; });
448 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
449 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
450 outputs, [](
auto &hs) {
return hs.data; });
459 Value clk = Value(), Value rst = Value())
460 :
info(std::move(
info)), b(builder), loc(loc),
clk(
clk), rst(rst) {}
462 Value constant(
const APInt &apv, std::optional<StringRef> name = {}) {
465 bool isZeroWidth = apv.getBitWidth() == 0;
467 auto it = constants.find(apv);
468 if (it != constants.end())
474 constants[apv] = cval;
478 Value constant(
unsigned width, int64_t value,
479 std::optional<StringRef> name = {}) {
481 APInt(width, value,
false,
true));
483 std::pair<Value, Value>
wrap(Value data, Value valid,
484 std::optional<StringRef> name = {}) {
485 auto wrapOp = esi::WrapValidReadyOp::create(b, loc, data, valid);
486 return {wrapOp.getResult(0), wrapOp.getResult(1)};
488 std::pair<Value, Value>
unwrap(Value channel, Value ready,
489 std::optional<StringRef> name = {}) {
490 auto unwrapOp = esi::UnwrapValidReadyOp::create(b, loc, channel, ready);
491 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
495 Value
reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
496 Value rst = Value()) {
497 Value resolvedClk =
clk ?
clk : this->
clk;
498 Value resolvedRst = rst ? rst : this->rst;
500 "No global clock provided to this RTLBuilder - a clock "
501 "signal must be provided to the reg(...) function.");
503 "No global reset provided to this RTLBuilder - a reset "
504 "signal must be provided to the reg(...) function.");
510 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
511 std::optional<StringRef> name = {}) {
512 return comb::ICmpOp::create(b, loc, predicate, lhs, rhs);
515 Value buildNamedOp(llvm::function_ref<Value()> f,
516 std::optional<StringRef> name) {
519 Operation *op = v.getDefiningOp();
520 if (name.has_value()) {
521 op->setAttr(
"sv.namehint", b.getStringAttr(*name));
522 nameAttr = b.getStringAttr(*name);
528 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
530 [&]() {
return comb::AndOp::create(b, loc, values,
false); }, name);
533 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
535 [&]() {
return comb::OrOp::create(b, loc, values,
false); }, name);
539 Value bNot(Value value, std::optional<StringRef> name = {}) {
540 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
541 std::string inferedName;
545 value.getDefiningOp()->getAttrOfType<StringAttr>(
"sv.namehint")) {
546 inferedName = (
"not_" +
valueName.getValue()).str();
552 [&]() {
return comb::XorOp::create(b, loc, value, allOnes); }, name);
554 return b.createOrFold<
comb::XorOp>(loc, value, allOnes,
false);
557 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
559 [&]() {
return comb::ShlOp::create(b, loc, value, shift); }, name);
562 Value concat(ValueRange values, std::optional<StringRef> name = {}) {
564 [&]() {
return comb::ConcatOp::create(b, loc, values); }, name);
568 Value pack(ValueRange values, Type structType = Type(),
569 std::optional<StringRef> name = {}) {
580 ValueRange unpack(Value value) {
581 auto structType = cast<hw::StructType>(value.getType());
582 llvm::SmallVector<Type> innerTypes;
583 structType.getInnerTypes(innerTypes);
584 return hw::StructExplodeOp::create(b, loc, innerTypes, value).getResults();
587 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
588 llvm::SmallVector<Value> bits;
589 for (
unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
595 Value rOr(Value v, std::optional<StringRef> name = {}) {
596 return buildNamedOp([&]() {
return bOr(toBits(v)); }, name);
600 Value extract(Value v,
unsigned lo,
unsigned hi,
601 std::optional<StringRef> name = {}) {
602 unsigned width = hi - lo + 1;
608 Value truncate(Value value,
unsigned width,
609 std::optional<StringRef> name = {}) {
610 return extract(value, 0, width - 1, name);
613 Value zext(Value value,
unsigned outWidth,
614 std::optional<StringRef> name = {}) {
615 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
616 assert(inWidth <= outWidth &&
"zext: input width must be <- output width.");
617 if (inWidth == outWidth)
619 auto c0 = constant(outWidth - inWidth, 0);
620 return concat({c0, value}, name);
623 Value sext(Value value,
unsigned outWidth,
624 std::optional<StringRef> name = {}) {
625 return comb::createOrFoldSExt(b, loc, value, b.getIntegerType(outWidth));
629 Value bit(Value v,
unsigned index, std::optional<StringRef> name = {}) {
630 return extract(v, index, index, name);
634 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
640 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
648 Value mux(Value index, ValueRange values,
649 std::optional<StringRef> name = {}) {
650 if (values.size() == 2)
651 return comb::MuxOp::create(b, loc, index, values[1], values[0]);
653 return arrayGet(arrayCreate(values), index, name);
658 Value ohMux(Value index, ValueRange inputs) {
660 unsigned numInputs = inputs.size();
661 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
662 "one-hot select can't mux inputs");
666 auto dataType = inputs[0].getType();
668 isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
669 Value muxValue = constant(width, 0);
672 for (
size_t i = numInputs - 1; i != 0; --i) {
673 Value input = inputs[i];
674 Value selectBit = bit(index, i);
675 muxValue = mux(selectBit, {muxValue, input});
685 DenseMap<APInt, Value> constants;
690static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
691 return TypeSwitch<Type, Value>(type)
692 .Case<NoneType>([&](NoneType) {
return s.constant(0, 0); })
693 .Case<IntType, IntegerType>([&](
auto type) {
694 return s.constant(type.getIntOrFloatBitWidth(), 0);
696 .Case<hw::StructType>([&](
auto structType) {
697 SmallVector<Value> zeroValues;
698 for (
auto field : structType.getElements())
699 zeroValues.push_back(createZeroDataConst(
s, loc, field.type));
702 .Default([&](Type) -> Value {
703 emitError(loc) <<
"unsupported type for zero value: " << type;
710addSequentialIOOperandsIfNeeded(Operation *op,
711 llvm::SmallVectorImpl<Value> &operands) {
715 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
717 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
719 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
726 HandshakeConversionPattern(ESITypeConverter &typeConverter,
727 MLIRContext *
context, OpBuilder &submoduleBuilder,
728 HandshakeLoweringState &ls)
730 submoduleBuilder(submoduleBuilder), ls(ls) {}
732 using OpAdaptor =
typename T::Adaptor;
735 matchAndRewrite(T op, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const override {
745 submoduleBuilder.setInsertionPoint(op->getParentOp());
746 implModule = hw::HWModuleOp::create(
747 submoduleBuilder, op.getLoc(),
753 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
754 clk = ports.getInput(
"clock");
755 rst = ports.getInput(
"reset");
759 RTLBuilder
s(ports.
getPortList(), b, op.getLoc(), clk, rst);
765 llvm::SmallVector<Value> operands = adaptor.getOperands();
766 addSequentialIOOperandsIfNeeded(op, operands);
767 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
768 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
781 UnwrappedIO unwrapped;
782 for (
auto port : ports.getInputs()) {
783 if (!isa<esi::ChannelType>(port.getType()))
786 auto ready = std::make_shared<Backedge>(bb.
get(
s.b.getI1Type()));
787 auto [
data, valid] =
s.unwrap(port, *ready);
791 unwrapped.inputs.push_back(hs);
793 for (
auto &outputInfo : ports.
getPortList().getOutputs()) {
795 dyn_cast<esi::ChannelType>(outputInfo.type);
800 auto data = std::make_shared<Backedge>(bb.
get(innerType));
801 auto valid = std::make_shared<Backedge>(bb.
get(
s.b.getI1Type()));
802 auto [dataCh, ready] =
s.wrap(*data, *valid);
806 ports.
setOutput(outputInfo.name, dataCh);
807 unwrapped.outputs.push_back(hs);
812 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
813 OutputHandshake &output, Value cond)
const {
814 auto validAndReady =
s.bAnd({output.ready, cond});
815 for (
auto &input : inputs)
816 input.ready->setValue(validAndReady);
819 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
820 OutputHandshake &output)
const {
821 llvm::SmallVector<Value> valids;
822 for (
auto &input : inputs)
823 valids.push_back(input.valid);
824 Value allValid =
s.bAnd(valids);
825 output.valid->setValue(allValid);
826 setAllReadyWithCond(s, inputs, output, allValid);
832 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
833 InputHandshake &select)
const {
835 size_t numInputs = unwrapped.inputs.size();
836 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
837 Value truncatedSelect =
838 select.data.getType().getIntOrFloatBitWidth() > selectWidth
839 ?
s.truncate(select.data, selectWidth)
843 auto selectZext =
s.zext(truncatedSelect, numInputs);
844 auto select1h =
s.shl(
s.constant(numInputs, 1), selectZext);
845 auto &res = unwrapped.outputs[0];
848 auto selectedInputValid =
849 s.mux(truncatedSelect, unwrapped.getInputValids());
851 auto selAndInputValid =
s.bAnd({selectedInputValid, select.valid});
852 res.valid->setValue(selAndInputValid);
853 auto resValidAndReady =
s.bAnd({selAndInputValid, res.ready});
856 select.ready->setValue(resValidAndReady);
859 for (
auto [inIdx, in] :
llvm::enumerate(unwrapped.inputs)) {
861 auto isSelected =
s.bit(select1h, inIdx);
865 auto activeAndResultValidAndReady =
866 s.bAnd({isSelected, resValidAndReady});
867 in.ready->setValue(activeAndResultValidAndReady);
871 res.data->setValue(
s.mux(truncatedSelect, unwrapped.getInputDatas()));
876 void buildForkLogic(RTLBuilder &s,
BackedgeBuilder &bb, InputHandshake &input,
877 ArrayRef<OutputHandshake> outputs)
const {
878 auto c0I1 =
s.constant(1, 0);
879 llvm::SmallVector<Value> doneWires;
880 for (
auto [i, output] :
llvm::enumerate(outputs)) {
881 auto doneBE = bb.
get(
s.b.getI1Type());
882 auto emitted =
s.bAnd({doneBE,
s.bNot(*input.ready)});
883 auto emittedReg =
s.reg(
"emitted_" + std::to_string(i), emitted, c0I1);
884 auto outValid =
s.bAnd({
s.bNot(emittedReg), input.valid});
885 output.valid->setValue(outValid);
886 auto validReady =
s.bAnd({output.ready, outValid});
887 auto done =
s.bOr({validReady, emittedReg},
"done" + std::to_string(i));
888 doneBE.setValue(done);
889 doneWires.push_back(done);
891 input.ready->setValue(
s.bAnd(doneWires,
"allDone"));
897 void buildUnitRateJoinLogic(
898 RTLBuilder &s, UnwrappedIO &unwrappedIO,
899 llvm::function_ref<Value(ValueRange)> unitBuilder)
const {
900 assert(unwrappedIO.outputs.size() == 1 &&
901 "Expected exactly one output for unit-rate join actor");
903 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
906 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
907 unwrappedIO.outputs[0].data->setValue(unitRes);
910 void buildUnitRateForkLogic(
912 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder)
const {
913 assert(unwrappedIO.inputs.size() == 1 &&
914 "Expected exactly one input for unit-rate fork actor");
916 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
919 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
920 assert(unitResults.size() == unwrappedIO.outputs.size() &&
921 "Expected unit builder to return one result per output");
922 for (
auto [res, outport] :
llvm::zip(unitResults, unwrappedIO.outputs))
923 outport.
data->setValue(res);
926 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
927 bool signExtend)
const {
929 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
930 .getIntOrFloatBitWidth();
931 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
933 return s.sext(inputs[0], outWidth);
934 return s.zext(inputs[0], outWidth);
938 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
939 unsigned targetWidth)
const {
941 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
942 .getIntOrFloatBitWidth();
943 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
944 return s.truncate(inputs[0], outWidth);
949 static size_t getNumIndexBits(uint64_t numValues) {
950 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
953 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
955 DenseMap<size_t, Value> &indexMapping)
const {
956 auto numInputs = inputs.size();
957 auto priorityArb = defaultValue;
959 for (
size_t i = numInputs; i > 0; --i) {
960 size_t inputIndex = i - 1;
961 size_t oneHotIndex =
size_t{1} << inputIndex;
962 auto constIndex =
s.constant(numInputs, oneHotIndex);
963 indexMapping[inputIndex] = constIndex;
964 priorityArb =
s.mux(inputs[inputIndex], {priorityArb, constIndex});
970 OpBuilder &submoduleBuilder;
971 HandshakeLoweringState &ls;
974class ForkConversionPattern :
public HandshakeConversionPattern<ForkOp> {
976 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
979 auto unwrapped = unwrapIO(s, bb, ports);
980 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
981 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
986class JoinConversionPattern :
public HandshakeConversionPattern<JoinOp> {
988 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
991 auto unwrappedIO = unwrapIO(s, bb, ports);
992 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
993 unwrappedIO.outputs[0].data->setValue(
s.constant(0, 0));
997class SyncConversionPattern :
public HandshakeConversionPattern<SyncOp> {
999 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1002 auto unwrappedIO = unwrapIO(s, bb, ports);
1005 HandshakeWire wire(bb,
s.b.getNoneType());
1007 OutputHandshake output = wire.getAsOutput();
1008 buildJoinLogic(s, unwrappedIO.inputs, output);
1010 InputHandshake input = wire.getAsInput();
1018 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1022 for (
auto &&[in, out] :
llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1027class MuxConversionPattern :
public HandshakeConversionPattern<MuxOp> {
1029 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1032 auto unwrappedIO = unwrapIO(s, bb, ports);
1035 auto select = unwrappedIO.inputs[0];
1036 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1037 buildMuxLogic(s, unwrappedIO, select);
1041class InstanceConversionPattern
1042 :
public HandshakeConversionPattern<handshake::InstanceOp> {
1044 using HandshakeConversionPattern<
1045 handshake::InstanceOp>::HandshakeConversionPattern;
1049 "If we indeed perform conversion in post-order, this "
1050 "should never be called. The base HandshakeConversionPattern logic "
1051 "will instantiate the external module.");
1055class ESIInstanceConversionPattern
1058 ESIInstanceConversionPattern(MLIRContext *
context,
1063 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1064 ConversionPatternRewriter &rewriter)
const override {
1070 SmallVector<Value> operands;
1071 for (
size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1073 operands.push_back(adaptor.getOperands()[i]);
1074 operands.push_back(adaptor.getClk());
1075 operands.push_back(adaptor.getRst());
1078 Operation *targetModule = symCache.
getDefinition(op.getModuleAttr());
1080 rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1081 op.getInstNameAttr(), operands);
1089class ReturnConversionPattern
1092 using OpConversionPattern::OpConversionPattern;
1094 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1095 ConversionPatternRewriter &rewriter)
const override {
1098 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1099 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1100 outputOp->setOperands(adaptor.getOperands());
1101 outputOp->moveAfter(&parent.getBodyBlock()->back());
1102 rewriter.eraseOp(op);
1109template <
typename TIn,
typename TOut = TIn>
1110class UnitRateConversionPattern :
public HandshakeConversionPattern<TIn> {
1112 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1115 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1116 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1121 return TOut::create(
s.b, op.getLoc(), inputs,
1122 ArrayRef<NamedAttribute>{});
1127class PackConversionPattern :
public HandshakeConversionPattern<PackOp> {
1129 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1132 auto unwrappedIO = unwrapIO(s, bb, ports);
1133 buildUnitRateJoinLogic(s, unwrappedIO,
1134 [&](ValueRange inputs) {
return s.pack(inputs); });
1138class StructCreateConversionPattern
1139 :
public HandshakeConversionPattern<hw::StructCreateOp> {
1141 using HandshakeConversionPattern<
1145 auto unwrappedIO = unwrapIO(s, bb, ports);
1146 auto structType = op.getResult().getType();
1147 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1148 return s.pack(inputs, structType);
1153class UnpackConversionPattern :
public HandshakeConversionPattern<UnpackOp> {
1155 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1158 auto unwrappedIO = unwrapIO(s, bb, ports);
1159 buildUnitRateForkLogic(s, bb, unwrappedIO,
1160 [&](Value input) {
return s.unpack(input); });
1164class ConditionalBranchConversionPattern
1165 :
public HandshakeConversionPattern<ConditionalBranchOp> {
1167 using HandshakeConversionPattern<
1168 ConditionalBranchOp>::HandshakeConversionPattern;
1171 auto unwrappedIO = unwrapIO(s, bb, ports);
1172 auto cond = unwrappedIO.inputs[0];
1173 auto arg = unwrappedIO.inputs[1];
1174 auto trueRes = unwrappedIO.outputs[0];
1175 auto falseRes = unwrappedIO.outputs[1];
1177 auto condArgValid =
s.bAnd({cond.valid, arg.valid});
1180 trueRes.valid->setValue(
s.bAnd({cond.data, condArgValid}));
1181 falseRes.valid->setValue(
s.bAnd({s.bNot(cond.data), condArgValid}));
1184 trueRes.data->setValue(arg.data);
1185 falseRes.data->setValue(arg.data);
1188 auto selectedResultReady =
1189 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1190 auto condArgReady =
s.bAnd({selectedResultReady, condArgValid});
1191 arg.ready->setValue(condArgReady);
1192 cond.ready->setValue(condArgReady);
1196template <
typename TIn,
bool signExtend>
1197class ExtendConversionPattern :
public HandshakeConversionPattern<TIn> {
1199 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1202 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1203 this->buildExtendLogic(s, unwrappedIO, signExtend);
1209template <
typename ArithOp, comb::ICmpPredicate pred>
1210class MinMaxConversionPattern :
public HandshakeConversionPattern<ArithOp> {
1212 using HandshakeConversionPattern<ArithOp>::HandshakeConversionPattern;
1215 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1216 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1220 comb::ICmpOp::create(
s.b, op.getLoc(), pred, inputs[0], inputs[1]);
1221 return comb::MuxOp::create(
s.b, op.getLoc(), cmp, inputs[0], inputs[1]);
1226class ComparisonConversionPattern
1227 :
public HandshakeConversionPattern<arith::CmpIOp> {
1229 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1232 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1233 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1234 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1235 return comb::ICmpOp::create(
s.b, op.getLoc(), predicate, inputs[0],
1240 switch (op.getPredicate()) {
1241 case arith::CmpIPredicate::eq:
1242 return buildCompareLogic(comb::ICmpPredicate::eq);
1243 case arith::CmpIPredicate::ne:
1244 return buildCompareLogic(comb::ICmpPredicate::ne);
1245 case arith::CmpIPredicate::slt:
1246 return buildCompareLogic(comb::ICmpPredicate::slt);
1247 case arith::CmpIPredicate::ult:
1248 return buildCompareLogic(comb::ICmpPredicate::ult);
1249 case arith::CmpIPredicate::sle:
1250 return buildCompareLogic(comb::ICmpPredicate::sle);
1251 case arith::CmpIPredicate::ule:
1252 return buildCompareLogic(comb::ICmpPredicate::ule);
1253 case arith::CmpIPredicate::sgt:
1254 return buildCompareLogic(comb::ICmpPredicate::sgt);
1255 case arith::CmpIPredicate::ugt:
1256 return buildCompareLogic(comb::ICmpPredicate::ugt);
1257 case arith::CmpIPredicate::sge:
1258 return buildCompareLogic(comb::ICmpPredicate::sge);
1259 case arith::CmpIPredicate::uge:
1260 return buildCompareLogic(comb::ICmpPredicate::uge);
1262 assert(
false &&
"invalid CmpIOp");
1266class TruncateConversionPattern
1267 :
public HandshakeConversionPattern<arith::TruncIOp> {
1269 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1272 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1273 unsigned targetBits =
1274 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1275 buildTruncateLogic(s, unwrappedIO, targetBits);
1279class ControlMergeConversionPattern
1280 :
public HandshakeConversionPattern<ControlMergeOp> {
1282 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1285 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1286 auto resData = unwrappedIO.outputs[0];
1287 auto resIndex = unwrappedIO.outputs[1];
1290 unsigned numInputs = unwrappedIO.inputs.size();
1291 auto indexType =
s.b.getIntegerType(numInputs);
1292 Value noWinner =
s.constant(numInputs, 0);
1293 Value c0I1 =
s.constant(1, 0);
1296 auto won = bb.
get(indexType);
1297 Value wonReg =
s.reg(
"won_reg", won, noWinner);
1300 auto win = bb.
get(indexType);
1304 auto fired = bb.
get(
s.b.getI1Type());
1307 auto resultEmitted = bb.
get(
s.b.getI1Type());
1308 Value resultEmittedReg =
s.reg(
"result_emitted_reg", resultEmitted, c0I1);
1309 auto indexEmitted = bb.
get(
s.b.getI1Type());
1310 Value indexEmittedReg =
s.reg(
"index_emitted_reg", indexEmitted, c0I1);
1313 auto resultDone = bb.
get(
s.b.getI1Type());
1314 auto indexDone = bb.
get(
s.b.getI1Type());
1318 auto hasWinnerCondition =
s.rOr({win});
1319 auto hadWinnerCondition =
s.rOr({wonReg});
1327 DenseMap<size_t, Value> argIndexValues;
1328 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1329 noWinner, argIndexValues);
1330 priorityArb =
s.mux(hadWinnerCondition, {priorityArb, wonReg});
1331 win.setValue(priorityArb);
1341 auto resultNotEmitted =
s.bNot(resultEmittedReg);
1342 auto resultValid =
s.bAnd({hasWinnerCondition, resultNotEmitted});
1343 resData.valid->setValue(resultValid);
1344 resData.data->setValue(
s.ohMux(win, unwrappedIO.getInputDatas()));
1346 auto indexNotEmitted =
s.bNot(indexEmittedReg);
1347 auto indexValid =
s.bAnd({hasWinnerCondition, indexNotEmitted});
1348 resIndex.valid->setValue(indexValid);
1352 SmallVector<Value, 8> indexOutputs;
1353 for (
size_t i = 0; i < numInputs; ++i)
1354 indexOutputs.push_back(
s.constant(64, i));
1356 auto indexOutput =
s.ohMux(win, indexOutputs);
1357 resIndex.data->setValue(indexOutput);
1363 won.setValue(
s.mux(fired, {win, noWinner}));
1368 auto resultValidAndReady =
s.bAnd({resultValid, resData.ready});
1369 resultDone.setValue(
s.bOr({resultValidAndReady, resultEmittedReg}));
1371 auto indexValidAndReady =
s.bAnd({indexValid, resIndex.ready});
1372 indexDone.setValue(
s.bOr({indexValidAndReady, indexEmittedReg}));
1376 fired.setValue(
s.bAnd({resultDone, indexDone}));
1382 resultEmitted.setValue(
s.mux(fired, {resultDone, c0I1}));
1383 indexEmitted.setValue(
s.mux(fired, {indexDone, c0I1}));
1388 auto winnerOrDefault =
s.mux(fired, {noWinner, win});
1389 for (
auto [i, ir] :
llvm::enumerate(unwrappedIO.getInputReadys())) {
1390 auto &indexValue = argIndexValues[i];
1391 ir->setValue(
s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1396class MergeConversionPattern :
public HandshakeConversionPattern<MergeOp> {
1398 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1401 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1402 auto resData = unwrappedIO.outputs[0];
1405 unsigned numInputs = unwrappedIO.inputs.size();
1406 auto indexType =
s.b.getIntegerType(numInputs);
1407 Value noWinner =
s.constant(numInputs, 0);
1410 auto win = bb.
get(indexType);
1413 auto hasWinnerCondition =
s.rOr(win);
1420 DenseMap<size_t, Value> argIndexValues;
1421 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1422 noWinner, argIndexValues);
1423 win.setValue(priorityArb);
1430 resData.valid->setValue(hasWinnerCondition);
1431 resData.data->setValue(
s.ohMux(win, unwrappedIO.getInputDatas()));
1436 auto resultValidAndReady =
s.bAnd({hasWinnerCondition, resData.ready});
1441 auto winnerOrDefault =
s.mux(resultValidAndReady, {noWinner, win});
1442 for (
auto [i, ir] :
llvm::enumerate(unwrappedIO.getInputReadys())) {
1443 auto &indexValue = argIndexValues[i];
1444 ir->setValue(
s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1449class LoadConversionPattern
1450 :
public HandshakeConversionPattern<handshake::LoadOp> {
1452 using HandshakeConversionPattern<
1453 handshake::LoadOp>::HandshakeConversionPattern;
1456 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1457 auto addrFromUser = unwrappedIO.inputs[0];
1458 auto dataFromMem = unwrappedIO.inputs[1];
1459 auto controlIn = unwrappedIO.inputs[2];
1460 auto dataToUser = unwrappedIO.outputs[0];
1461 auto addrToMem = unwrappedIO.outputs[1];
1463 addrToMem.data->setValue(addrFromUser.data);
1464 dataToUser.data->setValue(dataFromMem.data);
1468 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1472 dataToUser.valid->setValue(dataFromMem.valid);
1473 dataFromMem.ready->setValue(dataToUser.ready);
1477class StoreConversionPattern
1478 :
public HandshakeConversionPattern<handshake::StoreOp> {
1480 using HandshakeConversionPattern<
1481 handshake::StoreOp>::HandshakeConversionPattern;
1484 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1485 auto addrFromUser = unwrappedIO.inputs[0];
1486 auto dataFromUser = unwrappedIO.inputs[1];
1487 auto controlIn = unwrappedIO.inputs[2];
1488 auto dataToMem = unwrappedIO.outputs[0];
1489 auto addrToMem = unwrappedIO.outputs[1];
1492 auto outputsReady =
s.bAnd({dataToMem.ready, addrToMem.ready});
1496 HandshakeWire joinWire(bb,
s.b.getNoneType());
1497 joinWire.ready->setValue(outputsReady);
1498 OutputHandshake joinOutput = joinWire.getAsOutput();
1499 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1502 addrToMem.data->setValue(addrFromUser.data);
1503 dataToMem.data->setValue(dataFromUser.data);
1506 addrToMem.valid->setValue(*joinWire.valid);
1507 dataToMem.valid->setValue(*joinWire.valid);
1511class MemoryConversionPattern
1512 :
public HandshakeConversionPattern<handshake::MemoryOp> {
1514 using HandshakeConversionPattern<
1515 handshake::MemoryOp>::HandshakeConversionPattern;
1518 auto loc = op.getLoc();
1521 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1523 InputHandshake &
addr;
1524 OutputHandshake &
data;
1525 OutputHandshake &done;
1528 InputHandshake &
addr;
1529 InputHandshake &
data;
1530 OutputHandshake &done;
1532 SmallVector<LoadPort, 4> loadPorts;
1533 SmallVector<StorePort, 4> storePorts;
1535 unsigned stCount = op.getStCount();
1536 unsigned ldCount = op.getLdCount();
1537 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
1538 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1539 unwrappedIO.outputs[i],
1540 unwrappedIO.outputs[ldCount + stCount + i]};
1541 loadPorts.push_back(port);
1544 for (
unsigned i = 0, e = stCount; i != e; ++i) {
1545 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1546 unwrappedIO.inputs[i * 2],
1547 unwrappedIO.outputs[ldCount + i]};
1548 storePorts.push_back(port);
1552 auto c0I0 =
s.constant(0, 0);
1554 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1555 auto hlmem = seq::HLMemOp::create(
1556 s.b, loc,
s.clk,
s.rst,
1557 "_handshake_memory_" + std::to_string(op.getId()),
1558 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1561 for (
auto &ld : loadPorts) {
1562 llvm::SmallVector<Value> addresses = {
s.truncate(ld.addr.data, cl2dim)};
1563 auto readData = seq::ReadPortOp::create(
s.b, loc, hlmem.getHandle(),
1564 addresses, ld.addr.valid,
1566 ld.data.data->setValue(readData);
1567 ld.done.data->setValue(c0I0);
1569 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1573 for (
auto &st : storePorts) {
1576 auto writeValidBufferMuxBE = bb.
get(
s.b.getI1Type());
1577 auto writeValidBuffer =
1578 s.reg(
"writeValidBuffer", writeValidBufferMuxBE,
s.constant(1, 0));
1579 st.done.valid->setValue(writeValidBuffer);
1580 st.done.data->setValue(c0I0);
1584 auto storeCompleted =
1585 s.bAnd({st.done.ready, writeValidBuffer},
"storeCompleted");
1589 auto notWriteValidBuffer =
s.bNot(writeValidBuffer);
1590 auto emptyOrComplete =
1591 s.bOr({notWriteValidBuffer, storeCompleted},
"emptyOrComplete");
1594 st.addr.ready->setValue(emptyOrComplete);
1595 st.data.ready->setValue(emptyOrComplete);
1598 auto writeValid =
s.bAnd({st.addr.valid, st.data.valid},
"writeValid");
1604 writeValidBufferMuxBE.setValue(
1605 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1609 llvm::SmallVector<Value> addresses = {
s.truncate(st.addr.data, cl2dim)};
1610 seq::WritePortOp::create(
s.b, loc, hlmem.getHandle(), addresses,
1611 st.data.data, writeValid,
1617class SinkConversionPattern :
public HandshakeConversionPattern<SinkOp> {
1619 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1622 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1624 unwrappedIO.inputs[0].ready->setValue(
s.constant(1, 1));
1628class SourceConversionPattern :
public HandshakeConversionPattern<SourceOp> {
1630 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1633 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1635 unwrappedIO.outputs[0].valid->setValue(
s.constant(1, 1));
1636 unwrappedIO.outputs[0].data->setValue(
s.constant(0, 0));
1640class ConstantConversionPattern
1641 :
public HandshakeConversionPattern<handshake::ConstantOp> {
1643 using HandshakeConversionPattern<
1644 handshake::ConstantOp>::HandshakeConversionPattern;
1647 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1648 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1649 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1650 auto constantValue = op->getAttrOfType<IntegerAttr>(
"value").getValue();
1651 unwrappedIO.outputs[0].data->setValue(
s.constant(constantValue));
1655class BufferConversionPattern :
public HandshakeConversionPattern<BufferOp> {
1657 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1660 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1661 auto input = unwrappedIO.inputs[0];
1662 auto output = unwrappedIO.outputs[0];
1663 InputHandshake lastStage;
1664 SmallVector<int64_t> initValues;
1667 if (op.getInitValues())
1668 initValues = op.getInitValueArray();
1671 buildSeqBufferLogic(s, bb,
toValidType(op.getDataType()),
1672 op.getNumSlots(), input, output, initValues);
1675 output.data->setValue(lastStage.data);
1676 output.valid->setValue(lastStage.valid);
1677 lastStage.ready->setValue(output.ready);
1680 struct SeqBufferStage {
1681 SeqBufferStage(Type dataType, InputHandshake &preStage,
BackedgeBuilder &bb,
1682 RTLBuilder &s,
size_t index,
1683 std::optional<int64_t> initValue)
1684 : dataType(dataType), preStage(preStage),
s(
s), bb(bb), index(index) {
1687 c0s = createZeroDataConst(s,
s.loc, dataType);
1688 currentStage.ready = std::make_shared<Backedge>(bb.
get(
s.b.getI1Type()));
1690 auto hasInitValue =
s.constant(1, initValue.has_value());
1691 auto validBE = bb.
get(
s.b.getI1Type());
1692 auto validReg =
s.reg(getRegName(
"valid"), validBE, hasInitValue);
1693 auto readyBE = bb.
get(
s.b.getI1Type());
1695 Value initValueCs = c0s;
1696 if (initValue.has_value())
1697 initValueCs =
s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1702 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1703 buildControlBufferLogic(validReg, readyBE, dataReg);
1706 StringAttr getRegName(StringRef name) {
1707 return s.b.getStringAttr(name + std::to_string(index) +
"_reg");
1710 void buildControlBufferLogic(Value validReg,
Backedge &readyBE,
1712 auto c0I1 =
s.constant(1, 0);
1713 auto readyRegWire = bb.
get(
s.b.getI1Type());
1714 auto readyReg =
s.reg(getRegName(
"ready"), readyRegWire, c0I1);
1718 currentStage.valid =
s.mux(readyReg, {validReg, readyReg},
1719 "controlValid" + std::to_string(index));
1722 auto notReadyReg =
s.bNot(readyReg);
1725 auto succNotReady =
s.bNot(*currentStage.ready);
1726 auto neitherReady =
s.bAnd({succNotReady, notReadyReg});
1727 auto ctrlNotReady =
s.mux(neitherReady, {readyReg, validReg});
1728 auto bothReady =
s.bAnd({*currentStage.ready, readyReg});
1731 auto resetSignal =
s.mux(bothReady, {ctrlNotReady, c0I1});
1732 readyRegWire.setValue(resetSignal);
1735 auto ctrlDataRegBE = bb.
get(dataType);
1736 auto ctrlDataReg =
s.reg(getRegName(
"ctrl_data"), ctrlDataRegBE, c0s);
1737 auto dataResult =
s.mux(readyReg, {dataReg, ctrlDataReg});
1738 currentStage.data = dataResult;
1740 auto dataNotReadyMux =
s.mux(neitherReady, {ctrlDataReg, dataReg});
1741 auto dataResetSignal =
s.mux(bothReady, {dataNotReadyMux, c0s});
1742 ctrlDataRegBE.setValue(dataResetSignal);
1745 Value buildDataBufferLogic(Value validReg, Value initValue,
1749 auto notValidReg =
s.bNot(validReg);
1750 auto emptyOrReady =
s.bOr({notValidReg, readyBE});
1751 preStage.ready->setValue(emptyOrReady);
1757 auto validRegMux =
s.mux(emptyOrReady, {validReg, preStage.valid});
1763 auto dataRegBE = bb.
get(dataType);
1765 s.reg(getRegName(
"data"),
1766 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1767 dataRegBE.setValue(dataReg);
1771 InputHandshake getOutput() {
return currentStage; }
1774 InputHandshake &preStage;
1775 InputHandshake currentStage;
1784 InputHandshake buildSeqBufferLogic(RTLBuilder &s,
BackedgeBuilder &bb,
1785 Type dataType,
unsigned size,
1786 InputHandshake &input,
1787 OutputHandshake &output,
1788 llvm::ArrayRef<int64_t> initValues)
const {
1791 InputHandshake currentStage = input;
1793 for (
unsigned i = 0; i < size; ++i) {
1794 bool isInitialized = i < initValues.size();
1796 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1797 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1801 return currentStage;
1805class IndexCastConversionPattern
1806 :
public HandshakeConversionPattern<arith::IndexCastOp> {
1808 using HandshakeConversionPattern<
1809 arith::IndexCastOp>::HandshakeConversionPattern;
1812 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1813 unsigned sourceBits =
1814 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1815 unsigned targetBits =
1816 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1817 if (targetBits < sourceBits)
1818 buildTruncateLogic(s, unwrappedIO, targetBits);
1820 buildExtendLogic(s, unwrappedIO,
true);
1824template <
typename T>
1827 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1828 MLIRContext *
context, OpBuilder &submoduleBuilder,
1829 HandshakeLoweringState &ls)
1831 submoduleBuilder(submoduleBuilder), ls(ls) {}
1832 using OpAdaptor =
typename T::Adaptor;
1835 matchAndRewrite(T op, OpAdaptor adaptor,
1836 ConversionPatternRewriter &rewriter)
const override {
1841 implModule = hw::HWModuleExternOp::create(
1842 submoduleBuilder, op.getLoc(),
1846 llvm::SmallVector<Value> operands = adaptor.getOperands();
1847 addSequentialIOOperandsIfNeeded(op, operands);
1848 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1849 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1854 OpBuilder &submoduleBuilder;
1855 HandshakeLoweringState &ls;
1860 using OpConversionPattern::OpConversionPattern;
1864 ConversionPatternRewriter &rewriter)
const override {
1868 HWModuleLike hwModule;
1869 if (op.isExternal()) {
1870 hwModule = hw::HWModuleExternOp::create(
1871 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1873 auto hwModuleOp = hw::HWModuleOp::create(
1874 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1875 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1876 rewriter.inlineBlockBefore(&op.getBody().front(),
1877 hwModuleOp.getBodyBlock()->getTerminator(),
1879 hwModule = hwModuleOp;
1886 auto *parentOp = op->getParentOp();
1887 auto *predeclModule =
1888 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1889 if (predeclModule) {
1890 if (failed(SymbolTable::replaceAllSymbolUses(
1891 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1893 rewriter.eraseOp(predeclModule);
1897 rewriter.eraseOp(op);
1909 ConversionTarget &target,
1911 OpBuilder &moduleBuilder) {
1913 std::map<std::string, unsigned> instanceNameCntr;
1914 NameUniquer instanceUniquer = [&](Operation *op) {
1916 if (
auto idAttr = op->getAttrOfType<IntegerAttr>(
"handshake_id"); idAttr) {
1919 instName +=
"_id" + std::to_string(idAttr.getValue().getZExtValue());
1922 instName += std::to_string(instanceNameCntr[instName]++);
1927 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1929 RewritePatternSet
patterns(op.getContext());
1930 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1932 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1933 SyncConversionPattern>(typeConverter, op.getContext(),
1938 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1939 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1940 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1941 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1942 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1943 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1944 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1945 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1946 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1947 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1948 UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1949 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1950 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1951 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1952 MinMaxConversionPattern<arith::MaxSIOp, comb::ICmpPredicate::sge>,
1953 MinMaxConversionPattern<arith::MaxUIOp, comb::ICmpPredicate::uge>,
1954 MinMaxConversionPattern<arith::MinSIOp, comb::ICmpPredicate::sle>,
1955 MinMaxConversionPattern<arith::MinUIOp, comb::ICmpPredicate::ule>,
1957 StructCreateConversionPattern,
1959 ConditionalBranchConversionPattern, MuxConversionPattern,
1960 PackConversionPattern, UnpackConversionPattern,
1961 ComparisonConversionPattern, BufferConversionPattern,
1962 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1963 MergeConversionPattern, ControlMergeConversionPattern,
1964 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1965 InstanceConversionPattern,
1967 ExtendConversionPattern<arith::ExtUIOp,
false>,
1968 ExtendConversionPattern<arith::ExtSIOp,
true>,
1969 TruncateConversionPattern, IndexCastConversionPattern>(
1970 typeConverter, op.getContext(), moduleBuilder, ls);
1972 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
1973 return op->emitOpError() <<
"error during conversion";
1978class HandshakeToHWPass
1979 :
public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1981 void runOnOperation()
override {
1982 mlir::ModuleOp mod = getOperation();
1986 for (
auto f : mod.getOps<
handshake::FuncOp>()) {
1988 f.emitOpError() <<
"HandshakeToHW: failed to verify that all values "
1989 "are used exactly once. Remember to run the "
1990 "fork/sink materialization pass before HW lowering.";
1991 signalPassFailure();
1997 std::string topLevel;
1999 SmallVector<std::string> sortedFuncs;
2001 signalPassFailure();
2005 ESITypeConverter typeConverter;
2006 ConversionTarget target(getContext());
2012 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
2018 OpBuilder submoduleBuilder(mod.getContext());
2019 submoduleBuilder.setInsertionPointToStart(mod.getBody());
2020 for (
auto &funcName :
llvm::reverse(sortedFuncs)) {
2022 assert(funcOp &&
"handshake.func not found in module!");
2024 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2025 signalPassFailure();
2032 for (
auto hwModule : mod.getOps<
hw::HWModuleOp>())
2034 return signalPassFailure();
2040 RewritePatternSet
patterns(mod.getContext());
2041 patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2043 if (failed(applyPartialConversion(mod, target, std::move(
patterns)))) {
2044 mod->emitOpError() <<
"error during conversion";
2045 signalPassFailure();
2052 return std::make_unique<HandshakeToHWPass>();
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static Type tupleToStruct(TupleType tuple)
std::function< std::string(Operation *)> NameUniquer
static std::unique_ptr< Context > context
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static std::string getCallName(Operation *op)
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
void setOutput(unsigned i, Value v)
const ModulePortInfo & getPortList() const
This stores lookup tables to make manipulating and working with the IR more efficient.
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Channels are the basic communication primitives.
const Type * getInner() const
create(elements, Type result_type=None)
create(elements, Type result_type=None)
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
mlir::Type innerType(mlir::Type type)
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
This holds a decoded list of input/inout and output ports for a module or instance.
void eraseInput(size_t idx)
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.