CIRCT 22.0.0git
Loading...
Searching...
No Matches
FlattenMemRefs.cpp
Go to the documentation of this file.
1//===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the MemRef flattening pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "circt/Support/LLVM.h"
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
22#include "mlir/IR/BuiltinDialect.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/ImplicitLocOpBuilder.h"
25#include "mlir/IR/OperationSupport.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "llvm/Support/FormatVariadic.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
32
33namespace circt {
34#define GEN_PASS_DEF_FLATTENMEMREF
35#define GEN_PASS_DEF_FLATTENMEMREFCALLS
36#include "circt/Transforms/Passes.h.inc"
37} // namespace circt
38
39using namespace mlir;
40using namespace circt;
41
42bool circt::isUniDimensional(MemRefType memref) {
43 return memref.getShape().size() == 1;
44}
45
46/// A struct for maintaining function declarations which needs to be rewritten,
47/// if they contain memref arguments that was flattened.
49 func::FuncOp op;
50 FunctionType type;
51};
52
53static std::atomic<unsigned> globalCounter(0);
54static DenseMap<StringAttr, StringAttr> globalNameMap;
55
56static MemRefType getFlattenedMemRefType(MemRefType type) {
57 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
58 type.getElementType());
59}
60
61static std::string getFlattenedMemRefName(StringAttr baseName,
62 MemRefType type) {
63 unsigned uniqueID = globalCounter++;
64 return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
65 type.getElementType(), uniqueID);
66}
67
68// Flatten indices by generating the product of the i'th index and the [0:i-1]
69// shapes, for each index, and then summing these.
70static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
71 ValueRange indices, MemRefType memrefType) {
72 assert(memrefType.hasStaticShape() && "expected statically shaped memref");
73 Location loc = op->getLoc();
74
75 if (indices.empty()) {
76 // Singleton memref (e.g. memref<i32>) - return 0.
77 return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0))
78 .getResult();
79 }
80
81 Value finalIdx = indices.front();
82 for (auto memIdx : llvm::enumerate(indices.drop_front())) {
83 Value partialIdx = memIdx.value();
84 int64_t indexMulFactor = 1;
85
86 // Calculate the product of the i'th index and the [0:i-1] shape dims.
87 for (unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
88 ++i) {
89 int64_t dimSize = memrefType.getShape()[i];
90 indexMulFactor *= dimSize;
91 }
92
93 // Multiply product by the current index operand.
94 if (llvm::isPowerOf2_64(indexMulFactor)) {
95 auto constant = arith::ConstantOp::create(
96 rewriter, loc,
97 rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
98 .getResult();
99 finalIdx =
100 arith::ShLIOp::create(rewriter, loc, finalIdx, constant).getResult();
101 } else {
102 auto constant = arith::ConstantOp::create(
103 rewriter, loc, rewriter.getIndexAttr(indexMulFactor))
104 .getResult();
105 finalIdx =
106 arith::MulIOp::create(rewriter, loc, finalIdx, constant).getResult();
107 }
108
109 // Sum up with the prior lower dimension accessors.
110 auto sumOp = arith::AddIOp::create(rewriter, loc, finalIdx, partialIdx);
111 finalIdx = sumOp.getResult();
112 }
113 return finalIdx;
114}
115
116static bool hasMultiDimMemRef(ValueRange values) {
117 return llvm::any_of(values, [](Value v) {
118 auto memref = dyn_cast<MemRefType>(v.getType());
119 if (!memref)
120 return false;
121 return !isUniDimensional(memref);
122 });
123}
124
125namespace {
126
127struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
128 using OpConversionPattern::OpConversionPattern;
129
130 LogicalResult
131 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter) const override {
133 MemRefType type = op.getMemRefType();
134 if (isUniDimensional(type) || !type.hasStaticShape() ||
135 /*Already converted?*/ op.getIndices().size() == 1)
136 return failure();
137 Value finalIdx =
138 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
139 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
140
141 SmallVector<Value>{finalIdx});
142 return success();
143 }
144};
145
146struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
147 using OpConversionPattern::OpConversionPattern;
148
149 LogicalResult
150 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
151 ConversionPatternRewriter &rewriter) const override {
152 MemRefType type = op.getMemRefType();
153 if (isUniDimensional(type) || !type.hasStaticShape() ||
154 /*Already converted?*/ op.getIndices().size() == 1)
155 return failure();
156 Value finalIdx =
157 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
158 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
159 adaptor.getMemref(),
160 SmallVector<Value>{finalIdx});
161 return success();
162 }
163};
164
165struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
166 using OpConversionPattern::OpConversionPattern;
167
168 LogicalResult
169 matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
170 ConversionPatternRewriter &rewriter) const override {
171 MemRefType type = op.getType();
172 if (isUniDimensional(type) || !type.hasStaticShape())
173 return failure();
174 MemRefType newType = getFlattenedMemRefType(type);
175 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
176 return success();
177 }
178};
179
180struct AllocaOpConversion : public OpConversionPattern<memref::AllocaOp> {
181 using OpConversionPattern::OpConversionPattern;
182
183 LogicalResult
184 matchAndRewrite(memref::AllocaOp op, OpAdaptor /*adaptor*/,
185 ConversionPatternRewriter &rewriter) const override {
186 MemRefType type = op.getType();
187 if (isUniDimensional(type) || !type.hasStaticShape())
188 return failure();
189 MemRefType newType = getFlattenedMemRefType(type);
190 rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
191 return success();
192 }
193};
194
195struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
196 using OpConversionPattern::OpConversionPattern;
197
198 LogicalResult
199 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
200 ConversionPatternRewriter &rewriter) const override {
201 MemRefType type = op.getType();
202 if (isUniDimensional(type) || !type.hasStaticShape())
203 return failure();
204 MemRefType newType = getFlattenedMemRefType(type);
205
206 auto cstAttr =
207 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
208
209 SmallVector<Attribute> flattenedVals;
210 for (auto attr : cstAttr.getValues<Attribute>())
211 flattenedVals.push_back(attr);
212
213 auto newTypeAttr = TypeAttr::get(newType);
214 auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type);
215 auto newName = rewriter.getStringAttr(newNameStr);
216 globalNameMap[op.getSymNameAttr()] = newName;
217
218 RankedTensorType tensorType = RankedTensorType::get(
219 {static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
220 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
221
222 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
223 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
224 op.getConstantAttr(), op.getAlignmentAttr());
225
226 return success();
227 }
228};
229
230struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
231 using OpConversionPattern::OpConversionPattern;
232
233 LogicalResult
234 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter) const override {
236 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
237 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
238 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
239
240 MemRefType type = globalOp.getType();
241 if (isUniDimensional(type) || !type.hasStaticShape())
242 return failure();
243
244 MemRefType newType = getFlattenedMemRefType(type);
245 auto originalName = globalOp.getSymNameAttr();
246 auto newNameIt = globalNameMap.find(originalName);
247 if (newNameIt == globalNameMap.end())
248 return failure();
249 auto newName = newNameIt->second;
250
251 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
252
253 return success();
254 }
255};
256
257struct ReshapeOpConversion : public OpConversionPattern<memref::ReshapeOp> {
258 using OpConversionPattern::OpConversionPattern;
259
260 LogicalResult
261 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override {
263 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
264 if (!flattenedSource)
265 return failure();
266
267 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
268 if (isUniDimensional(flattenedSrcType) ||
269 !flattenedSrcType.hasStaticShape()) {
270 rewriter.replaceOp(op, flattenedSource);
271 return success();
272 }
273
274 return failure();
275 }
276};
277
278// A generic pattern which will replace an op with a new op of the same type
279// but using the adaptor (type converted) operands.
280template <typename TOp>
281struct OperandConversionPattern : public OpConversionPattern<TOp> {
283 using OpAdaptor = typename TOp::Adaptor;
284 LogicalResult
285 matchAndRewrite(TOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override {
287 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
288 adaptor.getOperands(), op->getAttrs());
289 return success();
290 }
291};
292
293// Cannot use OperandConversionPattern for branch op since the default builder
294// doesn't provide a method for communicating block successors.
295struct CondBranchOpConversion
296 : public OpConversionPattern<mlir::cf::CondBranchOp> {
297 using OpConversionPattern::OpConversionPattern;
298
299 LogicalResult
300 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter) const override {
302 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
303 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
304 adaptor.getFalseDestOperands(), /*branch_weights=*/nullptr,
305 op.getTrueDest(), op.getFalseDest());
306 return success();
307 }
308};
309
310// Rewrites a call op signature to flattened types. If rewriteFunctions is set,
311// will also replace the callee with a private definition of the called
312// function of the updated signature.
313struct CallOpConversion : public OpConversionPattern<func::CallOp> {
314 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
315 bool rewriteFunctions = false)
316 : OpConversionPattern(typeConverter, context),
317 rewriteFunctions(rewriteFunctions) {}
318
319 LogicalResult
320 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter) const override {
322 llvm::SmallVector<Type> convResTypes;
323 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
324 return failure();
325 auto newCallOp =
326 func::CallOp::create(rewriter, op.getLoc(), adaptor.getCallee(),
327 convResTypes, adaptor.getOperands());
328
329 if (!rewriteFunctions) {
330 rewriter.replaceOp(op, newCallOp);
331 return success();
332 }
333
334 // Override any definition corresponding to the updated signature.
335 // It is up to users of this pass to define how these rewritten functions
336 // are to be implemented.
337 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
338 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
339 FunctionType funcType = FunctionType::get(
340 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
341 func::FuncOp newFuncOp;
342 if (calledFunction)
343 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
344 calledFunction, op.getCallee(), funcType);
345 else
346 newFuncOp =
347 func::FuncOp::create(rewriter, op.getLoc(), op.getCallee(), funcType);
348 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
349 rewriter.replaceOp(op, newCallOp);
350
351 return success();
352 }
353
354private:
355 bool rewriteFunctions;
356};
357
358template <typename... TOp>
359void addGenericLegalityConstraint(ConversionTarget &target) {
360 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
361 return !hasMultiDimMemRef(op->getOperands()) &&
362 !hasMultiDimMemRef(op->getResults());
363 }),
364 ...);
365}
366
367static void populateFlattenMemRefsLegality(ConversionTarget &target) {
368 target.addLegalDialect<arith::ArithDialect>();
369 target.addDynamicallyLegalOp<memref::AllocOp>(
370 [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
371 target.addDynamicallyLegalOp<memref::AllocaOp>(
372 [](memref::AllocaOp op) { return isUniDimensional(op.getType()); });
373 target.addDynamicallyLegalOp<memref::StoreOp>(
374 [](memref::StoreOp op) { return op.getIndices().size() == 1; });
375 target.addDynamicallyLegalOp<memref::LoadOp>(
376 [](memref::LoadOp op) { return op.getIndices().size() == 1; });
377 target.addDynamicallyLegalOp<memref::GlobalOp>(
378 [](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
379 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
380 [](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
381 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
382 func::CallOp, func::ReturnOp, memref::DeallocOp,
383 memref::CopyOp>(target);
384
385 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
386 auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
387 return hasMultiDimMemRef(block.getArguments());
388 });
389
390 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
391 if (auto memref = dyn_cast<MemRefType>(type))
392 return isUniDimensional(memref);
393 return true;
394 });
395
396 return argsConverted && resultsConverted;
397 });
398}
399
400// Materializes a multidimensional memory to unidimensional memory by using a
401// memref.collapse_shape operation.
402// TODO: This is also possible for dynamically shaped memories.
403static Value materializeCollapseShapeFlattening(OpBuilder &builder,
404 MemRefType type,
405 ValueRange inputs,
406 Location loc) {
407 assert(type.hasStaticShape() &&
408 "Can only subview flatten memref's with static shape (for now...).");
409 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
410 int64_t memSize = sourceType.getNumElements();
411 ArrayRef<int64_t> sourceShape = sourceType.getShape();
412 ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
413
414 // Build ReassociationIndices to collapse completely to 1D MemRef.
415 auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
416 assert(indices.has_value() && "expected a valid collapse");
417
418 // Generate the appropriate return type:
419 return memref::CollapseShapeOp::create(builder, loc, inputs[0],
420 indices.value());
421}
422
423static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
424 // Add default conversion for all types generically.
425 typeConverter.addConversion([](Type type) { return type; });
426 // Add specific conversion for memref types.
427 typeConverter.addConversion([](MemRefType memref) {
428 if (isUniDimensional(memref))
429 return memref;
430 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
431 memref.getElementType());
432 });
433}
434
435struct FlattenMemRefPass
436 : public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
437public:
438 void runOnOperation() override {
439
440 auto *ctx = &getContext();
441 TypeConverter typeConverter;
442 populateTypeConversionPatterns(typeConverter);
443
444 RewritePatternSet patterns(ctx);
445 SetVector<StringRef> rewrittenCallees;
446 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
447 AllocaOpConversion, GlobalOpConversion, GetGlobalOpConversion,
448 ReshapeOpConversion, OperandConversionPattern<func::ReturnOp>,
449 OperandConversionPattern<memref::DeallocOp>,
450 CondBranchOpConversion,
451 OperandConversionPattern<memref::DeallocOp>,
452 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
453 typeConverter, ctx);
454 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
455 patterns, typeConverter);
456
457 ConversionTarget target(*ctx);
458 populateFlattenMemRefsLegality(target);
459
460 if (applyPartialConversion(getOperation(), target, std::move(patterns))
461 .failed()) {
462 signalPassFailure();
463 return;
464 }
465 }
466};
467
468struct FlattenMemRefCallsPass
469 : public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
470public:
471 void runOnOperation() override {
472 auto *ctx = &getContext();
473 TypeConverter typeConverter;
474 populateTypeConversionPatterns(typeConverter);
475 RewritePatternSet patterns(ctx);
476
477 // Only run conversion on call ops within the body of the function. callee
478 // functions are rewritten by rewriteFunctions=true. We do not use
479 // populateFuncOpTypeConversionPattern to rewrite the function signatures,
480 // since non-called functions should not have their types converted.
481 // It is up to users of this pass to define how these rewritten functions
482 // are to be implemented.
483 patterns.add<CallOpConversion>(typeConverter, ctx,
484 /*rewriteFunctions=*/true);
485
486 ConversionTarget target(*ctx);
487 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
488 addGenericLegalityConstraint<func::CallOp>(target);
489 addGenericLegalityConstraint<func::FuncOp>(target);
490
491 // Add a target materializer to handle memory flattening through
492 // memref.subview operations.
493 typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
494
495 if (applyPartialConversion(getOperation(), target, std::move(patterns))
496 .failed()) {
497 signalPassFailure();
498 return;
499 }
500 }
501};
502
503} // namespace
504
505namespace circt {
506std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
507 return std::make_unique<FlattenMemRefPass>();
508}
509
510std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
511 return std::make_unique<FlattenMemRefCallsPass>();
512}
513
514} // namespace circt
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type