13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
15#include "llvm/Support/Debug.h"
17#define DEBUG_TYPE "arc-lower-lut"
21#define GEN_PASS_DEF_LOWERLUT
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
44 LogicalResult computeTableEntries(LutOp lut);
48 ArrayRef<IntegerAttr> getRefToTableEntries();
51 void getCopyOfTableEntries(SmallVector<IntegerAttr> &tableEntries);
55 void getTableEntriesAsConstValues(OpBuilder &builder,
56 SmallVector<Value> &tableEntries);
58 uint32_t getTableSize();
60 uint32_t getInputBitWidth();
64 SmallVector<IntegerAttr> table;
76LogicalResult LutCalculator::computeTableEntries(LutOp lut) {
81 if (this->lut == lut && !table.empty())
90 DenseMap<Value, SmallVector<Attribute>> vals;
91 const uint32_t bw = getInputBitWidth();
94 vals[arg] = SmallVector<Attribute>(1U << bw);
96 for (
auto &operation : lut.
getBodyBlock()->without_terminator()) {
97 for (
auto operand : operation.getResults()) {
98 if (vals.count(operand))
100 vals[operand] = SmallVector<Attribute>(1U << bw);
105 for (
int i = 0; i < (1 << bw); ++i) {
106 const APInt input(bw, i);
109 const unsigned argBitWidth = arg.getType().getIntOrFloatBitWidth();
110 offset -= argBitWidth;
111 vals[arg][i] = IntegerAttr::get(arg.getType(),
112 input.extractBits(argBitWidth, offset));
116 for (
auto &operation : lut.
getBodyBlock()->without_terminator()) {
119 SmallVector<SmallVector<Attribute>, 8> constants(1U << bw);
120 for (
size_t j = 0, e = operation.getNumOperands(); j < e; ++j) {
121 SmallVector<Attribute> &tmp = vals[operation.getOperand(j)];
122 for (
int i = (1U << bw) - 1; i >= 0; i--)
123 constants[i].push_back(tmp[i]);
127 SmallVector<SmallVector<OpFoldResult>, 8> results(
128 1U << bw, SmallVector<OpFoldResult, 8>());
129 for (
int i = (1U << bw) - 1; i >= 0; i--) {
130 if (failed(operation.fold(constants[i], results[i]))) {
131 LLVM_DEBUG(llvm::dbgs() <<
"Failed to fold operation '";
132 operation.print(llvm::dbgs()); llvm::dbgs() <<
"'\n");
138 for (
size_t i = 0, e = operation.getNumResults(); i < e; ++i) {
139 SmallVector<Attribute> &ref = vals[operation.getResult(i)];
140 for (
int j = (1U << bw) - 1; j >= 0; j--) {
142 if (!(foldAttr = dyn_cast<Attribute>(results[j][i])))
143 foldAttr = vals[llvm::cast<Value>(results[j][i])][j];
151 auto outValue = lut.getBodyBlock()->getTerminator()->getOperand(0);
152 for (
int j = (1U << bw) - 1; j >= 0; j--)
153 table.push_back(cast<IntegerAttr>(vals[outValue][j]));
158ArrayRef<IntegerAttr> LutCalculator::getRefToTableEntries() {
return table; }
160void LutCalculator::getCopyOfTableEntries(
161 SmallVector<IntegerAttr> &tableEntries) {
162 tableEntries.append(table);
165void LutCalculator::getTableEntriesAsConstValues(
166 OpBuilder &builder, SmallVector<Value> &tableEntries) {
170 DenseMap<IntegerAttr, Value> map;
171 for (
auto entry : table) {
172 if (!map.count(entry))
175 tableEntries.push_back(map[entry]);
179uint32_t LutCalculator::getInputBitWidth() {
181 for (
auto val : lut.getInputs())
182 bw += cast<IntegerType>(val.getType()).
getWidth();
186uint32_t LutCalculator::getTableSize() {
187 return (1 << getInputBitWidth()) *
188 lut.getOutput().getType().getIntOrFloatBitWidth();
202 LutToInteger(LutCalculator &calculator, MLIRContext *context)
206 matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
207 ConversionPatternRewriter &rewriter)
const final {
208 if (failed(lutCalculator.computeTableEntries(lut)))
211 const uint32_t tableSize = lutCalculator.getTableSize();
212 const uint32_t inputBw = lutCalculator.getInputBitWidth();
218 auto constants = lutCalculator.getRefToTableEntries();
219 APInt result(tableSize, 0);
220 unsigned nextInsertion = tableSize;
222 for (
auto attr : constants) {
223 auto chunk = attr.getValue();
224 nextInsertion -= chunk.getBitWidth();
225 result.insertBits(chunk, nextInsertion);
235 rewriter.getIntegerType(tableSize - inputBw), 0);
236 Value entryOffset = comb::ConcatOp::create(rewriter, lut.getLoc(),
237 zextValue, lut.getInputs());
239 rewriter, lut.getLoc(), entryOffset.getType(),
240 lut.getResult().getType().getIntOrFloatBitWidth());
241 Value lookupValue = comb::MulOp::create(rewriter, lut.getLoc(), entryOffset,
246 comb::ShrUOp::create(rewriter, lut->getLoc(), table, lookupValue);
248 rewriter, lut.getLoc(), shiftedTable, 0,
249 lut.getOutput().getType().getIntOrFloatBitWidth());
251 rewriter.replaceOp(lut, extracted);
255 LutCalculator &lutCalculator;
262 LutToArray(LutCalculator &calculator, MLIRContext *context)
266 matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter)
const final {
268 if (failed(lutCalculator.computeTableEntries(lut)))
271 auto constants = lutCalculator.getRefToTableEntries();
272 SmallVector<Attribute> constantAttrs(constants.begin(), constants.end());
273 auto tableSize = lutCalculator.getTableSize();
274 auto inputBw = lutCalculator.getInputBitWidth();
276 if (tableSize <= 256)
279 Value table = hw::AggregateConstantOp::create(
280 rewriter, lut.getLoc(),
281 hw::ArrayType::get(lut.getType(), constantAttrs.size()),
282 rewriter.getArrayAttr(constantAttrs));
283 Value lookupValue = comb::ConcatOp::create(rewriter, lut.getLoc(),
284 rewriter.getIntegerType(inputBw),
286 const Value extracted =
289 rewriter.replaceOp(lut, extracted);
293 LutCalculator &lutCalculator;
305struct LowerLUTPass :
public arc::impl::LowerLUTBase<LowerLUTPass> {
306 void runOnOperation()
override;
311void LowerLUTPass::runOnOperation() {
312 MLIRContext &context = getContext();
313 ConversionTarget target(context);
314 RewritePatternSet
patterns(&context);
315 target.addLegalDialect<comb::CombDialect, hw::HWDialect, arc::ArcDialect>();
316 target.addIllegalOp<arc::LutOp>();
320 LutCalculator lutCalculator;
321 patterns.add<LutToInteger, LutToArray>(lutCalculator, &context);
324 applyPartialConversion(getOperation(), target, std::move(
patterns))))
329 return std::make_unique<LowerLUTPass>();
static Block * getBodyBlock(FModuleLike mod)
std::unique_ptr< mlir::Pass > createLowerLUTPass()
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.