13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
15#include "llvm/Support/Debug.h"
17#define DEBUG_TYPE "arc-lower-lut"
21#define GEN_PASS_DEF_LOWERLUT
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
44 LogicalResult computeTableEntries(LutOp lut);
48 ArrayRef<IntegerAttr> getRefToTableEntries();
51 void getCopyOfTableEntries(SmallVector<IntegerAttr> &tableEntries);
55 void getTableEntriesAsConstValues(OpBuilder &builder,
56 SmallVector<Value> &tableEntries);
58 uint32_t getTableSize();
60 uint32_t getInputBitWidth();
64 SmallVector<IntegerAttr> table;
76LogicalResult LutCalculator::computeTableEntries(LutOp lut) {
81 if (this->lut == lut && !table.empty())
90 DenseMap<Value, SmallVector<Attribute>> vals;
91 const uint32_t bw = getInputBitWidth();
94 vals[arg] = SmallVector<Attribute>(1U << bw);
96 for (
auto &operation : lut.
getBodyBlock()->without_terminator()) {
97 for (
auto operand : operation.getResults()) {
98 if (vals.count(operand))
100 vals[operand] = SmallVector<Attribute>(1U << bw);
105 for (
int i = 0; i < (1 << bw); ++i) {
106 const APInt input(bw, i);
109 const unsigned argBitWidth = arg.getType().getIntOrFloatBitWidth();
110 offset -= argBitWidth;
111 vals[arg][i] = IntegerAttr::get(arg.getType(),
112 input.extractBits(argBitWidth, offset));
116 for (
auto &operation : lut.
getBodyBlock()->without_terminator()) {
119 SmallVector<SmallVector<Attribute>, 8> constants(1U << bw);
120 for (
size_t j = 0, e = operation.getNumOperands(); j < e; ++j) {
121 SmallVector<Attribute> &tmp = vals[operation.getOperand(j)];
122 for (
int i = (1U << bw) - 1; i >= 0; i--)
123 constants[i].push_back(tmp[i]);
127 SmallVector<SmallVector<OpFoldResult>, 8> results(
128 1U << bw, SmallVector<OpFoldResult, 8>());
129 for (
int i = (1U << bw) - 1; i >= 0; i--) {
130 if (failed(operation.fold(constants[i], results[i]))) {
131 LLVM_DEBUG(llvm::dbgs() <<
"Failed to fold operation '";
132 operation.print(llvm::dbgs()); llvm::dbgs() <<
"'\n");
138 for (
size_t i = 0, e = operation.getNumResults(); i < e; ++i) {
139 SmallVector<Attribute> &ref = vals[operation.getResult(i)];
140 for (
int j = (1U << bw) - 1; j >= 0; j--) {
142 if (!(foldAttr = dyn_cast<Attribute>(results[j][i])))
143 foldAttr = vals[llvm::cast<Value>(results[j][i])][j];
151 auto outValue = lut.getBodyBlock()->getTerminator()->getOperand(0);
152 for (
int j = (1U << bw) - 1; j >= 0; j--)
153 table.push_back(cast<IntegerAttr>(vals[outValue][j]));
158ArrayRef<IntegerAttr> LutCalculator::getRefToTableEntries() {
return table; }
160void LutCalculator::getCopyOfTableEntries(
161 SmallVector<IntegerAttr> &tableEntries) {
162 tableEntries.append(table);
165void LutCalculator::getTableEntriesAsConstValues(
166 OpBuilder &builder, SmallVector<Value> &tableEntries) {
170 DenseMap<IntegerAttr, Value> map;
171 for (
auto entry : table) {
172 if (!map.count(entry))
175 tableEntries.push_back(map[entry]);
179uint32_t LutCalculator::getInputBitWidth() {
181 for (
auto val : lut.getInputs())
182 bw += cast<IntegerType>(val.getType()).
getWidth();
186uint32_t LutCalculator::getTableSize() {
187 return (1 << getInputBitWidth()) *
188 lut.getOutput().getType().getIntOrFloatBitWidth();
202 LutToInteger(LutCalculator &calculator, MLIRContext *context)
206 matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
207 ConversionPatternRewriter &rewriter)
const final {
208 if (failed(lutCalculator.computeTableEntries(lut)))
211 const uint32_t tableSize = lutCalculator.getTableSize();
212 const uint32_t inputBw = lutCalculator.getInputBitWidth();
218 auto constants = lutCalculator.getRefToTableEntries();
219 APInt result(tableSize, 0);
220 unsigned nextInsertion = tableSize;
222 for (
auto attr : constants) {
223 auto chunk = attr.getValue();
224 nextInsertion -= chunk.getBitWidth();
225 result.insertBits(chunk, nextInsertion);
228 Value table = rewriter.create<
hw::ConstantOp>(lut.getLoc(), result);
234 lut->getLoc(), rewriter.getIntegerType(tableSize - inputBw), 0);
238 lut.getLoc(), entryOffset.getType(),
239 lut.getResult().getType().getIntOrFloatBitWidth());
245 rewriter.create<
comb::ShrUOp>(lut->getLoc(), table, lookupValue);
247 lut.getLoc(), shiftedTable, 0,
248 lut.getOutput().getType().getIntOrFloatBitWidth());
250 rewriter.replaceOp(lut, extracted);
254 LutCalculator &lutCalculator;
261 LutToArray(LutCalculator &calculator, MLIRContext *context)
265 matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
266 ConversionPatternRewriter &rewriter)
const final {
267 if (failed(lutCalculator.computeTableEntries(lut)))
270 auto constants = lutCalculator.getRefToTableEntries();
271 SmallVector<Attribute> constantAttrs(constants.begin(), constants.end());
272 auto tableSize = lutCalculator.getTableSize();
273 auto inputBw = lutCalculator.getInputBitWidth();
275 if (tableSize <= 256)
278 Value table = rewriter.create<hw::AggregateConstantOp>(
279 lut.getLoc(), hw::ArrayType::get(lut.getType(), constantAttrs.size()),
280 rewriter.getArrayAttr(constantAttrs));
282 lut.getLoc(), rewriter.getIntegerType(inputBw), lut.getInputs());
283 const Value extracted =
284 rewriter.create<
hw::ArrayGetOp>(lut.getLoc(), table, lookupValue);
286 rewriter.replaceOp(lut, extracted);
290 LutCalculator &lutCalculator;
302struct LowerLUTPass :
public arc::impl::LowerLUTBase<LowerLUTPass> {
303 void runOnOperation()
override;
308void LowerLUTPass::runOnOperation() {
309 MLIRContext &context = getContext();
310 ConversionTarget target(context);
311 RewritePatternSet
patterns(&context);
312 target.addLegalDialect<comb::CombDialect, hw::HWDialect, arc::ArcDialect>();
313 target.addIllegalOp<arc::LutOp>();
317 LutCalculator lutCalculator;
318 patterns.add<LutToInteger, LutToArray>(lutCalculator, &context);
321 applyPartialConversion(getOperation(), target, std::move(
patterns))))
326 return std::make_unique<LowerLUTPass>();
static Block * getBodyBlock(FModuleLike mod)
std::unique_ptr< mlir::Pass > createLowerLUTPass()
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.