13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "arc-lower-lut"
21 #define GEN_PASS_DEF_LOWERLUT
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
26 using namespace circt;
44 LogicalResult computeTableEntries(LutOp lut);
48 ArrayRef<IntegerAttr> getRefToTableEntries();
51 void getCopyOfTableEntries(SmallVector<IntegerAttr> &tableEntries);
55 void getTableEntriesAsConstValues(OpBuilder &builder,
56 SmallVector<Value> &tableEntries);
58 uint32_t getTableSize();
60 uint32_t getInputBitWidth();
64 SmallVector<IntegerAttr> table;
70 class LutLoweringPattern :
public ConversionPattern {
72 LutLoweringPattern(LutCalculator &lutCalculator, MLIRContext *context,
73 mlir::PatternBenefit benefit = 1)
74 : ConversionPattern(LutOp::getOperationName(), benefit, context),
75 lutCalculator(lutCalculator) {}
76 LutLoweringPattern(LutCalculator &lutCalculator, TypeConverter &typeConverter,
77 MLIRContext *context, mlir::PatternBenefit benefit = 1)
78 : ConversionPattern(typeConverter, LutOp::getOperationName(), benefit,
80 lutCalculator(lutCalculator) {}
84 LogicalResult match(Operation *op)
const final {
85 auto lut = cast<LutOp>(op);
86 if (failed(lutCalculator.computeTableEntries(lut)))
90 void rewrite(Operation *op, ArrayRef<Value> operands,
91 ConversionPatternRewriter &rewriter)
const final {
92 rewrite(cast<LutOp>(op), LutOpAdaptor(operands, op->getAttrDictionary()),
96 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
97 ConversionPatternRewriter &rewriter)
const final {
98 auto lut = cast<LutOp>(op);
99 if (failed(lutCalculator.computeTableEntries(lut)))
101 return matchAndRewrite(lut, LutOpAdaptor(operands, op->getAttrDictionary()),
107 virtual LogicalResult match(LutOp op)
const {
108 llvm_unreachable(
"must override match or matchAndRewrite");
110 virtual void rewrite(LutOp op, LutOpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter)
const {
112 llvm_unreachable(
"must override matchAndRewrite or a rewrite method");
114 virtual LogicalResult
115 matchAndRewrite(LutOp op, LutOpAdaptor adaptor,
116 ConversionPatternRewriter &rewriter)
const {
117 if (failed(match(op)))
119 rewrite(op, adaptor, rewriter);
124 LutCalculator &lutCalculator;
127 using ConversionPattern::matchAndRewrite;
139 LogicalResult LutCalculator::computeTableEntries(LutOp lut) {
144 if (this->lut == lut && !table.empty())
153 DenseMap<Value, SmallVector<Attribute>> vals;
154 const uint32_t bw = getInputBitWidth();
156 for (
auto arg : lut.getBodyBlock()->getArguments())
157 vals[arg] = SmallVector<Attribute>(1U << bw);
159 for (
auto &operation : lut.getBodyBlock()->without_terminator()) {
160 for (
auto operand : operation.getResults()) {
161 if (vals.count(operand))
163 vals[operand] = SmallVector<Attribute>(1U << bw);
168 for (
int i = 0; i < (1 << bw); ++i) {
169 const APInt input(bw, i);
171 for (
auto arg : lut.getBodyBlock()->getArguments()) {
172 const unsigned argBitWidth = arg.getType().getIntOrFloatBitWidth();
173 offset -= argBitWidth;
175 input.extractBits(argBitWidth, offset));
179 for (
auto &operation : lut.getBodyBlock()->without_terminator()) {
182 SmallVector<SmallVector<Attribute>, 8> constants(1U << bw);
183 for (
size_t j = 0, e = operation.getNumOperands(); j < e; ++j) {
184 SmallVector<Attribute> &tmp = vals[operation.getOperand(j)];
185 for (
int i = (1U << bw) - 1; i >= 0; i--)
186 constants[i].push_back(tmp[i]);
190 SmallVector<SmallVector<OpFoldResult>, 8> results(
191 1U << bw, SmallVector<OpFoldResult, 8>());
192 for (
int i = (1U << bw) - 1; i >= 0; i--) {
193 if (failed(operation.fold(constants[i], results[i]))) {
194 LLVM_DEBUG(llvm::dbgs() <<
"Failed to fold operation '";
195 operation.print(llvm::dbgs()); llvm::dbgs() <<
"'\n");
201 for (
size_t i = 0, e = operation.getNumResults(); i < e; ++i) {
202 SmallVector<Attribute> &ref = vals[operation.getResult(i)];
203 for (
int j = (1U << bw) - 1; j >= 0; j--) {
205 if (!(foldAttr = dyn_cast<Attribute>(results[j][i])))
206 foldAttr = vals[results[j][i].get<Value>()][j];
214 auto outValue = lut.getBodyBlock()->getTerminator()->getOperand(0);
215 for (
int j = (1U << bw) - 1; j >= 0; j--)
216 table.push_back(cast<IntegerAttr>(vals[outValue][j]));
221 ArrayRef<IntegerAttr> LutCalculator::getRefToTableEntries() {
return table; }
223 void LutCalculator::getCopyOfTableEntries(
224 SmallVector<IntegerAttr> &tableEntries) {
225 tableEntries.append(table);
228 void LutCalculator::getTableEntriesAsConstValues(
229 OpBuilder &builder, SmallVector<Value> &tableEntries) {
233 DenseMap<IntegerAttr, Value> map;
234 for (
auto entry : table) {
235 if (!map.count(entry))
238 tableEntries.push_back(map[entry]);
242 uint32_t LutCalculator::getInputBitWidth() {
244 for (
auto val : lut.getInputs())
245 bw += cast<IntegerType>(val.getType()).getWidth();
249 uint32_t LutCalculator::getTableSize() {
250 return (1 << getInputBitWidth()) *
251 lut.getOutput().getType().getIntOrFloatBitWidth();
264 struct LutToInteger : LutLoweringPattern {
265 using LutLoweringPattern::LutLoweringPattern;
268 matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter)
const final {
271 const uint32_t tableSize = lutCalculator.getTableSize();
272 const uint32_t inputBw = lutCalculator.getInputBitWidth();
278 auto constants = lutCalculator.getRefToTableEntries();
279 APInt result(tableSize, 0);
280 unsigned nextInsertion = tableSize;
282 for (
auto attr : constants) {
283 auto chunk = attr.getValue();
284 nextInsertion -= chunk.getBitWidth();
285 result.insertBits(chunk, nextInsertion);
288 Value table = rewriter.create<
hw::ConstantOp>(lut.getLoc(), result);
294 lut->getLoc(), rewriter.getIntegerType(tableSize - inputBw), 0);
298 lut.getLoc(), entryOffset.getType(),
299 lut.getResult().getType().getIntOrFloatBitWidth());
305 rewriter.create<
comb::ShrUOp>(lut->getLoc(), table, lookupValue);
307 lut.getLoc(), shiftedTable, 0,
308 lut.getOutput().getType().getIntOrFloatBitWidth());
310 rewriter.replaceOp(lut, extracted);
318 struct LutToArray : LutLoweringPattern {
319 using LutLoweringPattern::LutLoweringPattern;
322 matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter)
const final {
324 auto constants = lutCalculator.getRefToTableEntries();
325 SmallVector<Attribute> constantAttrs(constants.begin(), constants.end());
326 auto tableSize = lutCalculator.getTableSize();
327 auto inputBw = lutCalculator.getInputBitWidth();
329 if (tableSize <= 256)
332 Value table = rewriter.create<hw::AggregateConstantOp>(
334 rewriter.getArrayAttr(constantAttrs));
336 lut.getLoc(), rewriter.getIntegerType(inputBw), lut.getInputs());
337 const Value extracted =
338 rewriter.create<
hw::ArrayGetOp>(lut.getLoc(), table, lookupValue);
340 rewriter.replaceOp(lut, extracted);
354 struct LowerLUTPass :
public arc::impl::LowerLUTBase<LowerLUTPass> {
355 void runOnOperation()
override;
360 void LowerLUTPass::runOnOperation() {
361 MLIRContext &context = getContext();
362 ConversionTarget target(context);
363 RewritePatternSet
patterns(&context);
364 target.addLegalDialect<comb::CombDialect, hw::HWDialect, arc::ArcDialect>();
365 target.addIllegalOp<arc::LutOp>();
369 LutCalculator lutCalculator;
370 patterns.add<LutToInteger, LutToArray>(lutCalculator, &context);
373 applyPartialConversion(getOperation(), target, std::move(
patterns))))
378 return std::make_unique<LowerLUTPass>();
def create(data_type, value)
std::unique_ptr< mlir::Pass > createLowerLUTPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.