CIRCT 22.0.0git
Loading...
Searching...
No Matches
LinearScanRegisterAllocationPass.cpp
Go to the documentation of this file.
1//===- LinearScanRegisterAllocationPass.cpp - Register Allocation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass allocates registers using a simple linear scan algorithm.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/PatternMatch.h"
18#include "llvm/Support/Debug.h"
19
20namespace circt {
21namespace rtg {
22#define GEN_PASS_DEF_LINEARSCANREGISTERALLOCATIONPASS
23#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
24} // namespace rtg
25} // namespace circt
26
27using namespace mlir;
28using namespace circt;
29
30#define DEBUG_TYPE "rtg-linear-scan-register-allocation"
31
32namespace {
33
34/// Represents a register and its live range.
35struct RegisterLiveRange {
36 rtg::RegisterAttrInterface fixedReg;
37 rtg::VirtualRegisterOp regOp;
38 unsigned start;
39 unsigned end;
40};
41
42class LinearScanRegisterAllocationPass
43 : public circt::rtg::impl::LinearScanRegisterAllocationPassBase<
44 LinearScanRegisterAllocationPass> {
45public:
46 void runOnOperation() override;
47};
48
49} // end namespace
50
51static void expireOldInterval(SmallVector<RegisterLiveRange *> &active,
52 RegisterLiveRange *reg) {
53 // TODO: use a better datastructure for 'active'
54 llvm::sort(active, [](auto *a, auto *b) { return a->end < b->end; });
55
56 for (auto *iter = active.begin(); iter != active.end(); ++iter) {
57 auto *a = *iter;
58 if (a->end >= reg->start)
59 return;
60
61 active.erase(iter--);
62 }
63}
64
65void LinearScanRegisterAllocationPass::runOnOperation() {
66 LLVM_DEBUG(llvm::dbgs() << "=== Processing "
67 << OpWithFlags(getOperation(),
68 OpPrintingFlags().skipRegions())
69 << "\n\n");
70
71 if (getOperation()->getNumRegions() != 1 ||
72 getOperation()->getRegion(0).getBlocks().size() != 1) {
73 getOperation()->emitError("expected a single region with a single block");
74 return signalPassFailure();
75 }
76
77 DenseMap<Operation *, unsigned> opIndices;
78 unsigned maxIdx;
79 for (auto [i, op] :
80 llvm::enumerate(getOperation()->getRegion(0).getBlocks().front())) {
81 // TODO: ideally check that the IR is already fully elaborated
82 opIndices[&op] = i;
83 maxIdx = i;
84 }
85
86 // Collect all the register intervals we have to consider.
87 SmallVector<std::unique_ptr<RegisterLiveRange>> regRanges;
88 SmallVector<RegisterLiveRange *> active;
89 for (auto &op : getOperation()->getRegion(0).getBlocks().front()) {
90 if (!isa<rtg::ConstantOp, rtg::VirtualRegisterOp>(&op) ||
91 !isa<rtg::RegisterTypeInterface>(op.getResult(0).getType()))
92 continue;
93
94 RegisterLiveRange lr;
95 lr.start = maxIdx;
96 lr.end = 0;
97
98 if (auto regOp = dyn_cast<rtg::VirtualRegisterOp>(&op))
99 lr.regOp = regOp;
100
101 if (auto regOp = dyn_cast<rtg::ConstantOp>(&op)) {
102 auto reg = dyn_cast<rtg::RegisterAttrInterface>(regOp.getValue());
103 if (!reg) {
104 op.emitError("expected register attribute");
105 return signalPassFailure();
106 }
107 lr.fixedReg = reg;
108 }
109
110 for (auto *user : op.getUsers()) {
111 if (!isa<rtg::InstructionOpInterface, rtg::ValidateOp>(user)) {
112 user->emitError("only operations implementing 'InstructionOpInterface' "
113 "and 'rtg.validate' are allowed to use registers");
114 return signalPassFailure();
115 }
116
117 // TODO: support labels and control-flow loops (jumps in general)
118 unsigned idx = opIndices.at(user);
119 lr.start = std::min(lr.start, idx);
120 lr.end = std::max(lr.end, idx);
121 }
122
123 regRanges.emplace_back(std::make_unique<RegisterLiveRange>(lr));
124
125 // Reserve fixed registers from the start. It will be made available again
126 // past the interval end. Not reserving it from the start can lead to the
127 // same register being chosen for a virtual register that overlaps with the
128 // fixed register interval.
129 // TODO: don't overapproximate that much
130 if (!lr.regOp)
131 active.push_back(regRanges.back().get());
132 }
133
134 // Sort such that we can process registers by increasing interval start.
135 llvm::sort(regRanges, [](const auto &a, const auto &b) {
136 return a->start < b->start || (a->start == b->start && !a->regOp);
137 });
138
139 for (auto &lr : regRanges) {
140 // Make registers out of live range available again.
141 expireOldInterval(active, lr.get());
142
143 // Handle already fixed registers.
144 if (!lr->regOp)
145 continue;
146
147 // Handle virtual registers.
148 auto configAttr =
149 cast<rtg::VirtualRegisterConfigAttr>(lr->regOp.getAllowedRegsAttr());
150 rtg::RegisterAttrInterface availableReg;
151 for (auto reg : configAttr.getAllowedRegs()) {
152 if (llvm::none_of(active, [&](auto *r) { return r->fixedReg == reg; })) {
153 availableReg = cast<rtg::RegisterAttrInterface>(reg);
154 break;
155 }
156 }
157
158 if (!availableReg) {
159 ++numRegistersSpilled;
160 lr->regOp->emitError(
161 "need to spill this register, but not supported yet");
162 return signalPassFailure();
163 }
164
165 lr->fixedReg = availableReg;
166 active.push_back(lr.get());
167 }
168
169 LLVM_DEBUG({
170 for (auto &regRange : regRanges) {
171 llvm::dbgs() << "Start: " << regRange->start << ", End: " << regRange->end
172 << ", Selected: " << regRange->fixedReg << "\n";
173 }
174 llvm::dbgs() << "\n";
175 });
176
177 for (auto &reg : regRanges) {
178 // No need to fix already fixed registers.
179 if (!reg->regOp)
180 continue;
181
182 IRRewriter rewriter(reg->regOp);
183 rewriter.replaceOpWithNewOp<rtg::ConstantOp>(reg->regOp, reg->fixedReg);
184 }
185}
static void expireOldInterval(SmallVector< RegisterLiveRange * > &active, RegisterLiveRange *reg)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition rtg.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21