CIRCT 23.0.0git
Loading...
Searching...
No Matches
FlattenMemRefs.cpp
Go to the documentation of this file.
1//===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the MemRef flattening pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "circt/Support/LLVM.h"
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23#include "mlir/IR/BuiltinDialect.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/OperationSupport.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "llvm/Support/FormatVariadic.h"
31#include "llvm/Support/LogicalResult.h"
32#include "llvm/Support/MathExtras.h"
33
34namespace circt {
35#define GEN_PASS_DEF_FLATTENMEMREF
36#define GEN_PASS_DEF_FLATTENMEMREFCALLS
37#include "circt/Transforms/Passes.h.inc"
38} // namespace circt
39
40using namespace mlir;
41using namespace circt;
42
43bool circt::isUniDimensional(MemRefType memref) {
44 return memref.getShape().size() == 1;
45}
46
47/// A struct for maintaining function declarations which needs to be rewritten,
48/// if they contain memref arguments that was flattened.
50 func::FuncOp op;
51 FunctionType type;
52};
53
55 unsigned counter = 0;
56 DenseMap<StringAttr, StringAttr> nameMap;
57};
58
59static MemRefType getFlattenedMemRefType(MemRefType type) {
60 return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
61 type.getElementType());
62}
63
65 StringAttr baseName,
66 MemRefType type) {
67 unsigned uniqueID = state.counter++;
68 return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
69 type.getElementType(), uniqueID);
70}
71
72// Flatten indices by generating the product of the i'th index and the [0:i-1]
73// shapes, for each index, and then summing these.
74static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
75 ValueRange indices, MemRefType memrefType) {
76 assert(memrefType.hasStaticShape() && "expected statically shaped memref");
77 Location loc = op->getLoc();
78
79 if (indices.empty()) {
80 // Singleton memref (e.g. memref<i32>) - return 0.
81 return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0))
82 .getResult();
83 }
84
85 Value finalIdx = indices.front();
86 for (auto memIdx : llvm::enumerate(indices.drop_front())) {
87 Value partialIdx = memIdx.value();
88 int64_t indexMulFactor = 1;
89
90 // Calculate the product of the i'th index and the [0:i-1] shape dims.
91 for (unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
92 ++i) {
93 int64_t dimSize = memrefType.getShape()[i];
94 indexMulFactor *= dimSize;
95 }
96
97 // Multiply product by the current index operand.
98 if (llvm::isPowerOf2_64(indexMulFactor)) {
99 auto constant = arith::ConstantOp::create(
100 rewriter, loc,
101 rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
102 .getResult();
103 finalIdx =
104 arith::ShLIOp::create(rewriter, loc, finalIdx, constant).getResult();
105 } else {
106 auto constant = arith::ConstantOp::create(
107 rewriter, loc, rewriter.getIndexAttr(indexMulFactor))
108 .getResult();
109 finalIdx =
110 arith::MulIOp::create(rewriter, loc, finalIdx, constant).getResult();
111 }
112
113 // Sum up with the prior lower dimension accessors.
114 auto sumOp = arith::AddIOp::create(rewriter, loc, finalIdx, partialIdx);
115 finalIdx = sumOp.getResult();
116 }
117 return finalIdx;
118}
119
120static bool hasMultiDimMemRef(ValueRange values) {
121 return llvm::any_of(values, [](Value v) {
122 auto memref = dyn_cast<MemRefType>(v.getType());
123 if (!memref)
124 return false;
125 return !isUniDimensional(memref);
126 });
127}
128
129namespace {
130
131struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
132 using OpConversionPattern::OpConversionPattern;
133
134 LogicalResult
135 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const override {
137 MemRefType type = op.getMemRefType();
138 if (isUniDimensional(type) || !type.hasStaticShape() ||
139 /*Already converted?*/ op.getIndices().size() == 1)
140 return failure();
141 Value finalIdx =
142 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
143 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
144
145 SmallVector<Value>{finalIdx});
146 return success();
147 }
148};
149
150struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
151 using OpConversionPattern::OpConversionPattern;
152
153 LogicalResult
154 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter) const override {
156 MemRefType type = op.getMemRefType();
157 if (isUniDimensional(type) || !type.hasStaticShape() ||
158 /*Already converted?*/ op.getIndices().size() == 1)
159 return failure();
160 Value finalIdx =
161 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
162 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
163 adaptor.getMemref(),
164 SmallVector<Value>{finalIdx});
165 return success();
166 }
167};
168
169struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
170 using OpConversionPattern::OpConversionPattern;
171
172 LogicalResult
173 matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
174 ConversionPatternRewriter &rewriter) const override {
175 MemRefType type = op.getType();
176 if (isUniDimensional(type) || !type.hasStaticShape())
177 return failure();
178 MemRefType newType = getFlattenedMemRefType(type);
179 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
180 return success();
181 }
182};
183
184struct AllocaOpConversion : public OpConversionPattern<memref::AllocaOp> {
185 using OpConversionPattern::OpConversionPattern;
186
187 LogicalResult
188 matchAndRewrite(memref::AllocaOp op, OpAdaptor /*adaptor*/,
189 ConversionPatternRewriter &rewriter) const override {
190 MemRefType type = op.getType();
191 if (isUniDimensional(type) || !type.hasStaticShape())
192 return failure();
193 MemRefType newType = getFlattenedMemRefType(type);
194 rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
195 return success();
196 }
197};
198
199struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
200 GlobalOpConversion(TypeConverter &typeConverter, MLIRContext *context,
201 FlattenMemRefsState &state)
202 : OpConversionPattern(typeConverter, context), state(state) {}
203
204 LogicalResult
205 matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
206 ConversionPatternRewriter &rewriter) const override {
207 MemRefType type = op.getType();
208 if (isUniDimensional(type) || !type.hasStaticShape())
209 return failure();
210 MemRefType newType = getFlattenedMemRefType(type);
211
212 auto cstAttr =
213 llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
214
215 SmallVector<Attribute> flattenedVals;
216 for (auto attr : cstAttr.getValues<Attribute>())
217 flattenedVals.push_back(attr);
218
219 auto newTypeAttr = TypeAttr::get(newType);
220 auto newNameStr =
221 getFlattenedMemRefName(state, op.getConstantAttrName(), type);
222 auto newName = rewriter.getStringAttr(newNameStr);
223 state.nameMap[op.getSymNameAttr()] = newName;
224
225 RankedTensorType tensorType = RankedTensorType::get(
226 {static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
227 auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
228
229 rewriter.replaceOpWithNewOp<memref::GlobalOp>(
230 op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
231 op.getConstantAttr(), op.getAlignmentAttr());
232
233 return success();
234 }
235
236private:
237 FlattenMemRefsState &state;
238};
239
240struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
241 GetGlobalOpConversion(TypeConverter &typeConverter, MLIRContext *context,
242 FlattenMemRefsState &state)
243 : OpConversionPattern(typeConverter, context), state(state) {}
244
245 LogicalResult
246 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter) const override {
248 auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
249 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
250 SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
251
252 MemRefType type = globalOp.getType();
253 if (isUniDimensional(type) || !type.hasStaticShape())
254 return failure();
255
256 MemRefType newType = getFlattenedMemRefType(type);
257 auto originalName = globalOp.getSymNameAttr();
258 auto newNameIt = state.nameMap.find(originalName);
259 if (newNameIt == state.nameMap.end())
260 return failure();
261 auto newName = newNameIt->second;
262
263 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
264
265 return success();
266 }
267
268private:
269 FlattenMemRefsState &state;
270};
271
272struct ReshapeOpConversion : public OpConversionPattern<memref::ReshapeOp> {
273 using OpConversionPattern::OpConversionPattern;
274
275 LogicalResult
276 matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
277 ConversionPatternRewriter &rewriter) const override {
278 Value flattenedSource = rewriter.getRemappedValue(op.getSource());
279 if (!flattenedSource)
280 return failure();
281
282 auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
283 if (isUniDimensional(flattenedSrcType) ||
284 !flattenedSrcType.hasStaticShape()) {
285 rewriter.replaceOp(op, flattenedSource);
286 return success();
287 }
288
289 return failure();
290 }
291};
292
293// A generic pattern which will replace an op with a new op of the same type
294// but using the adaptor (type converted) operands.
295template <typename TOp>
296struct OperandConversionPattern : public OpConversionPattern<TOp> {
298 using OpAdaptor = typename TOp::Adaptor;
299 LogicalResult
300 matchAndRewrite(TOp op, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter) const override {
302 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
303 adaptor.getOperands(), op->getAttrs());
304 return success();
305 }
306};
307
308// Rewrites a call op signature to flattened types. If rewriteFunctions is set,
309// will also replace the callee with a private definition of the called
310// function of the updated signature.
311struct CallOpConversion : public OpConversionPattern<func::CallOp> {
312 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
313 bool rewriteFunctions = false)
314 : OpConversionPattern(typeConverter, context),
315 rewriteFunctions(rewriteFunctions) {}
316
317 LogicalResult
318 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter) const override {
320 llvm::SmallVector<Type> convResTypes;
321 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
322 return failure();
323 auto newCallOp =
324 func::CallOp::create(rewriter, op.getLoc(), adaptor.getCallee(),
325 convResTypes, adaptor.getOperands());
326
327 if (!rewriteFunctions) {
328 rewriter.replaceOp(op, newCallOp);
329 return success();
330 }
331
332 // Override any definition corresponding to the updated signature.
333 // It is up to users of this pass to define how these rewritten functions
334 // are to be implemented.
335 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
336 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
337 FunctionType funcType = FunctionType::get(
338 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
339 func::FuncOp newFuncOp;
340 if (calledFunction)
341 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
342 calledFunction, op.getCallee(), funcType);
343 else
344 newFuncOp =
345 func::FuncOp::create(rewriter, op.getLoc(), op.getCallee(), funcType);
346 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
347 rewriter.replaceOp(op, newCallOp);
348
349 return success();
350 }
351
352private:
353 bool rewriteFunctions;
354};
355
356template <typename... TOp>
357void addGenericLegalityConstraint(ConversionTarget &target) {
358 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
359 return !hasMultiDimMemRef(op->getOperands()) &&
360 !hasMultiDimMemRef(op->getResults());
361 }),
362 ...);
363}
364
365static void populateFlattenMemRefsLegality(ConversionTarget &target) {
366 target.addLegalDialect<arith::ArithDialect>();
367 target.addDynamicallyLegalOp<memref::AllocOp>(
368 [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
369 target.addDynamicallyLegalOp<memref::AllocaOp>(
370 [](memref::AllocaOp op) { return isUniDimensional(op.getType()); });
371 target.addDynamicallyLegalOp<memref::StoreOp>(
372 [](memref::StoreOp op) { return op.getIndices().size() == 1; });
373 target.addDynamicallyLegalOp<memref::LoadOp>(
374 [](memref::LoadOp op) { return op.getIndices().size() == 1; });
375 target.addDynamicallyLegalOp<memref::GlobalOp>(
376 [](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
377 target.addDynamicallyLegalOp<memref::GetGlobalOp>(
378 [](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
379 addGenericLegalityConstraint<func::CallOp, func::ReturnOp, memref::DeallocOp,
380 memref::CopyOp>(target);
381
382 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
383 auto argsConverted = llvm::all_of(op.getArgumentTypes(), [](Type type) {
384 if (auto memref = dyn_cast<MemRefType>(type))
385 return isUniDimensional(memref);
386 return true;
387 });
388
389 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
390 if (auto memref = dyn_cast<MemRefType>(type))
391 return isUniDimensional(memref);
392 return true;
393 });
394
395 return argsConverted && resultsConverted;
396 });
397}
398
399// Materializes a multidimensional memory to unidimensional memory by using a
400// memref.collapse_shape operation.
401// TODO: This is also possible for dynamically shaped memories.
402static Value materializeCollapseShapeFlattening(OpBuilder &builder,
403 MemRefType type,
404 ValueRange inputs,
405 Location loc) {
406 assert(type.hasStaticShape() &&
407 "Can only subview flatten memref's with static shape (for now...).");
408 MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
409 int64_t memSize = sourceType.getNumElements();
410 ArrayRef<int64_t> sourceShape = sourceType.getShape();
411 ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
412
413 // Build ReassociationIndices to collapse completely to 1D MemRef.
414 auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
415 assert(indices.has_value() && "expected a valid collapse");
416
417 // Generate the appropriate return type:
418 return memref::CollapseShapeOp::create(builder, loc, inputs[0],
419 indices.value());
420}
421
422static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
423 // Add default conversion for all types generically.
424 typeConverter.addConversion([](Type type) { return type; });
425 // Add specific conversion for memref types.
426 typeConverter.addConversion([](MemRefType memref) {
427 if (isUniDimensional(memref))
428 return memref;
429 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
430 memref.getElementType());
431 });
432}
433
434struct FlattenMemRefPass
435 : public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
436public:
437 void runOnOperation() override {
438
439 auto *ctx = &getContext();
440 TypeConverter typeConverter;
442 populateTypeConversionPatterns(typeConverter);
443
444 RewritePatternSet patterns(ctx);
445 SetVector<StringRef> rewrittenCallees;
446 patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
447 AllocaOpConversion, ReshapeOpConversion,
448 OperandConversionPattern<func::ReturnOp>,
449 OperandConversionPattern<memref::DeallocOp>,
450 OperandConversionPattern<memref::DeallocOp>,
451 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
452 typeConverter, ctx);
453 patterns.add<GlobalOpConversion, GetGlobalOpConversion>(typeConverter, ctx,
454 state);
455 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
456 patterns, typeConverter);
457
458 ConversionTarget target(*ctx);
459 populateFlattenMemRefsLegality(target);
460 mlir::cf::populateCFStructuralTypeConversionsAndLegality(typeConverter,
461 patterns, target);
462
463 if (applyPartialConversion(getOperation(), target, std::move(patterns))
464 .failed()) {
465 signalPassFailure();
466 return;
467 }
468 }
469};
470
471struct FlattenMemRefCallsPass
472 : public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
473public:
474 void runOnOperation() override {
475 auto *ctx = &getContext();
476 TypeConverter typeConverter;
477 populateTypeConversionPatterns(typeConverter);
478 RewritePatternSet patterns(ctx);
479
480 // Only run conversion on call ops within the body of the function. callee
481 // functions are rewritten by rewriteFunctions=true. We do not use
482 // populateFuncOpTypeConversionPattern to rewrite the function signatures,
483 // since non-called functions should not have their types converted.
484 // It is up to users of this pass to define how these rewritten functions
485 // are to be implemented.
486 patterns.add<CallOpConversion>(typeConverter, ctx,
487 /*rewriteFunctions=*/true);
488
489 ConversionTarget target(*ctx);
490 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
491 addGenericLegalityConstraint<func::CallOp>(target);
492 addGenericLegalityConstraint<func::FuncOp>(target);
493
494 // Add a target materializer to handle memory flattening through
495 // memref.subview operations.
496 typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
497
498 if (applyPartialConversion(getOperation(), target, std::move(patterns))
499 .failed()) {
500 signalPassFailure();
501 return;
502 }
503 }
504};
505
506} // namespace
507
508namespace circt {
509std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
510 return std::make_unique<FlattenMemRefPass>();
511}
512
513std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
514 return std::make_unique<FlattenMemRefCallsPass>();
515}
516
517} // namespace circt
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static MemRefType getFlattenedMemRefType(MemRefType type)
static std::string getFlattenedMemRefName(FlattenMemRefsState &state, StringAttr baseName, MemRefType type)
static bool hasMultiDimMemRef(ValueRange values)
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
DenseMap< StringAttr, StringAttr > nameMap
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type