15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23#include "mlir/IR/BuiltinDialect.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/OperationSupport.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "llvm/Support/FormatVariadic.h"
31#include "llvm/Support/LogicalResult.h"
32#include "llvm/Support/MathExtras.h"
35#define GEN_PASS_DEF_FLATTENMEMREF
36#define GEN_PASS_DEF_FLATTENMEMREFCALLS
37#include "circt/Transforms/Passes.h.inc"
44 return memref.getShape().size() == 1;
58 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
59 type.getElementType());
65 return llvm::formatv(
"{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
66 type.getElementType(), uniqueID);
71static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
72 ValueRange indices, MemRefType memrefType) {
73 assert(memrefType.hasStaticShape() &&
"expected statically shaped memref");
74 Location loc = op->getLoc();
76 if (indices.empty()) {
78 return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0))
82 Value finalIdx = indices.front();
83 for (
auto memIdx : llvm::enumerate(indices.drop_front())) {
84 Value partialIdx = memIdx.value();
85 int64_t indexMulFactor = 1;
88 for (
unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
90 int64_t dimSize = memrefType.getShape()[i];
91 indexMulFactor *= dimSize;
95 if (llvm::isPowerOf2_64(indexMulFactor)) {
96 auto constant = arith::ConstantOp::create(
98 rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
101 arith::ShLIOp::create(rewriter, loc, finalIdx, constant).getResult();
103 auto constant = arith::ConstantOp::create(
104 rewriter, loc, rewriter.getIndexAttr(indexMulFactor))
107 arith::MulIOp::create(rewriter, loc, finalIdx, constant).getResult();
111 auto sumOp = arith::AddIOp::create(rewriter, loc, finalIdx, partialIdx);
112 finalIdx = sumOp.getResult();
118 return llvm::any_of(values, [](Value v) {
119 auto memref = dyn_cast<MemRefType>(v.getType());
129 using OpConversionPattern::OpConversionPattern;
132 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter)
const override {
134 MemRefType type = op.getMemRefType();
136 op.getIndices().size() == 1)
139 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
140 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
142 SmallVector<Value>{finalIdx});
148 using OpConversionPattern::OpConversionPattern;
151 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter)
const override {
153 MemRefType type = op.getMemRefType();
155 op.getIndices().size() == 1)
158 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
159 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
161 SmallVector<Value>{finalIdx});
167 using OpConversionPattern::OpConversionPattern;
170 matchAndRewrite(memref::AllocOp op, OpAdaptor ,
171 ConversionPatternRewriter &rewriter)
const override {
172 MemRefType type = op.getType();
176 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
182 using OpConversionPattern::OpConversionPattern;
185 matchAndRewrite(memref::AllocaOp op, OpAdaptor ,
186 ConversionPatternRewriter &rewriter)
const override {
187 MemRefType type = op.getType();
191 rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
197 using OpConversionPattern::OpConversionPattern;
200 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter)
const override {
202 MemRefType type = op.getType();
208 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
210 SmallVector<Attribute> flattenedVals;
211 for (
auto attr : cstAttr.getValues<Attribute>())
212 flattenedVals.push_back(attr);
214 auto newTypeAttr = TypeAttr::get(newType);
216 auto newName = rewriter.getStringAttr(newNameStr);
219 RankedTensorType tensorType = RankedTensorType::get(
220 {
static_cast<int64_t
>(flattenedVals.size())}, type.getElementType());
221 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
223 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
224 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
225 op.getConstantAttr(), op.getAlignmentAttr());
232 using OpConversionPattern::OpConversionPattern;
235 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const override {
237 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
238 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
239 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
241 MemRefType type = globalOp.getType();
246 auto originalName = globalOp.getSymNameAttr();
250 auto newName = newNameIt->second;
252 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
259 using OpConversionPattern::OpConversionPattern;
262 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter)
const override {
264 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
265 if (!flattenedSource)
268 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
270 !flattenedSrcType.hasStaticShape()) {
271 rewriter.replaceOp(op, flattenedSource);
281template <
typename TOp>
284 using OpAdaptor =
typename TOp::Adaptor;
286 matchAndRewrite(TOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter)
const override {
288 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
289 adaptor.getOperands(), op->getAttrs());
298 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
299 bool rewriteFunctions =
false)
301 rewriteFunctions(rewriteFunctions) {}
304 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter)
const override {
306 llvm::SmallVector<Type> convResTypes;
307 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
310 func::CallOp::create(rewriter, op.getLoc(), adaptor.getCallee(),
311 convResTypes, adaptor.getOperands());
313 if (!rewriteFunctions) {
314 rewriter.replaceOp(op, newCallOp);
321 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
322 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
323 FunctionType funcType = FunctionType::get(
324 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
325 func::FuncOp newFuncOp;
327 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
328 calledFunction, op.getCallee(), funcType);
331 func::FuncOp::create(rewriter, op.getLoc(), op.getCallee(), funcType);
332 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
333 rewriter.replaceOp(op, newCallOp);
339 bool rewriteFunctions;
342template <
typename... TOp>
343void addGenericLegalityConstraint(ConversionTarget &target) {
344 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
351static void populateFlattenMemRefsLegality(ConversionTarget &target) {
352 target.addLegalDialect<arith::ArithDialect>();
353 target.addDynamicallyLegalOp<memref::AllocOp>(
355 target.addDynamicallyLegalOp<memref::AllocaOp>(
357 target.addDynamicallyLegalOp<memref::StoreOp>(
358 [](memref::StoreOp op) {
return op.getIndices().size() == 1; });
359 target.addDynamicallyLegalOp<memref::LoadOp>(
360 [](memref::LoadOp op) {
return op.getIndices().size() == 1; });
361 target.addDynamicallyLegalOp<memref::GlobalOp>(
363 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
365 addGenericLegalityConstraint<func::CallOp, func::ReturnOp, memref::DeallocOp,
366 memref::CopyOp>(target);
368 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
369 auto argsConverted = llvm::all_of(op.getArgumentTypes(), [](Type type) {
370 if (auto memref = dyn_cast<MemRefType>(type))
371 return isUniDimensional(memref);
375 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
376 if (auto memref = dyn_cast<MemRefType>(type))
377 return isUniDimensional(memref);
381 return argsConverted && resultsConverted;
388static Value materializeCollapseShapeFlattening(OpBuilder &builder,
392 assert(type.hasStaticShape() &&
393 "Can only subview flatten memref's with static shape (for now...).");
394 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
395 int64_t memSize = sourceType.getNumElements();
396 ArrayRef<int64_t> sourceShape = sourceType.getShape();
397 ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
400 auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
401 assert(indices.has_value() &&
"expected a valid collapse");
404 return memref::CollapseShapeOp::create(builder, loc, inputs[0],
408static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
410 typeConverter.addConversion([](Type type) {
return type; });
412 typeConverter.addConversion([](MemRefType memref) {
415 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
416 memref.getElementType());
420struct FlattenMemRefPass
421 :
public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
423 void runOnOperation()
override {
425 auto *ctx = &getContext();
426 TypeConverter typeConverter;
427 populateTypeConversionPatterns(typeConverter);
430 SetVector<StringRef> rewrittenCallees;
431 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
432 AllocaOpConversion, GlobalOpConversion, GetGlobalOpConversion,
433 ReshapeOpConversion, OperandConversionPattern<func::ReturnOp>,
434 OperandConversionPattern<memref::DeallocOp>,
435 OperandConversionPattern<memref::DeallocOp>,
436 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
438 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
441 ConversionTarget target(*ctx);
442 populateFlattenMemRefsLegality(target);
443 mlir::cf::populateCFStructuralTypeConversionsAndLegality(typeConverter,
446 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
454struct FlattenMemRefCallsPass
455 :
public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
457 void runOnOperation()
override {
458 auto *ctx = &getContext();
459 TypeConverter typeConverter;
460 populateTypeConversionPatterns(typeConverter);
469 patterns.add<CallOpConversion>(typeConverter, ctx,
472 ConversionTarget target(*ctx);
473 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
474 addGenericLegalityConstraint<func::CallOp>(target);
475 addGenericLegalityConstraint<func::FuncOp>(target);
479 typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
481 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
493 return std::make_unique<FlattenMemRefPass>();
497 return std::make_unique<FlattenMemRefCallsPass>();
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...