20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
28#define DEBUG_TYPE "synth-lower-word-to-bits"
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
61 explicit BitBlaster(
hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
71 size_t numLoweredBits = 0;
74 size_t numLoweredConstants = 0;
77 size_t numLoweredOps = 0;
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <
typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <
typename OpTy>
90 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
93 lowerOp(Operation *op,
94 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
98 Value extractBit(Value value,
size_t index);
106 Value getBoolConstant(
bool value);
113 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
114 auto it = loweredValues.insert({value, std::move(bits)});
115 assert(it.second &&
"value already inserted");
116 return it.first->second;
120 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
121 auto it = knownBits.insert({value, std::move(bits)});
122 return it.first->second;
132 std::array<Value, 2> constants;
144const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
146 auto *it = knownBits.find(value);
147 if (it != knownBits.end())
150 auto width = hw::getBitWidth(value.getType());
151 auto *op = value.getDefiningOp();
155 return insertKnownBits(value, llvm::KnownBits(width));
157 llvm::KnownBits result(width);
158 if (
auto aig = dyn_cast<aig::AndInverterOp>(op)) {
160 result.One = APInt::getAllOnes(width);
161 result.Zero = APInt::getZero(width);
163 for (
auto [operand, inverted] :
164 llvm::zip(aig.getInputs(), aig.getInverted())) {
168 std::swap(operandKnownBits.Zero, operandKnownBits.One);
169 result &= operandKnownBits;
171 }
else if (
auto choice = dyn_cast<ChoiceOp>(op)) {
173 for (
auto input : choice.getInputs().drop_front()) {
175 result.One |= known.One;
176 result.Zero |= known.Zero;
182 result = comb::computeKnownBits(value);
185 return insertKnownBits(value, std::move(result));
188Value BitBlaster::extractBit(Value value,
size_t index) {
189 if (hw::getBitWidth(value.getType()) <= 1)
192 auto *op = value.getDefiningOp();
196 return lowerValueToBits(value)[index];
198 return TypeSwitch<Operation *, Value>(op)
200 for (
auto operand :
llvm::reverse(op.getOperands())) {
201 auto width = hw::getBitWidth(operand.getType());
202 assert(width >= 0 &&
"operand has zero width");
203 if (index <
static_cast<size_t>(width))
204 return extractBit(operand, index);
205 index -=
static_cast<size_t>(width);
207 llvm_unreachable(
"index out of bounds");
210 return extractBit(ext.getInput(),
211 static_cast<size_t>(ext.getLowBit()) + index);
213 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
214 return extractBit(op.getInput(),
215 index %
static_cast<size_t>(hw::getBitWidth(
216 op.getOperand().getType())));
219 auto value = op.getValue();
220 return getBoolConstant(value[index]);
222 .Default([&](
auto op) {
return lowerValueToBits(value)[index]; });
225ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
226 auto *it = loweredValues.find(value);
227 if (it != loweredValues.end())
230 auto width = hw::getBitWidth(value.getType());
232 return insertBits(value, {value});
234 auto *op = value.getDefiningOp();
236 SmallVector<Value> results;
237 OpBuilder builder(value.getContext());
238 builder.setInsertionPointAfterValue(value);
239 comb::extractBits(builder, value, results);
240 return insertBits(value, std::move(results));
243 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
244 .Case<ChoiceOp>([&](ChoiceOp op) {
245 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
246 return builder.createOrFold<ChoiceOp>(
247 op.getLoc(), operands[0].getType(), operands);
249 return lowerOp(op, createOp);
251 .Case<aig::AndInverterOp>(
252 [&](
auto op) {
return lowerInvertibleOperations(op); })
253 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
254 [&](
auto op) {
return lowerCombLogicOperations(op); })
255 .Case<comb::MuxOp>([&](
comb::MuxOp op) {
return lowerCombMux(op); })
256 .Default([&](
auto op) {
257 OpBuilder builder(value.getContext());
258 builder.setInsertionPoint(op);
259 SmallVector<Value> results;
260 comb::extractBits(builder, value, results);
262 return insertBits(value, std::move(results));
266LogicalResult BitBlaster::run() {
270 moduleOp, [](Value value, Operation *op) ->
bool {
274 comb::ReplicateOp>(op));
277 return mlir::emitError(moduleOp.getLoc(),
"there is a combinational cycle");
281 moduleOp.walk([&](Operation *op) {
284 (
void)lowerValueToBits(op->getResult(0));
288 for (
auto &[value, results] :
289 llvm::make_early_inc_range(
llvm::reverse(loweredValues))) {
290 if (hw::getBitWidth(value.getType()) <= 1)
293 auto *op = value.getDefiningOp();
297 if (value.use_empty()) {
305 OpBuilder builder(op);
306 std::reverse(results.begin(), results.end());
307 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
308 value.replaceAllUsesWith(concat);
316Value BitBlaster::getBoolConstant(
bool value) {
317 if (!constants[value]) {
318 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
320 builder.getI1Type(), value);
322 return constants[value];
325template <
typename OpTy>
326ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
327 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
328 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
330 return lowerOp(op, createOp);
333template <
typename OpTy>
334ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
335 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
336 return builder.createOrFold<OpTy>(op.getLoc(), operands,
337 op.getTwoStateAttr());
339 return lowerOp(op, createOp);
342ArrayRef<Value> BitBlaster::lowerCombMux(
comb::MuxOp op) {
343 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
344 return builder.createOrFold<
comb::MuxOp>(op.getLoc(), operands[0],
345 operands[1], operands[2],
346 op.getTwoStateAttr());
348 return lowerOp(op, createOp);
351ArrayRef<Value> BitBlaster::lowerOp(
353 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
354 auto value = op->getResult(0);
355 OpBuilder builder(op);
356 auto width = hw::getBitWidth(value.getType());
357 assert(width > 1 &&
"expected multi-bit operation");
360 APInt knownMask = known.Zero | known.One;
363 numLoweredConstants += knownMask.popcount();
364 numLoweredBits += width;
367 SmallVector<Value> results;
368 results.reserve(width);
370 for (int64_t i = 0; i < width; ++i) {
371 SmallVector<Value> operands;
372 operands.reserve(op->getNumOperands());
375 results.push_back(getBoolConstant(known.One[i]));
380 for (
auto operand : op->getOperands())
381 operands.push_back(extractBit(operand, i));
384 auto result = createOp(builder, operands);
385 results.push_back(result);
388 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint")) {
389 auto newName = StringAttr::get(
390 op->getContext(), name.getValue() +
"[" + std::to_string(i) +
"]");
391 if (
auto *loweredOp = result.getDefiningOp())
392 loweredOp->setAttr(
"sv.namehint", newName);
396 assert(results.size() ==
static_cast<size_t>(width));
397 return insertBits(value, std::move(results));
405struct LowerWordToBitsPass
406 :
public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
407 void runOnOperation()
override;
411void LowerWordToBitsPass::runOnOperation() {
412 BitBlaster driver(getOperation());
413 if (failed(driver.run()))
414 return signalPassFailure();
417 numLoweredBits += driver.numLoweredBits;
418 numLoweredConstants += driver.numLoweredConstants;
419 numLoweredOps += driver.numLoweredOps;
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)