CIRCT 22.0.0git
Loading...
Searching...
No Matches
FlattenMemRefs.cpp
Go to the documentation of this file.
1//===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the MemRef flattening pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "circt/Support/LLVM.h"
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23#include "mlir/IR/BuiltinDialect.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/OperationSupport.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "llvm/Support/FormatVariadic.h"
31#include "llvm/Support/LogicalResult.h"
32#include "llvm/Support/MathExtras.h"
33
34namespace circt {
35#define GEN_PASS_DEF_FLATTENMEMREF
36#define GEN_PASS_DEF_FLATTENMEMREFCALLS
37#include "circt/Transforms/Passes.h.inc"
38} // namespace circt
39
40using namespace mlir;
41using namespace circt;
42
43bool circt::isUniDimensional(MemRefType memref) {
44 return memref.getShape().size() == 1;
45}
46
47/// A struct for maintaining function declarations which needs to be rewritten,
48/// if they contain memref arguments that was flattened.
50 func::FuncOp op;
51 FunctionType type;
52};
53
54static std::atomic<unsigned> globalCounter(0);
55static DenseMap<StringAttr, StringAttr> globalNameMap;
56
57static MemRefType getFlattenedMemRefType(MemRefType type) {
58 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
59 type.getElementType());
60}
61
62static std::string getFlattenedMemRefName(StringAttr baseName,
63 MemRefType type) {
64 unsigned uniqueID = globalCounter++;
65 return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
66 type.getElementType(), uniqueID);
67}
68
69// Flatten indices by generating the product of the i'th index and the [0:i-1]
70// shapes, for each index, and then summing these.
71static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
72 ValueRange indices, MemRefType memrefType) {
73 assert(memrefType.hasStaticShape() && "expected statically shaped memref");
74 Location loc = op->getLoc();
75
76 if (indices.empty()) {
77 // Singleton memref (e.g. memref<i32>) - return 0.
78 return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0))
79 .getResult();
80 }
81
82 Value finalIdx = indices.front();
83 for (auto memIdx : llvm::enumerate(indices.drop_front())) {
84 Value partialIdx = memIdx.value();
85 int64_t indexMulFactor = 1;
86
87 // Calculate the product of the i'th index and the [0:i-1] shape dims.
88 for (unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
89 ++i) {
90 int64_t dimSize = memrefType.getShape()[i];
91 indexMulFactor *= dimSize;
92 }
93
94 // Multiply product by the current index operand.
95 if (llvm::isPowerOf2_64(indexMulFactor)) {
96 auto constant = arith::ConstantOp::create(
97 rewriter, loc,
98 rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
99 .getResult();
100 finalIdx =
101 arith::ShLIOp::create(rewriter, loc, finalIdx, constant).getResult();
102 } else {
103 auto constant = arith::ConstantOp::create(
104 rewriter, loc, rewriter.getIndexAttr(indexMulFactor))
105 .getResult();
106 finalIdx =
107 arith::MulIOp::create(rewriter, loc, finalIdx, constant).getResult();
108 }
109
110 // Sum up with the prior lower dimension accessors.
111 auto sumOp = arith::AddIOp::create(rewriter, loc, finalIdx, partialIdx);
112 finalIdx = sumOp.getResult();
113 }
114 return finalIdx;
115}
116
117static bool hasMultiDimMemRef(ValueRange values) {
118 return llvm::any_of(values, [](Value v) {
119 auto memref = dyn_cast<MemRefType>(v.getType());
120 if (!memref)
121 return false;
122 return !isUniDimensional(memref);
123 });
124}
125
126namespace {
127
128struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
129 using OpConversionPattern::OpConversionPattern;
130
131 LogicalResult
132 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 MemRefType type = op.getMemRefType();
135 if (isUniDimensional(type) || !type.hasStaticShape() ||
136 /*Already converted?*/ op.getIndices().size() == 1)
137 return failure();
138 Value finalIdx =
139 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
140 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
141
142 SmallVector<Value>{finalIdx});
143 return success();
144 }
145};
146
147struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
148 using OpConversionPattern::OpConversionPattern;
149
150 LogicalResult
151 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter) const override {
153 MemRefType type = op.getMemRefType();
154 if (isUniDimensional(type) || !type.hasStaticShape() ||
155 /*Already converted?*/ op.getIndices().size() == 1)
156 return failure();
157 Value finalIdx =
158 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
159 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
160 adaptor.getMemref(),
161 SmallVector<Value>{finalIdx});
162 return success();
163 }
164};
165
166struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
167 using OpConversionPattern::OpConversionPattern;
168
169 LogicalResult
170 matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
171 ConversionPatternRewriter &rewriter) const override {
172 MemRefType type = op.getType();
173 if (isUniDimensional(type) || !type.hasStaticShape())
174 return failure();
175 MemRefType newType = getFlattenedMemRefType(type);
176 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
177 return success();
178 }
179};
180
181struct AllocaOpConversion : public OpConversionPattern<memref::AllocaOp> {
182 using OpConversionPattern::OpConversionPattern;
183
184 LogicalResult
185 matchAndRewrite(memref::AllocaOp op, OpAdaptor /*adaptor*/,
186 ConversionPatternRewriter &rewriter) const override {
187 MemRefType type = op.getType();
188 if (isUniDimensional(type) || !type.hasStaticShape())
189 return failure();
190 MemRefType newType = getFlattenedMemRefType(type);
191 rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
192 return success();
193 }
194};
195
196struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
197 using OpConversionPattern::OpConversionPattern;
198
199 LogicalResult
200 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const override {
202 MemRefType type = op.getType();
203 if (isUniDimensional(type) || !type.hasStaticShape())
204 return failure();
205 MemRefType newType = getFlattenedMemRefType(type);
206
207 auto cstAttr =
208 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
209
210 SmallVector<Attribute> flattenedVals;
211 for (auto attr : cstAttr.getValues<Attribute>())
212 flattenedVals.push_back(attr);
213
214 auto newTypeAttr = TypeAttr::get(newType);
215 auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type);
216 auto newName = rewriter.getStringAttr(newNameStr);
217 globalNameMap[op.getSymNameAttr()] = newName;
218
219 RankedTensorType tensorType = RankedTensorType::get(
220 {static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
221 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
222
223 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
224 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
225 op.getConstantAttr(), op.getAlignmentAttr());
226
227 return success();
228 }
229};
230
231struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
232 using OpConversionPattern::OpConversionPattern;
233
234 LogicalResult
235 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter) const override {
237 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
238 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
239 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
240
241 MemRefType type = globalOp.getType();
242 if (isUniDimensional(type) || !type.hasStaticShape())
243 return failure();
244
245 MemRefType newType = getFlattenedMemRefType(type);
246 auto originalName = globalOp.getSymNameAttr();
247 auto newNameIt = globalNameMap.find(originalName);
248 if (newNameIt == globalNameMap.end())
249 return failure();
250 auto newName = newNameIt->second;
251
252 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
253
254 return success();
255 }
256};
257
258struct ReshapeOpConversion : public OpConversionPattern<memref::ReshapeOp> {
259 using OpConversionPattern::OpConversionPattern;
260
261 LogicalResult
262 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter) const override {
264 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
265 if (!flattenedSource)
266 return failure();
267
268 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
269 if (isUniDimensional(flattenedSrcType) ||
270 !flattenedSrcType.hasStaticShape()) {
271 rewriter.replaceOp(op, flattenedSource);
272 return success();
273 }
274
275 return failure();
276 }
277};
278
279// A generic pattern which will replace an op with a new op of the same type
280// but using the adaptor (type converted) operands.
281template <typename TOp>
282struct OperandConversionPattern : public OpConversionPattern<TOp> {
284 using OpAdaptor = typename TOp::Adaptor;
285 LogicalResult
286 matchAndRewrite(TOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter) const override {
288 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
289 adaptor.getOperands(), op->getAttrs());
290 return success();
291 }
292};
293
294// Rewrites a call op signature to flattened types. If rewriteFunctions is set,
295// will also replace the callee with a private definition of the called
296// function of the updated signature.
297struct CallOpConversion : public OpConversionPattern<func::CallOp> {
298 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
299 bool rewriteFunctions = false)
300 : OpConversionPattern(typeConverter, context),
301 rewriteFunctions(rewriteFunctions) {}
302
303 LogicalResult
304 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter) const override {
306 llvm::SmallVector<Type> convResTypes;
307 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
308 return failure();
309 auto newCallOp =
310 func::CallOp::create(rewriter, op.getLoc(), adaptor.getCallee(),
311 convResTypes, adaptor.getOperands());
312
313 if (!rewriteFunctions) {
314 rewriter.replaceOp(op, newCallOp);
315 return success();
316 }
317
318 // Override any definition corresponding to the updated signature.
319 // It is up to users of this pass to define how these rewritten functions
320 // are to be implemented.
321 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
322 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
323 FunctionType funcType = FunctionType::get(
324 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
325 func::FuncOp newFuncOp;
326 if (calledFunction)
327 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
328 calledFunction, op.getCallee(), funcType);
329 else
330 newFuncOp =
331 func::FuncOp::create(rewriter, op.getLoc(), op.getCallee(), funcType);
332 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
333 rewriter.replaceOp(op, newCallOp);
334
335 return success();
336 }
337
338private:
339 bool rewriteFunctions;
340};
341
342template <typename... TOp>
343void addGenericLegalityConstraint(ConversionTarget &target) {
344 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
345 return !hasMultiDimMemRef(op->getOperands()) &&
346 !hasMultiDimMemRef(op->getResults());
347 }),
348 ...);
349}
350
351static void populateFlattenMemRefsLegality(ConversionTarget &target) {
352 target.addLegalDialect<arith::ArithDialect>();
353 target.addDynamicallyLegalOp<memref::AllocOp>(
354 [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
355 target.addDynamicallyLegalOp<memref::AllocaOp>(
356 [](memref::AllocaOp op) { return isUniDimensional(op.getType()); });
357 target.addDynamicallyLegalOp<memref::StoreOp>(
358 [](memref::StoreOp op) { return op.getIndices().size() == 1; });
359 target.addDynamicallyLegalOp<memref::LoadOp>(
360 [](memref::LoadOp op) { return op.getIndices().size() == 1; });
361 target.addDynamicallyLegalOp<memref::GlobalOp>(
362 [](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
363 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
364 [](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
365 addGenericLegalityConstraint<func::CallOp, func::ReturnOp, memref::DeallocOp,
366 memref::CopyOp>(target);
367
368 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
369 auto argsConverted = llvm::all_of(op.getArgumentTypes(), [](Type type) {
370 if (auto memref = dyn_cast<MemRefType>(type))
371 return isUniDimensional(memref);
372 return true;
373 });
374
375 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
376 if (auto memref = dyn_cast<MemRefType>(type))
377 return isUniDimensional(memref);
378 return true;
379 });
380
381 return argsConverted && resultsConverted;
382 });
383}
384
385// Materializes a multidimensional memory to unidimensional memory by using a
386// memref.collapse_shape operation.
387// TODO: This is also possible for dynamically shaped memories.
388static Value materializeCollapseShapeFlattening(OpBuilder &builder,
389 MemRefType type,
390 ValueRange inputs,
391 Location loc) {
392 assert(type.hasStaticShape() &&
393 "Can only subview flatten memref's with static shape (for now...).");
394 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
395 int64_t memSize = sourceType.getNumElements();
396 ArrayRef<int64_t> sourceShape = sourceType.getShape();
397 ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
398
399 // Build ReassociationIndices to collapse completely to 1D MemRef.
400 auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
401 assert(indices.has_value() && "expected a valid collapse");
402
403 // Generate the appropriate return type:
404 return memref::CollapseShapeOp::create(builder, loc, inputs[0],
405 indices.value());
406}
407
408static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
409 // Add default conversion for all types generically.
410 typeConverter.addConversion([](Type type) { return type; });
411 // Add specific conversion for memref types.
412 typeConverter.addConversion([](MemRefType memref) {
413 if (isUniDimensional(memref))
414 return memref;
415 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
416 memref.getElementType());
417 });
418}
419
420struct FlattenMemRefPass
421 : public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
422public:
423 void runOnOperation() override {
424
425 auto *ctx = &getContext();
426 TypeConverter typeConverter;
427 populateTypeConversionPatterns(typeConverter);
428
429 RewritePatternSet patterns(ctx);
430 SetVector<StringRef> rewrittenCallees;
431 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
432 AllocaOpConversion, GlobalOpConversion, GetGlobalOpConversion,
433 ReshapeOpConversion, OperandConversionPattern<func::ReturnOp>,
434 OperandConversionPattern<memref::DeallocOp>,
435 OperandConversionPattern<memref::DeallocOp>,
436 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
437 typeConverter, ctx);
438 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
439 patterns, typeConverter);
440
441 ConversionTarget target(*ctx);
442 populateFlattenMemRefsLegality(target);
443 mlir::cf::populateCFStructuralTypeConversionsAndLegality(typeConverter,
444 patterns, target);
445
446 if (applyPartialConversion(getOperation(), target, std::move(patterns))
447 .failed()) {
448 signalPassFailure();
449 return;
450 }
451 }
452};
453
454struct FlattenMemRefCallsPass
455 : public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
456public:
457 void runOnOperation() override {
458 auto *ctx = &getContext();
459 TypeConverter typeConverter;
460 populateTypeConversionPatterns(typeConverter);
461 RewritePatternSet patterns(ctx);
462
463 // Only run conversion on call ops within the body of the function. callee
464 // functions are rewritten by rewriteFunctions=true. We do not use
465 // populateFuncOpTypeConversionPattern to rewrite the function signatures,
466 // since non-called functions should not have their types converted.
467 // It is up to users of this pass to define how these rewritten functions
468 // are to be implemented.
469 patterns.add<CallOpConversion>(typeConverter, ctx,
470 /*rewriteFunctions=*/true);
471
472 ConversionTarget target(*ctx);
473 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
474 addGenericLegalityConstraint<func::CallOp>(target);
475 addGenericLegalityConstraint<func::FuncOp>(target);
476
477 // Add a target materializer to handle memory flattening through
478 // memref.subview operations.
479 typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
480
481 if (applyPartialConversion(getOperation(), target, std::move(patterns))
482 .failed()) {
483 signalPassFailure();
484 return;
485 }
486 }
487};
488
489} // namespace
490
491namespace circt {
492std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
493 return std::make_unique<FlattenMemRefPass>();
494}
495
496std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
497 return std::make_unique<FlattenMemRefCallsPass>();
498}
499
500} // namespace circt
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type