26#include "mlir/Dialect/Arith/IR/Arith.h"
27#include "mlir/Dialect/MemRef/IR/MemRef.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/MathExtras.h"
37#define GEN_PASS_DEF_HANDSHAKETOHW
38#include "circt/Conversion/Passes.h.inc"
51 return toValidType(mlir::TupleType::get(types[0].getContext(), types));
56struct HandshakeLoweringState {
57 ModuleOp parentModule;
64class ESITypeConverter :
public TypeConverter {
67 addConversion([](Type type) -> Type {
return esiWrapper(type); });
69 addTargetMaterialization([&](mlir::OpBuilder &builder,
70 mlir::Type resultType, mlir::ValueRange inputs,
71 mlir::Location loc) -> mlir::Value {
72 if (inputs.size() != 1)
74 return UnrealizedConversionCastOp::create(builder, loc, resultType,
79 addSourceMaterialization([&](mlir::OpBuilder &builder,
80 mlir::Type resultType, mlir::ValueRange inputs,
81 mlir::Location loc) -> mlir::Value {
82 if (inputs.size() != 1)
84 return UnrealizedConversionCastOp::create(builder, loc, resultType,
99 std::string subModuleName = oldOp->getName().getStringRef().str();
100 std::replace(subModuleName.begin(), subModuleName.end(),
'.',
'_');
101 return subModuleName;
105 auto callOp = dyn_cast<handshake::InstanceOp>(op);
113 auto opType = op.getType();
114 if (
auto channelType = dyn_cast<esi::ChannelType>(opType))
115 return channelType.getInner();
121 SmallVector<Type> filterRes;
122 llvm::copy_if(input, std::back_inserter(filterRes),
123 [](Type type) {
return !isa<NoneType>(type); });
131 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
132 .Case<MemoryOp, ExternalMemoryOp>([&](
auto memOp) {
134 {memOp.getMemRefType().getElementType()}};
139 std::vector<Type> inTypes, outTypes;
140 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
142 llvm::transform(op->getResults(), std::back_inserter(outTypes),
154 std::string typeName;
156 if (type.isIntOrIndex()) {
157 if (
auto indexType = dyn_cast<IndexType>(type))
158 typeName +=
"_ui" + std::to_string(indexType.kInternalStorageBitWidth);
159 else if (type.isSignedInteger())
160 typeName +=
"_si" + std::to_string(type.getIntOrFloatBitWidth());
162 typeName +=
"_ui" + std::to_string(type.getIntOrFloatBitWidth());
163 }
else if (
auto tupleType = dyn_cast<TupleType>(type)) {
164 typeName +=
"_tuple";
167 }
else if (
auto structType = dyn_cast<hw::StructType>(type)) {
168 typeName +=
"_struct";
169 for (
auto element : structType.getElements())
170 typeName +=
"_" + element.name.str() +
getTypeName(loc, element.type);
172 emitError(loc) <<
"unsupported data type '" << type <<
"'";
179 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
180 return instanceOp.getModule().str();
185 if (
auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
186 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
187 auto intType = intAttr.getType();
189 if (intType.isSignedInteger())
190 subModuleName +=
"_c" + std::to_string(intAttr.getSInt());
191 else if (intType.isUnsignedInteger())
192 subModuleName +=
"_c" + std::to_string(intAttr.getUInt());
194 subModuleName +=
"_c" + std::to_string((uint64_t)intAttr.getInt());
196 oldOp->emitError(
"unsupported constant type");
201 if (!inTypes.empty())
202 subModuleName +=
"_in";
203 for (
auto inType : inTypes)
204 subModuleName +=
getTypeName(oldOp->getLoc(), inType);
206 if (!outTypes.empty())
207 subModuleName +=
"_out";
208 for (
auto outType : outTypes)
209 subModuleName +=
getTypeName(oldOp->getLoc(), outType);
212 if (
auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
213 subModuleName +=
"_id" + std::to_string(memOp.getId());
216 if (
auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
217 subModuleName +=
"_" + stringifyEnum(comOp.getPredicate()).str();
220 if (
auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
221 subModuleName +=
"_" + std::to_string(bufferOp.getNumSlots()) +
"slots";
222 if (bufferOp.isSequential())
223 subModuleName +=
"_seq";
225 subModuleName +=
"_fifo";
227 if (
auto initValues = bufferOp.getInitValues()) {
228 subModuleName +=
"_init";
229 for (
const Attribute e : *initValues) {
230 assert(isa<IntegerAttr>(e));
232 "_" + std::to_string(dyn_cast<IntegerAttr>(e).
getInt());
238 if (
auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
239 ctrlInterface && ctrlInterface.isControl()) {
241 subModuleName +=
"_" + std::to_string(oldOp->getNumOperands()) +
"ins_" +
242 std::to_string(oldOp->getNumResults()) +
"outs";
243 subModuleName +=
"_ctrl";
246 (!inTypes.empty() || !outTypes.empty()) &&
247 "Insufficient discriminating type info generated for the operation!");
250 return subModuleName;
262 if (
auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
264 if (
auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
271 HWModuleLike targetModule;
272 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
277 if (isa<handshake::InstanceOp>(oldOp))
279 "handshake.instance target modules should always have been lowered "
280 "before the modules that reference them!");
290static llvm::SmallVector<hw::detail::FieldInfo>
292 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
293 for (
auto port : portInfo)
294 fieldInfo.push_back({port.name, port.type});
302 auto *ctx = mod.getContext();
305 llvm::DenseMap<unsigned, Value> memrefPorts;
306 for (
auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
307 auto channel = dyn_cast<esi::ChannelType>(arg.getType());
308 if (channel && isa<MemRefType>(channel.getInner()))
309 memrefPorts[i] = arg;
312 if (memrefPorts.empty())
317 auto getMemoryIOInfo = [&](Location loc, Twine portName,
unsigned argIdx,
318 ArrayRef<hw::PortInfo> info,
322 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
326 for (
auto [i, arg] : memrefPorts) {
328 auto memName = mod.getArgName(i);
331 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
333 cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
334 extmemInstance, extmemInstance.getModuleNameAttr()));
345 SmallVector<PortInfo> outputs(portInfo.
getOutputs());
347 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_in", i, outputs,
348 hw::ModulePort::Direction::Input);
349 mod.insertPorts({{i, inPortInfo}}, {});
350 auto newInPort = mod.getArgumentForInput(i);
352 b.setInsertionPointToStart(mod.getBodyBlock());
353 auto newInPortExploded = hw::StructExplodeOp::create(
354 b, arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
355 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
359 unsigned outArgI = mod.getNumOutputPorts();
360 SmallVector<PortInfo> inputs(portInfo.
getInputs());
362 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_out", outArgI,
363 inputs, hw::ModulePort::Direction::Output);
365 auto memOutputArgs = extmemInstance.getOperands().drop_front();
366 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
368 b, arg.getLoc(), outPortInfo.type, memOutputArgs);
369 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
374 extmemInstance.erase();
378 mod.modifyPorts( {}, {},
389struct InputHandshake {
391 std::shared_ptr<Backedge> ready;
397struct OutputHandshake {
398 std::shared_ptr<Backedge> valid;
400 std::shared_ptr<Backedge>
data;
405struct HandshakeWire {
407 MLIRContext *ctx = dataType.getContext();
408 auto i1Type = IntegerType::get(ctx, 1);
409 valid = std::make_shared<Backedge>(bb.
get(i1Type));
410 ready = std::make_shared<Backedge>(bb.
get(i1Type));
411 data = std::make_shared<Backedge>(bb.
get(dataType));
416 InputHandshake getAsInput() {
return {*valid, ready, *
data}; }
417 OutputHandshake getAsOutput() {
return {valid, *ready,
data}; }
419 std::shared_ptr<Backedge> valid;
420 std::shared_ptr<Backedge> ready;
421 std::shared_ptr<Backedge>
data;
424template <
typename T,
typename TInner>
425llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
426 llvm::function_ref<T(TInner &)> extractor) {
427 llvm::SmallVector<T> result;
428 llvm::transform(container, std::back_inserter(result), extractor);
432 llvm::SmallVector<InputHandshake> inputs;
433 llvm::SmallVector<OutputHandshake> outputs;
435 llvm::SmallVector<Value> getInputValids() {
436 return extractValues<Value, InputHandshake>(
437 inputs, [](
auto &hs) {
return hs.valid; });
439 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
440 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
441 inputs, [](
auto &hs) {
return hs.ready; });
443 llvm::SmallVector<Value> getInputDatas() {
444 return extractValues<Value, InputHandshake>(
445 inputs, [](
auto &hs) {
return hs.data; });
447 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
448 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
449 outputs, [](
auto &hs) {
return hs.valid; });
451 llvm::SmallVector<Value> getOutputReadys() {
452 return extractValues<Value, OutputHandshake>(
453 outputs, [](
auto &hs) {
return hs.ready; });
455 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
456 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
457 outputs, [](
auto &hs) {
return hs.data; });
466 Value clk = Value(), Value rst = Value())
467 :
info(std::move(
info)), b(builder), loc(loc),
clk(
clk), rst(rst) {}
469 Value constant(
const APInt &apv, std::optional<StringRef> name = {}) {
472 bool isZeroWidth = apv.getBitWidth() == 0;
474 auto it = constants.find(apv);
475 if (it != constants.end())
481 constants[apv] = cval;
485 Value constant(
unsigned width, int64_t value,
486 std::optional<StringRef> name = {}) {
488 APInt(width, value,
false,
true));
490 std::pair<Value, Value>
wrap(Value data, Value valid,
491 std::optional<StringRef> name = {}) {
492 auto wrapOp = esi::WrapValidReadyOp::create(b, loc, data, valid);
493 return {wrapOp.getResult(0), wrapOp.getResult(1)};
495 std::pair<Value, Value>
unwrap(Value channel, Value ready,
496 std::optional<StringRef> name = {}) {
497 auto unwrapOp = esi::UnwrapValidReadyOp::create(b, loc, channel, ready);
498 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
502 Value
reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
503 Value rst = Value()) {
504 Value resolvedClk =
clk ?
clk : this->
clk;
505 Value resolvedRst = rst ? rst : this->rst;
507 "No global clock provided to this RTLBuilder - a clock "
508 "signal must be provided to the reg(...) function.");
510 "No global reset provided to this RTLBuilder - a reset "
511 "signal must be provided to the reg(...) function.");
517 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
518 std::optional<StringRef> name = {}) {
519 return comb::ICmpOp::create(b, loc, predicate, lhs, rhs);
522 Value buildNamedOp(llvm::function_ref<Value()> f,
523 std::optional<StringRef> name) {
526 Operation *op = v.getDefiningOp();
527 if (name.has_value()) {
528 op->setAttr(
"sv.namehint", b.getStringAttr(*name));
529 nameAttr = b.getStringAttr(*name);
535 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
537 [&]() {
return comb::AndOp::create(b, loc, values,
false); }, name);
540 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
542 [&]() {
return comb::OrOp::create(b, loc, values,
false); }, name);
546 Value bNot(Value value, std::optional<StringRef> name = {}) {
547 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
548 std::string inferedName;
552 value.getDefiningOp()->getAttrOfType<StringAttr>(
"sv.namehint")) {
553 inferedName = (
"not_" +
valueName.getValue()).str();
559 [&]() {
return comb::XorOp::create(b, loc, value, allOnes); }, name);
561 return b.createOrFold<
comb::XorOp>(loc, value, allOnes,
false);
564 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
566 [&]() {
return comb::ShlOp::create(b, loc, value, shift); }, name);
569 Value
concat(ValueRange values, std::optional<StringRef> name = {}) {
571 [&]() {
return comb::ConcatOp::create(b, loc, values); }, name);
575 Value pack(ValueRange values, Type structType = Type(),
576 std::optional<StringRef> name = {}) {
587 ValueRange unpack(Value value) {
588 auto structType = cast<hw::StructType>(value.getType());
589 llvm::SmallVector<Type> innerTypes;
590 structType.getInnerTypes(innerTypes);
591 return hw::StructExplodeOp::create(b, loc, innerTypes, value).getResults();
594 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
595 llvm::SmallVector<Value> bits;
596 for (
unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
602 Value rOr(Value v, std::optional<StringRef> name = {}) {
603 return buildNamedOp([&]() {
return bOr(toBits(v)); }, name);
607 Value extract(Value v,
unsigned lo,
unsigned hi,
608 std::optional<StringRef> name = {}) {
609 unsigned width = hi - lo + 1;
615 Value truncate(Value value,
unsigned width,
616 std::optional<StringRef> name = {}) {
617 return extract(value, 0, width - 1, name);
620 Value zext(Value value,
unsigned outWidth,
621 std::optional<StringRef> name = {}) {
622 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
623 assert(inWidth <= outWidth &&
"zext: input width must be <- output width.");
624 if (inWidth == outWidth)
626 auto c0 = constant(outWidth - inWidth, 0);
627 return concat({c0, value}, name);
630 Value sext(Value value,
unsigned outWidth,
631 std::optional<StringRef> name = {}) {
632 return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
636 Value bit(Value v,
unsigned index, std::optional<StringRef> name = {}) {
637 return extract(v, index, index, name);
641 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
647 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
655 Value mux(Value index, ValueRange values,
656 std::optional<StringRef> name = {}) {
657 if (values.size() == 2)
658 return comb::MuxOp::create(b, loc, index, values[1], values[0]);
660 return arrayGet(arrayCreate(values), index, name);
665 Value ohMux(Value index, ValueRange inputs) {
667 unsigned numInputs = inputs.size();
668 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
669 "one-hot select can't mux inputs");
673 auto dataType = inputs[0].getType();
675 isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
676 Value muxValue = constant(width, 0);
679 for (
size_t i = numInputs - 1; i != 0; --i) {
680 Value input = inputs[i];
681 Value selectBit = bit(index, i);
682 muxValue = mux(selectBit, {muxValue, input});
692 DenseMap<APInt, Value> constants;
697static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
698 return TypeSwitch<Type, Value>(type)
699 .Case<NoneType>([&](NoneType) {
return s.constant(0, 0); })
700 .Case<IntType, IntegerType>([&](
auto type) {
701 return s.constant(type.getIntOrFloatBitWidth(), 0);
703 .Case<hw::StructType>([&](
auto structType) {
704 SmallVector<Value> zeroValues;
705 for (
auto field : structType.getElements())
706 zeroValues.push_back(createZeroDataConst(s, loc, field.type));
709 .Default([&](Type) -> Value {
710 emitError(loc) <<
"unsupported type for zero value: " << type;
717addSequentialIOOperandsIfNeeded(Operation *op,
718 llvm::SmallVectorImpl<Value> &operands) {
722 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
724 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
726 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
733 HandshakeConversionPattern(ESITypeConverter &typeConverter,
734 MLIRContext *context, OpBuilder &submoduleBuilder,
735 HandshakeLoweringState &ls)
737 submoduleBuilder(submoduleBuilder), ls(ls) {}
739 using OpAdaptor =
typename T::Adaptor;
742 matchAndRewrite(T op, OpAdaptor adaptor,
743 ConversionPatternRewriter &rewriter)
const override {
752 submoduleBuilder.setInsertionPoint(op->getParentOp());
753 implModule = hw::HWModuleOp::create(
754 submoduleBuilder, op.getLoc(),
760 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
761 clk = ports.getInput(
"clock");
762 rst = ports.getInput(
"reset");
766 RTLBuilder s(ports.
getPortList(), b, op.getLoc(), clk, rst);
772 llvm::SmallVector<Value> operands = adaptor.getOperands();
773 addSequentialIOOperandsIfNeeded(op, operands);
774 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
775 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
788 UnwrappedIO unwrapped;
789 for (
auto port : ports.getInputs()) {
790 if (!isa<esi::ChannelType>(port.getType()))
793 auto ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
794 auto [
data, valid] = s.unwrap(port, *ready);
798 unwrapped.inputs.push_back(hs);
800 for (
auto &outputInfo : ports.
getPortList().getOutputs()) {
802 dyn_cast<esi::ChannelType>(outputInfo.type);
807 auto data = std::make_shared<Backedge>(bb.
get(innerType));
808 auto valid = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
809 auto [dataCh, ready] = s.wrap(*data, *valid);
813 ports.
setOutput(outputInfo.name, dataCh);
814 unwrapped.outputs.push_back(hs);
819 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
820 OutputHandshake &output, Value cond)
const {
821 auto validAndReady = s.bAnd({output.ready, cond});
822 for (
auto &input : inputs)
823 input.ready->setValue(validAndReady);
826 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
827 OutputHandshake &output)
const {
828 llvm::SmallVector<Value> valids;
829 for (
auto &input : inputs)
830 valids.push_back(input.valid);
831 Value allValid = s.bAnd(valids);
832 output.valid->setValue(allValid);
833 setAllReadyWithCond(s, inputs, output, allValid);
839 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
840 InputHandshake &select)
const {
842 size_t numInputs = unwrapped.inputs.size();
843 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
844 Value truncatedSelect =
845 select.data.getType().getIntOrFloatBitWidth() > selectWidth
846 ? s.truncate(select.data, selectWidth)
850 auto selectZext = s.zext(truncatedSelect, numInputs);
851 auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
852 auto &res = unwrapped.outputs[0];
855 auto selectedInputValid =
856 s.mux(truncatedSelect, unwrapped.getInputValids());
858 auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
859 res.valid->setValue(selAndInputValid);
860 auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
863 select.ready->setValue(resValidAndReady);
866 for (
auto [inIdx, in] :
llvm::enumerate(unwrapped.inputs)) {
868 auto isSelected = s.bit(select1h, inIdx);
872 auto activeAndResultValidAndReady =
873 s.bAnd({isSelected, resValidAndReady});
874 in.ready->setValue(activeAndResultValidAndReady);
878 res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
883 void buildForkLogic(RTLBuilder &s,
BackedgeBuilder &bb, InputHandshake &input,
884 ArrayRef<OutputHandshake> outputs)
const {
885 auto c0I1 = s.constant(1, 0);
886 llvm::SmallVector<Value> doneWires;
887 for (
auto [i, output] :
llvm::enumerate(outputs)) {
888 auto doneBE = bb.
get(s.b.getI1Type());
889 auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
890 auto emittedReg = s.reg(
"emitted_" + std::to_string(i), emitted, c0I1);
891 auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
892 output.valid->setValue(outValid);
893 auto validReady = s.bAnd({output.ready, outValid});
894 auto done = s.bOr({validReady, emittedReg},
"done" + std::to_string(i));
895 doneBE.setValue(done);
896 doneWires.push_back(done);
898 input.ready->setValue(s.bAnd(doneWires,
"allDone"));
904 void buildUnitRateJoinLogic(
905 RTLBuilder &s, UnwrappedIO &unwrappedIO,
906 llvm::function_ref<Value(ValueRange)> unitBuilder)
const {
907 assert(unwrappedIO.outputs.size() == 1 &&
908 "Expected exactly one output for unit-rate join actor");
910 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
913 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
914 unwrappedIO.outputs[0].data->setValue(unitRes);
917 void buildUnitRateForkLogic(
919 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder)
const {
920 assert(unwrappedIO.inputs.size() == 1 &&
921 "Expected exactly one input for unit-rate fork actor");
923 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
926 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
927 assert(unitResults.size() == unwrappedIO.outputs.size() &&
928 "Expected unit builder to return one result per output");
929 for (
auto [res, outport] :
llvm::zip(unitResults, unwrappedIO.outputs))
930 outport.
data->setValue(res);
933 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
934 bool signExtend)
const {
936 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
937 .getIntOrFloatBitWidth();
938 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
940 return s.sext(inputs[0], outWidth);
941 return s.zext(inputs[0], outWidth);
945 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
946 unsigned targetWidth)
const {
948 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
949 .getIntOrFloatBitWidth();
950 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
951 return s.truncate(inputs[0], outWidth);
956 static size_t getNumIndexBits(uint64_t numValues) {
957 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
960 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
962 DenseMap<size_t, Value> &indexMapping)
const {
963 auto numInputs = inputs.size();
964 auto priorityArb = defaultValue;
966 for (
size_t i = numInputs; i > 0; --i) {
967 size_t inputIndex = i - 1;
968 size_t oneHotIndex =
size_t{1} << inputIndex;
969 auto constIndex = s.constant(numInputs, oneHotIndex);
970 indexMapping[inputIndex] = constIndex;
971 priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
977 OpBuilder &submoduleBuilder;
978 HandshakeLoweringState &ls;
981class ForkConversionPattern :
public HandshakeConversionPattern<ForkOp> {
983 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
986 auto unwrapped = unwrapIO(s, bb, ports);
987 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
988 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
993class JoinConversionPattern :
public HandshakeConversionPattern<JoinOp> {
995 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
998 auto unwrappedIO = unwrapIO(s, bb, ports);
999 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
1000 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1004class SyncConversionPattern :
public HandshakeConversionPattern<SyncOp> {
1006 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1009 auto unwrappedIO = unwrapIO(s, bb, ports);
1012 HandshakeWire wire(bb, s.b.getNoneType());
1014 OutputHandshake output = wire.getAsOutput();
1015 buildJoinLogic(s, unwrappedIO.inputs, output);
1017 InputHandshake input = wire.getAsInput();
1025 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1029 for (
auto &&[in, out] :
llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1034class MuxConversionPattern :
public HandshakeConversionPattern<MuxOp> {
1036 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1039 auto unwrappedIO = unwrapIO(s, bb, ports);
1042 auto select = unwrappedIO.inputs[0];
1043 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1044 buildMuxLogic(s, unwrappedIO, select);
1048class InstanceConversionPattern
1049 :
public HandshakeConversionPattern<handshake::InstanceOp> {
1051 using HandshakeConversionPattern<
1052 handshake::InstanceOp>::HandshakeConversionPattern;
1056 "If we indeed perform conversion in post-order, this "
1057 "should never be called. The base HandshakeConversionPattern logic "
1058 "will instantiate the external module.");
1062class ESIInstanceConversionPattern
1065 ESIInstanceConversionPattern(MLIRContext *context,
1070 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1071 ConversionPatternRewriter &rewriter)
const override {
1077 SmallVector<Value> operands;
1078 for (
size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1080 operands.push_back(adaptor.getOperands()[i]);
1081 operands.push_back(adaptor.getClk());
1082 operands.push_back(adaptor.getRst());
1085 Operation *targetModule = symCache.
getDefinition(op.getModuleAttr());
1087 rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1088 op.getInstNameAttr(), operands);
1096class ReturnConversionPattern
1099 using OpConversionPattern::OpConversionPattern;
1101 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1102 ConversionPatternRewriter &rewriter)
const override {
1105 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1106 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1107 outputOp->setOperands(adaptor.getOperands());
1108 outputOp->moveAfter(&parent.getBodyBlock()->back());
1109 rewriter.eraseOp(op);
1116template <
typename TIn,
typename TOut = TIn>
1117class UnitRateConversionPattern :
public HandshakeConversionPattern<TIn> {
1119 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1122 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1123 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1128 return TOut::create(s.b, op.getLoc(), inputs,
1129 ArrayRef<NamedAttribute>{});
1134class PackConversionPattern :
public HandshakeConversionPattern<PackOp> {
1136 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1139 auto unwrappedIO = unwrapIO(s, bb, ports);
1140 buildUnitRateJoinLogic(s, unwrappedIO,
1141 [&](ValueRange inputs) {
return s.pack(inputs); });
1145class StructCreateConversionPattern
1146 :
public HandshakeConversionPattern<hw::StructCreateOp> {
1148 using HandshakeConversionPattern<
1152 auto unwrappedIO = unwrapIO(s, bb, ports);
1153 auto structType = op.getResult().getType();
1154 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1155 return s.pack(inputs, structType);
1160class UnpackConversionPattern :
public HandshakeConversionPattern<UnpackOp> {
1162 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1165 auto unwrappedIO = unwrapIO(s, bb, ports);
1166 buildUnitRateForkLogic(s, bb, unwrappedIO,
1167 [&](Value input) {
return s.unpack(input); });
1171class ConditionalBranchConversionPattern
1172 :
public HandshakeConversionPattern<ConditionalBranchOp> {
1174 using HandshakeConversionPattern<
1175 ConditionalBranchOp>::HandshakeConversionPattern;
1178 auto unwrappedIO = unwrapIO(s, bb, ports);
1179 auto cond = unwrappedIO.inputs[0];
1180 auto arg = unwrappedIO.inputs[1];
1181 auto trueRes = unwrappedIO.outputs[0];
1182 auto falseRes = unwrappedIO.outputs[1];
1184 auto condArgValid = s.bAnd({cond.valid, arg.valid});
1187 trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1188 falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1191 trueRes.data->setValue(arg.data);
1192 falseRes.data->setValue(arg.data);
1195 auto selectedResultReady =
1196 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1197 auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1198 arg.ready->setValue(condArgReady);
1199 cond.ready->setValue(condArgReady);
1203template <
typename TIn,
bool signExtend>
1204class ExtendConversionPattern :
public HandshakeConversionPattern<TIn> {
1206 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1209 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1210 this->buildExtendLogic(s, unwrappedIO, signExtend);
1214class ComparisonConversionPattern
1215 :
public HandshakeConversionPattern<arith::CmpIOp> {
1217 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1220 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1221 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1222 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1223 return comb::ICmpOp::create(s.b, op.getLoc(), predicate, inputs[0],
1228 switch (op.getPredicate()) {
1229 case arith::CmpIPredicate::eq:
1230 return buildCompareLogic(comb::ICmpPredicate::eq);
1231 case arith::CmpIPredicate::ne:
1232 return buildCompareLogic(comb::ICmpPredicate::ne);
1233 case arith::CmpIPredicate::slt:
1234 return buildCompareLogic(comb::ICmpPredicate::slt);
1235 case arith::CmpIPredicate::ult:
1236 return buildCompareLogic(comb::ICmpPredicate::ult);
1237 case arith::CmpIPredicate::sle:
1238 return buildCompareLogic(comb::ICmpPredicate::sle);
1239 case arith::CmpIPredicate::ule:
1240 return buildCompareLogic(comb::ICmpPredicate::ule);
1241 case arith::CmpIPredicate::sgt:
1242 return buildCompareLogic(comb::ICmpPredicate::sgt);
1243 case arith::CmpIPredicate::ugt:
1244 return buildCompareLogic(comb::ICmpPredicate::ugt);
1245 case arith::CmpIPredicate::sge:
1246 return buildCompareLogic(comb::ICmpPredicate::sge);
1247 case arith::CmpIPredicate::uge:
1248 return buildCompareLogic(comb::ICmpPredicate::uge);
1250 assert(
false &&
"invalid CmpIOp");
1254class TruncateConversionPattern
1255 :
public HandshakeConversionPattern<arith::TruncIOp> {
1257 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1260 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1261 unsigned targetBits =
1262 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1263 buildTruncateLogic(s, unwrappedIO, targetBits);
1267class ControlMergeConversionPattern
1268 :
public HandshakeConversionPattern<ControlMergeOp> {
1270 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1273 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1274 auto resData = unwrappedIO.outputs[0];
1275 auto resIndex = unwrappedIO.outputs[1];
1278 unsigned numInputs = unwrappedIO.inputs.size();
1279 auto indexType = s.b.getIntegerType(numInputs);
1280 Value noWinner = s.constant(numInputs, 0);
1281 Value c0I1 = s.constant(1, 0);
1284 auto won = bb.
get(indexType);
1285 Value wonReg = s.reg(
"won_reg", won, noWinner);
1288 auto win = bb.
get(indexType);
1292 auto fired = bb.
get(s.b.getI1Type());
1295 auto resultEmitted = bb.
get(s.b.getI1Type());
1296 Value resultEmittedReg = s.reg(
"result_emitted_reg", resultEmitted, c0I1);
1297 auto indexEmitted = bb.
get(s.b.getI1Type());
1298 Value indexEmittedReg = s.reg(
"index_emitted_reg", indexEmitted, c0I1);
1301 auto resultDone = bb.
get(s.b.getI1Type());
1302 auto indexDone = bb.
get(s.b.getI1Type());
1306 auto hasWinnerCondition = s.rOr({win});
1307 auto hadWinnerCondition = s.rOr({wonReg});
1315 DenseMap<size_t, Value> argIndexValues;
1316 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1317 noWinner, argIndexValues);
1318 priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1319 win.setValue(priorityArb);
1329 auto resultNotEmitted = s.bNot(resultEmittedReg);
1330 auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1331 resData.valid->setValue(resultValid);
1332 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1334 auto indexNotEmitted = s.bNot(indexEmittedReg);
1335 auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1336 resIndex.valid->setValue(indexValid);
1340 SmallVector<Value, 8> indexOutputs;
1341 for (
size_t i = 0; i < numInputs; ++i)
1342 indexOutputs.push_back(s.constant(64, i));
1344 auto indexOutput = s.ohMux(win, indexOutputs);
1345 resIndex.data->setValue(indexOutput);
1351 won.setValue(s.mux(fired, {win, noWinner}));
1356 auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1357 resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1359 auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1360 indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1364 fired.setValue(s.bAnd({resultDone, indexDone}));
1370 resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1371 indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1376 auto winnerOrDefault = s.mux(fired, {noWinner, win});
1377 for (
auto [i, ir] :
llvm::enumerate(unwrappedIO.getInputReadys())) {
1378 auto &indexValue = argIndexValues[i];
1379 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1384class MergeConversionPattern :
public HandshakeConversionPattern<MergeOp> {
1386 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1389 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1390 auto resData = unwrappedIO.outputs[0];
1393 unsigned numInputs = unwrappedIO.inputs.size();
1394 auto indexType = s.b.getIntegerType(numInputs);
1395 Value noWinner = s.constant(numInputs, 0);
1398 auto win = bb.
get(indexType);
1401 auto hasWinnerCondition = s.rOr(win);
1408 DenseMap<size_t, Value> argIndexValues;
1409 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1410 noWinner, argIndexValues);
1411 win.setValue(priorityArb);
1418 resData.valid->setValue(hasWinnerCondition);
1419 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1424 auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1429 auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1430 for (
auto [i, ir] :
llvm::enumerate(unwrappedIO.getInputReadys())) {
1431 auto &indexValue = argIndexValues[i];
1432 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1437class LoadConversionPattern
1438 :
public HandshakeConversionPattern<handshake::LoadOp> {
1440 using HandshakeConversionPattern<
1441 handshake::LoadOp>::HandshakeConversionPattern;
1444 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1445 auto addrFromUser = unwrappedIO.inputs[0];
1446 auto dataFromMem = unwrappedIO.inputs[1];
1447 auto controlIn = unwrappedIO.inputs[2];
1448 auto dataToUser = unwrappedIO.outputs[0];
1449 auto addrToMem = unwrappedIO.outputs[1];
1451 addrToMem.data->setValue(addrFromUser.data);
1452 dataToUser.data->setValue(dataFromMem.data);
1456 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1460 dataToUser.valid->setValue(dataFromMem.valid);
1461 dataFromMem.ready->setValue(dataToUser.ready);
1465class StoreConversionPattern
1466 :
public HandshakeConversionPattern<handshake::StoreOp> {
1468 using HandshakeConversionPattern<
1469 handshake::StoreOp>::HandshakeConversionPattern;
1472 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1473 auto addrFromUser = unwrappedIO.inputs[0];
1474 auto dataFromUser = unwrappedIO.inputs[1];
1475 auto controlIn = unwrappedIO.inputs[2];
1476 auto dataToMem = unwrappedIO.outputs[0];
1477 auto addrToMem = unwrappedIO.outputs[1];
1480 auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1484 HandshakeWire joinWire(bb, s.b.getNoneType());
1485 joinWire.ready->setValue(outputsReady);
1486 OutputHandshake joinOutput = joinWire.getAsOutput();
1487 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1490 addrToMem.data->setValue(addrFromUser.data);
1491 dataToMem.data->setValue(dataFromUser.data);
1494 addrToMem.valid->setValue(*joinWire.valid);
1495 dataToMem.valid->setValue(*joinWire.valid);
1499class MemoryConversionPattern
1500 :
public HandshakeConversionPattern<handshake::MemoryOp> {
1502 using HandshakeConversionPattern<
1503 handshake::MemoryOp>::HandshakeConversionPattern;
1506 auto loc = op.getLoc();
1509 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1511 InputHandshake &
addr;
1512 OutputHandshake &
data;
1513 OutputHandshake &done;
1516 InputHandshake &
addr;
1517 InputHandshake &
data;
1518 OutputHandshake &done;
1520 SmallVector<LoadPort, 4> loadPorts;
1521 SmallVector<StorePort, 4> storePorts;
1523 unsigned stCount = op.getStCount();
1524 unsigned ldCount = op.getLdCount();
1525 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
1526 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1527 unwrappedIO.outputs[i],
1528 unwrappedIO.outputs[ldCount + stCount + i]};
1529 loadPorts.push_back(port);
1532 for (
unsigned i = 0, e = stCount; i != e; ++i) {
1533 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1534 unwrappedIO.inputs[i * 2],
1535 unwrappedIO.outputs[ldCount + i]};
1536 storePorts.push_back(port);
1540 auto c0I0 = s.constant(0, 0);
1542 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1543 auto hlmem = seq::HLMemOp::create(
1544 s.b, loc, s.clk, s.rst,
1545 "_handshake_memory_" + std::to_string(op.getId()),
1546 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1549 for (
auto &ld : loadPorts) {
1550 llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1551 auto readData = seq::ReadPortOp::create(s.b, loc, hlmem.getHandle(),
1552 addresses, ld.addr.valid,
1554 ld.data.data->setValue(readData);
1555 ld.done.data->setValue(c0I0);
1557 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1561 for (
auto &st : storePorts) {
1564 auto writeValidBufferMuxBE = bb.
get(s.b.getI1Type());
1565 auto writeValidBuffer =
1566 s.reg(
"writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1567 st.done.valid->setValue(writeValidBuffer);
1568 st.done.data->setValue(c0I0);
1572 auto storeCompleted =
1573 s.bAnd({st.done.ready, writeValidBuffer},
"storeCompleted");
1577 auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1578 auto emptyOrComplete =
1579 s.bOr({notWriteValidBuffer, storeCompleted},
"emptyOrComplete");
1582 st.addr.ready->setValue(emptyOrComplete);
1583 st.data.ready->setValue(emptyOrComplete);
1586 auto writeValid = s.bAnd({st.addr.valid, st.data.valid},
"writeValid");
1592 writeValidBufferMuxBE.setValue(
1593 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1597 llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1598 seq::WritePortOp::create(s.b, loc, hlmem.getHandle(), addresses,
1599 st.data.data, writeValid,
1605class SinkConversionPattern :
public HandshakeConversionPattern<SinkOp> {
1607 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1610 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1612 unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1616class SourceConversionPattern :
public HandshakeConversionPattern<SourceOp> {
1618 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1621 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1623 unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1624 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1628class ConstantConversionPattern
1629 :
public HandshakeConversionPattern<handshake::ConstantOp> {
1631 using HandshakeConversionPattern<
1632 handshake::ConstantOp>::HandshakeConversionPattern;
1635 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1636 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1637 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1638 auto constantValue = op->getAttrOfType<IntegerAttr>(
"value").getValue();
1639 unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1643class BufferConversionPattern :
public HandshakeConversionPattern<BufferOp> {
1645 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1648 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1649 auto input = unwrappedIO.inputs[0];
1650 auto output = unwrappedIO.outputs[0];
1651 InputHandshake lastStage;
1652 SmallVector<int64_t> initValues;
1655 if (op.getInitValues())
1656 initValues = op.getInitValueArray();
1659 buildSeqBufferLogic(s, bb,
toValidType(op.getDataType()),
1660 op.getNumSlots(), input, output, initValues);
1663 output.data->setValue(lastStage.data);
1664 output.valid->setValue(lastStage.valid);
1665 lastStage.ready->setValue(output.ready);
1668 struct SeqBufferStage {
1669 SeqBufferStage(Type dataType, InputHandshake &preStage,
BackedgeBuilder &bb,
1670 RTLBuilder &s,
size_t index,
1671 std::optional<int64_t> initValue)
1672 : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1675 c0s = createZeroDataConst(s, s.loc, dataType);
1676 currentStage.ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
1678 auto hasInitValue = s.constant(1, initValue.has_value());
1679 auto validBE = bb.
get(s.b.getI1Type());
1680 auto validReg = s.reg(getRegName(
"valid"), validBE, hasInitValue);
1681 auto readyBE = bb.
get(s.b.getI1Type());
1683 Value initValueCs = c0s;
1684 if (initValue.has_value())
1685 initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1690 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1691 buildControlBufferLogic(validReg, readyBE, dataReg);
1694 StringAttr getRegName(StringRef name) {
1695 return s.b.getStringAttr(name + std::to_string(index) +
"_reg");
1698 void buildControlBufferLogic(Value validReg,
Backedge &readyBE,
1700 auto c0I1 = s.constant(1, 0);
1701 auto readyRegWire = bb.
get(s.b.getI1Type());
1702 auto readyReg = s.reg(getRegName(
"ready"), readyRegWire, c0I1);
1706 currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1707 "controlValid" + std::to_string(index));
1710 auto notReadyReg = s.bNot(readyReg);
1713 auto succNotReady = s.bNot(*currentStage.ready);
1714 auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1715 auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1716 auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1719 auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1720 readyRegWire.setValue(resetSignal);
1723 auto ctrlDataRegBE = bb.
get(dataType);
1724 auto ctrlDataReg = s.reg(getRegName(
"ctrl_data"), ctrlDataRegBE, c0s);
1725 auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1726 currentStage.data = dataResult;
1728 auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1729 auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1730 ctrlDataRegBE.setValue(dataResetSignal);
1733 Value buildDataBufferLogic(Value validReg, Value initValue,
1737 auto notValidReg = s.bNot(validReg);
1738 auto emptyOrReady = s.bOr({notValidReg, readyBE});
1739 preStage.ready->setValue(emptyOrReady);
1745 auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1751 auto dataRegBE = bb.
get(dataType);
1753 s.reg(getRegName(
"data"),
1754 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1755 dataRegBE.setValue(dataReg);
1759 InputHandshake getOutput() {
return currentStage; }
1762 InputHandshake &preStage;
1763 InputHandshake currentStage;
1772 InputHandshake buildSeqBufferLogic(RTLBuilder &s,
BackedgeBuilder &bb,
1773 Type dataType,
unsigned size,
1774 InputHandshake &input,
1775 OutputHandshake &output,
1776 llvm::ArrayRef<int64_t> initValues)
const {
1779 InputHandshake currentStage = input;
1781 for (
unsigned i = 0; i < size; ++i) {
1782 bool isInitialized = i < initValues.size();
1784 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1785 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1789 return currentStage;
1793class IndexCastConversionPattern
1794 :
public HandshakeConversionPattern<arith::IndexCastOp> {
1796 using HandshakeConversionPattern<
1797 arith::IndexCastOp>::HandshakeConversionPattern;
1800 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1801 unsigned sourceBits =
1802 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1803 unsigned targetBits =
1804 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1805 if (targetBits < sourceBits)
1806 buildTruncateLogic(s, unwrappedIO, targetBits);
1808 buildExtendLogic(s, unwrappedIO,
true);
1812template <
typename T>
1815 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1816 MLIRContext *context, OpBuilder &submoduleBuilder,
1817 HandshakeLoweringState &ls)
1819 submoduleBuilder(submoduleBuilder), ls(ls) {}
1820 using OpAdaptor =
typename T::Adaptor;
1823 matchAndRewrite(T op, OpAdaptor adaptor,
1824 ConversionPatternRewriter &rewriter)
const override {
1834 llvm::SmallVector<Value> operands = adaptor.getOperands();
1835 addSequentialIOOperandsIfNeeded(op, operands);
1836 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1837 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1842 OpBuilder &submoduleBuilder;
1843 HandshakeLoweringState &ls;
1848 using OpConversionPattern::OpConversionPattern;
1852 ConversionPatternRewriter &rewriter)
const override {
1856 HWModuleLike hwModule;
1857 if (op.isExternal()) {
1858 hwModule = hw::HWModuleExternOp::create(
1859 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1861 auto hwModuleOp = hw::HWModuleOp::create(
1862 rewriter, op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1863 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1864 rewriter.inlineBlockBefore(&op.getBody().front(),
1865 hwModuleOp.getBodyBlock()->getTerminator(),
1867 hwModule = hwModuleOp;
1874 auto *parentOp = op->getParentOp();
1875 auto *predeclModule =
1876 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1877 if (predeclModule) {
1878 if (failed(SymbolTable::replaceAllSymbolUses(
1879 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1881 rewriter.eraseOp(predeclModule);
1885 rewriter.eraseOp(op);
1897 ConversionTarget &target,
1899 OpBuilder &moduleBuilder) {
1901 std::map<std::string, unsigned> instanceNameCntr;
1902 NameUniquer instanceUniquer = [&](Operation *op) {
1904 if (
auto idAttr = op->getAttrOfType<IntegerAttr>(
"handshake_id"); idAttr) {
1907 instName +=
"_id" + std::to_string(idAttr.getValue().getZExtValue());
1910 instName += std::to_string(instanceNameCntr[instName]++);
1915 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1917 RewritePatternSet
patterns(op.getContext());
1918 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1920 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1921 SyncConversionPattern>(typeConverter, op.getContext(),
1926 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1927 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1928 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1929 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1930 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1931 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1932 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1933 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1934 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1935 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1936 UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1937 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1938 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1939 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1941 StructCreateConversionPattern,
1943 ConditionalBranchConversionPattern, MuxConversionPattern,
1944 PackConversionPattern, UnpackConversionPattern,
1945 ComparisonConversionPattern, BufferConversionPattern,
1946 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1947 MergeConversionPattern, ControlMergeConversionPattern,
1948 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1949 InstanceConversionPattern,
1951 ExtendConversionPattern<arith::ExtUIOp,
false>,
1952 ExtendConversionPattern<arith::ExtSIOp,
true>,
1953 TruncateConversionPattern, IndexCastConversionPattern>(
1954 typeConverter, op.getContext(), moduleBuilder, ls);
1956 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
1957 return op->emitOpError() <<
"error during conversion";
1962class HandshakeToHWPass
1963 :
public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1965 void runOnOperation()
override {
1966 mlir::ModuleOp mod = getOperation();
1970 for (
auto f : mod.getOps<
handshake::FuncOp>()) {
1972 f.emitOpError() <<
"HandshakeToHW: failed to verify that all values "
1973 "are used exactly once. Remember to run the "
1974 "fork/sink materialization pass before HW lowering.";
1975 signalPassFailure();
1981 std::string topLevel;
1983 SmallVector<std::string> sortedFuncs;
1985 signalPassFailure();
1989 ESITypeConverter typeConverter;
1990 ConversionTarget target(getContext());
1996 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
2002 OpBuilder submoduleBuilder(mod.getContext());
2003 submoduleBuilder.setInsertionPointToStart(mod.getBody());
2004 for (
auto &funcName :
llvm::reverse(sortedFuncs)) {
2006 assert(funcOp &&
"handshake.func not found in module!");
2008 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2009 signalPassFailure();
2016 for (
auto hwModule : mod.getOps<
hw::HWModuleOp>())
2018 return signalPassFailure();
2024 RewritePatternSet
patterns(mod.getContext());
2025 patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2027 if (failed(applyPartialConversion(mod, target, std::move(
patterns)))) {
2028 mod->emitOpError() <<
"error during conversion";
2029 signalPassFailure();
2036 return std::make_unique<HandshakeToHWPass>();
AIGLongestPathObject wrap(llvm::PointerUnion< Object *, DataflowPath::OutputPort * > object)
assert(baseType &&"element must be base type")
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static Type tupleToStruct(TupleType tuple)
std::function< std::string(Operation *)> NameUniquer
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static std::string getCallName(Operation *op)
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
void setOutput(unsigned i, Value v)
const ModulePortInfo & getPortList() const
This stores lookup tables to make manipulating and working with the IR more efficient.
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Channels are the basic communication primitives.
const Type * getInner() const
create(elements, Type result_type=None)
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
mlir::Type innerType(mlir::Type type)
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
This holds a decoded list of input/inout and output ports for a module or instance.
void eraseInput(size_t idx)
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.