20#include "mlir/IR/Operation.h"
21#include "mlir/IR/Value.h"
22#include "mlir/Pass/Pass.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/KnownBits.h"
28#include "llvm/Support/LogicalResult.h"
31#define DEBUG_TYPE "synth-lower-word-to-bits"
35#define GEN_PASS_DEF_LOWERWORDTOBITS
36#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
66 explicit BitBlaster(Operation *topOp) : topOp(topOp) {}
76 size_t numLoweredBits = 0;
79 size_t numLoweredConstants = 0;
82 size_t numLoweredOps = 0;
91 ArrayRef<Value> lowerValueToBits(Value value);
92 ArrayRef<Value> lowerBooleanLogicOperation(BooleanLogicOpInterface op);
93 template <
typename OpTy>
94 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
97 lowerOp(Operation *op,
98 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
102 Value extractBit(Value value,
size_t index);
110 Value getBoolConstant(
bool value, Block *block);
117 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
118 auto it = loweredValues.insert({value, std::move(bits)});
119 assert(it.second &&
"value already inserted");
120 return it.first->second;
124 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
125 auto it = knownBits.insert({value, std::move(bits)});
126 return it.first->second;
136 llvm::DenseMap<Block *, std::array<Value, 2>> constantsByBlock;
148const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
150 auto *it = knownBits.find(value);
151 if (it != knownBits.end())
154 auto width = hw::getBitWidth(value.getType());
155 auto *op = value.getDefiningOp();
159 return insertKnownBits(value, llvm::KnownBits(width));
161 llvm::KnownBits result(width);
162 if (
auto logicOp = dyn_cast<BooleanLogicOpInterface>(op))
164 logicOp.computeKnownBits([&](
unsigned i) ->
const llvm::KnownBits & {
167 else if (
auto choice = dyn_cast<ChoiceOp>(op)) {
169 for (
auto input : choice.getInputs().drop_front()) {
171 result.One |= known.One;
172 result.Zero |= known.Zero;
178 result = comb::computeKnownBits(value);
181 return insertKnownBits(value, std::move(result));
184Value BitBlaster::extractBit(Value value,
size_t index) {
185 if (hw::getBitWidth(value.getType()) <= 1)
188 auto *op = value.getDefiningOp();
192 return lowerValueToBits(value)[index];
194 return TypeSwitch<Operation *, Value>(op)
196 for (
auto operand :
llvm::reverse(op.getOperands())) {
197 auto width = hw::getBitWidth(operand.getType());
198 assert(width >= 0 &&
"operand has zero width");
199 if (index <
static_cast<size_t>(width))
200 return extractBit(operand, index);
201 index -=
static_cast<size_t>(width);
203 llvm_unreachable(
"index out of bounds");
206 return extractBit(ext.getInput(),
207 static_cast<size_t>(ext.getLowBit()) + index);
209 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
210 return extractBit(op.getInput(),
211 index %
static_cast<size_t>(hw::getBitWidth(
212 op.getOperand().getType())));
215 auto value = op.getValue();
216 return getBoolConstant(value[index], op->getBlock());
218 .Default([&](
auto op) {
return lowerValueToBits(value)[index]; });
221ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
222 auto *it = loweredValues.find(value);
223 if (it != loweredValues.end())
226 auto width = hw::getBitWidth(value.getType());
228 return insertBits(value, {value});
230 auto *op = value.getDefiningOp();
232 SmallVector<Value> results;
233 OpBuilder builder(value.getContext());
234 builder.setInsertionPointAfterValue(value);
235 comb::extractBits(builder, value, results);
236 return insertBits(value, std::move(results));
239 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
240 .Case<ChoiceOp>([&](ChoiceOp op) {
241 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
242 return builder.createOrFold<ChoiceOp>(
243 op.getLoc(), operands[0].getType(), operands);
245 return lowerOp(op, createOp);
247 .Case<BooleanLogicOpInterface>(
248 [&](
auto op) {
return lowerBooleanLogicOperation(op); })
249 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
250 [&](
auto op) {
return lowerCombLogicOperations(op); })
251 .Case<comb::MuxOp>([&](
comb::MuxOp op) {
return lowerCombMux(op); })
252 .Default([&](
auto op) {
253 OpBuilder builder(value.getContext());
254 builder.setInsertionPoint(op);
255 SmallVector<Value> results;
256 comb::extractBits(builder, value, results);
258 return insertBits(value, std::move(results));
262LogicalResult BitBlaster::run() {
266 topOp, [](Value value, Operation *op) ->
bool {
270 comb::ReplicateOp>(op));
273 return mlir::emitError(topOp->getLoc(),
"there is a combinational cycle");
277 topOp->walk([&](Operation *op) {
280 (
void)lowerValueToBits(op->getResult(0));
284 for (
auto &[value, results] :
285 llvm::make_early_inc_range(
llvm::reverse(loweredValues))) {
286 if (hw::getBitWidth(value.getType()) <= 1)
289 auto *op = value.getDefiningOp();
293 if (value.use_empty()) {
301 OpBuilder builder(op);
302 std::reverse(results.begin(), results.end());
303 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
304 value.replaceAllUsesWith(concat);
312Value BitBlaster::getBoolConstant(
bool value, Block *block) {
313 auto &blockConstants = constantsByBlock[block];
314 if (!blockConstants[value]) {
315 auto builder = OpBuilder::atBlockBegin(block);
317 builder, builder.getUnknownLoc(), builder.getI1Type(), value);
319 return blockConstants[value];
323BitBlaster::lowerBooleanLogicOperation(BooleanLogicOpInterface op) {
324 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
325 return op.cloneWithSameInversion(builder, operands);
327 return lowerOp(op.getOperation(), createOp);
330template <
typename OpTy>
331ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
332 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
333 return builder.createOrFold<OpTy>(op.getLoc(), operands,
334 op.getTwoStateAttr());
336 return lowerOp(op, createOp);
339ArrayRef<Value> BitBlaster::lowerCombMux(
comb::MuxOp op) {
340 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
341 return builder.createOrFold<
comb::MuxOp>(op.getLoc(), operands[0],
342 operands[1], operands[2],
343 op.getTwoStateAttr());
345 return lowerOp(op, createOp);
348ArrayRef<Value> BitBlaster::lowerOp(
350 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
351 auto value = op->getResult(0);
352 OpBuilder builder(op);
353 auto width = hw::getBitWidth(value.getType());
354 assert(width > 1 &&
"expected multi-bit operation");
357 APInt knownMask = known.Zero | known.One;
360 numLoweredConstants += knownMask.popcount();
361 numLoweredBits += width;
364 SmallVector<Value> results;
365 results.reserve(width);
367 for (int64_t i = 0; i < width; ++i) {
368 SmallVector<Value> operands;
369 operands.reserve(op->getNumOperands());
372 results.push_back(getBoolConstant(known.One[i], op->getBlock()));
377 for (
auto operand : op->getOperands())
378 operands.push_back(extractBit(operand, i));
381 auto result = createOp(builder, operands);
382 results.push_back(result);
385 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint")) {
386 auto newName = StringAttr::get(
387 op->getContext(), name.getValue() +
"[" + std::to_string(i) +
"]");
388 if (
auto *loweredOp = result.getDefiningOp())
389 loweredOp->setAttr(
"sv.namehint", newName);
393 assert(results.size() ==
static_cast<size_t>(width));
394 return insertBits(value, std::move(results));
402struct LowerWordToBitsPass
403 :
public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
404 void runOnOperation()
override;
408void LowerWordToBitsPass::runOnOperation() {
409 BitBlaster driver(getOperation());
410 if (failed(driver.run()))
411 return signalPassFailure();
414 numLoweredBits += driver.numLoweredBits;
415 numLoweredConstants += driver.numLoweredConstants;
416 numLoweredOps += driver.numLoweredOps;
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)