20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
28#define DEBUG_TYPE "synth-lower-word-to-bits"
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
46 return isa<aig::AndInverterOp, mig::MajorityInverterOp,
comb::AndOp,
61 explicit BitBlaster(
hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
71 size_t numLoweredBits = 0;
74 size_t numLoweredConstants = 0;
77 size_t numLoweredOps = 0;
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <
typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <
typename OpTy>
90 ArrayRef<Value> lowerCombOperations(OpTy op);
92 lowerOp(Operation *op,
93 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
97 Value extractBit(Value value,
size_t index);
105 Value getBoolConstant(
bool value);
112 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
113 auto it = loweredValues.insert({value, std::move(bits)});
114 assert(it.second &&
"value already inserted");
115 return it.first->second;
119 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
120 auto it = knownBits.insert({value, std::move(bits)});
121 return it.first->second;
125 llvm::MapVector<Value, SmallVector<Value>> loweredValues;
128 llvm::MapVector<Value, llvm::KnownBits> knownBits;
131 std::array<Value, 2> constants;
143const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
145 auto *it = knownBits.find(value);
146 if (it != knownBits.end())
149 auto width = hw::getBitWidth(value.getType());
150 auto *op = value.getDefiningOp();
154 return insertKnownBits(value, llvm::KnownBits(width));
156 llvm::KnownBits result(width);
157 if (
auto aig = dyn_cast<aig::AndInverterOp>(op)) {
159 result.One = APInt::getAllOnes(width);
160 result.Zero = APInt::getZero(width);
162 for (
auto [operand, inverted] :
163 llvm::zip(aig.getInputs(), aig.getInverted())) {
167 std::swap(operandKnownBits.Zero, operandKnownBits.One);
168 result &= operandKnownBits;
170 }
else if (
auto mig = dyn_cast<mig::MajorityInverterOp>(op)) {
172 if (mig.getNumOperands() == 3) {
173 std::array<llvm::KnownBits, 3> operandsKnownBits;
174 for (
auto [i, operand, inverted] :
175 llvm::enumerate(mig.getInputs(), mig.getInverted())) {
179 std::swap(operandsKnownBits[i].
Zero, operandsKnownBits[i].
One);
182 result = (operandsKnownBits[0] & operandsKnownBits[1]) |
183 (operandsKnownBits[0] & operandsKnownBits[2]) |
184 (operandsKnownBits[1] & operandsKnownBits[2]);
190 result = comb::computeKnownBits(value);
193 return insertKnownBits(value, std::move(result));
196Value BitBlaster::extractBit(Value value,
size_t index) {
197 if (hw::getBitWidth(value.getType()) <= 1)
200 auto *op = value.getDefiningOp();
204 return lowerValueToBits(value)[index];
206 return TypeSwitch<Operation *, Value>(op)
208 for (
auto operand :
llvm::reverse(op.getOperands())) {
209 auto width = hw::getBitWidth(operand.getType());
210 assert(width >= 0 &&
"operand has zero width");
211 if (index <
static_cast<size_t>(width))
212 return extractBit(operand, index);
213 index -=
static_cast<size_t>(width);
215 llvm_unreachable(
"index out of bounds");
218 return extractBit(ext.getInput(),
219 static_cast<size_t>(ext.getLowBit()) + index);
221 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
222 return extractBit(op.getInput(),
223 index %
static_cast<size_t>(hw::getBitWidth(
224 op.getOperand().getType())));
227 auto value = op.getValue();
228 return getBoolConstant(value[index]);
230 .Default([&](
auto op) {
return lowerValueToBits(value)[index]; });
233ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
234 auto *it = loweredValues.find(value);
235 if (it != loweredValues.end())
238 auto width = hw::getBitWidth(value.getType());
240 return insertBits(value, {value});
242 auto *op = value.getDefiningOp();
244 SmallVector<Value> results;
245 OpBuilder builder(value.getContext());
246 builder.setInsertionPointAfterValue(value);
247 comb::extractBits(builder, value, results);
248 return insertBits(value, std::move(results));
251 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
252 .Case<aig::AndInverterOp, mig::MajorityInverterOp>(
253 [&](
auto op) {
return lowerInvertibleOperations(op); })
254 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
255 [&](
auto op) {
return lowerCombOperations(op); })
256 .Default([&](
auto op) {
257 OpBuilder builder(value.getContext());
258 builder.setInsertionPoint(op);
259 SmallVector<Value> results;
260 comb::extractBits(builder, value, results);
262 return insertBits(value, std::move(results));
266LogicalResult BitBlaster::run() {
270 moduleOp, [](Value value, Operation *op) ->
bool {
274 comb::ReplicateOp>(op));
277 return mlir::emitError(moduleOp.getLoc(),
"there is a combinational cycle");
281 moduleOp.walk([&](Operation *op) {
284 (
void)lowerValueToBits(op->getResult(0));
288 for (
auto &[value, results] :
289 llvm::make_early_inc_range(
llvm::reverse(loweredValues))) {
290 if (hw::getBitWidth(value.getType()) <= 1)
293 auto *op = value.getDefiningOp();
297 if (value.use_empty()) {
305 OpBuilder builder(op);
306 std::reverse(results.begin(), results.end());
308 value.replaceAllUsesWith(
concat);
316Value BitBlaster::getBoolConstant(
bool value) {
317 if (!constants[value]) {
318 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
320 builder.getUnknownLoc(), builder.getI1Type(), value);
322 return constants[value];
325template <
typename OpTy>
326ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
327 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
328 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
330 return lowerOp(op, createOp);
333template <
typename OpTy>
334ArrayRef<Value> BitBlaster::lowerCombOperations(OpTy op) {
335 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
336 return builder.createOrFold<OpTy>(op.getLoc(), operands,
337 op.getTwoStateAttr());
339 return lowerOp(op, createOp);
342ArrayRef<Value> BitBlaster::lowerOp(
344 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
345 auto value = op->getResult(0);
346 OpBuilder builder(op);
347 auto width = hw::getBitWidth(value.getType());
348 assert(width > 1 &&
"expected multi-bit operation");
351 APInt knownMask = known.Zero | known.One;
354 numLoweredConstants += knownMask.popcount();
355 numLoweredBits += width;
358 SmallVector<Value> results;
359 results.reserve(width);
361 for (int64_t i = 0; i < width; ++i) {
362 SmallVector<Value> operands;
363 operands.reserve(op->getNumOperands());
366 results.push_back(getBoolConstant(known.One[i]));
371 for (
auto operand : op->getOperands())
372 operands.push_back(extractBit(operand, i));
375 auto result = createOp(builder, operands);
376 results.push_back(result);
379 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint")) {
380 auto newName = StringAttr::get(
381 op->getContext(), name.getValue() +
"[" + std::to_string(i) +
"]");
382 if (
auto *loweredOp = result.getDefiningOp())
383 loweredOp->setAttr(
"sv.namehint", newName);
387 assert(results.size() ==
static_cast<size_t>(width));
388 return insertBits(value, std::move(results));
396struct LowerWordToBitsPass
397 :
public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
398 void runOnOperation()
override;
402void LowerWordToBitsPass::runOnOperation() {
403 BitBlaster driver(getOperation());
404 if (failed(driver.run()))
405 return signalPassFailure();
408 numLoweredBits += driver.numLoweredBits;
409 numLoweredConstants += driver.numLoweredConstants;
410 numLoweredOps += driver.numLoweredOps;
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)