15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23#include "mlir/IR/BuiltinDialect.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/OperationSupport.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "llvm/Support/FormatVariadic.h"
31#include "llvm/Support/LogicalResult.h"
32#include "llvm/Support/MathExtras.h"
35#define GEN_PASS_DEF_FLATTENMEMREF
36#define GEN_PASS_DEF_FLATTENMEMREFCALLS
37#include "circt/Transforms/Passes.h.inc"
44 return memref.getShape().size() == 1;
60 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
61 type.getElementType());
67 unsigned uniqueID = state.
counter++;
68 return llvm::formatv(
"{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
69 type.getElementType(), uniqueID);
74static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
75 ValueRange indices, MemRefType memrefType) {
76 assert(memrefType.hasStaticShape() &&
"expected statically shaped memref");
77 Location loc = op->getLoc();
79 if (indices.empty()) {
81 return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0))
85 Value finalIdx = indices.front();
86 for (
auto memIdx : llvm::enumerate(indices.drop_front())) {
87 Value partialIdx = memIdx.value();
88 int64_t indexMulFactor = 1;
91 for (
unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
93 int64_t dimSize = memrefType.getShape()[i];
94 indexMulFactor *= dimSize;
98 if (llvm::isPowerOf2_64(indexMulFactor)) {
99 auto constant = arith::ConstantOp::create(
101 rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
104 arith::ShLIOp::create(rewriter, loc, finalIdx, constant).getResult();
106 auto constant = arith::ConstantOp::create(
107 rewriter, loc, rewriter.getIndexAttr(indexMulFactor))
110 arith::MulIOp::create(rewriter, loc, finalIdx, constant).getResult();
114 auto sumOp = arith::AddIOp::create(rewriter, loc, finalIdx, partialIdx);
115 finalIdx = sumOp.getResult();
121 return llvm::any_of(values, [](Value v) {
122 auto memref = dyn_cast<MemRefType>(v.getType());
132 using OpConversionPattern::OpConversionPattern;
135 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter)
const override {
137 MemRefType type = op.getMemRefType();
139 op.getIndices().size() == 1)
142 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
143 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
145 SmallVector<Value>{finalIdx});
151 using OpConversionPattern::OpConversionPattern;
154 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter)
const override {
156 MemRefType type = op.getMemRefType();
158 op.getIndices().size() == 1)
161 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
162 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
164 SmallVector<Value>{finalIdx});
170 using OpConversionPattern::OpConversionPattern;
173 matchAndRewrite(memref::AllocOp op, OpAdaptor ,
174 ConversionPatternRewriter &rewriter)
const override {
175 MemRefType type = op.getType();
179 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
185 using OpConversionPattern::OpConversionPattern;
188 matchAndRewrite(memref::AllocaOp op, OpAdaptor ,
189 ConversionPatternRewriter &rewriter)
const override {
190 MemRefType type = op.getType();
194 rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
200 GlobalOpConversion(TypeConverter &typeConverter, MLIRContext *
context,
205 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
206 ConversionPatternRewriter &rewriter)
const override {
207 MemRefType type = op.getType();
213 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
215 SmallVector<Attribute> flattenedVals;
216 for (
auto attr : cstAttr.getValues<Attribute>())
217 flattenedVals.push_back(attr);
219 auto newTypeAttr = TypeAttr::get(newType);
222 auto newName = rewriter.getStringAttr(newNameStr);
223 state.nameMap[op.getSymNameAttr()] = newName;
225 RankedTensorType tensorType = RankedTensorType::get(
226 {
static_cast<int64_t
>(flattenedVals.size())}, type.getElementType());
227 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
229 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
230 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
231 op.getConstantAttr(), op.getAlignmentAttr());
241 GetGlobalOpConversion(TypeConverter &typeConverter, MLIRContext *
context,
246 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter)
const override {
248 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
249 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
250 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
252 MemRefType type = globalOp.getType();
257 auto originalName = globalOp.getSymNameAttr();
258 auto newNameIt = state.nameMap.find(originalName);
259 if (newNameIt == state.nameMap.end())
261 auto newName = newNameIt->second;
263 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
273 using OpConversionPattern::OpConversionPattern;
276 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
277 ConversionPatternRewriter &rewriter)
const override {
278 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
279 if (!flattenedSource)
282 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
284 !flattenedSrcType.hasStaticShape()) {
285 rewriter.replaceOp(op, flattenedSource);
295template <
typename TOp>
298 using OpAdaptor =
typename TOp::Adaptor;
300 matchAndRewrite(TOp op, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter)
const override {
302 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
303 adaptor.getOperands(), op->getAttrs());
312 CallOpConversion(TypeConverter &typeConverter, MLIRContext *
context,
313 bool rewriteFunctions =
false)
315 rewriteFunctions(rewriteFunctions) {}
318 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter)
const override {
320 llvm::SmallVector<Type> convResTypes;
321 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
324 func::CallOp::create(rewriter, op.getLoc(), adaptor.getCallee(),
325 convResTypes, adaptor.getOperands());
327 if (!rewriteFunctions) {
328 rewriter.replaceOp(op, newCallOp);
335 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
336 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
337 FunctionType funcType = FunctionType::get(
338 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
339 func::FuncOp newFuncOp;
341 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
342 calledFunction, op.getCallee(), funcType);
345 func::FuncOp::create(rewriter, op.getLoc(), op.getCallee(), funcType);
346 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
347 rewriter.replaceOp(op, newCallOp);
353 bool rewriteFunctions;
356template <
typename... TOp>
357void addGenericLegalityConstraint(ConversionTarget &target) {
358 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
365static void populateFlattenMemRefsLegality(ConversionTarget &target) {
366 target.addLegalDialect<arith::ArithDialect>();
367 target.addDynamicallyLegalOp<memref::AllocOp>(
369 target.addDynamicallyLegalOp<memref::AllocaOp>(
371 target.addDynamicallyLegalOp<memref::StoreOp>(
372 [](memref::StoreOp op) {
return op.getIndices().size() == 1; });
373 target.addDynamicallyLegalOp<memref::LoadOp>(
374 [](memref::LoadOp op) {
return op.getIndices().size() == 1; });
375 target.addDynamicallyLegalOp<memref::GlobalOp>(
377 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
379 addGenericLegalityConstraint<func::CallOp, func::ReturnOp, memref::DeallocOp,
380 memref::CopyOp>(target);
382 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
383 auto argsConverted = llvm::all_of(op.getArgumentTypes(), [](Type type) {
384 if (auto memref = dyn_cast<MemRefType>(type))
385 return isUniDimensional(memref);
389 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
390 if (auto memref = dyn_cast<MemRefType>(type))
391 return isUniDimensional(memref);
395 return argsConverted && resultsConverted;
402static Value materializeCollapseShapeFlattening(OpBuilder &builder,
406 assert(type.hasStaticShape() &&
407 "Can only subview flatten memref's with static shape (for now...).");
408 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
409 int64_t memSize = sourceType.getNumElements();
410 ArrayRef<int64_t> sourceShape = sourceType.getShape();
411 ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
414 auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
415 assert(indices.has_value() &&
"expected a valid collapse");
418 return memref::CollapseShapeOp::create(builder, loc, inputs[0],
422static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
424 typeConverter.addConversion([](Type type) {
return type; });
426 typeConverter.addConversion([](MemRefType memref) {
429 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
430 memref.getElementType());
434struct FlattenMemRefPass
435 :
public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
437 void runOnOperation()
override {
439 auto *ctx = &getContext();
440 TypeConverter typeConverter;
442 populateTypeConversionPatterns(typeConverter);
445 SetVector<StringRef> rewrittenCallees;
446 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
447 AllocaOpConversion, ReshapeOpConversion,
448 OperandConversionPattern<func::ReturnOp>,
449 OperandConversionPattern<memref::DeallocOp>,
450 OperandConversionPattern<memref::DeallocOp>,
451 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
453 patterns.add<GlobalOpConversion, GetGlobalOpConversion>(typeConverter, ctx,
455 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
458 ConversionTarget target(*ctx);
459 populateFlattenMemRefsLegality(target);
460 mlir::cf::populateCFStructuralTypeConversionsAndLegality(typeConverter,
463 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
471struct FlattenMemRefCallsPass
472 :
public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
474 void runOnOperation()
override {
475 auto *ctx = &getContext();
476 TypeConverter typeConverter;
477 populateTypeConversionPatterns(typeConverter);
486 patterns.add<CallOpConversion>(typeConverter, ctx,
489 ConversionTarget target(*ctx);
490 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
491 addGenericLegalityConstraint<func::CallOp>(target);
492 addGenericLegalityConstraint<func::FuncOp>(target);
496 typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
498 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
510 return std::make_unique<FlattenMemRefPass>();
514 return std::make_unique<FlattenMemRefCallsPass>();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static MemRefType getFlattenedMemRefType(MemRefType type)
static std::string getFlattenedMemRefName(FlattenMemRefsState &state, StringAttr baseName, MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
DenseMap< StringAttr, StringAttr > nameMap
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...