CIRCT  19.0.0git
FlattenMemRefs.cpp
Go to the documentation of this file.
1 //===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the MemRef flattening pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15 #include "mlir/Conversion/LLVMCommon/Pattern.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/MathExtras.h"
28 
29 namespace circt {
30 #define GEN_PASS_DEF_FLATTENMEMREF
31 #define GEN_PASS_DEF_FLATTENMEMREFCALLS
32 #include "circt/Transforms/Passes.h.inc"
33 } // namespace circt
34 
35 using namespace mlir;
36 using namespace circt;
37 
38 bool circt::isUniDimensional(MemRefType memref) {
39  return memref.getShape().size() == 1;
40 }
41 
42 /// A struct for maintaining function declarations which needs to be rewritten,
43 /// if they contain memref arguments that was flattened.
45  func::FuncOp op;
46  FunctionType type;
47 };
48 
49 // Flatten indices by generating the product of the i'th index and the [0:i-1]
50 // shapes, for each index, and then summing these.
51 static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
52  ValueRange indices, MemRefType memrefType) {
53  assert(memrefType.hasStaticShape() && "expected statically shaped memref");
54  Location loc = op->getLoc();
55 
56  if (indices.empty()) {
57  // Singleton memref (e.g. memref<i32>) - return 0.
58  return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
59  .getResult();
60  }
61 
62  Value finalIdx = indices.front();
63  for (auto memIdx : llvm::enumerate(indices.drop_front())) {
64  Value partialIdx = memIdx.value();
65  int64_t indexMulFactor = 1;
66 
67  // Calculate the product of the i'th index and the [0:i-1] shape dims.
68  for (unsigned i = 0; i <= memIdx.index(); ++i) {
69  int64_t dimSize = memrefType.getShape()[i];
70  indexMulFactor *= dimSize;
71  }
72 
73  // Multiply product by the current index operand.
74  if (llvm::isPowerOf2_64(indexMulFactor)) {
75  auto constant =
76  rewriter
77  .create<arith::ConstantOp>(
78  loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
79  .getResult();
80  partialIdx =
81  rewriter.create<arith::ShLIOp>(loc, partialIdx, constant).getResult();
82  } else {
83  auto constant = rewriter
84  .create<arith::ConstantOp>(
85  loc, rewriter.getIndexAttr(indexMulFactor))
86  .getResult();
87  partialIdx =
88  rewriter.create<arith::MulIOp>(loc, partialIdx, constant).getResult();
89  }
90 
91  // Sum up with the prior lower dimension accessors.
92  auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
93  finalIdx = sumOp.getResult();
94  }
95  return finalIdx;
96 }
97 
98 static bool hasMultiDimMemRef(ValueRange values) {
99  return llvm::any_of(values, [](Value v) {
100  auto memref = dyn_cast<MemRefType>(v.getType());
101  if (!memref)
102  return false;
103  return !isUniDimensional(memref);
104  });
105 }
106 
107 namespace {
108 
109 struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
110  using OpConversionPattern::OpConversionPattern;
111 
112  LogicalResult
113  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
114  ConversionPatternRewriter &rewriter) const override {
115  MemRefType type = op.getMemRefType();
116  if (isUniDimensional(type) || !type.hasStaticShape() ||
117  /*Already converted?*/ op.getIndices().size() == 1)
118  return failure();
119  Value finalIdx =
120  flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
121  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
122 
123  SmallVector<Value>{finalIdx});
124  return success();
125  }
126 };
127 
128 struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
129  using OpConversionPattern::OpConversionPattern;
130 
131  LogicalResult
132  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
133  ConversionPatternRewriter &rewriter) const override {
134  MemRefType type = op.getMemRefType();
135  if (isUniDimensional(type) || !type.hasStaticShape() ||
136  /*Already converted?*/ op.getIndices().size() == 1)
137  return failure();
138  Value finalIdx =
139  flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
140  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
141  adaptor.getMemref(),
142  SmallVector<Value>{finalIdx});
143  return success();
144  }
145 };
146 
147 struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
148  using OpConversionPattern::OpConversionPattern;
149 
150  LogicalResult
151  matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
152  ConversionPatternRewriter &rewriter) const override {
153  MemRefType type = op.getType();
154  if (isUniDimensional(type) || !type.hasStaticShape())
155  return failure();
156  MemRefType newType = MemRefType::get(
157  SmallVector<int64_t>{type.getNumElements()}, type.getElementType());
158  rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
159  return success();
160  }
161 };
162 
163 // A generic pattern which will replace an op with a new op of the same type
164 // but using the adaptor (type converted) operands.
165 template <typename TOp>
166 struct OperandConversionPattern : public OpConversionPattern<TOp> {
168  using OpAdaptor = typename TOp::Adaptor;
169  LogicalResult
170  matchAndRewrite(TOp op, OpAdaptor adaptor,
171  ConversionPatternRewriter &rewriter) const override {
172  rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
173  adaptor.getOperands(), op->getAttrs());
174  return success();
175  }
176 };
177 
178 // Cannot use OperandConversionPattern for branch op since the default builder
179 // doesn't provide a method for communicating block successors.
180 struct CondBranchOpConversion
181  : public OpConversionPattern<mlir::cf::CondBranchOp> {
182  using OpConversionPattern::OpConversionPattern;
183 
184  LogicalResult
185  matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
186  ConversionPatternRewriter &rewriter) const override {
187  rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
188  op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
189  adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
190  return success();
191  }
192 };
193 
194 // Rewrites a call op signature to flattened types. If rewriteFunctions is set,
195 // will also replace the callee with a private definition of the called
196 // function of the updated signature.
197 struct CallOpConversion : public OpConversionPattern<func::CallOp> {
198  CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
199  bool rewriteFunctions = false)
200  : OpConversionPattern(typeConverter, context),
201  rewriteFunctions(rewriteFunctions) {}
202 
203  LogicalResult
204  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
205  ConversionPatternRewriter &rewriter) const override {
206  llvm::SmallVector<Type> convResTypes;
207  if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
208  return failure();
209  auto newCallOp = rewriter.create<func::CallOp>(
210  op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
211 
212  if (!rewriteFunctions) {
213  rewriter.replaceOp(op, newCallOp);
214  return success();
215  }
216 
217  // Override any definition corresponding to the updated signature.
218  // It is up to users of this pass to define how these rewritten functions
219  // are to be implemented.
220  rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
221  auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
222  FunctionType funcType = FunctionType::get(
223  op.getContext(), newCallOp.getOperandTypes(), convResTypes);
224  func::FuncOp newFuncOp;
225  if (calledFunction)
226  newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
227  calledFunction, op.getCallee(), funcType);
228  else
229  newFuncOp =
230  rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
231  newFuncOp.setVisibility(SymbolTable::Visibility::Private);
232  rewriter.replaceOp(op, newCallOp);
233 
234  return success();
235  }
236 
237 private:
238  bool rewriteFunctions;
239 };
240 
241 template <typename... TOp>
242 void addGenericLegalityConstraint(ConversionTarget &target) {
243  (target.addDynamicallyLegalOp<TOp>([](TOp op) {
244  return !hasMultiDimMemRef(op->getOperands()) &&
245  !hasMultiDimMemRef(op->getResults());
246  }),
247  ...);
248 }
249 
250 static void populateFlattenMemRefsLegality(ConversionTarget &target) {
251  target.addLegalDialect<arith::ArithDialect>();
252  target.addDynamicallyLegalOp<memref::AllocOp>(
253  [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
254  target.addDynamicallyLegalOp<memref::StoreOp>(
255  [](memref::StoreOp op) { return op.getIndices().size() == 1; });
256  target.addDynamicallyLegalOp<memref::LoadOp>(
257  [](memref::LoadOp op) { return op.getIndices().size() == 1; });
258 
259  addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
260  func::CallOp, func::ReturnOp, memref::DeallocOp,
261  memref::CopyOp>(target);
262 
263  target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
264  auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
265  return hasMultiDimMemRef(block.getArguments());
266  });
267 
268  auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
269  if (auto memref = dyn_cast<MemRefType>(type))
270  return isUniDimensional(memref);
271  return true;
272  });
273 
274  return argsConverted && resultsConverted;
275  });
276 }
277 
278 // Materializes a multidimensional memory to unidimensional memory by using a
279 // memref.subview operation.
280 // TODO: This is also possible for dynamically shaped memories.
281 static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
282  ValueRange inputs, Location loc) {
283  assert(type.hasStaticShape() &&
284  "Can only subview flatten memref's with static shape (for now...).");
285  MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
286  int64_t memSize = sourceType.getNumElements();
287  unsigned dims = sourceType.getShape().size();
288 
289  // Build offset, sizes and strides
290  SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
291  SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
292  offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
293  SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
294 
295  // Generate the appropriate return type:
296  MemRefType outType = MemRefType::get({memSize}, type.getElementType());
297  return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
298  offsets, strides);
299 }
300 
301 static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
302  // Add default conversion for all types generically.
303  typeConverter.addConversion([](Type type) { return type; });
304  // Add specific conversion for memref types.
305  typeConverter.addConversion([](MemRefType memref) {
306  if (isUniDimensional(memref))
307  return memref;
308  return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
309  memref.getElementType());
310  });
311 }
312 
313 struct FlattenMemRefPass
314  : public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
315 public:
316  void runOnOperation() override {
317 
318  auto *ctx = &getContext();
319  TypeConverter typeConverter;
320  populateTypeConversionPatterns(typeConverter);
321 
322  RewritePatternSet patterns(ctx);
323  SetVector<StringRef> rewrittenCallees;
324  patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
325  OperandConversionPattern<func::ReturnOp>,
326  OperandConversionPattern<memref::DeallocOp>,
327  CondBranchOpConversion,
328  OperandConversionPattern<memref::DeallocOp>,
329  OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
330  typeConverter, ctx);
331  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
332  patterns, typeConverter);
333 
334  ConversionTarget target(*ctx);
335  populateFlattenMemRefsLegality(target);
336 
337  if (applyPartialConversion(getOperation(), target, std::move(patterns))
338  .failed()) {
339  signalPassFailure();
340  return;
341  }
342  }
343 };
344 
345 struct FlattenMemRefCallsPass
346  : public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
347 public:
348  void runOnOperation() override {
349  auto *ctx = &getContext();
350  TypeConverter typeConverter;
351  populateTypeConversionPatterns(typeConverter);
352  RewritePatternSet patterns(ctx);
353 
354  // Only run conversion on call ops within the body of the function. callee
355  // functions are rewritten by rewriteFunctions=true. We do not use
356  // populateFuncOpTypeConversionPattern to rewrite the function signatures,
357  // since non-called functions should not have their types converted.
358  // It is up to users of this pass to define how these rewritten functions
359  // are to be implemented.
360  patterns.add<CallOpConversion>(typeConverter, ctx,
361  /*rewriteFunctions=*/true);
362 
363  ConversionTarget target(*ctx);
364  target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
365  addGenericLegalityConstraint<func::CallOp>(target);
366  addGenericLegalityConstraint<func::FuncOp>(target);
367 
368  // Add a target materializer to handle memory flattening through
369  // memref.subview operations.
370  typeConverter.addTargetMaterialization(materializeSubViewFlattening);
371 
372  if (applyPartialConversion(getOperation(), target, std::move(patterns))
373  .failed()) {
374  signalPassFailure();
375  return;
376  }
377  }
378 };
379 
380 } // namespace
381 
382 namespace circt {
383 std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
384  return std::make_unique<FlattenMemRefPass>();
385 }
386 
387 std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
388  return std::make_unique<FlattenMemRefCallsPass>();
389 }
390 
391 } // namespace circt
assert(baseType &&"element must be base type")
static bool hasMultiDimMemRef(ValueRange values)
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type
func::FuncOp op