12#include "mlir/IR/AttrTypeSubElements.h"
13#include "mlir/IR/Iterators.h"
14#include "mlir/Pass/Pass.h"
16#define DEBUG_TYPE "hw-convert-bitcasts"
20#define GEN_PASS_DEF_HWCONVERTBITCASTS
21#include "circt/Dialect/HW/Passes.h.inc"
30struct HWConvertBitcastsPass
31 : circt::hw::impl::HWConvertBitcastsBase<HWConvertBitcastsPass> {
32 using circt::hw::impl::HWConvertBitcastsBase<
33 HWConvertBitcastsPass>::HWConvertBitcastsBase;
35 void runOnOperation()
override;
36 static bool isTypeSupported(Type ty);
37 LogicalResult convertBitcastOp(OpBuilder builder,
BitcastOp bitcastOp);
45bool HWConvertBitcastsPass::isTypeSupported(Type ty) {
46 if (isa<IntegerType>(ty))
48 if (
auto arrayTy = hw::type_dyn_cast<hw::ArrayType>(ty))
49 return isTypeSupported(arrayTy.getElementType());
50 if (
auto structTy = hw::type_dyn_cast<hw::StructType>(ty))
51 return llvm::all_of(structTy.getElements(),
52 [](StructType::FieldInfo field) {
53 return isTypeSupported(field.type);
63 SmallVectorImpl<Value> &accumulator) {
65 if (isa<IntegerType>(inputVal.getType())) {
66 accumulator.push_back(inputVal);
71 assert(numBits >= 0 &&
"Bitwidth of input must be known");
74 if (
auto arrayTy = dyn_cast<ArrayType>(inputVal.getType())) {
76 auto indexType = builder.getIntegerType(llvm::Log2_64_Ceil(
numElements));
80 builder, loc, builder.getIntegerAttr(indexType,
numElements - i - 1));
88 if (
auto structTy = dyn_cast<StructType>(inputVal.getType())) {
90 auto explodeOp = StructExplodeOp::create(builder, loc, inputVal);
91 for (
auto elt : explodeOp.getResults())
96 assert(
false &&
"Unsupported type");
102 Value rawInteger, Type targetType) {
104 assert(numBits >= 0 &&
"Bitwidth of target must be known");
105 assert(numBits == rawInteger.getType().getIntOrFloatBitWidth());
108 if (isa<IntegerType>(targetType))
111 SmallVector<Value> elements;
114 if (
auto arrayTy = type_dyn_cast<ArrayType>(targetType)) {
116 auto sliceWidth =
getBitWidth(arrayTy.getElementType());
118 auto sliceTy = builder.getIntegerType(sliceWidth);
126 builder.getIntegerAttr(sliceTy, 0));
131 arrayTy.getElementType());
132 elements.push_back(elt);
140 if (
auto structTy = type_dyn_cast<StructType>(targetType)) {
142 unsigned consumedBits = 0;
145 auto eltBits =
getBitWidth(structTy.getElements()[i].type);
151 builder, loc, builder.getIntegerAttr(builder.getIntegerType(0), 0));
154 builder, loc, builder.getIntegerType(eltBits), rawInteger,
155 numBits - consumedBits - eltBits);
157 structTy.getElements()[i].type);
158 elements.push_back(elt);
159 consumedBits += eltBits;
161 assert(consumedBits == numBits);
167 assert(
false &&
"Unsupported type");
171LogicalResult HWConvertBitcastsPass::convertBitcastOp(OpBuilder builder,
173 bool inputSupported = isTypeSupported(bitcastOp.getInput().getType());
174 bool outputSupported = isTypeSupported(bitcastOp.getType());
175 if (!allowPartialConversion) {
177 bitcastOp.emitOpError(
"has unsupported input type");
178 if (!outputSupported)
179 bitcastOp.emitOpError(
"has unsupported output type");
181 if (!(inputSupported && outputSupported))
184 builder.setInsertionPoint(bitcastOp);
187 SmallVector<Value> integers;
191 if (integers.size() == 1)
192 concat = integers.front();
194 concat = comb::ConcatOp::create(builder, bitcastOp.getLoc(), integers)
199 concat, bitcastOp.getType());
202 bitcastOp.getResult().replaceAllUsesWith(result);
207void HWConvertBitcastsPass::runOnOperation() {
208 OpBuilder builder(getOperation());
209 bool anyFailed =
false;
210 bool anyChanged =
false;
211 getOperation().getBody()->walk<WalkOrder::PostOrder, ReverseIterator>(
214 if (failed(convertBitcastOp(builder, bitcastOp)))
219 markAllAnalysesPreserved();
223 if (anyFailed && !allowPartialConversion)
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static void collectIntegersRecursively(OpBuilder builder, Location loc, Value inputVal, SmallVectorImpl< Value > &accumulator)
static Value constructAggregateRecursively(OpBuilder builder, Location loc, Value rawInteger, Type targetType)
create(elements, Type result_type=None)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.