14 #include "../PassDetail.h"
25 #include "mlir/Dialect/Arith/IR/Arith.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/MathExtras.h"
35 using namespace circt;
49 struct HandshakeLoweringState {
50 ModuleOp parentModule;
57 class ESITypeConverter :
public TypeConverter {
60 addConversion([](Type type) -> Type {
return esiWrapper(type); });
62 addTargetMaterialization(
63 [&](mlir::OpBuilder &
builder, mlir::Type resultType,
65 mlir::Location loc) -> std::optional<mlir::Value> {
71 addSourceMaterialization(
72 [&](mlir::OpBuilder &
builder, mlir::Type resultType,
74 mlir::Location loc) -> std::optional<mlir::Value> {
90 std::string subModuleName = oldOp->getName().getStringRef().str();
91 std::replace(subModuleName.begin(), subModuleName.end(),
'.',
'_');
96 auto callOp = dyn_cast<handshake::InstanceOp>(op);
104 auto opType = op.getType();
105 if (
auto channelType = opType.dyn_cast<esi::ChannelType>())
106 return channelType.getInner();
112 SmallVector<Type> filterRes;
113 llvm::copy_if(input, std::back_inserter(filterRes),
114 [](Type type) {
return !type.isa<NoneType>(); });
122 return TypeSwitch<Operation *, DiscriminatingTypes>(op)
123 .Case<MemoryOp, ExternalMemoryOp>([&](
auto memOp) {
125 {memOp.getMemRefType().getElementType()}};
130 std::vector<Type> inTypes, outTypes;
131 llvm::transform(op->getOperands(), std::back_inserter(inTypes),
133 llvm::transform(op->getResults(), std::back_inserter(outTypes),
145 std::string typeName;
147 if (type.isIntOrIndex()) {
148 if (
auto indexType = type.dyn_cast<IndexType>())
149 typeName +=
"_ui" + std::to_string(indexType.kInternalStorageBitWidth);
150 else if (type.isSignedInteger())
151 typeName +=
"_si" + std::to_string(type.getIntOrFloatBitWidth());
153 typeName +=
"_ui" + std::to_string(type.getIntOrFloatBitWidth());
154 }
else if (
auto tupleType = type.dyn_cast<TupleType>()) {
155 typeName +=
"_tuple";
158 }
else if (
auto structType = type.dyn_cast<hw::StructType>()) {
159 typeName +=
"_struct";
160 for (
auto element : structType.getElements())
161 typeName +=
"_" + element.name.str() +
getTypeName(loc, element.type);
163 emitError(loc) <<
"unsupported data type '" << type <<
"'";
170 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
171 return instanceOp.getModule().str();
176 if (
auto constOp = dyn_cast<handshake::ConstantOp>(oldOp)) {
177 if (
auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>()) {
178 auto intType = intAttr.getType();
180 if (intType.isSignedInteger())
181 subModuleName +=
"_c" + std::to_string(intAttr.getSInt());
182 else if (intType.isUnsignedInteger())
183 subModuleName +=
"_c" + std::to_string(intAttr.getUInt());
185 subModuleName +=
"_c" + std::to_string((uint64_t)intAttr.getInt());
187 oldOp->emitError(
"unsupported constant type");
192 if (!inTypes.empty())
193 subModuleName +=
"_in";
194 for (
auto inType : inTypes)
195 subModuleName +=
getTypeName(oldOp->getLoc(), inType);
197 if (!outTypes.empty())
198 subModuleName +=
"_out";
199 for (
auto outType : outTypes)
200 subModuleName +=
getTypeName(oldOp->getLoc(), outType);
203 if (
auto memOp = dyn_cast<handshake::MemoryOp>(oldOp))
204 subModuleName +=
"_id" + std::to_string(memOp.getId());
207 if (
auto comOp = dyn_cast<mlir::arith::CmpIOp>(oldOp))
208 subModuleName +=
"_" + stringifyEnum(comOp.getPredicate()).str();
211 if (
auto bufferOp = dyn_cast<handshake::BufferOp>(oldOp)) {
212 subModuleName +=
"_" + std::to_string(bufferOp.getNumSlots()) +
"slots";
213 if (bufferOp.isSequential())
214 subModuleName +=
"_seq";
216 subModuleName +=
"_fifo";
218 if (
auto initValues = bufferOp.getInitValues()) {
219 subModuleName +=
"_init";
220 for (
const Attribute e : *initValues) {
221 assert(e.isa<IntegerAttr>());
223 "_" + std::to_string(e.dyn_cast<IntegerAttr>().getInt());
229 if (
auto ctrlInterface = dyn_cast<handshake::ControlInterface>(oldOp);
230 ctrlInterface && ctrlInterface.isControl()) {
232 subModuleName +=
"_" + std::to_string(oldOp->getNumOperands()) +
"ins_" +
233 std::to_string(oldOp->getNumResults()) +
"outs";
234 subModuleName +=
"_ctrl";
237 (!inTypes.empty() || !outTypes.empty()) &&
238 "Insufficient discriminating type info generated for the operation!");
241 return subModuleName;
253 if (
auto mod = parentModule.lookupSymbol<HWModuleOp>(modName))
255 if (
auto mod = parentModule.lookupSymbol<HWModuleExternOp>(modName))
262 HWModuleLike targetModule;
263 if (
auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp))
268 if (isa<handshake::InstanceOp>(oldOp))
270 "handshake.instance target modules should always have been lowered "
271 "before the modules that reference them!");
281 static llvm::SmallVector<hw::detail::FieldInfo>
283 llvm::SmallVector<hw::detail::FieldInfo> fieldInfo;
284 for (
auto port : portInfo)
285 fieldInfo.push_back({port.name, port.type});
293 auto ports = mod.getPortList();
294 auto *ctx = mod.getContext();
297 llvm::DenseMap<unsigned, Value> memrefPorts;
298 for (
auto [i, arg] : llvm::enumerate(mod.getBodyBlock()->getArguments())) {
299 auto channel = arg.getType().dyn_cast<esi::ChannelType>();
300 if (channel && channel.getInner().isa<MemRefType>())
301 memrefPorts[i] = arg;
304 if (memrefPorts.empty())
309 auto getMemoryIOInfo = [&](Location loc, Twine portName,
unsigned argIdx,
310 ArrayRef<hw::PortInfo> info,
314 hw::PortInfo{{b.getStringAttr(portName), type, direction}, argIdx};
318 for (
auto [i, arg] : memrefPorts) {
320 auto memName = mod.getArgName(i);
323 auto extmemInstance = cast<hw::InstanceOp>(*arg.getUsers().begin());
325 cast<hw::HWModuleExternOp>(extmemInstance.getReferencedModuleSlow());
326 auto portInfo = extmemMod.getPortList();
332 portInfo.eraseInput(0);
335 SmallVector<PortInfo>
outputs(portInfo.getOutputs());
337 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_in", i,
outputs,
339 mod.insertPorts({{i, inPortInfo}}, {});
340 auto newInPort = mod.getArgumentForInput(i);
342 b.setInsertionPointToStart(mod.getBodyBlock());
343 auto newInPortExploded = b.create<hw::StructExplodeOp>(
344 arg.getLoc(), extmemMod.getOutputTypes(), newInPort);
345 extmemInstance.replaceAllUsesWith(newInPortExploded.getResults());
349 unsigned outArgI = mod.getNumOutputPorts();
350 SmallVector<PortInfo>
inputs(portInfo.getInputs());
352 getMemoryIOInfo(arg.getLoc(), memName.strref() +
"_out", outArgI,
355 auto memOutputArgs = extmemInstance.getOperands().drop_front();
356 b.setInsertionPoint(mod.getBodyBlock()->getTerminator());
357 auto memOutputStruct = b.create<hw::StructCreateOp>(
358 arg.getLoc(), outPortInfo.type, memOutputArgs);
359 mod.appendOutputs({{outPortInfo.name, memOutputStruct}});
364 extmemInstance.erase();
368 mod.modifyPorts( {}, {},
379 struct InputHandshake {
381 std::shared_ptr<Backedge> ready;
387 struct OutputHandshake {
388 std::shared_ptr<Backedge> valid;
390 std::shared_ptr<Backedge>
data;
395 struct HandshakeWire {
397 MLIRContext *ctx = dataType.getContext();
399 valid = std::make_shared<Backedge>(bb.
get(i1Type));
400 ready = std::make_shared<Backedge>(bb.
get(i1Type));
401 data = std::make_shared<Backedge>(bb.
get(dataType));
406 InputHandshake getAsInput() {
return {*valid, ready, *
data}; }
407 OutputHandshake getAsOutput() {
return {valid, *ready,
data}; }
409 std::shared_ptr<Backedge> valid;
410 std::shared_ptr<Backedge> ready;
411 std::shared_ptr<Backedge>
data;
414 template <
typename T,
typename TInner>
415 llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
416 llvm::function_ref<T(TInner &)> extractor) {
417 llvm::SmallVector<T> result;
418 llvm::transform(container, std::back_inserter(result), extractor);
422 llvm::SmallVector<InputHandshake>
inputs;
423 llvm::SmallVector<OutputHandshake>
outputs;
425 llvm::SmallVector<Value> getInputValids() {
426 return extractValues<Value, InputHandshake>(
427 inputs, [](
auto &hs) {
return hs.valid; });
429 llvm::SmallVector<std::shared_ptr<Backedge>> getInputReadys() {
430 return extractValues<std::shared_ptr<Backedge>, InputHandshake>(
431 inputs, [](
auto &hs) {
return hs.ready; });
433 llvm::SmallVector<Value> getInputDatas() {
434 return extractValues<Value, InputHandshake>(
435 inputs, [](
auto &hs) {
return hs.data; });
437 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputValids() {
438 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
439 outputs, [](
auto &hs) {
return hs.valid; });
441 llvm::SmallVector<Value> getOutputReadys() {
442 return extractValues<Value, OutputHandshake>(
443 outputs, [](
auto &hs) {
return hs.ready; });
445 llvm::SmallVector<std::shared_ptr<Backedge>> getOutputDatas() {
446 return extractValues<std::shared_ptr<Backedge>, OutputHandshake>(
447 outputs, [](
auto &hs) {
return hs.data; });
455 RTLBuilder(hw::ModulePortInfo info, OpBuilder &
builder, Location loc,
456 Value clk = Value(), Value rst = Value())
457 : info(std::move(info)), b(
builder), loc(loc),
clk(
clk), rst(rst) {}
459 Value constant(
const APInt &apv, std::optional<StringRef> name = {}) {
462 bool isZeroWidth = apv.getBitWidth() == 0;
464 auto it = constants.find(apv);
465 if (it != constants.end())
469 auto cval = b.create<hw::ConstantOp>(loc, apv);
471 constants[apv] = cval;
475 Value constant(
unsigned width, int64_t value,
476 std::optional<StringRef> name = {}) {
477 return constant(APInt(
width, value));
479 std::pair<Value, Value>
wrap(Value data, Value valid,
480 std::optional<StringRef> name = {}) {
481 auto wrapOp = b.create<esi::WrapValidReadyOp>(loc,
data, valid);
482 return {wrapOp.getResult(0), wrapOp.getResult(1)};
484 std::pair<Value, Value>
unwrap(Value channel, Value ready,
485 std::optional<StringRef> name = {}) {
486 auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
487 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
491 Value
reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
492 Value rst = Value()) {
493 Value resolvedClk =
clk ?
clk : this->
clk;
494 Value resolvedRst = rst ? rst : this->rst;
496 "No global clock provided to this RTLBuilder - a clock "
497 "signal must be provided to the reg(...) function.");
499 "No global reset provided to this RTLBuilder - a reset "
500 "signal must be provided to the reg(...) function.");
502 return b.create<seq::CompRegOp>(loc, in.getType(), in, resolvedClk, name,
503 resolvedRst, rstValue, hw::InnerSymAttr());
506 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
507 std::optional<StringRef> name = {}) {
508 return b.create<comb::ICmpOp>(loc, predicate, lhs, rhs);
511 Value buildNamedOp(llvm::function_ref<Value()> f,
512 std::optional<StringRef> name) {
515 Operation *op = v.getDefiningOp();
516 if (name.has_value()) {
517 op->setAttr(
"sv.namehint", b.getStringAttr(*name));
518 nameAttr = b.getStringAttr(*name);
524 Value bAnd(ValueRange values, std::optional<StringRef> name = {}) {
526 [&]() {
return b.create<comb::AndOp>(loc, values,
false); }, name);
529 Value bOr(ValueRange values, std::optional<StringRef> name = {}) {
531 [&]() {
return b.create<comb::OrOp>(loc, values,
false); }, name);
535 Value bNot(Value value, std::optional<StringRef> name = {}) {
536 auto allOnes = constant(value.getType().getIntOrFloatBitWidth(), -1);
537 std::string inferedName;
541 value.getDefiningOp()->getAttrOfType<StringAttr>(
"sv.namehint")) {
542 inferedName = (
"not_" +
valueName.getValue()).str();
548 [&]() {
return b.create<comb::XorOp>(loc, value, allOnes); }, name);
550 return b.createOrFold<comb::XorOp>(loc, value, allOnes,
false);
553 Value shl(Value value, Value shift, std::optional<StringRef> name = {}) {
555 [&]() {
return b.create<comb::ShlOp>(loc, value, shift); }, name);
558 Value
concat(ValueRange values, std::optional<StringRef> name = {}) {
559 return buildNamedOp([&]() {
return b.create<comb::ConcatOp>(loc, values); },
564 Value pack(ValueRange values, Type structType = Type(),
565 std::optional<StringRef> name = {}) {
569 [&]() {
return b.create<hw::StructCreateOp>(loc, structType, values); },
574 ValueRange unpack(Value value) {
575 auto structType = value.getType().cast<hw::StructType>();
576 llvm::SmallVector<Type> innerTypes;
577 structType.getInnerTypes(innerTypes);
578 return b.create<hw::StructExplodeOp>(loc, innerTypes, value).getResults();
581 llvm::SmallVector<Value> toBits(Value v, std::optional<StringRef> name = {}) {
582 llvm::SmallVector<Value>
bits;
583 for (
unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
584 bits.push_back(b.create<comb::ExtractOp>(loc, v, i, 1));
589 Value rOr(Value v, std::optional<StringRef> name = {}) {
590 return buildNamedOp([&]() {
return bOr(toBits(v)); }, name);
594 Value extract(Value v,
unsigned lo,
unsigned hi,
595 std::optional<StringRef> name = {}) {
596 unsigned width = hi - lo + 1;
598 [&]() {
return b.create<comb::ExtractOp>(loc, v, lo,
width); }, name);
602 Value truncate(Value value,
unsigned width,
603 std::optional<StringRef> name = {}) {
604 return extract(value, 0,
width - 1, name);
607 Value zext(Value value,
unsigned outWidth,
608 std::optional<StringRef> name = {}) {
609 unsigned inWidth = value.getType().getIntOrFloatBitWidth();
610 assert(inWidth <= outWidth &&
"zext: input width must be <- output width.");
611 if (inWidth == outWidth)
613 auto c0 = constant(outWidth - inWidth, 0);
614 return concat({c0, value}, name);
617 Value sext(Value value,
unsigned outWidth,
618 std::optional<StringRef> name = {}) {
623 Value bit(Value v,
unsigned index, std::optional<StringRef> name = {}) {
624 return extract(v, index, index, name);
628 Value arrayCreate(ValueRange values, std::optional<StringRef> name = {}) {
630 [&]() {
return b.create<hw::ArrayCreateOp>(loc, values); }, name);
634 Value arrayGet(Value array, Value index, std::optional<StringRef> name = {}) {
636 [&]() {
return b.create<hw::ArrayGetOp>(loc, array, index); }, name);
642 Value mux(Value index, ValueRange values,
643 std::optional<StringRef> name = {}) {
644 if (values.size() == 2)
645 return b.create<comb::MuxOp>(loc, index, values[1], values[0]);
647 return arrayGet(arrayCreate(values), index, name);
652 Value ohMux(Value index, ValueRange
inputs) {
654 unsigned numInputs =
inputs.size();
655 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
656 "one-hot select can't mux inputs");
660 auto dataType =
inputs[0].getType();
662 dataType.isa<NoneType>() ? 0 : dataType.getIntOrFloatBitWidth();
663 Value muxValue = constant(
width, 0);
666 for (
size_t i = numInputs - 1; i != 0; --i) {
668 Value selectBit = bit(index, i);
669 muxValue = mux(selectBit, {muxValue, input});
675 hw::ModulePortInfo info;
679 DenseMap<APInt, Value> constants;
684 static Value createZeroDataConst(RTLBuilder &s, Location loc, Type type) {
685 return TypeSwitch<Type, Value>(type)
686 .Case<NoneType>([&](NoneType) {
return s.constant(0, 0); })
687 .Case<IntType, IntegerType>([&](
auto type) {
688 return s.constant(type.getIntOrFloatBitWidth(), 0);
690 .Case<hw::StructType>([&](
auto structType) {
691 SmallVector<Value> zeroValues;
692 for (
auto field : structType.getElements())
693 zeroValues.push_back(createZeroDataConst(s, loc, field.type));
694 return s.b.create<hw::StructCreateOp>(loc, structType, zeroValues);
696 .Default([&](Type) -> Value {
697 emitError(loc) <<
"unsupported type for zero value: " << type;
704 addSequentialIOOperandsIfNeeded(Operation *op,
705 llvm::SmallVectorImpl<Value> &operands) {
709 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
711 parent.getArgumentForInput(parent.getNumInputPorts() - 2));
713 parent.getArgumentForInput(parent.getNumInputPorts() - 1));
717 template <
typename T>
720 HandshakeConversionPattern(ESITypeConverter &typeConverter,
721 MLIRContext *context, OpBuilder &submoduleBuilder,
722 HandshakeLoweringState &ls)
724 submoduleBuilder(submoduleBuilder), ls(ls) {}
726 using OpAdaptor =
typename T::Adaptor;
729 matchAndRewrite(T op, OpAdaptor adaptor,
730 ConversionPatternRewriter &rewriter)
const override {
739 submoduleBuilder.setInsertionPoint(op->getParentOp());
740 implModule = submoduleBuilder.create<hw::HWModuleOp>(
742 portInfo, [&](OpBuilder &b, hw::HWModulePortAccessor &ports) {
746 if (op->template hasTrait<mlir::OpTrait::HasClock>()) {
747 clk = ports.getInput(
"clock");
748 rst = ports.getInput(
"reset");
752 RTLBuilder s(ports.getPortList(), b, op.getLoc(), clk, rst);
758 llvm::SmallVector<Value> operands = adaptor.getOperands();
759 addSequentialIOOperandsIfNeeded(op, operands);
760 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
761 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
766 hw::HWModulePortAccessor &ports)
const = 0;
773 hw::HWModulePortAccessor &ports)
const {
774 UnwrappedIO unwrapped;
775 for (
auto port : ports.getInputs()) {
776 if (!isa<esi::ChannelType>(port.getType()))
779 auto ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
780 auto [
data, valid] = s.unwrap(port, *ready);
784 unwrapped.inputs.push_back(hs);
786 for (
auto &outputInfo : ports.getPortList().getOutputs()) {
787 esi::ChannelType channelType =
788 dyn_cast<esi::ChannelType>(outputInfo.type);
794 auto valid = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
795 auto [dataCh, ready] = s.wrap(*data, *valid);
799 ports.setOutput(outputInfo.name, dataCh);
800 unwrapped.outputs.push_back(hs);
805 void setAllReadyWithCond(RTLBuilder &s, ArrayRef<InputHandshake>
inputs,
806 OutputHandshake &output, Value cond)
const {
807 auto validAndReady = s.bAnd({output.ready, cond});
808 for (
auto &input :
inputs)
809 input.ready->setValue(validAndReady);
812 void buildJoinLogic(RTLBuilder &s, ArrayRef<InputHandshake>
inputs,
813 OutputHandshake &output)
const {
814 llvm::SmallVector<Value> valids;
815 for (
auto &input :
inputs)
816 valids.push_back(input.valid);
817 Value allValid = s.bAnd(valids);
818 output.valid->setValue(allValid);
819 setAllReadyWithCond(s,
inputs, output, allValid);
825 void buildMuxLogic(RTLBuilder &s, UnwrappedIO &unwrapped,
826 InputHandshake &select)
const {
828 size_t numInputs = unwrapped.inputs.size();
829 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
830 Value truncatedSelect =
831 select.data.getType().getIntOrFloatBitWidth() > selectWidth
832 ? s.truncate(select.data, selectWidth)
836 auto selectZext = s.zext(truncatedSelect, numInputs);
837 auto select1h = s.shl(s.constant(numInputs, 1), selectZext);
838 auto &res = unwrapped.outputs[0];
841 auto selectedInputValid =
842 s.mux(truncatedSelect, unwrapped.getInputValids());
844 auto selAndInputValid = s.bAnd({selectedInputValid, select.valid});
845 res.valid->setValue(selAndInputValid);
846 auto resValidAndReady = s.bAnd({selAndInputValid, res.ready});
849 select.ready->setValue(resValidAndReady);
852 for (
auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
854 auto isSelected = s.bit(select1h, inIdx);
858 auto activeAndResultValidAndReady =
859 s.bAnd({isSelected, resValidAndReady});
860 in.ready->setValue(activeAndResultValidAndReady);
864 res.data->setValue(s.mux(truncatedSelect, unwrapped.getInputDatas()));
869 void buildForkLogic(RTLBuilder &s,
BackedgeBuilder &bb, InputHandshake &input,
870 ArrayRef<OutputHandshake>
outputs)
const {
871 auto c0I1 = s.constant(1, 0);
872 llvm::SmallVector<Value> doneWires;
873 for (
auto [i, output] : llvm::enumerate(
outputs)) {
874 auto doneBE = bb.
get(s.b.getI1Type());
875 auto emitted = s.bAnd({doneBE, s.bNot(*input.ready)});
876 auto emittedReg = s.reg(
"emitted_" + std::to_string(i), emitted, c0I1);
877 auto outValid = s.bAnd({s.bNot(emittedReg), input.valid});
879 auto validReady = s.bAnd({output.ready, outValid});
880 auto done = s.bOr({validReady, emittedReg},
"done" + std::to_string(i));
881 doneBE.setValue(done);
882 doneWires.push_back(done);
884 input.ready->setValue(s.bAnd(doneWires,
"allDone"));
890 void buildUnitRateJoinLogic(
891 RTLBuilder &s, UnwrappedIO &unwrappedIO,
892 llvm::function_ref<Value(ValueRange)> unitBuilder)
const {
893 assert(unwrappedIO.outputs.size() == 1 &&
894 "Expected exactly one output for unit-rate join actor");
896 this->buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
899 auto unitRes = unitBuilder(unwrappedIO.getInputDatas());
900 unwrappedIO.outputs[0].data->setValue(unitRes);
903 void buildUnitRateForkLogic(
905 llvm::function_ref<llvm::SmallVector<Value>(Value)> unitBuilder)
const {
906 assert(unwrappedIO.inputs.size() == 1 &&
907 "Expected exactly one input for unit-rate fork actor");
909 this->buildForkLogic(s, bb, unwrappedIO.inputs[0], unwrappedIO.outputs);
912 auto unitResults = unitBuilder(unwrappedIO.inputs[0].data);
913 assert(unitResults.size() == unwrappedIO.outputs.size() &&
914 "Expected unit builder to return one result per output");
915 for (
auto [res, outport] : llvm::zip(unitResults, unwrappedIO.outputs))
916 outport.data->setValue(res);
919 void buildExtendLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
920 bool signExtend)
const {
922 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
923 .getIntOrFloatBitWidth();
924 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange
inputs) {
926 return s.sext(
inputs[0], outWidth);
927 return s.zext(
inputs[0], outWidth);
931 void buildTruncateLogic(RTLBuilder &s, UnwrappedIO &unwrappedIO,
932 unsigned targetWidth)
const {
934 toValidType(
static_cast<Value
>(*unwrappedIO.outputs[0].data).getType())
935 .getIntOrFloatBitWidth();
936 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange
inputs) {
937 return s.truncate(
inputs[0], outWidth);
942 static size_t getNumIndexBits(uint64_t numValues) {
943 return numValues > 1 ? llvm::Log2_64_Ceil(numValues) : 1;
946 Value buildPriorityArbiter(RTLBuilder &s, ArrayRef<Value>
inputs,
948 DenseMap<size_t, Value> &indexMapping)
const {
949 auto numInputs =
inputs.size();
950 auto priorityArb = defaultValue;
952 for (
size_t i = numInputs; i > 0; --i) {
953 size_t inputIndex = i - 1;
954 size_t oneHotIndex =
size_t{1} << inputIndex;
955 auto constIndex = s.constant(numInputs, oneHotIndex);
956 indexMapping[inputIndex] = constIndex;
957 priorityArb = s.mux(
inputs[inputIndex], {priorityArb, constIndex});
963 OpBuilder &submoduleBuilder;
964 HandshakeLoweringState &ls;
967 class ForkConversionPattern :
public HandshakeConversionPattern<ForkOp> {
969 using HandshakeConversionPattern<ForkOp>::HandshakeConversionPattern;
971 hw::HWModulePortAccessor &ports)
const override {
972 auto unwrapped = unwrapIO(s, bb, ports);
973 buildUnitRateForkLogic(s, bb, unwrapped, [&](Value input) {
974 return llvm::SmallVector<Value>(unwrapped.outputs.size(), input);
979 class JoinConversionPattern :
public HandshakeConversionPattern<JoinOp> {
981 using HandshakeConversionPattern<JoinOp>::HandshakeConversionPattern;
983 hw::HWModulePortAccessor &ports)
const override {
984 auto unwrappedIO = unwrapIO(s, bb, ports);
985 buildJoinLogic(s, unwrappedIO.inputs, unwrappedIO.outputs[0]);
986 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
990 class SyncConversionPattern :
public HandshakeConversionPattern<SyncOp> {
992 using HandshakeConversionPattern<SyncOp>::HandshakeConversionPattern;
994 hw::HWModulePortAccessor &ports)
const override {
995 auto unwrappedIO = unwrapIO(s, bb, ports);
998 HandshakeWire wire(bb, s.b.getNoneType());
1000 OutputHandshake output = wire.getAsOutput();
1001 buildJoinLogic(s, unwrappedIO.inputs, output);
1003 InputHandshake input = wire.getAsInput();
1011 buildForkLogic(s, bb, input, unwrappedIO.outputs);
1015 for (
auto &&[in, out] : llvm::zip(unwrappedIO.inputs, unwrappedIO.outputs))
1016 out.data->setValue(in.data);
1020 class MuxConversionPattern :
public HandshakeConversionPattern<MuxOp> {
1022 using HandshakeConversionPattern<MuxOp>::HandshakeConversionPattern;
1024 hw::HWModulePortAccessor &ports)
const override {
1025 auto unwrappedIO = unwrapIO(s, bb, ports);
1028 auto select = unwrappedIO.inputs[0];
1029 unwrappedIO.inputs.erase(unwrappedIO.inputs.begin());
1030 buildMuxLogic(s, unwrappedIO, select);
1034 class InstanceConversionPattern
1035 :
public HandshakeConversionPattern<handshake::InstanceOp> {
1037 using HandshakeConversionPattern<
1038 handshake::InstanceOp>::HandshakeConversionPattern;
1040 hw::HWModulePortAccessor &ports)
const override {
1042 "If we indeed perform conversion in post-order, this "
1043 "should never be called. The base HandshakeConversionPattern logic "
1044 "will instantiate the external module.");
1048 class ReturnConversionPattern
1051 using OpConversionPattern::OpConversionPattern;
1053 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
1054 ConversionPatternRewriter &rewriter)
const override {
1057 auto parent = cast<hw::HWModuleOp>(op->getParentOp());
1058 auto outputOp = *parent.getBodyBlock()->getOps<hw::OutputOp>().begin();
1059 outputOp->setOperands(adaptor.getOperands());
1060 outputOp->moveAfter(&parent.getBodyBlock()->back());
1061 rewriter.eraseOp(op);
1068 template <
typename TIn,
typename TOut = TIn>
1069 class UnitRateConversionPattern :
public HandshakeConversionPattern<TIn> {
1071 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1073 hw::HWModulePortAccessor &ports)
const override {
1074 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1075 this->buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange
inputs) {
1080 return s.b.create<TOut>(op.getLoc(),
inputs,
1081 ArrayRef<NamedAttribute>{});
1086 class PackConversionPattern :
public HandshakeConversionPattern<PackOp> {
1088 using HandshakeConversionPattern<PackOp>::HandshakeConversionPattern;
1090 hw::HWModulePortAccessor &ports)
const override {
1091 auto unwrappedIO = unwrapIO(s, bb, ports);
1092 buildUnitRateJoinLogic(s, unwrappedIO,
1097 class StructCreateConversionPattern
1098 :
public HandshakeConversionPattern<hw::StructCreateOp> {
1100 using HandshakeConversionPattern<
1101 hw::StructCreateOp>::HandshakeConversionPattern;
1103 hw::HWModulePortAccessor &ports)
const override {
1104 auto unwrappedIO = unwrapIO(s, bb, ports);
1105 auto structType = op.getResult().getType();
1106 buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange
inputs) {
1107 return s.pack(
inputs, structType);
1112 class UnpackConversionPattern :
public HandshakeConversionPattern<UnpackOp> {
1114 using HandshakeConversionPattern<UnpackOp>::HandshakeConversionPattern;
1116 hw::HWModulePortAccessor &ports)
const override {
1117 auto unwrappedIO = unwrapIO(s, bb, ports);
1118 buildUnitRateForkLogic(s, bb, unwrappedIO,
1119 [&](Value input) {
return s.unpack(input); });
1123 class ConditionalBranchConversionPattern
1124 :
public HandshakeConversionPattern<ConditionalBranchOp> {
1126 using HandshakeConversionPattern<
1127 ConditionalBranchOp>::HandshakeConversionPattern;
1129 hw::HWModulePortAccessor &ports)
const override {
1130 auto unwrappedIO = unwrapIO(s, bb, ports);
1131 auto cond = unwrappedIO.inputs[0];
1132 auto arg = unwrappedIO.inputs[1];
1133 auto trueRes = unwrappedIO.outputs[0];
1134 auto falseRes = unwrappedIO.outputs[1];
1136 auto condArgValid = s.bAnd({cond.valid, arg.valid});
1139 trueRes.valid->setValue(s.bAnd({cond.data, condArgValid}));
1140 falseRes.valid->setValue(s.bAnd({s.bNot(cond.data), condArgValid}));
1143 trueRes.data->setValue(arg.data);
1144 falseRes.data->setValue(arg.data);
1147 auto selectedResultReady =
1148 s.mux(cond.data, {falseRes.ready, trueRes.ready});
1149 auto condArgReady = s.bAnd({selectedResultReady, condArgValid});
1150 arg.ready->setValue(condArgReady);
1151 cond.ready->setValue(condArgReady);
1155 template <
typename TIn,
bool signExtend>
1156 class ExtendConversionPattern :
public HandshakeConversionPattern<TIn> {
1158 using HandshakeConversionPattern<TIn>::HandshakeConversionPattern;
1160 hw::HWModulePortAccessor &ports)
const override {
1161 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1162 this->buildExtendLogic(s, unwrappedIO, signExtend);
1166 class ComparisonConversionPattern
1167 :
public HandshakeConversionPattern<arith::CmpIOp> {
1169 using HandshakeConversionPattern<arith::CmpIOp>::HandshakeConversionPattern;
1171 hw::HWModulePortAccessor &ports)
const override {
1172 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1173 auto buildCompareLogic = [&](comb::ICmpPredicate predicate) {
1174 return buildUnitRateJoinLogic(s, unwrappedIO, [&](ValueRange
inputs) {
1175 return s.b.create<comb::ICmpOp>(op.getLoc(), predicate,
inputs[0],
1180 switch (op.getPredicate()) {
1181 case arith::CmpIPredicate::eq:
1182 return buildCompareLogic(comb::ICmpPredicate::eq);
1183 case arith::CmpIPredicate::ne:
1184 return buildCompareLogic(comb::ICmpPredicate::ne);
1185 case arith::CmpIPredicate::slt:
1186 return buildCompareLogic(comb::ICmpPredicate::slt);
1187 case arith::CmpIPredicate::ult:
1188 return buildCompareLogic(comb::ICmpPredicate::ult);
1189 case arith::CmpIPredicate::sle:
1190 return buildCompareLogic(comb::ICmpPredicate::sle);
1191 case arith::CmpIPredicate::ule:
1192 return buildCompareLogic(comb::ICmpPredicate::ule);
1193 case arith::CmpIPredicate::sgt:
1194 return buildCompareLogic(comb::ICmpPredicate::sgt);
1195 case arith::CmpIPredicate::ugt:
1196 return buildCompareLogic(comb::ICmpPredicate::ugt);
1197 case arith::CmpIPredicate::sge:
1198 return buildCompareLogic(comb::ICmpPredicate::sge);
1199 case arith::CmpIPredicate::uge:
1200 return buildCompareLogic(comb::ICmpPredicate::uge);
1202 assert(
false &&
"invalid CmpIOp");
1206 class TruncateConversionPattern
1207 :
public HandshakeConversionPattern<arith::TruncIOp> {
1209 using HandshakeConversionPattern<arith::TruncIOp>::HandshakeConversionPattern;
1211 hw::HWModulePortAccessor &ports)
const override {
1212 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1213 unsigned targetBits =
1214 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1215 buildTruncateLogic(s, unwrappedIO, targetBits);
1219 class ControlMergeConversionPattern
1220 :
public HandshakeConversionPattern<ControlMergeOp> {
1222 using HandshakeConversionPattern<ControlMergeOp>::HandshakeConversionPattern;
1224 hw::HWModulePortAccessor &ports)
const override {
1225 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1226 auto resData = unwrappedIO.outputs[0];
1227 auto resIndex = unwrappedIO.outputs[1];
1230 unsigned numInputs = unwrappedIO.inputs.size();
1231 auto indexType = s.b.getIntegerType(numInputs);
1232 Value noWinner = s.constant(numInputs, 0);
1233 Value c0I1 = s.constant(1, 0);
1236 auto won = bb.
get(indexType);
1237 Value wonReg = s.reg(
"won_reg", won, noWinner);
1240 auto win = bb.
get(indexType);
1244 auto fired = bb.
get(s.b.getI1Type());
1247 auto resultEmitted = bb.
get(s.b.getI1Type());
1248 Value resultEmittedReg = s.reg(
"result_emitted_reg", resultEmitted, c0I1);
1249 auto indexEmitted = bb.
get(s.b.getI1Type());
1250 Value indexEmittedReg = s.reg(
"index_emitted_reg", indexEmitted, c0I1);
1253 auto resultDone = bb.
get(s.b.getI1Type());
1254 auto indexDone = bb.
get(s.b.getI1Type());
1258 auto hasWinnerCondition = s.rOr({win});
1259 auto hadWinnerCondition = s.rOr({wonReg});
1267 DenseMap<size_t, Value> argIndexValues;
1268 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1269 noWinner, argIndexValues);
1270 priorityArb = s.mux(hadWinnerCondition, {priorityArb, wonReg});
1271 win.setValue(priorityArb);
1281 auto resultNotEmitted = s.bNot(resultEmittedReg);
1282 auto resultValid = s.bAnd({hasWinnerCondition, resultNotEmitted});
1283 resData.valid->setValue(resultValid);
1284 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1286 auto indexNotEmitted = s.bNot(indexEmittedReg);
1287 auto indexValid = s.bAnd({hasWinnerCondition, indexNotEmitted});
1288 resIndex.valid->setValue(indexValid);
1292 SmallVector<Value, 8> indexOutputs;
1293 for (
size_t i = 0; i < numInputs; ++i)
1294 indexOutputs.push_back(s.constant(64, i));
1296 auto indexOutput = s.ohMux(win, indexOutputs);
1297 resIndex.data->setValue(indexOutput);
1303 won.setValue(s.mux(fired, {win, noWinner}));
1308 auto resultValidAndReady = s.bAnd({resultValid, resData.ready});
1309 resultDone.setValue(s.bOr({resultValidAndReady, resultEmittedReg}));
1311 auto indexValidAndReady = s.bAnd({indexValid, resIndex.ready});
1312 indexDone.setValue(s.bOr({indexValidAndReady, indexEmittedReg}));
1316 fired.setValue(s.bAnd({resultDone, indexDone}));
1322 resultEmitted.setValue(s.mux(fired, {resultDone, c0I1}));
1323 indexEmitted.setValue(s.mux(fired, {indexDone, c0I1}));
1328 auto winnerOrDefault = s.mux(fired, {noWinner, win});
1329 for (
auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1330 auto &indexValue = argIndexValues[i];
1331 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1336 class MergeConversionPattern :
public HandshakeConversionPattern<MergeOp> {
1338 using HandshakeConversionPattern<MergeOp>::HandshakeConversionPattern;
1340 hw::HWModulePortAccessor &ports)
const override {
1341 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1342 auto resData = unwrappedIO.outputs[0];
1345 unsigned numInputs = unwrappedIO.inputs.size();
1346 auto indexType = s.b.getIntegerType(numInputs);
1347 Value noWinner = s.constant(numInputs, 0);
1350 auto win = bb.
get(indexType);
1353 auto hasWinnerCondition = s.rOr(win);
1360 DenseMap<size_t, Value> argIndexValues;
1361 Value priorityArb = buildPriorityArbiter(s, unwrappedIO.getInputValids(),
1362 noWinner, argIndexValues);
1363 win.setValue(priorityArb);
1370 resData.valid->setValue(hasWinnerCondition);
1371 resData.data->setValue(s.ohMux(win, unwrappedIO.getInputDatas()));
1376 auto resultValidAndReady = s.bAnd({hasWinnerCondition, resData.ready});
1381 auto winnerOrDefault = s.mux(resultValidAndReady, {noWinner, win});
1382 for (
auto [i, ir] : llvm::enumerate(unwrappedIO.getInputReadys())) {
1383 auto &indexValue = argIndexValues[i];
1384 ir->setValue(s.cmp(winnerOrDefault, indexValue, comb::ICmpPredicate::eq));
1389 class LoadConversionPattern
1390 :
public HandshakeConversionPattern<handshake::LoadOp> {
1392 using HandshakeConversionPattern<
1393 handshake::LoadOp>::HandshakeConversionPattern;
1395 hw::HWModulePortAccessor &ports)
const override {
1396 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1397 auto addrFromUser = unwrappedIO.inputs[0];
1398 auto dataFromMem = unwrappedIO.inputs[1];
1399 auto controlIn = unwrappedIO.inputs[2];
1400 auto dataToUser = unwrappedIO.outputs[0];
1401 auto addrToMem = unwrappedIO.outputs[1];
1403 addrToMem.data->setValue(addrFromUser.data);
1404 dataToUser.data->setValue(dataFromMem.data);
1408 buildJoinLogic(s, {addrFromUser, controlIn}, addrToMem);
1412 dataToUser.valid->setValue(dataFromMem.valid);
1413 dataFromMem.ready->setValue(dataToUser.ready);
1417 class StoreConversionPattern
1418 :
public HandshakeConversionPattern<handshake::StoreOp> {
1420 using HandshakeConversionPattern<
1421 handshake::StoreOp>::HandshakeConversionPattern;
1423 hw::HWModulePortAccessor &ports)
const override {
1424 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1425 auto addrFromUser = unwrappedIO.inputs[0];
1426 auto dataFromUser = unwrappedIO.inputs[1];
1427 auto controlIn = unwrappedIO.inputs[2];
1428 auto dataToMem = unwrappedIO.outputs[0];
1429 auto addrToMem = unwrappedIO.outputs[1];
1432 auto outputsReady = s.bAnd({dataToMem.ready, addrToMem.ready});
1436 HandshakeWire joinWire(bb, s.b.getNoneType());
1437 joinWire.ready->setValue(outputsReady);
1438 OutputHandshake joinOutput = joinWire.getAsOutput();
1439 buildJoinLogic(s, {dataFromUser, addrFromUser, controlIn}, joinOutput);
1442 addrToMem.data->setValue(addrFromUser.data);
1443 dataToMem.data->setValue(dataFromUser.data);
1446 addrToMem.valid->setValue(*joinWire.valid);
1447 dataToMem.valid->setValue(*joinWire.valid);
1451 class MemoryConversionPattern
1452 :
public HandshakeConversionPattern<handshake::MemoryOp> {
1454 using HandshakeConversionPattern<
1455 handshake::MemoryOp>::HandshakeConversionPattern;
1457 hw::HWModulePortAccessor &ports)
const override {
1458 auto loc = op.getLoc();
1461 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1463 InputHandshake &
addr;
1464 OutputHandshake &
data;
1465 OutputHandshake &done;
1468 InputHandshake &
addr;
1469 InputHandshake &
data;
1470 OutputHandshake &done;
1472 SmallVector<LoadPort, 4> loadPorts;
1473 SmallVector<StorePort, 4> storePorts;
1475 unsigned stCount = op.getStCount();
1476 unsigned ldCount = op.getLdCount();
1477 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
1478 LoadPort port = {unwrappedIO.inputs[stCount * 2 + i],
1479 unwrappedIO.outputs[i],
1480 unwrappedIO.outputs[ldCount + stCount + i]};
1481 loadPorts.push_back(port);
1484 for (
unsigned i = 0, e = stCount; i != e; ++i) {
1485 StorePort port = {unwrappedIO.inputs[i * 2 + 1],
1486 unwrappedIO.inputs[i * 2],
1487 unwrappedIO.outputs[ldCount + i]};
1488 storePorts.push_back(port);
1492 auto c0I0 = s.constant(0, 0);
1494 auto cl2dim = llvm::Log2_64_Ceil(op.getMemRefType().getShape()[0]);
1495 auto hlmem = s.b.create<seq::HLMemOp>(
1496 loc, s.clk, s.rst,
"_handshake_memory_" + std::to_string(op.getId()),
1497 op.getMemRefType().getShape(), op.getMemRefType().getElementType());
1500 for (
auto &ld : loadPorts) {
1501 llvm::SmallVector<Value> addresses = {s.truncate(ld.addr.data, cl2dim)};
1502 auto readData = s.b.create<seq::ReadPortOp>(loc, hlmem.getHandle(),
1503 addresses, ld.addr.valid,
1505 ld.data.data->setValue(readData);
1506 ld.done.data->setValue(c0I0);
1508 buildForkLogic(s, bb, ld.addr, {ld.data, ld.done});
1512 for (
auto &st : storePorts) {
1515 auto writeValidBufferMuxBE = bb.
get(s.b.getI1Type());
1516 auto writeValidBuffer =
1517 s.reg(
"writeValidBuffer", writeValidBufferMuxBE, s.constant(1, 0));
1518 st.done.valid->
setValue(writeValidBuffer);
1519 st.done.data->setValue(c0I0);
1523 auto storeCompleted =
1524 s.bAnd({st.done.ready, writeValidBuffer},
"storeCompleted");
1528 auto notWriteValidBuffer = s.bNot(writeValidBuffer);
1529 auto emptyOrComplete =
1530 s.bOr({notWriteValidBuffer, storeCompleted},
"emptyOrComplete");
1533 st.addr.ready->setValue(emptyOrComplete);
1534 st.data.ready->setValue(emptyOrComplete);
1537 auto writeValid = s.bAnd({st.addr.valid, st.data.valid},
"writeValid");
1543 writeValidBufferMuxBE.setValue(
1544 s.mux(emptyOrComplete, {writeValidBuffer, writeValid}));
1548 llvm::SmallVector<Value> addresses = {s.truncate(st.addr.data, cl2dim)};
1549 s.b.create<seq::WritePortOp>(loc, hlmem.getHandle(), addresses,
1550 st.data.data, writeValid,
1556 class SinkConversionPattern :
public HandshakeConversionPattern<SinkOp> {
1558 using HandshakeConversionPattern<SinkOp>::HandshakeConversionPattern;
1560 hw::HWModulePortAccessor &ports)
const override {
1561 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1563 unwrappedIO.inputs[0].ready->setValue(s.constant(1, 1));
1567 class SourceConversionPattern :
public HandshakeConversionPattern<SourceOp> {
1569 using HandshakeConversionPattern<SourceOp>::HandshakeConversionPattern;
1571 hw::HWModulePortAccessor &ports)
const override {
1572 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1574 unwrappedIO.outputs[0].valid->setValue(s.constant(1, 1));
1575 unwrappedIO.outputs[0].data->setValue(s.constant(0, 0));
1579 class ConstantConversionPattern
1580 :
public HandshakeConversionPattern<handshake::ConstantOp> {
1582 using HandshakeConversionPattern<
1583 handshake::ConstantOp>::HandshakeConversionPattern;
1585 hw::HWModulePortAccessor &ports)
const override {
1586 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1587 unwrappedIO.outputs[0].valid->setValue(unwrappedIO.inputs[0].valid);
1588 unwrappedIO.inputs[0].ready->setValue(unwrappedIO.outputs[0].ready);
1589 auto constantValue = op->getAttrOfType<IntegerAttr>(
"value").getValue();
1590 unwrappedIO.outputs[0].data->setValue(s.constant(constantValue));
1594 class BufferConversionPattern :
public HandshakeConversionPattern<BufferOp> {
1596 using HandshakeConversionPattern<BufferOp>::HandshakeConversionPattern;
1598 hw::HWModulePortAccessor &ports)
const override {
1599 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1600 auto input = unwrappedIO.inputs[0];
1601 auto output = unwrappedIO.outputs[0];
1602 InputHandshake lastStage;
1603 SmallVector<int64_t> initValues;
1606 if (op.getInitValues())
1607 initValues = op.getInitValueArray();
1610 buildSeqBufferLogic(s, bb,
toValidType(op.getDataType()),
1611 op.getNumSlots(), input, output, initValues);
1614 output.data->setValue(lastStage.data);
1615 output.valid->setValue(lastStage.valid);
1616 lastStage.ready->setValue(output.ready);
1619 struct SeqBufferStage {
1620 SeqBufferStage(Type dataType, InputHandshake &preStage,
BackedgeBuilder &bb,
1621 RTLBuilder &s,
size_t index,
1622 std::optional<int64_t> initValue)
1623 : dataType(dataType), preStage(preStage), s(s), bb(bb), index(index) {
1626 c0s = createZeroDataConst(s, s.loc, dataType);
1627 currentStage.ready = std::make_shared<Backedge>(bb.
get(s.b.getI1Type()));
1629 auto hasInitValue = s.constant(1, initValue.has_value());
1630 auto validBE = bb.
get(s.b.getI1Type());
1631 auto validReg = s.reg(getRegName(
"valid"), validBE, hasInitValue);
1632 auto readyBE = bb.
get(s.b.getI1Type());
1634 Value initValueCs = c0s;
1635 if (initValue.has_value())
1636 initValueCs = s.constant(dataType.getIntOrFloatBitWidth(), *initValue);
1641 buildDataBufferLogic(validReg, initValueCs, validBE, readyBE);
1642 buildControlBufferLogic(validReg, readyBE, dataReg);
1645 StringAttr getRegName(StringRef name) {
1646 return s.b.getStringAttr(name + std::to_string(index) +
"_reg");
1649 void buildControlBufferLogic(Value validReg,
Backedge &readyBE,
1651 auto c0I1 = s.constant(1, 0);
1652 auto readyRegWire = bb.
get(s.b.getI1Type());
1653 auto readyReg = s.reg(getRegName(
"ready"), readyRegWire, c0I1);
1657 currentStage.valid = s.mux(readyReg, {validReg, readyReg},
1658 "controlValid" + std::to_string(index));
1661 auto notReadyReg = s.bNot(readyReg);
1664 auto succNotReady = s.bNot(*currentStage.ready);
1665 auto neitherReady = s.bAnd({succNotReady, notReadyReg});
1666 auto ctrlNotReady = s.mux(neitherReady, {readyReg, validReg});
1667 auto bothReady = s.bAnd({*currentStage.ready, readyReg});
1670 auto resetSignal = s.mux(bothReady, {ctrlNotReady, c0I1});
1671 readyRegWire.setValue(resetSignal);
1674 auto ctrlDataRegBE = bb.
get(dataType);
1675 auto ctrlDataReg = s.reg(getRegName(
"ctrl_data"), ctrlDataRegBE, c0s);
1676 auto dataResult = s.mux(readyReg, {dataReg, ctrlDataReg});
1677 currentStage.data = dataResult;
1679 auto dataNotReadyMux = s.mux(neitherReady, {ctrlDataReg, dataReg});
1680 auto dataResetSignal = s.mux(bothReady, {dataNotReadyMux, c0s});
1681 ctrlDataRegBE.
setValue(dataResetSignal);
1684 Value buildDataBufferLogic(Value validReg, Value initValue,
1688 auto notValidReg = s.bNot(validReg);
1689 auto emptyOrReady = s.bOr({notValidReg, readyBE});
1690 preStage.ready->setValue(emptyOrReady);
1696 auto validRegMux = s.mux(emptyOrReady, {validReg, preStage.valid});
1702 auto dataRegBE = bb.
get(dataType);
1704 s.reg(getRegName(
"data"),
1705 s.mux(emptyOrReady, {dataRegBE, preStage.data}), initValue);
1710 InputHandshake getOutput() {
return currentStage; }
1713 InputHandshake &preStage;
1714 InputHandshake currentStage;
1723 InputHandshake buildSeqBufferLogic(RTLBuilder &s,
BackedgeBuilder &bb,
1724 Type dataType,
unsigned size,
1725 InputHandshake &input,
1726 OutputHandshake &output,
1727 llvm::ArrayRef<int64_t> initValues)
const {
1730 InputHandshake currentStage = input;
1732 for (
unsigned i = 0; i <
size; ++i) {
1733 bool isInitialized = i < initValues.size();
1735 isInitialized ? std::optional<int64_t>(initValues[i]) : std::nullopt;
1736 currentStage = SeqBufferStage(dataType, currentStage, bb, s, i, initValue)
1740 return currentStage;
1744 class IndexCastConversionPattern
1745 :
public HandshakeConversionPattern<arith::IndexCastOp> {
1747 using HandshakeConversionPattern<
1748 arith::IndexCastOp>::HandshakeConversionPattern;
1750 hw::HWModulePortAccessor &ports)
const override {
1751 auto unwrappedIO = this->unwrapIO(s, bb, ports);
1752 unsigned sourceBits =
1753 toValidType(op.getIn().getType()).getIntOrFloatBitWidth();
1754 unsigned targetBits =
1755 toValidType(op.getResult().getType()).getIntOrFloatBitWidth();
1756 if (targetBits < sourceBits)
1757 buildTruncateLogic(s, unwrappedIO, targetBits);
1759 buildExtendLogic(s, unwrappedIO,
true);
1763 template <
typename T>
1766 ExtModuleConversionPattern(ESITypeConverter &typeConverter,
1767 MLIRContext *context, OpBuilder &submoduleBuilder,
1768 HandshakeLoweringState &ls)
1770 submoduleBuilder(submoduleBuilder), ls(ls) {}
1771 using OpAdaptor =
typename T::Adaptor;
1774 matchAndRewrite(T op, OpAdaptor adaptor,
1775 ConversionPatternRewriter &rewriter)
const override {
1780 implModule = submoduleBuilder.create<hw::HWModuleExternOp>(
1785 llvm::SmallVector<Value> operands = adaptor.getOperands();
1786 addSequentialIOOperandsIfNeeded(op, operands);
1787 rewriter.replaceOpWithNewOp<hw::InstanceOp>(
1788 op, implModule, rewriter.getStringAttr(ls.nameUniquer(op)), operands);
1793 OpBuilder &submoduleBuilder;
1794 HandshakeLoweringState &ls;
1799 using OpConversionPattern::OpConversionPattern;
1802 matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
1803 ConversionPatternRewriter &rewriter)
const override {
1807 HWModuleLike hwModule;
1808 if (op.isExternal()) {
1809 hwModule = rewriter.create<hw::HWModuleExternOp>(
1810 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1812 auto hwModuleOp = rewriter.create<hw::HWModuleOp>(
1813 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
1814 auto args = hwModuleOp.getBodyBlock()->getArguments().drop_back(2);
1815 rewriter.inlineBlockBefore(&op.getBody().front(),
1816 hwModuleOp.getBodyBlock()->getTerminator(),
1818 hwModule = hwModuleOp;
1825 auto *parentOp = op->getParentOp();
1826 auto *predeclModule =
1827 SymbolTable::lookupSymbolIn(parentOp, predecl.getValue());
1828 if (predeclModule) {
1829 if (failed(SymbolTable::replaceAllSymbolUses(
1830 predeclModule, hwModule.getModuleNameAttr(), parentOp)))
1832 rewriter.eraseOp(predeclModule);
1836 rewriter.eraseOp(op);
1848 ConversionTarget &target,
1849 handshake::FuncOp op,
1850 OpBuilder &moduleBuilder) {
1852 std::map<std::string, unsigned> instanceNameCntr;
1853 NameUniquer instanceUniquer = [&](Operation *op) {
1855 if (
auto idAttr = op->getAttrOfType<IntegerAttr>(
"handshake_id"); idAttr) {
1858 instName +=
"_id" + std::to_string(idAttr.getValue().getZExtValue());
1861 instName += std::to_string(instanceNameCntr[instName]++);
1866 auto ls = HandshakeLoweringState{op->getParentOfType<mlir::ModuleOp>(),
1868 RewritePatternSet
patterns(op.getContext());
1869 patterns.insert<FuncOpConversionPattern, ReturnConversionPattern>(
1871 patterns.insert<JoinConversionPattern, ForkConversionPattern,
1872 SyncConversionPattern>(typeConverter, op.getContext(),
1877 UnitRateConversionPattern<arith::AddIOp, comb::AddOp>,
1878 UnitRateConversionPattern<arith::SubIOp, comb::SubOp>,
1879 UnitRateConversionPattern<arith::MulIOp, comb::MulOp>,
1880 UnitRateConversionPattern<arith::DivUIOp, comb::DivSOp>,
1881 UnitRateConversionPattern<arith::DivSIOp, comb::DivUOp>,
1882 UnitRateConversionPattern<arith::RemUIOp, comb::ModUOp>,
1883 UnitRateConversionPattern<arith::RemSIOp, comb::ModSOp>,
1884 UnitRateConversionPattern<arith::AndIOp, comb::AndOp>,
1885 UnitRateConversionPattern<arith::OrIOp, comb::OrOp>,
1886 UnitRateConversionPattern<arith::XOrIOp, comb::XorOp>,
1887 UnitRateConversionPattern<arith::ShLIOp, comb::OrOp>,
1888 UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
1889 UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
1890 UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
1892 StructCreateConversionPattern,
1894 ConditionalBranchConversionPattern, MuxConversionPattern,
1895 PackConversionPattern, UnpackConversionPattern,
1896 ComparisonConversionPattern, BufferConversionPattern,
1897 SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
1898 MergeConversionPattern, ControlMergeConversionPattern,
1899 LoadConversionPattern, StoreConversionPattern, MemoryConversionPattern,
1900 InstanceConversionPattern,
1902 ExtendConversionPattern<arith::ExtUIOp,
false>,
1903 ExtendConversionPattern<arith::ExtSIOp,
true>,
1904 TruncateConversionPattern, IndexCastConversionPattern>(
1905 typeConverter, op.getContext(), moduleBuilder, ls);
1907 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
1908 return op->emitOpError() <<
"error during conversion";
1913 class HandshakeToHWPass :
public HandshakeToHWBase<HandshakeToHWPass> {
1915 void runOnOperation()
override {
1916 mlir::ModuleOp mod = getOperation();
1920 for (
auto f : mod.getOps<handshake::FuncOp>()) {
1922 f.emitOpError() <<
"HandshakeToHW: failed to verify that all values "
1923 "are used exactly once. Remember to run the "
1924 "fork/sink materialization pass before HW lowering.";
1925 signalPassFailure();
1931 std::string topLevel;
1933 SmallVector<std::string> sortedFuncs;
1935 signalPassFailure();
1939 ESITypeConverter typeConverter;
1940 ConversionTarget target(getContext());
1943 target.addLegalOp<hw::HWModuleOp, hw::HWModuleExternOp, hw::OutputOp,
1946 .addIllegalDialect<handshake::HandshakeDialect, arith::ArithDialect>();
1952 OpBuilder submoduleBuilder(mod.getContext());
1953 submoduleBuilder.setInsertionPointToStart(mod.getBody());
1954 for (
auto &funcName : llvm::reverse(sortedFuncs)) {
1955 auto funcOp = mod.lookupSymbol<handshake::FuncOp>(funcName);
1956 assert(funcOp &&
"handshake.func not found in module!");
1958 convertFuncOp(typeConverter, target, funcOp, submoduleBuilder))) {
1959 signalPassFailure();
1966 for (
auto hwModule : mod.getOps<hw::HWModuleOp>())
1968 return signalPassFailure();
1974 return std::make_unique<HandshakeToHWPass>();
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static Type tupleToStruct(TupleType tuple)
std::function< std::string(Operation *)> NameUniquer
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, bool withAnnotations=true)
static std::string getCallName(Operation *op)
static SmallVector< Type > filterNoneTypes(ArrayRef< Type > input)
Filters NoneType's from the input.
static Type getOperandDataType(Value op)
Extracts the type of the data-carrying type of opType.
static DiscriminatingTypes getHandshakeDiscriminatingTypes(Operation *op)
static ModulePortInfo getPortInfoForOp(Operation *op)
Returns a vector of PortInfo's which defines the HW interface of the to-be-converted op.
static std::string getBareSubModuleName(Operation *oldOp)
Returns a submodule name resulting from an operation, without discriminating type information.
static std::string getSubModuleName(Operation *oldOp)
Construct a name for creating HW sub-module.
static HWModuleLike checkSubModuleOp(mlir::ModuleOp parentModule, StringRef modName)
Check whether a submodule with the same name has been created elsewhere in the top level module.
std::pair< SmallVector< Type >, SmallVector< Type > > DiscriminatingTypes
Returns a set of types which may uniquely identify the provided op.
static LogicalResult convertFuncOp(ESITypeConverter &typeConverter, ConversionTarget &target, handshake::FuncOp op, OpBuilder &moduleBuilder)
static llvm::SmallVector< hw::detail::FieldInfo > portToFieldInfo(llvm::ArrayRef< hw::PortInfo > portInfo)
static std::string getTypeName(Location loc, Type type)
Get type name.
static LogicalResult convertExtMemoryOps(HWModuleOp mod)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
static size_t bits(::capnp::schema::Type::Reader type)
Return the number of bits used by a Capnp type.
static int64_t size(hw::ArrayType mType, capnp::schema::Field::Reader cField)
Returns the expected size of an array (capnp list) in 64-bit words.
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
mlir::Type innerType(mlir::Type type)
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs, TypeRange outputs)
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
static constexpr const char * kPredeclarationAttr
LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp op)
esi::ChannelType esiWrapper(mlir::Type t)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createHandshakeToHWPass()
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
This holds a decoded list of input/inout and output ports for a module or instance.