CIRCT  20.0.0git
FlattenMemRefs.cpp
Go to the documentation of this file.
1 //===- FlattenMemRefs.cpp - MemRef flattening pass --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the MemRef flattening pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15 #include "mlir/Conversion/LLVMCommon/Pattern.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/MathExtras.h"
29 
30 namespace circt {
31 #define GEN_PASS_DEF_FLATTENMEMREF
32 #define GEN_PASS_DEF_FLATTENMEMREFCALLS
33 #include "circt/Transforms/Passes.h.inc"
34 } // namespace circt
35 
36 using namespace mlir;
37 using namespace circt;
38 
39 bool circt::isUniDimensional(MemRefType memref) {
40  return memref.getShape().size() == 1;
41 }
42 
43 /// A struct for maintaining function declarations which needs to be rewritten,
44 /// if they contain memref arguments that was flattened.
46  func::FuncOp op;
47  FunctionType type;
48 };
49 
50 static std::atomic<unsigned> globalCounter(0);
51 static DenseMap<StringAttr, StringAttr> globalNameMap;
52 
53 static MemRefType getFlattenedMemRefType(MemRefType type) {
54  return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
55  type.getElementType());
56 }
57 
58 static std::string getFlattenedMemRefName(StringAttr baseName,
59  MemRefType type) {
60  unsigned uniqueID = globalCounter++;
61  return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
62  type.getElementType(), uniqueID);
63 }
64 
65 // Flatten indices by generating the product of the i'th index and the [0:i-1]
66 // shapes, for each index, and then summing these.
67 static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
68  ValueRange indices, MemRefType memrefType) {
69  assert(memrefType.hasStaticShape() && "expected statically shaped memref");
70  Location loc = op->getLoc();
71 
72  if (indices.empty()) {
73  // Singleton memref (e.g. memref<i32>) - return 0.
74  return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
75  .getResult();
76  }
77 
78  Value finalIdx = indices.front();
79  for (auto memIdx : llvm::enumerate(indices.drop_front())) {
80  Value partialIdx = memIdx.value();
81  int64_t indexMulFactor = 1;
82 
83  // Calculate the product of the i'th index and the [0:i-1] shape dims.
84  for (unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
85  ++i) {
86  int64_t dimSize = memrefType.getShape()[i];
87  indexMulFactor *= dimSize;
88  }
89 
90  // Multiply product by the current index operand.
91  if (llvm::isPowerOf2_64(indexMulFactor)) {
92  auto constant =
93  rewriter
94  .create<arith::ConstantOp>(
95  loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
96  .getResult();
97  finalIdx =
98  rewriter.create<arith::ShLIOp>(loc, finalIdx, constant).getResult();
99  } else {
100  auto constant = rewriter
101  .create<arith::ConstantOp>(
102  loc, rewriter.getIndexAttr(indexMulFactor))
103  .getResult();
104  finalIdx =
105  rewriter.create<arith::MulIOp>(loc, finalIdx, constant).getResult();
106  }
107 
108  // Sum up with the prior lower dimension accessors.
109  auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
110  finalIdx = sumOp.getResult();
111  }
112  return finalIdx;
113 }
114 
115 static bool hasMultiDimMemRef(ValueRange values) {
116  return llvm::any_of(values, [](Value v) {
117  auto memref = dyn_cast<MemRefType>(v.getType());
118  if (!memref)
119  return false;
120  return !isUniDimensional(memref);
121  });
122 }
123 
124 namespace {
125 
126 struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
127  using OpConversionPattern::OpConversionPattern;
128 
129  LogicalResult
130  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override {
132  MemRefType type = op.getMemRefType();
133  if (isUniDimensional(type) || !type.hasStaticShape() ||
134  /*Already converted?*/ op.getIndices().size() == 1)
135  return failure();
136  Value finalIdx =
137  flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
138  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
139 
140  SmallVector<Value>{finalIdx});
141  return success();
142  }
143 };
144 
145 struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
146  using OpConversionPattern::OpConversionPattern;
147 
148  LogicalResult
149  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
150  ConversionPatternRewriter &rewriter) const override {
151  MemRefType type = op.getMemRefType();
152  if (isUniDimensional(type) || !type.hasStaticShape() ||
153  /*Already converted?*/ op.getIndices().size() == 1)
154  return failure();
155  Value finalIdx =
156  flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
157  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
158  adaptor.getMemref(),
159  SmallVector<Value>{finalIdx});
160  return success();
161  }
162 };
163 
164 struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
165  using OpConversionPattern::OpConversionPattern;
166 
167  LogicalResult
168  matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
169  ConversionPatternRewriter &rewriter) const override {
170  MemRefType type = op.getType();
171  if (isUniDimensional(type) || !type.hasStaticShape())
172  return failure();
173  MemRefType newType = getFlattenedMemRefType(type);
174  rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
175  return success();
176  }
177 };
178 
179 struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
180  using OpConversionPattern::OpConversionPattern;
181 
182  LogicalResult
183  matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
184  ConversionPatternRewriter &rewriter) const override {
185  MemRefType type = op.getType();
186  if (isUniDimensional(type) || !type.hasStaticShape())
187  return failure();
188  MemRefType newType = getFlattenedMemRefType(type);
189 
190  auto cstAttr =
191  llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());
192 
193  SmallVector<Attribute> flattenedVals;
194  for (auto attr : cstAttr.getValues<Attribute>())
195  flattenedVals.push_back(attr);
196 
197  auto newTypeAttr = TypeAttr::get(newType);
198  auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type);
199  auto newName = rewriter.getStringAttr(newNameStr);
200  globalNameMap[op.getSymNameAttr()] = newName;
201 
202  RankedTensorType tensorType = RankedTensorType::get(
203  {static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
204  auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);
205 
206  rewriter.replaceOpWithNewOp<memref::GlobalOp>(
207  op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
208  op.getConstantAttr(), op.getAlignmentAttr());
209 
210  return success();
211  }
212 };
213 
214 struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
215  using OpConversionPattern::OpConversionPattern;
216 
217  LogicalResult
218  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
219  ConversionPatternRewriter &rewriter) const override {
220  auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
221  auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
222  SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));
223 
224  MemRefType type = globalOp.getType();
225  if (isUniDimensional(type) || !type.hasStaticShape())
226  return failure();
227 
228  MemRefType newType = getFlattenedMemRefType(type);
229  auto originalName = globalOp.getSymNameAttr();
230  auto newNameIt = globalNameMap.find(originalName);
231  if (newNameIt == globalNameMap.end())
232  return failure();
233  auto newName = newNameIt->second;
234 
235  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);
236 
237  return success();
238  }
239 };
240 
241 // A generic pattern which will replace an op with a new op of the same type
242 // but using the adaptor (type converted) operands.
243 template <typename TOp>
244 struct OperandConversionPattern : public OpConversionPattern<TOp> {
246  using OpAdaptor = typename TOp::Adaptor;
247  LogicalResult
248  matchAndRewrite(TOp op, OpAdaptor adaptor,
249  ConversionPatternRewriter &rewriter) const override {
250  rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
251  adaptor.getOperands(), op->getAttrs());
252  return success();
253  }
254 };
255 
256 // Cannot use OperandConversionPattern for branch op since the default builder
257 // doesn't provide a method for communicating block successors.
258 struct CondBranchOpConversion
259  : public OpConversionPattern<mlir::cf::CondBranchOp> {
260  using OpConversionPattern::OpConversionPattern;
261 
262  LogicalResult
263  matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
264  ConversionPatternRewriter &rewriter) const override {
265  rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
266  op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
267  adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
268  return success();
269  }
270 };
271 
272 // Rewrites a call op signature to flattened types. If rewriteFunctions is set,
273 // will also replace the callee with a private definition of the called
274 // function of the updated signature.
275 struct CallOpConversion : public OpConversionPattern<func::CallOp> {
276  CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
277  bool rewriteFunctions = false)
278  : OpConversionPattern(typeConverter, context),
279  rewriteFunctions(rewriteFunctions) {}
280 
281  LogicalResult
282  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
283  ConversionPatternRewriter &rewriter) const override {
284  llvm::SmallVector<Type> convResTypes;
285  if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
286  return failure();
287  auto newCallOp = rewriter.create<func::CallOp>(
288  op.getLoc(), adaptor.getCallee(), convResTypes, adaptor.getOperands());
289 
290  if (!rewriteFunctions) {
291  rewriter.replaceOp(op, newCallOp);
292  return success();
293  }
294 
295  // Override any definition corresponding to the updated signature.
296  // It is up to users of this pass to define how these rewritten functions
297  // are to be implemented.
298  rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
299  auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
300  FunctionType funcType = FunctionType::get(
301  op.getContext(), newCallOp.getOperandTypes(), convResTypes);
302  func::FuncOp newFuncOp;
303  if (calledFunction)
304  newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
305  calledFunction, op.getCallee(), funcType);
306  else
307  newFuncOp =
308  rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
309  newFuncOp.setVisibility(SymbolTable::Visibility::Private);
310  rewriter.replaceOp(op, newCallOp);
311 
312  return success();
313  }
314 
315 private:
316  bool rewriteFunctions;
317 };
318 
319 template <typename... TOp>
320 void addGenericLegalityConstraint(ConversionTarget &target) {
321  (target.addDynamicallyLegalOp<TOp>([](TOp op) {
322  return !hasMultiDimMemRef(op->getOperands()) &&
323  !hasMultiDimMemRef(op->getResults());
324  }),
325  ...);
326 }
327 
328 static void populateFlattenMemRefsLegality(ConversionTarget &target) {
329  target.addLegalDialect<arith::ArithDialect>();
330  target.addDynamicallyLegalOp<memref::AllocOp>(
331  [](memref::AllocOp op) { return isUniDimensional(op.getType()); });
332  target.addDynamicallyLegalOp<memref::StoreOp>(
333  [](memref::StoreOp op) { return op.getIndices().size() == 1; });
334  target.addDynamicallyLegalOp<memref::LoadOp>(
335  [](memref::LoadOp op) { return op.getIndices().size() == 1; });
336  target.addDynamicallyLegalOp<memref::GlobalOp>(
337  [](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
338  target.addDynamicallyLegalOp<memref::GetGlobalOp>(
339  [](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
340  addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
341  func::CallOp, func::ReturnOp, memref::DeallocOp,
342  memref::CopyOp>(target);
343 
344  target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
345  auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
346  return hasMultiDimMemRef(block.getArguments());
347  });
348 
349  auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
350  if (auto memref = dyn_cast<MemRefType>(type))
351  return isUniDimensional(memref);
352  return true;
353  });
354 
355  return argsConverted && resultsConverted;
356  });
357 }
358 
359 // Materializes a multidimensional memory to unidimensional memory by using a
360 // memref.subview operation.
361 // TODO: This is also possible for dynamically shaped memories.
362 static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
363  ValueRange inputs, Location loc) {
364  assert(type.hasStaticShape() &&
365  "Can only subview flatten memref's with static shape (for now...).");
366  MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
367  int64_t memSize = sourceType.getNumElements();
368  unsigned dims = sourceType.getShape().size();
369 
370  // Build offset, sizes and strides
371  SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
372  SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
373  offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
374  SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
375 
376  // Generate the appropriate return type:
377  MemRefType outType = MemRefType::get({memSize}, type.getElementType());
378  return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
379  offsets, strides);
380 }
381 
382 static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
383  // Add default conversion for all types generically.
384  typeConverter.addConversion([](Type type) { return type; });
385  // Add specific conversion for memref types.
386  typeConverter.addConversion([](MemRefType memref) {
387  if (isUniDimensional(memref))
388  return memref;
389  return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
390  memref.getElementType());
391  });
392 }
393 
394 struct FlattenMemRefPass
395  : public circt::impl::FlattenMemRefBase<FlattenMemRefPass> {
396 public:
397  void runOnOperation() override {
398 
399  auto *ctx = &getContext();
400  TypeConverter typeConverter;
401  populateTypeConversionPatterns(typeConverter);
402 
403  RewritePatternSet patterns(ctx);
404  SetVector<StringRef> rewrittenCallees;
405  patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
406  GlobalOpConversion, GetGlobalOpConversion,
407  OperandConversionPattern<func::ReturnOp>,
408  OperandConversionPattern<memref::DeallocOp>,
409  CondBranchOpConversion,
410  OperandConversionPattern<memref::DeallocOp>,
411  OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
412  typeConverter, ctx);
413  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
414  patterns, typeConverter);
415 
416  ConversionTarget target(*ctx);
417  populateFlattenMemRefsLegality(target);
418 
419  if (applyPartialConversion(getOperation(), target, std::move(patterns))
420  .failed()) {
421  signalPassFailure();
422  return;
423  }
424  }
425 };
426 
427 struct FlattenMemRefCallsPass
428  : public circt::impl::FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
429 public:
430  void runOnOperation() override {
431  auto *ctx = &getContext();
432  TypeConverter typeConverter;
433  populateTypeConversionPatterns(typeConverter);
434  RewritePatternSet patterns(ctx);
435 
436  // Only run conversion on call ops within the body of the function. callee
437  // functions are rewritten by rewriteFunctions=true. We do not use
438  // populateFuncOpTypeConversionPattern to rewrite the function signatures,
439  // since non-called functions should not have their types converted.
440  // It is up to users of this pass to define how these rewritten functions
441  // are to be implemented.
442  patterns.add<CallOpConversion>(typeConverter, ctx,
443  /*rewriteFunctions=*/true);
444 
445  ConversionTarget target(*ctx);
446  target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
447  addGenericLegalityConstraint<func::CallOp>(target);
448  addGenericLegalityConstraint<func::FuncOp>(target);
449 
450  // Add a target materializer to handle memory flattening through
451  // memref.subview operations.
452  typeConverter.addTargetMaterialization(materializeSubViewFlattening);
453 
454  if (applyPartialConversion(getOperation(), target, std::move(patterns))
455  .failed()) {
456  signalPassFailure();
457  return;
458  }
459  }
460 };
461 
462 } // namespace
463 
464 namespace circt {
465 std::unique_ptr<mlir::Pass> createFlattenMemRefPass() {
466  return std::make_unique<FlattenMemRefPass>();
467 }
468 
469 std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass() {
470  return std::make_unique<FlattenMemRefCallsPass>();
471 }
472 
473 } // namespace circt
assert(baseType &&"element must be base type")
static MemRefType getFlattenedMemRefType(MemRefType type)
static std::atomic< unsigned > globalCounter(0)
static bool hasMultiDimMemRef(ValueRange values)
static DenseMap< StringAttr, StringAttr > globalNameMap
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
static std::string getFlattenedMemRefName(StringAttr baseName, MemRefType type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...
FunctionType type
func::FuncOp op