21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/RegionUtils.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/TypeSwitch.h"
31#define GEN_PASS_DEF_HWVECTORIZATION
32#include "circt/Dialect/HW/Passes.h.inc"
48 Bit(Value source,
int index) : source(source), index(index) {}
49 Bit() : source(nullptr), index(0) {}
57 llvm::SmallVector<Bit> bits;
60 bool isLinear(
int size, Value sourceInput)
const {
61 if (bits.size() !=
static_cast<size_t>(size))
63 for (
int i = 0; i < size; ++i) {
64 if (bits[i].source != sourceInput || bits[i].index != i)
71 bool isReverse(
int size, Value sourceInput)
const {
72 if (bits.size() !=
static_cast<size_t>(size))
74 for (
int i = 0; i < size; ++i) {
75 if (bits[i].source != sourceInput || (size - 1) - i != bits[i].index)
82 Value getSingleSourceValue()
const {
83 Value source =
nullptr;
84 for (
const auto &bit : bits) {
89 else if (source != bit.source)
95 size_t size()
const {
return bits.size(); }
107 dyn_cast<hw::OutputOp>(module.getBodyBlock()->getTerminator());
111 IRRewriter rewriter(module.getContext());
112 bool changed =
false;
114 for (Value oldOutputVal : outputOp->getOperands()) {
115 auto type = dyn_cast<IntegerType>(oldOutputVal.getType());
119 unsigned bitWidth = type.getWidth();
120 auto it = bitArrays.find(oldOutputVal);
121 if (it == bitArrays.end())
124 BitArray &
arr = it->second;
125 if (
arr.size() != bitWidth)
128 Value sourceInput =
arr.getSingleSourceValue();
132 if (
arr.isLinear(bitWidth, sourceInput)) {
133 oldOutputVal.replaceAllUsesWith(sourceInput);
135 }
else if (
arr.isReverse(bitWidth, sourceInput)) {
136 rewriter.setInsertionPointAfterValue(sourceInput);
137 Value reversed = comb::ReverseOp::create(
138 rewriter, sourceInput.getLoc(), sourceInput.getType(), sourceInput);
139 oldOutputVal.replaceAllUsesWith(reversed);
141 }
else if (isValidPermutation(arr, bitWidth)) {
142 applyMixVectorization(rewriter, oldOutputVal, sourceInput, arr,
149 (void)mlir::runRegionDCE(rewriter, module.getBody());
154 llvm::DenseMap<Value, BitArray> bitArrays;
159 bool isValidPermutation(
const BitArray &arr,
unsigned bitWidth) {
160 if (
arr.size() != bitWidth)
162 llvm::SmallBitVector seen(bitWidth);
163 for (
const auto &bit :
arr.bits) {
165 if (bit.index >=
static_cast<int>(bitWidth) || seen.test(bit.index))
179 void applyMixVectorization(IRRewriter &rewriter, Value oldOutputVal,
180 Value sourceInput,
const BitArray &arr,
182 rewriter.setInsertionPointAfterValue(sourceInput);
183 Location loc = sourceInput.getLoc();
187 llvm::SmallVector<Value> chunks;
189 while (i < bitWidth) {
190 unsigned startBit =
arr.bits[i].index;
192 while (i + len < bitWidth &&
193 arr.bits[i + len].index ==
static_cast<int>(startBit + len))
197 rewriter, loc, rewriter.getIntegerType(len), sourceInput, startBit));
202 std::reverse(chunks.begin(), chunks.end());
204 Value newVal = comb::ConcatOp::create(
205 rewriter, loc, rewriter.getIntegerType(bitWidth), chunks);
207 oldOutputVal.replaceAllUsesWith(newVal);
212 module.walk([&](Operation *op) {
213 llvm::TypeSwitch<Operation *>(op)
214 .Case<comb::ExtractOp>([&](comb::ExtractOp extractOp) {
217 dyn_cast<IntegerType>(extractOp.getResult().getType());
218 if (!resultType || resultType.getWidth() != 1)
223 Bit(extractOp.getInput(), extractOp.getLowBit()));
224 bitArrays.insert({extractOp.getResult(), bits});
228 dyn_cast<IntegerType>(concatOp.getResult().getType());
232 unsigned totalWidth = resultType.getWidth();
233 BitArray concatenatedArray;
234 concatenatedArray.bits.resize(totalWidth);
236 unsigned currentBitOffset = 0;
237 for (Value operand :
llvm::reverse(concatOp.getInputs())) {
238 unsigned operandWidth =
239 cast<IntegerType>(operand.getType()).getWidth();
240 auto it = bitArrays.find(operand);
241 if (it != bitArrays.end()) {
242 for (
unsigned i = 0; i < it->second.bits.size(); ++i)
243 concatenatedArray.bits[i + currentBitOffset] =
246 currentBitOffset += operandWidth;
248 bitArrays.insert({concatOp.getResult(), concatenatedArray});
254struct HWVectorizationPass
255 :
public hw::impl::HWVectorizationBase<HWVectorizationPass> {
257 void runOnOperation()
override {
258 Vectorizer v(getOperation());
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.