CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWVectorization.cpp
Go to the documentation of this file.
1//===- HWVectorization.cpp - HW Vectorization Pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass performs structural vectorization of hardware modules,
10// merging scalar bit-level assignments into vectorized operations.
11// This version handles linear, reverse, and mix permutation and
12// structural vectorization using bit-tracking.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/RegionUtils.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/TypeSwitch.h"
28
29namespace circt {
30namespace hw {
31#define GEN_PASS_DEF_HWVECTORIZATION
32#include "circt/Dialect/HW/Passes.h.inc"
33} // namespace hw
34} // namespace circt
35
36using namespace mlir;
37using namespace circt;
38using namespace comb;
39using namespace hw;
40
41namespace {
42
43/// Represents a specific bit from a source SSA Value.
44struct Bit {
45 Value source;
46 int index;
47
48 Bit(Value source, int index) : source(source), index(index) {}
49 Bit() : source(nullptr), index(0) {}
50};
51
52/// Maintains a mapping of bit indices to their source origins.
53/// Uses SmallVector since the map is always dense (size == bitWidth).
54struct BitArray {
55 // Each element at position i holds the Bit for output bit i.
56 // An unset entry has source == nullptr.
57 llvm::SmallVector<Bit> bits;
58
59 /// Checks if all bits form a linear sequence: output[i] <- source[i].
60 bool isLinear(int size, Value sourceInput) const {
61 if (bits.size() != static_cast<size_t>(size))
62 return false;
63 for (int i = 0; i < size; ++i) {
64 if (bits[i].source != sourceInput || bits[i].index != i)
65 return false;
66 }
67 return true;
68 }
69
70 /// Checks if all bits form a reverse sequence: output[i] <- source[N-1-i].
71 bool isReverse(int size, Value sourceInput) const {
72 if (bits.size() != static_cast<size_t>(size))
73 return false;
74 for (int i = 0; i < size; ++i) {
75 if (bits[i].source != sourceInput || (size - 1) - i != bits[i].index)
76 return false;
77 }
78 return true;
79 }
80
81 /// Returns the single source Value if all tracked bits share the same source.
82 Value getSingleSourceValue() const {
83 Value source = nullptr;
84 for (const auto &bit : bits) {
85 if (!bit.source)
86 return nullptr;
87 if (!source)
88 source = bit.source;
89 else if (source != bit.source)
90 return nullptr;
91 }
92 return source;
93 }
94
95 size_t size() const { return bits.size(); }
96};
97
98class Vectorizer {
99public:
100 Vectorizer(hw::HWModuleOp module) : module(module) {}
101
102 /// Entry point: analyze provenance then apply the best vectorization for
103 /// each output port.
104 void vectorize() {
105 // Phase 1: populate `bitArrays` by walking all ops in program order.
106 processOps();
107
108 auto outputOp =
109 dyn_cast<hw::OutputOp>(module.getBodyBlock()->getTerminator());
110 if (!outputOp)
111 return;
112
113 IRRewriter rewriter(module.getContext());
114 bool changed = false;
115
116 // Phase 2: for each integer output, attempt vectorization strategies in
117 // order of increasing complexity:
118 // (a) Linear – direct wire-through -> drop the concat entirely
119 // (b) Reverse – mirror permutation -> comb.reverse
120 // (c) Mix – arbitrary bijection -> extract + concat chunks
121 // (d) Structural – isomorphic scalar cones -> wide AND/OR/XOR/MUX
122 for (Value oldOutputVal : outputOp->getOperands()) {
123 auto type = dyn_cast<IntegerType>(oldOutputVal.getType());
124 if (!type)
125 continue;
126
127 unsigned bitWidth = type.getWidth();
128 auto it = bitArrays.find(oldOutputVal);
129 if (it == bitArrays.end())
130 continue;
131
132 BitArray &arr = it->second;
133 if (arr.size() != bitWidth)
134 continue;
135
136 Value sourceInput = arr.getSingleSourceValue();
137
138 // `transformed` tracks whether *this* output was successfully vectorized.
139 // It must be local to each iteration so that a successful transform on
140 // one output port does not suppress strategy (d) for a later port.
141 bool transformed = false;
142
143 // 1. Try vectorizing from a single source (Linear, Reverse, Mix).
144 if (sourceInput) {
145 if (arr.isLinear(bitWidth, sourceInput)) {
146 oldOutputVal.replaceAllUsesWith(sourceInput);
147 transformed = true;
148 } else if (arr.isReverse(bitWidth, sourceInput)) {
149 rewriter.setInsertionPointAfterValue(sourceInput);
150 Value reversed =
151 comb::ReverseOp::create(rewriter, sourceInput.getLoc(),
152 sourceInput.getType(), sourceInput);
153 oldOutputVal.replaceAllUsesWith(reversed);
154 transformed = true;
155 } else if (isValidPermutation(arr, bitWidth)) {
156 applyMixVectorization(rewriter, oldOutputVal, sourceInput, arr,
157 bitWidth);
158 transformed = true;
159 }
160 }
161
162 // 2. If it wasn't vectorized (or if it has multiple sources), try
163 // Structural.
164 if (!transformed && !hasCrossBitDependencies(oldOutputVal) &&
165 canVectorizeStructurally(oldOutputVal)) {
166 rewriter.setInsertionPointAfterValue(oldOutputVal);
167
168 unsigned width = cast<IntegerType>(oldOutputVal.getType()).getWidth();
169 Value slice0 = findBitSource(oldOutputVal, 0);
170 if (slice0) {
171 DenseMap<Value, Value> vectorizedMap;
172 Value vec = vectorizeSubgraph(rewriter, slice0, width, vectorizedMap);
173 if (vec) {
174 oldOutputVal.replaceAllUsesWith(vec);
175 transformed = true;
176 }
177 }
178 }
179
180 if (transformed)
181 changed = true;
182 }
183
184 if (changed)
185 (void)mlir::runRegionDCE(rewriter, module.getBody());
186 }
187
188private:
189 /// Maps each SSA Value to its bit-level provenance after the analysis phase.
190 llvm::DenseMap<Value, BitArray> bitArrays;
191 hw::HWModuleOp module;
192
193 /// Analyzes the logic cones of all bit lanes to detect illegal cross-bit
194 /// dependencies in O(bitWidth + N) time.
195 bool hasCrossBitDependencies(mlir::Value outputVal) {
196 unsigned bitWidth = cast<IntegerType>(outputVal.getType()).getWidth();
197
198 llvm::DenseSet<mlir::Value> visitedUnsafe; // Accumulate unsafe values.
199 llvm::SmallVector<mlir::Value> worklist;
200
201 for (unsigned i = 0; i < bitWidth; ++i) {
202 llvm::DenseSet<mlir::Value> visitedLocal;
203 mlir::Value bitSource = findBitSource(outputVal, i);
204 if (!bitSource)
205 continue;
206
207 worklist.push_back(bitSource);
208 while (!worklist.empty()) {
209 auto top = worklist.pop_back_val();
210 if (isSafeSharedValue(top))
211 continue; // don't add to the set.
212 if (!visitedLocal.insert(top).second)
213 continue; // Arriving multiple time in the same iteration is fine.
214 // If it's already visited, there is a depencency
215 if (!visitedUnsafe.insert(top).second)
216 return true;
217 if (auto *op = top.getDefiningOp()) {
218 for (auto operand : op->getOperands())
219 worklist.push_back(operand);
220 }
221 }
222 }
223 return false;
224 }
225
226 /// Determines if a shared value is safe for vectorization. Only constants
227 /// and block arguments are safe to share between bit lanes. Any intermediate
228 /// operation is considered unsafe as it may introduce cross-lane
229 /// dependencies.
230 bool isSafeSharedValue(mlir::Value val) {
231 return val &&
232 (isa<BlockArgument>(val) || val.getDefiningOp<hw::ConstantOp>());
233 }
234
235 /// Checks if a logic cone is composed of structurally equivalent slices
236 /// that can be merged into a vector operation.
237 ///
238 /// The check succeeds when every bit slice i of the output is produced by a
239 /// subgraph that is isomorphic to the bit-0 subgraph (slice0).
240 bool canVectorizeStructurally(Value output) {
241 unsigned bitWidth = cast<IntegerType>(output.getType()).getWidth();
242 if (bitWidth <= 1)
243 return false;
244
245 Value slice0Val = findBitSource(output, 0);
246 if (!slice0Val)
247 return false;
248
249 Value slice1Val = findBitSource(output, 1);
250 if (!slice1Val)
251 return false;
252
253 auto extract0 = slice0Val.getDefiningOp<comb::ExtractOp>();
254 auto extract1 = slice1Val.getDefiningOp<comb::ExtractOp>();
255
256 if (!extract0 || !extract1 || extract0.getInput() != extract1.getInput()) {
257 for (unsigned i = 1; i < bitWidth; ++i) {
258 Value sliceNVal = findBitSource(output, i);
259 if (!sliceNVal || !sliceNVal.getDefiningOp())
260 return false;
261 llvm::DenseMap<mlir::Value, mlir::Value> map;
262 if (!areSubgraphsEquivalent(slice0Val, sliceNVal, i, 1, map)) {
263 return false;
264 }
265 }
266 return true;
267 }
268
269 int stride = (int)extract1.getLowBit() - (int)extract0.getLowBit();
270
271 for (unsigned i = 1; i < bitWidth; ++i) {
272 Value sliceNVal = findBitSource(output, i);
273 if (!sliceNVal || !sliceNVal.getDefiningOp())
274 return false;
275
276 llvm::DenseMap<mlir::Value, mlir::Value> map;
277 if (!areSubgraphsEquivalent(slice0Val, sliceNVal, i, stride, map)) {
278 return false;
279 }
280 }
281 return true;
282 }
283
284 /// Recursively compares two subgraphs to determine if they are isomorphic
285 /// with respect to a constant bit-stride.
286 ///
287 /// It assumes that all ExtractOp low-bit indices in the second subgraph
288 /// are exactly (sliceIndex * stride) greater than those in the first.
289 /// Caches results in slice0ToNMap to handle DAGs efficiently.
290 bool areSubgraphsEquivalent(Value slice0Val, Value sliceNVal,
291 unsigned sliceIndex, int stride,
292 DenseMap<Value, Value> &slice0ToNMap) {
293
294 if (slice0ToNMap.count(slice0Val))
295 return slice0ToNMap[slice0Val] == sliceNVal;
296
297 Operation *op0 = slice0Val.getDefiningOp();
298 Operation *opN = sliceNVal.getDefiningOp();
299
300 if (auto extract0 = dyn_cast_or_null<comb::ExtractOp>(op0)) {
301 auto extractN = dyn_cast_or_null<comb::ExtractOp>(opN);
302
303 if (extractN && extract0.getInput() == extractN.getInput() &&
304 extractN.getLowBit() ==
305 static_cast<unsigned>(static_cast<int>(extract0.getLowBit()) +
306 static_cast<int>(sliceIndex) * stride)) {
307 slice0ToNMap[slice0Val] = sliceNVal;
308 return true;
309 }
310 return false;
311 }
312
313 if (slice0Val == sliceNVal && (mlir::isa<BlockArgument>(slice0Val) ||
314 mlir::isa<hw::ConstantOp>(op0))) {
315 slice0ToNMap[slice0Val] = sliceNVal;
316 return true;
317 }
318
319 if (!op0 || !opN || op0->getName() != opN->getName() ||
320 op0->getNumOperands() != opN->getNumOperands())
321 return false;
322
323 for (unsigned i = 0; i < op0->getNumOperands(); ++i) {
324 if (!areSubgraphsEquivalent(op0->getOperand(i), opN->getOperand(i),
325 sliceIndex, stride, slice0ToNMap))
326 return false;
327 }
328
329 slice0ToNMap[slice0Val] = sliceNVal;
330 return true;
331 }
332
333 /// Traverses through ConcatOps and basic logic gates to locate the
334 /// original 1-bit source for a specific bit index.
335 ///
336 /// Returns the 1-bit Value or nullptr if the bit cannot be traced back
337 /// to a concrete scalar source
338 Value findBitSource(Value vectorVal, unsigned bitIndex) {
339
340 if (auto blockArg = dyn_cast<BlockArgument>(vectorVal)) {
341 if (blockArg.getType().isInteger(1))
342 return blockArg;
343 return nullptr;
344 }
345
346 Operation *op = vectorVal.getDefiningOp();
347
348 if (op->getNumResults() == 1 && op->getResult(0).getType().isInteger(1)) {
349 return op->getResult(0);
350 }
351
352 if (auto concat = dyn_cast<comb::ConcatOp>(op)) {
353 unsigned currentBit = cast<IntegerType>(vectorVal.getType()).getWidth();
354 for (Value operand : concat.getInputs()) {
355 unsigned operandWidth = cast<IntegerType>(operand.getType()).getWidth();
356 currentBit -= operandWidth;
357 if (bitIndex >= currentBit && bitIndex < currentBit + operandWidth) {
358 return findBitSource(operand, bitIndex - currentBit);
359 }
360 }
361 } else if (auto orOp = dyn_cast<comb::OrOp>(op)) {
362 if (orOp.getNumOperands() != 2)
363 return nullptr;
364
365 Value lhs = orOp.getInputs()[0];
366 Value rhs = orOp.getInputs()[1];
367
368 if (auto constRhs =
369 dyn_cast_or_null<hw::ConstantOp>(rhs.getDefiningOp())) {
370 if (!constRhs.getValue()[bitIndex])
371 return findBitSource(lhs, bitIndex);
372 }
373
374 if (auto constLhs =
375 dyn_cast_or_null<hw::ConstantOp>(lhs.getDefiningOp())) {
376 if (!constLhs.getValue()[bitIndex])
377 return findBitSource(rhs, bitIndex);
378 }
379 } else if (auto andOp = dyn_cast<comb::AndOp>(op)) {
380 if (andOp.getNumOperands() != 2)
381 return nullptr;
382
383 Value lhs = andOp.getInputs()[0];
384 Value rhs = andOp.getInputs()[1];
385
386 if (auto constRhs =
387 dyn_cast_or_null<hw::ConstantOp>(rhs.getDefiningOp())) {
388 if (constRhs.getValue()[bitIndex])
389 return findBitSource(lhs, bitIndex);
390 }
391
392 if (auto constLhs =
393 dyn_cast_or_null<hw::ConstantOp>(lhs.getDefiningOp())) {
394 if (constLhs.getValue()[bitIndex])
395 return findBitSource(rhs, bitIndex);
396 }
397 }
398
399 return nullptr;
400 }
401
402 /// Recursively builds the vectorized counterpart of a scalar subgraph.
403 ///
404 /// `scalarRoot` is the 1-bit root of the scalar bit-0 cone.
405 /// `width` is the target vector width (= number of bit lanes).
406 /// `map` caches already-vectorized scalar values to avoid duplicate work.
407 ///
408 /// Returns nullptr if any node in the subgraph cannot be vectorized.
409 Value vectorizeSubgraph(OpBuilder &b, Value scalarRoot, unsigned width,
410 DenseMap<Value, Value> &map) {
411 if (map.count(scalarRoot))
412 return map[scalarRoot];
413
414 // Base case: an ExtractOp represents one bit lane of a wider source vector.
415 // Return that source vector directly; the other lanes are handled by the
416 // isomorphic slices discovered in canVectorizeStructurally
417 if (auto ex =
418 dyn_cast_or_null<comb::ExtractOp>(scalarRoot.getDefiningOp())) {
419 Value vec = ex.getInput();
420 map[scalarRoot] = vec;
421 return vec;
422 }
423
424 // Base case: a 1-bit constant or block argument (e.g., a shared selector)
425 // must be broadcast to all `width` lanes via comb.replicate.
426 if (isSafeSharedValue(scalarRoot)) {
427 if (cast<IntegerType>(scalarRoot.getType()).getWidth() == 1)
428 return comb::ReplicateOp::create(b, scalarRoot.getLoc(),
429 b.getIntegerType(width), scalarRoot);
430 // Wider constants are already the right width; pass through unchanged.
431 return scalarRoot;
432 }
433
434 Operation *op = scalarRoot.getDefiningOp();
435 if (!op)
436 return nullptr;
437
438 // Recursively vectorize all operands before creating the wide op.
439 SmallVector<Value> ops;
440 for (Value operand : op->getOperands()) {
441 Value v = vectorizeSubgraph(b, operand, width, map);
442 if (!v)
443 return map[scalarRoot] = nullptr;
444 ops.push_back(v);
445 }
446
447 Type vecTy = b.getIntegerType(width);
448 Value result;
449
450 // Lift the scalar op to its N-bit equivalent.
451 if (isa<comb::AndOp>(op))
452 result = comb::AndOp::create(b, op->getLoc(), vecTy, ops);
453 else if (isa<comb::OrOp>(op))
454 result = comb::OrOp::create(b, op->getLoc(), vecTy, ops);
455 else if (isa<comb::XorOp>(op))
456 result = comb::XorOp::create(b, op->getLoc(), vecTy, ops);
457 else if (auto muxOp = dyn_cast<comb::MuxOp>(op)) {
458 Value sel = muxOp.getCond();
459 result = comb::MuxOp::create(b, muxOp.getLoc(), sel, ops[1], ops[2]);
460 } else
461 // Unsupported op kind; signal failure to the caller.
462 return nullptr;
463
464 map[scalarRoot] = result;
465 return result;
466 }
467
468 /// Checks that all bit indices are in [0, bitWidth] and form a bijection.
469 /// Guards applyMixVectorization against malformed BitArrays.
470 bool isValidPermutation(const BitArray &arr, unsigned bitWidth) {
471 if (arr.size() != bitWidth)
472 return false;
473 llvm::SmallBitVector seen(bitWidth);
474 for (const auto &bit : arr.bits) {
475 assert(bit.index >= 0);
476 if (bit.index >= static_cast<int>(bitWidth) || seen.test(bit.index))
477 return false;
478 seen.set(bit.index);
479 }
480 return true;
481 }
482
483 /// Handles arbitrary permutations from a single source by grouping runs of
484 /// consecutive source-bit indices into ExtractOps, then concatenating them.
485 ///
486 /// Example: bits = [2, 3, 0, 1] produces:
487 /// %0 = extract src[2:1] // bits 0-1 of output <- source[2:3]
488 /// %1 = extract src[0:1] // bits 2-3 of output <- source[0:1]
489 /// %out = concat(%1, %0) // MSB->LSB order
490 void applyMixVectorization(IRRewriter &rewriter, Value oldOutputVal,
491 Value sourceInput, const BitArray &arr,
492 unsigned bitWidth) {
493 rewriter.setInsertionPointAfterValue(sourceInput);
494 Location loc = sourceInput.getLoc();
495
496 // Walk output bits LSB->MSB, greedily extending each run while source
497 // indices remain consecutive.
498 llvm::SmallVector<Value> chunks;
499 unsigned i = 0;
500 while (i < bitWidth) {
501 unsigned startBit = arr.bits[i].index;
502 unsigned len = 1;
503 while (i + len < bitWidth &&
504 arr.bits[i + len].index == static_cast<int>(startBit + len))
505 ++len;
506
507 chunks.push_back(comb::ExtractOp::create(
508 rewriter, loc, rewriter.getIntegerType(len), sourceInput, startBit));
509 i += len;
510 }
511
512 // comb.concat expects operands MSB->LSB, so reverse the chunk list.
513 std::reverse(chunks.begin(), chunks.end());
514
515 Value newVal = comb::ConcatOp::create(
516 rewriter, loc, rewriter.getIntegerType(bitWidth), chunks);
517
518 oldOutputVal.replaceAllUsesWith(newVal);
519 }
520
521 /// Single walk that handles ExtractOp and ConcatOp using TypeSwitch.
522 void processOps() {
523 module.walk([&](Operation *op) {
524 llvm::TypeSwitch<Operation *>(op)
525 .Case<comb::ExtractOp>([&](comb::ExtractOp extractOp) {
526 // Only handle single-bit extracts; skip multi-bit ranges.
527 auto resultType =
528 dyn_cast<IntegerType>(extractOp.getResult().getType());
529 if (!resultType || resultType.getWidth() != 1)
530 return;
531
532 BitArray bits;
533 bits.bits.push_back(
534 Bit(extractOp.getInput(), extractOp.getLowBit()));
535 bitArrays[extractOp.getResult()] = bits;
536 })
537 .Case<comb::ConcatOp>([&](comb::ConcatOp concatOp) {
538 auto resultType =
539 dyn_cast<IntegerType>(concatOp.getResult().getType());
540 if (!resultType)
541 return;
542
543 unsigned totalWidth = resultType.getWidth();
544 BitArray concatenatedArray;
545 concatenatedArray.bits.resize(totalWidth);
546
547 unsigned currentBitOffset = 0;
548 for (Value operand : llvm::reverse(concatOp.getInputs())) {
549 unsigned operandWidth =
550 cast<IntegerType>(operand.getType()).getWidth();
551 auto it = bitArrays.find(operand);
552 if (it != bitArrays.end()) {
553 for (unsigned i = 0; i < it->second.bits.size(); ++i)
554 concatenatedArray.bits[i + currentBitOffset] =
555 it->second.bits[i];
556 }
557 currentBitOffset += operandWidth;
558 }
559 bitArrays[concatOp.getResult()] = concatenatedArray;
560 })
561 .Case<comb::AndOp, comb::OrOp, comb::XorOp, comb::MuxOp>(
562 [&](Operation *op) {
563 auto result = op->getResult(0);
564 auto resultType = dyn_cast<IntegerType>(result.getType());
565 if (resultType && resultType.getWidth() == 1) {
566 BitArray arr;
567 arr.bits.push_back(Bit(result, 0));
568 bitArrays[result] = arr;
569 }
570 });
571 });
572 }
573};
574
575struct HWVectorizationPass
576 : public hw::impl::HWVectorizationBase<HWVectorizationPass> {
577
578 void runOnOperation() override {
579 Vectorizer v(getOperation());
580 v.vectorize();
581 }
582};
583
584} // namespace
assert(baseType &&"element must be base type")
create(low_bit, result_type, input=None)
Definition comb.py:187
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1
Definition hw.py:1