20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
28#define DEBUG_TYPE "synth-lower-word-to-bits"
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
46 return isa<ChoiceOp, aig::AndInverterOp, mig::MajorityInverterOp,
comb::AndOp,
61 explicit BitBlaster(
hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
71 size_t numLoweredBits = 0;
74 size_t numLoweredConstants = 0;
77 size_t numLoweredOps = 0;
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <
typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <
typename OpTy>
90 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
93 lowerOp(Operation *op,
94 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
98 Value extractBit(Value value,
size_t index);
106 Value getBoolConstant(
bool value);
113 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
114 auto it = loweredValues.insert({value, std::move(bits)});
115 assert(it.second &&
"value already inserted");
116 return it.first->second;
120 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
121 auto it = knownBits.insert({value, std::move(bits)});
122 return it.first->second;
126 llvm::MapVector<Value, SmallVector<Value>> loweredValues;
129 llvm::MapVector<Value, llvm::KnownBits> knownBits;
132 std::array<Value, 2> constants;
144const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
146 auto *it = knownBits.find(value);
147 if (it != knownBits.end())
150 auto width = hw::getBitWidth(value.getType());
151 auto *op = value.getDefiningOp();
155 return insertKnownBits(value, llvm::KnownBits(width));
157 llvm::KnownBits result(width);
158 if (
auto aig = dyn_cast<aig::AndInverterOp>(op)) {
160 result.One = APInt::getAllOnes(width);
161 result.Zero = APInt::getZero(width);
163 for (
auto [operand, inverted] :
164 llvm::zip(aig.getInputs(), aig.getInverted())) {
168 std::swap(operandKnownBits.Zero, operandKnownBits.One);
169 result &= operandKnownBits;
171 }
else if (
auto mig = dyn_cast<mig::MajorityInverterOp>(op)) {
173 if (mig.getNumOperands() == 3) {
174 std::array<llvm::KnownBits, 3> operandsKnownBits;
175 for (
auto [i, operand, inverted] :
176 llvm::enumerate(mig.getInputs(), mig.getInverted())) {
180 std::swap(operandsKnownBits[i].
Zero, operandsKnownBits[i].
One);
183 result = (operandsKnownBits[0] & operandsKnownBits[1]) |
184 (operandsKnownBits[0] & operandsKnownBits[2]) |
185 (operandsKnownBits[1] & operandsKnownBits[2]);
187 }
else if (
auto choice = dyn_cast<ChoiceOp>(op)) {
189 for (
auto input : choice.getInputs().drop_front()) {
191 result.One |= known.One;
192 result.Zero |= known.Zero;
198 result = comb::computeKnownBits(value);
201 return insertKnownBits(value, std::move(result));
204Value BitBlaster::extractBit(Value value,
size_t index) {
205 if (hw::getBitWidth(value.getType()) <= 1)
208 auto *op = value.getDefiningOp();
212 return lowerValueToBits(value)[index];
214 return TypeSwitch<Operation *, Value>(op)
216 for (
auto operand :
llvm::reverse(op.getOperands())) {
217 auto width = hw::getBitWidth(operand.getType());
218 assert(width >= 0 &&
"operand has zero width");
219 if (index <
static_cast<size_t>(width))
220 return extractBit(operand, index);
221 index -=
static_cast<size_t>(width);
223 llvm_unreachable(
"index out of bounds");
226 return extractBit(ext.getInput(),
227 static_cast<size_t>(ext.getLowBit()) + index);
229 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
230 return extractBit(op.getInput(),
231 index %
static_cast<size_t>(hw::getBitWidth(
232 op.getOperand().getType())));
235 auto value = op.getValue();
236 return getBoolConstant(value[index]);
238 .Default([&](
auto op) {
return lowerValueToBits(value)[index]; });
241ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
242 auto *it = loweredValues.find(value);
243 if (it != loweredValues.end())
246 auto width = hw::getBitWidth(value.getType());
248 return insertBits(value, {value});
250 auto *op = value.getDefiningOp();
252 SmallVector<Value> results;
253 OpBuilder builder(value.getContext());
254 builder.setInsertionPointAfterValue(value);
255 comb::extractBits(builder, value, results);
256 return insertBits(value, std::move(results));
259 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
260 .Case<ChoiceOp>([&](ChoiceOp op) {
261 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
262 return builder.createOrFold<ChoiceOp>(
263 op.getLoc(), operands[0].getType(), operands);
265 return lowerOp(op, createOp);
267 .Case<aig::AndInverterOp, mig::MajorityInverterOp>(
268 [&](
auto op) {
return lowerInvertibleOperations(op); })
269 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
270 [&](
auto op) {
return lowerCombLogicOperations(op); })
271 .Case<comb::MuxOp>([&](
comb::MuxOp op) {
return lowerCombMux(op); })
272 .Default([&](
auto op) {
273 OpBuilder builder(value.getContext());
274 builder.setInsertionPoint(op);
275 SmallVector<Value> results;
276 comb::extractBits(builder, value, results);
278 return insertBits(value, std::move(results));
282LogicalResult BitBlaster::run() {
286 moduleOp, [](Value value, Operation *op) ->
bool {
290 comb::ReplicateOp>(op));
293 return mlir::emitError(moduleOp.getLoc(),
"there is a combinational cycle");
297 moduleOp.walk([&](Operation *op) {
300 (
void)lowerValueToBits(op->getResult(0));
304 for (
auto &[value, results] :
305 llvm::make_early_inc_range(
llvm::reverse(loweredValues))) {
306 if (hw::getBitWidth(value.getType()) <= 1)
309 auto *op = value.getDefiningOp();
313 if (value.use_empty()) {
321 OpBuilder builder(op);
322 std::reverse(results.begin(), results.end());
323 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
324 value.replaceAllUsesWith(concat);
332Value BitBlaster::getBoolConstant(
bool value) {
333 if (!constants[value]) {
334 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
336 builder.getI1Type(), value);
338 return constants[value];
341template <
typename OpTy>
342ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
343 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
344 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
346 return lowerOp(op, createOp);
349template <
typename OpTy>
350ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
351 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
352 return builder.createOrFold<OpTy>(op.getLoc(), operands,
353 op.getTwoStateAttr());
355 return lowerOp(op, createOp);
358ArrayRef<Value> BitBlaster::lowerCombMux(
comb::MuxOp op) {
359 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
360 return builder.createOrFold<
comb::MuxOp>(op.getLoc(), operands[0],
361 operands[1], operands[2],
362 op.getTwoStateAttr());
364 return lowerOp(op, createOp);
367ArrayRef<Value> BitBlaster::lowerOp(
369 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
370 auto value = op->getResult(0);
371 OpBuilder builder(op);
372 auto width = hw::getBitWidth(value.getType());
373 assert(width > 1 &&
"expected multi-bit operation");
376 APInt knownMask = known.Zero | known.One;
379 numLoweredConstants += knownMask.popcount();
380 numLoweredBits += width;
383 SmallVector<Value> results;
384 results.reserve(width);
386 for (int64_t i = 0; i < width; ++i) {
387 SmallVector<Value> operands;
388 operands.reserve(op->getNumOperands());
391 results.push_back(getBoolConstant(known.One[i]));
396 for (
auto operand : op->getOperands())
397 operands.push_back(extractBit(operand, i));
400 auto result = createOp(builder, operands);
401 results.push_back(result);
404 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint")) {
405 auto newName = StringAttr::get(
406 op->getContext(), name.getValue() +
"[" + std::to_string(i) +
"]");
407 if (
auto *loweredOp = result.getDefiningOp())
408 loweredOp->setAttr(
"sv.namehint", newName);
412 assert(results.size() ==
static_cast<size_t>(width));
413 return insertBits(value, std::move(results));
421struct LowerWordToBitsPass
422 :
public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
423 void runOnOperation()
override;
427void LowerWordToBitsPass::runOnOperation() {
428 BitBlaster driver(getOperation());
429 if (failed(driver.run()))
430 return signalPassFailure();
433 numLoweredBits += driver.numLoweredBits;
434 numLoweredConstants += driver.numLoweredConstants;
435 numLoweredOps += driver.numLoweredOps;
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)