CIRCT 20.0.0git
Loading...
Searching...
No Matches
FlattenMemRefs.cpp
Go to the documentation of this file.
1//===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the MemRef flattening pass.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/ImplicitLocOpBuilder.h"
23#include "mlir/IR/OperationSupport.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/MathExtras.h"
29
30namespace circt {
31#define GEN_PASS_DEF_FLATTENMEMREF
32#define GEN_PASS_DEF_FLATTENMEMREFCALLS
33#include "circt/Transforms/Passes.h.inc"
34} // namespace circt
35
36using namespace mlir;
37using namespace circt;
38
39bool circt::isUniDimensional(MemRefType memref) {
40 return memref.getShape().size() == 1;
41}
42
43/// A struct for maintaining function declarations which needs to be rewritten,
44/// if they contain memref arguments that was flattened.
46 func::FuncOp op;
47 FunctionType type;
48};
49
50static std::atomic<unsigned> globalCounter(0);
51static DenseMap<StringAttr, StringAttr> globalNameMap;
52
53static MemRefType getFlattenedMemRefType(MemRefType type) {
54 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
55 type.getElementType());
56}
57
58static std::string getFlattenedMemRefName(StringAttr baseName,
59 MemRefType type) {
60 unsigned uniqueID = globalCounter++;
61 return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
62 type.getElementType(), uniqueID);
63}
64
65// Flatten indices by generating the product of the i'th index and the [0:i-1]
66// shapes, for each index, and then summing these.
67static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
68 ValueRange indices, MemRefType memrefType) {
69 assert(memrefType.hasStaticShape() && "expected statically shaped memref");
70 Location loc = op->getLoc();
71
72 if (indices.empty()) {
73 // Singleton memref (e.g. memref<i32>) - return 0.
74 return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
75 .getResult();
76 }
77
78 Value finalIdx = indices.front();
79 for (auto memIdx : llvm::enumerate(indices.drop_front())) {
80 Value partialIdx = memIdx.value();
81 int64_t indexMulFactor = 1;
82
83 // Calculate the product of the i'th index and the [0:i-1] shape dims.
84 for (unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
85 ++i) {
86 int64_t dimSize = memrefType.getShape()[i];
87 indexMulFactor *= dimSize;
88 }
89
90 // Multiply product by the current index operand.
91 if (llvm::isPowerOf2_64(indexMulFactor)) {
92 auto constant =
93 rewriter
94 .create<arith::ConstantOp>(
95 loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
96 .getResult();
97 finalIdx =
98 rewriter.create<arith::ShLIOp>(loc, finalIdx, constant).getResult();
99 } else {
100 auto constant = rewriter
101 .create<arith::ConstantOp>(
102 loc, rewriter.getIndexAttr(indexMulFactor))
103 .getResult();
104 finalIdx =
105 rewriter.create<arith::MulIOp>(loc, finalIdx, constant).getResult();
106 }
107
108 // Sum up with the prior lower dimension accessors.
109 auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
110 finalIdx = sumOp.getResult();
111 }
112 return finalIdx;
113}
114
115static bool hasMultiDimMemRef(ValueRange values) {
116 return llvm::any_of(values, [](Value v) {
117 auto memref = dyn_cast<MemRefType>(v.getType());
118 if (!memref)
119 return false;
120 return !isUniDimensional(memref);
121 });
122}
123
124namespace {
125
126struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
127 using OpConversionPattern::OpConversionPattern;
128
129 LogicalResult
130 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const override {
132 MemRefType type = op.getMemRefType();
133 if (isUniDimensional(type) || !type.hasStaticShape() ||
134 /*Already converted?*/ op.getIndices().size() == 1)
135 return failure();
136 Value finalIdx =
137 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
138 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
139
140 SmallVector<Value>{finalIdx});
141 return success();
142 }
143};
144
145struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
146 using OpConversionPattern::OpConversionPattern;
147
148 LogicalResult
149 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter) const override {
151 MemRefType type = op.getMemRefType();
152 if (isUniDimensional(type) || !type.hasStaticShape() ||
153 /*Already converted?*/ op.getIndices().size() == 1)
154 return failure();
155 Value finalIdx =
156 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
157 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
158 adaptor.getMemref(),
159 SmallVector<Value>{finalIdx});
160 return success();
161 }
162};
163
164struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
165 using OpConversionPattern::OpConversionPattern;
166
167 LogicalResult
168 matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
169 ConversionPatternRewriter &rewriter) const override {
170 MemRefType type = op.getType();
171 if (isUniDimensional(type) || !type.hasStaticShape())
172 return failure();
173 MemRefType newType = getFlattenedMemRefType(type);
174 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
175 return success();
176 }
177};
178
179struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
180 using OpConversionPattern::OpConversionPattern;
181
182 LogicalResult
183 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter) const override {
185 MemRefType type = op.getType();
186 if (isUniDimensional(type) || !type.hasStaticShape())
187 return failure();
188 MemRefType newType = getFlattenedMemRefType(type);
189
190 auto cstAttr =
191 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
192
193 SmallVector<Attribute> flattenedVals;
194 for (auto attr : cstAttr.getValues<Attribute>())
195 flattenedVals.push_back(attr);
196
197 auto newTypeAttr = TypeAttr::get(newType);
198 auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type);
199 auto newName = rewriter.getStringAttr(newNameStr);
200 globalNameMap[op.getSymNameAttr()] = newName;
201
202 RankedTensorType tensorType = RankedTensorType::get(
203 {static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
204 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
205
206 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
207 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
208 op.getConstantAttr(), op.getAlignmentAttr());
209
210 return success();
211 }
212};
213
214struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
215 using OpConversionPattern::OpConversionPattern;
216
217 LogicalResult
218 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter) const override {
220 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
221 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
222 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
223
224 MemRefType type = globalOp.getType();
225 if (isUniDimensional(type) || !type.hasStaticShape())
226 return failure();
227
228 MemRefType newType = getFlattenedMemRefType(type);
229 auto originalName = globalOp.getSymNameAttr();
230 auto newNameIt = globalNameMap.find(originalName);
231 if (newNameIt == globalNameMap.end())
232 return failure();
233 auto newName = newNameIt->second;
234
235 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
236
237 return success();
238 }
239};
240
241// A generic pattern which will replace an op with a new op of the same type
242// but using the adaptor (type converted) operands.
243template <typename TOp>
244struct OperandConversionPattern : public OpConversionPattern<TOp> {
246 using OpAdaptor = typename TOp::Adaptor;
247 LogicalResult
248 matchAndRewrite(TOp op, OpAdaptor adaptor,
249 ConversionPatternRewriter &rewriter) const override {
250 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
251 adaptor.getOperands(), op->getAttrs());
252 return success();
253 }
254};
255
256// Cannot use OperandConversionPattern for branch op since the default builder
257// doesn't provide a method for communicating block successors.
258struct CondBranchOpConversion
259 : public OpConversionPattern<mlir::cf::CondBranchOp> {
260 using OpConversionPattern::OpConversionPattern;
261
262 LogicalResult
263 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter) const override {
265 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
266 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
267 adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
268 return success();
269 }
270};
271
272// Rewrites a call op signature to flattened types. If rewriteFunctions is set,
273// will also replace the callee with a private definition of the called
274// function of the updated signature.
275struct CallOpConversion : public OpConversionPattern<func::CallOp> {
276 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
277 bool rewriteFunctions = false)
278 : OpConversionPattern(typeConverter, context),
279 rewriteFunctions(rewriteFunctions) {}
280
281 LogicalResult
282 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter) const override {
284 llvm::SmallVector<Type> convResTypes;
285 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
286 return failure();
287 auto newCallOp = rewriter.create<func::CallOp>(
288 op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
289
290 if (!rewriteFunctions) {
291 rewriter.replaceOp(op, newCallOp);
292 return success();
293 }
294
295 // Override any definition corresponding to the updated signature.
296 // It is up to users of this pass to define how these rewritten functions
297 // are to be implemented.
298 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
299 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
300 FunctionType funcType = FunctionType::get(
301 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
302 func::FuncOp newFuncOp;
303 if (calledFunction)
304 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
305 calledFunction, op.getCallee(), funcType);
306 else
307 newFuncOp =
308 rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
309 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
310 rewriter.replaceOp(op, newCallOp);
311
312 return success();
313 }
314
315private:
316 bool rewriteFunctions;
317};
318
319template <typename... TOp>
320void addGenericLegalityConstraint(ConversionTarget &target) {
321 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
322 return !hasMultiDimMemRef(op->getOperands()) &&
323 !hasMultiDimMemRef(op->getResults());
324 }),
325 ...);
326}
327
328static void populateFlattenMemRefsLegality(ConversionTarget &target) {
329 target.addLegalDialect<arith::ArithDialect>();
330 target.addDynamicallyLegalOp<memref::AllocOp>(
331 [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
332 target.addDynamicallyLegalOp<memref::StoreOp>(
333 [](memref::StoreOp op) { return op.getIndices().size() == 1; });
334 target.addDynamicallyLegalOp<memref::LoadOp>(
335 [](memref::LoadOp op) { return op.getIndices().size() == 1; });
336 target.addDynamicallyLegalOp<memref::GlobalOp>(
337 [](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
338 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
339 [](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
340 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
341 func::CallOp, func::ReturnOp, memref::DeallocOp,
342 memref::CopyOp>(target);
343
344 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
345 auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
346 return hasMultiDimMemRef(block.getArguments());
347 });
348
349 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
350 if (auto memref = dyn_cast<MemRefType>(type))
351 return isUniDimensional(memref);
352 return true;
353 });
354
355 return argsConverted && resultsConverted;
356 });
357}
358
359// Materializes a multidimensional memory to unidimensional memory by using a
360// memref.subview operation.
361// TODO: This is also possible for dynamically shaped memories.
362static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
363 ValueRange inputs, Location loc) {
364 assert(type.hasStaticShape() &&
365 "Can only subview flatten memref's with static shape (for now...).");
366 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
367 int64_t memSize = sourceType.getNumElements();
368 unsigned dims = sourceType.getShape().size();
369
370 // Build offset, sizes and strides
371 SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
372 SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
373 offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
374 SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
375
376 // Generate the appropriate return type:
377 MemRefType outType = MemRefType::get({memSize}, type.getElementType());
378 return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
379 offsets, strides);
380}
381
382static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
383 // Add default conversion for all types generically.
384 typeConverter.addConversion([](Type type) { return type; });
385 // Add specific conversion for memref types.
386 typeConverter.addConversion([](MemRefType memref) {
387 if (isUniDimensional(memref))
388 return memref;
389 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
390 memref.getElementType());
391 });
392}
393
394struct FlattenMemRefPass
395 : public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
396public:
397 void runOnOperation() override {
398
399 auto *ctx = &getContext();
400 TypeConverter typeConverter;
401 populateTypeConversionPatterns(typeConverter);
402
403 RewritePatternSet patterns(ctx);
404 SetVector<StringRef> rewrittenCallees;
405 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
406 GlobalOpConversion, GetGlobalOpConversion,
407 OperandConversionPattern<func::ReturnOp>,
408 OperandConversionPattern<memref::DeallocOp>,
409 CondBranchOpConversion,
410 OperandConversionPattern<memref::DeallocOp>,
411 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
412 typeConverter, ctx);
413 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
414 patterns, typeConverter);
415
416 ConversionTarget target(*ctx);
417 populateFlattenMemRefsLegality(target);
418
419 if (applyPartialConversion(getOperation(), target, std::move(patterns))
420 .failed()) {
421 signalPassFailure();
422 return;
423 }
424 }
425};
426
427struct FlattenMemRefCallsPass
428 : public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
429public:
430 void runOnOperation() override {
431 auto *ctx = &getContext();
432 TypeConverter typeConverter;
433 populateTypeConversionPatterns(typeConverter);
434 RewritePatternSet patterns(ctx);
435
436 // Only run conversion on call ops within the body of the function. callee
437 // functions are rewritten by rewriteFunctions=true. We do not use
438 // populateFuncOpTypeConversionPattern to rewrite the function signatures,
439 // since non-called functions should not have their types converted.
440 // It is up to users of this pass to define how these rewritten functions
441 // are to be implemented.
442 patterns.add<CallOpConversion>(typeConverter, ctx,
443 /*rewriteFunctions=*/true);
444
445 ConversionTarget target(*ctx);
446 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
447 addGenericLegalityConstraint<func::CallOp>(target);
448 addGenericLegalityConstraint<func::FuncOp>(target);
449
450 // Add a target materializer to handle memory flattening through
451 // memref.subview operations.
452 typeConverter.addTargetMaterialization(materializeSubViewFlattening);
453
454 if (applyPartialConversion(getOperation(), target, std::move(patterns))
455 .failed()) {
456 signalPassFailure();
457 return;
458 }
459 }
460};
461
462} // namespace
463
464namespace circt {
465std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
466 return std::make_unique<FlattenMemRefPass>();
467}
468
469std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
470 return std::make_unique<FlattenMemRefCallsPass>();
471}
472
473} // namespace circt
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type