15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
22#include "mlir/IR/BuiltinDialect.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/ImplicitLocOpBuilder.h"
25#include "mlir/IR/OperationSupport.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "llvm/Support/FormatVariadic.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
34#define GEN_PASS_DEF_FLATTENMEMREF
35#define GEN_PASS_DEF_FLATTENMEMREFCALLS
36#include "circt/Transforms/Passes.h.inc"
43 return memref.getShape().size() == 1;
57 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
58 type.getElementType());
64 return llvm::formatv(
"{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
65 type.getElementType(), uniqueID);
70static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
71 ValueRange indices, MemRefType memrefType) {
72 assert(memrefType.hasStaticShape() &&
"expected statically shaped memref");
73 Location loc = op->getLoc();
75 if (indices.empty()) {
77 return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
81 Value finalIdx = indices.front();
82 for (
auto memIdx : llvm::enumerate(indices.drop_front())) {
83 Value partialIdx = memIdx.value();
84 int64_t indexMulFactor = 1;
87 for (
unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
89 int64_t dimSize = memrefType.getShape()[i];
90 indexMulFactor *= dimSize;
94 if (llvm::isPowerOf2_64(indexMulFactor)) {
97 .create<arith::ConstantOp>(
98 loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
101 rewriter.create<arith::ShLIOp>(loc, finalIdx, constant).getResult();
103 auto constant = rewriter
104 .create<arith::ConstantOp>(
105 loc, rewriter.getIndexAttr(indexMulFactor))
108 rewriter.create<arith::MulIOp>(loc, finalIdx, constant).getResult();
112 auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
113 finalIdx = sumOp.getResult();
70static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, {
…}
119 return llvm::any_of(values, [](Value v) {
120 auto memref = dyn_cast<MemRefType>(v.getType());
130 using OpConversionPattern::OpConversionPattern;
133 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
134 ConversionPatternRewriter &rewriter)
const override {
135 MemRefType type = op.getMemRefType();
137 op.getIndices().size() == 1)
140 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
141 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
143 SmallVector<Value>{finalIdx});
149 using OpConversionPattern::OpConversionPattern;
152 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
153 ConversionPatternRewriter &rewriter)
const override {
154 MemRefType type = op.getMemRefType();
156 op.getIndices().size() == 1)
159 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
160 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
162 SmallVector<Value>{finalIdx});
168 using OpConversionPattern::OpConversionPattern;
171 matchAndRewrite(memref::AllocOp op, OpAdaptor ,
172 ConversionPatternRewriter &rewriter)
const override {
173 MemRefType type = op.getType();
177 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
183 using OpConversionPattern::OpConversionPattern;
186 matchAndRewrite(memref::AllocaOp op, OpAdaptor ,
187 ConversionPatternRewriter &rewriter)
const override {
188 MemRefType type = op.getType();
192 rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
198 using OpConversionPattern::OpConversionPattern;
201 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter)
const override {
203 MemRefType type = op.getType();
209 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
211 SmallVector<Attribute> flattenedVals;
212 for (
auto attr : cstAttr.getValues<Attribute>())
213 flattenedVals.push_back(attr);
215 auto newTypeAttr = TypeAttr::get(newType);
217 auto newName = rewriter.getStringAttr(newNameStr);
220 RankedTensorType tensorType = RankedTensorType::get(
221 {
static_cast<int64_t
>(flattenedVals.size())}, type.getElementType());
222 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
224 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
225 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
226 op.getConstantAttr(), op.getAlignmentAttr());
233 using OpConversionPattern::OpConversionPattern;
236 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter)
const override {
238 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
239 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
240 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
242 MemRefType type = globalOp.getType();
247 auto originalName = globalOp.getSymNameAttr();
251 auto newName = newNameIt->second;
253 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
260 using OpConversionPattern::OpConversionPattern;
263 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter)
const override {
265 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
266 if (!flattenedSource)
269 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
271 !flattenedSrcType.hasStaticShape()) {
272 rewriter.replaceOp(op, flattenedSource);
282template <
typename TOp>
285 using OpAdaptor =
typename TOp::Adaptor;
287 matchAndRewrite(TOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
290 adaptor.getOperands(), op->getAttrs());
297struct CondBranchOpConversion
299 using OpConversionPattern::OpConversionPattern;
302 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter)
const override {
304 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
305 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
306 adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
315 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
316 bool rewriteFunctions =
false)
318 rewriteFunctions(rewriteFunctions) {}
321 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter)
const override {
323 llvm::SmallVector<Type> convResTypes;
324 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
326 auto newCallOp = rewriter.create<func::CallOp>(
327 op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
329 if (!rewriteFunctions) {
330 rewriter.replaceOp(op, newCallOp);
337 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
338 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
339 FunctionType funcType = FunctionType::get(
340 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
341 func::FuncOp newFuncOp;
343 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
344 calledFunction, op.getCallee(), funcType);
347 rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
348 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
349 rewriter.replaceOp(op, newCallOp);
355 bool rewriteFunctions;
358template <
typename... TOp>
359void addGenericLegalityConstraint(ConversionTarget &target) {
360 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
367static void populateFlattenMemRefsLegality(ConversionTarget &target) {
368 target.addLegalDialect<arith::ArithDialect>();
369 target.addDynamicallyLegalOp<memref::AllocOp>(
371 target.addDynamicallyLegalOp<memref::AllocaOp>(
373 target.addDynamicallyLegalOp<memref::StoreOp>(
374 [](memref::StoreOp op) {
return op.getIndices().size() == 1; });
375 target.addDynamicallyLegalOp<memref::LoadOp>(
376 [](memref::LoadOp op) {
return op.getIndices().size() == 1; });
377 target.addDynamicallyLegalOp<memref::GlobalOp>(
379 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
381 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
382 func::CallOp, func::ReturnOp, memref::DeallocOp,
383 memref::CopyOp>(target);
385 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
386 auto argsConverted = llvm::none_of(op.getBlocks(), [](
auto &block) {
387 return hasMultiDimMemRef(block.getArguments());
390 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
391 if (auto memref = dyn_cast<MemRefType>(type))
392 return isUniDimensional(memref);
396 return argsConverted && resultsConverted;
403static Value materializeCollapseShapeFlattening(OpBuilder &builder,
407 assert(type.hasStaticShape() &&
408 "Can only subview flatten memref's with static shape (for now...).");
409 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
410 int64_t memSize = sourceType.getNumElements();
411 ArrayRef<int64_t> sourceShape = sourceType.getShape();
412 ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
415 auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
416 assert(indices.has_value() &&
"expected a valid collapse");
419 return builder.create<memref::CollapseShapeOp>(loc, inputs[0],
423static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
425 typeConverter.addConversion([](Type type) {
return type; });
427 typeConverter.addConversion([](MemRefType memref) {
430 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
431 memref.getElementType());
435struct FlattenMemRefPass
436 :
public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
438 void runOnOperation()
override {
440 auto *ctx = &getContext();
441 TypeConverter typeConverter;
442 populateTypeConversionPatterns(typeConverter);
445 SetVector<StringRef> rewrittenCallees;
446 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
447 AllocaOpConversion, GlobalOpConversion, GetGlobalOpConversion,
448 ReshapeOpConversion, OperandConversionPattern<func::ReturnOp>,
449 OperandConversionPattern<memref::DeallocOp>,
450 CondBranchOpConversion,
451 OperandConversionPattern<memref::DeallocOp>,
452 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
454 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
457 ConversionTarget target(*ctx);
458 populateFlattenMemRefsLegality(target);
460 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
468struct FlattenMemRefCallsPass
469 :
public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
471 void runOnOperation()
override {
472 auto *ctx = &getContext();
473 TypeConverter typeConverter;
474 populateTypeConversionPatterns(typeConverter);
483 patterns.add<CallOpConversion>(typeConverter, ctx,
486 ConversionTarget target(*ctx);
487 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
488 addGenericLegalityConstraint<func::CallOp>(target);
489 addGenericLegalityConstraint<func::FuncOp>(target);
493 typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
495 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
507 return std::make_unique<FlattenMemRefPass>();
511 return std::make_unique<FlattenMemRefCallsPass>();
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...