13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "mlir/Pass/Pass.h"
15 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "arc-lookup-tables"
21 #define GEN_PASS_DEF_MAKETABLES
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
26 using namespace circt;
32 static constexpr
int tableMinOpCount = 20;
33 static constexpr
int tableMaxSize = 32768;
35 struct MakeTablesPass :
public arc::impl::MakeTablesBase<MakeTablesPass> {
36 void runOnOperation()
override;
37 void runOnArc(DefineOp defineOp);
41 static inline uint32_t
bitsMask(uint32_t nbits) {
44 return (1 << nbits) - 1;
47 static inline uint32_t
bitsGet(uint32_t x, uint32_t lb, uint32_t ub) {
48 return (x >> lb) &
bitsMask(ub - lb + 1);
51 void MakeTablesPass::runOnOperation() {
52 auto module = getOperation();
53 for (
auto op : module.getOps<DefineOp>())
57 void MakeTablesPass::runOnArc(DefineOp defineOp) {
59 unsigned numInputBits = 0;
60 for (
auto &type : defineOp.getArgumentTypes()) {
61 auto intType = dyn_cast<IntegerType>(type);
64 numInputBits += intType.getWidth();
66 if (numInputBits == 0)
71 for (
auto &op : defineOp.getBodyBlock().without_terminator())
72 if (!op.hasTrait<OpTrait::ConstantLike>())
76 unsigned numOutputBits = 0;
77 auto outputOp = cast<arc::OutputOp>(defineOp.getBodyBlock().getTerminator());
78 for (
auto type : outputOp.getOperandTypes()) {
79 auto intType = dyn_cast<IntegerType>(type);
82 numOutputBits += intType.getWidth();
84 if (numOutputBits == 0)
87 LLVM_DEBUG(llvm::dbgs() <<
"Making lookup tables in `" << defineOp.getName()
89 LLVM_DEBUG(llvm::dbgs() <<
"- " << numInputBits <<
" input bits, "
90 << numOutputBits <<
" output bits, " << numOps
94 if (numInputBits >= 31) {
95 LLVM_DEBUG(llvm::dbgs() <<
"- Skip; too many input bits\n");
98 if (numOps < tableMinOpCount) {
99 LLVM_DEBUG(llvm::dbgs() <<
"- Skip; not enough ops\n");
103 unsigned numTableEntries = 1U << numInputBits;
104 if (numTableEntries > tableMaxSize / numOutputBits) {
105 LLVM_DEBUG(llvm::dbgs() <<
"- Skip; table too large\n");
108 LLVM_DEBUG(llvm::dbgs() <<
"- Creating table of "
109 << numTableEntries * numOutputBits <<
" bits\n");
112 SmallVector<Operation *, 64> tabularizedOps;
113 for (
auto &op : defineOp.getBodyBlock().without_terminator())
114 tabularizedOps.push_back(&op);
117 auto builder = ImplicitLocOpBuilder::atBlockBegin(defineOp.getLoc(),
118 &defineOp.getBodyBlock());
119 SmallVector<Value> inputsToConcat(defineOp.getArguments());
120 std::reverse(inputsToConcat.begin(), inputsToConcat.end());
121 auto concatInputs = inputsToConcat.size() > 1
126 SmallVector<SmallVector<Attribute, 0>> tables;
127 DenseMap<Value, Attribute> values;
128 tables.resize(outputOp->getNumOperands());
130 for (
int input = (1U << numInputBits) - 1; input >= 0; input--) {
134 for (
auto arg : defineOp.getArguments()) {
135 auto w = dyn_cast<IntegerType>(arg.getType()).getWidth();
136 values[arg] = builder.getIntegerAttr(arg.getType(),
137 bitsGet(input, bits, bits + w - 1));
142 SmallVector<Attribute> constants;
143 for (
auto *operation : tabularizedOps) {
145 for (
auto operand : operation->getOperands())
146 constants.push_back(values[operand]);
148 SmallVector<OpFoldResult, 8> resultValues;
149 if (failed(operation->fold(constants, resultValues))) {
150 LLVM_DEBUG(llvm::dbgs() <<
"- Skip; operation folder failed\n");
154 for (
auto [result, resultValue] :
155 llvm::zip(operation->getResults(), resultValues)) {
156 auto attr = dyn_cast<Attribute>(resultValue);
158 attr = values[dyn_cast<Value>(resultValue)];
159 values[result] = attr;
164 for (
auto [table, outputOperand] :
165 llvm::zip(tables, outputOp->getOpOperands())) {
166 table.push_back(dyn_cast<Attribute>(values[outputOperand.get()]));
171 for (
auto [table, outputOperand] :
172 llvm::zip(tables, outputOp->getOpOperands())) {
173 auto array = builder.create<hw::AggregateConstantOp>(
175 builder.getArrayAttr(table));
176 outputOperand.set(builder.create<
hw::ArrayGetOp>(array, concatInputs));
179 for (
auto *op : tabularizedOps) {
186 return std::make_unique<MakeTablesPass>();
static uint32_t bitsGet(uint32_t x, uint32_t lb, uint32_t ub)
static uint32_t bitsMask(uint32_t nbits)
std::unique_ptr< mlir::Pass > createMakeTablesPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.