21#include "mlir/Analysis/TopologicalSortUtils.h"
22#include "mlir/IR/OpDefinition.h"
24#define DEBUG_TYPE "synth-lower-variadic"
28#define GEN_PASS_DEF_LOWERVARIADIC
29#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
42struct LowerVariadicPass :
public impl::LowerVariadicBase<LowerVariadicPass> {
43 using LowerVariadicBase::LowerVariadicBase;
44 void runOnOperation()
override;
54 IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
55 Operation *op, llvm::function_ref<
bool(OpOperand &)> isInverted,
56 llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
59 SmallVector<ValueWithArrivalTime> operands;
60 size_t valueNumber = 0;
62 for (
size_t i = 0, e = op->getNumOperands(); i < e; ++i) {
67 auto result = analysis->getMaxDelay(op->getOperand(i));
72 operands.push_back(ValueWithArrivalTime(op->getOperand(i), delay,
73 isInverted(op->getOpOperand(i)),
78 auto result = buildBalancedTreeWithArrivalTimes<ValueWithArrivalTime>(
81 [&](
const ValueWithArrivalTime &lhs,
const ValueWithArrivalTime &rhs) {
82 Value combined = createBinaryOp(lhs, rhs);
85 auto delayResult = analysis->getMaxDelay(combined);
86 if (succeeded(delayResult))
87 newDelay = *delayResult;
89 return ValueWithArrivalTime(combined, newDelay,
false, valueNumber++);
92 rewriter.replaceOp(op, result.getValue());
96void LowerVariadicPass::runOnOperation() {
99 if (!mlir::sortTopologically(
100 getOperation().
getBodyBlock(), [](Value val, Operation *op) ->
bool {
101 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
102 return isa<hw::InstanceOp>(op);
103 return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
106 mlir::emitError(getOperation().
getLoc())
107 <<
"Failed to topologically sort graph region blocks";
108 return signalPassFailure();
113 if (timingAware.getValue())
114 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
116 auto moduleOp = getOperation();
119 SmallVector<OperationName> names;
120 for (
const auto &name : opNames)
121 names.push_back(OperationName(name, &getContext()));
124 auto shouldLower = [&](Operation *op) {
128 return llvm::find(names, op->getName()) != names.end();
131 mlir::IRRewriter rewriter(&getContext());
132 rewriter.setListener(analysis);
140 if (!shouldLower(op) || op->getNumOperands() <= 2)
143 rewriter.setInsertionPoint(op);
146 if (
auto andInverterOp = dyn_cast<aig::AndInverterOp>(op)) {
148 analysis, rewriter, op,
150 [&](OpOperand &operand) {
151 return andInverterOp.isInverted(operand.getOperandNumber());
154 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
155 return aig::AndInverterOp::create(
156 rewriter, op->getLoc(), lhs.getValue(), rhs.getValue(),
157 lhs.isInverted(), rhs.isInverted());
160 return signalPassFailure();
166 if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
167 op->hasTrait<OpTrait::IsCommutative>()) {
169 analysis, rewriter, op,
171 [](OpOperand &) {
return false; },
173 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
174 OperationState state(op->getLoc(), op->getName());
175 state.addOperands(ValueRange{lhs.getValue(), rhs.getValue()});
176 state.addTypes(op->getResult(0).getType());
177 auto *newOp = Operation::create(state);
178 rewriter.insert(newOp);
179 return newOp->getResult(0);
182 return signalPassFailure();
static LogicalResult replaceWithBalancedTree(IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter, Operation *op, llvm::function_ref< bool(OpOperand &)> isInverted, llvm::function_ref< Value(ValueWithArrivalTime, ValueWithArrivalTime)> createBinaryOp)
Construct a balanced binary tree from a variadic operation using a delay-aware algorithm.
static Location getLoc(DefSlot slot)
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.