CIRCT  19.0.0git
FlattenMemRefs.cpp
Go to the documentation of this file.
1 //===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the MemRef flattening pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16 #include "mlir/Conversion/LLVMCommon/Pattern.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/BuiltinDialect.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/MathExtras.h"
28 
29 using namespace mlir;
30 using namespace circt;
31 
32 bool circt::isUniDimensional(MemRefType memref) {
33  return memref.getShape().size() == 1;
34 }
35 
36 /// A struct for maintaining function declarations which needs to be rewritten,
37 /// if they contain memref arguments that was flattened.
39  func::FuncOp op;
40  FunctionType type;
41 };
42 
43 // Flatten indices by generating the product of the i'th index and the [0:i-1]
44 // shapes, for each index, and then summing these.
45 static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
46  ValueRange indices, MemRefType memrefType) {
47  assert(memrefType.hasStaticShape() && "expected statically shaped memref");
48  Location loc = op->getLoc();
49 
50  if (indices.empty()) {
51  // Singleton memref (e.g. memref<i32>) - return 0.
52  return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
53  .getResult();
54  }
55 
56  Value finalIdx = indices.front();
57  for (auto memIdx : llvm::enumerate(indices.drop_front())) {
58  Value partialIdx = memIdx.value();
59  int64_t indexMulFactor = 1;
60 
61  // Calculate the product of the i'th index and the [0:i-1] shape dims.
62  for (unsigned i = 0; i <= memIdx.index(); ++i) {
63  int64_t dimSize = memrefType.getShape()[i];
64  indexMulFactor *= dimSize;
65  }
66 
67  // Multiply product by the current index operand.
68  if (llvm::isPowerOf2_64(indexMulFactor)) {
69  auto constant =
70  rewriter
71  .create<arith::ConstantOp>(
72  loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
73  .getResult();
74  partialIdx =
75  rewriter.create<arith::ShLIOp>(loc, partialIdx, constant).getResult();
76  } else {
77  auto constant = rewriter
78  .create<arith::ConstantOp>(
79  loc, rewriter.getIndexAttr(indexMulFactor))
80  .getResult();
81  partialIdx =
82  rewriter.create<arith::MulIOp>(loc, partialIdx, constant).getResult();
83  }
84 
85  // Sum up with the prior lower dimension accessors.
86  auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
87  finalIdx = sumOp.getResult();
88  }
89  return finalIdx;
90 }
91 
92 static bool hasMultiDimMemRef(ValueRange values) {
93  return llvm::any_of(values, [](Value v) {
94  auto memref = v.getType().dyn_cast<MemRefType>();
95  if (!memref)
96  return false;
97  return !isUniDimensional(memref);
98  });
99 }
100 
101 namespace {
102 
103 struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
104  using OpConversionPattern::OpConversionPattern;
105 
106  LogicalResult
107  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
108  ConversionPatternRewriter &rewriter) const override {
109  MemRefType type = op.getMemRefType();
110  if (isUniDimensional(type) || !type.hasStaticShape() ||
111  /*Already converted?*/ op.getIndices().size() == 1)
112  return failure();
113  Value finalIdx =
114  flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
115  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
116 
117  SmallVector<Value>{finalIdx});
118  return success();
119  }
120 };
121 
122 struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
123  using OpConversionPattern::OpConversionPattern;
124 
125  LogicalResult
126  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
127  ConversionPatternRewriter &rewriter) const override {
128  MemRefType type = op.getMemRefType();
129  if (isUniDimensional(type) || !type.hasStaticShape() ||
130  /*Already converted?*/ op.getIndices().size() == 1)
131  return failure();
132  Value finalIdx =
133  flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
134  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
135  adaptor.getMemref(),
136  SmallVector<Value>{finalIdx});
137  return success();
138  }
139 };
140 
141 struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
142  using OpConversionPattern::OpConversionPattern;
143 
144  LogicalResult
145  matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
146  ConversionPatternRewriter &rewriter) const override {
147  MemRefType type = op.getType();
148  if (isUniDimensional(type) || !type.hasStaticShape())
149  return failure();
150  MemRefType newType = MemRefType::get(
151  SmallVector<int64_t>{type.getNumElements()}, type.getElementType());
152  rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
153  return success();
154  }
155 };
156 
157 // A generic pattern which will replace an op with a new op of the same type
158 // but using the adaptor (type converted) operands.
159 template <typename TOp>
160 struct OperandConversionPattern : public OpConversionPattern<TOp> {
162  using OpAdaptor = typename TOp::Adaptor;
163  LogicalResult
164  matchAndRewrite(TOp op, OpAdaptor adaptor,
165  ConversionPatternRewriter &rewriter) const override {
166  rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
167  adaptor.getOperands(), op->getAttrs());
168  return success();
169  }
170 };
171 
172 // Cannot use OperandConversionPattern for branch op since the default builder
173 // doesn't provide a method for communicating block successors.
174 struct CondBranchOpConversion
175  : public OpConversionPattern<mlir::cf::CondBranchOp> {
176  using OpConversionPattern::OpConversionPattern;
177 
178  LogicalResult
179  matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
180  ConversionPatternRewriter &rewriter) const override {
181  rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
182  op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
183  adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
184  return success();
185  }
186 };
187 
188 // Rewrites a call op signature to flattened types. If rewriteFunctions is set,
189 // will also replace the callee with a private definition of the called
190 // function of the updated signature.
191 struct CallOpConversion : public OpConversionPattern<func::CallOp> {
192  CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
193  bool rewriteFunctions = false)
194  : OpConversionPattern(typeConverter, context),
195  rewriteFunctions(rewriteFunctions) {}
196 
197  LogicalResult
198  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
199  ConversionPatternRewriter &rewriter) const override {
200  llvm::SmallVector<Type> convResTypes;
201  if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
202  return failure();
203  auto newCallOp = rewriter.replaceOpWithNewOp<func::CallOp>(
204  op, adaptor.getCallee(), convResTypes, adaptor.getOperands());
205 
206  if (!rewriteFunctions)
207  return success();
208 
209  // Override any definition corresponding to the updated signature.
210  // It is up to users of this pass to define how these rewritten functions
211  // are to be implemented.
212  rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
213  auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
214  FunctionType funcType = FunctionType::get(
215  op.getContext(), newCallOp.getOperandTypes(), convResTypes);
216  func::FuncOp newFuncOp;
217  if (calledFunction)
218  newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
219  calledFunction, op.getCallee(), funcType);
220  else
221  newFuncOp =
222  rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
223  newFuncOp.setVisibility(SymbolTable::Visibility::Private);
224 
225  return success();
226  }
227 
228 private:
229  bool rewriteFunctions;
230 };
231 
232 template <typename... TOp>
233 void addGenericLegalityConstraint(ConversionTarget &target) {
234  (target.addDynamicallyLegalOp<TOp>([](TOp op) {
235  return !hasMultiDimMemRef(op->getOperands()) &&
236  !hasMultiDimMemRef(op->getResults());
237  }),
238  ...);
239 }
240 
241 static void populateFlattenMemRefsLegality(ConversionTarget &target) {
242  target.addLegalDialect<arith::ArithDialect>();
243  target.addDynamicallyLegalOp<memref::AllocOp>(
244  [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
245  target.addDynamicallyLegalOp<memref::StoreOp>(
246  [](memref::StoreOp op) { return op.getIndices().size() == 1; });
247  target.addDynamicallyLegalOp<memref::LoadOp>(
248  [](memref::LoadOp op) { return op.getIndices().size() == 1; });
249 
250  addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
251  func::CallOp, func::ReturnOp, memref::DeallocOp,
252  memref::CopyOp>(target);
253 
254  target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
255  auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
256  return hasMultiDimMemRef(block.getArguments());
257  });
258 
259  auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
260  if (auto memref = type.dyn_cast<MemRefType>())
261  return isUniDimensional(memref);
262  return true;
263  });
264 
265  return argsConverted && resultsConverted;
266  });
267 }
268 
269 // Materializes a multidimensional memory to unidimensional memory by using a
270 // memref.subview operation.
271 // TODO: This is also possible for dynamically shaped memories.
272 static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
273  ValueRange inputs, Location loc) {
274  assert(type.hasStaticShape() &&
275  "Can only subview flatten memref's with static shape (for now...).");
276  MemRefType sourceType = inputs[0].getType().cast<MemRefType>();
277  int64_t memSize = sourceType.getNumElements();
278  unsigned dims = sourceType.getShape().size();
279 
280  // Build offset, sizes and strides
281  SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
282  SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
283  offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
284  SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
285 
286  // Generate the appropriate return type:
287  MemRefType outType = MemRefType::get({memSize}, type.getElementType());
288  return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
289  offsets, strides);
290 }
291 
292 static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
293  // Add default conversion for all types generically.
294  typeConverter.addConversion([](Type type) { return type; });
295  // Add specific conversion for memref types.
296  typeConverter.addConversion([](MemRefType memref) {
297  if (isUniDimensional(memref))
298  return memref;
299  return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
300  memref.getElementType());
301  });
302 }
303 
304 struct FlattenMemRefPass : public FlattenMemRefBase<FlattenMemRefPass> {
305 public:
306  void runOnOperation() override {
307 
308  auto *ctx = &getContext();
309  TypeConverter typeConverter;
310  populateTypeConversionPatterns(typeConverter);
311 
312  RewritePatternSet patterns(ctx);
313  SetVector<StringRef> rewrittenCallees;
314  patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
315  OperandConversionPattern<func::ReturnOp>,
316  OperandConversionPattern<memref::DeallocOp>,
317  CondBranchOpConversion,
318  OperandConversionPattern<memref::DeallocOp>,
319  OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
320  typeConverter, ctx);
321  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
322  patterns, typeConverter);
323 
324  ConversionTarget target(*ctx);
325  populateFlattenMemRefsLegality(target);
326 
327  if (applyPartialConversion(getOperation(), target, std::move(patterns))
328  .failed()) {
329  signalPassFailure();
330  return;
331  }
332  }
333 };
334 
335 struct FlattenMemRefCallsPass
336  : public FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
337 public:
338  void runOnOperation() override {
339  auto *ctx = &getContext();
340  TypeConverter typeConverter;
341  populateTypeConversionPatterns(typeConverter);
342  RewritePatternSet patterns(ctx);
343 
344  // Only run conversion on call ops within the body of the function. callee
345  // functions are rewritten by rewriteFunctions=true. We do not use
346  // populateFuncOpTypeConversionPattern to rewrite the function signatures,
347  // since non-called functions should not have their types converted.
348  // It is up to users of this pass to define how these rewritten functions
349  // are to be implemented.
350  patterns.add<CallOpConversion>(typeConverter, ctx,
351  /*rewriteFunctions=*/true);
352 
353  ConversionTarget target(*ctx);
354  target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
355  addGenericLegalityConstraint<func::CallOp>(target);
356  addGenericLegalityConstraint<func::FuncOp>(target);
357 
358  // Add a target materializer to handle memory flattening through
359  // memref.subview operations.
360  typeConverter.addTargetMaterialization(materializeSubViewFlattening);
361 
362  if (applyPartialConversion(getOperation(), target, std::move(patterns))
363  .failed()) {
364  signalPassFailure();
365  return;
366  }
367  }
368 };
369 
370 } // namespace
371 
372 namespace circt {
373 std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
374  return std::make_unique<FlattenMemRefPass>();
375 }
376 
377 std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
378  return std::make_unique<FlattenMemRefCallsPass>();
379 }
380 
381 } // namespace circt
assert(baseType &&"element must be base type")
static bool hasMultiDimMemRef(ValueRange values)
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
OneToOneConvertToLLVMPattern< llhd::LoadOp, LLVM::LoadOp > LoadOpConversion
llvm::SmallVector< StringAttr > inputs
Builder builder
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type
func::FuncOp op