CIRCT  19.0.0git
LegalizeMemrefs.cpp
Go to the documentation of this file.
1 //===- LegalizeMemrefs.cpp - handshake memref legalization pass -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the memref legalization pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 namespace circt {
25 namespace handshake {
26 #define GEN_PASS_DEF_HANDSHAKELEGALIZEMEMREFS
27 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
28 } // namespace handshake
29 } // namespace circt
30 
31 using namespace circt;
32 using namespace handshake;
33 using namespace mlir;
34 
35 namespace {
36 
37 struct HandshakeLegalizeMemrefsPass
38  : public circt::handshake::impl::HandshakeLegalizeMemrefsBase<
39  HandshakeLegalizeMemrefsPass> {
40  void runOnOperation() override {
41  func::FuncOp op = getOperation();
42  if (op.isExternal())
43  return;
44 
45  // Erase all memref.dealloc operations - this implies that we consider all
46  // memref.alloc's in the IR to be "static", in the C sense. It is then up to
47  // callers of the handshake module to determine whether a call to said
48  // module implies a _call_ (shared semantics) or an _instance_.
49  for (auto dealloc :
50  llvm::make_early_inc_range(op.getOps<memref::DeallocOp>()))
51  dealloc.erase();
52 
53  auto b = OpBuilder(op);
54 
55  // Convert any memref.copy to explicit store operations (scf loop in case of
56  // an array).
57  for (auto copy : llvm::make_early_inc_range(op.getOps<memref::CopyOp>())) {
58  b.setInsertionPoint(copy);
59  auto loc = copy.getLoc();
60  auto src = copy.getSource();
61  auto dst = copy.getTarget();
62  auto memrefType = cast<MemRefType>(src.getType());
63  if (!isUniDimensional(memrefType)) {
64  llvm::errs() << "Cannot legalize multi-dimensional memref operation "
65  << copy
66  << ". Please run the memref flattening pass before this "
67  "pass.";
68  signalPassFailure();
69  return;
70  }
71 
72  auto emitLoadStore = [&](Value index) {
73  llvm::SmallVector<Value> indices = {index};
74  auto loadValue = b.create<memref::LoadOp>(loc, src, indices);
75  b.create<memref::StoreOp>(loc, loadValue, dst, indices);
76  };
77 
78  auto n = memrefType.getShape()[0];
79 
80  if (n > 1) {
81  auto lb = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
82  auto ub = b.create<arith::ConstantIndexOp>(loc, n).getResult();
83  auto step = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
84 
85  b.create<scf::ForOp>(
86  loc, lb, ub, step, llvm::SmallVector<Value>(),
87  [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
88  emitLoadStore(iv);
89  b.create<scf::YieldOp>(loc);
90  });
91  } else
92  emitLoadStore(b.create<arith::ConstantIndexOp>(loc, 0));
93 
94  copy.erase();
95  }
96  };
97 };
98 } // namespace
99 
100 std::unique_ptr<mlir::Pass>
102  return std::make_unique<HandshakeLegalizeMemrefsPass>();
103 }
std::unique_ptr< mlir::Pass > createHandshakeLegalizeMemrefsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
bool isUniDimensional(mlir::MemRefType memref)