26#include "mlir/Dialect/Arith/IR/Arith.h"
27#include "mlir/Dialect/MemRef/IR/MemRef.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/MathExtras.h"
37#define GEN_PASS_DEF_HANDSHAKETOHW
38#include "circt/Conversion/Passes.h.inc"
51 return toValidType(mlir::TupleType::get(types[0].getContext(), types));
56struct HandshakeLoweringState {
57 ModuleOp parentModule;
64class ESITypeConverter :
public TypeConverter {
67 addConversion([](Type type) -> Type {
return esiWrapper(type); });
69 addTargetMaterialization([&](mlir::OpBuilder &builder,
70 mlir::Type resultType, mlir::ValueRange inputs,
71 mlir::Location loc) -> mlir::Value {
72 if (inputs.size() != 1)
75 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
79 addSourceMaterialization([&](mlir::OpBuilder &builder,
80 mlir::Type resultType, mlir::ValueRange inputs,
81 mlir::Location loc) -> mlir::Value {
82 if (inputs.size() != 1)
85 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
99 std::string subModuleName = oldOp->getName().getStringRef().str();
100 std::replace(subModuleName.begin(), subModuleName.end(),
'.',
'_');
101 return subModuleName;
105 auto callOp = dyn_cast<handshake::InstanceOp>(op);
113 auto opType = op.getType();
114 if (
auto channelType = dyn_cast<esi::ChannelType>(opType))
115 return channelType.getInner();
121 SmallVector<Type> filterRes;
122 llvm::copy_if(input, std::back_inserter(filterRes),
123 [](Type type) {
return !isa<NoneType>(type); });
131 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
132 .Case<MemoryOp, ExternalMemoryOp>([&](
auto memOp) {
134 {memOp.getMemRefType().getElementType()}};
139 std::vector<Type> inTypes, outTypes;
140 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
142 llvm::transform(op->getResults(), std::back_inserter(outTypes),
154 std::string typeName;
156 if (type.isIntOrIndex()) {
157 if (
auto indexType = dyn_cast<IndexType>(type))
158 typeName +=
"_ui" + std::to_string(indexType.kInternalStorageBitWidth);
159 else if (type.isSignedInteger())
160 typeName +=
"_si" + std::to_string(type.getIntOrFloatBitWidth());
162 typeName +=
"_ui" + std::to_string(type.getIntOrFloatBitWidth());
163 }
else if (
auto tupleType = dyn_cast<TupleType>(type)) {
164 typeName +=
"_tuple";
167 }
else if (
auto structType = dyn_cast<hw::StructType>(type)) {
168 typeName +=
"_struct";
169 for (
auto element : structType.getElements())
170 typeName +=
"_" + element.name.str() +
getTypeName(loc, element.type);
172 emitError(loc) <<
"unsupported data type '" << type <<
"'";
179 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
180 return instanceOp.getModule().str();
185 if (
auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
186 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
187 auto intType = intAttr.getType();
189 if (intType.isSignedInteger())
190 subModuleName +=
"_c" + std::to_string(intAttr.getSInt());
191 else if (intType.isUnsignedInteger())
192 subModuleName +=
"_c" + std::to_string(intAttr.getUInt());
194 subModuleName +=
"_c" + std::to_string((uint64_t)intAttr.getInt());
196 oldOp->emitError(
"unsupported constant type");
201 if (!inTypes.empty())
202 subModuleName +=
"_in";
203 for (
auto inType : inTypes)
204 subModuleName +=
getTypeName(oldOp->getLoc(), inType);
206 if (!outTypes.empty())
207 subModuleName +=
"_out";
208 for (
auto outType : outTypes)
209 subModuleName +=
getTypeName(oldOp->getLoc(), outType);
212 if (
auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
213 subModuleName +=
"_id" + std::to_string(memOp.getId());
216 if (
auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
217 subModuleName +=
"_" + stringifyEnum(comOp.getPredicate()).str();
220 if (
auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
221 subModuleName +=
"_" + std::to_string(bufferOp.getNumSlots()) +
"slots";
222 if (bufferOp.isSequential())
223 subModuleName +=
"_seq";
225 subModuleName +=
"_fifo";
227 if (
auto initValues = bufferOp.getInitValues()) {
228 subModuleName +=
"_init";
229 for (
const Attribute e : *initValues) {
230 assert(isa<IntegerAttr>(e));
232 "_" + std::to_string(dyn_cast<IntegerAttr>(e).
getInt());
238 if (
auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
239 ctrlInterface && ctrlInterface.isControl()) {
241 subModuleName +=
"_" + std::to_string(oldOp->getNumOperands()) +
"ins_" +
242 std::to_string(oldOp->getNumResults()) +
"outs";
243 subModuleName +=
"_ctrl";
246 (!inTypes.empty() || !outTypes.empty()) &&
247 "Insufficient discriminating type info generated for the operation!");
250 return subModuleName;
262 if (
auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
264 if (
auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
271 HWModuleLike targetModule;
272 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
277 if (isa<handshake::InstanceOp>(oldOp))
279 "handshake.instance target modules should always have been lowered "
280 "before the modules that reference them!");
290static llvm::SmallVector<hw::detail::FieldInfo>
292 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
293 for (
auto port : portInfo)
294 fieldInfo.push_back({port.name, port.type});
302 auto *ctx = mod.getContext();
305 llvm::DenseMap<unsigned, Value> memrefPorts;
306 for (
auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
307 auto channel = dyn_cast<esi::ChannelType>(arg.getType());
308 if (channel && isa<MemRefType>(channel.getInner()))
309 memrefPorts[i] = arg;
312 if (memrefPorts.empty())
317 auto getMemoryIOInfo = [&](Location loc, Twine portName,
unsigned argIdx,
318 ArrayRef<hw::PortInfo> info,
322 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
326 for (
auto [i, arg] : memrefPorts) {
328 auto memName = mod.getArgName(i);
331 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
333 cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
334 extmemInstance, extmemInstance.getModuleNameAttr()));
345 SmallVector<PortInfo> outputs(portInfo.
getOutputs());
347 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_in", i, outputs,
348 hw::ModulePort::Direction::Input);
349 mod.insertPorts({{i, inPortInfo}}, {});
350 auto newInPort = mod.getArgumentForInput(i);
352 b.setInsertionPointToStart(mod.getBodyBlock());
353 auto newInPortExploded = b.create<hw::StructExplodeOp>(
354 arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
355 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
359 unsigned outArgI = mod.getNumOutputPorts();
360 SmallVector<PortInfo> inputs(portInfo.
getInputs());
362 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_out", outArgI,
363 inputs, hw::ModulePort::Direction::Output);
365 auto memOutputArgs = extmemInstance.getOperands().drop_front();
366 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
368 arg.getLoc(), outPortInfo.type, memOutputArgs);
369 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
374 extmemInstance.erase();
378 mod.modifyPorts( {}, {},
389struct InputHandshake {
391 std::shared_ptr<Backedge> ready;
397struct OutputHandshake {
398 std::shared_ptr<Backedge> valid;
400 std::shared_ptr<Backedge>
data;
405struct HandshakeWire {
407 MLIRContext *ctx = dataType.getContext();
408 auto i1Type = IntegerType::get(ctx, 1);
409 valid = std::make_shared<Backedge>(bb.
get(i1Type));
410 ready = std::make_shared<Backedge>(bb.
get(i1Type));
411 data = std::make_shared<Backedge>(bb.
get(dataType));
416 InputHandshake getAsInput() {
return {*valid, ready, *
data}; }
417 OutputHandshake getAsOutput() {
return {valid, *ready,
data}; }
419 std::shared_ptr<Backedge> valid;
420 std::shared_ptr<Backedge> ready;
421 std::shared_ptr<Backedge>
data;
424template <
typename T,
typename TInner>
425llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
426 llvm::function_ref<T(TInner &)> extractor) {
427 llvm::SmallVector<T> result;
428 llvm::transform(container, std::back_inserter(result), extractor);
432 llvm::SmallVector<InputHandshake> inputs;
433 llvm::SmallVector<OutputHandshake> outputs;
435 llvm::SmallVector<Value> getInputValids() {
436 return extractValues<Value, InputHandshake>(
437 inputs, [](
auto &hs) {
return hs.valid; });
439 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
440 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
441 inputs, [](
auto &hs) {
return hs.ready; });
443 llvm::SmallVector<Value> getInputDatas() {
444 return extractValues<Value, InputHandshake>(
445 inputs, [](
auto &hs) {
return hs.data; });
447 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
448 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
449 outputs, [](
auto &hs) {
return hs.valid; });
451 llvm::SmallVector<Value> getOutputReadys() {
452 return extractValues<Value, OutputHandshake>(
453 outputs, [](
auto &hs) {
return hs.ready; });
455 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
456 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
457 outputs, [](
auto &hs) {
return hs.data; });
466 Value clk = Value(), Value rst = Value())
467 : info(std::move(info)), b(builder), loc(loc),
clk(
clk), rst(rst) {}
469 Value constant(
const APInt &apv, std::optional<StringRef> name = {}) {
472 bool isZeroWidth = apv.getBitWidth() == 0;
474 auto it = constants.find(apv);
475 if (it != constants.end())
481 constants[apv] = cval;
485 Value constant(
unsigned width, int64_t value,
486 std::optional<StringRef> name = {}) {
488 APInt(width, value,
false,
true));
490 std::pair<Value, Value>
wrap(Value data, Value valid,
491 std::optional<StringRef> name = {}) {
492 auto wrapOp = b.create<esi::WrapValidReadyOp>(loc,
data, valid);
493 return {wrapOp.getResult(0), wrapOp.getResult(1)};
495 std::pair<Value, Value>
unwrap(Value channel, Value ready,
496 std::optional<StringRef> name = {}) {
497 auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
498 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
502 Value
reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
503 Value rst = Value()) {
504 Value resolvedClk =
clk ?
clk : this->
clk;
505 Value resolvedRst = rst ? rst : this->rst;
507 "No global clock provided to this RTLBuilder - a clock "
508 "signal must be provided to the reg(...) function.");
510 "No global reset provided to this RTLBuilder - a reset "
511 "signal must be provided to the reg(...) function.");
513 return b.create<
seq::CompRegOp>(loc, in, resolvedClk, resolvedRst, rstValue,
517 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
518 std::optional<StringRef> name = {}) {
519 return b.
create<comb::ICmpOp>(loc, predicate, lhs, rhs);
522 Value buildNamedOp(llvm::function_ref<Value()> f,
523 std::optional<StringRef> name) {
526 Operation *op = v.getDefiningOp();
527 if (name.has_value()) {
528 op->setAttr(
"sv.namehint", b.getStringAttr(*name));
529 nameAttr = b.getStringAttr(*name);
535 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
537 [&]() {
return b.create<
comb::AndOp>(loc, values,
false); }, name);
540 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
542 [&]() {
return b.create<
comb::OrOp>(loc, values,
false); }, name);
546 Value bNot(Value value, std::optional<StringRef> name = {}) {
547 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
548 std::string inferedName;
552 value.getDefiningOp()->getAttrOfType<StringAttr>(
"sv.namehint")) {
553 inferedName = (
"not_" +
valueName.getValue()).str();
559 [&]() {
return b.create<
comb::XorOp>(loc, value, allOnes); }, name);
561 return b.createOrFold<
comb::XorOp>(loc, value, allOnes,
false);
564 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
566 [&]() {
return b.create<
comb::ShlOp>(loc, value, shift); }, name);
569 Value
concat(ValueRange values, std::optional<StringRef> name = {}) {
570 return buildNamedOp([&]() {
return b.create<
comb::ConcatOp>(loc, values); },
575 Value pack(ValueRange values, Type structType = Type(),
576 std::optional<StringRef> name = {}) {
585 ValueRange unpack(Value value) {
586 auto structType = cast<hw::StructType>(value.getType());
587 llvm::SmallVector<Type> innerTypes;
588 structType.getInnerTypes(innerTypes);
589 return b.create<hw::StructExplodeOp>(loc, innerTypes, value).getResults();
592 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
593 llvm::SmallVector<Value> bits;
594 for (
unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
600 Value rOr(Value v, std::optional<StringRef> name = {}) {
601 return buildNamedOp([&]() {
return bOr(toBits(v)); }, name);
605 Value extract(Value v,
unsigned lo,
unsigned hi,
606 std::optional<StringRef> name = {}) {
607 unsigned width = hi - lo + 1;
613 Value truncate(Value value,
unsigned width,
614 std::optional<StringRef> name = {}) {
615 return extract(value, 0, width - 1, name);
618 Value zext(Value value,
unsigned outWidth,
619 std::optional<StringRef> name = {}) {
620 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
621 assert(inWidth <= outWidth &&
"zext: input width must be <- output width.");
622 if (inWidth == outWidth)
624 auto c0 = constant(outWidth - inWidth, 0);
625 return concat({c0, value}, name);
628 Value sext(Value value,
unsigned outWidth,
629 std::optional<StringRef> name = {}) {
630 return comb::createOrFoldSExt(loc, value, b.getIntegerType(outWidth), b);
634 Value bit(Value v,
unsigned index, std::optional<StringRef> name = {}) {
635 return extract(v, index, index, name);
639 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
645 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
647 [&]() {
return b.create<
hw::ArrayGetOp>(loc, array, index); }, name);
653 Value mux(Value index, ValueRange values,
654 std::optional<StringRef> name = {}) {
655 if (values.size() == 2)
656 return b.create<
comb::MuxOp>(loc, index, values[1], values[0]);
658 return arrayGet(arrayCreate(values), index, name);
663 Value ohMux(Value index, ValueRange inputs) {
665 unsigned numInputs = inputs.size();
666 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
667 "one-hot select can't mux inputs");
671 auto dataType = inputs[0].getType();
673 isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
674 Value muxValue = constant(width, 0);
677 for (
size_t i = numInputs - 1; i != 0; --i) {
678 Value input = inputs[i];
679 Value selectBit = bit(index, i);
680 muxValue = mux(selectBit, {muxValue, input});
690 DenseMap<APInt, Value> constants;
695static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
696 return TypeSwitch<Type, Value>(type)
697 .Case<NoneType>([&](NoneType) {
return s.constant(0, 0); })
698 .Case<IntType, IntegerType>([&](
auto type) {
699 return s.constant(type.getIntOrFloatBitWidth(), 0);
701 .Case<hw::StructType>([&](
auto structType) {
702 SmallVector<Value> zeroValues;
703 for (
auto field : structType.getElements())
704 zeroValues.push_back(createZeroDataConst(s, loc, field.type));
707 .Default([&](Type) -> Value {
708 emitError(loc) <<
"unsupported type for zero value: " << type;
715addSequentialIOOperandsIfNeeded(Operation *op,
716 llvm::SmallVectorImpl<Value> &operands) {
720 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
722 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
724 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
731 HandshakeConversionPattern(ESITypeConverter &typeConverter,
732 MLIRContext *context, OpBuilder &submoduleBuilder,
733 HandshakeLoweringState &ls)
735 submoduleBuilder(submoduleBuilder), ls(ls) {}
737 using OpAdaptor =
typename T::Adaptor;
740 matchAndRewrite(T op, OpAdaptor adaptor,
741 ConversionPatternRewriter &rewriter)
const override {
750 submoduleBuilder.setInsertionPoint(op->getParentOp());
757 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
758 clk = ports.getInput(
"clock");
759 rst = ports.getInput(
"reset");
763 RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
769 llvm::SmallVector<Value> operands = adaptor.getOperands();
770 addSequentialIOOperandsIfNeeded(op, operands);
771 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
772 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
785 UnwrappedIO unwrapped;
786 for (
auto port : ports.getInputs()) {
787 if (!isa<esi::ChannelType>(port.getType()))
790 auto ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
791 auto [
data, valid] = s.unwrap(port, *ready);
795 unwrapped.inputs.push_back(hs);
797 for (
auto &outputInfo : ports.
getPortList().getOutputs()) {
799 dyn_cast<esi::ChannelType>(outputInfo.type);
804 auto data = std::make_shared<Backedge>(bb.
get(innerType));
805 auto valid = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
806 auto [dataCh, ready] = s.wrap(*data, *valid);
810 ports.
setOutput(outputInfo.name, dataCh);
811 unwrapped.outputs.push_back(hs);
816 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
817 OutputHandshake &output, Value cond)
const {
818 auto validAndReady = s.bAnd({output.ready, cond});
819 for (
auto &input : inputs)
820 input.ready->setValue(validAndReady);
823 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
824 OutputHandshake &output)
const {
825 llvm::SmallVector<Value> valids;
826 for (
auto &input : inputs)
827 valids.push_back(input.valid);
828 Value allValid = s.bAnd(valids);
829 output.valid->setValue(allValid);
830 setAllReadyWithCond(s, inputs, output, allValid);
836 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
837 InputHandshake &select)
const {
839 size_t numInputs = unwrapped.inputs.size();
840 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
841 Value truncatedSelect =
842 select.data.getType().getIntOrFloatBitWidth() > selectWidth
843 ? s.truncate(select.data, selectWidth)
847 auto selectZext = s.zext(truncatedSelect, numInputs);
848 auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
849 auto &res = unwrapped.outputs[0];
852 auto selectedInputValid =
853 s.mux(truncatedSelect, unwrapped.getInputValids());
855 auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
856 res.valid->setValue(selAndInputValid);
857 auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
860 select.ready->setValue(resValidAndReady);
863 for (
auto [inIdx, in] :
llvm::enumerate(unwrapped.inputs)) {
865 auto isSelected = s.bit(select1h, inIdx);
869 auto activeAndResultValidAndReady =
870 s.bAnd({isSelected, resValidAndReady});
871 in.ready->setValue(activeAndResultValidAndReady);
875 res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
880 void buildForkLogic(RTLBuilder &s,
BackedgeBuilder &bb, InputHandshake &input,
881 ArrayRef<OutputHandshake> outputs)
const {
882 auto c0I1 = s.constant(1, 0);
883 llvm::SmallVector<Value> doneWires;
884 for (
auto [i, output] :
llvm::enumerate(outputs)) {
885 auto doneBE = bb.
get(s.b.getI1Type());
886 auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
887 auto emittedReg = s.reg(
"emitted_" + std::to_string(i), emitted, c0I1);
888 auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
889 output.valid->setValue(outValid);
890 auto validReady = s.bAnd({output.ready, outValid});
891 auto done = s.bOr({validReady, emittedReg},
"done" + std::to_string(i));
892 doneBE.setValue(done);
893 doneWires.push_back(done);
895 input.ready->setValue(s.bAnd(doneWires,
"allDone"));
901 void buildUnitRateJoinLogic(
902 RTLBuilder &s, UnwrappedIO &unwrappedIO,
903 llvm::function_ref<Value(ValueRange)> unitBuilder)
const {
904 assert(unwrappedIO.outputs.size() == 1 &&
905 "Expected exactly one output for unit-rate join actor");
907 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
910 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
911 unwrappedIO.outputs[0].data->setValue(unitRes);
914 void buildUnitRateForkLogic(
916 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder)
const {
917 assert(unwrappedIO.inputs.size() == 1 &&
918 "Expected exactly one input for unit-rate fork actor");
920 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
923 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
924 assert(unitResults.size() == unwrappedIO.outputs.size() &&
925 "Expected unit builder to return one result per output");
926 for (
auto [res, outport] :
llvm::zip(unitResults, unwrappedIO.outputs))
927 outport.
data->setValue(res);
930 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
931 bool signExtend)
const {
933 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
934 .getIntOrFloatBitWidth();
935 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
937 return s.sext(inputs[0], outWidth);
938 return s.zext(inputs[0], outWidth);
942 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
943 unsigned targetWidth)
const {
945 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
946 .getIntOrFloatBitWidth();
947 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
948 return s.truncate(inputs[0], outWidth);
953 static size_t getNumIndexBits(uint64_t numValues) {
954 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
957 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
959 DenseMap<size_t, Value> &indexMapping)
const {
960 auto numInputs = inputs.size();
961 auto priorityArb = defaultValue;
963 for (
size_t i = numInputs; i > 0; --i) {
964 size_t inputIndex = i - 1;
965 size_t oneHotIndex =
size_t{1} << inputIndex;
966 auto constIndex = s.constant(numInputs, oneHotIndex);
967 indexMapping[inputIndex] = constIndex;
968 priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
974 OpBuilder &submoduleBuilder;
975 HandshakeLoweringState &ls;
978class ForkConversionPattern :
public HandshakeConversionPattern<ForkOp> {
980 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
983 auto unwrapped = unwrapIO(s, bb, ports);
984 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
985 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
990class JoinConversionPattern :
public HandshakeConversionPattern<JoinOp> {
992 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
995 auto unwrappedIO = unwrapIO(s, bb, ports);
996 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
997 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1001class SyncConversionPattern :
public HandshakeConversionPattern<SyncOp> {
1003 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1006 auto unwrappedIO = unwrapIO(s, bb, ports);
1009 HandshakeWire wire(bb, s.b.getNoneType());
1011 OutputHandshake output = wire.getAsOutput();
1012 buildJoinLogic(s, unwrappedIO.inputs, output);
1014 InputHandshake input = wire.getAsInput();
1022 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1026 for (
auto &&[in, out] :
llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1031class MuxConversionPattern :
public HandshakeConversionPattern<MuxOp> {
1033 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1036 auto unwrappedIO = unwrapIO(s, bb, ports);
1039 auto select = unwrappedIO.inputs[0];
1040 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1041 buildMuxLogic(s, unwrappedIO, select);
1045class InstanceConversionPattern
1046 :
public HandshakeConversionPattern<handshake::InstanceOp> {
1048 using HandshakeConversionPattern<
1049 handshake::InstanceOp>::HandshakeConversionPattern;
1053 "If we indeed perform conversion in post-order, this "
1054 "should never be called. The base HandshakeConversionPattern logic "
1055 "will instantiate the external module.");
1059class ESIInstanceConversionPattern
1062 ESIInstanceConversionPattern(MLIRContext *context,
1067 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1068 ConversionPatternRewriter &rewriter)
const override {
1074 SmallVector<Value> operands;
1075 for (
size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1077 operands.push_back(adaptor.getOperands()[i]);
1078 operands.push_back(adaptor.getClk());
1079 operands.push_back(adaptor.getRst());
1082 Operation *targetModule = symCache.
getDefinition(op.getModuleAttr());
1084 rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1085 op.getInstNameAttr(), operands);
1093class ReturnConversionPattern
1096 using OpConversionPattern::OpConversionPattern;
1098 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter)
const override {
1102 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1103 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1104 outputOp->setOperands(adaptor.getOperands());
1105 outputOp->moveAfter(&parent.getBodyBlock()->back());
1106 rewriter.eraseOp(op);
1113template <
typename TIn,
typename TOut = TIn>
1114class UnitRateConversionPattern :
public HandshakeConversionPattern<TIn> {
1116 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1119 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1120 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1125 return s.b.create<TOut>(op.getLoc(), inputs,
1126 ArrayRef<NamedAttribute>{});
1131class PackConversionPattern :
public HandshakeConversionPattern<PackOp> {
1133 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1136 auto unwrappedIO = unwrapIO(s, bb, ports);
1137 buildUnitRateJoinLogic(s, unwrappedIO,
1138 [&](ValueRange inputs) {
return s.pack(inputs); });
1142class StructCreateConversionPattern
1143 :
public HandshakeConversionPattern<hw::StructCreateOp> {
1145 using HandshakeConversionPattern<
1149 auto unwrappedIO = unwrapIO(s, bb, ports);
1150 auto structType = op.getResult().getType();
1151 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1152 return s.pack(inputs, structType);
1157class UnpackConversionPattern :
public HandshakeConversionPattern<UnpackOp> {
1159 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1162 auto unwrappedIO = unwrapIO(s, bb, ports);
1163 buildUnitRateForkLogic(s, bb, unwrappedIO,
1164 [&](Value input) {
return s.unpack(input); });
1168class ConditionalBranchConversionPattern
1169 :
public HandshakeConversionPattern<ConditionalBranchOp> {
1171 using HandshakeConversionPattern<
1172 ConditionalBranchOp>::HandshakeConversionPattern;
1175 auto unwrappedIO = unwrapIO(s, bb, ports);
1176 auto cond = unwrappedIO.inputs[0];
1177 auto arg = unwrappedIO.inputs[1];
1178 auto trueRes = unwrappedIO.outputs[0];
1179 auto falseRes = unwrappedIO.outputs[1];
1181 auto condArgValid = s.bAnd({cond.valid, arg.valid});
1184 trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1185 falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1188 trueRes.data->setValue(arg.data);
1189 falseRes.data->setValue(arg.data);
1192 auto selectedResultReady =
1193 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1194 auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1195 arg.ready->setValue(condArgReady);
1196 cond.ready->setValue(condArgReady);
1200template <
typename TIn,
bool signExtend>
1201class ExtendConversionPattern :
public HandshakeConversionPattern<TIn> {
1203 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1206 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1207 this->buildExtendLogic(s, unwrappedIO, signExtend);
1211class ComparisonConversionPattern
1212 :
public HandshakeConversionPattern<arith::CmpIOp> {
1214 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1217 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1218 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1219 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1220 return s.b.create<comb::ICmpOp>(op.getLoc(), predicate, inputs[0],
1225 switch (op.getPredicate()) {
1226 case arith::CmpIPredicate::eq:
1227 return buildCompareLogic(comb::ICmpPredicate::eq);
1228 case arith::CmpIPredicate::ne:
1229 return buildCompareLogic(comb::ICmpPredicate::ne);
1230 case arith::CmpIPredicate::slt:
1231 return buildCompareLogic(comb::ICmpPredicate::slt);
1232 case arith::CmpIPredicate::ult:
1233 return buildCompareLogic(comb::ICmpPredicate::ult);
1234 case arith::CmpIPredicate::sle:
1235 return buildCompareLogic(comb::ICmpPredicate::sle);
1236 case arith::CmpIPredicate::ule:
1237 return buildCompareLogic(comb::ICmpPredicate::ule);
1238 case arith::CmpIPredicate::sgt:
1239 return buildCompareLogic(comb::ICmpPredicate::sgt);
1240 case arith::CmpIPredicate::ugt:
1241 return buildCompareLogic(comb::ICmpPredicate::ugt);
1242 case arith::CmpIPredicate::sge:
1243 return buildCompareLogic(comb::ICmpPredicate::sge);
1244 case arith::CmpIPredicate::uge:
1245 return buildCompareLogic(comb::ICmpPredicate::uge);
1247 assert(
false &&
"invalid CmpIOp");
1251class TruncateConversionPattern
1252 :
public HandshakeConversionPattern<arith::TruncIOp> {
1254 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1257 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1258 unsigned targetBits =
1259 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1260 buildTruncateLogic(s, unwrappedIO, targetBits);
1264class ControlMergeConversionPattern
1265 :
public HandshakeConversionPattern<ControlMergeOp> {
1267 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1270 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1271 auto resData = unwrappedIO.outputs[0];
1272 auto resIndex = unwrappedIO.outputs[1];
1275 unsigned numInputs = unwrappedIO.inputs.size();
1276 auto indexType = s.b.getIntegerType(numInputs);
1277 Value noWinner = s.constant(numInputs, 0);
1278 Value c0I1 = s.constant(1, 0);
1281 auto won = bb.
get(indexType);
1282 Value wonReg = s.reg(
"won_reg", won, noWinner);
1285 auto win = bb.
get(indexType);
1289 auto fired = bb.
get(s.b.getI1Type());
1292 auto resultEmitted = bb.
get(s.b.getI1Type());
1293 Value resultEmittedReg = s.reg(
"result_emitted_reg", resultEmitted, c0I1);
1294 auto indexEmitted = bb.
get(s.b.getI1Type());
1295 Value indexEmittedReg = s.reg(
"index_emitted_reg", indexEmitted, c0I1);
1298 auto resultDone = bb.
get(s.b.getI1Type());
1299 auto indexDone = bb.
get(s.b.getI1Type());
1303 auto hasWinnerCondition = s.rOr({win});
1304 auto hadWinnerCondition = s.rOr({wonReg});
1312 DenseMap<size_t, Value> argIndexValues;
1313 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1314 noWinner, argIndexValues);
1315 priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1316 win.setValue(priorityArb);
1326 auto resultNotEmitted = s.bNot(resultEmittedReg);
1327 auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1328 resData.valid->setValue(resultValid);
1329 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1331 auto indexNotEmitted = s.bNot(indexEmittedReg);
1332 auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1333 resIndex.valid->setValue(indexValid);
1337 SmallVector<Value, 8> indexOutputs;
1338 for (
size_t i = 0; i < numInputs; ++i)
1339 indexOutputs.push_back(s.constant(64, i));
1341 auto indexOutput = s.ohMux(win, indexOutputs);
1342 resIndex.data->setValue(indexOutput);
1348 won.setValue(s.mux(fired, {win, noWinner}));
1353 auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1354 resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1356 auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1357 indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1361 fired.setValue(s.bAnd({resultDone, indexDone}));
1367 resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1368 indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1373 auto winnerOrDefault = s.mux(fired, {noWinner, win});
1374 for (
auto [i, ir] :
llvm::enumerate(unwrappedIO.getInputReadys())) {
1375 auto &indexValue = argIndexValues[i];
1376 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1381class MergeConversionPattern :
public HandshakeConversionPattern<MergeOp> {
1383 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1386 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1387 auto resData = unwrappedIO.outputs[0];
1390 unsigned numInputs = unwrappedIO.inputs.size();
1391 auto indexType = s.b.getIntegerType(numInputs);
1392 Value noWinner = s.constant(numInputs, 0);
1395 auto win = bb.
get(indexType);
1398 auto hasWinnerCondition = s.rOr(win);
1405 DenseMap<size_t, Value> argIndexValues;
1406 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1407 noWinner, argIndexValues);
1408 win.setValue(priorityArb);
1415 resData.valid->setValue(hasWinnerCondition);
1416 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1421 auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1426 auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1427 for (
auto [i, ir] :
llvm::enumerate(unwrappedIO.getInputReadys())) {
1428 auto &indexValue = argIndexValues[i];
1429 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1434class LoadConversionPattern
1435 :
public HandshakeConversionPattern<handshake::LoadOp> {
1437 using HandshakeConversionPattern<
1438 handshake::LoadOp>::HandshakeConversionPattern;
1441 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1442 auto addrFromUser = unwrappedIO.inputs[0];
1443 auto dataFromMem = unwrappedIO.inputs[1];
1444 auto controlIn = unwrappedIO.inputs[2];
1445 auto dataToUser = unwrappedIO.outputs[0];
1446 auto addrToMem = unwrappedIO.outputs[1];
1448 addrToMem.data->setValue(addrFromUser.data);
1449 dataToUser.data->setValue(dataFromMem.data);
1453 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1457 dataToUser.valid->setValue(dataFromMem.valid);
1458 dataFromMem.ready->setValue(dataToUser.ready);
1462class StoreConversionPattern
1463 :
public HandshakeConversionPattern<handshake::StoreOp> {
1465 using HandshakeConversionPattern<
1466 handshake::StoreOp>::HandshakeConversionPattern;
1469 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1470 auto addrFromUser = unwrappedIO.inputs[0];
1471 auto dataFromUser = unwrappedIO.inputs[1];
1472 auto controlIn = unwrappedIO.inputs[2];
1473 auto dataToMem = unwrappedIO.outputs[0];
1474 auto addrToMem = unwrappedIO.outputs[1];
1477 auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1481 HandshakeWire joinWire(bb, s.b.getNoneType());
1482 joinWire.ready->setValue(outputsReady);
1483 OutputHandshake joinOutput = joinWire.getAsOutput();
1484 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1487 addrToMem.data->setValue(addrFromUser.data);
1488 dataToMem.data->setValue(dataFromUser.data);
1491 addrToMem.valid->setValue(*joinWire.valid);
1492 dataToMem.valid->setValue(*joinWire.valid);
1496class MemoryConversionPattern
1497 :
public HandshakeConversionPattern<handshake::MemoryOp> {
1499 using HandshakeConversionPattern<
1500 handshake::MemoryOp>::HandshakeConversionPattern;
1503 auto loc = op.getLoc();
1506 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1508 InputHandshake &
addr;
1509 OutputHandshake &
data;
1510 OutputHandshake &done;
1513 InputHandshake &
addr;
1514 InputHandshake &
data;
1515 OutputHandshake &done;
1517 SmallVector<LoadPort, 4> loadPorts;
1518 SmallVector<StorePort, 4> storePorts;
1520 unsigned stCount = op.getStCount();
1521 unsigned ldCount = op.getLdCount();
1522 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
1523 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1524 unwrappedIO.outputs[i],
1525 unwrappedIO.outputs[ldCount + stCount + i]};
1526 loadPorts.push_back(port);
1529 for (
unsigned i = 0, e = stCount; i != e; ++i) {
1530 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1531 unwrappedIO.inputs[i * 2],
1532 unwrappedIO.outputs[ldCount + i]};
1533 storePorts.push_back(port);
1537 auto c0I0 = s.constant(0, 0);
1539 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1540 auto hlmem = s.b.create<seq::HLMemOp>(
1541 loc, s.clk, s.rst,
"_handshake_memory_" + std::to_string(op.getId()),
1542 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1545 for (
auto &ld : loadPorts) {
1546 llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1547 auto readData = s.b.create<seq::ReadPortOp>(loc, hlmem.getHandle(),
1548 addresses, ld.addr.valid,
1550 ld.data.data->setValue(readData);
1551 ld.done.data->setValue(c0I0);
1553 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1557 for (
auto &st : storePorts) {
1560 auto writeValidBufferMuxBE = bb.
get(s.b.getI1Type());
1561 auto writeValidBuffer =
1562 s.reg(
"writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1563 st.done.valid->setValue(writeValidBuffer);
1564 st.done.data->setValue(c0I0);
1568 auto storeCompleted =
1569 s.bAnd({st.done.ready, writeValidBuffer},
"storeCompleted");
1573 auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1574 auto emptyOrComplete =
1575 s.bOr({notWriteValidBuffer, storeCompleted},
"emptyOrComplete");
1578 st.addr.ready->setValue(emptyOrComplete);
1579 st.data.ready->setValue(emptyOrComplete);
1582 auto writeValid = s.bAnd({st.addr.valid, st.data.valid},
"writeValid");
1588 writeValidBufferMuxBE.setValue(
1589 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1593 llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1594 s.b.create<seq::WritePortOp>(loc, hlmem.getHandle(), addresses,
1595 st.data.data, writeValid,
1601class SinkConversionPattern :
public HandshakeConversionPattern<SinkOp> {
1603 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1606 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1608 unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1612class SourceConversionPattern :
public HandshakeConversionPattern<SourceOp> {
1614 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1617 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1619 unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1620 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1624class ConstantConversionPattern
1625 :
public HandshakeConversionPattern<handshake::ConstantOp> {
1627 using HandshakeConversionPattern<
1628 handshake::ConstantOp>::HandshakeConversionPattern;
1631 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1632 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1633 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1634 auto constantValue = op->getAttrOfType<IntegerAttr>(
"value").getValue();
1635 unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1639class BufferConversionPattern :
public HandshakeConversionPattern<BufferOp> {
1641 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1644 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1645 auto input = unwrappedIO.inputs[0];
1646 auto output = unwrappedIO.outputs[0];
1647 InputHandshake lastStage;
1648 SmallVector<int64_t> initValues;
1651 if (op.getInitValues())
1652 initValues = op.getInitValueArray();
1655 buildSeqBufferLogic(s, bb,
toValidType(op.getDataType()),
1656 op.getNumSlots(), input, output, initValues);
1659 output.data->setValue(lastStage.data);
1660 output.valid->setValue(lastStage.valid);
1661 lastStage.ready->setValue(output.ready);
1664 struct SeqBufferStage {
1665 SeqBufferStage(Type dataType, InputHandshake &preStage,
BackedgeBuilder &bb,
1666 RTLBuilder &s,
size_t index,
1667 std::optional<int64_t> initValue)
1668 : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1671 c0s = createZeroDataConst(s, s.loc, dataType);
1672 currentStage.ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
1674 auto hasInitValue = s.constant(1, initValue.has_value());
1675 auto validBE = bb.
get(s.b.getI1Type());
1676 auto validReg = s.reg(getRegName(
"valid"), validBE, hasInitValue);
1677 auto readyBE = bb.
get(s.b.getI1Type());
1679 Value initValueCs = c0s;
1680 if (initValue.has_value())
1681 initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1686 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1687 buildControlBufferLogic(validReg, readyBE, dataReg);
1690 StringAttr getRegName(StringRef name) {
1691 return s.b.getStringAttr(name + std::to_string(index) +
"_reg");
1694 void buildControlBufferLogic(Value validReg,
Backedge &readyBE,
1696 auto c0I1 = s.constant(1, 0);
1697 auto readyRegWire = bb.
get(s.b.getI1Type());
1698 auto readyReg = s.reg(getRegName(
"ready"), readyRegWire, c0I1);
1702 currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1703 "controlValid" + std::to_string(index));
1706 auto notReadyReg = s.bNot(readyReg);
1709 auto succNotReady = s.bNot(*currentStage.ready);
1710 auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1711 auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1712 auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1715 auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1716 readyRegWire.setValue(resetSignal);
1719 auto ctrlDataRegBE = bb.
get(dataType);
1720 auto ctrlDataReg = s.reg(getRegName(
"ctrl_data"), ctrlDataRegBE, c0s);
1721 auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1722 currentStage.data = dataResult;
1724 auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1725 auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1726 ctrlDataRegBE.setValue(dataResetSignal);
1729 Value buildDataBufferLogic(Value validReg, Value initValue,
1733 auto notValidReg = s.bNot(validReg);
1734 auto emptyOrReady = s.bOr({notValidReg, readyBE});
1735 preStage.ready->setValue(emptyOrReady);
1741 auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1747 auto dataRegBE = bb.
get(dataType);
1749 s.reg(getRegName(
"data"),
1750 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1751 dataRegBE.setValue(dataReg);
1755 InputHandshake getOutput() {
return currentStage; }
1758 InputHandshake &preStage;
1759 InputHandshake currentStage;
1768 InputHandshake buildSeqBufferLogic(RTLBuilder &s,
BackedgeBuilder &bb,
1769 Type dataType,
unsigned size,
1770 InputHandshake &input,
1771 OutputHandshake &output,
1772 llvm::ArrayRef<int64_t> initValues)
const {
1775 InputHandshake currentStage = input;
1777 for (
unsigned i = 0; i < size; ++i) {
1778 bool isInitialized = i < initValues.size();
1780 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1781 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1785 return currentStage;
1789class IndexCastConversionPattern
1790 :
public HandshakeConversionPattern<arith::IndexCastOp> {
1792 using HandshakeConversionPattern<
1793 arith::IndexCastOp>::HandshakeConversionPattern;
1796 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1797 unsigned sourceBits =
1798 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1799 unsigned targetBits =
1800 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1801 if (targetBits < sourceBits)
1802 buildTruncateLogic(s, unwrappedIO, targetBits);
1804 buildExtendLogic(s, unwrappedIO,
true);
1808template <
typename T>
1811 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1812 MLIRContext *context, OpBuilder &submoduleBuilder,
1813 HandshakeLoweringState &ls)
1815 submoduleBuilder(submoduleBuilder), ls(ls) {}
1816 using OpAdaptor =
typename T::Adaptor;
1819 matchAndRewrite(T op, OpAdaptor adaptor,
1820 ConversionPatternRewriter &rewriter)
const override {
1830 llvm::SmallVector<Value> operands = adaptor.getOperands();
1831 addSequentialIOOperandsIfNeeded(op, operands);
1832 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1833 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1838 OpBuilder &submoduleBuilder;
1839 HandshakeLoweringState &ls;
1844 using OpConversionPattern::OpConversionPattern;
1848 ConversionPatternRewriter &rewriter)
const override {
1852 HWModuleLike hwModule;
1853 if (op.isExternal()) {
1855 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1858 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1859 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1860 rewriter.inlineBlockBefore(&op.getBody().front(),
1861 hwModuleOp.getBodyBlock()->getTerminator(),
1863 hwModule = hwModuleOp;
1870 auto *parentOp = op->getParentOp();
1871 auto *predeclModule =
1872 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1873 if (predeclModule) {
1874 if (failed(SymbolTable::replaceAllSymbolUses(
1875 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1877 rewriter.eraseOp(predeclModule);
1881 rewriter.eraseOp(op);
1893 ConversionTarget &target,
1895 OpBuilder &moduleBuilder) {
1897 std::map<std::string, unsigned> instanceNameCntr;
1898 NameUniquer instanceUniquer = [&](Operation *op) {
1900 if (
auto idAttr = op->getAttrOfType<IntegerAttr>(
"handshake_id"); idAttr) {
1903 instName +=
"_id" + std::to_string(idAttr.getValue().getZExtValue());
1906 instName += std::to_string(instanceNameCntr[instName]++);
1911 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1913 RewritePatternSet
patterns(op.getContext());
1914 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1916 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1917 SyncConversionPattern>(typeConverter, op.getContext(),
1922 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1923 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1924 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1925 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1926 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1927 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1928 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1929 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1930 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1931 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1932 UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1933 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1934 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1935 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1937 StructCreateConversionPattern,
1939 ConditionalBranchConversionPattern, MuxConversionPattern,
1940 PackConversionPattern, UnpackConversionPattern,
1941 ComparisonConversionPattern, BufferConversionPattern,
1942 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1943 MergeConversionPattern, ControlMergeConversionPattern,
1944 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1945 InstanceConversionPattern,
1947 ExtendConversionPattern<arith::ExtUIOp,
false>,
1948 ExtendConversionPattern<arith::ExtSIOp,
true>,
1949 TruncateConversionPattern, IndexCastConversionPattern>(
1950 typeConverter, op.getContext(), moduleBuilder, ls);
1952 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
1953 return op->emitOpError() <<
"error during conversion";
1958class HandshakeToHWPass
1959 :
public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1961 void runOnOperation()
override {
1962 mlir::ModuleOp mod = getOperation();
1966 for (
auto f : mod.getOps<
handshake::FuncOp>()) {
1968 f.emitOpError() <<
"HandshakeToHW: failed to verify that all values "
1969 "are used exactly once. Remember to run the "
1970 "fork/sink materialization pass before HW lowering.";
1971 signalPassFailure();
1977 std::string topLevel;
1979 SmallVector<std::string> sortedFuncs;
1981 signalPassFailure();
1985 ESITypeConverter typeConverter;
1986 ConversionTarget target(getContext());
1992 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1998 OpBuilder submoduleBuilder(mod.getContext());
1999 submoduleBuilder.setInsertionPointToStart(mod.getBody());
2000 for (
auto &funcName :
llvm::reverse(sortedFuncs)) {
2002 assert(funcOp &&
"handshake.func not found in module!");
2004 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2005 signalPassFailure();
2012 for (
auto hwModule : mod.getOps<
hw::HWModuleOp>())
2014 return signalPassFailure();
2020 RewritePatternSet
patterns(mod.getContext());
2021 patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2023 if (failed(applyPartialConversion(mod, target, std::move(
patterns)))) {
2024 mod->emitOpError() <<
"error during conversion";
2025 signalPassFailure();
2032 return std::make_unique<HandshakeToHWPass>();
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static Type tupleToStruct(TupleType tuple)
std::function< std::string(Operation *)> NameUniquer
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static std::string getCallName(Operation *op)
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
void setOutput(unsigned i, Value v)
This stores lookup tables to make manipulating and working with the IR more efficient.
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Channels are the basic communication primitives.
const Type * getInner() const
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
mlir::Type innerType(mlir::Type type)
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
This holds a decoded list of input/inout and output ports for a module or instance.
void eraseInput(size_t idx)
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.