23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/Support/MathExtras.h"
31 #define GEN_PASS_DEF_HANDSHAKETODC
32 #include "circt/Conversion/Passes.h.inc"
36 using namespace circt;
37 using namespace handshake;
40 using namespace handshaketodc;
46 DCTuple(Value token, Value
data) : token(token),
data(
data) {}
47 DCTuple(dc::UnpackOp unpack)
48 : token(unpack.getToken()),
data(unpack.getOutput()) {}
54 static DCTuple unpack(OpBuilder &b, Value v) {
55 if (isa<dc::ValueType>(v.getType()))
56 return DCTuple(b.create<dc::UnpackOp>(v.getLoc(), v));
57 assert(isa<dc::TokenType>(v.getType()) &&
"Expected a dc::TokenType");
58 return DCTuple(v, {});
61 static Value pack(OpBuilder &b, Value token, Value data = {}) {
64 return b.create<dc::PackOp>(token.getLoc(), token,
data);
67 class DCTypeConverter :
public TypeConverter {
70 addConversion([](Type type) -> Type {
71 if (isa<NoneType>(type))
75 addConversion([](ValueType type) {
return type; });
76 addConversion([](TokenType type) {
return type; });
78 addTargetMaterialization(
79 [](mlir::OpBuilder &builder, mlir::Type resultType,
80 mlir::ValueRange inputs,
81 mlir::Location loc) -> std::optional<mlir::Value> {
82 if (inputs.size() != 1)
86 if (isa<dc::TokenType>(resultType) &&
87 isa<dc::ValueType>(inputs.front().getType()))
88 return unpack(builder, inputs.front()).token;
91 auto vt = dyn_cast<dc::ValueType>(resultType);
92 if (vt && !vt.getInnerType())
93 return pack(builder, inputs.front());
98 addSourceMaterialization(
99 [](mlir::OpBuilder &builder, mlir::Type resultType,
100 mlir::ValueRange inputs,
101 mlir::Location loc) -> std::optional<mlir::Value> {
102 if (inputs.size() != 1)
106 if (isa<dc::TokenType>(resultType) &&
107 isa<dc::ValueType>(inputs.front().getType()))
108 return unpack(builder, inputs.front()).token;
111 auto vt = dyn_cast<dc::ValueType>(resultType);
112 if (vt && !vt.getInnerType())
113 return pack(builder, inputs.front());
120 template <
typename OpTy>
124 using OpAdaptor =
typename OpTy::Adaptor;
126 DCOpConversionPattern(MLIRContext *context, TypeConverter &typeConverter,
129 convertedOps(convertedOps) {}
133 class CondBranchConversionPattern
134 :
public DCOpConversionPattern<handshake::ConditionalBranchOp> {
136 using DCOpConversionPattern<
137 handshake::ConditionalBranchOp>::DCOpConversionPattern;
138 using OpAdaptor =
typename handshake::ConditionalBranchOp::Adaptor;
141 matchAndRewrite(handshake::ConditionalBranchOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter)
const override {
143 auto condition = unpack(rewriter, adaptor.getConditionOperand());
144 auto data = unpack(rewriter, adaptor.getDataOperand());
147 auto join = rewriter.create<dc::JoinOp>(
148 op.getLoc(), ValueRange{condition.token,
data.token});
151 auto packedCondition = pack(rewriter, join, condition.data);
154 auto branch = rewriter.create<dc::BranchOp>(op.getLoc(), packedCondition);
157 llvm::SmallVector<Value, 4> packed;
158 packed.push_back(pack(rewriter, branch.getTrueToken(),
data.data));
159 packed.push_back(pack(rewriter, branch.getFalseToken(),
data.data));
161 rewriter.replaceOp(op, packed);
166 class ForkOpConversionPattern
167 :
public DCOpConversionPattern<handshake::ForkOp> {
169 using DCOpConversionPattern<handshake::ForkOp>::DCOpConversionPattern;
170 using OpAdaptor =
typename handshake::ForkOp::Adaptor;
173 matchAndRewrite(handshake::ForkOp op, OpAdaptor adaptor,
174 ConversionPatternRewriter &rewriter)
const override {
175 auto input = unpack(rewriter, adaptor.getOperand());
176 auto forkOut = rewriter.create<dc::ForkOp>(op.getLoc(), input.token,
180 llvm::SmallVector<Value, 4> packed;
181 for (
auto res : forkOut.getResults())
182 packed.push_back(pack(rewriter, res, input.data));
184 rewriter.replaceOp(op, packed);
189 class JoinOpConversion :
public DCOpConversionPattern<handshake::JoinOp> {
191 using DCOpConversionPattern<handshake::JoinOp>::DCOpConversionPattern;
192 using OpAdaptor =
typename handshake::JoinOp::Adaptor;
195 matchAndRewrite(handshake::JoinOp op, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter)
const override {
197 llvm::SmallVector<Value, 4> inputTokens;
198 for (
auto input : adaptor.getData())
199 inputTokens.push_back(unpack(rewriter, input).token);
201 rewriter.replaceOpWithNewOp<dc::JoinOp>(op, inputTokens);
206 class MergeOpConversion :
public DCOpConversionPattern<handshake::MergeOp> {
208 using DCOpConversionPattern<handshake::MergeOp>::DCOpConversionPattern;
209 using OpAdaptor =
typename handshake::MergeOp::Adaptor;
212 matchAndRewrite(handshake::MergeOp op, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter)
const override {
214 if (op.getNumOperands() > 2)
215 return rewriter.notifyMatchFailure(op,
"only two inputs supported");
217 SmallVector<Value, 4> tokens,
data;
219 for (
auto input : adaptor.getDataOperands()) {
220 auto up = unpack(rewriter, input);
221 tokens.push_back(up.token);
223 data.push_back(up.data);
227 Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
228 auto selectedIndexUnpacked = unpack(rewriter, selectedIndex);
233 auto dataMux = rewriter.create<arith::SelectOp>(
234 op.getLoc(), selectedIndexUnpacked.data,
data[0],
data[1]);
235 convertedOps->insert(dataMux);
238 mergeOutput = pack(rewriter, selectedIndexUnpacked.token, dataMux);
242 mergeOutput = selectedIndexUnpacked.token;
245 rewriter.replaceOp(op, mergeOutput);
250 class ControlMergeOpConversion
251 :
public DCOpConversionPattern<handshake::ControlMergeOp> {
253 using DCOpConversionPattern<handshake::ControlMergeOp>::DCOpConversionPattern;
255 using OpAdaptor =
typename handshake::ControlMergeOp::Adaptor;
258 matchAndRewrite(handshake::ControlMergeOp op, OpAdaptor adaptor,
259 ConversionPatternRewriter &rewriter)
const override {
260 if (op.getDataOperands().size() != 2)
261 return op.emitOpError(
"expected two data operands");
263 llvm::SmallVector<Value> tokens,
data;
264 for (
auto input : adaptor.getDataOperands()) {
265 auto up = unpack(rewriter, input);
266 tokens.push_back(up.token);
268 data.push_back(up.data);
271 bool isIndexType = isa<IndexType>(op.getIndex().getType());
274 Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
275 auto mergeOpUnpacked = unpack(rewriter, selectedIndex);
276 auto selValue = mergeOpUnpacked.data;
278 Value dataSide = selectedIndex;
281 auto dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), selValue,
283 convertedOps->insert(dataMux);
285 auto packed = pack(rewriter, mergeOpUnpacked.token, dataMux);
293 selValue = rewriter.create<arith::IndexCastOp>(
294 op.getLoc(), rewriter.getIndexType(), selValue);
295 convertedOps->insert(selValue.getDefiningOp());
296 selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
300 selValue = rewriter.create<arith::ExtUIOp>(
301 op.getLoc(), op.getIndex().getType(), selValue);
302 convertedOps->insert(selValue.getDefiningOp());
303 selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
306 rewriter.replaceOp(op, {dataSide, selectedIndex});
311 class SyncOpConversion :
public DCOpConversionPattern<handshake::SyncOp> {
313 using DCOpConversionPattern<handshake::SyncOp>::DCOpConversionPattern;
314 using OpAdaptor =
typename handshake::SyncOp::Adaptor;
317 matchAndRewrite(handshake::SyncOp op, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter)
const override {
319 llvm::SmallVector<Value, 4> inputTokens;
320 for (
auto input : adaptor.getOperands())
321 inputTokens.push_back(unpack(rewriter, input).token);
323 auto syncToken = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
326 llvm::SmallVector<Value, 4> wrappedInputs;
327 for (
auto input : adaptor.getOperands())
328 wrappedInputs.push_back(pack(rewriter, syncToken, input));
330 rewriter.replaceOp(op, wrappedInputs);
336 class ConstantOpConversion
337 :
public DCOpConversionPattern<handshake::ConstantOp> {
339 using DCOpConversionPattern<handshake::ConstantOp>::DCOpConversionPattern;
340 using OpAdaptor =
typename handshake::ConstantOp::Adaptor;
343 matchAndRewrite(handshake::ConstantOp op, OpAdaptor adaptor,
344 ConversionPatternRewriter &rewriter)
const override {
346 auto token = rewriter.create<dc::SourceOp>(op.getLoc());
348 rewriter.create<arith::ConstantOp>(op.getLoc(), adaptor.getValue());
349 convertedOps->insert(cst);
350 rewriter.replaceOp(op, pack(rewriter, token, cst));
355 struct UnitRateConversionPattern :
public ConversionPattern {
357 UnitRateConversionPattern(MLIRContext *context, TypeConverter &converter,
359 : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context),
360 joinedOps(joinedOps) {}
361 using ConversionPattern::ConversionPattern;
366 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
367 ConversionPatternRewriter &rewriter)
const override {
368 llvm::SmallVector<Value> inputData;
371 if (operands.empty()) {
372 if (!op->hasTrait<OpTrait::ConstantLike>())
373 return op->emitOpError(
374 "no-operand operation which isn't constant-like. Too dangerous "
375 "to assume semantics - won't convert");
379 outToken = rewriter.create<dc::SourceOp>(op->getLoc());
381 llvm::SmallVector<Value> inputTokens;
382 for (
auto input : operands) {
383 auto dct = unpack(rewriter, input);
384 inputData.push_back(dct.data);
385 inputTokens.push_back(dct.token);
388 assert(!inputTokens.empty() &&
"Expected at least one input token");
389 outToken = rewriter.create<dc::JoinOp>(op->getLoc(), inputTokens);
393 auto opName = op->getName();
394 if (opName.getStringRef() ==
"handshake.select") {
395 opName = OperationName(
"arith.select", getContext());
396 }
else if (opName.getStringRef() ==
"handshake.constant") {
397 opName = OperationName(
"arith.constant", getContext());
401 OperationState state(op->getLoc(), opName, inputData, op->getResultTypes(),
402 op->getAttrs(), op->getSuccessors());
404 Operation *newOp = rewriter.create(state);
405 joinedOps->insert(newOp);
408 llvm::SmallVector<Value> results;
409 for (
auto result : newOp->getResults())
410 results.push_back(pack(rewriter, outToken, result));
412 rewriter.replaceOp(op, results);
420 class SinkOpConversionPattern
421 :
public DCOpConversionPattern<handshake::SinkOp> {
423 using DCOpConversionPattern<handshake::SinkOp>::DCOpConversionPattern;
424 using OpAdaptor =
typename handshake::SinkOp::Adaptor;
427 matchAndRewrite(handshake::SinkOp op, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter)
const override {
429 auto input = unpack(rewriter, adaptor.getOperand());
430 rewriter.replaceOpWithNewOp<dc::SinkOp>(op, input.token);
435 class SourceOpConversionPattern
436 :
public DCOpConversionPattern<handshake::SourceOp> {
438 using DCOpConversionPattern<handshake::SourceOp>::DCOpConversionPattern;
439 using OpAdaptor =
typename handshake::SourceOp::Adaptor;
442 matchAndRewrite(handshake::SourceOp op, OpAdaptor adaptor,
443 ConversionPatternRewriter &rewriter)
const override {
444 rewriter.replaceOpWithNewOp<dc::SourceOp>(op);
449 class BufferOpConversion :
public DCOpConversionPattern<handshake::BufferOp> {
451 using DCOpConversionPattern<handshake::BufferOp>::DCOpConversionPattern;
452 using OpAdaptor =
typename handshake::BufferOp::Adaptor;
455 matchAndRewrite(handshake::BufferOp op, OpAdaptor adaptor,
456 ConversionPatternRewriter &rewriter)
const override {
457 rewriter.getI32IntegerAttr(1);
458 rewriter.replaceOpWithNewOp<dc::BufferOp>(
459 op, adaptor.getOperand(),
static_cast<size_t>(op.getNumSlots()));
467 using OpAdaptor =
typename handshake::ReturnOp::Adaptor;
470 matchAndRewrite(handshake::ReturnOp op, OpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter)
const override {
475 auto outputOp = *hwModule.getBodyBlock()->getOps<hw::OutputOp>().begin();
476 outputOp->setOperands(adaptor.getOperands());
477 outputOp->moveAfter(&hwModule.getBodyBlock()->back());
478 rewriter.eraseOp(op);
483 class MuxOpConversionPattern :
public DCOpConversionPattern<handshake::MuxOp> {
485 using DCOpConversionPattern<handshake::MuxOp>::DCOpConversionPattern;
486 using OpAdaptor =
typename handshake::MuxOp::Adaptor;
489 matchAndRewrite(handshake::MuxOp op, OpAdaptor adaptor,
490 ConversionPatternRewriter &rewriter)
const override {
491 auto select = unpack(rewriter, adaptor.getSelectOperand());
492 auto selectData = select.data;
493 auto selectToken = select.token;
494 bool isIndexType = isa<IndexType>(selectData.getType());
496 bool withData = !isa<NoneType>(op.getResult().getType());
498 llvm::SmallVector<DCTuple> inputs;
499 for (
auto input : adaptor.getDataOperands())
500 inputs.push_back(unpack(rewriter, input));
503 Value controlMux = inputs.front().token;
508 dataMux = inputs[0].data;
510 llvm::SmallVector<Value> controlMuxInputs = {inputs.front().token};
511 for (
auto [i, input] :
512 llvm::enumerate(llvm::make_range(inputs.begin() + 1, inputs.end()))) {
517 Value inputData = input.data;
518 Value inputControl = input.token;
520 cmpIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), i);
522 size_t width = cast<IntegerType>(selectData.getType()).getWidth();
523 cmpIndex = rewriter.create<arith::ConstantIntOp>(op.getLoc(), i,
width);
525 auto inputSelected = rewriter.create<arith::CmpIOp>(
526 op.getLoc(), arith::CmpIPredicate::eq, selectData, cmpIndex);
527 dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), inputSelected,
531 convertedOps->insert(cmpIndex.getDefiningOp());
532 convertedOps->insert(dataMux.getDefiningOp());
533 convertedOps->insert(inputSelected);
538 auto inputSelectedControl = pack(rewriter, selectToken, inputSelected);
539 controlMux = rewriter.create<dc::SelectOp>(
540 op.getLoc(), inputSelectedControl, inputControl, controlMux);
541 convertedOps->insert(controlMux.getDefiningOp());
546 op, pack(rewriter, controlMux, withData ? dataMux : Value{}));
551 static hw::ModulePortInfo getModulePortInfoHS(
const TypeConverter &tc,
552 handshake::FuncOp funcOp) {
553 SmallVector<hw::PortInfo> inputs, outputs;
554 auto *ctx = funcOp->getContext();
555 auto ft = funcOp.getFunctionType();
558 for (
auto [index, type] : llvm::enumerate(ft.getInputs())) {
566 for (
auto [index, type] : llvm::enumerate(ft.getResults())) {
574 return hw::ModulePortInfo{inputs, outputs};
580 using OpAdaptor =
typename handshake::FuncOp::Adaptor;
589 matchAndRewrite(handshake::FuncOp op, OpAdaptor adaptor,
590 ConversionPatternRewriter &rewriter)
const override {
591 ModulePortInfo ports = getModulePortInfoHS(*getTypeConverter(), op);
593 if (op.isExternal()) {
595 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
598 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
600 auto ®ion = op->getRegions().front();
602 Region &moduleRegion = hwModule->getRegions().front();
603 rewriter.mergeBlocks(®ion.getBlocks().front(), hwModule.getBodyBlock(),
604 hwModule.getBodyBlock()->getArguments());
605 TypeConverter::SignatureConversion result(moduleRegion.getNumArguments());
606 (void)getTypeConverter()->convertSignatureArgs(
607 TypeRange(moduleRegion.getArgumentTypes()), result);
608 rewriter.applySignatureConversion(hwModule.getBodyBlock(), result);
611 rewriter.eraseOp(op);
616 class HandshakeToDCPass
617 :
public circt::impl::HandshakeToDCBase<HandshakeToDCPass> {
619 void runOnOperation()
override {
620 mlir::ModuleOp mod = getOperation();
621 auto targetModifier = [](mlir::ConversionTarget &target) {
622 target.addLegalDialect<hw::HWDialect, func::FuncDialect>();
625 auto patternBuilder = [&](TypeConverter &typeConverter,
640 return std::make_unique<HandshakeToDCPass>();
645 llvm::function_ref<
void(TypeConverter &typeConverter,
649 llvm::function_ref<
void(mlir::ConversionTarget &)> configureTarget) {
661 mlir::MLIRContext *ctx = op->getContext();
662 ConversionTarget target(*ctx);
663 target.addIllegalDialect<handshake::HandshakeDialect>();
664 target.addLegalDialect<dc::DCDialect>();
665 target.addLegalOp<mlir::ModuleOp>();
669 configureTarget(target);
677 target.markUnknownOpDynamicallyLegal(
678 [&](Operation *op) {
return convertedOps.contains(op); });
680 DCTypeConverter typeConverter;
686 patterns.add<BufferOpConversion, CondBranchConversionPattern,
687 SinkOpConversionPattern, SourceOpConversionPattern,
688 MuxOpConversionPattern, ForkOpConversionPattern,
689 JoinOpConversion, MergeOpConversion, ControlMergeOpConversion,
690 ConstantOpConversion, SyncOpConversion>(ctx, typeConverter,
695 patterns.add<UnitRateConversionPattern>(ctx, typeConverter, &convertedOps);
698 patternBuilder(typeConverter, convertedOps,
patterns);
699 return applyPartialConversion(op, target, std::move(
patterns));
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
DenseSet< Operation * > ConvertedOps
LogicalResult runHandshakeToDC(mlir::Operation *op, llvm::function_ref< void(TypeConverter &typeConverter, ConvertedOps &convertedOps, RewritePatternSet &patterns)> patternBuilder, llvm::function_ref< void(mlir::ConversionTarget &)> configureTarget={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createHandshakeToDCPass()
This holds a decoded list of input/inout and output ports for a module or instance.
Creates a new Calyx component for each FuncOp in the program.