20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
28#define DEBUG_TYPE "synth-lower-word-to-bits"
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
46 return isa<aig::AndInverterOp, mig::MajorityInverterOp,
comb::AndOp,
61 explicit BitBlaster(
hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
71 size_t numLoweredBits = 0;
74 size_t numLoweredConstants = 0;
77 size_t numLoweredOps = 0;
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <
typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <
typename OpTy>
90 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
93 lowerOp(Operation *op,
94 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
98 Value extractBit(Value value,
size_t index);
106 Value getBoolConstant(
bool value);
113 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
114 auto it = loweredValues.insert({value, std::move(bits)});
115 assert(it.second &&
"value already inserted");
116 return it.first->second;
120 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
121 auto it = knownBits.insert({value, std::move(bits)});
122 return it.first->second;
126 llvm::MapVector<Value, SmallVector<Value>> loweredValues;
129 llvm::MapVector<Value, llvm::KnownBits> knownBits;
132 std::array<Value, 2> constants;
144const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
146 auto *it = knownBits.find(value);
147 if (it != knownBits.end())
150 auto width = hw::getBitWidth(value.getType());
151 auto *op = value.getDefiningOp();
155 return insertKnownBits(value, llvm::KnownBits(width));
157 llvm::KnownBits result(width);
158 if (
auto aig = dyn_cast<aig::AndInverterOp>(op)) {
160 result.One = APInt::getAllOnes(width);
161 result.Zero = APInt::getZero(width);
163 for (
auto [operand, inverted] :
164 llvm::zip(aig.getInputs(), aig.getInverted())) {
168 std::swap(operandKnownBits.Zero, operandKnownBits.One);
169 result &= operandKnownBits;
171 }
else if (
auto mig = dyn_cast<mig::MajorityInverterOp>(op)) {
173 if (mig.getNumOperands() == 3) {
174 std::array<llvm::KnownBits, 3> operandsKnownBits;
175 for (
auto [i, operand, inverted] :
176 llvm::enumerate(mig.getInputs(), mig.getInverted())) {
180 std::swap(operandsKnownBits[i].
Zero, operandsKnownBits[i].
One);
183 result = (operandsKnownBits[0] & operandsKnownBits[1]) |
184 (operandsKnownBits[0] & operandsKnownBits[2]) |
185 (operandsKnownBits[1] & operandsKnownBits[2]);
191 result = comb::computeKnownBits(value);
194 return insertKnownBits(value, std::move(result));
197Value BitBlaster::extractBit(Value value,
size_t index) {
198 if (hw::getBitWidth(value.getType()) <= 1)
201 auto *op = value.getDefiningOp();
205 return lowerValueToBits(value)[index];
207 return TypeSwitch<Operation *, Value>(op)
209 for (
auto operand :
llvm::reverse(op.getOperands())) {
210 auto width = hw::getBitWidth(operand.getType());
211 assert(width >= 0 &&
"operand has zero width");
212 if (index <
static_cast<size_t>(width))
213 return extractBit(operand, index);
214 index -=
static_cast<size_t>(width);
216 llvm_unreachable(
"index out of bounds");
219 return extractBit(ext.getInput(),
220 static_cast<size_t>(ext.getLowBit()) + index);
222 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
223 return extractBit(op.getInput(),
224 index %
static_cast<size_t>(hw::getBitWidth(
225 op.getOperand().getType())));
228 auto value = op.getValue();
229 return getBoolConstant(value[index]);
231 .Default([&](
auto op) {
return lowerValueToBits(value)[index]; });
234ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
235 auto *it = loweredValues.find(value);
236 if (it != loweredValues.end())
239 auto width = hw::getBitWidth(value.getType());
241 return insertBits(value, {value});
243 auto *op = value.getDefiningOp();
245 SmallVector<Value> results;
246 OpBuilder builder(value.getContext());
247 builder.setInsertionPointAfterValue(value);
248 comb::extractBits(builder, value, results);
249 return insertBits(value, std::move(results));
252 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
253 .Case<aig::AndInverterOp, mig::MajorityInverterOp>(
254 [&](
auto op) {
return lowerInvertibleOperations(op); })
255 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
256 [&](
auto op) {
return lowerCombLogicOperations(op); })
257 .Case<comb::MuxOp>([&](
comb::MuxOp op) {
return lowerCombMux(op); })
258 .Default([&](
auto op) {
259 OpBuilder builder(value.getContext());
260 builder.setInsertionPoint(op);
261 SmallVector<Value> results;
262 comb::extractBits(builder, value, results);
264 return insertBits(value, std::move(results));
268LogicalResult BitBlaster::run() {
272 moduleOp, [](Value value, Operation *op) ->
bool {
276 comb::ReplicateOp>(op));
279 return mlir::emitError(moduleOp.getLoc(),
"there is a combinational cycle");
283 moduleOp.walk([&](Operation *op) {
286 (
void)lowerValueToBits(op->getResult(0));
290 for (
auto &[value, results] :
291 llvm::make_early_inc_range(
llvm::reverse(loweredValues))) {
292 if (hw::getBitWidth(value.getType()) <= 1)
295 auto *op = value.getDefiningOp();
299 if (value.use_empty()) {
307 OpBuilder builder(op);
308 std::reverse(results.begin(), results.end());
309 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
310 value.replaceAllUsesWith(concat);
318Value BitBlaster::getBoolConstant(
bool value) {
319 if (!constants[value]) {
320 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
322 builder.getI1Type(), value);
324 return constants[value];
327template <
typename OpTy>
328ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
329 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
330 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
332 return lowerOp(op, createOp);
335template <
typename OpTy>
336ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
337 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
338 return builder.createOrFold<OpTy>(op.getLoc(), operands,
339 op.getTwoStateAttr());
341 return lowerOp(op, createOp);
344ArrayRef<Value> BitBlaster::lowerCombMux(
comb::MuxOp op) {
345 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
346 return builder.createOrFold<
comb::MuxOp>(op.getLoc(), operands[0],
347 operands[1], operands[2],
348 op.getTwoStateAttr());
350 return lowerOp(op, createOp);
353ArrayRef<Value> BitBlaster::lowerOp(
355 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
356 auto value = op->getResult(0);
357 OpBuilder builder(op);
358 auto width = hw::getBitWidth(value.getType());
359 assert(width > 1 &&
"expected multi-bit operation");
362 APInt knownMask = known.Zero | known.One;
365 numLoweredConstants += knownMask.popcount();
366 numLoweredBits += width;
369 SmallVector<Value> results;
370 results.reserve(width);
372 for (int64_t i = 0; i < width; ++i) {
373 SmallVector<Value> operands;
374 operands.reserve(op->getNumOperands());
377 results.push_back(getBoolConstant(known.One[i]));
382 for (
auto operand : op->getOperands())
383 operands.push_back(extractBit(operand, i));
386 auto result = createOp(builder, operands);
387 results.push_back(result);
390 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint")) {
391 auto newName = StringAttr::get(
392 op->getContext(), name.getValue() +
"[" + std::to_string(i) +
"]");
393 if (
auto *loweredOp = result.getDefiningOp())
394 loweredOp->setAttr(
"sv.namehint", newName);
398 assert(results.size() ==
static_cast<size_t>(width));
399 return insertBits(value, std::move(results));
407struct LowerWordToBitsPass
408 :
public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
409 void runOnOperation()
override;
413void LowerWordToBitsPass::runOnOperation() {
414 BitBlaster driver(getOperation());
415 if (failed(driver.run()))
416 return signalPassFailure();
419 numLoweredBits += driver.numLoweredBits;
420 numLoweredConstants += driver.numLoweredConstants;
421 numLoweredOps += driver.numLoweredOps;
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)