CIRCT  20.0.0git
LowerArcsToFuncs.cpp
Go to the documentation of this file.
1 //===- LowerArcsToFuncs.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Transforms/DialectConversion.h"
14 #include "llvm/Support/Debug.h"
15 
16 #define DEBUG_TYPE "arc-lower-arcs-to-funcs"
17 
18 namespace circt {
19 namespace arc {
20 #define GEN_PASS_DEF_LOWERARCSTOFUNCS
21 #include "circt/Dialect/Arc/ArcPasses.h.inc"
22 } // namespace arc
23 } // namespace circt
24 
25 using namespace mlir;
26 using namespace circt;
27 
28 //===----------------------------------------------------------------------===//
29 // Pass Implementation
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 struct LowerArcsToFuncsPass
34  : public arc::impl::LowerArcsToFuncsBase<LowerArcsToFuncsPass> {
35 
36  LogicalResult lowerToFuncs();
37  void runOnOperation() override;
38 };
39 
40 struct DefineOpLowering : public OpConversionPattern<arc::DefineOp> {
41  using OpConversionPattern::OpConversionPattern;
42  LogicalResult
43  matchAndRewrite(arc::DefineOp op, OpAdaptor adaptor,
44  ConversionPatternRewriter &rewriter) const final {
45  auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(),
46  op.getFunctionType());
47  func->setAttr(
48  "llvm.linkage",
49  LLVM::LinkageAttr::get(getContext(), LLVM::linkage::Linkage::Internal));
50  rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
51  rewriter.eraseOp(op);
52  return success();
53  }
54 };
55 
56 struct OutputOpLowering : public OpConversionPattern<arc::OutputOp> {
57  using OpConversionPattern::OpConversionPattern;
58  LogicalResult
59  matchAndRewrite(arc::OutputOp op, OpAdaptor adaptor,
60  ConversionPatternRewriter &rewriter) const final {
61  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOutputs());
62  return success();
63  }
64 };
65 
66 struct CallOpLowering : public OpConversionPattern<arc::CallOp> {
67  using OpConversionPattern::OpConversionPattern;
68  LogicalResult
69  matchAndRewrite(arc::CallOp op, OpAdaptor adaptor,
70  ConversionPatternRewriter &rewriter) const final {
71  SmallVector<Type> newResultTypes;
72  if (failed(
73  typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
74  return failure();
75  rewriter.replaceOpWithNewOp<func::CallOp>(
76  op, newResultTypes, op.getArcAttr(), adaptor.getInputs());
77  return success();
78  }
79 };
80 
81 struct StateOpLowering : public OpConversionPattern<arc::StateOp> {
82  using OpConversionPattern::OpConversionPattern;
83  LogicalResult
84  matchAndRewrite(arc::StateOp op, OpAdaptor adaptor,
85  ConversionPatternRewriter &rewriter) const final {
86  SmallVector<Type> newResultTypes;
87  if (failed(
88  typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
89  return failure();
90  rewriter.replaceOpWithNewOp<func::CallOp>(
91  op, newResultTypes, op.getArcAttr(), adaptor.getInputs());
92  return success();
93  }
94 };
95 
96 } // namespace
97 
98 static void populateLegality(ConversionTarget &target) {
99  target.addLegalDialect<func::FuncDialect>();
100  target.addLegalDialect<LLVM::LLVMDialect>();
101 
102  target.addIllegalOp<arc::CallOp>();
103  target.addIllegalOp<arc::DefineOp>();
104  target.addIllegalOp<arc::OutputOp>();
105  target.addIllegalOp<arc::StateOp>();
106 }
107 
108 static void populateOpConversion(RewritePatternSet &patterns,
109  TypeConverter &typeConverter) {
110  auto *context = patterns.getContext();
111  patterns
112  .add<CallOpLowering, DefineOpLowering, OutputOpLowering, StateOpLowering>(
113  typeConverter, context);
114 }
115 
116 static void populateTypeConversion(TypeConverter &typeConverter) {
117  typeConverter.addConversion([&](Type type) { return type; });
118 }
119 
120 LogicalResult LowerArcsToFuncsPass::lowerToFuncs() {
121  LLVM_DEBUG(llvm::dbgs() << "Lowering arcs to funcs\n");
122  ConversionTarget target(getContext());
123  TypeConverter converter;
124  RewritePatternSet patterns(&getContext());
125  populateLegality(target);
126  populateTypeConversion(converter);
127  populateOpConversion(patterns, converter);
128  return applyPartialConversion(getOperation(), target, std::move(patterns));
129 }
130 
131 void LowerArcsToFuncsPass::runOnOperation() {
132  if (failed(lowerToFuncs()))
133  return signalPassFailure();
134 }
135 
136 std::unique_ptr<Pass> arc::createLowerArcsToFuncsPass() {
137  return std::make_unique<LowerArcsToFuncsPass>();
138 }
static void populateLegality(ConversionTarget &target)
static void populateOpConversion(RewritePatternSet &patterns, TypeConverter &typeConverter)
static void populateTypeConversion(TypeConverter &typeConverter)
std::unique_ptr< mlir::Pass > createLowerArcsToFuncsPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21