14 #include "../PassDetail.h"
24 #include "mlir/Dialect/Arith/IR/Arith.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/Support/MathExtras.h"
31 using namespace circt;
32 using namespace handshake;
35 using namespace handshaketodc;
41 DCTuple(Value token, Value data) : token(token),
data(
data) {}
42 DCTuple(dc::UnpackOp unpack)
43 : token(unpack.getToken()),
data(unpack.getOutput()) {}
49 static DCTuple unpack(OpBuilder &b, Value v) {
50 if (v.getType().isa<dc::ValueType>())
51 return DCTuple(b.create<dc::UnpackOp>(v.getLoc(), v));
52 assert(v.getType().isa<dc::TokenType>() &&
"Expected a dc::TokenType");
53 return DCTuple(v, {});
56 static Value pack(OpBuilder &b, Value token, Value data = {}) {
59 return b.create<dc::PackOp>(token.getLoc(), token,
data);
62 class DCTypeConverter :
public TypeConverter {
65 addConversion([](Type type) -> Type {
66 if (type.isa<NoneType>())
70 addConversion([](ValueType type) {
return type; });
71 addConversion([](TokenType type) {
return type; });
73 addTargetMaterialization(
74 [](mlir::OpBuilder &
builder, mlir::Type resultType,
76 mlir::Location loc) -> std::optional<mlir::Value> {
81 if (resultType.isa<dc::TokenType>() &&
82 inputs.front().getType().isa<dc::ValueType>())
86 auto vt = resultType.dyn_cast<dc::ValueType>();
87 if (vt && !vt.getInnerType())
93 addSourceMaterialization(
94 [](mlir::OpBuilder &
builder, mlir::Type resultType,
96 mlir::Location loc) -> std::optional<mlir::Value> {
101 if (resultType.isa<dc::TokenType>() &&
102 inputs.front().getType().isa<dc::ValueType>())
106 auto vt = resultType.dyn_cast<dc::ValueType>();
107 if (vt && !vt.getInnerType())
115 template <
typename OpTy>
119 using OpAdaptor =
typename OpTy::Adaptor;
121 DCOpConversionPattern(MLIRContext *context, TypeConverter &typeConverter,
124 convertedOps(convertedOps) {}
128 class CondBranchConversionPattern
129 :
public DCOpConversionPattern<handshake::ConditionalBranchOp> {
131 using DCOpConversionPattern<
132 handshake::ConditionalBranchOp>::DCOpConversionPattern;
133 using OpAdaptor =
typename handshake::ConditionalBranchOp::Adaptor;
136 matchAndRewrite(handshake::ConditionalBranchOp op, OpAdaptor adaptor,
137 ConversionPatternRewriter &rewriter)
const override {
138 auto condition = unpack(rewriter, adaptor.getConditionOperand());
139 auto data = unpack(rewriter, adaptor.getDataOperand());
142 auto join = rewriter.create<dc::JoinOp>(
143 op.getLoc(), ValueRange{condition.token,
data.token});
146 auto packedCondition = pack(rewriter, join, condition.data);
149 auto branch = rewriter.create<dc::BranchOp>(op.getLoc(), packedCondition);
152 llvm::SmallVector<Value, 4> packed;
153 packed.push_back(pack(rewriter, branch.getTrueToken(),
data.data));
154 packed.push_back(pack(rewriter, branch.getFalseToken(),
data.data));
156 rewriter.replaceOp(op, packed);
161 class ForkOpConversionPattern
162 :
public DCOpConversionPattern<handshake::ForkOp> {
164 using DCOpConversionPattern<handshake::ForkOp>::DCOpConversionPattern;
165 using OpAdaptor =
typename handshake::ForkOp::Adaptor;
168 matchAndRewrite(handshake::ForkOp op, OpAdaptor adaptor,
169 ConversionPatternRewriter &rewriter)
const override {
170 auto input = unpack(rewriter, adaptor.getOperand());
171 auto forkOut = rewriter.create<dc::ForkOp>(op.getLoc(), input.token,
175 llvm::SmallVector<Value, 4> packed;
176 for (
auto res : forkOut.getResults())
177 packed.push_back(pack(rewriter, res, input.data));
179 rewriter.replaceOp(op, packed);
184 class JoinOpConversion :
public DCOpConversionPattern<handshake::JoinOp> {
186 using DCOpConversionPattern<handshake::JoinOp>::DCOpConversionPattern;
187 using OpAdaptor =
typename handshake::JoinOp::Adaptor;
190 matchAndRewrite(handshake::JoinOp op, OpAdaptor adaptor,
191 ConversionPatternRewriter &rewriter)
const override {
192 llvm::SmallVector<Value, 4> inputTokens;
193 for (
auto input : adaptor.getData())
194 inputTokens.push_back(unpack(rewriter, input).token);
196 rewriter.replaceOpWithNewOp<dc::JoinOp>(op, inputTokens);
201 class ControlMergeOpConversion
202 :
public DCOpConversionPattern<handshake::ControlMergeOp> {
204 using DCOpConversionPattern<handshake::ControlMergeOp>::DCOpConversionPattern;
206 using OpAdaptor =
typename handshake::ControlMergeOp::Adaptor;
209 matchAndRewrite(handshake::ControlMergeOp op, OpAdaptor adaptor,
210 ConversionPatternRewriter &rewriter)
const override {
211 if (op.getDataOperands().size() != 2)
212 return op.emitOpError(
"expected two data operands");
214 llvm::SmallVector<Value> tokens,
data;
215 for (
auto input : adaptor.getDataOperands()) {
216 auto up = unpack(rewriter, input);
217 tokens.push_back(up.token);
219 data.push_back(up.data);
223 Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
224 auto mergeOpUnpacked = unpack(rewriter, selectedIndex);
225 auto selValue = mergeOpUnpacked.data;
227 Value dataSide = selectedIndex;
230 auto dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), selValue,
232 convertedOps->insert(dataMux);
234 auto packed = pack(rewriter, mergeOpUnpacked.token, dataMux);
241 if (op.getIndex().getType().isa<IndexType>()) {
242 selValue = rewriter.create<arith::IndexCastOp>(
243 op.getLoc(), rewriter.getIndexType(), selValue);
244 convertedOps->insert(selValue.getDefiningOp());
245 selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
248 rewriter.replaceOp(op, {dataSide, selectedIndex});
253 class SyncOpConversion :
public DCOpConversionPattern<handshake::SyncOp> {
255 using DCOpConversionPattern<handshake::SyncOp>::DCOpConversionPattern;
256 using OpAdaptor =
typename handshake::SyncOp::Adaptor;
259 matchAndRewrite(handshake::SyncOp op, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter)
const override {
261 llvm::SmallVector<Value, 4> inputTokens;
262 for (
auto input : adaptor.getOperands())
263 inputTokens.push_back(unpack(rewriter, input).token);
265 auto syncToken = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
268 llvm::SmallVector<Value, 4> wrappedInputs;
269 for (
auto input : adaptor.getOperands())
270 wrappedInputs.push_back(pack(rewriter, syncToken, input));
272 rewriter.replaceOp(op, wrappedInputs);
278 class ConstantOpConversion
279 :
public DCOpConversionPattern<handshake::ConstantOp> {
281 using DCOpConversionPattern<handshake::ConstantOp>::DCOpConversionPattern;
282 using OpAdaptor =
typename handshake::ConstantOp::Adaptor;
285 matchAndRewrite(handshake::ConstantOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter)
const override {
288 auto token = rewriter.create<dc::SourceOp>(op.getLoc());
290 rewriter.create<arith::ConstantOp>(op.getLoc(), adaptor.getValue());
291 convertedOps->insert(cst);
292 rewriter.replaceOp(op, pack(rewriter, token, cst));
297 struct UnitRateConversionPattern :
public ConversionPattern {
299 UnitRateConversionPattern(MLIRContext *context, TypeConverter &converter,
301 : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context),
302 joinedOps(joinedOps) {}
303 using ConversionPattern::ConversionPattern;
308 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
309 ConversionPatternRewriter &rewriter)
const override {
310 llvm::SmallVector<Value> inputData;
313 if (operands.empty()) {
314 if (!op->hasTrait<OpTrait::ConstantLike>())
315 return op->emitOpError(
316 "no-operand operation which isn't constant-like. Too dangerous "
317 "to assume semantics - won't convert");
321 outToken = rewriter.create<dc::SourceOp>(op->getLoc());
323 llvm::SmallVector<Value> inputTokens;
324 for (
auto input : operands) {
325 auto dct = unpack(rewriter, input);
326 inputData.push_back(dct.data);
327 inputTokens.push_back(dct.token);
330 assert(!inputTokens.empty() &&
"Expected at least one input token");
331 outToken = rewriter.create<dc::JoinOp>(op->getLoc(), inputTokens);
335 auto opName = op->getName();
336 if (opName.getStringRef() ==
"handshake.select") {
337 opName = OperationName(
"arith.select", getContext());
338 }
else if (opName.getStringRef() ==
"handshake.constant") {
339 opName = OperationName(
"arith.constant", getContext());
343 OperationState state(op->getLoc(), opName, inputData, op->getResultTypes(),
344 op->getAttrs(), op->getSuccessors());
346 Operation *newOp = rewriter.create(state);
347 joinedOps->insert(newOp);
350 llvm::SmallVector<Value> results;
351 for (
auto result : newOp->getResults())
352 results.push_back(pack(rewriter, outToken, result));
354 rewriter.replaceOp(op, results);
362 class SinkOpConversionPattern
363 :
public DCOpConversionPattern<handshake::SinkOp> {
365 using DCOpConversionPattern<handshake::SinkOp>::DCOpConversionPattern;
366 using OpAdaptor =
typename handshake::SinkOp::Adaptor;
369 matchAndRewrite(handshake::SinkOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter)
const override {
371 auto input = unpack(rewriter, adaptor.getOperand());
372 rewriter.replaceOpWithNewOp<dc::SinkOp>(op, input.token);
377 class SourceOpConversionPattern
378 :
public DCOpConversionPattern<handshake::SourceOp> {
380 using DCOpConversionPattern<handshake::SourceOp>::DCOpConversionPattern;
381 using OpAdaptor =
typename handshake::SourceOp::Adaptor;
384 matchAndRewrite(handshake::SourceOp op, OpAdaptor adaptor,
385 ConversionPatternRewriter &rewriter)
const override {
386 rewriter.replaceOpWithNewOp<dc::SourceOp>(op);
391 class BufferOpConversion :
public DCOpConversionPattern<handshake::BufferOp> {
393 using DCOpConversionPattern<handshake::BufferOp>::DCOpConversionPattern;
394 using OpAdaptor =
typename handshake::BufferOp::Adaptor;
397 matchAndRewrite(handshake::BufferOp op, OpAdaptor adaptor,
398 ConversionPatternRewriter &rewriter)
const override {
399 rewriter.getI32IntegerAttr(1);
400 rewriter.replaceOpWithNewOp<dc::BufferOp>(
401 op, adaptor.getOperand(),
static_cast<size_t>(op.getNumSlots()));
409 using OpAdaptor =
typename handshake::ReturnOp::Adaptor;
412 matchAndRewrite(handshake::ReturnOp op, OpAdaptor adaptor,
413 ConversionPatternRewriter &rewriter)
const override {
417 auto outputOp = *hwModule.getBodyBlock()->getOps<hw::OutputOp>().begin();
418 outputOp->setOperands(adaptor.getOperands());
419 outputOp->moveAfter(&hwModule.getBodyBlock()->back());
420 rewriter.eraseOp(op);
425 class MuxOpConversionPattern :
public DCOpConversionPattern<handshake::MuxOp> {
427 using DCOpConversionPattern<handshake::MuxOp>::DCOpConversionPattern;
428 using OpAdaptor =
typename handshake::MuxOp::Adaptor;
431 matchAndRewrite(handshake::MuxOp op, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter)
const override {
433 auto select = unpack(rewriter, adaptor.getSelectOperand());
434 auto selectData = select.data;
435 auto selectToken = select.token;
436 bool isIndexType = selectData.getType().isa<IndexType>();
438 bool withData = !op.getResult().getType().isa<NoneType>();
440 llvm::SmallVector<DCTuple>
inputs;
441 for (
auto input : adaptor.getDataOperands())
442 inputs.push_back(unpack(rewriter, input));
445 Value controlMux =
inputs.front().token;
452 llvm::SmallVector<Value> controlMuxInputs = {
inputs.front().token};
453 for (
auto [i, input] :
454 llvm::enumerate(llvm::make_range(
inputs.begin() + 1,
inputs.end()))) {
459 Value inputData = input.data;
460 Value inputControl = input.token;
462 cmpIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), i);
464 size_t width = selectData.getType().cast<IntegerType>().
getWidth();
465 cmpIndex = rewriter.create<arith::ConstantIntOp>(op.getLoc(), i,
width);
467 auto inputSelected = rewriter.create<arith::CmpIOp>(
468 op.getLoc(), arith::CmpIPredicate::eq, selectData, cmpIndex);
469 dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), inputSelected,
473 convertedOps->insert(cmpIndex.getDefiningOp());
474 convertedOps->insert(dataMux.getDefiningOp());
475 convertedOps->insert(inputSelected);
480 auto inputSelectedControl = pack(rewriter, selectToken, inputSelected);
481 controlMux = rewriter.create<dc::SelectOp>(
482 op.getLoc(), inputSelectedControl, inputControl, controlMux);
483 convertedOps->insert(controlMux.getDefiningOp());
488 op, pack(rewriter, controlMux, withData ? dataMux : Value{}));
493 static hw::ModulePortInfo getModulePortInfoHS(
const TypeConverter &tc,
494 handshake::FuncOp funcOp) {
496 auto *ctx = funcOp->getContext();
497 auto ft = funcOp.getFunctionType();
500 for (
auto [index, type] : llvm::enumerate(ft.getInputs())) {
508 for (
auto [index, type] : llvm::enumerate(ft.getResults())) {
522 using OpAdaptor =
typename handshake::FuncOp::Adaptor;
531 matchAndRewrite(handshake::FuncOp op, OpAdaptor adaptor,
532 ConversionPatternRewriter &rewriter)
const override {
533 ModulePortInfo ports = getModulePortInfoHS(*getTypeConverter(), op);
535 if (op.isExternal()) {
537 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
540 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
542 auto ®ion = op->getRegions().front();
544 Region &moduleRegion = hwModule->getRegions().front();
545 rewriter.mergeBlocks(®ion.getBlocks().front(), hwModule.getBodyBlock(),
546 hwModule.getBodyBlock()->getArguments());
547 TypeConverter::SignatureConversion result(moduleRegion.getNumArguments());
548 (void)getTypeConverter()->convertSignatureArgs(
549 TypeRange(moduleRegion.getArgumentTypes()), result);
550 rewriter.applySignatureConversion(&moduleRegion, result);
553 rewriter.eraseOp(op);
558 class HandshakeToDCPass :
public HandshakeToDCBase<HandshakeToDCPass> {
560 void runOnOperation()
override {
561 mlir::ModuleOp mod = getOperation();
562 auto targetModifier = [](mlir::ConversionTarget &target) {
563 target.addLegalDialect<hw::HWDialect, func::FuncDialect>();
566 auto patternBuilder = [&](TypeConverter &typeConverter,
569 patterns.add<FuncOpConversion, ReturnOpConversion>(typeConverter,
581 return std::make_unique<HandshakeToDCPass>();
586 llvm::function_ref<
void(TypeConverter &typeConverter,
590 llvm::function_ref<
void(mlir::ConversionTarget &)> configureTarget) {
602 mlir::MLIRContext *ctx = op->getContext();
603 ConversionTarget target(*ctx);
604 target.addIllegalDialect<handshake::HandshakeDialect>();
605 target.addLegalDialect<dc::DCDialect>();
606 target.addLegalOp<mlir::ModuleOp>();
610 configureTarget(target);
618 target.markUnknownOpDynamicallyLegal(
619 [&](Operation *op) {
return convertedOps.contains(op); });
621 DCTypeConverter typeConverter;
628 .add<BufferOpConversion, CondBranchConversionPattern,
629 SinkOpConversionPattern, SourceOpConversionPattern,
630 MuxOpConversionPattern, ForkOpConversionPattern, JoinOpConversion,
631 ControlMergeOpConversion, ConstantOpConversion, SyncOpConversion>(
632 ctx, typeConverter, &convertedOps);
636 patterns.add<UnitRateConversionPattern>(ctx, typeConverter, &convertedOps);
639 patternBuilder(typeConverter, convertedOps,
patterns);
640 return applyPartialConversion(op, target, std::move(
patterns));
assert(baseType &&"element must be base type")
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
DenseSet< Operation * > ConvertedOps
LogicalResult runHandshakeToDC(mlir::Operation *op, llvm::function_ref< void(TypeConverter &typeConverter, ConvertedOps &convertedOps, RewritePatternSet &patterns)> patternBuilder, llvm::function_ref< void(mlir::ConversionTarget &)> configureTarget={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToDCPass()
This holds a decoded list of input/inout and output ports for a module or instance.