CIRCT 21.0.0git
Loading...
Searching...
No Matches
LegalizeMemrefs.cpp
Go to the documentation of this file.
1//===- LegalizeMemrefs.cpp - handshake memref legalization pass -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the memref legalization pass.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Rewrite/FrozenRewritePatternSet.h"
22#include "mlir/Transforms/DialectConversion.h"
23
24namespace circt {
25namespace handshake {
26#define GEN_PASS_DEF_HANDSHAKELEGALIZEMEMREFS
27#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
28} // namespace handshake
29} // namespace circt
30
31using namespace circt;
32using namespace handshake;
33using namespace mlir;
34
35namespace {
36
37struct HandshakeLegalizeMemrefsPass
38 : public circt::handshake::impl::HandshakeLegalizeMemrefsBase<
39 HandshakeLegalizeMemrefsPass> {
40 void runOnOperation() override {
41 func::FuncOp op = getOperation();
42 if (op.isExternal())
43 return;
44
45 // Erase all memref.dealloc operations - this implies that we consider all
46 // memref.alloc's in the IR to be "static", in the C sense. It is then up to
47 // callers of the handshake module to determine whether a call to said
48 // module implies a _call_ (shared semantics) or an _instance_.
49 for (auto dealloc :
50 llvm::make_early_inc_range(op.getOps<memref::DeallocOp>()))
51 dealloc.erase();
52
53 auto b = OpBuilder(op);
54
55 // Convert any memref.copy to explicit store operations (scf loop in case of
56 // an array).
57 for (auto copy : llvm::make_early_inc_range(op.getOps<memref::CopyOp>())) {
58 b.setInsertionPoint(copy);
59 auto loc = copy.getLoc();
60 auto src = copy.getSource();
61 auto dst = copy.getTarget();
62 auto memrefType = cast<MemRefType>(src.getType());
63 if (!isUniDimensional(memrefType)) {
64 llvm::errs() << "Cannot legalize multi-dimensional memref operation "
65 << copy
66 << ". Please run the memref flattening pass before this "
67 "pass.";
68 signalPassFailure();
69 return;
70 }
71
72 auto emitLoadStore = [&](Value index) {
73 llvm::SmallVector<Value> indices = {index};
74 auto loadValue = b.create<memref::LoadOp>(loc, src, indices);
75 b.create<memref::StoreOp>(loc, loadValue, dst, indices);
76 };
77
78 auto n = memrefType.getShape()[0];
79
80 if (n > 1) {
81 auto lb = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
82 auto ub = b.create<arith::ConstantIndexOp>(loc, n).getResult();
83 auto step = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
84
85 b.create<scf::ForOp>(
86 loc, lb, ub, step, llvm::SmallVector<Value>(),
87 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
88 emitLoadStore(iv);
89 b.create<scf::YieldOp>(loc);
90 });
91 } else
92 emitLoadStore(b.create<arith::ConstantIndexOp>(loc, 0));
93
94 copy.erase();
95 }
96 };
97};
98} // namespace
99
100std::unique_ptr<mlir::Pass>
102 return std::make_unique<HandshakeLegalizeMemrefsPass>();
103}
std::unique_ptr< mlir::Pass > createHandshakeLegalizeMemrefsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
bool isUniDimensional(mlir::MemRefType memref)