26#include "mlir/Dialect/Arith/IR/Arith.h"
27#include "mlir/Dialect/MemRef/IR/MemRef.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/MathExtras.h"
37#define GEN_PASS_DEF_HANDSHAKETOHW
38#include "circt/Conversion/Passes.h.inc"
51 return toValidType(mlir::TupleType::get(types[0].getContext(), types));
56struct HandshakeLoweringState {
57 ModuleOp parentModule;
64class ESITypeConverter :
public TypeConverter {
67 addConversion([](Type type) -> Type {
return esiWrapper(type); });
69 addTargetMaterialization([&](mlir::OpBuilder &builder,
70 mlir::Type resultType, mlir::ValueRange inputs,
71 mlir::Location loc) -> mlir::Value {
72 if (inputs.size() != 1)
74 return UnrealizedConversionCastOp::create(builder, loc, resultType,
79 addSourceMaterialization([&](mlir::OpBuilder &builder,
80 mlir::Type resultType, mlir::ValueRange inputs,
81 mlir::Location loc) -> mlir::Value {
82 if (inputs.size() != 1)
84 return UnrealizedConversionCastOp::create(builder, loc, resultType,
99 std::string subModuleName = oldOp->getName().getStringRef().str();
100 std::replace(subModuleName.begin(), subModuleName.end(),
'.',
'_');
101 return subModuleName;
105 auto callOp = dyn_cast<handshake::InstanceOp>(op);
113 auto opType = op.getType();
114 if (
auto channelType = dyn_cast<esi::ChannelType>(opType))
115 return channelType.getInner();
123 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
124 .Case<MemoryOp, ExternalMemoryOp>([&](
auto memOp) {
126 {memOp.getMemRefType().getElementType()}};
131 SmallVector<Type> inTypes, outTypes;
132 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
134 llvm::transform(op->getResults(), std::back_inserter(outTypes),
145 std::string typeName;
147 if (type.isIntOrIndex()) {
148 if (
auto indexType = dyn_cast<IndexType>(type))
149 typeName +=
"_ui" + std::to_string(indexType.kInternalStorageBitWidth);
150 else if (type.isSignedInteger())
151 typeName +=
"_si" + std::to_string(type.getIntOrFloatBitWidth());
153 typeName +=
"_ui" + std::to_string(type.getIntOrFloatBitWidth());
154 }
else if (
auto tupleType = dyn_cast<TupleType>(type)) {
155 typeName +=
"_tuple";
158 }
else if (
auto structType = dyn_cast<hw::StructType>(type)) {
159 typeName +=
"_struct";
160 for (
auto element : structType.getElements())
161 typeName +=
"_" + element.name.str() +
getTypeName(loc, element.type);
162 }
else if (isa<NoneType>(type)) {
165 emitError(loc) <<
"unsupported data type '" << type <<
"'";
172 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
173 return instanceOp.getModule().str();
178 if (
auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
179 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
180 auto intType = intAttr.getType();
182 if (intType.isSignedInteger())
183 subModuleName +=
"_c" + std::to_string(intAttr.getSInt());
184 else if (intType.isUnsignedInteger())
185 subModuleName +=
"_c" + std::to_string(intAttr.getUInt());
187 subModuleName +=
"_c" + std::to_string((uint64_t)intAttr.getInt());
189 oldOp->emitError(
"unsupported constant type");
194 if (!inTypes.empty())
195 subModuleName +=
"_in";
196 for (
auto inType : inTypes)
197 subModuleName +=
getTypeName(oldOp->getLoc(), inType);
199 if (!outTypes.empty())
200 subModuleName +=
"_out";
201 for (
auto outType : outTypes)
202 subModuleName +=
getTypeName(oldOp->getLoc(), outType);
205 if (
auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
206 subModuleName +=
"_id" + std::to_string(memOp.getId());
209 if (
auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
210 subModuleName +=
"_" + stringifyEnum(comOp.getPredicate()).str();
213 if (
auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
214 subModuleName +=
"_" + std::to_string(bufferOp.getNumSlots()) +
"slots";
215 if (bufferOp.isSequential())
216 subModuleName +=
"_seq";
218 subModuleName +=
"_fifo";
220 if (
auto initValues = bufferOp.getInitValues()) {
221 subModuleName +=
"_init";
222 for (
const Attribute e : *initValues) {
223 assert(isa<IntegerAttr>(e));
225 "_" + std::to_string(dyn_cast<IntegerAttr>(e).
getInt());
231 if (
auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
232 ctrlInterface && ctrlInterface.isControl()) {
234 subModuleName +=
"_" + std::to_string(oldOp->getNumOperands()) +
"ins_" +
235 std::to_string(oldOp->getNumResults()) +
"outs";
236 subModuleName +=
"_ctrl";
239 (!inTypes.empty() || !outTypes.empty()) &&
240 "Insufficient discriminating type info generated for the operation!");
243 return subModuleName;
255 if (
auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
257 if (
auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
264 HWModuleLike targetModule;
265 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
270 if (isa<handshake::InstanceOp>(oldOp))
272 "handshake.instance target modules should always have been lowered "
273 "before the modules that reference them!");
283static llvm::SmallVector<hw::detail::FieldInfo>
285 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
286 for (
auto port : portInfo)
287 fieldInfo.push_back({port.name, port.type});
295 auto *ctx = mod.getContext();
298 llvm::DenseMap<unsigned, Value> memrefPorts;
299 for (
auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
300 auto channel = dyn_cast<esi::ChannelType>(arg.getType());
301 if (channel && isa<MemRefType>(channel.getInner()))
302 memrefPorts[i] = arg;
305 if (memrefPorts.empty())
310 auto getMemoryIOInfo = [&](Location loc, Twine portName,
unsigned argIdx,
311 ArrayRef<hw::PortInfo> info,
315 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
319 for (
auto [i, arg] : memrefPorts) {
321 auto memName = mod.getArgName(i);
324 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
326 cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
327 extmemInstance, extmemInstance.getModuleNameAttr()));
338 SmallVector<PortInfo> outputs(portInfo.
getOutputs());
340 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_in", i, outputs,
341 hw::ModulePort::Direction::Input);
342 mod.insertPorts({{i, inPortInfo}}, {});
343 auto newInPort = mod.getArgumentForInput(i);
345 b.setInsertionPointToStart(mod.getBodyBlock());
346 auto newInPortExploded = hw::StructExplodeOp::create(
347 b, arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
348 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
352 unsigned outArgI = mod.getNumOutputPorts();
353 SmallVector<PortInfo> inputs(portInfo.
getInputs());
355 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_out", outArgI,
356 inputs, hw::ModulePort::Direction::Output);
358 auto memOutputArgs = extmemInstance.getOperands().drop_front();
359 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
361 b, arg.getLoc(), outPortInfo.type, memOutputArgs);
362 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
367 extmemInstance.erase();
371 mod.modifyPorts( {}, {},
382struct InputHandshake {
384 std::shared_ptr<Backedge> ready;
390struct OutputHandshake {
391 std::shared_ptr<Backedge> valid;
393 std::shared_ptr<Backedge>
data;
398struct HandshakeWire {
400 MLIRContext *ctx = dataType.getContext();
401 auto i1Type = IntegerType::get(ctx, 1);
402 valid = std::make_shared<Backedge>(bb.
get(i1Type));
403 ready = std::make_shared<Backedge>(bb.
get(i1Type));
404 data = std::make_shared<Backedge>(bb.
get(dataType));
409 InputHandshake getAsInput() {
return {*valid, ready, *
data}; }
410 OutputHandshake getAsOutput() {
return {valid, *ready,
data}; }
412 std::shared_ptr<Backedge> valid;
413 std::shared_ptr<Backedge> ready;
414 std::shared_ptr<Backedge>
data;
417template <
typename T,
typename TInner>
418llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
419 llvm::function_ref<T(TInner &)> extractor) {
420 llvm::SmallVector<T> result;
421 llvm::transform(container, std::back_inserter(result), extractor);
425 llvm::SmallVector<InputHandshake> inputs;
426 llvm::SmallVector<OutputHandshake> outputs;
428 llvm::SmallVector<Value> getInputValids() {
429 return extractValues<Value, InputHandshake>(
430 inputs, [](
auto &hs) {
return hs.valid; });
432 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
433 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
434 inputs, [](
auto &hs) {
return hs.ready; });
436 llvm::SmallVector<Value> getInputDatas() {
437 return extractValues<Value, InputHandshake>(
438 inputs, [](
auto &hs) {
return hs.data; });
440 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
441 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
442 outputs, [](
auto &hs) {
return hs.valid; });
444 llvm::SmallVector<Value> getOutputReadys() {
445 return extractValues<Value, OutputHandshake>(
446 outputs, [](
auto &hs) {
return hs.ready; });
448 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
449 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
450 outputs, [](
auto &hs) {
return hs.data; });
459 Value clk = Value(), Value rst = Value())
460 :
info(std::move(
info)), b(builder), loc(loc),
clk(
clk), rst(rst) {}
462 Value constant(
const APInt &apv, std::optional<StringRef> name = {}) {
465 bool isZeroWidth = apv.getBitWidth() == 0;
467 auto it = constants.find(apv);
468 if (it != constants.end())
474 constants[apv] = cval;
478 Value constant(
unsigned width, int64_t value,
479 std::optional<StringRef> name = {}) {
481 APInt(width, value,
false,
true));
483 std::pair<Value, Value>
wrap(Value data, Value valid,
484 std::optional<StringRef> name = {}) {
485 auto wrapOp = esi::WrapValidReadyOp::create(b, loc, data, valid);
486 return {wrapOp.getResult(0), wrapOp.getResult(1)};
488 std::pair<Value, Value>
unwrap(Value channel, Value ready,
489 std::optional<StringRef> name = {}) {
490 auto unwrapOp = esi::UnwrapValidReadyOp::create(b, loc, channel, ready);
491 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
495 Value
reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
496 Value rst = Value()) {
497 Value resolvedClk =
clk ?
clk : this->
clk;
498 Value resolvedRst = rst ? rst : this->rst;
500 "No global clock provided to this RTLBuilder - a clock "
501 "signal must be provided to the reg(...) function.");
503 "No global reset provided to this RTLBuilder - a reset "
504 "signal must be provided to the reg(...) function.");
510 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
511 std::optional<StringRef> name = {}) {
512 return comb::ICmpOp::create(b, loc, predicate, lhs, rhs);
515 Value buildNamedOp(llvm::function_ref<Value()> f,
516 std::optional<StringRef> name) {
519 Operation *op = v.getDefiningOp();
520 if (name.has_value()) {
521 op->setAttr(
"sv.namehint", b.getStringAttr(*name));
522 nameAttr = b.getStringAttr(*name);
528 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
530 [&]() {
return comb::AndOp::create(b, loc, values,
false); }, name);
533 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
535 [&]() {
return comb::OrOp::create(b, loc, values,
false); }, name);
539 Value bNot(Value value, std::optional<StringRef> name = {}) {
540 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
541 std::string inferedName;
545 value.getDefiningOp()->getAttrOfType<StringAttr>(
"sv.namehint")) {
546 inferedName = (
"not_" +
valueName.getValue()).str();
552 [&]() {
return comb::XorOp::create(b, loc, value, allOnes); }, name);
554 return b.createOrFold<
comb::XorOp>(loc, value, allOnes,
false);
557 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
559 [&]() {
return comb::ShlOp::create(b, loc, value, shift); }, name);
562 Value concat(ValueRange values, std::optional<StringRef> name = {}) {
564 [&]() {
return comb::ConcatOp::create(b, loc, values); }, name);
568 Value pack(ValueRange values, Type structType = Type(),
569 std::optional<StringRef> name = {}) {
580 ValueRange unpack(Value value) {
581 auto structType = cast<hw::StructType>(value.getType());
582 llvm::SmallVector<Type> innerTypes;
583 structType.getInnerTypes(innerTypes);
584 return hw::StructExplodeOp::create(b, loc, innerTypes, value).getResults();
587 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
588 llvm::SmallVector<Value> bits;
589 for (
unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
595 Value rOr(Value v, std::optional<StringRef> name = {}) {
596 return buildNamedOp([&]() {
return bOr(toBits(v)); }, name);
600 Value extract(Value v,
unsigned lo,
unsigned hi,
601 std::optional<StringRef> name = {}) {
602 unsigned width = hi - lo + 1;
608 Value truncate(Value value,
unsigned width,
609 std::optional<StringRef> name = {}) {
610 return extract(value, 0, width - 1, name);
613 Value zext(Value value,
unsigned outWidth,
614 std::optional<StringRef> name = {}) {
615 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
616 assert(inWidth <= outWidth &&
"zext: input width must be <- output width.");
617 if (inWidth == outWidth)
619 auto c0 = constant(outWidth - inWidth, 0);
620 return concat({c0, value}, name);
623 Value sext(Value value,
unsigned outWidth,
624 std::optional<StringRef> name = {}) {
625 return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
629 Value bit(Value v,
unsigned index, std::optional<StringRef> name = {}) {
630 return extract(v, index, index, name);
634 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
640 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
648 Value mux(Value index, ValueRange values,
649 std::optional<StringRef> name = {}) {
650 if (values.size() == 2)
651 return comb::MuxOp::create(b, loc, index, values[1], values[0]);
653 return arrayGet(arrayCreate(values), index, name);
658 Value ohMux(Value index, ValueRange inputs) {
660 unsigned numInputs = inputs.size();
661 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
662 "one-hot select can't mux inputs");
666 auto dataType = inputs[0].getType();
668 isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
669 Value muxValue = constant(width, 0);
672 for (
size_t i = numInputs - 1; i != 0; --i) {
673 Value input = inputs[i];
674 Value selectBit = bit(index, i);
675 muxValue = mux(selectBit, {muxValue, input});
685 DenseMap<APInt, Value> constants;
690static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
691 return TypeSwitch<Type, Value>(type)
692 .Case<NoneType>([&](NoneType) {
return s.constant(0, 0); })
693 .Case<IntType, IntegerType>([&](
auto type) {
694 return s.constant(type.getIntOrFloatBitWidth(), 0);
696 .Case<hw::StructType>([&](
auto structType) {
697 SmallVector<Value> zeroValues;
698 for (
auto field : structType.getElements())
699 zeroValues.push_back(createZeroDataConst(s, loc, field.type));
702 .Default([&](Type) -> Value {
703 emitError(loc) <<
"unsupported type for zero value: " << type;
710addSequentialIOOperandsIfNeeded(Operation *op,
711 llvm::SmallVectorImpl<Value> &operands) {
715 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
717 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
719 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
726 HandshakeConversionPattern(ESITypeConverter &typeConverter,
727 MLIRContext *
context, OpBuilder &submoduleBuilder,
728 HandshakeLoweringState &ls)
730 submoduleBuilder(submoduleBuilder), ls(ls) {}
732 using OpAdaptor =
typename T::Adaptor;
735 matchAndRewrite(T op, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const override {
745 submoduleBuilder.setInsertionPoint(op->getParentOp());
746 implModule = hw::HWModuleOp::create(
747 submoduleBuilder, op.getLoc(),
753 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
754 clk = ports.getInput(
"clock");
755 rst = ports.getInput(
"reset");
759 RTLBuilder s(ports.
getPortList(), b, op.getLoc(), clk, rst);
765 llvm::SmallVector<Value> operands = adaptor.getOperands();
766 addSequentialIOOperandsIfNeeded(op, operands);
767 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
768 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
781 UnwrappedIO unwrapped;
782 for (
auto port : ports.getInputs()) {
783 if (!isa<esi::ChannelType>(port.getType()))
786 auto ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
787 auto [
data, valid] = s.unwrap(port, *ready);
791 unwrapped.inputs.push_back(hs);
793 for (
auto &outputInfo : ports.
getPortList().getOutputs()) {
795 dyn_cast<esi::ChannelType>(outputInfo.type);
800 auto data = std::make_shared<Backedge>(bb.
get(innerType));
801 auto valid = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
802 auto [dataCh, ready] = s.wrap(*data, *valid);
806 ports.
setOutput(outputInfo.name, dataCh);
807 unwrapped.outputs.push_back(hs);
812 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
813 OutputHandshake &output, Value cond)
const {
814 auto validAndReady = s.bAnd({output.ready, cond});
815 for (
auto &input : inputs)
816 input.ready->setValue(validAndReady);
819 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
820 OutputHandshake &output)
const {
821 llvm::SmallVector<Value> valids;
822 for (
auto &input : inputs)
823 valids.push_back(input.valid);
824 Value allValid = s.bAnd(valids);
825 output.valid->setValue(allValid);
826 setAllReadyWithCond(s, inputs, output, allValid);
832 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
833 InputHandshake &select)
const {
835 size_t numInputs = unwrapped.inputs.size();
836 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
837 Value truncatedSelect =
838 select.data.getType().getIntOrFloatBitWidth() > selectWidth
839 ? s.truncate(select.data, selectWidth)
843 auto selectZext = s.zext(truncatedSelect, numInputs);
844 auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
845 auto &res = unwrapped.outputs[0];
848 auto selectedInputValid =
849 s.mux(truncatedSelect, unwrapped.getInputValids());
851 auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
852 res.valid->setValue(selAndInputValid);
853 auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
856 select.ready->setValue(resValidAndReady);
859 for (
auto [inIdx, in] :
llvm::enumerate(unwrapped.inputs)) {
861 auto isSelected = s.bit(select1h, inIdx);
865 auto activeAndResultValidAndReady =
866 s.bAnd({isSelected, resValidAndReady});
867 in.ready->setValue(activeAndResultValidAndReady);
871 res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
876 void buildForkLogic(RTLBuilder &s,
BackedgeBuilder &bb, InputHandshake &input,
877 ArrayRef<OutputHandshake> outputs)
const {
878 auto c0I1 = s.constant(1, 0);
879 llvm::SmallVector<Value> doneWires;
880 for (
auto [i, output] :
llvm::enumerate(outputs)) {
881 auto doneBE = bb.
get(s.b.getI1Type());
882 auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
883 auto emittedReg = s.reg(
"emitted_" + std::to_string(i), emitted, c0I1);
884 auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
885 output.valid->setValue(outValid);
886 auto validReady = s.bAnd({output.ready, outValid});
887 auto done = s.bOr({validReady, emittedReg},
"done" + std::to_string(i));
888 doneBE.setValue(done);
889 doneWires.push_back(done);
891 input.ready->setValue(s.bAnd(doneWires,
"allDone"));
897 void buildUnitRateJoinLogic(
898 RTLBuilder &s, UnwrappedIO &unwrappedIO,
899 llvm::function_ref<Value(ValueRange)> unitBuilder)
const {
900 assert(unwrappedIO.outputs.size() == 1 &&
901 "Expected exactly one output for unit-rate join actor");
903 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
906 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
907 unwrappedIO.outputs[0].data->setValue(unitRes);
910 void buildUnitRateForkLogic(
912 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder)
const {
913 assert(unwrappedIO.inputs.size() == 1 &&
914 "Expected exactly one input for unit-rate fork actor");
916 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
919 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
920 assert(unitResults.size() == unwrappedIO.outputs.size() &&
921 "Expected unit builder to return one result per output");
922 for (
auto [res, outport] :
llvm::zip(unitResults, unwrappedIO.outputs))
923 outport.
data->setValue(res);
926 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
927 bool signExtend)
const {
929 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
930 .getIntOrFloatBitWidth();
931 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
933 return s.sext(inputs[0], outWidth);
934 return s.zext(inputs[0], outWidth);
938 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
939 unsigned targetWidth)
const {
941 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
942 .getIntOrFloatBitWidth();
943 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
944 return s.truncate(inputs[0], outWidth);
949 static size_t getNumIndexBits(uint64_t numValues) {
950 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
953 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
955 DenseMap<size_t, Value> &indexMapping)
const {
956 auto numInputs = inputs.size();
957 auto priorityArb = defaultValue;
959 for (
size_t i = numInputs; i > 0; --i) {
960 size_t inputIndex = i - 1;
961 size_t oneHotIndex =
size_t{1} << inputIndex;
962 auto constIndex = s.constant(numInputs, oneHotIndex);
963 indexMapping[inputIndex] = constIndex;
964 priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
970 OpBuilder &submoduleBuilder;
971 HandshakeLoweringState &ls;
974class ForkConversionPattern :
public HandshakeConversionPattern<ForkOp> {
976 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
979 auto unwrapped = unwrapIO(s, bb, ports);
980 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
981 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
986class JoinConversionPattern :
public HandshakeConversionPattern<JoinOp> {
988 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
991 auto unwrappedIO = unwrapIO(s, bb, ports);
992 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
993 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
997class SyncConversionPattern :
public HandshakeConversionPattern<SyncOp> {
999 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1002 auto unwrappedIO = unwrapIO(s, bb, ports);
1005 HandshakeWire wire(bb, s.b.getNoneType());
1007 OutputHandshake output = wire.getAsOutput();
1008 buildJoinLogic(s, unwrappedIO.inputs, output);
1010 InputHandshake input = wire.getAsInput();
1018 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1022 for (
auto &&[in, out] :
llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1027class MuxConversionPattern :
public HandshakeConversionPattern<MuxOp> {
1029 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1032 auto unwrappedIO = unwrapIO(s, bb, ports);
1035 auto select = unwrappedIO.inputs[0];
1036 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1037 buildMuxLogic(s, unwrappedIO, select);
1041class InstanceConversionPattern
1042 :
public HandshakeConversionPattern<handshake::InstanceOp> {
1044 using HandshakeConversionPattern<
1045 handshake::InstanceOp>::HandshakeConversionPattern;
1049 "If we indeed perform conversion in post-order, this "
1050 "should never be called. The base HandshakeConversionPattern logic "
1051 "will instantiate the external module.");
1055class ESIInstanceConversionPattern
1058 ESIInstanceConversionPattern(MLIRContext *
context,
1063 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1064 ConversionPatternRewriter &rewriter)
const override {
1070 SmallVector<Value> operands;
1071 for (
size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1073 operands.push_back(adaptor.getOperands()[i]);
1074 operands.push_back(adaptor.getClk());
1075 operands.push_back(adaptor.getRst());
1078 Operation *targetModule = symCache.
getDefinition(op.getModuleAttr());
1080 rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1081 op.getInstNameAttr(), operands);
1089class ReturnConversionPattern
1092 using OpConversionPattern::OpConversionPattern;
1094 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1095 ConversionPatternRewriter &rewriter)
const override {
1098 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1099 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1100 outputOp->setOperands(adaptor.getOperands());
1101 outputOp->moveAfter(&parent.getBodyBlock()->back());
1102 rewriter.eraseOp(op);
1109template <
typename TIn,
typename TOut = TIn>
1110class UnitRateConversionPattern :
public HandshakeConversionPattern<TIn> {
1112 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1115 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1116 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1121 return TOut::create(s.b, op.getLoc(), inputs,
1122 ArrayRef<NamedAttribute>{});
1127class PackConversionPattern :
public HandshakeConversionPattern<PackOp> {
1129 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1132 auto unwrappedIO = unwrapIO(s, bb, ports);
1133 buildUnitRateJoinLogic(s, unwrappedIO,
1134 [&](ValueRange inputs) {
return s.pack(inputs); });
1138class StructCreateConversionPattern
1139 :
public HandshakeConversionPattern<hw::StructCreateOp> {
1141 using HandshakeConversionPattern<
1145 auto unwrappedIO = unwrapIO(s, bb, ports);
1146 auto structType = op.getResult().getType();
1147 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1148 return s.pack(inputs, structType);
1153class UnpackConversionPattern :
public HandshakeConversionPattern<UnpackOp> {
1155 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1158 auto unwrappedIO = unwrapIO(s, bb, ports);
1159 buildUnitRateForkLogic(s, bb, unwrappedIO,
1160 [&](Value input) {
return s.unpack(input); });
1164class ConditionalBranchConversionPattern
1165 :
public HandshakeConversionPattern<ConditionalBranchOp> {
1167 using HandshakeConversionPattern<
1168 ConditionalBranchOp>::HandshakeConversionPattern;
1171 auto unwrappedIO = unwrapIO(s, bb, ports);
1172 auto cond = unwrappedIO.inputs[0];
1173 auto arg = unwrappedIO.inputs[1];
1174 auto trueRes = unwrappedIO.outputs[0];
1175 auto falseRes = unwrappedIO.outputs[1];
1177 auto condArgValid = s.bAnd({cond.valid, arg.valid});
1180 trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1181 falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1184 trueRes.data->setValue(arg.data);
1185 falseRes.data->setValue(arg.data);
1188 auto selectedResultReady =
1189 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1190 auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1191 arg.ready->setValue(condArgReady);
1192 cond.ready->setValue(condArgReady);
1196template <
typename TIn,
bool signExtend>
1197class ExtendConversionPattern :
public HandshakeConversionPattern<TIn> {
1199 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1202 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1203 this->buildExtendLogic(s, unwrappedIO, signExtend);
1207class ComparisonConversionPattern
1208 :
public HandshakeConversionPattern<arith::CmpIOp> {
1210 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1213 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1214 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1215 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1216 return comb::ICmpOp::create(s.b, op.getLoc(), predicate, inputs[0],
1221 switch (op.getPredicate()) {
1222 case arith::CmpIPredicate::eq:
1223 return buildCompareLogic(comb::ICmpPredicate::eq);
1224 case arith::CmpIPredicate::ne:
1225 return buildCompareLogic(comb::ICmpPredicate::ne);
1226 case arith::CmpIPredicate::slt:
1227 return buildCompareLogic(comb::ICmpPredicate::slt);
1228 case arith::CmpIPredicate::ult:
1229 return buildCompareLogic(comb::ICmpPredicate::ult);
1230 case arith::CmpIPredicate::sle:
1231 return buildCompareLogic(comb::ICmpPredicate::sle);
1232 case arith::CmpIPredicate::ule:
1233 return buildCompareLogic(comb::ICmpPredicate::ule);
1234 case arith::CmpIPredicate::sgt:
1235 return buildCompareLogic(comb::ICmpPredicate::sgt);
1236 case arith::CmpIPredicate::ugt:
1237 return buildCompareLogic(comb::ICmpPredicate::ugt);
1238 case arith::CmpIPredicate::sge:
1239 return buildCompareLogic(comb::ICmpPredicate::sge);
1240 case arith::CmpIPredicate::uge:
1241 return buildCompareLogic(comb::ICmpPredicate::uge);
1243 assert(
false &&
"invalid CmpIOp");
1247class TruncateConversionPattern
1248 :
public HandshakeConversionPattern<arith::TruncIOp> {
1250 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1253 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1254 unsigned targetBits =
1255 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1256 buildTruncateLogic(s, unwrappedIO, targetBits);
1260class ControlMergeConversionPattern
1261 :
public HandshakeConversionPattern<ControlMergeOp> {
1263 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1266 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1267 auto resData = unwrappedIO.outputs[0];
1268 auto resIndex = unwrappedIO.outputs[1];
1271 unsigned numInputs = unwrappedIO.inputs.size();
1272 auto indexType = s.b.getIntegerType(numInputs);
1273 Value noWinner = s.constant(numInputs, 0);
1274 Value c0I1 = s.constant(1, 0);
1277 auto won = bb.
get(indexType);
1278 Value wonReg = s.reg(
"won_reg", won, noWinner);
1281 auto win = bb.
get(indexType);
1285 auto fired = bb.
get(s.b.getI1Type());
1288 auto resultEmitted = bb.
get(s.b.getI1Type());
1289 Value resultEmittedReg = s.reg(
"result_emitted_reg", resultEmitted, c0I1);
1290 auto indexEmitted = bb.
get(s.b.getI1Type());
1291 Value indexEmittedReg = s.reg(
"index_emitted_reg", indexEmitted, c0I1);
1294 auto resultDone = bb.
get(s.b.getI1Type());
1295 auto indexDone = bb.
get(s.b.getI1Type());
1299 auto hasWinnerCondition = s.rOr({win});
1300 auto hadWinnerCondition = s.rOr({wonReg});
1308 DenseMap<size_t, Value> argIndexValues;
1309 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1310 noWinner, argIndexValues);
1311 priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1312 win.setValue(priorityArb);
1322 auto resultNotEmitted = s.bNot(resultEmittedReg);
1323 auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1324 resData.valid->setValue(resultValid);
1325 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1327 auto indexNotEmitted = s.bNot(indexEmittedReg);
1328 auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1329 resIndex.valid->setValue(indexValid);
1333 SmallVector<Value, 8> indexOutputs;
1334 for (
size_t i = 0; i < numInputs; ++i)
1335 indexOutputs.push_back(s.constant(64, i));
1337 auto indexOutput = s.ohMux(win, indexOutputs);
1338 resIndex.data->setValue(indexOutput);
1344 won.setValue(s.mux(fired, {win, noWinner}));
1349 auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1350 resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1352 auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1353 indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1357 fired.setValue(s.bAnd({resultDone, indexDone}));
1363 resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1364 indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1369 auto winnerOrDefault = s.mux(fired, {noWinner, win});
1370 for (
auto [i, ir] :
llvm::enumerate(unwrappedIO.getInputReadys())) {
1371 auto &indexValue = argIndexValues[i];
1372 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1377class MergeConversionPattern :
public HandshakeConversionPattern<MergeOp> {
1379 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1382 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1383 auto resData = unwrappedIO.outputs[0];
1386 unsigned numInputs = unwrappedIO.inputs.size();
1387 auto indexType = s.b.getIntegerType(numInputs);
1388 Value noWinner = s.constant(numInputs, 0);
1391 auto win = bb.
get(indexType);
1394 auto hasWinnerCondition = s.rOr(win);
1401 DenseMap<size_t, Value> argIndexValues;
1402 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1403 noWinner, argIndexValues);
1404 win.setValue(priorityArb);
1411 resData.valid->setValue(hasWinnerCondition);
1412 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1417 auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1422 auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1423 for (
auto [i, ir] :
llvm::enumerate(unwrappedIO.getInputReadys())) {
1424 auto &indexValue = argIndexValues[i];
1425 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1430class LoadConversionPattern
1431 :
public HandshakeConversionPattern<handshake::LoadOp> {
1433 using HandshakeConversionPattern<
1434 handshake::LoadOp>::HandshakeConversionPattern;
1437 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1438 auto addrFromUser = unwrappedIO.inputs[0];
1439 auto dataFromMem = unwrappedIO.inputs[1];
1440 auto controlIn = unwrappedIO.inputs[2];
1441 auto dataToUser = unwrappedIO.outputs[0];
1442 auto addrToMem = unwrappedIO.outputs[1];
1444 addrToMem.data->setValue(addrFromUser.data);
1445 dataToUser.data->setValue(dataFromMem.data);
1449 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1453 dataToUser.valid->setValue(dataFromMem.valid);
1454 dataFromMem.ready->setValue(dataToUser.ready);
1458class StoreConversionPattern
1459 :
public HandshakeConversionPattern<handshake::StoreOp> {
1461 using HandshakeConversionPattern<
1462 handshake::StoreOp>::HandshakeConversionPattern;
1465 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1466 auto addrFromUser = unwrappedIO.inputs[0];
1467 auto dataFromUser = unwrappedIO.inputs[1];
1468 auto controlIn = unwrappedIO.inputs[2];
1469 auto dataToMem = unwrappedIO.outputs[0];
1470 auto addrToMem = unwrappedIO.outputs[1];
1473 auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1477 HandshakeWire joinWire(bb, s.b.getNoneType());
1478 joinWire.ready->setValue(outputsReady);
1479 OutputHandshake joinOutput = joinWire.getAsOutput();
1480 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1483 addrToMem.data->setValue(addrFromUser.data);
1484 dataToMem.data->setValue(dataFromUser.data);
1487 addrToMem.valid->setValue(*joinWire.valid);
1488 dataToMem.valid->setValue(*joinWire.valid);
1492class MemoryConversionPattern
1493 :
public HandshakeConversionPattern<handshake::MemoryOp> {
1495 using HandshakeConversionPattern<
1496 handshake::MemoryOp>::HandshakeConversionPattern;
1499 auto loc = op.getLoc();
1502 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1504 InputHandshake &
addr;
1505 OutputHandshake &
data;
1506 OutputHandshake &done;
1509 InputHandshake &
addr;
1510 InputHandshake &
data;
1511 OutputHandshake &done;
1513 SmallVector<LoadPort, 4> loadPorts;
1514 SmallVector<StorePort, 4> storePorts;
1516 unsigned stCount = op.getStCount();
1517 unsigned ldCount = op.getLdCount();
1518 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
1519 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1520 unwrappedIO.outputs[i],
1521 unwrappedIO.outputs[ldCount + stCount + i]};
1522 loadPorts.push_back(port);
1525 for (
unsigned i = 0, e = stCount; i != e; ++i) {
1526 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1527 unwrappedIO.inputs[i * 2],
1528 unwrappedIO.outputs[ldCount + i]};
1529 storePorts.push_back(port);
1533 auto c0I0 = s.constant(0, 0);
1535 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1536 auto hlmem = seq::HLMemOp::create(
1537 s.b, loc, s.clk, s.rst,
1538 "_handshake_memory_" + std::to_string(op.getId()),
1539 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1542 for (
auto &ld : loadPorts) {
1543 llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1544 auto readData = seq::ReadPortOp::create(s.b, loc, hlmem.getHandle(),
1545 addresses, ld.addr.valid,
1547 ld.data.data->setValue(readData);
1548 ld.done.data->setValue(c0I0);
1550 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1554 for (
auto &st : storePorts) {
1557 auto writeValidBufferMuxBE = bb.
get(s.b.getI1Type());
1558 auto writeValidBuffer =
1559 s.reg(
"writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1560 st.done.valid->setValue(writeValidBuffer);
1561 st.done.data->setValue(c0I0);
1565 auto storeCompleted =
1566 s.bAnd({st.done.ready, writeValidBuffer},
"storeCompleted");
1570 auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1571 auto emptyOrComplete =
1572 s.bOr({notWriteValidBuffer, storeCompleted},
"emptyOrComplete");
1575 st.addr.ready->setValue(emptyOrComplete);
1576 st.data.ready->setValue(emptyOrComplete);
1579 auto writeValid = s.bAnd({st.addr.valid, st.data.valid},
"writeValid");
1585 writeValidBufferMuxBE.setValue(
1586 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1590 llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1591 seq::WritePortOp::create(s.b, loc, hlmem.getHandle(), addresses,
1592 st.data.data, writeValid,
1598class SinkConversionPattern :
public HandshakeConversionPattern<SinkOp> {
1600 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1603 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1605 unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1609class SourceConversionPattern :
public HandshakeConversionPattern<SourceOp> {
1611 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1614 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1616 unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1617 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1621class ConstantConversionPattern
1622 :
public HandshakeConversionPattern<handshake::ConstantOp> {
1624 using HandshakeConversionPattern<
1625 handshake::ConstantOp>::HandshakeConversionPattern;
1628 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1629 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1630 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1631 auto constantValue = op->getAttrOfType<IntegerAttr>(
"value").getValue();
1632 unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1636class BufferConversionPattern :
public HandshakeConversionPattern<BufferOp> {
1638 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1641 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1642 auto input = unwrappedIO.inputs[0];
1643 auto output = unwrappedIO.outputs[0];
1644 InputHandshake lastStage;
1645 SmallVector<int64_t> initValues;
1648 if (op.getInitValues())
1649 initValues = op.getInitValueArray();
1652 buildSeqBufferLogic(s, bb,
toValidType(op.getDataType()),
1653 op.getNumSlots(), input, output, initValues);
1656 output.data->setValue(lastStage.data);
1657 output.valid->setValue(lastStage.valid);
1658 lastStage.ready->setValue(output.ready);
1661 struct SeqBufferStage {
1662 SeqBufferStage(Type dataType, InputHandshake &preStage,
BackedgeBuilder &bb,
1663 RTLBuilder &s,
size_t index,
1664 std::optional<int64_t> initValue)
1665 : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1668 c0s = createZeroDataConst(s, s.loc, dataType);
1669 currentStage.ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
1671 auto hasInitValue = s.constant(1, initValue.has_value());
1672 auto validBE = bb.
get(s.b.getI1Type());
1673 auto validReg = s.reg(getRegName(
"valid"), validBE, hasInitValue);
1674 auto readyBE = bb.
get(s.b.getI1Type());
1676 Value initValueCs = c0s;
1677 if (initValue.has_value())
1678 initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1683 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1684 buildControlBufferLogic(validReg, readyBE, dataReg);
1687 StringAttr getRegName(StringRef name) {
1688 return s.b.getStringAttr(name + std::to_string(index) +
"_reg");
1691 void buildControlBufferLogic(Value validReg,
Backedge &readyBE,
1693 auto c0I1 = s.constant(1, 0);
1694 auto readyRegWire = bb.
get(s.b.getI1Type());
1695 auto readyReg = s.reg(getRegName(
"ready"), readyRegWire, c0I1);
1699 currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1700 "controlValid" + std::to_string(index));
1703 auto notReadyReg = s.bNot(readyReg);
1706 auto succNotReady = s.bNot(*currentStage.ready);
1707 auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1708 auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1709 auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1712 auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1713 readyRegWire.setValue(resetSignal);
1716 auto ctrlDataRegBE = bb.
get(dataType);
1717 auto ctrlDataReg = s.reg(getRegName(
"ctrl_data"), ctrlDataRegBE, c0s);
1718 auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1719 currentStage.data = dataResult;
1721 auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1722 auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1723 ctrlDataRegBE.setValue(dataResetSignal);
1726 Value buildDataBufferLogic(Value validReg, Value initValue,
1730 auto notValidReg = s.bNot(validReg);
1731 auto emptyOrReady = s.bOr({notValidReg, readyBE});
1732 preStage.ready->setValue(emptyOrReady);
1738 auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1744 auto dataRegBE = bb.
get(dataType);
1746 s.reg(getRegName(
"data"),
1747 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1748 dataRegBE.setValue(dataReg);
1752 InputHandshake getOutput() {
return currentStage; }
1755 InputHandshake &preStage;
1756 InputHandshake currentStage;
1765 InputHandshake buildSeqBufferLogic(RTLBuilder &s,
BackedgeBuilder &bb,
1766 Type dataType,
unsigned size,
1767 InputHandshake &input,
1768 OutputHandshake &output,
1769 llvm::ArrayRef<int64_t> initValues)
const {
1772 InputHandshake currentStage = input;
1774 for (
unsigned i = 0; i < size; ++i) {
1775 bool isInitialized = i < initValues.size();
1777 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1778 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1782 return currentStage;
1786class IndexCastConversionPattern
1787 :
public HandshakeConversionPattern<arith::IndexCastOp> {
1789 using HandshakeConversionPattern<
1790 arith::IndexCastOp>::HandshakeConversionPattern;
1793 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1794 unsigned sourceBits =
1795 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1796 unsigned targetBits =
1797 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1798 if (targetBits < sourceBits)
1799 buildTruncateLogic(s, unwrappedIO, targetBits);
1801 buildExtendLogic(s, unwrappedIO,
true);
1805template <
typename T>
1808 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1809 MLIRContext *
context, OpBuilder &submoduleBuilder,
1810 HandshakeLoweringState &ls)
1812 submoduleBuilder(submoduleBuilder), ls(ls) {}
1813 using OpAdaptor =
typename T::Adaptor;
1816 matchAndRewrite(T op, OpAdaptor adaptor,
1817 ConversionPatternRewriter &rewriter)
const override {
1822 implModule = hw::HWModuleExternOp::create(
1823 submoduleBuilder, op.getLoc(),
1827 llvm::SmallVector<Value> operands = adaptor.getOperands();
1828 addSequentialIOOperandsIfNeeded(op, operands);
1829 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1830 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1835 OpBuilder &submoduleBuilder;
1836 HandshakeLoweringState &ls;
1841 using OpConversionPattern::OpConversionPattern;
1845 ConversionPatternRewriter &rewriter)
const override {
1849 HWModuleLike hwModule;
1850 if (op.isExternal()) {
1851 hwModule = hw::HWModuleExternOp::create(
1852 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1854 auto hwModuleOp = hw::HWModuleOp::create(
1855 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1856 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1857 rewriter.inlineBlockBefore(&op.getBody().front(),
1858 hwModuleOp.getBodyBlock()->getTerminator(),
1860 hwModule = hwModuleOp;
1867 auto *parentOp = op->getParentOp();
1868 auto *predeclModule =
1869 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1870 if (predeclModule) {
1871 if (failed(SymbolTable::replaceAllSymbolUses(
1872 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1874 rewriter.eraseOp(predeclModule);
1878 rewriter.eraseOp(op);
1890 ConversionTarget &target,
1892 OpBuilder &moduleBuilder) {
1894 std::map<std::string, unsigned> instanceNameCntr;
1895 NameUniquer instanceUniquer = [&](Operation *op) {
1897 if (
auto idAttr = op->getAttrOfType<IntegerAttr>(
"handshake_id"); idAttr) {
1900 instName +=
"_id" + std::to_string(idAttr.getValue().getZExtValue());
1903 instName += std::to_string(instanceNameCntr[instName]++);
1908 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1910 RewritePatternSet
patterns(op.getContext());
1911 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1913 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1914 SyncConversionPattern>(typeConverter, op.getContext(),
1919 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1920 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1921 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1922 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1923 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1924 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1925 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1926 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1927 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1928 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1929 UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1930 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1931 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1932 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1934 StructCreateConversionPattern,
1936 ConditionalBranchConversionPattern, MuxConversionPattern,
1937 PackConversionPattern, UnpackConversionPattern,
1938 ComparisonConversionPattern, BufferConversionPattern,
1939 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1940 MergeConversionPattern, ControlMergeConversionPattern,
1941 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1942 InstanceConversionPattern,
1944 ExtendConversionPattern<arith::ExtUIOp,
false>,
1945 ExtendConversionPattern<arith::ExtSIOp,
true>,
1946 TruncateConversionPattern, IndexCastConversionPattern>(
1947 typeConverter, op.getContext(), moduleBuilder, ls);
1949 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
1950 return op->emitOpError() <<
"error during conversion";
1955class HandshakeToHWPass
1956 :
public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1958 void runOnOperation()
override {
1959 mlir::ModuleOp mod = getOperation();
1963 for (
auto f : mod.getOps<
handshake::FuncOp>()) {
1965 f.emitOpError() <<
"HandshakeToHW: failed to verify that all values "
1966 "are used exactly once. Remember to run the "
1967 "fork/sink materialization pass before HW lowering.";
1968 signalPassFailure();
1974 std::string topLevel;
1976 SmallVector<std::string> sortedFuncs;
1978 signalPassFailure();
1982 ESITypeConverter typeConverter;
1983 ConversionTarget target(getContext());
1989 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1995 OpBuilder submoduleBuilder(mod.getContext());
1996 submoduleBuilder.setInsertionPointToStart(mod.getBody());
1997 for (
auto &funcName :
llvm::reverse(sortedFuncs)) {
1999 assert(funcOp &&
"handshake.func not found in module!");
2001 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2002 signalPassFailure();
2009 for (
auto hwModule : mod.getOps<
hw::HWModuleOp>())
2011 return signalPassFailure();
2017 RewritePatternSet
patterns(mod.getContext());
2018 patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2020 if (failed(applyPartialConversion(mod, target, std::move(
patterns)))) {
2021 mod->emitOpError() <<
"error during conversion";
2022 signalPassFailure();
2029 return std::make_unique<HandshakeToHWPass>();
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static Type tupleToStruct(TupleType tuple)
std::function< std::string(Operation *)> NameUniquer
static std::unique_ptr< Context > context
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static std::string getCallName(Operation *op)
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
void setOutput(unsigned i, Value v)
const ModulePortInfo & getPortList() const
This stores lookup tables to make manipulating and working with the IR more efficient.
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Channels are the basic communication primitives.
const Type * getInner() const
create(elements, Type result_type=None)
create(elements, Type result_type=None)
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
mlir::Type innerType(mlir::Type type)
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
This holds a decoded list of input/inout and output ports for a module or instance.
void eraseInput(size_t idx)
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.