CIRCT  19.0.0git
LegalizeMemrefs.cpp
Go to the documentation of this file.
1 //===- LegalizeMemrefs.cpp - handshake memref legalization pass -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the memref legalization pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace circt;
25 using namespace handshake;
26 using namespace mlir;
27 
28 namespace {
29 
30 struct HandshakeLegalizeMemrefsPass
31  : public HandshakeLegalizeMemrefsBase<HandshakeLegalizeMemrefsPass> {
32  void runOnOperation() override {
33  func::FuncOp op = getOperation();
34  if (op.isExternal())
35  return;
36 
37  // Erase all memref.dealloc operations - this implies that we consider all
38  // memref.alloc's in the IR to be "static", in the C sense. It is then up to
39  // callers of the handshake module to determine whether a call to said
40  // module implies a _call_ (shared semantics) or an _instance_.
41  for (auto dealloc :
42  llvm::make_early_inc_range(op.getOps<memref::DeallocOp>()))
43  dealloc.erase();
44 
45  auto b = OpBuilder(op);
46 
47  // Convert any memref.copy to explicit store operations (scf loop in case of
48  // an array).
49  for (auto copy : llvm::make_early_inc_range(op.getOps<memref::CopyOp>())) {
50  b.setInsertionPoint(copy);
51  auto loc = copy.getLoc();
52  auto src = copy.getSource();
53  auto dst = copy.getTarget();
54  auto memrefType = src.getType().cast<MemRefType>();
55  if (!isUniDimensional(memrefType)) {
56  llvm::errs() << "Cannot legalize multi-dimensional memref operation "
57  << copy
58  << ". Please run the memref flattening pass before this "
59  "pass.";
60  signalPassFailure();
61  return;
62  }
63 
64  auto emitLoadStore = [&](Value index) {
65  llvm::SmallVector<Value> indices = {index};
66  auto loadValue = b.create<memref::LoadOp>(loc, src, indices);
67  b.create<memref::StoreOp>(loc, loadValue, dst, indices);
68  };
69 
70  auto n = memrefType.getShape()[0];
71 
72  if (n > 1) {
73  auto lb = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
74  auto ub = b.create<arith::ConstantIndexOp>(loc, n).getResult();
75  auto step = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
76 
77  b.create<scf::ForOp>(
78  loc, lb, ub, step, llvm::SmallVector<Value>(),
79  [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
80  emitLoadStore(iv);
81  b.create<scf::YieldOp>(loc);
82  });
83  } else
84  emitLoadStore(b.create<arith::ConstantIndexOp>(loc, 0));
85 
86  copy.erase();
87  }
88  };
89 };
90 } // namespace
91 
92 std::unique_ptr<mlir::Pass>
94  return std::make_unique<HandshakeLegalizeMemrefsPass>();
95 }
std::unique_ptr< mlir::Pass > createHandshakeLegalizeMemrefsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
bool isUniDimensional(mlir::MemRefType memref)