CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWVectorization.cpp
Go to the documentation of this file.
1//===- HWVectorization.cpp - HW Vectorization Pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass performs structural vectorization of hardware modules,
10// merging scalar bit-level assignments into vectorized operations.
11// This version handles linear, reverse, and mix vectorization using
12// bit-tracking.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/RegionUtils.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/TypeSwitch.h"
28
29namespace circt {
30namespace hw {
31#define GEN_PASS_DEF_HWVECTORIZATION
32#include "circt/Dialect/HW/Passes.h.inc"
33} // namespace hw
34} // namespace circt
35
36using namespace mlir;
37using namespace circt;
38using namespace comb;
39using namespace hw;
40
41namespace {
42
43/// Represents a specific bit from a source SSA Value.
44struct Bit {
45 Value source;
46 int index;
47
48 Bit(Value source, int index) : source(source), index(index) {}
49 Bit() : source(nullptr), index(0) {}
50};
51
52/// Maintains a mapping of bit indices to their source origins.
53/// Uses SmallVector since the map is always dense (size == bitWidth).
54struct BitArray {
55 // Each element at position i holds the Bit for output bit i.
56 // An unset entry has source == nullptr.
57 llvm::SmallVector<Bit> bits;
58
59 /// Checks if all bits form a linear sequence: output[i] <- source[i].
60 bool isLinear(int size, Value sourceInput) const {
61 if (bits.size() != static_cast<size_t>(size))
62 return false;
63 for (int i = 0; i < size; ++i) {
64 if (bits[i].source != sourceInput || bits[i].index != i)
65 return false;
66 }
67 return true;
68 }
69
70 /// Checks if all bits form a reverse sequence: output[i] <- source[N-1-i].
71 bool isReverse(int size, Value sourceInput) const {
72 if (bits.size() != static_cast<size_t>(size))
73 return false;
74 for (int i = 0; i < size; ++i) {
75 if (bits[i].source != sourceInput || (size - 1) - i != bits[i].index)
76 return false;
77 }
78 return true;
79 }
80
81 /// Returns the single source Value if all tracked bits share the same source.
82 Value getSingleSourceValue() const {
83 Value source = nullptr;
84 for (const auto &bit : bits) {
85 if (!bit.source)
86 return nullptr;
87 if (!source)
88 source = bit.source;
89 else if (source != bit.source)
90 return nullptr;
91 }
92 return source;
93 }
94
95 size_t size() const { return bits.size(); }
96};
97
98class Vectorizer {
99public:
100 Vectorizer(hw::HWModuleOp module) : module(module) {}
101
102 /// Analyzes bit-level provenance and applies vectorization transforms.
103 void vectorize() {
104 processOps();
105
106 auto outputOp =
107 dyn_cast<hw::OutputOp>(module.getBodyBlock()->getTerminator());
108 if (!outputOp)
109 return;
110
111 IRRewriter rewriter(module.getContext());
112 bool changed = false;
113
114 for (Value oldOutputVal : outputOp->getOperands()) {
115 auto type = dyn_cast<IntegerType>(oldOutputVal.getType());
116 if (!type)
117 continue;
118
119 unsigned bitWidth = type.getWidth();
120 auto it = bitArrays.find(oldOutputVal);
121 if (it == bitArrays.end())
122 continue;
123
124 BitArray &arr = it->second;
125 if (arr.size() != bitWidth)
126 continue;
127
128 Value sourceInput = arr.getSingleSourceValue();
129 if (!sourceInput)
130 continue;
131
132 if (arr.isLinear(bitWidth, sourceInput)) {
133 oldOutputVal.replaceAllUsesWith(sourceInput);
134 changed = true;
135 } else if (arr.isReverse(bitWidth, sourceInput)) {
136 rewriter.setInsertionPointAfterValue(sourceInput);
137 Value reversed = comb::ReverseOp::create(
138 rewriter, sourceInput.getLoc(), sourceInput.getType(), sourceInput);
139 oldOutputVal.replaceAllUsesWith(reversed);
140 changed = true;
141 } else if (isValidPermutation(arr, bitWidth)) {
142 applyMixVectorization(rewriter, oldOutputVal, sourceInput, arr,
143 bitWidth);
144 changed = true;
145 }
146 }
147
148 if (changed)
149 (void)mlir::runRegionDCE(rewriter, module.getBody());
150 }
151
152private:
153 /// Maps values to their decomposed bit provenance.
154 llvm::DenseMap<Value, BitArray> bitArrays;
155 hw::HWModuleOp module;
156
157 /// Checks that all bit indices are in [0, bitWidth] and form a bijection.
158 /// Guards applyMixVectorization against malformed BitArrays.
159 bool isValidPermutation(const BitArray &arr, unsigned bitWidth) {
160 if (arr.size() != bitWidth)
161 return false;
162 llvm::SmallBitVector seen(bitWidth);
163 for (const auto &bit : arr.bits) {
164 assert(bit.index >= 0);
165 if (bit.index >= static_cast<int>(bitWidth) || seen.test(bit.index))
166 return false;
167 seen.set(bit.index);
168 }
169 return true;
170 }
171
172 /// Handles arbitrary permutations from a single source by grouping runs of
173 /// consecutive source-bit indices into ExtractOps, then concatenating them.
174 ///
175 /// Example: bits = [2, 3, 0, 1] produces:
176 /// %0 = extract src[2:1] // bits 0-1 of output <- source[2:3]
177 /// %1 = extract src[0:1] // bits 2-3 of output <- source[0:1]
178 /// %out = concat(%1, %0) // MSB->LSB order
179 void applyMixVectorization(IRRewriter &rewriter, Value oldOutputVal,
180 Value sourceInput, const BitArray &arr,
181 unsigned bitWidth) {
182 rewriter.setInsertionPointAfterValue(sourceInput);
183 Location loc = sourceInput.getLoc();
184
185 // Walk output bits LSB->MSB, greedily extending each run while source
186 // indices remain consecutive.
187 llvm::SmallVector<Value> chunks;
188 unsigned i = 0;
189 while (i < bitWidth) {
190 unsigned startBit = arr.bits[i].index;
191 unsigned len = 1;
192 while (i + len < bitWidth &&
193 arr.bits[i + len].index == static_cast<int>(startBit + len))
194 ++len;
195
196 chunks.push_back(comb::ExtractOp::create(
197 rewriter, loc, rewriter.getIntegerType(len), sourceInput, startBit));
198 i += len;
199 }
200
201 // comb.concat expects operands MSB->LSB, so reverse the chunk list.
202 std::reverse(chunks.begin(), chunks.end());
203
204 Value newVal = comb::ConcatOp::create(
205 rewriter, loc, rewriter.getIntegerType(bitWidth), chunks);
206
207 oldOutputVal.replaceAllUsesWith(newVal);
208 }
209
210 /// Single walk that handles ExtractOp and ConcatOp using TypeSwitch.
211 void processOps() {
212 module.walk([&](Operation *op) {
213 llvm::TypeSwitch<Operation *>(op)
214 .Case<comb::ExtractOp>([&](comb::ExtractOp extractOp) {
215 // Only handle single-bit extracts; skip multi-bit ranges.
216 auto resultType =
217 dyn_cast<IntegerType>(extractOp.getResult().getType());
218 if (!resultType || resultType.getWidth() != 1)
219 return;
220
221 BitArray bits;
222 bits.bits.push_back(
223 Bit(extractOp.getInput(), extractOp.getLowBit()));
224 bitArrays.insert({extractOp.getResult(), bits});
225 })
226 .Case<comb::ConcatOp>([&](comb::ConcatOp concatOp) {
227 auto resultType =
228 dyn_cast<IntegerType>(concatOp.getResult().getType());
229 if (!resultType)
230 return;
231
232 unsigned totalWidth = resultType.getWidth();
233 BitArray concatenatedArray;
234 concatenatedArray.bits.resize(totalWidth);
235
236 unsigned currentBitOffset = 0;
237 for (Value operand : llvm::reverse(concatOp.getInputs())) {
238 unsigned operandWidth =
239 cast<IntegerType>(operand.getType()).getWidth();
240 auto it = bitArrays.find(operand);
241 if (it != bitArrays.end()) {
242 for (unsigned i = 0; i < it->second.bits.size(); ++i)
243 concatenatedArray.bits[i + currentBitOffset] =
244 it->second.bits[i];
245 }
246 currentBitOffset += operandWidth;
247 }
248 bitArrays.insert({concatOp.getResult(), concatenatedArray});
249 });
250 });
251 }
252};
253
254struct HWVectorizationPass
255 : public hw::impl::HWVectorizationBase<HWVectorizationPass> {
256
257 void runOnOperation() override {
258 Vectorizer v(getOperation());
259 v.vectorize();
260 }
261};
262
263} // namespace
assert(baseType &&"element must be base type")
create(low_bit, result_type, input=None)
Definition comb.py:187
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1
Definition hw.py:1