20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
28#define DEBUG_TYPE "synth-lower-word-to-bits"
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
46 return isa<aig::AndInverterOp, mig::MajorityInverterOp,
comb::AndOp,
61 explicit BitBlaster(
hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
71 size_t numLoweredBits = 0;
74 size_t numLoweredConstants = 0;
77 size_t numLoweredOps = 0;
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <
typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <
typename OpTy>
90 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
93 lowerOp(Operation *op,
94 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
98 Value extractBit(Value value,
size_t index);
106 Value getBoolConstant(
bool value);
113 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
114 auto it = loweredValues.insert({value, std::move(bits)});
115 assert(it.second &&
"value already inserted");
116 return it.first->second;
120 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
121 auto it = knownBits.insert({value, std::move(bits)});
122 return it.first->second;
126 llvm::MapVector<Value, SmallVector<Value>> loweredValues;
129 llvm::MapVector<Value, llvm::KnownBits> knownBits;
132 std::array<Value, 2> constants;
144const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
146 auto *it = knownBits.find(value);
147 if (it != knownBits.end())
150 auto width = hw::getBitWidth(value.getType());
151 auto *op = value.getDefiningOp();
155 return insertKnownBits(value, llvm::KnownBits(width));
157 llvm::KnownBits result(width);
158 if (
auto aig = dyn_cast<aig::AndInverterOp>(op)) {
160 result.One = APInt::getAllOnes(width);
161 result.Zero = APInt::getZero(width);
163 for (
auto [operand, inverted] :
164 llvm::zip(aig.getInputs(), aig.getInverted())) {
168 std::swap(operandKnownBits.Zero, operandKnownBits.One);
169 result &= operandKnownBits;
171 }
else if (
auto mig = dyn_cast<mig::MajorityInverterOp>(op)) {
173 if (mig.getNumOperands() == 3) {
174 std::array<llvm::KnownBits, 3> operandsKnownBits;
175 for (
auto [i, operand, inverted] :
176 llvm::enumerate(mig.getInputs(), mig.getInverted())) {
180 std::swap(operandsKnownBits[i].
Zero, operandsKnownBits[i].
One);
183 result = (operandsKnownBits[0] & operandsKnownBits[1]) |
184 (operandsKnownBits[0] & operandsKnownBits[2]) |
185 (operandsKnownBits[1] & operandsKnownBits[2]);
191 result = comb::computeKnownBits(value);
194 return insertKnownBits(value, std::move(result));
197Value BitBlaster::extractBit(Value value,
size_t index) {
198 if (hw::getBitWidth(value.getType()) <= 1)
201 auto *op = value.getDefiningOp();
205 return lowerValueToBits(value)[index];
207 return TypeSwitch<Operation *, Value>(op)
209 for (
auto operand :
llvm::reverse(op.getOperands())) {
210 auto width = hw::getBitWidth(operand.getType());
211 assert(width >= 0 &&
"operand has zero width");
212 if (index <
static_cast<size_t>(width))
213 return extractBit(operand, index);
214 index -=
static_cast<size_t>(width);
216 llvm_unreachable(
"index out of bounds");
219 return extractBit(ext.getInput(),
220 static_cast<size_t>(ext.getLowBit()) + index);
222 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
223 return extractBit(op.getInput(),
224 index %
static_cast<size_t>(hw::getBitWidth(
225 op.getOperand().getType())));
228 auto value = op.getValue();
229 return getBoolConstant(value[index]);
231 .Default([&](
auto op) {
return lowerValueToBits(value)[index]; });
234ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
235 auto *it = loweredValues.find(value);
236 if (it != loweredValues.end())
239 auto width = hw::getBitWidth(value.getType());
241 return insertBits(value, {value});
243 auto *op = value.getDefiningOp();
245 SmallVector<Value> results;
246 OpBuilder builder(value.getContext());
247 builder.setInsertionPointAfterValue(value);
248 comb::extractBits(builder, value, results);
249 return insertBits(value, std::move(results));
252 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
253 .Case<aig::AndInverterOp, mig::MajorityInverterOp>(
254 [&](
auto op) {
return lowerInvertibleOperations(op); })
255 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
256 [&](
auto op) {
return lowerCombLogicOperations(op); })
257 .Case<comb::MuxOp>([&](
comb::MuxOp op) {
return lowerCombMux(op); })
258 .Default([&](
auto op) {
259 OpBuilder builder(value.getContext());
260 builder.setInsertionPoint(op);
261 SmallVector<Value> results;
262 comb::extractBits(builder, value, results);
264 return insertBits(value, std::move(results));
268LogicalResult BitBlaster::run() {
272 moduleOp, [](Value value, Operation *op) ->
bool {
276 comb::ReplicateOp>(op));
279 return mlir::emitError(moduleOp.getLoc(),
"there is a combinational cycle");
283 moduleOp.walk([&](Operation *op) {
286 (
void)lowerValueToBits(op->getResult(0));
290 for (
auto &[value, results] :
291 llvm::make_early_inc_range(
llvm::reverse(loweredValues))) {
292 if (hw::getBitWidth(value.getType()) <= 1)
295 auto *op = value.getDefiningOp();
299 if (value.use_empty()) {
307 OpBuilder builder(op);
308 std::reverse(results.begin(), results.end());
309 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
310 value.replaceAllUsesWith(
concat);
318Value BitBlaster::getBoolConstant(
bool value) {
319 if (!constants[value]) {
320 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
322 builder.getI1Type(), value);
324 return constants[value];
327template <
typename OpTy>
328ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
329 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
330 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
332 return lowerOp(op, createOp);
335template <
typename OpTy>
336ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
337 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
338 return builder.createOrFold<OpTy>(op.getLoc(), operands,
339 op.getTwoStateAttr());
341 return lowerOp(op, createOp);
344ArrayRef<Value> BitBlaster::lowerCombMux(
comb::MuxOp op) {
345 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
346 return builder.createOrFold<
comb::MuxOp>(op.getLoc(), operands[0],
347 operands[1], operands[2],
348 op.getTwoStateAttr());
350 return lowerOp(op, createOp);
353ArrayRef<Value> BitBlaster::lowerOp(
355 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
356 auto value = op->getResult(0);
357 OpBuilder builder(op);
358 auto width = hw::getBitWidth(value.getType());
359 assert(width > 1 &&
"expected multi-bit operation");
362 APInt knownMask = known.Zero | known.One;
365 numLoweredConstants += knownMask.popcount();
366 numLoweredBits += width;
369 SmallVector<Value> results;
370 results.reserve(width);
372 for (int64_t i = 0; i < width; ++i) {
373 SmallVector<Value> operands;
374 operands.reserve(op->getNumOperands());
377 results.push_back(getBoolConstant(known.One[i]));
382 for (
auto operand : op->getOperands())
383 operands.push_back(extractBit(operand, i));
386 auto result = createOp(builder, operands);
387 results.push_back(result);
390 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint")) {
391 auto newName = StringAttr::get(
392 op->getContext(), name.getValue() +
"[" + std::to_string(i) +
"]");
393 if (
auto *loweredOp = result.getDefiningOp())
394 loweredOp->setAttr(
"sv.namehint", newName);
398 assert(results.size() ==
static_cast<size_t>(width));
399 return insertBits(value, std::move(results));
407struct LowerWordToBitsPass
408 :
public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
409 void runOnOperation()
override;
413void LowerWordToBitsPass::runOnOperation() {
414 BitBlaster driver(getOperation());
415 if (failed(driver.run()))
416 return signalPassFailure();
419 numLoweredBits += driver.numLoweredBits;
420 numLoweredConstants += driver.numLoweredConstants;
421 numLoweredOps += driver.numLoweredOps;
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)