21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/RegionUtils.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/TypeSwitch.h"
31#define GEN_PASS_DEF_HWVECTORIZATION
32#include "circt/Dialect/HW/Passes.h.inc"
48 Bit(Value source,
int index) : source(source), index(index) {}
49 Bit() : source(nullptr), index(0) {}
57 llvm::SmallVector<Bit> bits;
60 bool isLinear(
int size, Value sourceInput)
const {
61 if (bits.size() !=
static_cast<size_t>(size))
63 for (
int i = 0; i < size; ++i) {
64 if (bits[i].source != sourceInput || bits[i].index != i)
71 bool isReverse(
int size, Value sourceInput)
const {
72 if (bits.size() !=
static_cast<size_t>(size))
74 for (
int i = 0; i < size; ++i) {
75 if (bits[i].source != sourceInput || (size - 1) - i != bits[i].index)
82 Value getSingleSourceValue()
const {
83 Value source =
nullptr;
84 for (
const auto &bit : bits) {
89 else if (source != bit.source)
95 size_t size()
const {
return bits.size(); }
109 dyn_cast<hw::OutputOp>(module.getBodyBlock()->getTerminator());
113 IRRewriter rewriter(module.getContext());
114 bool changed =
false;
122 for (Value oldOutputVal : outputOp->getOperands()) {
123 auto type = dyn_cast<IntegerType>(oldOutputVal.getType());
127 unsigned bitWidth = type.getWidth();
128 auto it = bitArrays.find(oldOutputVal);
129 if (it == bitArrays.end())
132 BitArray &
arr = it->second;
133 if (
arr.size() != bitWidth)
136 Value sourceInput =
arr.getSingleSourceValue();
141 bool transformed =
false;
145 if (
arr.isLinear(bitWidth, sourceInput)) {
146 oldOutputVal.replaceAllUsesWith(sourceInput);
148 }
else if (
arr.isReverse(bitWidth, sourceInput)) {
149 rewriter.setInsertionPointAfterValue(sourceInput);
151 comb::ReverseOp::create(rewriter, sourceInput.getLoc(),
152 sourceInput.getType(), sourceInput);
153 oldOutputVal.replaceAllUsesWith(reversed);
155 }
else if (isValidPermutation(arr, bitWidth)) {
156 applyMixVectorization(rewriter, oldOutputVal, sourceInput, arr,
164 if (!transformed && !hasCrossBitDependencies(oldOutputVal) &&
165 canVectorizeStructurally(oldOutputVal)) {
166 rewriter.setInsertionPointAfterValue(oldOutputVal);
168 unsigned width = cast<IntegerType>(oldOutputVal.getType()).getWidth();
169 Value slice0 = findBitSource(oldOutputVal, 0);
171 DenseMap<Value, Value> vectorizedMap;
172 Value vec = vectorizeSubgraph(rewriter, slice0, width, vectorizedMap);
174 oldOutputVal.replaceAllUsesWith(vec);
185 (void)mlir::runRegionDCE(rewriter, module.getBody());
190 llvm::DenseMap<Value, BitArray> bitArrays;
195 bool hasCrossBitDependencies(mlir::Value outputVal) {
196 unsigned bitWidth = cast<IntegerType>(outputVal.getType()).getWidth();
198 llvm::DenseSet<mlir::Value> visitedUnsafe;
199 llvm::SmallVector<mlir::Value> worklist;
201 for (
unsigned i = 0; i < bitWidth; ++i) {
202 llvm::DenseSet<mlir::Value> visitedLocal;
203 mlir::Value bitSource = findBitSource(outputVal, i);
207 worklist.push_back(bitSource);
208 while (!worklist.empty()) {
209 auto top = worklist.pop_back_val();
210 if (isSafeSharedValue(top))
212 if (!visitedLocal.insert(top).second)
215 if (!visitedUnsafe.insert(top).second)
217 if (
auto *op = top.getDefiningOp()) {
218 for (
auto operand : op->getOperands())
219 worklist.push_back(operand);
230 bool isSafeSharedValue(mlir::Value val) {
240 bool canVectorizeStructurally(Value output) {
241 unsigned bitWidth = cast<IntegerType>(output.getType()).getWidth();
245 Value slice0Val = findBitSource(output, 0);
249 Value slice1Val = findBitSource(output, 1);
256 if (!extract0 || !extract1 || extract0.getInput() != extract1.getInput()) {
257 for (
unsigned i = 1; i < bitWidth; ++i) {
258 Value sliceNVal = findBitSource(output, i);
259 if (!sliceNVal || !sliceNVal.getDefiningOp())
261 llvm::DenseMap<mlir::Value, mlir::Value> map;
262 if (!areSubgraphsEquivalent(slice0Val, sliceNVal, i, 1, map)) {
269 int stride = (int)extract1.getLowBit() - (int)extract0.getLowBit();
271 for (
unsigned i = 1; i < bitWidth; ++i) {
272 Value sliceNVal = findBitSource(output, i);
273 if (!sliceNVal || !sliceNVal.getDefiningOp())
276 llvm::DenseMap<mlir::Value, mlir::Value> map;
277 if (!areSubgraphsEquivalent(slice0Val, sliceNVal, i, stride, map)) {
290 bool areSubgraphsEquivalent(Value slice0Val, Value sliceNVal,
291 unsigned sliceIndex,
int stride,
292 DenseMap<Value, Value> &slice0ToNMap) {
294 if (slice0ToNMap.count(slice0Val))
295 return slice0ToNMap[slice0Val] == sliceNVal;
297 Operation *op0 = slice0Val.getDefiningOp();
298 Operation *opN = sliceNVal.getDefiningOp();
300 if (
auto extract0 = dyn_cast_or_null<comb::ExtractOp>(op0)) {
301 auto extractN = dyn_cast_or_null<comb::ExtractOp>(opN);
303 if (extractN && extract0.getInput() == extractN.getInput() &&
304 extractN.getLowBit() ==
305 static_cast<unsigned>(
static_cast<int>(extract0.getLowBit()) +
306 static_cast<int>(sliceIndex) * stride)) {
307 slice0ToNMap[slice0Val] = sliceNVal;
313 if (slice0Val == sliceNVal && (mlir::isa<BlockArgument>(slice0Val) ||
314 mlir::isa<hw::ConstantOp>(op0))) {
315 slice0ToNMap[slice0Val] = sliceNVal;
319 if (!op0 || !opN || op0->getName() != opN->getName() ||
320 op0->getNumOperands() != opN->getNumOperands())
323 for (
unsigned i = 0; i < op0->getNumOperands(); ++i) {
324 if (!areSubgraphsEquivalent(op0->getOperand(i), opN->getOperand(i),
325 sliceIndex, stride, slice0ToNMap))
329 slice0ToNMap[slice0Val] = sliceNVal;
338 Value findBitSource(Value vectorVal,
unsigned bitIndex) {
340 if (
auto blockArg = dyn_cast<BlockArgument>(vectorVal)) {
341 if (blockArg.getType().isInteger(1))
346 Operation *op = vectorVal.getDefiningOp();
348 if (op->getNumResults() == 1 && op->getResult(0).getType().isInteger(1)) {
349 return op->getResult(0);
352 if (
auto concat = dyn_cast<comb::ConcatOp>(op)) {
353 unsigned currentBit = cast<IntegerType>(vectorVal.getType()).getWidth();
354 for (Value operand : concat.getInputs()) {
355 unsigned operandWidth = cast<IntegerType>(operand.getType()).getWidth();
356 currentBit -= operandWidth;
357 if (bitIndex >= currentBit && bitIndex < currentBit + operandWidth) {
358 return findBitSource(operand, bitIndex - currentBit);
361 }
else if (
auto orOp = dyn_cast<comb::OrOp>(op)) {
362 if (orOp.getNumOperands() != 2)
365 Value lhs = orOp.getInputs()[0];
366 Value rhs = orOp.getInputs()[1];
369 dyn_cast_or_null<hw::ConstantOp>(rhs.getDefiningOp())) {
370 if (!constRhs.getValue()[bitIndex])
371 return findBitSource(lhs, bitIndex);
375 dyn_cast_or_null<hw::ConstantOp>(lhs.getDefiningOp())) {
376 if (!constLhs.getValue()[bitIndex])
377 return findBitSource(rhs, bitIndex);
379 }
else if (
auto andOp = dyn_cast<comb::AndOp>(op)) {
380 if (andOp.getNumOperands() != 2)
383 Value lhs = andOp.getInputs()[0];
384 Value rhs = andOp.getInputs()[1];
387 dyn_cast_or_null<hw::ConstantOp>(rhs.getDefiningOp())) {
388 if (constRhs.getValue()[bitIndex])
389 return findBitSource(lhs, bitIndex);
393 dyn_cast_or_null<hw::ConstantOp>(lhs.getDefiningOp())) {
394 if (constLhs.getValue()[bitIndex])
395 return findBitSource(rhs, bitIndex);
409 Value vectorizeSubgraph(OpBuilder &b, Value scalarRoot,
unsigned width,
410 DenseMap<Value, Value> &map) {
411 if (map.count(scalarRoot))
412 return map[scalarRoot];
418 dyn_cast_or_null<comb::ExtractOp>(scalarRoot.getDefiningOp())) {
419 Value vec = ex.getInput();
420 map[scalarRoot] = vec;
426 if (isSafeSharedValue(scalarRoot)) {
427 if (cast<IntegerType>(scalarRoot.getType()).getWidth() == 1)
428 return comb::ReplicateOp::create(b, scalarRoot.getLoc(),
429 b.getIntegerType(width), scalarRoot);
434 Operation *op = scalarRoot.getDefiningOp();
439 SmallVector<Value> ops;
440 for (Value operand : op->getOperands()) {
441 Value v = vectorizeSubgraph(b, operand, width, map);
443 return map[scalarRoot] =
nullptr;
447 Type vecTy =
b.getIntegerType(width);
451 if (isa<comb::AndOp>(op))
452 result = comb::AndOp::create(b, op->getLoc(), vecTy, ops);
453 else if (isa<comb::OrOp>(op))
454 result = comb::OrOp::create(b, op->getLoc(), vecTy, ops);
455 else if (isa<comb::XorOp>(op))
456 result = comb::XorOp::create(b, op->getLoc(), vecTy, ops);
457 else if (
auto muxOp = dyn_cast<comb::MuxOp>(op)) {
458 Value sel = muxOp.getCond();
459 result = comb::MuxOp::create(b, muxOp.getLoc(), sel, ops[1], ops[2]);
464 map[scalarRoot] = result;
470 bool isValidPermutation(
const BitArray &arr,
unsigned bitWidth) {
471 if (
arr.size() != bitWidth)
473 llvm::SmallBitVector seen(bitWidth);
474 for (
const auto &bit :
arr.bits) {
476 if (bit.index >=
static_cast<int>(bitWidth) || seen.test(bit.index))
490 void applyMixVectorization(IRRewriter &rewriter, Value oldOutputVal,
491 Value sourceInput,
const BitArray &arr,
493 rewriter.setInsertionPointAfterValue(sourceInput);
494 Location loc = sourceInput.getLoc();
498 llvm::SmallVector<Value> chunks;
500 while (i < bitWidth) {
501 unsigned startBit =
arr.bits[i].index;
503 while (i + len < bitWidth &&
504 arr.bits[i + len].index ==
static_cast<int>(startBit + len))
508 rewriter, loc, rewriter.getIntegerType(len), sourceInput, startBit));
513 std::reverse(chunks.begin(), chunks.end());
515 Value newVal = comb::ConcatOp::create(
516 rewriter, loc, rewriter.getIntegerType(bitWidth), chunks);
518 oldOutputVal.replaceAllUsesWith(newVal);
523 module.walk([&](Operation *op) {
524 llvm::TypeSwitch<Operation *>(op)
525 .Case<comb::ExtractOp>([&](comb::ExtractOp extractOp) {
528 dyn_cast<IntegerType>(extractOp.getResult().getType());
529 if (!resultType || resultType.getWidth() != 1)
534 Bit(extractOp.getInput(), extractOp.getLowBit()));
535 bitArrays[extractOp.getResult()] = bits;
539 dyn_cast<IntegerType>(concatOp.getResult().getType());
543 unsigned totalWidth = resultType.getWidth();
544 BitArray concatenatedArray;
545 concatenatedArray.bits.resize(totalWidth);
547 unsigned currentBitOffset = 0;
548 for (Value operand :
llvm::reverse(concatOp.getInputs())) {
549 unsigned operandWidth =
550 cast<IntegerType>(operand.getType()).getWidth();
551 auto it = bitArrays.find(operand);
552 if (it != bitArrays.end()) {
553 for (
unsigned i = 0; i < it->second.bits.size(); ++i)
554 concatenatedArray.bits[i + currentBitOffset] =
557 currentBitOffset += operandWidth;
559 bitArrays[concatOp.getResult()] = concatenatedArray;
561 .Case<comb::AndOp, comb::OrOp, comb::XorOp, comb::MuxOp>(
563 auto result = op->getResult(0);
564 auto resultType = dyn_cast<IntegerType>(result.getType());
565 if (resultType && resultType.getWidth() == 1) {
567 arr.bits.push_back(Bit(result, 0));
568 bitArrays[result] =
arr;
575struct HWVectorizationPass
576 :
public hw::impl::HWVectorizationBase<HWVectorizationPass> {
578 void runOnOperation()
override {
579 Vectorizer v(getOperation());
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.