21#include "mlir/Analysis/TopologicalSortUtils.h"
22#include "mlir/IR/OpDefinition.h"
23#include "llvm/ADT/PointerIntPair.h"
24#include "llvm/ADT/PriorityQueue.h"
26#define DEBUG_TYPE "synth-lower-variadic"
30#define GEN_PASS_DEF_LOWERVARIADIC
31#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
46class ValueWithArrivalTime {
49 llvm::PointerIntPair<Value, 1, bool> value;
57 size_t valueNumbering = 0;
60 ValueWithArrivalTime(Value value, int64_t arrivalTime,
bool invert,
61 size_t valueNumbering)
62 : value(value, invert), arrivalTime(arrivalTime),
63 valueNumbering(valueNumbering) {}
65 Value getValue()
const {
return value.getPointer(); }
66 bool isInverted()
const {
return value.getInt(); }
71 bool operator>(
const ValueWithArrivalTime &other)
const {
72 return arrivalTime > other.arrivalTime ||
73 (arrivalTime == other.arrivalTime &&
74 valueNumbering > other.valueNumbering);
78struct LowerVariadicPass :
public impl::LowerVariadicBase<LowerVariadicPass> {
79 using LowerVariadicBase::LowerVariadicBase;
80 void runOnOperation()
override;
90 IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
91 Operation *op, llvm::function_ref<
bool(OpOperand &)> isInverted,
92 llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
96 llvm::PriorityQueue<ValueWithArrivalTime, std::vector<ValueWithArrivalTime>,
97 std::greater<ValueWithArrivalTime>>
101 size_t valueNumber = 0;
103 auto push = [&](Value value,
bool invert) {
108 auto result = analysis->getMaxDelay(value);
113 ValueWithArrivalTime entry(value, delay, invert, valueNumber++);
119 for (
size_t i = 0, e = op->getNumOperands(); i < e; ++i)
120 if (failed(push(op->getOperand(i), isInverted(op->getOpOperand(i)))))
125 while (queue.size() >= 2) {
126 auto lhs = queue.top();
128 auto rhs = queue.top();
131 if (failed(push(createBinaryOp(lhs, rhs),
false)))
136 auto result = queue.top().getValue();
137 rewriter.replaceOp(op, result);
141void LowerVariadicPass::runOnOperation() {
144 if (!mlir::sortTopologically(
145 getOperation().
getBodyBlock(), [](Value val, Operation *op) ->
bool {
146 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
147 return isa<hw::InstanceOp>(op);
148 return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
151 mlir::emitError(getOperation().
getLoc())
152 <<
"Failed to topologically sort graph region blocks";
153 return signalPassFailure();
158 if (timingAware.getValue())
159 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
161 auto moduleOp = getOperation();
164 SmallVector<OperationName> names;
165 for (
const auto &name : opNames)
166 names.push_back(OperationName(name, &getContext()));
169 auto shouldLower = [&](Operation *op) {
173 return llvm::find(names, op->getName()) != names.end();
176 mlir::IRRewriter rewriter(&getContext());
177 rewriter.setListener(analysis);
185 if (!shouldLower(op) || op->getNumOperands() <= 2)
188 rewriter.setInsertionPoint(op);
191 if (
auto andInverterOp = dyn_cast<aig::AndInverterOp>(op)) {
193 analysis, rewriter, op,
195 [&](OpOperand &operand) {
196 return andInverterOp.isInverted(operand.getOperandNumber());
199 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
200 return aig::AndInverterOp::create(
201 rewriter, op->getLoc(), lhs.getValue(), rhs.getValue(),
202 lhs.isInverted(), rhs.isInverted());
205 return signalPassFailure();
211 if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
212 op->hasTrait<OpTrait::IsCommutative>()) {
214 analysis, rewriter, op,
216 [](OpOperand &) {
return false; },
218 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
219 OperationState state(op->getLoc(), op->getName());
220 state.addOperands(ValueRange{lhs.getValue(), rhs.getValue()});
221 state.addTypes(op->getResult(0).getType());
222 auto *newOp = Operation::create(state);
223 rewriter.insert(newOp);
224 return newOp->getResult(0);
227 return signalPassFailure();
static LogicalResult replaceWithBalancedTree(IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter, Operation *op, llvm::function_ref< bool(OpOperand &)> isInverted, llvm::function_ref< Value(ValueWithArrivalTime, ValueWithArrivalTime)> createBinaryOp)
Construct a balanced binary tree from a variadic operation using a delay-aware algorithm.
static Location getLoc(DefSlot slot)
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.