18#include "mlir/IR/ImplicitLocOpBuilder.h"
19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/DenseMap.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/Debug.h"
24#define DEBUG_TYPE "reg-of-vec-to-mem"
32#define GEN_PASS_DEF_REGOFVECTOMEM
33#include "circt/Dialect/Seq/SeqPasses.h.inc"
51 hw::ArrayInjectOp writeAccess;
54class RegOfVecToMemPass :
public impl::RegOfVecToMemBase<RegOfVecToMemPass> {
56 void runOnOperation()
override;
59 bool analyzeMemoryPattern(FirRegOp
reg, MemoryPattern &
pattern);
60 bool createFirMemory(MemoryPattern &
pattern);
61 bool isArrayType(Type type);
62 std::optional<std::pair<uint64_t, uint64_t>> getArrayDimensions(Type type);
64 SmallVector<Operation *> opsToErase;
69bool RegOfVecToMemPass::isArrayType(Type type) {
70 return isa<hw::ArrayType, hw::UnpackedArrayType>(type);
73std::optional<std::pair<uint64_t, uint64_t>>
74RegOfVecToMemPass::getArrayDimensions(Type type) {
75 if (
auto arrayType = dyn_cast<hw::ArrayType>(type)) {
76 auto elemType = arrayType.getElementType();
77 if (
auto intType = dyn_cast<IntegerType>(elemType)) {
78 return std::make_pair(arrayType.getNumElements(), intType.getWidth());
84bool RegOfVecToMemPass::analyzeMemoryPattern(FirRegOp
reg,
86 LLVM_DEBUG(llvm::dbgs() <<
"Analyzing register: " <<
reg <<
"\n");
89 if (!isArrayType(
reg.getType()))
93 ArrayInjectOp writeAccess;
95 for (
auto *user :
reg.getResult().getUsers()) {
96 LLVM_DEBUG(llvm::dbgs() <<
" Register user: " << *user <<
"\n");
97 if (
auto arrayGet = dyn_cast<hw::ArrayGetOp>(user); !readAccess && arrayGet)
98 readAccess = arrayGet;
99 else if (
auto arrayInject = dyn_cast<hw::ArrayInjectOp>(user);
100 !writeAccess && arrayInject)
101 writeAccess = arrayInject;
102 else if (
auto mux = dyn_cast<comb::MuxOp>(user); !writeMux && mux)
107 if (!readAccess || !writeAccess || !writeMux)
114 auto nextValue =
reg.getNext();
119 LLVM_DEBUG(llvm::dbgs() <<
" Found driving mux: " << mux <<
"\n");
123 if (!mux.getResult().hasOneUse()) {
124 LLVM_DEBUG(llvm::dbgs() <<
" Mux has multiple uses, cannot transform\n");
129 Value writeResult = mux.getTrueValue();
130 Value currentMemory = mux.getFalseValue();
133 if (currentMemory !=
reg.getResult())
137 auto arrayInject = writeResult.getDefiningOp<hw::ArrayInjectOp>();
141 LLVM_DEBUG(llvm::dbgs() <<
" Found array_inject: " << arrayInject <<
"\n");
142 pattern.writeAccess = arrayInject;
143 pattern.writeAddr = arrayInject.getIndex();
144 pattern.writeData = arrayInject.getElement();
145 pattern.writeEnable = mux.getCond();
148 auto arrayGet = readAccess;
149 LLVM_DEBUG(llvm::dbgs() <<
" Found array_get: " << arrayGet <<
"\n");
151 pattern.readAddr = arrayGet.getIndex();
154 for (
auto *readUser : arrayGet.getResult().getUsers()) {
155 if (
auto outputReg = dyn_cast<FirRegOp>(readUser)) {
156 if (outputReg.getClk() ==
pattern.clock) {
157 LLVM_DEBUG(llvm::dbgs()
158 <<
" Found output register: " << outputReg <<
"\n");
165 bool success =
pattern.readAccess !=
nullptr;
166 LLVM_DEBUG(llvm::dbgs() <<
" Pattern analysis "
167 << (success ?
"succeeded" :
"failed") <<
"\n");
171bool RegOfVecToMemPass::createFirMemory(MemoryPattern &
pattern) {
172 LLVM_DEBUG(llvm::dbgs() <<
"Creating FirMemory for pattern\n");
174 auto dims = getArrayDimensions(
pattern.memReg.getType());
178 uint64_t depth = dims->first;
179 uint64_t width = dims->second;
181 LLVM_DEBUG(llvm::dbgs() <<
" Memory dimensions: " << depth <<
" x " << width
184 ImplicitLocOpBuilder builder(
pattern.memReg.getLoc(),
pattern.memReg);
188 FirMemType::get(builder.getContext(), depth, width, 1);
189 auto firMem = seq::FirMemOp::create(
190 builder, memType, 0, 1,
193 builder.getStringAttr(
"mem"), hw::InnerSymAttr{},
194 seq::FirMemInitAttr{}, StringAttr{},
200 auto fixZeroWidthAddr = [&](Value
addr) -> Value {
201 if (
addr.getType().getIntOrFloatBitWidth() == 0) {
203 mlir::IntegerType::get(&getContext(), 1), 0)
210 auto readAddr = fixZeroWidthAddr(
pattern.readAddr);
211 Value readData = FirMemReadOp::create(
212 builder, firMem, readAddr,
pattern.clock,
215 LLVM_DEBUG(llvm::dbgs() <<
" Created read port\n"
216 << firMem <<
"\n " << readData);
220 auto writeAddr = fixZeroWidthAddr(
pattern.writeAddr);
221 FirMemWriteOp::create(builder, firMem, writeAddr,
pattern.clock,
224 LLVM_DEBUG(llvm::dbgs() <<
" Created write port\n");
229 pattern.outputReg.getNext().replaceAllUsesWith(readData);
232 pattern.readAccess.getResult().replaceAllUsesWith(readData);
235 opsToErase.push_back(
pattern.memReg);
237 opsToErase.push_back(
pattern.readAccess);
239 opsToErase.push_back(
pattern.writeAccess);
241 opsToErase.push_back(
pattern.writeMux);
246void RegOfVecToMemPass::runOnOperation() {
247 auto module = getOperation();
249 SmallVector<FirRegOp> arrayRegs;
252 module.walk([&](FirRegOp reg) {
253 if (isArrayType(reg.getType())) {
254 arrayRegs.push_back(reg);
259 for (
auto reg : arrayRegs) {
267 for (
auto *op : opsToErase) {
268 LLVM_DEBUG(llvm::dbgs()
269 <<
"Erasing operation: " << *op <<
" number of uses:"
RewritePatternSet pattern
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)