14 #include "../PassDetail.h"
24 #include "mlir/Dialect/Arith/IR/Arith.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/Support/MathExtras.h"
31 using namespace circt;
32 using namespace handshake;
38 using ConvertedOps = DenseSet<Operation *>;
42 DCTuple(Value token, Value data) : token(token),
data(
data) {}
43 DCTuple(dc::UnpackOp unpack)
44 : token(unpack.getToken()),
data(unpack.getOutput()) {}
50 static DCTuple unpack(OpBuilder &b, Value v) {
51 if (v.getType().isa<dc::ValueType>())
52 return DCTuple(b.create<dc::UnpackOp>(v.getLoc(), v));
53 assert(v.getType().isa<dc::TokenType>() &&
"Expected a dc::TokenType");
54 return DCTuple(v, {});
57 static Value pack(OpBuilder &b, Value token, Value data = {}) {
60 return b.create<dc::PackOp>(token.getLoc(), token,
data);
63 class DCTypeConverter :
public TypeConverter {
66 addConversion([](Type type) -> Type {
67 if (type.isa<NoneType>())
71 addConversion([](ValueType type) {
return type; });
72 addConversion([](TokenType type) {
return type; });
74 addTargetMaterialization(
75 [](mlir::OpBuilder &
builder, mlir::Type resultType,
77 mlir::Location loc) -> std::optional<mlir::Value> {
82 if (resultType.isa<dc::TokenType>() &&
83 inputs.front().getType().isa<dc::ValueType>())
87 auto vt = resultType.dyn_cast<dc::ValueType>();
88 if (vt && !vt.getInnerType())
94 addSourceMaterialization(
95 [](mlir::OpBuilder &
builder, mlir::Type resultType,
97 mlir::Location loc) -> std::optional<mlir::Value> {
102 if (resultType.isa<dc::TokenType>() &&
103 inputs.front().getType().isa<dc::ValueType>())
107 auto vt = resultType.dyn_cast<dc::ValueType>();
108 if (vt && !vt.getInnerType())
116 template <
typename OpTy>
120 using OpAdaptor =
typename OpTy::Adaptor;
122 DCOpConversionPattern(MLIRContext *context, TypeConverter &typeConverter,
123 ConvertedOps *convertedOps)
125 convertedOps(convertedOps) {}
126 mutable ConvertedOps *convertedOps;
129 class CondBranchConversionPattern
130 :
public DCOpConversionPattern<handshake::ConditionalBranchOp> {
132 using DCOpConversionPattern<
133 handshake::ConditionalBranchOp>::DCOpConversionPattern;
134 using OpAdaptor =
typename handshake::ConditionalBranchOp::Adaptor;
137 matchAndRewrite(handshake::ConditionalBranchOp op, OpAdaptor adaptor,
138 ConversionPatternRewriter &rewriter)
const override {
139 auto condition = unpack(rewriter, adaptor.getConditionOperand());
140 auto data = unpack(rewriter, adaptor.getDataOperand());
143 auto join = rewriter.create<dc::JoinOp>(
144 op.getLoc(), ValueRange{condition.token,
data.token});
147 auto packedCondition = pack(rewriter, join, condition.data);
150 auto branch = rewriter.create<dc::BranchOp>(op.getLoc(), packedCondition);
153 llvm::SmallVector<Value, 4> packed;
154 packed.push_back(pack(rewriter, branch.getTrueToken(),
data.data));
155 packed.push_back(pack(rewriter, branch.getFalseToken(),
data.data));
157 rewriter.replaceOp(op, packed);
162 class ForkOpConversionPattern
163 :
public DCOpConversionPattern<handshake::ForkOp> {
165 using DCOpConversionPattern<handshake::ForkOp>::DCOpConversionPattern;
166 using OpAdaptor =
typename handshake::ForkOp::Adaptor;
169 matchAndRewrite(handshake::ForkOp op, OpAdaptor adaptor,
170 ConversionPatternRewriter &rewriter)
const override {
171 auto input = unpack(rewriter, adaptor.getOperand());
172 auto forkOut = rewriter.create<dc::ForkOp>(op.getLoc(), input.token,
176 llvm::SmallVector<Value, 4> packed;
177 for (
auto res : forkOut.getResults())
178 packed.push_back(pack(rewriter, res, input.data));
180 rewriter.replaceOp(op, packed);
185 class JoinOpConversion :
public DCOpConversionPattern<handshake::JoinOp> {
187 using DCOpConversionPattern<handshake::JoinOp>::DCOpConversionPattern;
188 using OpAdaptor =
typename handshake::JoinOp::Adaptor;
191 matchAndRewrite(handshake::JoinOp op, OpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter)
const override {
193 llvm::SmallVector<Value, 4> inputTokens;
194 for (
auto input : adaptor.getData())
195 inputTokens.push_back(unpack(rewriter, input).token);
197 rewriter.replaceOpWithNewOp<dc::JoinOp>(op, inputTokens);
202 class ControlMergeOpConversion
203 :
public DCOpConversionPattern<handshake::ControlMergeOp> {
205 using DCOpConversionPattern<handshake::ControlMergeOp>::DCOpConversionPattern;
207 using OpAdaptor =
typename handshake::ControlMergeOp::Adaptor;
210 matchAndRewrite(handshake::ControlMergeOp op, OpAdaptor adaptor,
211 ConversionPatternRewriter &rewriter)
const override {
212 if (op.getDataOperands().size() != 2)
213 return op.emitOpError(
"expected two data operands");
215 llvm::SmallVector<Value> tokens,
data;
216 for (
auto input : adaptor.getDataOperands()) {
217 auto up = unpack(rewriter, input);
218 tokens.push_back(up.token);
220 data.push_back(up.data);
224 Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
225 auto mergeOpUnpacked = unpack(rewriter, selectedIndex);
226 auto selValue = mergeOpUnpacked.data;
228 Value dataSide = selectedIndex;
231 auto dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), selValue,
233 convertedOps->insert(dataMux);
235 auto packed = pack(rewriter, mergeOpUnpacked.token, dataMux);
242 if (op.getIndex().getType().isa<IndexType>()) {
243 selValue = rewriter.create<arith::IndexCastOp>(
244 op.getLoc(), rewriter.getIndexType(), selValue);
245 convertedOps->insert(selValue.getDefiningOp());
246 selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
249 rewriter.replaceOp(op, {dataSide, selectedIndex});
254 class SyncOpConversion :
public DCOpConversionPattern<handshake::SyncOp> {
256 using DCOpConversionPattern<handshake::SyncOp>::DCOpConversionPattern;
257 using OpAdaptor =
typename handshake::SyncOp::Adaptor;
260 matchAndRewrite(handshake::SyncOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter)
const override {
262 llvm::SmallVector<Value, 4> inputTokens;
263 for (
auto input : adaptor.getOperands())
264 inputTokens.push_back(unpack(rewriter, input).token);
266 auto syncToken = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
269 llvm::SmallVector<Value, 4> wrappedInputs;
270 for (
auto input : adaptor.getOperands())
271 wrappedInputs.push_back(pack(rewriter, syncToken, input));
273 rewriter.replaceOp(op, wrappedInputs);
279 class ConstantOpConversion
280 :
public DCOpConversionPattern<handshake::ConstantOp> {
282 using DCOpConversionPattern<handshake::ConstantOp>::DCOpConversionPattern;
283 using OpAdaptor =
typename handshake::ConstantOp::Adaptor;
286 matchAndRewrite(handshake::ConstantOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter)
const override {
289 auto token = rewriter.create<dc::SourceOp>(op.getLoc());
291 rewriter.create<arith::ConstantOp>(op.getLoc(), adaptor.getValue());
292 convertedOps->insert(cst);
293 rewriter.replaceOp(op, pack(rewriter, token, cst));
298 struct UnitRateConversionPattern :
public ConversionPattern {
300 UnitRateConversionPattern(MLIRContext *context, TypeConverter &converter,
301 ConvertedOps *joinedOps)
302 : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context),
303 joinedOps(joinedOps) {}
304 using ConversionPattern::ConversionPattern;
309 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
310 ConversionPatternRewriter &rewriter)
const override {
311 if (op->getNumResults() != 1)
312 return op->emitOpError(
"expected single result for pattern to apply");
314 llvm::SmallVector<Value, 4> inputData;
315 llvm::SmallVector<Value, 4> inputTokens;
316 for (
auto input : operands) {
317 auto dct = unpack(rewriter, input);
318 inputData.push_back(dct.data);
319 inputTokens.push_back(dct.token);
323 auto join = rewriter.create<dc::JoinOp>(op->getLoc(), inputTokens);
326 auto opName = op->getName();
327 if (opName.getStringRef() ==
"handshake.select") {
328 opName = OperationName(
"arith.select", getContext());
329 }
else if (opName.getStringRef() ==
"handshake.constant") {
330 opName = OperationName(
"arith.constant", getContext());
334 OperationState state(op->getLoc(), opName, inputData, op->getResultTypes(),
335 op->getAttrs(), op->getSuccessors());
337 Operation *newOp = rewriter.create(state);
338 joinedOps->insert(newOp);
341 rewriter.replaceOp(op, ValueRange{pack(rewriter, join.getResult(),
342 newOp->getResults().front())});
347 mutable ConvertedOps *joinedOps;
350 class SinkOpConversionPattern
351 :
public DCOpConversionPattern<handshake::SinkOp> {
353 using DCOpConversionPattern<handshake::SinkOp>::DCOpConversionPattern;
354 using OpAdaptor =
typename handshake::SinkOp::Adaptor;
357 matchAndRewrite(handshake::SinkOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter)
const override {
359 auto input = unpack(rewriter, adaptor.getOperand());
360 rewriter.replaceOpWithNewOp<dc::SinkOp>(op, input.token);
365 class SourceOpConversionPattern
366 :
public DCOpConversionPattern<handshake::SourceOp> {
368 using DCOpConversionPattern<handshake::SourceOp>::DCOpConversionPattern;
369 using OpAdaptor =
typename handshake::SourceOp::Adaptor;
372 matchAndRewrite(handshake::SourceOp op, OpAdaptor adaptor,
373 ConversionPatternRewriter &rewriter)
const override {
374 rewriter.replaceOpWithNewOp<dc::SourceOp>(op);
379 class BufferOpConversion :
public DCOpConversionPattern<handshake::BufferOp> {
381 using DCOpConversionPattern<handshake::BufferOp>::DCOpConversionPattern;
382 using OpAdaptor =
typename handshake::BufferOp::Adaptor;
385 matchAndRewrite(handshake::BufferOp op, OpAdaptor adaptor,
386 ConversionPatternRewriter &rewriter)
const override {
387 rewriter.getI32IntegerAttr(1);
388 rewriter.replaceOpWithNewOp<dc::BufferOp>(
389 op, adaptor.getOperand(),
static_cast<size_t>(op.getNumSlots()));
394 class ReturnOpConversion :
public DCOpConversionPattern<handshake::ReturnOp> {
396 using DCOpConversionPattern<handshake::ReturnOp>::DCOpConversionPattern;
397 using OpAdaptor =
typename handshake::ReturnOp::Adaptor;
400 matchAndRewrite(handshake::ReturnOp op, OpAdaptor adaptor,
401 ConversionPatternRewriter &rewriter)
const override {
404 auto hwModule = op->getParentOfType<hw::HWModuleOp>();
405 auto outputOp = *hwModule.getBodyBlock()->getOps<hw::OutputOp>().begin();
406 outputOp->setOperands(adaptor.getOperands());
407 outputOp->moveAfter(&hwModule.getBodyBlock()->back());
408 rewriter.eraseOp(op);
413 class MuxOpConversionPattern :
public DCOpConversionPattern<handshake::MuxOp> {
415 using DCOpConversionPattern<handshake::MuxOp>::DCOpConversionPattern;
416 using OpAdaptor =
typename handshake::MuxOp::Adaptor;
419 matchAndRewrite(handshake::MuxOp op, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter)
const override {
421 auto select = unpack(rewriter, adaptor.getSelectOperand());
422 auto selectData = select.data;
423 auto selectToken = select.token;
424 bool isIndexType = selectData.getType().isa<IndexType>();
426 bool withData = !op.getResult().getType().isa<NoneType>();
428 llvm::SmallVector<DCTuple>
inputs;
429 for (
auto input : adaptor.getDataOperands())
430 inputs.push_back(unpack(rewriter, input));
433 Value controlMux =
inputs.front().token;
440 llvm::SmallVector<Value> controlMuxInputs = {
inputs.front().token};
441 for (
auto [i, input] :
442 llvm::enumerate(llvm::make_range(
inputs.begin() + 1,
inputs.end()))) {
447 Value inputData = input.data;
448 Value inputControl = input.token;
450 cmpIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), i);
452 size_t width = selectData.getType().cast<IntegerType>().getWidth();
453 cmpIndex = rewriter.create<arith::ConstantIntOp>(op.getLoc(), i,
width);
455 auto inputSelected = rewriter.create<arith::CmpIOp>(
456 op.getLoc(), arith::CmpIPredicate::eq, selectData, cmpIndex);
457 dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), inputSelected,
461 convertedOps->insert(cmpIndex.getDefiningOp());
462 convertedOps->insert(dataMux.getDefiningOp());
463 convertedOps->insert(inputSelected);
468 auto inputSelectedControl = pack(rewriter, selectToken, inputSelected);
469 controlMux = rewriter.create<dc::SelectOp>(
470 op.getLoc(), inputSelectedControl, inputControl, controlMux);
471 convertedOps->insert(controlMux.getDefiningOp());
476 op, pack(rewriter, controlMux, withData ? dataMux : Value{}));
481 static hw::ModulePortInfo getModulePortInfoHS(
const TypeConverter &tc,
482 handshake::FuncOp funcOp) {
484 auto *ctx = funcOp->getContext();
485 auto ft = funcOp.getFunctionType();
488 for (
auto [index, type] : llvm::enumerate(ft.getInputs())) {
492 hw::InnerSymAttr{}});
496 for (
auto [index, type] : llvm::enumerate(ft.getResults())) {
501 hw::InnerSymAttr{}});
507 class FuncOpConversion :
public DCOpConversionPattern<handshake::FuncOp> {
509 using DCOpConversionPattern<handshake::FuncOp>::DCOpConversionPattern;
510 using OpAdaptor =
typename handshake::FuncOp::Adaptor;
519 matchAndRewrite(handshake::FuncOp op, OpAdaptor adaptor,
520 ConversionPatternRewriter &rewriter)
const override {
521 ModulePortInfo ports = getModulePortInfoHS(*getTypeConverter(), op);
523 if (op.isExternal()) {
524 rewriter.create<hw::HWModuleExternOp>(
525 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
527 auto hwModule = rewriter.create<hw::HWModuleOp>(
528 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
530 auto ®ion = op->getRegions().front();
532 Region &moduleRegion = hwModule->getRegions().front();
533 rewriter.mergeBlocks(®ion.getBlocks().front(), hwModule.getBodyBlock(),
534 hwModule.getBodyBlock()->getArguments());
535 TypeConverter::SignatureConversion result(moduleRegion.getNumArguments());
536 (void)getTypeConverter()->convertSignatureArgs(
537 TypeRange(moduleRegion.getArgumentTypes()), result);
538 rewriter.applySignatureConversion(&moduleRegion, result);
541 rewriter.eraseOp(op);
546 class HandshakeToDCPass :
public HandshakeToDCBase<HandshakeToDCPass> {
548 void runOnOperation()
override {
549 mlir::ModuleOp mod = getOperation();
561 ConvertedOps convertedOps;
563 ConversionTarget target(getContext());
564 target.addIllegalDialect<handshake::HandshakeDialect>();
565 target.addLegalDialect<dc::DCDialect, func::FuncDialect, hw::HWDialect>();
566 target.addLegalOp<mlir::ModuleOp>();
574 target.markUnknownOpDynamicallyLegal(
575 [&](Operation *op) {
return convertedOps.contains(op); });
577 DCTypeConverter typeConverter;
578 RewritePatternSet
patterns(&getContext());
584 .add<FuncOpConversion, BufferOpConversion, CondBranchConversionPattern,
585 SinkOpConversionPattern, SourceOpConversionPattern,
586 MuxOpConversionPattern, ReturnOpConversion,
587 ForkOpConversionPattern, JoinOpConversion,
588 ControlMergeOpConversion, ConstantOpConversion, SyncOpConversion>(
589 &getContext(), typeConverter, &convertedOps);
593 patterns.add<UnitRateConversionPattern>(&getContext(), typeConverter,
596 if (failed(applyPartialConversion(mod, target, std::move(
patterns))))
603 return std::make_unique<HandshakeToDCPass>();
assert(baseType &&"element must be base type")
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createHandshakeToDCPass()