17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
22 #include "mlir/Transforms/DialectConversion.h"
26 #define GEN_PASS_DEF_HANDSHAKELEGALIZEMEMREFS
27 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
31 using namespace circt;
32 using namespace handshake;
37 struct HandshakeLegalizeMemrefsPass
38 :
public circt::handshake::impl::HandshakeLegalizeMemrefsBase<
39 HandshakeLegalizeMemrefsPass> {
40 void runOnOperation()
override {
41 func::FuncOp op = getOperation();
50 llvm::make_early_inc_range(op.getOps<memref::DeallocOp>()))
53 auto b = OpBuilder(op);
57 for (
auto copy : llvm::make_early_inc_range(op.getOps<memref::CopyOp>())) {
58 b.setInsertionPoint(copy);
59 auto loc = copy.getLoc();
60 auto src = copy.getSource();
61 auto dst = copy.getTarget();
62 auto memrefType = cast<MemRefType>(src.getType());
64 llvm::errs() <<
"Cannot legalize multi-dimensional memref operation "
66 <<
". Please run the memref flattening pass before this "
72 auto emitLoadStore = [&](Value index) {
73 llvm::SmallVector<Value> indices = {index};
74 auto loadValue = b.create<memref::LoadOp>(loc, src, indices);
75 b.create<memref::StoreOp>(loc, loadValue, dst, indices);
78 auto n = memrefType.getShape()[0];
81 auto lb = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
82 auto ub = b.create<arith::ConstantIndexOp>(loc, n).getResult();
83 auto step = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
86 loc, lb, ub, step, llvm::SmallVector<Value>(),
87 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
89 b.create<scf::YieldOp>(loc);
92 emitLoadStore(b.create<arith::ConstantIndexOp>(loc, 0));
100 std::unique_ptr<mlir::Pass>
102 return std::make_unique<HandshakeLegalizeMemrefsPass>();
std::unique_ptr< mlir::Pass > createHandshakeLegalizeMemrefsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
bool isUniDimensional(mlir::MemRefType memref)