14 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15 #include "mlir/Conversion/LLVMCommon/Pattern.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/MathExtras.h"
30 #define GEN_PASS_DEF_FLATTENMEMREF
31 #define GEN_PASS_DEF_FLATTENMEMREFCALLS
32 #include "circt/Transforms/Passes.h.inc"
36 using namespace circt;
39 return memref.getShape().size() == 1;
51 static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
52 ValueRange indices, MemRefType memrefType) {
53 assert(memrefType.hasStaticShape() &&
"expected statically shaped memref");
54 Location loc = op->getLoc();
56 if (indices.empty()) {
58 return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
62 Value finalIdx = indices.front();
63 for (
auto memIdx : llvm::enumerate(indices.drop_front())) {
64 Value partialIdx = memIdx.value();
65 int64_t indexMulFactor = 1;
68 for (
unsigned i = 0; i <= memIdx.index(); ++i) {
69 int64_t dimSize = memrefType.getShape()[i];
70 indexMulFactor *= dimSize;
74 if (llvm::isPowerOf2_64(indexMulFactor)) {
77 .create<arith::ConstantOp>(
78 loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
81 rewriter.create<arith::ShLIOp>(loc, partialIdx, constant).getResult();
83 auto constant = rewriter
84 .create<arith::ConstantOp>(
85 loc, rewriter.getIndexAttr(indexMulFactor))
88 rewriter.create<arith::MulIOp>(loc, partialIdx, constant).getResult();
92 auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
93 finalIdx = sumOp.getResult();
99 return llvm::any_of(values, [](Value v) {
100 auto memref = dyn_cast<MemRefType>(v.getType());
110 using OpConversionPattern::OpConversionPattern;
113 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
114 ConversionPatternRewriter &rewriter)
const override {
115 MemRefType type = op.getMemRefType();
117 op.getIndices().size() == 1)
120 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
121 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
123 SmallVector<Value>{finalIdx});
129 using OpConversionPattern::OpConversionPattern;
132 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter)
const override {
134 MemRefType type = op.getMemRefType();
136 op.getIndices().size() == 1)
139 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
140 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
142 SmallVector<Value>{finalIdx});
148 using OpConversionPattern::OpConversionPattern;
151 matchAndRewrite(memref::AllocOp op, OpAdaptor ,
152 ConversionPatternRewriter &rewriter)
const override {
153 MemRefType type = op.getType();
157 SmallVector<int64_t>{type.getNumElements()}, type.getElementType());
158 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
165 template <
typename TOp>
168 using OpAdaptor =
typename TOp::Adaptor;
170 matchAndRewrite(TOp op, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter)
const override {
172 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
173 adaptor.getOperands(), op->getAttrs());
180 struct CondBranchOpConversion
182 using OpConversionPattern::OpConversionPattern;
185 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter)
const override {
187 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
188 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
189 adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
198 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
199 bool rewriteFunctions =
false)
201 rewriteFunctions(rewriteFunctions) {}
204 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter)
const override {
206 llvm::SmallVector<Type> convResTypes;
207 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
209 auto newCallOp = rewriter.create<func::CallOp>(
210 op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
212 if (!rewriteFunctions) {
213 rewriter.replaceOp(op, newCallOp);
220 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
221 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
223 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
224 func::FuncOp newFuncOp;
226 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
227 calledFunction, op.getCallee(), funcType);
230 rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
231 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
232 rewriter.replaceOp(op, newCallOp);
238 bool rewriteFunctions;
241 template <
typename... TOp>
242 void addGenericLegalityConstraint(ConversionTarget &target) {
243 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
250 static void populateFlattenMemRefsLegality(ConversionTarget &target) {
251 target.addLegalDialect<arith::ArithDialect>();
252 target.addDynamicallyLegalOp<memref::AllocOp>(
254 target.addDynamicallyLegalOp<memref::StoreOp>(
255 [](memref::StoreOp op) {
return op.getIndices().size() == 1; });
256 target.addDynamicallyLegalOp<memref::LoadOp>(
257 [](memref::LoadOp op) {
return op.getIndices().size() == 1; });
259 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
260 func::CallOp, func::ReturnOp, memref::DeallocOp,
261 memref::CopyOp>(target);
263 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
264 auto argsConverted = llvm::none_of(op.getBlocks(), [](
auto &block) {
265 return hasMultiDimMemRef(block.getArguments());
268 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
269 if (auto memref = dyn_cast<MemRefType>(type))
270 return isUniDimensional(memref);
274 return argsConverted && resultsConverted;
281 static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
282 ValueRange inputs, Location loc) {
283 assert(type.hasStaticShape() &&
284 "Can only subview flatten memref's with static shape (for now...).");
285 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
286 int64_t memSize = sourceType.getNumElements();
287 unsigned dims = sourceType.getShape().size();
290 SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
291 SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
292 offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
293 SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
296 MemRefType outType =
MemRefType::get({memSize}, type.getElementType());
297 return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
301 static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
303 typeConverter.addConversion([](Type type) {
return type; });
305 typeConverter.addConversion([](MemRefType memref) {
308 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
309 memref.getElementType());
313 struct FlattenMemRefPass
314 :
public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
316 void runOnOperation()
override {
318 auto *ctx = &getContext();
319 TypeConverter typeConverter;
320 populateTypeConversionPatterns(typeConverter);
323 SetVector<StringRef> rewrittenCallees;
324 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
325 OperandConversionPattern<func::ReturnOp>,
326 OperandConversionPattern<memref::DeallocOp>,
327 CondBranchOpConversion,
328 OperandConversionPattern<memref::DeallocOp>,
329 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
331 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
334 ConversionTarget target(*ctx);
335 populateFlattenMemRefsLegality(target);
337 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
345 struct FlattenMemRefCallsPass
346 :
public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
348 void runOnOperation()
override {
349 auto *ctx = &getContext();
350 TypeConverter typeConverter;
351 populateTypeConversionPatterns(typeConverter);
360 patterns.add<CallOpConversion>(typeConverter, ctx,
363 ConversionTarget target(*ctx);
364 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
365 addGenericLegalityConstraint<func::CallOp>(target);
366 addGenericLegalityConstraint<func::FuncOp>(target);
370 typeConverter.addTargetMaterialization(materializeSubViewFlattening);
372 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
384 return std::make_unique<FlattenMemRefPass>();
388 return std::make_unique<FlattenMemRefCallsPass>();
assert(baseType &&"element must be base type")
static bool hasMultiDimMemRef(ValueRange values)
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...