14 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15 #include "mlir/Conversion/LLVMCommon/Pattern.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/MathExtras.h"
31 #define GEN_PASS_DEF_FLATTENMEMREF
32 #define GEN_PASS_DEF_FLATTENMEMREFCALLS
33 #include "circt/Transforms/Passes.h.inc"
37 using namespace circt;
40 return memref.getShape().size() == 1;
55 type.getElementType());
61 return llvm::formatv(
"{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
62 type.getElementType(), uniqueID);
67 static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
68 ValueRange indices, MemRefType memrefType) {
69 assert(memrefType.hasStaticShape() &&
"expected statically shaped memref");
70 Location loc = op->getLoc();
72 if (indices.empty()) {
74 return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
78 Value finalIdx = indices.front();
79 for (
auto memIdx : llvm::enumerate(indices.drop_front())) {
80 Value partialIdx = memIdx.value();
81 int64_t indexMulFactor = 1;
84 for (
unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
86 int64_t dimSize = memrefType.getShape()[i];
87 indexMulFactor *= dimSize;
91 if (llvm::isPowerOf2_64(indexMulFactor)) {
94 .create<arith::ConstantOp>(
95 loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
98 rewriter.create<arith::ShLIOp>(loc, finalIdx, constant).getResult();
100 auto constant = rewriter
101 .create<arith::ConstantOp>(
102 loc, rewriter.getIndexAttr(indexMulFactor))
105 rewriter.create<arith::MulIOp>(loc, finalIdx, constant).getResult();
109 auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
110 finalIdx = sumOp.getResult();
116 return llvm::any_of(values, [](Value v) {
117 auto memref = dyn_cast<MemRefType>(v.getType());
127 using OpConversionPattern::OpConversionPattern;
130 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter)
const override {
132 MemRefType type = op.getMemRefType();
134 op.getIndices().size() == 1)
137 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
138 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
140 SmallVector<Value>{finalIdx});
146 using OpConversionPattern::OpConversionPattern;
149 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter)
const override {
151 MemRefType type = op.getMemRefType();
153 op.getIndices().size() == 1)
156 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
157 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
159 SmallVector<Value>{finalIdx});
165 using OpConversionPattern::OpConversionPattern;
168 matchAndRewrite(memref::AllocOp op, OpAdaptor ,
169 ConversionPatternRewriter &rewriter)
const override {
170 MemRefType type = op.getType();
174 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
180 using OpConversionPattern::OpConversionPattern;
183 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter)
const override {
185 MemRefType type = op.getType();
191 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
193 SmallVector<Attribute> flattenedVals;
194 for (
auto attr : cstAttr.getValues<Attribute>())
195 flattenedVals.push_back(attr);
199 auto newName = rewriter.getStringAttr(newNameStr);
203 {
static_cast<int64_t
>(flattenedVals.size())}, type.getElementType());
206 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
207 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
208 op.getConstantAttr(), op.getAlignmentAttr());
215 using OpConversionPattern::OpConversionPattern;
218 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter)
const override {
220 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
221 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
222 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
224 MemRefType type = globalOp.getType();
229 auto originalName = globalOp.getSymNameAttr();
233 auto newName = newNameIt->second;
235 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
243 template <
typename TOp>
246 using OpAdaptor =
typename TOp::Adaptor;
248 matchAndRewrite(TOp op, OpAdaptor adaptor,
249 ConversionPatternRewriter &rewriter)
const override {
250 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
251 adaptor.getOperands(), op->getAttrs());
258 struct CondBranchOpConversion
260 using OpConversionPattern::OpConversionPattern;
263 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter)
const override {
265 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
266 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
267 adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
276 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
277 bool rewriteFunctions =
false)
279 rewriteFunctions(rewriteFunctions) {}
282 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter)
const override {
284 llvm::SmallVector<Type> convResTypes;
285 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
287 auto newCallOp = rewriter.create<func::CallOp>(
288 op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
290 if (!rewriteFunctions) {
291 rewriter.replaceOp(op, newCallOp);
298 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
299 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
301 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
302 func::FuncOp newFuncOp;
304 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
305 calledFunction, op.getCallee(), funcType);
308 rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
309 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
310 rewriter.replaceOp(op, newCallOp);
316 bool rewriteFunctions;
319 template <
typename... TOp>
320 void addGenericLegalityConstraint(ConversionTarget &target) {
321 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
328 static void populateFlattenMemRefsLegality(ConversionTarget &target) {
329 target.addLegalDialect<arith::ArithDialect>();
330 target.addDynamicallyLegalOp<memref::AllocOp>(
332 target.addDynamicallyLegalOp<memref::StoreOp>(
333 [](memref::StoreOp op) {
return op.getIndices().size() == 1; });
334 target.addDynamicallyLegalOp<memref::LoadOp>(
335 [](memref::LoadOp op) {
return op.getIndices().size() == 1; });
336 target.addDynamicallyLegalOp<memref::GlobalOp>(
338 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
340 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
341 func::CallOp, func::ReturnOp, memref::DeallocOp,
342 memref::CopyOp>(target);
344 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
345 auto argsConverted = llvm::none_of(op.getBlocks(), [](
auto &block) {
346 return hasMultiDimMemRef(block.getArguments());
349 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
350 if (auto memref = dyn_cast<MemRefType>(type))
351 return isUniDimensional(memref);
355 return argsConverted && resultsConverted;
362 static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
363 ValueRange inputs, Location loc) {
364 assert(type.hasStaticShape() &&
365 "Can only subview flatten memref's with static shape (for now...).");
366 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
367 int64_t memSize = sourceType.getNumElements();
368 unsigned dims = sourceType.getShape().size();
371 SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
372 SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
373 offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
374 SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
377 MemRefType outType =
MemRefType::get({memSize}, type.getElementType());
378 return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
382 static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
384 typeConverter.addConversion([](Type type) {
return type; });
386 typeConverter.addConversion([](MemRefType memref) {
389 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
390 memref.getElementType());
394 struct FlattenMemRefPass
395 :
public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
397 void runOnOperation()
override {
399 auto *ctx = &getContext();
400 TypeConverter typeConverter;
401 populateTypeConversionPatterns(typeConverter);
404 SetVector<StringRef> rewrittenCallees;
405 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
406 GlobalOpConversion, GetGlobalOpConversion,
407 OperandConversionPattern<func::ReturnOp>,
408 OperandConversionPattern<memref::DeallocOp>,
409 CondBranchOpConversion,
410 OperandConversionPattern<memref::DeallocOp>,
411 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
413 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
416 ConversionTarget target(*ctx);
417 populateFlattenMemRefsLegality(target);
419 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
427 struct FlattenMemRefCallsPass
428 :
public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
430 void runOnOperation()
override {
431 auto *ctx = &getContext();
432 TypeConverter typeConverter;
433 populateTypeConversionPatterns(typeConverter);
442 patterns.add<CallOpConversion>(typeConverter, ctx,
445 ConversionTarget target(*ctx);
446 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
447 addGenericLegalityConstraint<func::CallOp>(target);
448 addGenericLegalityConstraint<func::FuncOp>(target);
452 typeConverter.addTargetMaterialization(materializeSubViewFlattening);
454 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
466 return std::make_unique<FlattenMemRefPass>();
470 return std::make_unique<FlattenMemRefCallsPass>();
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static std::atomic< unsigned > globalCounter(0)
static bool hasMultiDimMemRef(ValueRange values)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...