Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
FlattenMemRefs.cpp
Go to the documentation of this file.
1//===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the MemRef flattening pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "circt/Support/LLVM.h"
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
22#include "mlir/IR/BuiltinDialect.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/ImplicitLocOpBuilder.h"
25#include "mlir/IR/OperationSupport.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "llvm/Support/FormatVariadic.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
32
33namespace circt {
34#define GEN_PASS_DEF_FLATTENMEMREF
35#define GEN_PASS_DEF_FLATTENMEMREFCALLS
36#include "circt/Transforms/Passes.h.inc"
37} // namespace circt
38
39using namespace mlir;
40using namespace circt;
41
42bool circt::isUniDimensional(MemRefType memref) {
43 return memref.getShape().size() == 1;
44}
45
46/// A struct for maintaining function declarations which needs to be rewritten,
47/// if they contain memref arguments that was flattened.
49 func::FuncOp op;
50 FunctionType type;
51};
52
53static std::atomic<unsigned> globalCounter(0);
54static DenseMap<StringAttr, StringAttr> globalNameMap;
55
56static MemRefType getFlattenedMemRefType(MemRefType type) {
57 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
58 type.getElementType());
59}
60
61static std::string getFlattenedMemRefName(StringAttr baseName,
62 MemRefType type) {
63 unsigned uniqueID = globalCounter++;
64 return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
65 type.getElementType(), uniqueID);
66}
67
68// Flatten indices by generating the product of the i'th index and the [0:i-1]
69// shapes, for each index, and then summing these.
70static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
71 ValueRange indices, MemRefType memrefType) {
72 assert(memrefType.hasStaticShape() && "expected statically shaped memref");
73 Location loc = op->getLoc();
74
75 if (indices.empty()) {
76 // Singleton memref (e.g. memref<i32>) - return 0.
77 return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
78 .getResult();
79 }
80
81 Value finalIdx = indices.front();
82 for (auto memIdx : llvm::enumerate(indices.drop_front())) {
83 Value partialIdx = memIdx.value();
84 int64_t indexMulFactor = 1;
85
86 // Calculate the product of the i'th index and the [0:i-1] shape dims.
87 for (unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
88 ++i) {
89 int64_t dimSize = memrefType.getShape()[i];
90 indexMulFactor *= dimSize;
91 }
92
93 // Multiply product by the current index operand.
94 if (llvm::isPowerOf2_64(indexMulFactor)) {
95 auto constant =
96 rewriter
97 .create<arith::ConstantOp>(
98 loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
99 .getResult();
100 finalIdx =
101 rewriter.create<arith::ShLIOp>(loc, finalIdx, constant).getResult();
102 } else {
103 auto constant = rewriter
104 .create<arith::ConstantOp>(
105 loc, rewriter.getIndexAttr(indexMulFactor))
106 .getResult();
107 finalIdx =
108 rewriter.create<arith::MulIOp>(loc, finalIdx, constant).getResult();
109 }
110
111 // Sum up with the prior lower dimension accessors.
112 auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
113 finalIdx = sumOp.getResult();
114 }
115 return finalIdx;
116}
117
118static bool hasMultiDimMemRef(ValueRange values) {
119 return llvm::any_of(values, [](Value v) {
120 auto memref = dyn_cast<MemRefType>(v.getType());
121 if (!memref)
122 return false;
123 return !isUniDimensional(memref);
124 });
125}
126
127namespace {
128
129struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
130 using OpConversionPattern::OpConversionPattern;
131
132 LogicalResult
133 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
134 ConversionPatternRewriter &rewriter) const override {
135 MemRefType type = op.getMemRefType();
136 if (isUniDimensional(type) || !type.hasStaticShape() ||
137 /*Already converted?*/ op.getIndices().size() == 1)
138 return failure();
139 Value finalIdx =
140 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
141 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
142
143 SmallVector<Value>{finalIdx});
144 return success();
145 }
146};
147
148struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
149 using OpConversionPattern::OpConversionPattern;
150
151 LogicalResult
152 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
153 ConversionPatternRewriter &rewriter) const override {
154 MemRefType type = op.getMemRefType();
155 if (isUniDimensional(type) || !type.hasStaticShape() ||
156 /*Already converted?*/ op.getIndices().size() == 1)
157 return failure();
158 Value finalIdx =
159 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
160 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
161 adaptor.getMemref(),
162 SmallVector<Value>{finalIdx});
163 return success();
164 }
165};
166
167struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
168 using OpConversionPattern::OpConversionPattern;
169
170 LogicalResult
171 matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
172 ConversionPatternRewriter &rewriter) const override {
173 MemRefType type = op.getType();
174 if (isUniDimensional(type) || !type.hasStaticShape())
175 return failure();
176 MemRefType newType = getFlattenedMemRefType(type);
177 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
178 return success();
179 }
180};
181
182struct AllocaOpConversion : public OpConversionPattern<memref::AllocaOp> {
183 using OpConversionPattern::OpConversionPattern;
184
185 LogicalResult
186 matchAndRewrite(memref::AllocaOp op, OpAdaptor /*adaptor*/,
187 ConversionPatternRewriter &rewriter) const override {
188 MemRefType type = op.getType();
189 if (isUniDimensional(type) || !type.hasStaticShape())
190 return failure();
191 MemRefType newType = getFlattenedMemRefType(type);
192 rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
193 return success();
194 }
195};
196
197struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
198 using OpConversionPattern::OpConversionPattern;
199
200 LogicalResult
201 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter) const override {
203 MemRefType type = op.getType();
204 if (isUniDimensional(type) || !type.hasStaticShape())
205 return failure();
206 MemRefType newType = getFlattenedMemRefType(type);
207
208 auto cstAttr =
209 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
210
211 SmallVector<Attribute> flattenedVals;
212 for (auto attr : cstAttr.getValues<Attribute>())
213 flattenedVals.push_back(attr);
214
215 auto newTypeAttr = TypeAttr::get(newType);
216 auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type);
217 auto newName = rewriter.getStringAttr(newNameStr);
218 globalNameMap[op.getSymNameAttr()] = newName;
219
220 RankedTensorType tensorType = RankedTensorType::get(
221 {static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
222 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
223
224 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
225 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
226 op.getConstantAttr(), op.getAlignmentAttr());
227
228 return success();
229 }
230};
231
232struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
233 using OpConversionPattern::OpConversionPattern;
234
235 LogicalResult
236 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter) const override {
238 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
239 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
240 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
241
242 MemRefType type = globalOp.getType();
243 if (isUniDimensional(type) || !type.hasStaticShape())
244 return failure();
245
246 MemRefType newType = getFlattenedMemRefType(type);
247 auto originalName = globalOp.getSymNameAttr();
248 auto newNameIt = globalNameMap.find(originalName);
249 if (newNameIt == globalNameMap.end())
250 return failure();
251 auto newName = newNameIt->second;
252
253 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
254
255 return success();
256 }
257};
258
259struct ReshapeOpConversion : public OpConversionPattern<memref::ReshapeOp> {
260 using OpConversionPattern::OpConversionPattern;
261
262 LogicalResult
263 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter) const override {
265 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
266 if (!flattenedSource)
267 return failure();
268
269 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
270 if (isUniDimensional(flattenedSrcType) ||
271 !flattenedSrcType.hasStaticShape()) {
272 rewriter.replaceOp(op, flattenedSource);
273 return success();
274 }
275
276 return failure();
277 }
278};
279
280// A generic pattern which will replace an op with a new op of the same type
281// but using the adaptor (type converted) operands.
282template <typename TOp>
283struct OperandConversionPattern : public OpConversionPattern<TOp> {
285 using OpAdaptor = typename TOp::Adaptor;
286 LogicalResult
287 matchAndRewrite(TOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
290 adaptor.getOperands(), op->getAttrs());
291 return success();
292 }
293};
294
295// Cannot use OperandConversionPattern for branch op since the default builder
296// doesn't provide a method for communicating block successors.
297struct CondBranchOpConversion
298 : public OpConversionPattern<mlir::cf::CondBranchOp> {
299 using OpConversionPattern::OpConversionPattern;
300
301 LogicalResult
302 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter) const override {
304 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
305 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
306 adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
307 return success();
308 }
309};
310
311// Rewrites a call op signature to flattened types. If rewriteFunctions is set,
312// will also replace the callee with a private definition of the called
313// function of the updated signature.
314struct CallOpConversion : public OpConversionPattern<func::CallOp> {
315 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
316 bool rewriteFunctions = false)
317 : OpConversionPattern(typeConverter, context),
318 rewriteFunctions(rewriteFunctions) {}
319
320 LogicalResult
321 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 llvm::SmallVector<Type> convResTypes;
324 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
325 return failure();
326 auto newCallOp = rewriter.create<func::CallOp>(
327 op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
328
329 if (!rewriteFunctions) {
330 rewriter.replaceOp(op, newCallOp);
331 return success();
332 }
333
334 // Override any definition corresponding to the updated signature.
335 // It is up to users of this pass to define how these rewritten functions
336 // are to be implemented.
337 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
338 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
339 FunctionType funcType = FunctionType::get(
340 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
341 func::FuncOp newFuncOp;
342 if (calledFunction)
343 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
344 calledFunction, op.getCallee(), funcType);
345 else
346 newFuncOp =
347 rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
348 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
349 rewriter.replaceOp(op, newCallOp);
350
351 return success();
352 }
353
354private:
355 bool rewriteFunctions;
356};
357
358template <typename... TOp>
359void addGenericLegalityConstraint(ConversionTarget &target) {
360 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
361 return !hasMultiDimMemRef(op->getOperands()) &&
362 !hasMultiDimMemRef(op->getResults());
363 }),
364 ...);
365}
366
367static void populateFlattenMemRefsLegality(ConversionTarget &target) {
368 target.addLegalDialect<arith::ArithDialect>();
369 target.addDynamicallyLegalOp<memref::AllocOp>(
370 [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
371 target.addDynamicallyLegalOp<memref::AllocaOp>(
372 [](memref::AllocaOp op) { return isUniDimensional(op.getType()); });
373 target.addDynamicallyLegalOp<memref::StoreOp>(
374 [](memref::StoreOp op) { return op.getIndices().size() == 1; });
375 target.addDynamicallyLegalOp<memref::LoadOp>(
376 [](memref::LoadOp op) { return op.getIndices().size() == 1; });
377 target.addDynamicallyLegalOp<memref::GlobalOp>(
378 [](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
379 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
380 [](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
381 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
382 func::CallOp, func::ReturnOp, memref::DeallocOp,
383 memref::CopyOp>(target);
384
385 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
386 auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
387 return hasMultiDimMemRef(block.getArguments());
388 });
389
390 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
391 if (auto memref = dyn_cast<MemRefType>(type))
392 return isUniDimensional(memref);
393 return true;
394 });
395
396 return argsConverted && resultsConverted;
397 });
398}
399
400// Materializes a multidimensional memory to unidimensional memory by using a
401// memref.collapse_shape operation.
402// TODO: This is also possible for dynamically shaped memories.
403static Value materializeCollapseShapeFlattening(OpBuilder &builder,
404 MemRefType type,
405 ValueRange inputs,
406 Location loc) {
407 assert(type.hasStaticShape() &&
408 "Can only subview flatten memref's with static shape (for now...).");
409 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
410 int64_t memSize = sourceType.getNumElements();
411 ArrayRef<int64_t> sourceShape = sourceType.getShape();
412 ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
413
414 // Build ReassociationIndices to collapse completely to 1D MemRef.
415 auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
416 assert(indices.has_value() && "expected a valid collapse");
417
418 // Generate the appropriate return type:
419 return builder.create<memref::CollapseShapeOp>(loc, inputs[0],
420 indices.value());
421}
422
423static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
424 // Add default conversion for all types generically.
425 typeConverter.addConversion([](Type type) { return type; });
426 // Add specific conversion for memref types.
427 typeConverter.addConversion([](MemRefType memref) {
428 if (isUniDimensional(memref))
429 return memref;
430 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
431 memref.getElementType());
432 });
433}
434
435struct FlattenMemRefPass
436 : public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
437public:
438 void runOnOperation() override {
439
440 auto *ctx = &getContext();
441 TypeConverter typeConverter;
442 populateTypeConversionPatterns(typeConverter);
443
444 RewritePatternSet patterns(ctx);
445 SetVector<StringRef> rewrittenCallees;
446 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
447 AllocaOpConversion, GlobalOpConversion, GetGlobalOpConversion,
448 ReshapeOpConversion, OperandConversionPattern<func::ReturnOp>,
449 OperandConversionPattern<memref::DeallocOp>,
450 CondBranchOpConversion,
451 OperandConversionPattern<memref::DeallocOp>,
452 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
453 typeConverter, ctx);
454 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
455 patterns, typeConverter);
456
457 ConversionTarget target(*ctx);
458 populateFlattenMemRefsLegality(target);
459
460 if (applyPartialConversion(getOperation(), target, std::move(patterns))
461 .failed()) {
462 signalPassFailure();
463 return;
464 }
465 }
466};
467
468struct FlattenMemRefCallsPass
469 : public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
470public:
471 void runOnOperation() override {
472 auto *ctx = &getContext();
473 TypeConverter typeConverter;
474 populateTypeConversionPatterns(typeConverter);
475 RewritePatternSet patterns(ctx);
476
477 // Only run conversion on call ops within the body of the function. callee
478 // functions are rewritten by rewriteFunctions=true. We do not use
479 // populateFuncOpTypeConversionPattern to rewrite the function signatures,
480 // since non-called functions should not have their types converted.
481 // It is up to users of this pass to define how these rewritten functions
482 // are to be implemented.
483 patterns.add<CallOpConversion>(typeConverter, ctx,
484 /*rewriteFunctions=*/true);
485
486 ConversionTarget target(*ctx);
487 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
488 addGenericLegalityConstraint<func::CallOp>(target);
489 addGenericLegalityConstraint<func::FuncOp>(target);
490
491 // Add a target materializer to handle memory flattening through
492 // memref.subview operations.
493 typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
494
495 if (applyPartialConversion(getOperation(), target, std::move(patterns))
496 .failed()) {
497 signalPassFailure();
498 return;
499 }
500 }
501};
502
503} // namespace
504
505namespace circt {
506std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
507 return std::make_unique<FlattenMemRefPass>();
508}
509
510std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
511 return std::make_unique<FlattenMemRefCallsPass>();
512}
513
514} // namespace circt
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static std::atomic< unsigned > globalCounter(0)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type