26 #include "mlir/Dialect/Arith/IR/Arith.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Pass/PassManager.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/MathExtras.h"
37 #define GEN_PASS_DEF_HANDSHAKETOHW
38 #include "circt/Conversion/Passes.h.inc"
42 using namespace circt;
56 struct HandshakeLoweringState {
57 ModuleOp parentModule;
64 class ESITypeConverter :
public TypeConverter {
67 addConversion([](Type type) -> Type {
return esiWrapper(type); });
69 addTargetMaterialization([&](mlir::OpBuilder &builder,
70 mlir::Type resultType, mlir::ValueRange inputs,
71 mlir::Location loc) -> mlir::Value {
72 if (inputs.size() != 1)
75 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
79 addSourceMaterialization([&](mlir::OpBuilder &builder,
80 mlir::Type resultType, mlir::ValueRange inputs,
81 mlir::Location loc) -> mlir::Value {
82 if (inputs.size() != 1)
85 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
99 std::string subModuleName = oldOp->getName().getStringRef().str();
100 std::replace(subModuleName.begin(), subModuleName.end(),
'.',
'_');
101 return subModuleName;
105 auto callOp = dyn_cast<handshake::InstanceOp>(op);
113 auto opType = op.getType();
114 if (
auto channelType = dyn_cast<esi::ChannelType>(opType))
115 return channelType.getInner();
121 SmallVector<Type> filterRes;
122 llvm::copy_if(input, std::back_inserter(filterRes),
123 [](Type type) {
return !isa<NoneType>(type); });
131 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
132 .Case<MemoryOp, ExternalMemoryOp>([&](
auto memOp) {
134 {memOp.getMemRefType().getElementType()}};
139 std::vector<Type> inTypes, outTypes;
140 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
142 llvm::transform(op->getResults(), std::back_inserter(outTypes),
154 std::string typeName;
156 if (type.isIntOrIndex()) {
157 if (
auto indexType = dyn_cast<IndexType>(type))
158 typeName +=
"_ui" + std::to_string(indexType.kInternalStorageBitWidth);
159 else if (type.isSignedInteger())
160 typeName +=
"_si" + std::to_string(type.getIntOrFloatBitWidth());
162 typeName +=
"_ui" + std::to_string(type.getIntOrFloatBitWidth());
163 }
else if (
auto tupleType = dyn_cast<TupleType>(type)) {
164 typeName +=
"_tuple";
167 }
else if (
auto structType = dyn_cast<hw::StructType>(type)) {
168 typeName +=
"_struct";
169 for (
auto element : structType.getElements())
170 typeName +=
"_" + element.name.str() +
getTypeName(loc, element.type);
172 emitError(loc) <<
"unsupported data type '" << type <<
"'";
179 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
180 return instanceOp.getModule().str();
185 if (
auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
186 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
187 auto intType = intAttr.getType();
189 if (intType.isSignedInteger())
190 subModuleName +=
"_c" + std::to_string(intAttr.getSInt());
191 else if (intType.isUnsignedInteger())
192 subModuleName +=
"_c" + std::to_string(intAttr.getUInt());
194 subModuleName +=
"_c" + std::to_string((uint64_t)intAttr.getInt());
196 oldOp->emitError(
"unsupported constant type");
201 if (!inTypes.empty())
202 subModuleName +=
"_in";
203 for (
auto inType : inTypes)
204 subModuleName +=
getTypeName(oldOp->getLoc(), inType);
206 if (!outTypes.empty())
207 subModuleName +=
"_out";
208 for (
auto outType : outTypes)
209 subModuleName +=
getTypeName(oldOp->getLoc(), outType);
212 if (
auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
213 subModuleName +=
"_id" + std::to_string(memOp.getId());
216 if (
auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
217 subModuleName +=
"_" + stringifyEnum(comOp.getPredicate()).str();
220 if (
auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
221 subModuleName +=
"_" + std::to_string(bufferOp.getNumSlots()) +
"slots";
222 if (bufferOp.isSequential())
223 subModuleName +=
"_seq";
225 subModuleName +=
"_fifo";
227 if (
auto initValues = bufferOp.getInitValues()) {
228 subModuleName +=
"_init";
229 for (
const Attribute e : *initValues) {
230 assert(isa<IntegerAttr>(e));
232 "_" + std::to_string(dyn_cast<IntegerAttr>(e).
getInt());
238 if (
auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
239 ctrlInterface && ctrlInterface.isControl()) {
241 subModuleName +=
"_" + std::to_string(oldOp->getNumOperands()) +
"ins_" +
242 std::to_string(oldOp->getNumResults()) +
"outs";
243 subModuleName +=
"_ctrl";
246 (!inTypes.empty() || !outTypes.empty()) &&
247 "Insufficient discriminating type info generated for the operation!");
250 return subModuleName;
262 if (
auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
264 if (
auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
271 HWModuleLike targetModule;
272 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
277 if (isa<handshake::InstanceOp>(oldOp))
279 "handshake.instance target modules should always have been lowered "
280 "before the modules that reference them!");
290 static llvm::SmallVector<hw::detail::FieldInfo>
292 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
293 for (
auto port : portInfo)
294 fieldInfo.push_back({port.name, port.type});
302 auto *ctx = mod.getContext();
305 llvm::DenseMap<unsigned, Value> memrefPorts;
306 for (
auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
307 auto channel = dyn_cast<esi::ChannelType>(arg.getType());
308 if (channel && isa<MemRefType>(channel.getInner()))
309 memrefPorts[i] = arg;
312 if (memrefPorts.empty())
317 auto getMemoryIOInfo = [&](Location loc, Twine portName,
unsigned argIdx,
318 ArrayRef<hw::PortInfo> info,
322 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
326 for (
auto [i, arg] : memrefPorts) {
328 auto memName = mod.getArgName(i);
331 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
333 cast<hw::HWModuleExternOp>(SymbolTable::lookupNearestSymbolFrom(
334 extmemInstance, extmemInstance.getModuleNameAttr()));
345 SmallVector<PortInfo> outputs(portInfo.
getOutputs());
347 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_in", i, outputs,
349 mod.insertPorts({{i, inPortInfo}}, {});
350 auto newInPort = mod.getArgumentForInput(i);
352 b.setInsertionPointToStart(mod.getBodyBlock());
353 auto newInPortExploded = b.create<hw::StructExplodeOp>(
354 arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
355 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
359 unsigned outArgI = mod.getNumOutputPorts();
360 SmallVector<PortInfo> inputs(portInfo.
getInputs());
362 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_out", outArgI,
365 auto memOutputArgs = extmemInstance.getOperands().drop_front();
366 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
368 arg.getLoc(), outPortInfo.type, memOutputArgs);
369 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
374 extmemInstance.erase();
378 mod.modifyPorts( {}, {},
389 struct InputHandshake {
391 std::shared_ptr<Backedge> ready;
397 struct OutputHandshake {
398 std::shared_ptr<Backedge> valid;
400 std::shared_ptr<Backedge>
data;
405 struct HandshakeWire {
407 MLIRContext *ctx = dataType.getContext();
409 valid = std::make_shared<Backedge>(bb.
get(i1Type));
410 ready = std::make_shared<Backedge>(bb.
get(i1Type));
411 data = std::make_shared<Backedge>(bb.
get(dataType));
416 InputHandshake getAsInput() {
return {*valid, ready, *
data}; }
417 OutputHandshake getAsOutput() {
return {valid, *ready,
data}; }
419 std::shared_ptr<Backedge> valid;
420 std::shared_ptr<Backedge> ready;
421 std::shared_ptr<Backedge>
data;
424 template <
typename T,
typename TInner>
425 llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
426 llvm::function_ref<T(TInner &)> extractor) {
427 llvm::SmallVector<T> result;
428 llvm::transform(container, std::back_inserter(result), extractor);
432 llvm::SmallVector<InputHandshake> inputs;
433 llvm::SmallVector<OutputHandshake> outputs;
435 llvm::SmallVector<Value> getInputValids() {
436 return extractValues<Value, InputHandshake>(
437 inputs, [](
auto &hs) {
return hs.valid; });
439 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
440 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
441 inputs, [](
auto &hs) {
return hs.ready; });
443 llvm::SmallVector<Value> getInputDatas() {
444 return extractValues<Value, InputHandshake>(
445 inputs, [](
auto &hs) {
return hs.data; });
447 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
448 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
449 outputs, [](
auto &hs) {
return hs.valid; });
451 llvm::SmallVector<Value> getOutputReadys() {
452 return extractValues<Value, OutputHandshake>(
453 outputs, [](
auto &hs) {
return hs.ready; });
455 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
456 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
457 outputs, [](
auto &hs) {
return hs.data; });
465 RTLBuilder(hw::ModulePortInfo info, OpBuilder &builder, Location loc,
466 Value clk = Value(), Value rst = Value())
467 : info(std::move(info)), b(builder), loc(loc),
clk(
clk), rst(rst) {}
469 Value constant(
const APInt &apv, std::optional<StringRef> name = {}) {
472 bool isZeroWidth = apv.getBitWidth() == 0;
474 auto it = constants.find(apv);
475 if (it != constants.end())
481 constants[apv] = cval;
485 Value constant(
unsigned width, int64_t value,
486 std::optional<StringRef> name = {}) {
488 APInt(width, value,
false,
true));
490 std::pair<Value, Value>
wrap(Value data, Value valid,
491 std::optional<StringRef> name = {}) {
492 auto wrapOp = b.create<esi::WrapValidReadyOp>(loc,
data, valid);
493 return {wrapOp.getResult(0), wrapOp.getResult(1)};
495 std::pair<Value, Value>
unwrap(Value channel, Value ready,
496 std::optional<StringRef> name = {}) {
497 auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
498 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
502 Value
reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
503 Value rst = Value()) {
504 Value resolvedClk =
clk ?
clk : this->
clk;
505 Value resolvedRst = rst ? rst : this->rst;
507 "No global clock provided to this RTLBuilder - a clock "
508 "signal must be provided to the reg(...) function.");
510 "No global reset provided to this RTLBuilder - a reset "
511 "signal must be provided to the reg(...) function.");
513 return b.create<
seq::CompRegOp>(loc, in, resolvedClk, resolvedRst, rstValue,
517 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
518 std::optional<StringRef> name = {}) {
519 return b.
create<comb::ICmpOp>(loc, predicate, lhs, rhs);
522 Value buildNamedOp(llvm::function_ref<Value()> f,
523 std::optional<StringRef> name) {
526 Operation *op = v.getDefiningOp();
527 if (name.has_value()) {
528 op->setAttr(
"sv.namehint", b.getStringAttr(*name));
529 nameAttr = b.getStringAttr(*name);
535 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
537 [&]() {
return b.create<
comb::AndOp>(loc, values,
false); }, name);
540 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
542 [&]() {
return b.create<
comb::OrOp>(loc, values,
false); }, name);
546 Value bNot(Value value, std::optional<StringRef> name = {}) {
547 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
548 std::string inferedName;
552 value.getDefiningOp()->getAttrOfType<StringAttr>(
"sv.namehint")) {
553 inferedName = (
"not_" +
valueName.getValue()).str();
559 [&]() {
return b.create<
comb::XorOp>(loc, value, allOnes); }, name);
561 return b.createOrFold<
comb::XorOp>(loc, value, allOnes,
false);
564 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
566 [&]() {
return b.create<
comb::ShlOp>(loc, value, shift); }, name);
569 Value
concat(ValueRange values, std::optional<StringRef> name = {}) {
570 return buildNamedOp([&]() {
return b.create<
comb::ConcatOp>(loc, values); },
575 Value pack(ValueRange values, Type structType = Type(),
576 std::optional<StringRef> name = {}) {
585 ValueRange unpack(Value value) {
586 auto structType = cast<hw::StructType>(value.getType());
587 llvm::SmallVector<Type> innerTypes;
588 structType.getInnerTypes(innerTypes);
589 return b.create<hw::StructExplodeOp>(loc, innerTypes, value).getResults();
592 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
593 llvm::SmallVector<Value> bits;
594 for (
unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
600 Value rOr(Value v, std::optional<StringRef> name = {}) {
601 return buildNamedOp([&]() {
return bOr(toBits(v)); }, name);
605 Value extract(Value v,
unsigned lo,
unsigned hi,
606 std::optional<StringRef> name = {}) {
607 unsigned width = hi - lo + 1;
613 Value truncate(Value value,
unsigned width,
614 std::optional<StringRef> name = {}) {
615 return extract(value, 0, width - 1, name);
618 Value zext(Value value,
unsigned outWidth,
619 std::optional<StringRef> name = {}) {
620 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
621 assert(inWidth <= outWidth &&
"zext: input width must be <- output width.");
622 if (inWidth == outWidth)
624 auto c0 = constant(outWidth - inWidth, 0);
625 return concat({c0, value}, name);
628 Value sext(Value value,
unsigned outWidth,
629 std::optional<StringRef> name = {}) {
634 Value bit(Value v,
unsigned index, std::optional<StringRef> name = {}) {
635 return extract(v, index, index, name);
639 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
645 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
647 [&]() {
return b.create<
hw::ArrayGetOp>(loc, array, index); }, name);
653 Value mux(Value index, ValueRange values,
654 std::optional<StringRef> name = {}) {
655 if (values.size() == 2)
656 return b.create<
comb::MuxOp>(loc, index, values[1], values[0]);
658 return arrayGet(arrayCreate(values), index, name);
663 Value ohMux(Value index, ValueRange inputs) {
665 unsigned numInputs = inputs.size();
666 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
667 "one-hot select can't mux inputs");
671 auto dataType = inputs[0].getType();
673 isa<NoneType>(dataType) ? 0 : dataType.getIntOrFloatBitWidth();
674 Value muxValue = constant(width, 0);
677 for (
size_t i = numInputs - 1; i != 0; --i) {
678 Value input = inputs[i];
679 Value selectBit = bit(index, i);
680 muxValue = mux(selectBit, {muxValue, input});
686 hw::ModulePortInfo info;
690 DenseMap<APInt, Value> constants;
695 static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
696 return TypeSwitch<Type, Value>(type)
697 .Case<NoneType>([&](NoneType) {
return s.constant(0, 0); })
698 .Case<IntType, IntegerType>([&](
auto type) {
699 return s.constant(type.getIntOrFloatBitWidth(), 0);
701 .Case<hw::StructType>([&](
auto structType) {
702 SmallVector<Value> zeroValues;
703 for (
auto field : structType.getElements())
704 zeroValues.push_back(createZeroDataConst(s, loc, field.type));
707 .Default([&](Type) -> Value {
708 emitError(loc) <<
"unsupported type for zero value: " << type;
715 addSequentialIOOperandsIfNeeded(Operation *op,
716 llvm::SmallVectorImpl<Value> &operands) {
720 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
722 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
724 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
728 template <
typename T>
731 HandshakeConversionPattern(ESITypeConverter &typeConverter,
732 MLIRContext *context, OpBuilder &submoduleBuilder,
733 HandshakeLoweringState &ls)
735 submoduleBuilder(submoduleBuilder), ls(ls) {}
737 using OpAdaptor =
typename T::Adaptor;
740 matchAndRewrite(T op, OpAdaptor adaptor,
741 ConversionPatternRewriter &rewriter)
const override {
750 submoduleBuilder.setInsertionPoint(op->getParentOp());
753 portInfo, [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
757 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
758 clk = ports.getInput(
"clock");
759 rst = ports.getInput(
"reset");
763 RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
769 llvm::SmallVector<Value> operands = adaptor.getOperands();
770 addSequentialIOOperandsIfNeeded(op, operands);
771 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
772 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
777 hw::HWModulePortAccessor &ports)
const = 0;
784 hw::HWModulePortAccessor &ports)
const {
785 UnwrappedIO unwrapped;
786 for (
auto port : ports.getInputs()) {
787 if (!isa<esi::ChannelType>(port.getType()))
790 auto ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
791 auto [
data, valid] = s.unwrap(port, *ready);
795 unwrapped.inputs.push_back(hs);
797 for (
auto &outputInfo : ports.getPortList().getOutputs()) {
799 dyn_cast<esi::ChannelType>(outputInfo.type);
805 auto valid = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
806 auto [dataCh, ready] = s.wrap(*data, *valid);
810 ports.setOutput(outputInfo.name, dataCh);
811 unwrapped.outputs.push_back(hs);
816 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
817 OutputHandshake &output, Value cond)
const {
818 auto validAndReady = s.bAnd({output.ready, cond});
819 for (
auto &input : inputs)
820 input.ready->setValue(validAndReady);
823 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake> inputs,
824 OutputHandshake &output)
const {
825 llvm::SmallVector<Value> valids;
826 for (
auto &input : inputs)
827 valids.push_back(input.valid);
828 Value allValid = s.bAnd(valids);
829 output.valid->setValue(allValid);
830 setAllReadyWithCond(s, inputs, output, allValid);
836 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
837 InputHandshake &select)
const {
839 size_t numInputs = unwrapped.inputs.size();
840 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
841 Value truncatedSelect =
842 select.data.getType().getIntOrFloatBitWidth() > selectWidth
843 ? s.truncate(select.data, selectWidth)
847 auto selectZext = s.zext(truncatedSelect, numInputs);
848 auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
849 auto &res = unwrapped.outputs[0];
852 auto selectedInputValid =
853 s.mux(truncatedSelect, unwrapped.getInputValids());
855 auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
856 res.valid->setValue(selAndInputValid);
857 auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
860 select.ready->setValue(resValidAndReady);
863 for (
auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
865 auto isSelected = s.bit(select1h, inIdx);
869 auto activeAndResultValidAndReady =
870 s.bAnd({isSelected, resValidAndReady});
871 in.ready->setValue(activeAndResultValidAndReady);
875 res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
880 void buildForkLogic(RTLBuilder &s,
BackedgeBuilder &bb, InputHandshake &input,
881 ArrayRef<OutputHandshake> outputs)
const {
882 auto c0I1 = s.constant(1, 0);
883 llvm::SmallVector<Value> doneWires;
884 for (
auto [i, output] : llvm::enumerate(outputs)) {
885 auto doneBE = bb.
get(s.b.getI1Type());
886 auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
887 auto emittedReg = s.reg(
"emitted_" + std::to_string(i), emitted, c0I1);
888 auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
890 auto validReady = s.bAnd({output.ready, outValid});
891 auto done = s.bOr({validReady, emittedReg},
"done" + std::to_string(i));
892 doneBE.setValue(done);
893 doneWires.push_back(done);
895 input.ready->setValue(s.bAnd(doneWires,
"allDone"));
901 void buildUnitRateJoinLogic(
902 RTLBuilder &s, UnwrappedIO &unwrappedIO,
903 llvm::function_ref<Value(ValueRange)> unitBuilder)
const {
904 assert(unwrappedIO.outputs.size() == 1 &&
905 "Expected exactly one output for unit-rate join actor");
907 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
910 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
911 unwrappedIO.outputs[0].data->setValue(unitRes);
914 void buildUnitRateForkLogic(
916 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder)
const {
917 assert(unwrappedIO.inputs.size() == 1 &&
918 "Expected exactly one input for unit-rate fork actor");
920 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
923 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
924 assert(unitResults.size() == unwrappedIO.outputs.size() &&
925 "Expected unit builder to return one result per output");
926 for (
auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
927 outport.data->setValue(res);
930 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
931 bool signExtend)
const {
933 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
934 .getIntOrFloatBitWidth();
935 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
937 return s.sext(inputs[0], outWidth);
938 return s.zext(inputs[0], outWidth);
942 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
943 unsigned targetWidth)
const {
945 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
946 .getIntOrFloatBitWidth();
947 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
948 return s.truncate(inputs[0], outWidth);
953 static size_t getNumIndexBits(uint64_t numValues) {
954 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
957 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value> inputs,
959 DenseMap<size_t, Value> &indexMapping)
const {
960 auto numInputs = inputs.size();
961 auto priorityArb = defaultValue;
963 for (
size_t i = numInputs; i > 0; --i) {
964 size_t inputIndex = i - 1;
965 size_t oneHotIndex =
size_t{1} << inputIndex;
966 auto constIndex = s.constant(numInputs, oneHotIndex);
967 indexMapping[inputIndex] = constIndex;
968 priorityArb = s.mux(inputs[inputIndex], {priorityArb, constIndex});
974 OpBuilder &submoduleBuilder;
975 HandshakeLoweringState &ls;
978 class ForkConversionPattern :
public HandshakeConversionPattern<ForkOp> {
980 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
982 hw::HWModulePortAccessor &ports)
const override {
983 auto unwrapped = unwrapIO(s, bb, ports);
984 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
985 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
990 class JoinConversionPattern :
public HandshakeConversionPattern<JoinOp> {
992 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
994 hw::HWModulePortAccessor &ports)
const override {
995 auto unwrappedIO = unwrapIO(s, bb, ports);
996 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
997 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1001 class SyncConversionPattern :
public HandshakeConversionPattern<SyncOp> {
1003 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
1005 hw::HWModulePortAccessor &ports)
const override {
1006 auto unwrappedIO = unwrapIO(s, bb, ports);
1009 HandshakeWire wire(bb, s.b.getNoneType());
1011 OutputHandshake output = wire.getAsOutput();
1012 buildJoinLogic(s, unwrappedIO.inputs, output);
1014 InputHandshake input = wire.getAsInput();
1022 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1026 for (
auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1027 out.data->setValue(in.data);
1031 class MuxConversionPattern :
public HandshakeConversionPattern<MuxOp> {
1033 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1035 hw::HWModulePortAccessor &ports)
const override {
1036 auto unwrappedIO = unwrapIO(s, bb, ports);
1039 auto select = unwrappedIO.inputs[0];
1040 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1041 buildMuxLogic(s, unwrappedIO, select);
1045 class InstanceConversionPattern
1046 :
public HandshakeConversionPattern<handshake::InstanceOp> {
1048 using HandshakeConversionPattern<
1049 handshake::InstanceOp>::HandshakeConversionPattern;
1051 hw::HWModulePortAccessor &ports)
const override {
1053 "If we indeed perform conversion in post-order, this "
1054 "should never be called. The base HandshakeConversionPattern logic "
1055 "will instantiate the external module.");
1059 class ESIInstanceConversionPattern
1062 ESIInstanceConversionPattern(MLIRContext *context,
1067 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
1068 ConversionPatternRewriter &rewriter)
const override {
1074 SmallVector<Value> operands;
1075 for (
size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
1077 operands.push_back(adaptor.getOperands()[i]);
1078 operands.push_back(adaptor.getClk());
1079 operands.push_back(adaptor.getRst());
1082 Operation *targetModule = symCache.
getDefinition(op.getModuleAttr());
1084 rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
1085 op.getInstNameAttr(), operands);
1093 class ReturnConversionPattern
1096 using OpConversionPattern::OpConversionPattern;
1098 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter)
const override {
1102 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1103 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1104 outputOp->setOperands(adaptor.getOperands());
1105 outputOp->moveAfter(&parent.getBodyBlock()->back());
1106 rewriter.eraseOp(op);
1113 template <
typename TIn,
typename TOut = TIn>
1114 class UnitRateConversionPattern :
public HandshakeConversionPattern<TIn> {
1116 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1118 hw::HWModulePortAccessor &ports)
const override {
1119 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1120 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1125 return s.b.create<TOut>(op.getLoc(), inputs,
1126 ArrayRef<NamedAttribute>{});
1131 class PackConversionPattern :
public HandshakeConversionPattern<PackOp> {
1133 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1135 hw::HWModulePortAccessor &ports)
const override {
1136 auto unwrappedIO = unwrapIO(s, bb, ports);
1137 buildUnitRateJoinLogic(s, unwrappedIO,
1138 [&](ValueRange inputs) {
return s.pack(inputs); });
1142 class StructCreateConversionPattern
1143 :
public HandshakeConversionPattern<hw::StructCreateOp> {
1145 using HandshakeConversionPattern<
1148 hw::HWModulePortAccessor &ports)
const override {
1149 auto unwrappedIO = unwrapIO(s, bb, ports);
1150 auto structType = op.getResult().getType();
1151 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1152 return s.pack(inputs, structType);
1157 class UnpackConversionPattern :
public HandshakeConversionPattern<UnpackOp> {
1159 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1161 hw::HWModulePortAccessor &ports)
const override {
1162 auto unwrappedIO = unwrapIO(s, bb, ports);
1163 buildUnitRateForkLogic(s, bb, unwrappedIO,
1164 [&](Value input) {
return s.unpack(input); });
1168 class ConditionalBranchConversionPattern
1169 :
public HandshakeConversionPattern<ConditionalBranchOp> {
1171 using HandshakeConversionPattern<
1172 ConditionalBranchOp>::HandshakeConversionPattern;
1174 hw::HWModulePortAccessor &ports)
const override {
1175 auto unwrappedIO = unwrapIO(s, bb, ports);
1176 auto cond = unwrappedIO.inputs[0];
1177 auto arg = unwrappedIO.inputs[1];
1178 auto trueRes = unwrappedIO.outputs[0];
1179 auto falseRes = unwrappedIO.outputs[1];
1181 auto condArgValid = s.bAnd({cond.valid, arg.valid});
1184 trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1185 falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1188 trueRes.data->setValue(arg.data);
1189 falseRes.data->setValue(arg.data);
1192 auto selectedResultReady =
1193 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1194 auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1195 arg.ready->setValue(condArgReady);
1196 cond.ready->setValue(condArgReady);
1200 template <
typename TIn,
bool signExtend>
1201 class ExtendConversionPattern :
public HandshakeConversionPattern<TIn> {
1203 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1205 hw::HWModulePortAccessor &ports)
const override {
1206 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1207 this->buildExtendLogic(s, unwrappedIO, signExtend);
1211 class ComparisonConversionPattern
1212 :
public HandshakeConversionPattern<arith::CmpIOp> {
1214 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1216 hw::HWModulePortAccessor &ports)
const override {
1217 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1218 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1219 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange inputs) {
1220 return s.b.create<comb::ICmpOp>(op.getLoc(), predicate, inputs[0],
1225 switch (op.getPredicate()) {
1226 case arith::CmpIPredicate::eq:
1227 return buildCompareLogic(comb::ICmpPredicate::eq);
1228 case arith::CmpIPredicate::ne:
1229 return buildCompareLogic(comb::ICmpPredicate::ne);
1230 case arith::CmpIPredicate::slt:
1231 return buildCompareLogic(comb::ICmpPredicate::slt);
1232 case arith::CmpIPredicate::ult:
1233 return buildCompareLogic(comb::ICmpPredicate::ult);
1234 case arith::CmpIPredicate::sle:
1235 return buildCompareLogic(comb::ICmpPredicate::sle);
1236 case arith::CmpIPredicate::ule:
1237 return buildCompareLogic(comb::ICmpPredicate::ule);
1238 case arith::CmpIPredicate::sgt:
1239 return buildCompareLogic(comb::ICmpPredicate::sgt);
1240 case arith::CmpIPredicate::ugt:
1241 return buildCompareLogic(comb::ICmpPredicate::ugt);
1242 case arith::CmpIPredicate::sge:
1243 return buildCompareLogic(comb::ICmpPredicate::sge);
1244 case arith::CmpIPredicate::uge:
1245 return buildCompareLogic(comb::ICmpPredicate::uge);
1247 assert(
false &&
"invalid CmpIOp");
1251 class TruncateConversionPattern
1252 :
public HandshakeConversionPattern<arith::TruncIOp> {
1254 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1256 hw::HWModulePortAccessor &ports)
const override {
1257 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1258 unsigned targetBits =
1259 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1260 buildTruncateLogic(s, unwrappedIO, targetBits);
1264 class ControlMergeConversionPattern
1265 :
public HandshakeConversionPattern<ControlMergeOp> {
1267 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1269 hw::HWModulePortAccessor &ports)
const override {
1270 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1271 auto resData = unwrappedIO.outputs[0];
1272 auto resIndex = unwrappedIO.outputs[1];
1275 unsigned numInputs = unwrappedIO.inputs.size();
1276 auto indexType = s.b.getIntegerType(numInputs);
1277 Value noWinner = s.constant(numInputs, 0);
1278 Value c0I1 = s.constant(1, 0);
1281 auto won = bb.
get(indexType);
1282 Value wonReg = s.reg(
"won_reg", won, noWinner);
1285 auto win = bb.
get(indexType);
1289 auto fired = bb.
get(s.b.getI1Type());
1292 auto resultEmitted = bb.
get(s.b.getI1Type());
1293 Value resultEmittedReg = s.reg(
"result_emitted_reg", resultEmitted, c0I1);
1294 auto indexEmitted = bb.
get(s.b.getI1Type());
1295 Value indexEmittedReg = s.reg(
"index_emitted_reg", indexEmitted, c0I1);
1298 auto resultDone = bb.
get(s.b.getI1Type());
1299 auto indexDone = bb.
get(s.b.getI1Type());
1303 auto hasWinnerCondition = s.rOr({win});
1304 auto hadWinnerCondition = s.rOr({wonReg});
1312 DenseMap<size_t, Value> argIndexValues;
1313 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1314 noWinner, argIndexValues);
1315 priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1316 win.setValue(priorityArb);
1326 auto resultNotEmitted = s.bNot(resultEmittedReg);
1327 auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1328 resData.valid->setValue(resultValid);
1329 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1331 auto indexNotEmitted = s.bNot(indexEmittedReg);
1332 auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1333 resIndex.valid->setValue(indexValid);
1337 SmallVector<Value, 8> indexOutputs;
1338 for (
size_t i = 0; i < numInputs; ++i)
1339 indexOutputs.push_back(s.constant(64, i));
1341 auto indexOutput = s.ohMux(win, indexOutputs);
1342 resIndex.data->setValue(indexOutput);
1348 won.setValue(s.mux(fired, {win, noWinner}));
1353 auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1354 resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1356 auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1357 indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1361 fired.setValue(s.bAnd({resultDone, indexDone}));
1367 resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1368 indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1373 auto winnerOrDefault = s.mux(fired, {noWinner, win});
1374 for (
auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1375 auto &indexValue = argIndexValues[i];
1376 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1381 class MergeConversionPattern :
public HandshakeConversionPattern<MergeOp> {
1383 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1385 hw::HWModulePortAccessor &ports)
const override {
1386 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1387 auto resData = unwrappedIO.outputs[0];
1390 unsigned numInputs = unwrappedIO.inputs.size();
1391 auto indexType = s.b.getIntegerType(numInputs);
1392 Value noWinner = s.constant(numInputs, 0);
1395 auto win = bb.
get(indexType);
1398 auto hasWinnerCondition = s.rOr(win);
1405 DenseMap<size_t, Value> argIndexValues;
1406 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1407 noWinner, argIndexValues);
1408 win.setValue(priorityArb);
1415 resData.valid->setValue(hasWinnerCondition);
1416 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1421 auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1426 auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1427 for (
auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1428 auto &indexValue = argIndexValues[i];
1429 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1434 class LoadConversionPattern
1435 :
public HandshakeConversionPattern<handshake::LoadOp> {
1437 using HandshakeConversionPattern<
1438 handshake::LoadOp>::HandshakeConversionPattern;
1440 hw::HWModulePortAccessor &ports)
const override {
1441 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1442 auto addrFromUser = unwrappedIO.inputs[0];
1443 auto dataFromMem = unwrappedIO.inputs[1];
1444 auto controlIn = unwrappedIO.inputs[2];
1445 auto dataToUser = unwrappedIO.outputs[0];
1446 auto addrToMem = unwrappedIO.outputs[1];
1448 addrToMem.data->setValue(addrFromUser.data);
1449 dataToUser.data->setValue(dataFromMem.data);
1453 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1457 dataToUser.valid->setValue(dataFromMem.valid);
1458 dataFromMem.ready->setValue(dataToUser.ready);
1462 class StoreConversionPattern
1463 :
public HandshakeConversionPattern<handshake::StoreOp> {
1465 using HandshakeConversionPattern<
1466 handshake::StoreOp>::HandshakeConversionPattern;
1468 hw::HWModulePortAccessor &ports)
const override {
1469 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1470 auto addrFromUser = unwrappedIO.inputs[0];
1471 auto dataFromUser = unwrappedIO.inputs[1];
1472 auto controlIn = unwrappedIO.inputs[2];
1473 auto dataToMem = unwrappedIO.outputs[0];
1474 auto addrToMem = unwrappedIO.outputs[1];
1477 auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1481 HandshakeWire joinWire(bb, s.b.getNoneType());
1482 joinWire.ready->setValue(outputsReady);
1483 OutputHandshake joinOutput = joinWire.getAsOutput();
1484 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1487 addrToMem.data->setValue(addrFromUser.data);
1488 dataToMem.data->setValue(dataFromUser.data);
1491 addrToMem.valid->setValue(*joinWire.valid);
1492 dataToMem.valid->setValue(*joinWire.valid);
1496 class MemoryConversionPattern
1497 :
public HandshakeConversionPattern<handshake::MemoryOp> {
1499 using HandshakeConversionPattern<
1500 handshake::MemoryOp>::HandshakeConversionPattern;
1502 hw::HWModulePortAccessor &ports)
const override {
1503 auto loc = op.getLoc();
1506 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1508 InputHandshake &
addr;
1509 OutputHandshake &
data;
1510 OutputHandshake &done;
1513 InputHandshake &
addr;
1514 InputHandshake &
data;
1515 OutputHandshake &done;
1517 SmallVector<LoadPort, 4> loadPorts;
1518 SmallVector<StorePort, 4> storePorts;
1520 unsigned stCount = op.getStCount();
1521 unsigned ldCount = op.getLdCount();
1522 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
1523 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1524 unwrappedIO.outputs[i],
1525 unwrappedIO.outputs[ldCount + stCount + i]};
1526 loadPorts.push_back(port);
1529 for (
unsigned i = 0, e = stCount; i != e; ++i) {
1530 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1531 unwrappedIO.inputs[i * 2],
1532 unwrappedIO.outputs[ldCount + i]};
1533 storePorts.push_back(port);
1537 auto c0I0 = s.constant(0, 0);
1539 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1540 auto hlmem = s.b.create<seq::HLMemOp>(
1541 loc, s.clk, s.rst,
"_handshake_memory_" + std::to_string(op.getId()),
1542 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1545 for (
auto &ld : loadPorts) {
1546 llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1547 auto readData = s.b.create<seq::ReadPortOp>(loc, hlmem.getHandle(),
1548 addresses, ld.addr.valid,
1550 ld.data.data->setValue(readData);
1551 ld.done.data->setValue(c0I0);
1553 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1557 for (
auto &st : storePorts) {
1560 auto writeValidBufferMuxBE = bb.
get(s.b.getI1Type());
1561 auto writeValidBuffer =
1562 s.reg(
"writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1563 st.done.valid->
setValue(writeValidBuffer);
1564 st.done.data->setValue(c0I0);
1568 auto storeCompleted =
1569 s.bAnd({st.done.ready, writeValidBuffer},
"storeCompleted");
1573 auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1574 auto emptyOrComplete =
1575 s.bOr({notWriteValidBuffer, storeCompleted},
"emptyOrComplete");
1578 st.addr.ready->setValue(emptyOrComplete);
1579 st.data.ready->setValue(emptyOrComplete);
1582 auto writeValid = s.bAnd({st.addr.valid, st.data.valid},
"writeValid");
1588 writeValidBufferMuxBE.setValue(
1589 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1593 llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1594 s.b.create<seq::WritePortOp>(loc, hlmem.getHandle(), addresses,
1595 st.data.data, writeValid,
1601 class SinkConversionPattern :
public HandshakeConversionPattern<SinkOp> {
1603 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1605 hw::HWModulePortAccessor &ports)
const override {
1606 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1608 unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1612 class SourceConversionPattern :
public HandshakeConversionPattern<SourceOp> {
1614 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1616 hw::HWModulePortAccessor &ports)
const override {
1617 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1619 unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1620 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1624 class ConstantConversionPattern
1625 :
public HandshakeConversionPattern<handshake::ConstantOp> {
1627 using HandshakeConversionPattern<
1628 handshake::ConstantOp>::HandshakeConversionPattern;
1630 hw::HWModulePortAccessor &ports)
const override {
1631 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1632 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1633 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1634 auto constantValue = op->getAttrOfType<IntegerAttr>(
"value").getValue();
1635 unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1639 class BufferConversionPattern :
public HandshakeConversionPattern<BufferOp> {
1641 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1643 hw::HWModulePortAccessor &ports)
const override {
1644 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1645 auto input = unwrappedIO.inputs[0];
1646 auto output = unwrappedIO.outputs[0];
1647 InputHandshake lastStage;
1648 SmallVector<int64_t> initValues;
1651 if (op.getInitValues())
1652 initValues = op.getInitValueArray();
1655 buildSeqBufferLogic(s, bb,
toValidType(op.getDataType()),
1656 op.getNumSlots(), input, output, initValues);
1659 output.data->setValue(lastStage.data);
1660 output.valid->setValue(lastStage.valid);
1661 lastStage.ready->setValue(output.ready);
1664 struct SeqBufferStage {
1665 SeqBufferStage(Type dataType, InputHandshake &preStage,
BackedgeBuilder &bb,
1666 RTLBuilder &s,
size_t index,
1667 std::optional<int64_t> initValue)
1668 : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1671 c0s = createZeroDataConst(s, s.loc, dataType);
1672 currentStage.ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
1674 auto hasInitValue = s.constant(1, initValue.has_value());
1675 auto validBE = bb.
get(s.b.getI1Type());
1676 auto validReg = s.reg(getRegName(
"valid"), validBE, hasInitValue);
1677 auto readyBE = bb.
get(s.b.getI1Type());
1679 Value initValueCs = c0s;
1680 if (initValue.has_value())
1681 initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1686 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1687 buildControlBufferLogic(validReg, readyBE, dataReg);
1690 StringAttr getRegName(StringRef name) {
1691 return s.b.getStringAttr(name + std::to_string(index) +
"_reg");
1694 void buildControlBufferLogic(Value validReg,
Backedge &readyBE,
1696 auto c0I1 = s.constant(1, 0);
1697 auto readyRegWire = bb.
get(s.b.getI1Type());
1698 auto readyReg = s.reg(getRegName(
"ready"), readyRegWire, c0I1);
1702 currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1703 "controlValid" + std::to_string(index));
1706 auto notReadyReg = s.bNot(readyReg);
1709 auto succNotReady = s.bNot(*currentStage.ready);
1710 auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1711 auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1712 auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1715 auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1716 readyRegWire.setValue(resetSignal);
1719 auto ctrlDataRegBE = bb.
get(dataType);
1720 auto ctrlDataReg = s.reg(getRegName(
"ctrl_data"), ctrlDataRegBE, c0s);
1721 auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1722 currentStage.data = dataResult;
1724 auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1725 auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1726 ctrlDataRegBE.
setValue(dataResetSignal);
1729 Value buildDataBufferLogic(Value validReg, Value initValue,
1733 auto notValidReg = s.bNot(validReg);
1734 auto emptyOrReady = s.bOr({notValidReg, readyBE});
1735 preStage.ready->setValue(emptyOrReady);
1741 auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1747 auto dataRegBE = bb.
get(dataType);
1749 s.reg(getRegName(
"data"),
1750 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1755 InputHandshake getOutput() {
return currentStage; }
1758 InputHandshake &preStage;
1759 InputHandshake currentStage;
1768 InputHandshake buildSeqBufferLogic(RTLBuilder &s,
BackedgeBuilder &bb,
1769 Type dataType,
unsigned size,
1770 InputHandshake &input,
1771 OutputHandshake &output,
1772 llvm::ArrayRef<int64_t> initValues)
const {
1775 InputHandshake currentStage = input;
1777 for (
unsigned i = 0; i < size; ++i) {
1778 bool isInitialized = i < initValues.size();
1780 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1781 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1785 return currentStage;
1789 class IndexCastConversionPattern
1790 :
public HandshakeConversionPattern<arith::IndexCastOp> {
1792 using HandshakeConversionPattern<
1793 arith::IndexCastOp>::HandshakeConversionPattern;
1795 hw::HWModulePortAccessor &ports)
const override {
1796 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1797 unsigned sourceBits =
1798 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1799 unsigned targetBits =
1800 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1801 if (targetBits < sourceBits)
1802 buildTruncateLogic(s, unwrappedIO, targetBits);
1804 buildExtendLogic(s, unwrappedIO,
true);
1808 template <
typename T>
1811 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1812 MLIRContext *context, OpBuilder &submoduleBuilder,
1813 HandshakeLoweringState &ls)
1815 submoduleBuilder(submoduleBuilder), ls(ls) {}
1816 using OpAdaptor =
typename T::Adaptor;
1819 matchAndRewrite(T op, OpAdaptor adaptor,
1820 ConversionPatternRewriter &rewriter)
const override {
1830 llvm::SmallVector<Value> operands = adaptor.getOperands();
1831 addSequentialIOOperandsIfNeeded(op, operands);
1832 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1833 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1838 OpBuilder &submoduleBuilder;
1839 HandshakeLoweringState &ls;
1844 using OpConversionPattern::OpConversionPattern;
1848 ConversionPatternRewriter &rewriter)
const override {
1852 HWModuleLike hwModule;
1853 if (op.isExternal()) {
1855 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1858 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1859 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1860 rewriter.inlineBlockBefore(&op.getBody().front(),
1861 hwModuleOp.getBodyBlock()->getTerminator(),
1863 hwModule = hwModuleOp;
1870 auto *parentOp = op->getParentOp();
1871 auto *predeclModule =
1872 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1873 if (predeclModule) {
1874 if (failed(SymbolTable::replaceAllSymbolUses(
1875 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1877 rewriter.eraseOp(predeclModule);
1881 rewriter.eraseOp(op);
1893 ConversionTarget &target,
1895 OpBuilder &moduleBuilder) {
1897 std::map<std::string, unsigned> instanceNameCntr;
1898 NameUniquer instanceUniquer = [&](Operation *op) {
1900 if (
auto idAttr = op->getAttrOfType<IntegerAttr>(
"handshake_id"); idAttr) {
1903 instName +=
"_id" + std::to_string(idAttr.getValue().getZExtValue());
1906 instName += std::to_string(instanceNameCntr[instName]++);
1911 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1913 RewritePatternSet
patterns(op.getContext());
1914 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1916 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1917 SyncConversionPattern>(typeConverter, op.getContext(),
1922 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1923 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1924 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1925 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1926 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1927 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1928 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1929 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1930 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1931 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1932 UnitRateConversionPattern<arith::ShLIOp, comb::ShlOp>,
1933 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1934 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1935 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1937 StructCreateConversionPattern,
1939 ConditionalBranchConversionPattern, MuxConversionPattern,
1940 PackConversionPattern, UnpackConversionPattern,
1941 ComparisonConversionPattern, BufferConversionPattern,
1942 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1943 MergeConversionPattern, ControlMergeConversionPattern,
1944 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1945 InstanceConversionPattern,
1947 ExtendConversionPattern<arith::ExtUIOp,
false>,
1948 ExtendConversionPattern<arith::ExtSIOp,
true>,
1949 TruncateConversionPattern, IndexCastConversionPattern>(
1950 typeConverter, op.getContext(), moduleBuilder, ls);
1952 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
1953 return op->emitOpError() <<
"error during conversion";
1958 class HandshakeToHWPass
1959 :
public circt::impl::HandshakeToHWBase<HandshakeToHWPass> {
1961 void runOnOperation()
override {
1962 mlir::ModuleOp mod = getOperation();
1968 f.emitOpError() <<
"HandshakeToHW: failed to verify that all values "
1969 "are used exactly once. Remember to run the "
1970 "fork/sink materialization pass before HW lowering.";
1971 signalPassFailure();
1977 std::string topLevel;
1979 SmallVector<std::string> sortedFuncs;
1981 signalPassFailure();
1985 ESITypeConverter typeConverter;
1986 ConversionTarget target(getContext());
1992 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1998 OpBuilder submoduleBuilder(mod.getContext());
1999 submoduleBuilder.setInsertionPointToStart(mod.getBody());
2000 for (
auto &funcName : llvm::reverse(sortedFuncs)) {
2002 assert(funcOp &&
"handshake.func not found in module!");
2004 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
2005 signalPassFailure();
2014 return signalPassFailure();
2020 RewritePatternSet
patterns(mod.getContext());
2021 patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
2023 if (failed(applyPartialConversion(mod, target, std::move(
patterns)))) {
2024 mod->emitOpError() <<
"error during conversion";
2025 signalPassFailure();
2032 return std::make_unique<HandshakeToHWPass>();
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static Type tupleToStruct(TupleType tuple)
std::function< std::string(Operation *)> NameUniquer
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static std::string getCallName(Operation *op)
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
This stores lookup tables to make manipulating and working with the IR more efficient.
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
Channels are the basic communication primitives.
const Type * getInner() const
def create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
mlir::Type innerType(mlir::Type type)
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
Returns the hw::ModulePortInfo that corresponds to the given handshake operation and its in- and outp...
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
Attribute name for the name of a predeclaration of the to-be-lowered hw.module from a handshake funct...
esi::ChannelType esiWrapper(Type t)
Wraps a type into an ESI ChannelType type.
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
Checks all block arguments and values within op to ensure that all values have exactly one use.
Type toValidType(Type t)
Converts 't' into a valid HW type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
This holds a decoded list of input/inout and output ports for a module or instance.
void eraseInput(size_t idx)
PortDirectionRange getInputs()
PortDirectionRange getOutputs()