CIRCT 22.0.0git
Loading...
Searching...
No Matches
HWConvertBitcasts.cpp
Go to the documentation of this file.
1//===- ConvertBitcasts.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/AttrTypeSubElements.h"
13#include "mlir/IR/Iterators.h"
14#include "mlir/Pass/Pass.h"
15
16#define DEBUG_TYPE "hw-convert-bitcasts"
17
18namespace circt {
19namespace hw {
20#define GEN_PASS_DEF_HWCONVERTBITCASTS
21#include "circt/Dialect/HW/Passes.h.inc"
22} // namespace hw
23} // namespace circt
24
25using namespace mlir;
26using namespace circt;
27using namespace hw;
28
29namespace {
30struct HWConvertBitcastsPass
31 : circt::hw::impl::HWConvertBitcastsBase<HWConvertBitcastsPass> {
32 using circt::hw::impl::HWConvertBitcastsBase<
33 HWConvertBitcastsPass>::HWConvertBitcastsBase;
34
35 void runOnOperation() override;
36 static bool isTypeSupported(Type ty);
37 LogicalResult convertBitcastOp(OpBuilder builder, BitcastOp bitcastOp);
38};
39} // namespace
40
41// Array conversion: Lower bits correspond to lower array index
42// Struct conversion: Higher bits correspond to lower field index
43
44// NOLINTNEXTLINE(misc-no-recursion)
45bool HWConvertBitcastsPass::isTypeSupported(Type ty) {
46 if (isa<IntegerType>(ty))
47 return true;
48 if (auto arrayTy = hw::type_dyn_cast<hw::ArrayType>(ty))
49 return isTypeSupported(arrayTy.getElementType());
50 if (auto structTy = hw::type_dyn_cast<hw::StructType>(ty))
51 return llvm::all_of(structTy.getElements(),
52 [](StructType::FieldInfo field) {
53 return isTypeSupported(field.type);
54 });
55 // TODO: Add support for: union, packed array, enum
56 return false;
57}
58
59// Collect integer fields of aggregates for concatenation
60// NOLINTNEXTLINE(misc-no-recursion)
61static void collectIntegersRecursively(OpBuilder builder, Location loc,
62 Value inputVal,
63 SmallVectorImpl<Value> &accumulator) {
64 // End of recursion: Integer value
65 if (isa<IntegerType>(inputVal.getType())) {
66 accumulator.push_back(inputVal);
67 return;
68 }
69
70 auto numBits = getBitWidth(inputVal.getType());
71 assert(numBits >= 0 && "Bitwidth of input must be known");
72
73 // Array Type
74 if (auto arrayTy = dyn_cast<ArrayType>(inputVal.getType())) {
75 unsigned numElements = arrayTy.getNumElements();
76 auto indexType = builder.getIntegerType(llvm::Log2_64_Ceil(numElements));
77 for (unsigned i = 0; i < numElements; ++i) {
78 // Process from high to low array index
79 auto indexCst = ConstantOp::create(
80 builder, loc, builder.getIntegerAttr(indexType, numElements - i - 1));
81 auto getOp = ArrayGetOp::create(builder, loc, inputVal, indexCst);
82 collectIntegersRecursively(builder, loc, getOp.getResult(), accumulator);
83 }
84 return;
85 }
86
87 // Struct Type
88 if (auto structTy = dyn_cast<StructType>(inputVal.getType())) {
89 // Process from low to high field index
90 auto explodeOp = StructExplodeOp::create(builder, loc, inputVal);
91 for (auto elt : explodeOp.getResults())
92 collectIntegersRecursively(builder, loc, elt, accumulator);
93 return;
94 }
95
96 assert(false && "Unsupported type");
97}
98
99// Convert integer to aggregate
100// NOLINTNEXTLINE(misc-no-recursion)
101static Value constructAggregateRecursively(OpBuilder builder, Location loc,
102 Value rawInteger, Type targetType) {
103 auto numBits = getBitWidth(targetType);
104 assert(numBits >= 0 && "Bitwidth of target must be known");
105 assert(numBits == rawInteger.getType().getIntOrFloatBitWidth());
106
107 // End of recursion: Integer value
108 if (isa<IntegerType>(targetType))
109 return rawInteger;
110
111 SmallVector<Value> elements;
112
113 // Array Type
114 if (auto arrayTy = type_dyn_cast<ArrayType>(targetType)) {
115 auto numElements = arrayTy.getNumElements();
116 auto sliceWidth = getBitWidth(arrayTy.getElementType());
117 assert(sliceWidth >= 0);
118 auto sliceTy = builder.getIntegerType(sliceWidth);
119 elements.reserve(numElements);
120 for (unsigned i = 0; i < numElements; ++i) {
121 // Process bits from MSB to LSB
122 auto offset = sliceWidth * (numElements - i - 1);
123 Value slice;
124 if (sliceWidth == 0)
125 slice = ConstantOp::create(builder, loc,
126 builder.getIntegerAttr(sliceTy, 0));
127 else
128 slice =
129 comb::ExtractOp::create(builder, loc, sliceTy, rawInteger, offset);
130 auto elt = constructAggregateRecursively(builder, loc, slice,
131 arrayTy.getElementType());
132 elements.push_back(elt);
133 }
134 // Array create reverses order:
135 // Higher bits are added first and become high indices.
136 return ArrayCreateOp::create(builder, loc, targetType, elements);
137 }
138
139 // Struct Type
140 if (auto structTy = type_dyn_cast<StructType>(targetType)) {
141 auto numElements = structTy.getElements().size();
142 unsigned consumedBits = 0;
143 for (unsigned i = 0; i < numElements; ++i) {
144
145 auto eltBits = getBitWidth(structTy.getElements()[i].type);
146 assert(eltBits >= 0);
147 Value slice;
148 // Process bits from MSB to LSB
149 if (eltBits == 0)
150 slice = ConstantOp::create(
151 builder, loc, builder.getIntegerAttr(builder.getIntegerType(0), 0));
152 else
154 builder, loc, builder.getIntegerType(eltBits), rawInteger,
155 numBits - consumedBits - eltBits);
156 auto elt = constructAggregateRecursively(builder, loc, slice,
157 structTy.getElements()[i].type);
158 elements.push_back(elt);
159 consumedBits += eltBits;
160 }
161 assert(consumedBits == numBits);
162 // Struct create does not reverse order:
163 // Higher bits are added first and become low indices.
164 return StructCreateOp::create(builder, loc, targetType, elements);
165 }
166
167 assert(false && "Unsupported type");
168 return {};
169}
170
171LogicalResult HWConvertBitcastsPass::convertBitcastOp(OpBuilder builder,
172 BitcastOp bitcastOp) {
173 bool inputSupported = isTypeSupported(bitcastOp.getInput().getType());
174 bool outputSupported = isTypeSupported(bitcastOp.getType());
175 if (!allowPartialConversion) {
176 if (!inputSupported)
177 bitcastOp.emitOpError("has unsupported input type");
178 if (!outputSupported)
179 bitcastOp.emitOpError("has unsupported output type");
180 }
181 if (!(inputSupported && outputSupported))
182 return failure();
183
184 builder.setInsertionPoint(bitcastOp);
185
186 // Convert input value to a packed integer
187 SmallVector<Value> integers;
188 collectIntegersRecursively(builder, bitcastOp.getLoc(), bitcastOp.getInput(),
189 integers);
190 Value concat;
191 if (integers.size() == 1)
192 concat = integers.front();
193 else
194 concat = comb::ConcatOp::create(builder, bitcastOp.getLoc(), integers)
195 .getResult();
196
197 // Convert packed integer to the target type
198 auto result = constructAggregateRecursively(builder, bitcastOp.getLoc(),
199 concat, bitcastOp.getType());
200
201 // Replace operation
202 bitcastOp.getResult().replaceAllUsesWith(result);
203 bitcastOp.erase();
204 return success();
205}
206
207void HWConvertBitcastsPass::runOnOperation() {
208 OpBuilder builder(getOperation());
209 bool anyFailed = false;
210 bool anyChanged = false;
211 getOperation().getBody()->walk<WalkOrder::PostOrder, ReverseIterator>(
212 [&](BitcastOp bitcastOp) {
213 anyChanged = true;
214 if (failed(convertBitcastOp(builder, bitcastOp)))
215 anyFailed = true;
216 });
217
218 if (!anyChanged) {
219 markAllAnalysesPreserved();
220 return;
221 }
222
223 if (anyFailed && !allowPartialConversion)
224 signalPassFailure();
225}
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static void collectIntegersRecursively(OpBuilder builder, Location loc, Value inputVal, SmallVectorImpl< Value > &accumulator)
static Value constructAggregateRecursively(OpBuilder builder, Location loc, Value rawInteger, Type targetType)
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements)
Definition hw.py:483
create(array_value, idx)
Definition hw.py:450
create(data_type, value)
Definition hw.py:433
create(elements, Type result_type=None)
Definition hw.py:532
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1