CIRCT 21.0.0git
Loading...
Searching...
No Matches
FlattenMemRefs.cpp
Go to the documentation of this file.
1//===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the MemRef flattening pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "circt/Support/LLVM.h"
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/IR/BuiltinDialect.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/ImplicitLocOpBuilder.h"
24#include "mlir/IR/OperationSupport.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/Support/FormatVariadic.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/MathExtras.h"
31
32namespace circt {
33#define GEN_PASS_DEF_FLATTENMEMREF
34#define GEN_PASS_DEF_FLATTENMEMREFCALLS
35#include "circt/Transforms/Passes.h.inc"
36} // namespace circt
37
38using namespace mlir;
39using namespace circt;
40
41bool circt::isUniDimensional(MemRefType memref) {
42 return memref.getShape().size() == 1;
43}
44
45/// A struct for maintaining function declarations which needs to be rewritten,
46/// if they contain memref arguments that was flattened.
48 func::FuncOp op;
49 FunctionType type;
50};
51
52static std::atomic<unsigned> globalCounter(0);
53static DenseMap<StringAttr, StringAttr> globalNameMap;
54
55static MemRefType getFlattenedMemRefType(MemRefType type) {
56 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
57 type.getElementType());
58}
59
60static std::string getFlattenedMemRefName(StringAttr baseName,
61 MemRefType type) {
62 unsigned uniqueID = globalCounter++;
63 return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
64 type.getElementType(), uniqueID);
65}
66
67// Flatten indices by generating the product of the i'th index and the [0:i-1]
68// shapes, for each index, and then summing these.
69static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
70 ValueRange indices, MemRefType memrefType) {
71 assert(memrefType.hasStaticShape() && "expected statically shaped memref");
72 Location loc = op->getLoc();
73
74 if (indices.empty()) {
75 // Singleton memref (e.g. memref<i32>) - return 0.
76 return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
77 .getResult();
78 }
79
80 Value finalIdx = indices.front();
81 for (auto memIdx : llvm::enumerate(indices.drop_front())) {
82 Value partialIdx = memIdx.value();
83 int64_t indexMulFactor = 1;
84
85 // Calculate the product of the i'th index and the [0:i-1] shape dims.
86 for (unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
87 ++i) {
88 int64_t dimSize = memrefType.getShape()[i];
89 indexMulFactor *= dimSize;
90 }
91
92 // Multiply product by the current index operand.
93 if (llvm::isPowerOf2_64(indexMulFactor)) {
94 auto constant =
95 rewriter
96 .create<arith::ConstantOp>(
97 loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
98 .getResult();
99 finalIdx =
100 rewriter.create<arith::ShLIOp>(loc, finalIdx, constant).getResult();
101 } else {
102 auto constant = rewriter
103 .create<arith::ConstantOp>(
104 loc, rewriter.getIndexAttr(indexMulFactor))
105 .getResult();
106 finalIdx =
107 rewriter.create<arith::MulIOp>(loc, finalIdx, constant).getResult();
108 }
109
110 // Sum up with the prior lower dimension accessors.
111 auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
112 finalIdx = sumOp.getResult();
113 }
114 return finalIdx;
115}
116
117static bool hasMultiDimMemRef(ValueRange values) {
118 return llvm::any_of(values, [](Value v) {
119 auto memref = dyn_cast<MemRefType>(v.getType());
120 if (!memref)
121 return false;
122 return !isUniDimensional(memref);
123 });
124}
125
126namespace {
127
128struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
129 using OpConversionPattern::OpConversionPattern;
130
131 LogicalResult
132 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 MemRefType type = op.getMemRefType();
135 if (isUniDimensional(type) || !type.hasStaticShape() ||
136 /*Already converted?*/ op.getIndices().size() == 1)
137 return failure();
138 Value finalIdx =
139 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
140 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
141
142 SmallVector<Value>{finalIdx});
143 return success();
144 }
145};
146
147struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
148 using OpConversionPattern::OpConversionPattern;
149
150 LogicalResult
151 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter) const override {
153 MemRefType type = op.getMemRefType();
154 if (isUniDimensional(type) || !type.hasStaticShape() ||
155 /*Already converted?*/ op.getIndices().size() == 1)
156 return failure();
157 Value finalIdx =
158 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
159 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
160 adaptor.getMemref(),
161 SmallVector<Value>{finalIdx});
162 return success();
163 }
164};
165
166struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
167 using OpConversionPattern::OpConversionPattern;
168
169 LogicalResult
170 matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
171 ConversionPatternRewriter &rewriter) const override {
172 MemRefType type = op.getType();
173 if (isUniDimensional(type) || !type.hasStaticShape())
174 return failure();
175 MemRefType newType = getFlattenedMemRefType(type);
176 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
177 return success();
178 }
179};
180
181struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
182 using OpConversionPattern::OpConversionPattern;
183
184 LogicalResult
185 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const override {
187 MemRefType type = op.getType();
188 if (isUniDimensional(type) || !type.hasStaticShape())
189 return failure();
190 MemRefType newType = getFlattenedMemRefType(type);
191
192 auto cstAttr =
193 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
194
195 SmallVector<Attribute> flattenedVals;
196 for (auto attr : cstAttr.getValues<Attribute>())
197 flattenedVals.push_back(attr);
198
199 auto newTypeAttr = TypeAttr::get(newType);
200 auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type);
201 auto newName = rewriter.getStringAttr(newNameStr);
202 globalNameMap[op.getSymNameAttr()] = newName;
203
204 RankedTensorType tensorType = RankedTensorType::get(
205 {static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
206 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
207
208 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
209 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
210 op.getConstantAttr(), op.getAlignmentAttr());
211
212 return success();
213 }
214};
215
216struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
217 using OpConversionPattern::OpConversionPattern;
218
219 LogicalResult
220 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
221 ConversionPatternRewriter &rewriter) const override {
222 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
223 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
224 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
225
226 MemRefType type = globalOp.getType();
227 if (isUniDimensional(type) || !type.hasStaticShape())
228 return failure();
229
230 MemRefType newType = getFlattenedMemRefType(type);
231 auto originalName = globalOp.getSymNameAttr();
232 auto newNameIt = globalNameMap.find(originalName);
233 if (newNameIt == globalNameMap.end())
234 return failure();
235 auto newName = newNameIt->second;
236
237 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
238
239 return success();
240 }
241};
242
243struct ReshapeOpConversion : public OpConversionPattern<memref::ReshapeOp> {
244 using OpConversionPattern::OpConversionPattern;
245
246 LogicalResult
247 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
248 ConversionPatternRewriter &rewriter) const override {
249 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
250 if (!flattenedSource)
251 return failure();
252
253 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
254 if (isUniDimensional(flattenedSrcType) ||
255 !flattenedSrcType.hasStaticShape()) {
256 rewriter.replaceOp(op, flattenedSource);
257 return success();
258 }
259
260 return failure();
261 }
262};
263
264// A generic pattern which will replace an op with a new op of the same type
265// but using the adaptor (type converted) operands.
266template <typename TOp>
267struct OperandConversionPattern : public OpConversionPattern<TOp> {
269 using OpAdaptor = typename TOp::Adaptor;
270 LogicalResult
271 matchAndRewrite(TOp op, OpAdaptor adaptor,
272 ConversionPatternRewriter &rewriter) const override {
273 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
274 adaptor.getOperands(), op->getAttrs());
275 return success();
276 }
277};
278
279// Cannot use OperandConversionPattern for branch op since the default builder
280// doesn't provide a method for communicating block successors.
281struct CondBranchOpConversion
282 : public OpConversionPattern<mlir::cf::CondBranchOp> {
283 using OpConversionPattern::OpConversionPattern;
284
285 LogicalResult
286 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter) const override {
288 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
289 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
290 adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
291 return success();
292 }
293};
294
295// Rewrites a call op signature to flattened types. If rewriteFunctions is set,
296// will also replace the callee with a private definition of the called
297// function of the updated signature.
298struct CallOpConversion : public OpConversionPattern<func::CallOp> {
299 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
300 bool rewriteFunctions = false)
301 : OpConversionPattern(typeConverter, context),
302 rewriteFunctions(rewriteFunctions) {}
303
304 LogicalResult
305 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter) const override {
307 llvm::SmallVector<Type> convResTypes;
308 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
309 return failure();
310 auto newCallOp = rewriter.create<func::CallOp>(
311 op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
312
313 if (!rewriteFunctions) {
314 rewriter.replaceOp(op, newCallOp);
315 return success();
316 }
317
318 // Override any definition corresponding to the updated signature.
319 // It is up to users of this pass to define how these rewritten functions
320 // are to be implemented.
321 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
322 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
323 FunctionType funcType = FunctionType::get(
324 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
325 func::FuncOp newFuncOp;
326 if (calledFunction)
327 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
328 calledFunction, op.getCallee(), funcType);
329 else
330 newFuncOp =
331 rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
332 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
333 rewriter.replaceOp(op, newCallOp);
334
335 return success();
336 }
337
338private:
339 bool rewriteFunctions;
340};
341
342template <typename... TOp>
343void addGenericLegalityConstraint(ConversionTarget &target) {
344 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
345 return !hasMultiDimMemRef(op->getOperands()) &&
346 !hasMultiDimMemRef(op->getResults());
347 }),
348 ...);
349}
350
351static void populateFlattenMemRefsLegality(ConversionTarget &target) {
352 target.addLegalDialect<arith::ArithDialect>();
353 target.addDynamicallyLegalOp<memref::AllocOp>(
354 [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
355 target.addDynamicallyLegalOp<memref::StoreOp>(
356 [](memref::StoreOp op) { return op.getIndices().size() == 1; });
357 target.addDynamicallyLegalOp<memref::LoadOp>(
358 [](memref::LoadOp op) { return op.getIndices().size() == 1; });
359 target.addDynamicallyLegalOp<memref::GlobalOp>(
360 [](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
361 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
362 [](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
363 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
364 func::CallOp, func::ReturnOp, memref::DeallocOp,
365 memref::CopyOp>(target);
366
367 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
368 auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
369 return hasMultiDimMemRef(block.getArguments());
370 });
371
372 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
373 if (auto memref = dyn_cast<MemRefType>(type))
374 return isUniDimensional(memref);
375 return true;
376 });
377
378 return argsConverted && resultsConverted;
379 });
380}
381
382// Materializes a multidimensional memory to unidimensional memory by using a
383// memref.subview operation.
384// TODO: This is also possible for dynamically shaped memories.
385static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
386 ValueRange inputs, Location loc) {
387 assert(type.hasStaticShape() &&
388 "Can only subview flatten memref's with static shape (for now...).");
389 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
390 int64_t memSize = sourceType.getNumElements();
391 unsigned dims = sourceType.getShape().size();
392
393 // Build offset, sizes and strides
394 SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
395 SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
396 offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
397 SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
398
399 // Generate the appropriate return type:
400 MemRefType outType = MemRefType::get({memSize}, type.getElementType());
401 return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
402 offsets, strides);
403}
404
405static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
406 // Add default conversion for all types generically.
407 typeConverter.addConversion([](Type type) { return type; });
408 // Add specific conversion for memref types.
409 typeConverter.addConversion([](MemRefType memref) {
410 if (isUniDimensional(memref))
411 return memref;
412 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
413 memref.getElementType());
414 });
415}
416
417struct FlattenMemRefPass
418 : public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
419public:
420 void runOnOperation() override {
421
422 auto *ctx = &getContext();
423 TypeConverter typeConverter;
424 populateTypeConversionPatterns(typeConverter);
425
426 RewritePatternSet patterns(ctx);
427 SetVector<StringRef> rewrittenCallees;
428 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
429 GlobalOpConversion, GetGlobalOpConversion, ReshapeOpConversion,
430 OperandConversionPattern<func::ReturnOp>,
431 OperandConversionPattern<memref::DeallocOp>,
432 CondBranchOpConversion,
433 OperandConversionPattern<memref::DeallocOp>,
434 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
435 typeConverter, ctx);
436 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
437 patterns, typeConverter);
438
439 ConversionTarget target(*ctx);
440 populateFlattenMemRefsLegality(target);
441
442 if (applyPartialConversion(getOperation(), target, std::move(patterns))
443 .failed()) {
444 signalPassFailure();
445 return;
446 }
447 }
448};
449
450struct FlattenMemRefCallsPass
451 : public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
452public:
453 void runOnOperation() override {
454 auto *ctx = &getContext();
455 TypeConverter typeConverter;
456 populateTypeConversionPatterns(typeConverter);
457 RewritePatternSet patterns(ctx);
458
459 // Only run conversion on call ops within the body of the function. callee
460 // functions are rewritten by rewriteFunctions=true. We do not use
461 // populateFuncOpTypeConversionPattern to rewrite the function signatures,
462 // since non-called functions should not have their types converted.
463 // It is up to users of this pass to define how these rewritten functions
464 // are to be implemented.
465 patterns.add<CallOpConversion>(typeConverter, ctx,
466 /*rewriteFunctions=*/true);
467
468 ConversionTarget target(*ctx);
469 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
470 addGenericLegalityConstraint<func::CallOp>(target);
471 addGenericLegalityConstraint<func::FuncOp>(target);
472
473 // Add a target materializer to handle memory flattening through
474 // memref.subview operations.
475 typeConverter.addTargetMaterialization(materializeSubViewFlattening);
476
477 if (applyPartialConversion(getOperation(), target, std::move(patterns))
478 .failed()) {
479 signalPassFailure();
480 return;
481 }
482 }
483};
484
485} // namespace
486
487namespace circt {
488std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
489 return std::make_unique<FlattenMemRefPass>();
490}
491
492std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
493 return std::make_unique<FlattenMemRefCallsPass>();
494}
495
496} // namespace circt
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type