15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/IR/BuiltinDialect.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/ImplicitLocOpBuilder.h"
24#include "mlir/IR/OperationSupport.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/Support/FormatVariadic.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/MathExtras.h"
33#define GEN_PASS_DEF_FLATTENMEMREF
34#define GEN_PASS_DEF_FLATTENMEMREFCALLS
35#include "circt/Transforms/Passes.h.inc"
42 return memref.getShape().size() == 1;
56 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
57 type.getElementType());
63 return llvm::formatv(
"{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
64 type.getElementType(), uniqueID);
69static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
70 ValueRange indices, MemRefType memrefType) {
71 assert(memrefType.hasStaticShape() &&
"expected statically shaped memref");
72 Location loc = op->getLoc();
74 if (indices.empty()) {
76 return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
80 Value finalIdx = indices.front();
81 for (
auto memIdx : llvm::enumerate(indices.drop_front())) {
82 Value partialIdx = memIdx.value();
83 int64_t indexMulFactor = 1;
86 for (
unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
88 int64_t dimSize = memrefType.getShape()[i];
89 indexMulFactor *= dimSize;
93 if (llvm::isPowerOf2_64(indexMulFactor)) {
96 .create<arith::ConstantOp>(
97 loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
100 rewriter.create<arith::ShLIOp>(loc, finalIdx, constant).getResult();
102 auto constant = rewriter
103 .create<arith::ConstantOp>(
104 loc, rewriter.getIndexAttr(indexMulFactor))
107 rewriter.create<arith::MulIOp>(loc, finalIdx, constant).getResult();
111 auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
112 finalIdx = sumOp.getResult();
118 return llvm::any_of(values, [](Value v) {
119 auto memref = dyn_cast<MemRefType>(v.getType());
129 using OpConversionPattern::OpConversionPattern;
132 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter)
const override {
134 MemRefType type = op.getMemRefType();
136 op.getIndices().size() == 1)
139 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
140 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
142 SmallVector<Value>{finalIdx});
148 using OpConversionPattern::OpConversionPattern;
151 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter)
const override {
153 MemRefType type = op.getMemRefType();
155 op.getIndices().size() == 1)
158 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
159 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
161 SmallVector<Value>{finalIdx});
167 using OpConversionPattern::OpConversionPattern;
170 matchAndRewrite(memref::AllocOp op, OpAdaptor ,
171 ConversionPatternRewriter &rewriter)
const override {
172 MemRefType type = op.getType();
176 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
182 using OpConversionPattern::OpConversionPattern;
185 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter)
const override {
187 MemRefType type = op.getType();
193 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
195 SmallVector<Attribute> flattenedVals;
196 for (
auto attr : cstAttr.getValues<Attribute>())
197 flattenedVals.push_back(attr);
199 auto newTypeAttr = TypeAttr::get(newType);
201 auto newName = rewriter.getStringAttr(newNameStr);
204 RankedTensorType tensorType = RankedTensorType::get(
205 {
static_cast<int64_t
>(flattenedVals.size())}, type.getElementType());
206 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
208 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
209 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
210 op.getConstantAttr(), op.getAlignmentAttr());
217 using OpConversionPattern::OpConversionPattern;
220 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
221 ConversionPatternRewriter &rewriter)
const override {
222 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
223 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
224 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
226 MemRefType type = globalOp.getType();
231 auto originalName = globalOp.getSymNameAttr();
235 auto newName = newNameIt->second;
237 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
244 using OpConversionPattern::OpConversionPattern;
247 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
248 ConversionPatternRewriter &rewriter)
const override {
249 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
250 if (!flattenedSource)
253 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
255 !flattenedSrcType.hasStaticShape()) {
256 rewriter.replaceOp(op, flattenedSource);
266template <
typename TOp>
269 using OpAdaptor =
typename TOp::Adaptor;
271 matchAndRewrite(TOp op, OpAdaptor adaptor,
272 ConversionPatternRewriter &rewriter)
const override {
273 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
274 adaptor.getOperands(), op->getAttrs());
281struct CondBranchOpConversion
283 using OpConversionPattern::OpConversionPattern;
286 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter)
const override {
288 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
289 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
290 adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
299 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
300 bool rewriteFunctions =
false)
302 rewriteFunctions(rewriteFunctions) {}
305 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter)
const override {
307 llvm::SmallVector<Type> convResTypes;
308 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
310 auto newCallOp = rewriter.create<func::CallOp>(
311 op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
313 if (!rewriteFunctions) {
314 rewriter.replaceOp(op, newCallOp);
321 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
322 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
323 FunctionType funcType = FunctionType::get(
324 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
325 func::FuncOp newFuncOp;
327 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
328 calledFunction, op.getCallee(), funcType);
331 rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
332 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
333 rewriter.replaceOp(op, newCallOp);
339 bool rewriteFunctions;
342template <
typename... TOp>
343void addGenericLegalityConstraint(ConversionTarget &target) {
344 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
351static void populateFlattenMemRefsLegality(ConversionTarget &target) {
352 target.addLegalDialect<arith::ArithDialect>();
353 target.addDynamicallyLegalOp<memref::AllocOp>(
355 target.addDynamicallyLegalOp<memref::StoreOp>(
356 [](memref::StoreOp op) {
return op.getIndices().size() == 1; });
357 target.addDynamicallyLegalOp<memref::LoadOp>(
358 [](memref::LoadOp op) {
return op.getIndices().size() == 1; });
359 target.addDynamicallyLegalOp<memref::GlobalOp>(
361 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
363 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
364 func::CallOp, func::ReturnOp, memref::DeallocOp,
365 memref::CopyOp>(target);
367 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
368 auto argsConverted = llvm::none_of(op.getBlocks(), [](
auto &block) {
369 return hasMultiDimMemRef(block.getArguments());
372 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
373 if (auto memref = dyn_cast<MemRefType>(type))
374 return isUniDimensional(memref);
378 return argsConverted && resultsConverted;
385static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
386 ValueRange inputs, Location loc) {
387 assert(type.hasStaticShape() &&
388 "Can only subview flatten memref's with static shape (for now...).");
389 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
390 int64_t memSize = sourceType.getNumElements();
391 unsigned dims = sourceType.getShape().size();
394 SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
395 SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
396 offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
397 SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
400 MemRefType outType = MemRefType::get({memSize}, type.getElementType());
401 return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
405static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
407 typeConverter.addConversion([](Type type) {
return type; });
409 typeConverter.addConversion([](MemRefType memref) {
412 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
413 memref.getElementType());
417struct FlattenMemRefPass
418 :
public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
420 void runOnOperation()
override {
422 auto *ctx = &getContext();
423 TypeConverter typeConverter;
424 populateTypeConversionPatterns(typeConverter);
427 SetVector<StringRef> rewrittenCallees;
428 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
429 GlobalOpConversion, GetGlobalOpConversion, ReshapeOpConversion,
430 OperandConversionPattern<func::ReturnOp>,
431 OperandConversionPattern<memref::DeallocOp>,
432 CondBranchOpConversion,
433 OperandConversionPattern<memref::DeallocOp>,
434 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
436 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
439 ConversionTarget target(*ctx);
440 populateFlattenMemRefsLegality(target);
442 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
450struct FlattenMemRefCallsPass
451 :
public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
453 void runOnOperation()
override {
454 auto *ctx = &getContext();
455 TypeConverter typeConverter;
456 populateTypeConversionPatterns(typeConverter);
465 patterns.add<CallOpConversion>(typeConverter, ctx,
468 ConversionTarget target(*ctx);
469 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
470 addGenericLegalityConstraint<func::CallOp>(target);
471 addGenericLegalityConstraint<func::FuncOp>(target);
475 typeConverter.addTargetMaterialization(materializeSubViewFlattening);
477 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
489 return std::make_unique<FlattenMemRefPass>();
493 return std::make_unique<FlattenMemRefCallsPass>();
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...