CIRCT 22.0.0git
Loading...
Searching...
No Matches
RegOfVecToMem.cpp
Go to the documentation of this file.
1//===- RegOfVecToMem.cpp - Convert Register Arrays to Memories -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass converts register arrays that follow memory access
10// patterns to seq.firmem operations.
11//
12//===----------------------------------------------------------------------===//
13
18#include "mlir/IR/ImplicitLocOpBuilder.h"
19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/DenseMap.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/Debug.h"
23
24#define DEBUG_TYPE "reg-of-vec-to-mem"
25
26using namespace circt;
27using namespace seq;
28using namespace hw;
29
30namespace circt {
31namespace seq {
32#define GEN_PASS_DEF_REGOFVECTOMEM
33#include "circt/Dialect/Seq/SeqPasses.h.inc"
34} // namespace seq
35} // namespace circt
36
37namespace {
38
39struct MemoryPattern {
40 FirRegOp memReg; // The register array representing memory
41 FirRegOp outputReg; // Optional output register
42 Value clock; // Clock signal
43 Value readAddr; // Read address
44 Value writeAddr; // Write address
45 Value writeData; // Write data
46 Value writeEnable; // Write enable
47 Value readEnable; // Read enable (optional)
48 comb::MuxOp writeMux; // Mux selecting between old/new memory state
49 comb::MuxOp readMux; // Mux for read data
50 hw::ArrayGetOp readAccess; // Array read operation
51 hw::ArrayInjectOp writeAccess; // Array write operation
52};
53
54class RegOfVecToMemPass : public impl::RegOfVecToMemBase<RegOfVecToMemPass> {
55public:
56 void runOnOperation() override;
57
58private:
59 bool analyzeMemoryPattern(FirRegOp reg, MemoryPattern &pattern);
60 bool createFirMemory(MemoryPattern &pattern);
61 bool isArrayType(Type type);
62 std::optional<std::pair<uint64_t, uint64_t>> getArrayDimensions(Type type);
63
64 SmallVector<Operation *> opsToErase;
65};
66
67} // end anonymous namespace
68
69bool RegOfVecToMemPass::isArrayType(Type type) {
70 return isa<hw::ArrayType, hw::UnpackedArrayType>(type);
71}
72
73std::optional<std::pair<uint64_t, uint64_t>>
74RegOfVecToMemPass::getArrayDimensions(Type type) {
75 if (auto arrayType = dyn_cast<hw::ArrayType>(type)) {
76 auto elemType = arrayType.getElementType();
77 if (auto intType = dyn_cast<IntegerType>(elemType)) {
78 return std::make_pair(arrayType.getNumElements(), intType.getWidth());
79 }
80 }
81 return std::nullopt;
82}
83
84bool RegOfVecToMemPass::analyzeMemoryPattern(FirRegOp reg,
85 MemoryPattern &pattern) {
86 LLVM_DEBUG(llvm::dbgs() << "Analyzing register: " << reg << "\n");
87
88 // Check if register has array type
89 if (!isArrayType(reg.getType()))
90 return false;
91
92 ArrayGetOp readAccess;
93 ArrayInjectOp writeAccess;
94 comb::MuxOp writeMux;
95 for (auto *user : reg.getResult().getUsers()) {
96 LLVM_DEBUG(llvm::dbgs() << " Register user: " << *user << "\n");
97 if (auto arrayGet = dyn_cast<hw::ArrayGetOp>(user); !readAccess && arrayGet)
98 readAccess = arrayGet;
99 else if (auto arrayInject = dyn_cast<hw::ArrayInjectOp>(user);
100 !writeAccess && arrayInject)
101 writeAccess = arrayInject;
102 else if (auto mux = dyn_cast<comb::MuxOp>(user); !writeMux && mux)
103 writeMux = mux;
104 else
105 return false;
106 }
107 if (!readAccess || !writeAccess || !writeMux)
108 return false;
109
110 pattern.memReg = reg;
111 pattern.clock = reg.getClk();
112
113 // Find the mux that drives this register
114 auto nextValue = reg.getNext();
115 auto mux = nextValue.getDefiningOp<comb::MuxOp>();
116 if (!mux)
117 return false;
118
119 LLVM_DEBUG(llvm::dbgs() << " Found driving mux: " << mux << "\n");
120 pattern.writeMux = mux;
121
122 // Check that the mux is only used by this register (safety check)
123 if (!mux.getResult().hasOneUse()) {
124 LLVM_DEBUG(llvm::dbgs() << " Mux has multiple uses, cannot transform\n");
125 return false;
126 }
127
128 // Analyze mux inputs: sel ? write_result : current_memory
129 Value writeResult = mux.getTrueValue();
130 Value currentMemory = mux.getFalseValue();
131
132 // Check if false value is the current register (feedback)
133 if (currentMemory != reg.getResult())
134 return false;
135
136 // Look for array_inject operation in write path
137 auto arrayInject = writeResult.getDefiningOp<hw::ArrayInjectOp>();
138 if (!arrayInject)
139 return false;
140
141 LLVM_DEBUG(llvm::dbgs() << " Found array_inject: " << arrayInject << "\n");
142 pattern.writeAccess = arrayInject;
143 pattern.writeAddr = arrayInject.getIndex();
144 pattern.writeData = arrayInject.getElement();
145 pattern.writeEnable = mux.getCond();
146
147 // Look for read pattern - find array_get users
148 auto arrayGet = readAccess;
149 LLVM_DEBUG(llvm::dbgs() << " Found array_get: " << arrayGet << "\n");
150 pattern.readAccess = arrayGet;
151 pattern.readAddr = arrayGet.getIndex();
152
153 // Check if read goes through output register
154 for (auto *readUser : arrayGet.getResult().getUsers()) {
155 if (auto outputReg = dyn_cast<FirRegOp>(readUser)) {
156 if (outputReg.getClk() == pattern.clock) {
157 LLVM_DEBUG(llvm::dbgs()
158 << " Found output register: " << outputReg << "\n");
159 pattern.outputReg = outputReg;
160 break;
161 }
162 }
163 }
164
165 bool success = pattern.readAccess != nullptr;
166 LLVM_DEBUG(llvm::dbgs() << " Pattern analysis "
167 << (success ? "succeeded" : "failed") << "\n");
168 return success;
169}
170
171bool RegOfVecToMemPass::createFirMemory(MemoryPattern &pattern) {
172 LLVM_DEBUG(llvm::dbgs() << "Creating FirMemory for pattern\n");
173
174 auto dims = getArrayDimensions(pattern.memReg.getType());
175 if (!dims)
176 return false;
177
178 uint64_t depth = dims->first;
179 uint64_t width = dims->second;
180
181 LLVM_DEBUG(llvm::dbgs() << " Memory dimensions: " << depth << " x " << width
182 << "\n");
183
184 ImplicitLocOpBuilder builder(pattern.memReg.getLoc(), pattern.memReg);
185
186 // Create FirMem
187 auto memType =
188 FirMemType::get(builder.getContext(), depth, width, /*maskWidth=*/1);
189 auto firMem = seq::FirMemOp::create(
190 builder, memType, /*readLatency=*/0, /*writeLatency=*/1,
191 /*readUnderWrite=*/seq::RUW::Undefined,
192 /*writeUnderWrite=*/seq::WUW::Undefined,
193 /*name=*/builder.getStringAttr("mem"), /*innerSym=*/hw::InnerSymAttr{},
194 /*init=*/seq::FirMemInitAttr{}, /*prefix=*/StringAttr{},
195 /*outputFile=*/Attribute{});
196
197 // Create read port
198 Value readData = FirMemReadOp::create(
199 builder, firMem, pattern.readAddr, pattern.clock,
200 /*enable=*/hw::ConstantOp::create(builder, builder.getI1Type(), 1));
201
202 LLVM_DEBUG(llvm::dbgs() << " Created read port\n"
203 << firMem << "\n " << readData);
204
205 Value mask;
206 // Create write port
207 FirMemWriteOp::create(builder, firMem, pattern.writeAddr, pattern.clock,
208 pattern.writeEnable, pattern.writeData, mask);
209
210 LLVM_DEBUG(llvm::dbgs() << " Created write port\n");
211
212 // Replace read access
213 if (pattern.outputReg)
214 // If there's an output register, replace its input
215 pattern.outputReg.getNext().replaceAllUsesWith(readData);
216 else
217 // Replace direct read access
218 pattern.readAccess.getResult().replaceAllUsesWith(readData);
219
220 // Mark old operations for removal
221 opsToErase.push_back(pattern.memReg);
222 if (pattern.readAccess)
223 opsToErase.push_back(pattern.readAccess);
224 if (pattern.writeAccess)
225 opsToErase.push_back(pattern.writeAccess);
226 if (pattern.writeMux)
227 opsToErase.push_back(pattern.writeMux);
228
229 return true;
230}
231
232void RegOfVecToMemPass::runOnOperation() {
233 auto module = getOperation();
234
235 SmallVector<FirRegOp> arrayRegs;
236
237 // Collect all FirRegOp with array types
238 module.walk([&](FirRegOp reg) {
239 if (isArrayType(reg.getType())) {
240 arrayRegs.push_back(reg);
241 }
242 });
243
244 // Analyze each array register for memory patterns
245 for (auto reg : arrayRegs) {
246 MemoryPattern pattern;
247 if (analyzeMemoryPattern(reg, pattern)) {
248 createFirMemory(pattern);
249 }
250 }
251
252 // Erase all marked operations
253 for (auto *op : opsToErase) {
254 LLVM_DEBUG(llvm::dbgs()
255 << "Erasing operation: " << *op << " number of uses:"
256 << "\n");
257 op->dropAllUses();
258 op->erase();
259 }
260 opsToErase.clear();
261}
RewritePatternSet pattern
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21