23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/Support/MathExtras.h"
31 #define GEN_PASS_DEF_HANDSHAKETODC
32 #include "circt/Conversion/Passes.h.inc"
36 using namespace circt;
37 using namespace handshake;
40 using namespace handshaketodc;
46 DCTuple(Value token, Value
data) : token(token),
data(
data) {}
47 DCTuple(dc::UnpackOp unpack)
48 : token(unpack.getToken()),
data(unpack.getOutput()) {}
54 static DCTuple unpack(OpBuilder &b, Value v) {
55 if (isa<dc::ValueType>(v.getType()))
56 return DCTuple(b.create<dc::UnpackOp>(v.getLoc(), v));
57 assert(isa<dc::TokenType>(v.getType()) &&
"Expected a dc::TokenType");
58 return DCTuple(v, {});
61 static Value pack(OpBuilder &b, Value token, Value data = {}) {
64 return b.create<dc::PackOp>(token.getLoc(), token,
data);
67 class DCTypeConverter :
public TypeConverter {
70 addConversion([](Type type) -> Type {
71 if (isa<NoneType>(type))
75 addConversion([](ValueType type) {
return type; });
76 addConversion([](TokenType type) {
return type; });
78 addTargetMaterialization(
79 [](mlir::OpBuilder &builder, mlir::Type resultType,
80 mlir::ValueRange inputs,
81 mlir::Location loc) -> std::optional<mlir::Value> {
82 if (inputs.size() != 1)
86 if (isa<dc::TokenType>(resultType) &&
87 isa<dc::ValueType>(inputs.front().getType()))
88 return unpack(builder, inputs.front()).token;
91 auto vt = dyn_cast<dc::ValueType>(resultType);
92 if (vt && !vt.getInnerType())
93 return pack(builder, inputs.front());
96 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
100 addSourceMaterialization(
101 [](mlir::OpBuilder &builder, mlir::Type resultType,
102 mlir::ValueRange inputs,
103 mlir::Location loc) -> std::optional<mlir::Value> {
104 if (inputs.size() != 1)
108 if (isa<dc::TokenType>(resultType) &&
109 isa<dc::ValueType>(inputs.front().getType()))
110 return unpack(builder, inputs.front()).token;
113 auto vt = dyn_cast<dc::ValueType>(resultType);
114 if (vt && !vt.getInnerType())
115 return pack(builder, inputs.front());
118 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
124 template <
typename OpTy>
128 using OpAdaptor =
typename OpTy::Adaptor;
130 DCOpConversionPattern(MLIRContext *context, TypeConverter &typeConverter,
133 convertedOps(convertedOps) {}
137 class CondBranchConversionPattern
138 :
public DCOpConversionPattern<handshake::ConditionalBranchOp> {
140 using DCOpConversionPattern<
141 handshake::ConditionalBranchOp>::DCOpConversionPattern;
142 using OpAdaptor =
typename handshake::ConditionalBranchOp::Adaptor;
145 matchAndRewrite(handshake::ConditionalBranchOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter)
const override {
147 auto condition = unpack(rewriter, adaptor.getConditionOperand());
148 auto data = unpack(rewriter, adaptor.getDataOperand());
151 auto join = rewriter.create<dc::JoinOp>(
152 op.getLoc(), ValueRange{condition.token,
data.token});
155 auto packedCondition = pack(rewriter, join, condition.data);
158 auto branch = rewriter.create<dc::BranchOp>(op.getLoc(), packedCondition);
161 llvm::SmallVector<Value, 4> packed;
162 packed.push_back(pack(rewriter, branch.getTrueToken(),
data.data));
163 packed.push_back(pack(rewriter, branch.getFalseToken(),
data.data));
165 rewriter.replaceOp(op, packed);
170 class ForkOpConversionPattern
171 :
public DCOpConversionPattern<handshake::ForkOp> {
173 using DCOpConversionPattern<handshake::ForkOp>::DCOpConversionPattern;
174 using OpAdaptor =
typename handshake::ForkOp::Adaptor;
177 matchAndRewrite(handshake::ForkOp op, OpAdaptor adaptor,
178 ConversionPatternRewriter &rewriter)
const override {
179 auto input = unpack(rewriter, adaptor.getOperand());
180 auto forkOut = rewriter.create<dc::ForkOp>(op.getLoc(), input.token,
184 llvm::SmallVector<Value, 4> packed;
185 for (
auto res : forkOut.getResults())
186 packed.push_back(pack(rewriter, res, input.data));
188 rewriter.replaceOp(op, packed);
193 class JoinOpConversion :
public DCOpConversionPattern<handshake::JoinOp> {
195 using DCOpConversionPattern<handshake::JoinOp>::DCOpConversionPattern;
196 using OpAdaptor =
typename handshake::JoinOp::Adaptor;
199 matchAndRewrite(handshake::JoinOp op, OpAdaptor adaptor,
200 ConversionPatternRewriter &rewriter)
const override {
201 llvm::SmallVector<Value, 4> inputTokens;
202 for (
auto input : adaptor.getData())
203 inputTokens.push_back(unpack(rewriter, input).token);
205 rewriter.replaceOpWithNewOp<dc::JoinOp>(op, inputTokens);
210 class MergeOpConversion :
public DCOpConversionPattern<handshake::MergeOp> {
212 using DCOpConversionPattern<handshake::MergeOp>::DCOpConversionPattern;
213 using OpAdaptor =
typename handshake::MergeOp::Adaptor;
216 matchAndRewrite(handshake::MergeOp op, OpAdaptor adaptor,
217 ConversionPatternRewriter &rewriter)
const override {
218 if (op.getNumOperands() > 2)
219 return rewriter.notifyMatchFailure(op,
"only two inputs supported");
221 SmallVector<Value, 4> tokens,
data;
223 for (
auto input : adaptor.getDataOperands()) {
224 auto up = unpack(rewriter, input);
225 tokens.push_back(up.token);
227 data.push_back(up.data);
231 Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
232 auto selectedIndexUnpacked = unpack(rewriter, selectedIndex);
237 auto dataMux = rewriter.create<arith::SelectOp>(
238 op.getLoc(), selectedIndexUnpacked.data,
data[0],
data[1]);
239 convertedOps->insert(dataMux);
242 mergeOutput = pack(rewriter, selectedIndexUnpacked.token, dataMux);
246 mergeOutput = selectedIndexUnpacked.token;
249 rewriter.replaceOp(op, mergeOutput);
254 class ControlMergeOpConversion
255 :
public DCOpConversionPattern<handshake::ControlMergeOp> {
257 using DCOpConversionPattern<handshake::ControlMergeOp>::DCOpConversionPattern;
259 using OpAdaptor =
typename handshake::ControlMergeOp::Adaptor;
262 matchAndRewrite(handshake::ControlMergeOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter)
const override {
264 if (op.getDataOperands().size() != 2)
265 return op.emitOpError(
"expected two data operands");
267 llvm::SmallVector<Value> tokens,
data;
268 for (
auto input : adaptor.getDataOperands()) {
269 auto up = unpack(rewriter, input);
270 tokens.push_back(up.token);
272 data.push_back(up.data);
275 bool isIndexType = isa<IndexType>(op.getIndex().getType());
278 Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
279 auto mergeOpUnpacked = unpack(rewriter, selectedIndex);
280 auto selValue = mergeOpUnpacked.data;
282 Value dataSide = selectedIndex;
285 auto dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), selValue,
287 convertedOps->insert(dataMux);
289 auto packed = pack(rewriter, mergeOpUnpacked.token, dataMux);
297 selValue = rewriter.create<arith::IndexCastOp>(
298 op.getLoc(), rewriter.getIndexType(), selValue);
299 convertedOps->insert(selValue.getDefiningOp());
300 selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
304 selValue = rewriter.create<arith::ExtUIOp>(
305 op.getLoc(), op.getIndex().getType(), selValue);
306 convertedOps->insert(selValue.getDefiningOp());
307 selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
310 rewriter.replaceOp(op, {dataSide, selectedIndex});
315 class SyncOpConversion :
public DCOpConversionPattern<handshake::SyncOp> {
317 using DCOpConversionPattern<handshake::SyncOp>::DCOpConversionPattern;
318 using OpAdaptor =
typename handshake::SyncOp::Adaptor;
321 matchAndRewrite(handshake::SyncOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter)
const override {
323 llvm::SmallVector<Value, 4> inputTokens;
324 for (
auto input : adaptor.getOperands())
325 inputTokens.push_back(unpack(rewriter, input).token);
327 auto syncToken = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
330 llvm::SmallVector<Value, 4> wrappedInputs;
331 for (
auto input : adaptor.getOperands())
332 wrappedInputs.push_back(pack(rewriter, syncToken, input));
334 rewriter.replaceOp(op, wrappedInputs);
340 class ConstantOpConversion
341 :
public DCOpConversionPattern<handshake::ConstantOp> {
343 using DCOpConversionPattern<handshake::ConstantOp>::DCOpConversionPattern;
344 using OpAdaptor =
typename handshake::ConstantOp::Adaptor;
347 matchAndRewrite(handshake::ConstantOp op, OpAdaptor adaptor,
348 ConversionPatternRewriter &rewriter)
const override {
350 auto token = rewriter.create<dc::SourceOp>(op.getLoc());
352 rewriter.create<arith::ConstantOp>(op.getLoc(), adaptor.getValue());
353 convertedOps->insert(cst);
354 rewriter.replaceOp(op, pack(rewriter, token, cst));
359 struct UnitRateConversionPattern :
public ConversionPattern {
361 UnitRateConversionPattern(MLIRContext *context, TypeConverter &converter,
363 : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context),
364 joinedOps(joinedOps) {}
365 using ConversionPattern::ConversionPattern;
370 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
371 ConversionPatternRewriter &rewriter)
const override {
372 llvm::SmallVector<Value> inputData;
375 if (operands.empty()) {
376 if (!op->hasTrait<OpTrait::ConstantLike>())
377 return op->emitOpError(
378 "no-operand operation which isn't constant-like. Too dangerous "
379 "to assume semantics - won't convert");
383 outToken = rewriter.create<dc::SourceOp>(op->getLoc());
385 llvm::SmallVector<Value> inputTokens;
386 for (
auto input : operands) {
387 auto dct = unpack(rewriter, input);
388 inputData.push_back(dct.data);
389 inputTokens.push_back(dct.token);
392 assert(!inputTokens.empty() &&
"Expected at least one input token");
393 outToken = rewriter.create<dc::JoinOp>(op->getLoc(), inputTokens);
397 auto opName = op->getName();
398 if (opName.getStringRef() ==
"handshake.select") {
399 opName = OperationName(
"arith.select", getContext());
400 }
else if (opName.getStringRef() ==
"handshake.constant") {
401 opName = OperationName(
"arith.constant", getContext());
405 OperationState state(op->getLoc(), opName, inputData, op->getResultTypes(),
406 op->getAttrs(), op->getSuccessors());
408 Operation *newOp = rewriter.create(state);
409 joinedOps->insert(newOp);
412 llvm::SmallVector<Value> results;
413 for (
auto result : newOp->getResults())
414 results.push_back(pack(rewriter, outToken, result));
416 rewriter.replaceOp(op, results);
424 class SinkOpConversionPattern
425 :
public DCOpConversionPattern<handshake::SinkOp> {
427 using DCOpConversionPattern<handshake::SinkOp>::DCOpConversionPattern;
428 using OpAdaptor =
typename handshake::SinkOp::Adaptor;
431 matchAndRewrite(handshake::SinkOp op, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter)
const override {
433 auto input = unpack(rewriter, adaptor.getOperand());
434 rewriter.replaceOpWithNewOp<dc::SinkOp>(op, input.token);
439 class SourceOpConversionPattern
440 :
public DCOpConversionPattern<handshake::SourceOp> {
442 using DCOpConversionPattern<handshake::SourceOp>::DCOpConversionPattern;
443 using OpAdaptor =
typename handshake::SourceOp::Adaptor;
446 matchAndRewrite(handshake::SourceOp op, OpAdaptor adaptor,
447 ConversionPatternRewriter &rewriter)
const override {
448 rewriter.replaceOpWithNewOp<dc::SourceOp>(op);
453 class BufferOpConversion :
public DCOpConversionPattern<handshake::BufferOp> {
455 using DCOpConversionPattern<handshake::BufferOp>::DCOpConversionPattern;
456 using OpAdaptor =
typename handshake::BufferOp::Adaptor;
459 matchAndRewrite(handshake::BufferOp op, OpAdaptor adaptor,
460 ConversionPatternRewriter &rewriter)
const override {
461 rewriter.getI32IntegerAttr(1);
462 rewriter.replaceOpWithNewOp<dc::BufferOp>(
463 op, adaptor.getOperand(),
static_cast<size_t>(op.getNumSlots()));
471 using OpAdaptor =
typename handshake::ReturnOp::Adaptor;
474 matchAndRewrite(handshake::ReturnOp op, OpAdaptor adaptor,
475 ConversionPatternRewriter &rewriter)
const override {
479 auto outputOp = *hwModule.getBodyBlock()->getOps<hw::OutputOp>().begin();
480 outputOp->setOperands(adaptor.getOperands());
481 outputOp->moveAfter(&hwModule.getBodyBlock()->back());
482 rewriter.eraseOp(op);
487 class MuxOpConversionPattern :
public DCOpConversionPattern<handshake::MuxOp> {
489 using DCOpConversionPattern<handshake::MuxOp>::DCOpConversionPattern;
490 using OpAdaptor =
typename handshake::MuxOp::Adaptor;
493 matchAndRewrite(handshake::MuxOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter)
const override {
495 auto select = unpack(rewriter, adaptor.getSelectOperand());
496 auto selectData = select.data;
497 auto selectToken = select.token;
498 bool isIndexType = isa<IndexType>(selectData.getType());
500 bool withData = !isa<NoneType>(op.getResult().getType());
502 llvm::SmallVector<DCTuple> inputs;
503 for (
auto input : adaptor.getDataOperands())
504 inputs.push_back(unpack(rewriter, input));
507 Value controlMux = inputs.front().token;
512 dataMux = inputs[0].data;
514 llvm::SmallVector<Value> controlMuxInputs = {inputs.front().token};
515 for (
auto [i, input] :
516 llvm::enumerate(llvm::make_range(inputs.begin() + 1, inputs.end()))) {
521 Value inputData = input.data;
522 Value inputControl = input.token;
524 cmpIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), i);
526 size_t width = cast<IntegerType>(selectData.getType()).getWidth();
527 cmpIndex = rewriter.create<arith::ConstantIntOp>(op.getLoc(), i,
width);
529 auto inputSelected = rewriter.create<arith::CmpIOp>(
530 op.getLoc(), arith::CmpIPredicate::eq, selectData, cmpIndex);
531 dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), inputSelected,
535 convertedOps->insert(cmpIndex.getDefiningOp());
536 convertedOps->insert(dataMux.getDefiningOp());
537 convertedOps->insert(inputSelected);
542 auto inputSelectedControl = pack(rewriter, selectToken, inputSelected);
543 controlMux = rewriter.create<dc::SelectOp>(
544 op.getLoc(), inputSelectedControl, inputControl, controlMux);
545 convertedOps->insert(controlMux.getDefiningOp());
550 op, pack(rewriter, controlMux, withData ? dataMux : Value{}));
555 static hw::ModulePortInfo getModulePortInfoHS(
const TypeConverter &tc,
556 handshake::FuncOp funcOp) {
557 SmallVector<hw::PortInfo> inputs, outputs;
558 auto *ctx = funcOp->getContext();
559 auto ft = funcOp.getFunctionType();
562 for (
auto [index, type] : llvm::enumerate(ft.getInputs())) {
570 for (
auto [index, type] : llvm::enumerate(ft.getResults())) {
578 return hw::ModulePortInfo{inputs, outputs};
584 using OpAdaptor =
typename handshake::FuncOp::Adaptor;
593 matchAndRewrite(handshake::FuncOp op, OpAdaptor adaptor,
594 ConversionPatternRewriter &rewriter)
const override {
595 ModulePortInfo ports = getModulePortInfoHS(*getTypeConverter(), op);
597 if (op.isExternal()) {
599 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
602 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
604 auto ®ion = op->getRegions().front();
606 Region &moduleRegion = hwModule->getRegions().front();
607 rewriter.mergeBlocks(®ion.getBlocks().front(), hwModule.getBodyBlock(),
608 hwModule.getBodyBlock()->getArguments());
609 TypeConverter::SignatureConversion result(moduleRegion.getNumArguments());
610 (void)getTypeConverter()->convertSignatureArgs(
611 TypeRange(moduleRegion.getArgumentTypes()), result);
612 rewriter.applySignatureConversion(hwModule.getBodyBlock(), result);
615 rewriter.eraseOp(op);
620 class HandshakeToDCPass
621 :
public circt::impl::HandshakeToDCBase<HandshakeToDCPass> {
623 void runOnOperation()
override {
624 mlir::ModuleOp mod = getOperation();
625 auto targetModifier = [](mlir::ConversionTarget &target) {
626 target.addLegalDialect<hw::HWDialect, func::FuncDialect>();
629 auto patternBuilder = [&](TypeConverter &typeConverter,
644 return std::make_unique<HandshakeToDCPass>();
649 llvm::function_ref<
void(TypeConverter &typeConverter,
653 llvm::function_ref<
void(mlir::ConversionTarget &)> configureTarget) {
665 mlir::MLIRContext *ctx = op->getContext();
666 ConversionTarget target(*ctx);
667 target.addIllegalDialect<handshake::HandshakeDialect>();
668 target.addLegalDialect<dc::DCDialect>();
669 target.addLegalOp<mlir::ModuleOp>();
673 configureTarget(target);
681 target.markUnknownOpDynamicallyLegal(
682 [&](Operation *op) {
return convertedOps.contains(op); });
684 DCTypeConverter typeConverter;
690 patterns.add<BufferOpConversion, CondBranchConversionPattern,
691 SinkOpConversionPattern, SourceOpConversionPattern,
692 MuxOpConversionPattern, ForkOpConversionPattern,
693 JoinOpConversion, MergeOpConversion, ControlMergeOpConversion,
694 ConstantOpConversion, SyncOpConversion>(ctx, typeConverter,
699 patterns.add<UnitRateConversionPattern>(ctx, typeConverter, &convertedOps);
702 patternBuilder(typeConverter, convertedOps,
patterns);
703 return applyPartialConversion(op, target, std::move(
patterns));
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
DenseSet< Operation * > ConvertedOps
LogicalResult runHandshakeToDC(mlir::Operation *op, llvm::function_ref< void(TypeConverter &typeConverter, ConvertedOps &convertedOps, RewritePatternSet &patterns)> patternBuilder, llvm::function_ref< void(mlir::ConversionTarget &)> configureTarget={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToDCPass()
This holds a decoded list of input/inout and output ports for a module or instance.
Creates a new Calyx component for each FuncOp in the program.