15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
22#include "mlir/IR/BuiltinDialect.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/ImplicitLocOpBuilder.h"
25#include "mlir/IR/OperationSupport.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "llvm/Support/FormatVariadic.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
34#define GEN_PASS_DEF_FLATTENMEMREF
35#define GEN_PASS_DEF_FLATTENMEMREFCALLS
36#include "circt/Transforms/Passes.h.inc"
43 return memref.getShape().size() == 1;
57 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
58 type.getElementType());
64 return llvm::formatv(
"{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
65 type.getElementType(), uniqueID);
70static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
71 ValueRange indices, MemRefType memrefType) {
72 assert(memrefType.hasStaticShape() &&
"expected statically shaped memref");
73 Location loc = op->getLoc();
75 if (indices.empty()) {
77 return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0))
81 Value finalIdx = indices.front();
82 for (
auto memIdx : llvm::enumerate(indices.drop_front())) {
83 Value partialIdx = memIdx.value();
84 int64_t indexMulFactor = 1;
87 for (
unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
89 int64_t dimSize = memrefType.getShape()[i];
90 indexMulFactor *= dimSize;
94 if (llvm::isPowerOf2_64(indexMulFactor)) {
95 auto constant = arith::ConstantOp::create(
97 rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
100 arith::ShLIOp::create(rewriter, loc, finalIdx, constant).getResult();
102 auto constant = arith::ConstantOp::create(
103 rewriter, loc, rewriter.getIndexAttr(indexMulFactor))
106 arith::MulIOp::create(rewriter, loc, finalIdx, constant).getResult();
110 auto sumOp = arith::AddIOp::create(rewriter, loc, finalIdx, partialIdx);
111 finalIdx = sumOp.getResult();
117 return llvm::any_of(values, [](Value v) {
118 auto memref = dyn_cast<MemRefType>(v.getType());
128 using OpConversionPattern::OpConversionPattern;
131 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const override {
133 MemRefType type = op.getMemRefType();
135 op.getIndices().size() == 1)
138 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
139 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
141 SmallVector<Value>{finalIdx});
147 using OpConversionPattern::OpConversionPattern;
150 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
151 ConversionPatternRewriter &rewriter)
const override {
152 MemRefType type = op.getMemRefType();
154 op.getIndices().size() == 1)
157 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
158 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
160 SmallVector<Value>{finalIdx});
166 using OpConversionPattern::OpConversionPattern;
169 matchAndRewrite(memref::AllocOp op, OpAdaptor ,
170 ConversionPatternRewriter &rewriter)
const override {
171 MemRefType type = op.getType();
175 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
181 using OpConversionPattern::OpConversionPattern;
184 matchAndRewrite(memref::AllocaOp op, OpAdaptor ,
185 ConversionPatternRewriter &rewriter)
const override {
186 MemRefType type = op.getType();
190 rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
196 using OpConversionPattern::OpConversionPattern;
199 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
200 ConversionPatternRewriter &rewriter)
const override {
201 MemRefType type = op.getType();
207 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
209 SmallVector<Attribute> flattenedVals;
210 for (
auto attr : cstAttr.getValues<Attribute>())
211 flattenedVals.push_back(attr);
213 auto newTypeAttr = TypeAttr::get(newType);
215 auto newName = rewriter.getStringAttr(newNameStr);
218 RankedTensorType tensorType = RankedTensorType::get(
219 {
static_cast<int64_t
>(flattenedVals.size())}, type.getElementType());
220 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
222 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
223 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
224 op.getConstantAttr(), op.getAlignmentAttr());
231 using OpConversionPattern::OpConversionPattern;
234 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter)
const override {
236 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
237 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
238 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
240 MemRefType type = globalOp.getType();
245 auto originalName = globalOp.getSymNameAttr();
249 auto newName = newNameIt->second;
251 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
258 using OpConversionPattern::OpConversionPattern;
261 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override {
263 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
264 if (!flattenedSource)
267 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
269 !flattenedSrcType.hasStaticShape()) {
270 rewriter.replaceOp(op, flattenedSource);
280template <
typename TOp>
283 using OpAdaptor =
typename TOp::Adaptor;
285 matchAndRewrite(TOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter)
const override {
287 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
288 adaptor.getOperands(), op->getAttrs());
295struct CondBranchOpConversion
297 using OpConversionPattern::OpConversionPattern;
300 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter)
const override {
302 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
303 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
304 adaptor.getFalseDestOperands(),
nullptr,
305 op.getTrueDest(), op.getFalseDest());
314 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
315 bool rewriteFunctions =
false)
317 rewriteFunctions(rewriteFunctions) {}
320 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter)
const override {
322 llvm::SmallVector<Type> convResTypes;
323 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
326 func::CallOp::create(rewriter, op.getLoc(), adaptor.getCallee(),
327 convResTypes, adaptor.getOperands());
329 if (!rewriteFunctions) {
330 rewriter.replaceOp(op, newCallOp);
337 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
338 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
339 FunctionType funcType = FunctionType::get(
340 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
341 func::FuncOp newFuncOp;
343 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
344 calledFunction, op.getCallee(), funcType);
347 func::FuncOp::create(rewriter, op.getLoc(), op.getCallee(), funcType);
348 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
349 rewriter.replaceOp(op, newCallOp);
355 bool rewriteFunctions;
358template <
typename... TOp>
359void addGenericLegalityConstraint(ConversionTarget &target) {
360 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
367static void populateFlattenMemRefsLegality(ConversionTarget &target) {
368 target.addLegalDialect<arith::ArithDialect>();
369 target.addDynamicallyLegalOp<memref::AllocOp>(
371 target.addDynamicallyLegalOp<memref::AllocaOp>(
373 target.addDynamicallyLegalOp<memref::StoreOp>(
374 [](memref::StoreOp op) {
return op.getIndices().size() == 1; });
375 target.addDynamicallyLegalOp<memref::LoadOp>(
376 [](memref::LoadOp op) {
return op.getIndices().size() == 1; });
377 target.addDynamicallyLegalOp<memref::GlobalOp>(
379 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
381 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
382 func::CallOp, func::ReturnOp, memref::DeallocOp,
383 memref::CopyOp>(target);
385 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
386 auto argsConverted = llvm::none_of(op.getBlocks(), [](
auto &block) {
387 return hasMultiDimMemRef(block.getArguments());
390 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
391 if (auto memref = dyn_cast<MemRefType>(type))
392 return isUniDimensional(memref);
396 return argsConverted && resultsConverted;
403static Value materializeCollapseShapeFlattening(OpBuilder &builder,
407 assert(type.hasStaticShape() &&
408 "Can only subview flatten memref's with static shape (for now...).");
409 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
410 int64_t memSize = sourceType.getNumElements();
411 ArrayRef<int64_t> sourceShape = sourceType.getShape();
412 ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
415 auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
416 assert(indices.has_value() &&
"expected a valid collapse");
419 return memref::CollapseShapeOp::create(builder, loc, inputs[0],
423static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
425 typeConverter.addConversion([](Type type) {
return type; });
427 typeConverter.addConversion([](MemRefType memref) {
430 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
431 memref.getElementType());
435struct FlattenMemRefPass
436 :
public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
438 void runOnOperation()
override {
440 auto *ctx = &getContext();
441 TypeConverter typeConverter;
442 populateTypeConversionPatterns(typeConverter);
445 SetVector<StringRef> rewrittenCallees;
446 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
447 AllocaOpConversion, GlobalOpConversion, GetGlobalOpConversion,
448 ReshapeOpConversion, OperandConversionPattern<func::ReturnOp>,
449 OperandConversionPattern<memref::DeallocOp>,
450 CondBranchOpConversion,
451 OperandConversionPattern<memref::DeallocOp>,
452 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
454 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
457 ConversionTarget target(*ctx);
458 populateFlattenMemRefsLegality(target);
460 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
468struct FlattenMemRefCallsPass
469 :
public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
471 void runOnOperation()
override {
472 auto *ctx = &getContext();
473 TypeConverter typeConverter;
474 populateTypeConversionPatterns(typeConverter);
483 patterns.add<CallOpConversion>(typeConverter, ctx,
486 ConversionTarget target(*ctx);
487 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
488 addGenericLegalityConstraint<func::CallOp>(target);
489 addGenericLegalityConstraint<func::FuncOp>(target);
493 typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
495 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
507 return std::make_unique<FlattenMemRefPass>();
511 return std::make_unique<FlattenMemRefCallsPass>();
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...