CIRCT 23.0.0git
Loading...
Searching...
No Matches
RegOfVecToMem.cpp
Go to the documentation of this file.
1//===- RegOfVecToMem.cpp - Convert Register Arrays to Memories -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass converts register arrays that follow memory access
10// patterns to seq.firmem operations.
11//
12//===----------------------------------------------------------------------===//
13
18#include "mlir/IR/ImplicitLocOpBuilder.h"
19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/DenseMap.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/Debug.h"
23
24#define DEBUG_TYPE "reg-of-vec-to-mem"
25
26using namespace circt;
27using namespace seq;
28using namespace hw;
29
30namespace circt {
31namespace seq {
32#define GEN_PASS_DEF_REGOFVECTOMEM
33#include "circt/Dialect/Seq/SeqPasses.h.inc"
34} // namespace seq
35} // namespace circt
36
37namespace {
38
39struct MemoryPattern {
40 FirRegOp memReg; // The register array representing memory
41 FirRegOp outputReg; // Optional output register
42 Value clock; // Clock signal
43 Value readAddr; // Read address
44 Value writeAddr; // Write address
45 Value writeData; // Write data
46 Value writeEnable; // Write enable
47 Value readEnable; // Read enable (optional)
48 comb::MuxOp writeMux; // Mux selecting between old/new memory state
49 comb::MuxOp readMux; // Mux for read data
50 hw::ArrayGetOp readAccess; // Array read operation
51 hw::ArrayInjectOp writeAccess; // Array write operation
52};
53
54class RegOfVecToMemPass : public impl::RegOfVecToMemBase<RegOfVecToMemPass> {
55public:
56 void runOnOperation() override;
57
58private:
59 bool analyzeMemoryPattern(FirRegOp reg, MemoryPattern &pattern);
60 bool createFirMemory(MemoryPattern &pattern);
61 bool isArrayType(Type type);
62 std::optional<std::pair<uint64_t, uint64_t>> getArrayDimensions(Type type);
63
64 SmallVector<Operation *> opsToErase;
65};
66
67} // end anonymous namespace
68
69bool RegOfVecToMemPass::isArrayType(Type type) {
70 return isa<hw::ArrayType, hw::UnpackedArrayType>(type);
71}
72
73std::optional<std::pair<uint64_t, uint64_t>>
74RegOfVecToMemPass::getArrayDimensions(Type type) {
75 if (auto arrayType = dyn_cast<hw::ArrayType>(type)) {
76 auto elemType = arrayType.getElementType();
77 if (auto intType = dyn_cast<IntegerType>(elemType)) {
78 return std::make_pair(arrayType.getNumElements(), intType.getWidth());
79 }
80 }
81 return std::nullopt;
82}
83
84bool RegOfVecToMemPass::analyzeMemoryPattern(FirRegOp reg,
85 MemoryPattern &pattern) {
86 LLVM_DEBUG(llvm::dbgs() << "Analyzing register: " << reg << "\n");
87
88 // Check if register has array type
89 if (!isArrayType(reg.getType()))
90 return false;
91
92 ArrayGetOp readAccess;
93 ArrayInjectOp writeAccess;
94 comb::MuxOp writeMux;
95 for (auto *user : reg.getResult().getUsers()) {
96 LLVM_DEBUG(llvm::dbgs() << " Register user: " << *user << "\n");
97 if (auto arrayGet = dyn_cast<hw::ArrayGetOp>(user); !readAccess && arrayGet)
98 readAccess = arrayGet;
99 else if (auto arrayInject = dyn_cast<hw::ArrayInjectOp>(user);
100 !writeAccess && arrayInject)
101 writeAccess = arrayInject;
102 else if (auto mux = dyn_cast<comb::MuxOp>(user); !writeMux && mux)
103 writeMux = mux;
104 else
105 return false;
106 }
107 if (!readAccess || !writeAccess || !writeMux)
108 return false;
109
110 pattern.memReg = reg;
111 pattern.clock = reg.getClk();
112
113 // Find the mux that drives this register
114 auto nextValue = reg.getNext();
115 auto mux = nextValue.getDefiningOp<comb::MuxOp>();
116 if (!mux)
117 return false;
118
119 LLVM_DEBUG(llvm::dbgs() << " Found driving mux: " << mux << "\n");
120 pattern.writeMux = mux;
121
122 // Check that the mux is only used by this register (safety check)
123 if (!mux.getResult().hasOneUse()) {
124 LLVM_DEBUG(llvm::dbgs() << " Mux has multiple uses, cannot transform\n");
125 return false;
126 }
127
128 // Analyze mux inputs: sel ? write_result : current_memory
129 Value writeResult = mux.getTrueValue();
130 Value currentMemory = mux.getFalseValue();
131
132 // Check if false value is the current register (feedback)
133 if (currentMemory != reg.getResult())
134 return false;
135
136 // Look for array_inject operation in write path
137 auto arrayInject = writeResult.getDefiningOp<hw::ArrayInjectOp>();
138 if (!arrayInject)
139 return false;
140
141 LLVM_DEBUG(llvm::dbgs() << " Found array_inject: " << arrayInject << "\n");
142 pattern.writeAccess = arrayInject;
143 pattern.writeAddr = arrayInject.getIndex();
144 pattern.writeData = arrayInject.getElement();
145 pattern.writeEnable = mux.getCond();
146
147 // Look for read pattern - find array_get users
148 auto arrayGet = readAccess;
149 LLVM_DEBUG(llvm::dbgs() << " Found array_get: " << arrayGet << "\n");
150 pattern.readAccess = arrayGet;
151 pattern.readAddr = arrayGet.getIndex();
152
153 // Check if read goes through output register
154 for (auto *readUser : arrayGet.getResult().getUsers()) {
155 if (auto outputReg = dyn_cast<FirRegOp>(readUser)) {
156 if (outputReg.getClk() == pattern.clock) {
157 LLVM_DEBUG(llvm::dbgs()
158 << " Found output register: " << outputReg << "\n");
159 pattern.outputReg = outputReg;
160 break;
161 }
162 }
163 }
164
165 bool success = pattern.readAccess != nullptr;
166 LLVM_DEBUG(llvm::dbgs() << " Pattern analysis "
167 << (success ? "succeeded" : "failed") << "\n");
168 return success;
169}
170
171bool RegOfVecToMemPass::createFirMemory(MemoryPattern &pattern) {
172 LLVM_DEBUG(llvm::dbgs() << "Creating FirMemory for pattern\n");
173
174 auto dims = getArrayDimensions(pattern.memReg.getType());
175 if (!dims)
176 return false;
177
178 uint64_t depth = dims->first;
179 uint64_t width = dims->second;
180
181 LLVM_DEBUG(llvm::dbgs() << " Memory dimensions: " << depth << " x " << width
182 << "\n");
183
184 ImplicitLocOpBuilder builder(pattern.memReg.getLoc(), pattern.memReg);
185
186 // Create FirMem
187 auto memType =
188 FirMemType::get(builder.getContext(), depth, width, /*maskWidth=*/1);
189 auto firMem = seq::FirMemOp::create(
190 builder, memType, /*readLatency=*/0, /*writeLatency=*/1,
191 /*readUnderWrite=*/seq::RUW::Undefined,
192 /*writeUnderWrite=*/seq::WUW::Undefined,
193 /*name=*/builder.getStringAttr("mem"), /*innerSym=*/hw::InnerSymAttr{},
194 /*init=*/seq::FirMemInitAttr{}, /*prefix=*/StringAttr{},
195 /*outputFile=*/Attribute{});
196
197 // FIRRTL currently uses a 1-bit address for a single element memory,
198 // however HW arrays use 0-bit addresses. To bridge this gap, create a 1-bit
199 // address equal to 0 if our address is 0-bit.
200 auto fixZeroWidthAddr = [&](Value addr) -> Value {
201 if (addr.getType().getIntOrFloatBitWidth() == 0) {
202 return hw::ConstantOp::create(builder,
203 mlir::IntegerType::get(&getContext(), 1), 0)
204 .getResult();
205 }
206 return addr;
207 };
208
209 // Create read port
210 auto readAddr = fixZeroWidthAddr(pattern.readAddr);
211 Value readData = FirMemReadOp::create(
212 builder, firMem, readAddr, pattern.clock,
213 /*enable=*/hw::ConstantOp::create(builder, builder.getI1Type(), 1));
214
215 LLVM_DEBUG(llvm::dbgs() << " Created read port\n"
216 << firMem << "\n " << readData);
217
218 Value mask;
219 // Create write port
220 auto writeAddr = fixZeroWidthAddr(pattern.writeAddr);
221 FirMemWriteOp::create(builder, firMem, writeAddr, pattern.clock,
222 pattern.writeEnable, pattern.writeData, mask);
223
224 LLVM_DEBUG(llvm::dbgs() << " Created write port\n");
225
226 // Replace read access
227 if (pattern.outputReg)
228 // If there's an output register, replace its input
229 pattern.outputReg.getNext().replaceAllUsesWith(readData);
230 else
231 // Replace direct read access
232 pattern.readAccess.getResult().replaceAllUsesWith(readData);
233
234 // Mark old operations for removal
235 opsToErase.push_back(pattern.memReg);
236 if (pattern.readAccess)
237 opsToErase.push_back(pattern.readAccess);
238 if (pattern.writeAccess)
239 opsToErase.push_back(pattern.writeAccess);
240 if (pattern.writeMux)
241 opsToErase.push_back(pattern.writeMux);
242
243 return true;
244}
245
246void RegOfVecToMemPass::runOnOperation() {
247 auto module = getOperation();
248
249 SmallVector<FirRegOp> arrayRegs;
250
251 // Collect all FirRegOp with array types
252 module.walk([&](FirRegOp reg) {
253 if (isArrayType(reg.getType())) {
254 arrayRegs.push_back(reg);
255 }
256 });
257
258 // Analyze each array register for memory patterns
259 for (auto reg : arrayRegs) {
260 MemoryPattern pattern;
261 if (analyzeMemoryPattern(reg, pattern)) {
262 createFirMemory(pattern);
263 }
264 }
265
266 // Erase all marked operations
267 for (auto *op : opsToErase) {
268 LLVM_DEBUG(llvm::dbgs()
269 << "Erasing operation: " << *op << " number of uses:"
270 << "\n");
271 op->dropAllUses();
272 op->erase();
273 }
274 opsToErase.clear();
275}
RewritePatternSet pattern
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21