CIRCT 21.0.0git
Loading...
Searching...
No Matches
Sig2RegPass.cpp
Go to the documentation of this file.
1//===- Sig2RegPass.cpp - Implement the Sig2Reg Pass -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implement Pass to promote LLHD signals to SSA values.
10//
11//===----------------------------------------------------------------------===//
12
17#include "llvm/Support/Debug.h"
18
19#define DEBUG_TYPE "llhd-sig2reg"
20
21namespace circt {
22namespace llhd {
23#define GEN_PASS_DEF_SIG2REG
24#include "circt/Dialect/LLHD/Transforms/LLHDPasses.h.inc"
25} // namespace llhd
26} // namespace circt
27
28using namespace mlir;
29using namespace circt;
30
31namespace {
32
33/// Represents an offset of an interval relative to a root interval. All values
34/// describe number of bits, not elements.
35struct Offset {
36 Offset(uint64_t min, uint64_t max, ArrayRef<Value> dynamic)
37 : min(min), max(max), dynamic(dynamic) {}
38
39 Offset(uint64_t idx) : min(idx), max(idx) {}
40
41 // The lower bound of the offset known statically.
42 uint64_t min = 0;
43 // The upper bound of the offset known statically.
44 uint64_t max = -1;
45 // A list of SSA values used to compute the final offset.
46 SmallVector<Value> dynamic;
47
48 /// Returns if we know the exact offset statically.
49 bool isStatic() const { return min == max; }
50};
51
52/// Represents an alias interval within a root interval that is written to or
53/// read from. All values refer to number of bits, not elements.
54struct Interval {
55 Interval(const Offset &low, uint64_t bitwidth, Value value,
56 llhd::TimeAttr delay = llhd::TimeAttr())
57 : low(low), bitwidth(bitwidth), value(value), delay(delay) {}
58
59 // The offset of the interval relative to the root interval (i.e. all the bits
60 // of the original signal).
61 Offset low;
62 // The width of the interval.
63 uint64_t bitwidth;
64 // The value written to this interval or the OpResult of a read.
65 Value value;
66 // The delay with which the value is written.
67 llhd::TimeAttr delay;
68};
69
70class SigPromoter {
71public:
72 SigPromoter(llhd::SignalOp sigOp) : sigOp(sigOp) {}
73
74 // Start at the signal operation and traverse all alias operations to compute
75 // all the intervals and sort them by ascending offset.
76 LogicalResult computeIntervals() {
77 SmallVector<std::pair<Operation *, Offset>> stack;
78
79 for (auto *user : sigOp->getUsers())
80 stack.emplace_back(user, Offset(0));
81
82 while (!stack.empty()) {
83 auto currAndOffset = stack.pop_back_val();
84 auto *curr = currAndOffset.first;
85 auto offset = currAndOffset.second;
86
87 if (curr->getBlock() != sigOp->getBlock()) {
88 LLVM_DEBUG(llvm::dbgs() << " - User in other block, skipping...\n\n");
89 return failure();
90 }
91
92 auto result =
93 TypeSwitch<Operation *, LogicalResult>(curr)
94 .Case<llhd::PrbOp>([&](llhd::PrbOp probeOp) {
95 auto bw = hw::getBitWidth(probeOp.getResult().getType());
96 if (bw <= 0)
97 return failure();
98
99 readIntervals.emplace_back(offset, bw, probeOp.getResult());
100 return success();
101 })
102 .Case<llhd::DrvOp>([&](llhd::DrvOp driveOp) {
103 if (driveOp.getEnable()) {
104 LLVM_DEBUG(llvm::dbgs()
105 << " - Conditional driver, skipping...\n\n");
106 return failure();
107 }
108
109 auto timeOp =
110 driveOp.getTime().getDefiningOp<llhd::ConstantTimeOp>();
111 if (!timeOp) {
112 LLVM_DEBUG(llvm::dbgs()
113 << " - Unknown drive delay, skipping...\n\n");
114 return failure();
115 }
116
117 auto bw = hw::getBitWidth(driveOp.getValue().getType());
118 if (bw <= 0)
119 return failure();
120
121 intervals.emplace_back(offset, bw, driveOp.getValue(),
122 timeOp.getValueAttr());
123 return success();
124 })
125 .Case<llhd::SigExtractOp>([&](llhd::SigExtractOp extractOp) {
126 if (auto constOp =
127 extractOp.getLowBit().getDefiningOp<hw::ConstantOp>();
128 constOp && offset.isStatic()) {
129 for (auto *user : extractOp->getUsers())
130 stack.emplace_back(
131 user,
132 Offset(constOp.getValue().getZExtValue() + offset.min));
133
134 return success();
135 }
136
137 auto bw = hw::getBitWidth(
138 cast<hw::InOutType>(extractOp.getInput().getType())
139 .getElementType());
140 if (bw <= 0)
141 return failure();
142
143 SmallVector<Value> indices(offset.dynamic);
144 indices.push_back(extractOp.getLowBit());
145
146 for (auto *user : extractOp->getUsers())
147 stack.emplace_back(
148 user, Offset(offset.min, offset.max + bw - 1, indices));
149
150 return success();
151 })
152 .Default([](auto *op) {
153 LLVM_DEBUG(llvm::dbgs() << " - User that is not a probe or "
154 "drive, skipping...\n "
155 << *op << "\n\n");
156 return failure();
157 });
158
159 if (failed(result))
160 return failure();
161
162 toDelete.push_back(curr);
163 }
164
165 llvm::sort(intervals, [](const Interval &a, const Interval &b) {
166 return a.low.min < b.low.min;
167 });
168
169 LLVM_DEBUG({
170 llvm::dbgs() << " - Detected intervals:\n";
171 dumpIntervals(llvm::dbgs(), 4);
172 });
173
174 return success();
175 }
176
177#ifndef NDEBUG
178
179 /// Print the list of intervals in a readable format for debugging.
180 void dumpIntervals(llvm::raw_ostream &os, unsigned indent = 0) {
181 os << llvm::indent(indent) << "[\n";
182 for (const auto &interval : intervals) {
183 os << llvm::indent(indent + 2) << "<from [" << interval.low.min << ", "
184 << interval.low.max << "]\n";
185 os << llvm::indent(indent + 3) << "width " << interval.bitwidth << "\n";
186
187 for (auto idx : interval.low.dynamic)
188 os << llvm::indent(indent + 3) << idx << "\n";
189
190 os << llvm::indent(indent + 3) << "value: " << interval.value << "\n";
191 os << llvm::indent(indent + 3) << "delay: " << interval.delay << "\n";
192 os << llvm::indent(indent + 2) << ">,\n";
193 }
194 os << llvm::indent(indent) << "]\n";
195 }
196
197#endif
198
199 /// Check if we can promote the entire signal according to the current
200 /// limitations of the pass.
201 bool isPromotable() {
202 for (unsigned i = 0; i < intervals.size(); ++i) {
203 if (i >= intervals.size() - 1)
204 break;
205
206 if (intervals[i].low.max + intervals[i].bitwidth - 1 >
207 intervals[i + 1].low.min) {
208 LLVM_DEBUG({
209 llvm::dbgs() << " - Potentially overlapping drives, skipping...\n\n";
210 });
211 return false;
212 }
213 }
214
215 return true;
216 }
217
218 /// Promote the signal. This builds the necessary operations, replaces the
219 /// values, and removes the signal and signal value handling operations.
220 void promote() {
221 auto bw = hw::getBitWidth(sigOp.getInit().getType());
222 assert(bw > 0 && "bw must be known and non-zero");
223
224 OpBuilder builder(sigOp);
225 Value val = sigOp.getInit();
226 Location loc = sigOp->getLoc();
227 auto type = builder.getIntegerType(bw);
228 val = builder.createOrFold<hw::BitcastOp>(loc, type, val);
229
230 // Handle the writes by starting with the signal init value and injecting
231 // the written values at the right offsets.
232 for (auto interval : intervals) {
233 Value invMask = builder.create<hw::ConstantOp>(
234 loc, APInt::getAllOnes(interval.bitwidth));
235
236 if (uint64_t(bw) > interval.bitwidth) {
237 Value pad = builder.create<hw::ConstantOp>(
238 loc, APInt::getZero(bw - interval.bitwidth));
239 invMask = builder.createOrFold<comb::ConcatOp>(loc, pad, invMask);
240 }
241
242 Value amt = buildDynamicIndex(builder, loc, interval.low.min,
243 interval.low.dynamic, bw);
244 invMask = builder.createOrFold<comb::ShlOp>(loc, invMask, amt);
245 Value allOnes =
246 builder.create<hw::ConstantOp>(loc, APInt::getAllOnes(bw));
247 Value mask = builder.createOrFold<comb::XorOp>(loc, invMask, allOnes);
248 val = builder.createOrFold<comb::AndOp>(loc, val, mask);
249
250 Value assignVal = builder.createOrFold<hw::BitcastOp>(
251 loc, builder.getIntegerType(interval.bitwidth), interval.value);
252
253 if (uint64_t(bw) > interval.bitwidth) {
254 Value pad = builder.create<hw::ConstantOp>(
255 loc, APInt::getZero(bw - interval.bitwidth));
256 assignVal = builder.createOrFold<comb::ConcatOp>(loc, pad, assignVal);
257 }
258
259 assignVal = builder.createOrFold<comb::ShlOp>(loc, assignVal, amt);
260 if (!isImmediate(interval.delay))
261 assignVal =
262 builder.createOrFold<llhd::DelayOp>(loc, assignVal, interval.delay);
263 val = builder.createOrFold<comb::OrOp>(loc, assignVal, val);
264 }
265
266 // Handle the reads by extracting right number of bits at the right offset.
267 for (auto interval : readIntervals) {
268 if (interval.low.isStatic()) {
269 Value read = builder.createOrFold<comb::ExtractOp>(
270 loc, builder.getIntegerType(interval.bitwidth), val,
271 interval.low.min);
272 read = builder.createOrFold<hw::BitcastOp>(
273 loc, interval.value.getType(), read);
274 interval.value.replaceAllUsesWith(read);
275 continue;
276 }
277
278 Value read = buildDynamicIndex(builder, loc, interval.low.min,
279 interval.low.dynamic, bw);
280 read = builder.createOrFold<comb::ShrUOp>(loc, val, read);
281 read = builder.createOrFold<comb::ExtractOp>(
282 loc, builder.getIntegerType(interval.bitwidth), read, 0);
283 read = builder.createOrFold<hw::BitcastOp>(loc, interval.value.getType(),
284 read);
285 interval.value.replaceAllUsesWith(read);
286 }
287
288 // Delete all operations operating on signal values.
289 for (auto *op : llvm::reverse(toDelete))
290 op->erase();
291
292 sigOp->erase();
293 }
294
295private:
296 /// Given a static offset and a list of dynamic offset values, materialize an
297 /// SSA value that adds all these offsets together and is an integer with the
298 /// given 'width'.
299 Value buildDynamicIndex(OpBuilder &builder, Location loc,
300 uint64_t constOffset, ArrayRef<Value> indices,
301 uint64_t width) {
302 Value index = builder.create<hw::ConstantOp>(
303 loc, builder.getIntegerType(width), constOffset);
304
305 for (auto idx : indices) {
306 auto bw = hw::getBitWidth(idx.getType());
307 Value pad =
308 builder.create<hw::ConstantOp>(loc, APInt::getZero(width - bw));
309 idx = builder.createOrFold<comb::ConcatOp>(loc, pad, idx);
310 index = builder.createOrFold<comb::AddOp>(loc, index, idx);
311 }
312
313 return index;
314 }
315
316 bool isImmediate(llhd::TimeAttr attr) const {
317 return attr.getTime() == 0 && attr.getDelta() == 0 &&
318 attr.getEpsilon() == 1;
319 }
320
321 // The signal to be promoted.
322 llhd::SignalOp sigOp;
323 // Intervals written to.
324 SmallVector<Interval> intervals;
325 // Intervals read from.
326 SmallVector<Interval> readIntervals;
327 // Operations to delete after promotion is done.
328 SmallVector<Operation *> toDelete;
329};
330
331struct Sig2RegPass : public circt::llhd::impl::Sig2RegBase<Sig2RegPass> {
332 void runOnOperation() override;
333};
334} // namespace
335
336void Sig2RegPass::runOnOperation() {
337 hw::HWModuleOp moduleOp = getOperation();
338
339 LLVM_DEBUG(llvm::dbgs() << "=== Sig2Reg in module " << moduleOp.getSymName()
340 << "\n\n");
341
342 for (auto sigOp :
343 llvm::make_early_inc_range(moduleOp.getOps<llhd::SignalOp>())) {
344 LLVM_DEBUG(llvm::dbgs() << " - Attempting to promote signal "
345 << sigOp.getName() << "\n");
346 SigPromoter promoter(sigOp);
347 if (failed(promoter.computeIntervals()) || !promoter.isPromotable())
348 continue;
349
350 promoter.promote();
351 LLVM_DEBUG(llvm::dbgs() << " - Successfully promoted!\n\n");
352 }
353
354 LLVM_DEBUG({
355 if (moduleOp.getOps<llhd::SignalOp>().empty())
356 llvm::dbgs() << " Successfully promoted all signals in module!\n";
357 });
358
359 LLVM_DEBUG(llvm::dbgs() << "\n");
360}
assert(baseType &&"element must be base type")
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.