14 #include "../PassDetail.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/MathExtras.h"
36 using namespace circt;
44 auto *ctx = tuple.getContext();
45 mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
46 for (
auto [i,
innerType] : llvm::enumerate(tuple)) {
48 if (
auto tupleInnerType =
innerType.dyn_cast<TupleType>())
63 return TypeSwitch<Type, Type>(t)
65 .Case([](hw::StructType st) {
66 llvm::SmallVector<hw::StructType::FieldInfo> structFields(
68 for (
auto &field : structFields)
73 .Default([](Type t) {
return t; });
78 llvm::TypeSwitch<Type, Type>(t)
79 .Case([](ValueType vt) {
83 .Case([](TokenType tt) {
87 .Default([](
auto t) {
return toHWType(t); });
96 struct DCLoweringState {
97 ModuleOp parentModule;
104 class ESITypeConverter :
public TypeConverter {
107 addConversion([](Type type) -> Type {
return toESIHWType(type); });
108 addConversion([](esi::ChannelType t) -> Type {
return t; });
109 addTargetMaterialization(
110 [](mlir::OpBuilder &
builder, mlir::Type resultType,
112 mlir::Location loc) -> std::optional<mlir::Value> {
119 addSourceMaterialization(
120 [](mlir::OpBuilder &
builder, mlir::Type resultType,
122 mlir::Location loc) -> std::optional<mlir::Value> {
141 struct InputHandshake {
144 std::optional<Backedge> ready;
150 struct OutputHandshake {
152 std::optional<Backedge> valid;
154 std::optional<Backedge>
data;
158 static void connect(InputHandshake &input, OutputHandshake &output) {
159 output.valid->setValue(input.valid);
160 input.ready->setValue(output.ready);
163 template <
typename T,
typename TInner>
164 llvm::SmallVector<T> extractValues(llvm::SmallVector<TInner> &container,
165 llvm::function_ref<T(TInner &)> extractor) {
166 llvm::SmallVector<T> result;
167 llvm::transform(container, std::back_inserter(result), extractor);
174 llvm::SmallVector<InputHandshake>
inputs;
175 llvm::SmallVector<OutputHandshake>
outputs;
177 llvm::SmallVector<Value> getInputValids() {
178 return extractValues<Value, InputHandshake>(
179 inputs, [](
auto &hs) {
return hs.valid; });
181 llvm::SmallVector<std::optional<Backedge>> getInputReadys() {
182 return extractValues<std::optional<Backedge>, InputHandshake>(
183 inputs, [](
auto &hs) {
return hs.ready; });
185 llvm::SmallVector<std::optional<Backedge>> getOutputValids() {
186 return extractValues<std::optional<Backedge>, OutputHandshake>(
187 outputs, [](
auto &hs) {
return hs.valid; });
189 llvm::SmallVector<Value> getInputDatas() {
190 return extractValues<Value, InputHandshake>(
191 inputs, [](
auto &hs) {
return hs.data; });
193 llvm::SmallVector<Value> getOutputReadys() {
194 return extractValues<Value, OutputHandshake>(
195 outputs, [](
auto &hs) {
return hs.ready; });
198 llvm::SmallVector<Value> getOutputChannels() {
199 return extractValues<Value, OutputHandshake>(
200 outputs, [](
auto &hs) {
return hs.channel; });
202 llvm::SmallVector<std::optional<Backedge>> getOutputDatas() {
203 return extractValues<std::optional<Backedge>, OutputHandshake>(
204 outputs, [](
auto &hs) {
return hs.data; });
212 RTLBuilder(Location loc, OpBuilder &
builder, Value clk = Value(),
216 Value constant(
const APInt &apv, StringRef name = {}) {
219 bool isZeroWidth = apv.getBitWidth() == 0;
221 auto it = constants.find(apv);
222 if (it != constants.end())
228 constants[apv] = cval;
232 Value constant(
unsigned width, int64_t
value, StringRef name = {}) {
235 std::pair<Value, Value>
wrap(Value data, Value valid, StringRef name = {}) {
236 auto wrapOp = b.create<esi::WrapValidReadyOp>(loc,
data, valid);
237 return {wrapOp.getResult(0), wrapOp.getResult(1)};
239 std::pair<Value, Value>
unwrap(Value channel, Value ready,
240 StringRef name = {}) {
241 auto unwrapOp = b.create<esi::UnwrapValidReadyOp>(loc, channel, ready);
242 return {unwrapOp.getResult(0), unwrapOp.getResult(1)};
246 Value
reg(StringRef name, Value in, Value rstValue, Value clk = Value(),
247 Value rst = Value()) {
248 Value resolvedClk =
clk ?
clk : this->
clk;
249 Value resolvedRst = rst ? rst : this->rst;
251 "No global clock provided to this RTLBuilder - a clock "
252 "signal must be provided to the reg(...) function.");
254 "No global reset provided to this RTLBuilder - a reset "
255 "signal must be provided to the reg(...) function.");
257 return b.create<
seq::CompRegOp>(loc, in, resolvedClk, resolvedRst, rstValue,
261 Value cmp(Value lhs, Value rhs, comb::ICmpPredicate predicate,
262 StringRef name = {}) {
263 return b.
create<comb::ICmpOp>(loc, predicate, lhs, rhs);
266 Value buildNamedOp(llvm::function_ref<Value()> f, StringRef name) {
269 Operation *op = v.getDefiningOp();
271 op->setAttr(
"sv.namehint", b.getStringAttr(name));
272 nameAttr = b.getStringAttr(name);
278 Value bitAnd(ValueRange values, StringRef name = {}) {
280 [&]() {
return b.create<
comb::AndOp>(loc, values,
false); }, name);
284 Value bitOr(ValueRange values, StringRef name = {}) {
286 [&]() {
return b.create<
comb::OrOp>(loc, values,
false); }, name);
290 Value bitNot(Value
value, StringRef name = {}) {
291 auto allOnes = constant(
value.getType().getIntOrFloatBitWidth(), -1);
292 std::string inferedName;
296 value.getDefiningOp()->getAttrOfType<StringAttr>(
"sv.namehint")) {
297 inferedName = (
"not_" +
valueName.getValue()).str();
306 Value shl(Value
value, Value shift, StringRef name = {}) {
311 Value
concat(ValueRange values, StringRef name = {}) {
312 return buildNamedOp([&]() {
return b.create<
comb::ConcatOp>(loc, values); },
316 llvm::SmallVector<Value>
extractBits(Value v, StringRef name = {}) {
317 llvm::SmallVector<Value> bits;
318 for (
unsigned i = 0, e = v.getType().getIntOrFloatBitWidth(); i != e; ++i)
324 Value reduceOr(Value v, StringRef name = {}) {
325 return buildNamedOp([&]() {
return bitOr(
extractBits(v)); }, name);
329 Value extract(Value v,
unsigned lo,
unsigned hi, StringRef name = {}) {
330 unsigned width = hi - lo + 1;
336 Value truncate(Value
value,
unsigned width, StringRef name = {}) {
340 Value zext(Value
value,
unsigned outWidth, StringRef name = {}) {
341 unsigned inWidth =
value.getType().getIntOrFloatBitWidth();
342 assert(inWidth <= outWidth &&
"zext: input width must be <= output width.");
343 if (inWidth == outWidth)
345 auto c0 = constant(outWidth - inWidth, 0);
349 Value sext(Value
value,
unsigned outWidth, StringRef name = {}) {
354 Value bit(Value v,
unsigned index, StringRef name = {}) {
355 return extract(v, index, index, name);
359 Value arrayCreate(ValueRange values, StringRef name = {}) {
365 Value arrayGet(Value array, Value index, StringRef name = {}) {
367 [&]() {
return b.create<
hw::ArrayGetOp>(loc, array, index); }, name);
373 Value mux(Value index, ValueRange values, StringRef name = {}) {
374 if (values.size() == 2) {
377 return b.create<
comb::MuxOp>(loc, index, values[1], values[0]);
381 return arrayGet(arrayCreate(values), index, name);
386 Value oneHotMux(Value index, ValueRange
inputs) {
388 unsigned numInputs =
inputs.size();
389 assert(numInputs == index.getType().getIntOrFloatBitWidth() &&
390 "mismatch between width of one-hot select input and the number of "
391 "inputs to be selected");
394 auto dataType =
inputs[0].getType();
396 dataType.isa<NoneType>() ? 0 : dataType.getIntOrFloatBitWidth();
397 Value muxValue = constant(
width, 0);
400 for (
size_t i = numInputs - 1; i != 0; --i) {
402 Value selectBit = bit(index, i);
403 muxValue = mux(selectBit, {muxValue, input});
412 DenseMap<APInt, Value> constants;
415 static bool isZeroWidthType(Type type) {
416 if (
auto intType = type.dyn_cast<IntegerType>())
417 return intType.getWidth() == 0;
418 return type.isa<NoneType>();
421 static UnwrappedIO unwrapIO(Location loc, ValueRange operands,
423 ConversionPatternRewriter &rewriter,
425 RTLBuilder rtlb(loc, rewriter);
426 UnwrappedIO unwrapped;
427 for (
auto in : operands) {
428 assert(isa<esi::ChannelType>(in.getType()));
429 auto ready = bb.
get(rtlb.b.getI1Type());
430 auto [
data, valid] = rtlb.unwrap(in, ready);
431 unwrapped.inputs.push_back(InputHandshake{in, valid, ready,
data});
433 for (
auto outputType : results) {
435 esi::ChannelType channelType = cast<esi::ChannelType>(outputType);
442 rewriter.create<
hw::ConstantOp>(loc, rewriter.getIntegerType(0), 0);
446 hs.data = dataBackedge;
449 auto valid = bb.
get(rewriter.getI1Type());
450 auto [dataCh, ready] = rtlb.wrap(data, valid);
454 unwrapped.outputs.push_back(hs);
459 static UnwrappedIO unwrapIO(Operation *op, ValueRange operands,
460 ConversionPatternRewriter &rewriter,
462 return unwrapIO(op->getLoc(), operands, op->getResultTypes(), rewriter, bb);
468 auto *parent = op->getParentOp();
469 auto parentFuncOp = dyn_cast<HWModuleLike>(parent);
471 return parent->emitOpError(
"parent op does not implement HWModuleLike");
473 auto argAttrs = parentFuncOp.getAllInputAttrs();
475 std::optional<size_t> clockIdx, resetIdx;
477 for (
auto [idx, battrs] : llvm::enumerate(argAttrs)) {
478 auto attrs = cast<DictionaryAttr>(battrs);
479 if (attrs.get(
"dc.clock")) {
481 return parent->emitOpError(
482 "multiple arguments contains a 'dc.clock' attribute");
486 if (attrs.get(
"dc.reset")) {
488 return parent->emitOpError(
489 "multiple arguments contains a 'dc.reset' attribute");
495 return parent->emitOpError(
"no argument contains a 'dc.clock' attribute");
498 return parent->emitOpError(
"no argument contains a 'dc.reset' attribute");
500 return {std::make_pair(parentFuncOp.getArgumentForInput(*clockIdx),
501 parentFuncOp.getArgumentForInput(*resetIdx))};
508 matchAndRewrite(ForkOp op, OpAdaptor operands,
509 ConversionPatternRewriter &rewriter)
const override {
511 auto crRes = getClockAndReset(op);
514 auto [clock, reset] = *crRes;
515 RTLBuilder rtlb(op.getLoc(), rewriter, clock, reset);
516 UnwrappedIO io = unwrapIO(op, operands.getOperands(), rewriter, bb);
518 auto &input = io.inputs[0];
520 Value c0I1 = rtlb.constant(1, 0);
521 llvm::SmallVector<Value> doneWires;
522 for (
auto [i, output] : llvm::enumerate(io.outputs)) {
524 Value emitted = rtlb.bitAnd({doneBE, rtlb.bitNot(*input.ready)});
526 rtlb.reg(
"emitted_" + std::to_string(i), emitted, c0I1);
527 Value outValid = rtlb.bitAnd({rtlb.bitNot(emittedReg), input.valid});
528 output.valid->setValue(outValid);
529 Value validReady = rtlb.bitAnd({output.ready, outValid});
531 rtlb.bitOr({validReady, emittedReg},
"done" + std::to_string(i));
533 doneWires.push_back(done);
535 input.ready->setValue(rtlb.bitAnd(doneWires,
"allDone"));
537 rewriter.replaceOp(op, io.getOutputChannels());
547 matchAndRewrite(JoinOp op, OpAdaptor operands,
548 ConversionPatternRewriter &rewriter)
const override {
550 UnwrappedIO io = unwrapIO(op, operands.getOperands(), rewriter, bb);
551 RTLBuilder rtlb(op.getLoc(), rewriter);
552 auto &output = io.outputs[0];
554 Value allValid = rtlb.bitAnd(io.getInputValids());
555 output.valid->setValue(allValid);
557 auto validAndReady = rtlb.bitAnd({output.ready, allValid});
558 for (
auto &input : io.inputs)
559 input.ready->setValue(validAndReady);
561 rewriter.replaceOp(op, io.outputs[0].channel);
571 matchAndRewrite(SelectOp op, OpAdaptor operands,
572 ConversionPatternRewriter &rewriter)
const override {
574 UnwrappedIO io = unwrapIO(op, operands.getOperands(), rewriter, bb);
575 RTLBuilder rtlb(op.getLoc(), rewriter);
578 auto select = io.inputs[0];
579 io.inputs.erase(io.inputs.begin());
580 buildMuxLogic(rtlb, io, select);
582 rewriter.replaceOp(op, io.outputs[0].channel);
589 void buildMuxLogic(RTLBuilder &rtlb, UnwrappedIO &unwrapped,
590 InputHandshake &select)
const {
593 size_t numInputs = unwrapped.inputs.size();
594 size_t selectWidth = llvm::Log2_64_Ceil(numInputs);
595 Value truncatedSelect =
596 select.data.getType().getIntOrFloatBitWidth() > selectWidth
597 ? rtlb.truncate(select.data, selectWidth)
601 auto selectZext = rtlb.zext(truncatedSelect, numInputs);
602 auto select1h = rtlb.shl(rtlb.constant(numInputs, 1), selectZext);
603 auto &res = unwrapped.outputs[0];
606 auto selectedInputValid =
607 rtlb.mux(truncatedSelect, unwrapped.getInputValids());
609 auto selAndInputValid = rtlb.bitAnd({selectedInputValid, select.valid});
610 res.valid->setValue(selAndInputValid);
611 auto resValidAndReady = rtlb.bitAnd({selAndInputValid, res.ready});
614 select.ready->setValue(resValidAndReady);
617 for (
auto [inIdx, in] : llvm::enumerate(unwrapped.inputs)) {
619 auto isSelected = rtlb.bit(select1h, inIdx);
623 auto activeAndResultValidAndReady =
624 rtlb.bitAnd({isSelected, resValidAndReady});
625 in.ready->setValue(activeAndResultValidAndReady);
632 using OpConversionPattern::OpConversionPattern;
634 matchAndRewrite(BranchOp op, OpAdaptor operands,
635 ConversionPatternRewriter &rewriter)
const override {
637 UnwrappedIO io = unwrapIO(op, operands.getOperands(), rewriter, bb);
638 RTLBuilder rtlb(op.getLoc(), rewriter);
639 auto cond = io.inputs[0];
640 auto trueRes = io.outputs[0];
641 auto falseRes = io.outputs[1];
644 trueRes.valid->setValue(rtlb.bitAnd({cond.data, cond.valid}));
645 falseRes.valid->setValue(rtlb.bitAnd({rtlb.bitNot(cond.data), cond.valid}));
648 Value selectedResultReady =
649 rtlb.mux(cond.data, {falseRes.ready, trueRes.ready});
650 Value condReady = rtlb.bitAnd({selectedResultReady, cond.valid});
651 cond.ready->setValue(condReady);
653 rewriter.replaceOp(op,
654 SmallVector<Value>{trueRes.channel, falseRes.channel});
663 using OpConversionPattern::OpConversionPattern;
666 matchAndRewrite(ToESIOp op, OpAdaptor operands,
667 ConversionPatternRewriter &rewriter)
const override {
668 rewriter.replaceOp(op, operands.getOperands());
677 using OpConversionPattern::OpConversionPattern;
680 matchAndRewrite(FromESIOp op, OpAdaptor operands,
681 ConversionPatternRewriter &rewriter)
const override {
682 rewriter.replaceOp(op, operands.getOperands());
692 matchAndRewrite(SinkOp op, OpAdaptor operands,
693 ConversionPatternRewriter &rewriter)
const override {
695 UnwrappedIO io = unwrapIO(op, operands.getOperands(), rewriter, bb);
696 io.inputs[0].ready->setValue(
697 RTLBuilder(op.getLoc(), rewriter).constant(1, 1));
698 rewriter.eraseOp(op);
707 matchAndRewrite(SourceOp op, OpAdaptor operands,
708 ConversionPatternRewriter &rewriter)
const override {
710 UnwrappedIO io = unwrapIO(op, operands.getOperands(), rewriter, bb);
711 RTLBuilder rtlb(op.getLoc(), rewriter);
712 io.outputs[0].valid->setValue(rtlb.constant(1, 1));
713 rewriter.replaceOp(op, io.outputs[0].channel);
722 matchAndRewrite(PackOp op, OpAdaptor operands,
723 ConversionPatternRewriter &rewriter)
const override {
725 UnwrappedIO io = unwrapIO(op, llvm::SmallVector<Value>{operands.getToken()},
727 RTLBuilder rtlb(op.getLoc(), rewriter);
728 auto &input = io.inputs[0];
729 auto &output = io.outputs[0];
730 output.data->setValue(operands.getInput());
732 rewriter.replaceOp(op, output.channel);
741 matchAndRewrite(UnpackOp op, OpAdaptor operands,
742 ConversionPatternRewriter &rewriter)
const override {
744 UnwrappedIO io = unwrapIO(
745 op.getLoc(), llvm::SmallVector<Value>{operands.getInput()},
747 llvm::SmallVector<Type>{op.getToken().getType()}, rewriter, bb);
748 RTLBuilder rtlb(op.getLoc(), rewriter);
749 auto &input = io.inputs[0];
750 auto &output = io.outputs[0];
752 llvm::SmallVector<Value> unpackedValues;
753 unpackedValues.push_back(input.data);
756 llvm::SmallVector<Value>
outputs;
757 outputs.push_back(output.channel);
758 outputs.append(unpackedValues.begin(), unpackedValues.end());
759 rewriter.replaceOp(op,
outputs);
769 matchAndRewrite(BufferOp op, OpAdaptor operands,
770 ConversionPatternRewriter &rewriter)
const override {
771 auto crRes = getClockAndReset(op);
774 auto [clock, reset] = *crRes;
778 Type channelType = operands.getInput().getType();
779 rewriter.replaceOpWithNewOp<esi::ChannelBufferOp>(
780 op, channelType, clock, reset, operands.getInput(), op.getSizeAttr(),
788 static bool isDCType(Type type) {
return type.isa<TokenType, ValueType>(); }
793 if (
auto funcOp = dyn_cast<HWModuleLike>(op)) {
794 return llvm::none_of(funcOp.getPortTypes(),
isDCType) &&
795 llvm::none_of(funcOp.getBodyBlock()->getArgumentTypes(),
isDCType);
798 bool operandsOK = llvm::none_of(op->getOperandTypes(),
isDCType);
799 bool resultsOK = llvm::none_of(op->getResultTypes(),
isDCType);
800 return operandsOK && resultsOK;
808 class DCToHWPass :
public DCToHWBase<DCToHWPass> {
810 void runOnOperation()
override {
811 Operation *parent = getOperation();
815 auto walkRes = parent->walk([&](Operation *op) {
816 for (
auto res : op->getResults()) {
817 if (res.getType().isa<dc::TokenType, dc::ValueType>()) {
818 if (res.use_empty()) {
819 op->emitOpError() <<
"DCToHW: value " << res <<
" is unused.";
820 return WalkResult::interrupt();
822 if (!res.hasOneUse()) {
824 <<
"DCToHW: value " << res <<
" has multiple uses.";
825 return WalkResult::interrupt();
829 return WalkResult::advance();
832 if (walkRes.wasInterrupted()) {
833 parent->emitOpError()
834 <<
"DCToHW: failed to verify that all values "
835 "are used exactly once. Remember to run the "
836 "fork/sink materialization pass before HW lowering.";
841 ESITypeConverter typeConverter;
842 ConversionTarget target(getContext());
843 target.markUnknownOpDynamicallyLegal(
isLegalOp);
847 target.addIllegalDialect<dc::DCDialect>();
849 RewritePatternSet
patterns(parent->getContext());
851 patterns.insert<ForkConversionPattern, JoinConversionPattern,
852 SelectConversionPattern, BranchConversionPattern,
853 PackConversionPattern, UnpackConversionPattern,
854 BufferConversionPattern, SourceConversionPattern,
856 ToESIConversionPattern, FromESIConversionPattern>(
857 typeConverter, parent->getContext());
859 if (failed(applyPartialConversion(parent, target, std::move(
patterns))))
866 return std::make_unique<DCToHWPass>();
assert(baseType &&"element must be base type")
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static Type toHWType(Type t)
Converts any type 't' into a hw-compatible type.
static bool isDCType(Type type)
static Type tupleToStruct(TupleType tuple)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal - i.e.
std::function< std::string(Operation *)> NameUniquer
static Type toESIHWType(Type t)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static EvaluatorValuePtr unwrap(OMEvaluatorValue c)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
def create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
mlir::Type innerType(mlir::Type type)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createDCToHWPass()
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Generic pattern which replaces an operation by one of the same operation name, but with converted att...