CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerVariadic.cpp
Go to the documentation of this file.
1//===- LowerVariadic.cpp - Lowering Variadic to Binary Ops ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers variadic operations to binary operations using a
10// delay-aware algorithm for commutative operations.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/Analysis/TopologicalSortUtils.h"
22#include "mlir/IR/OpDefinition.h"
23#include "llvm/ADT/PointerIntPair.h"
24#include "llvm/ADT/PriorityQueue.h"
25
26#define DEBUG_TYPE "synth-lower-variadic"
27
28namespace circt {
29namespace synth {
30#define GEN_PASS_DEF_LOWERVARIADIC
31#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
32} // namespace synth
33} // namespace circt
34
35using namespace circt;
36using namespace synth;
37
38//===----------------------------------------------------------------------===//
39// Lower Variadic pass
40//===----------------------------------------------------------------------===//
41
42namespace {
43
44/// Helper class for delay-aware variadic operation lowering.
45/// Stores a value along with its arrival time for priority queue ordering.
46class ValueWithArrivalTime {
47 /// The value and an optional inversion flag packed together.
48 /// The inversion flag is used for AndInverterOp lowering.
49 llvm::PointerIntPair<Value, 1, bool> value;
50
51 /// The arrival time (delay) of this value in the circuit.
52 int64_t arrivalTime;
53
54 /// Value numbering for deterministic ordering when arrival times are equal.
55 /// This ensures consistent results across runs when multiple values have
56 /// the same delay.
57 size_t valueNumbering = 0;
58
59public:
60 ValueWithArrivalTime(Value value, int64_t arrivalTime, bool invert,
61 size_t valueNumbering)
62 : value(value, invert), arrivalTime(arrivalTime),
63 valueNumbering(valueNumbering) {}
64
65 Value getValue() const { return value.getPointer(); }
66 bool isInverted() const { return value.getInt(); }
67
68 /// Comparison operator for priority queue. Values with earlier arrival times
69 /// have higher priority. When arrival times are equal, use value numbering
70 /// for determinism.
71 bool operator>(const ValueWithArrivalTime &other) const {
72 return arrivalTime > other.arrivalTime ||
73 (arrivalTime == other.arrivalTime &&
74 valueNumbering > other.valueNumbering);
75 }
76};
77
78struct LowerVariadicPass : public impl::LowerVariadicBase<LowerVariadicPass> {
79 using LowerVariadicBase::LowerVariadicBase;
80 void runOnOperation() override;
81};
82
83} // namespace
84
85/// Construct a balanced binary tree from a variadic operation using a
86/// delay-aware algorithm. This function builds the tree by repeatedly combining
87/// the two values with the earliest arrival times, which minimizes the critical
88/// path delay.
89static LogicalResult replaceWithBalancedTree(
90 IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
91 Operation *op, llvm::function_ref<bool(OpOperand &)> isInverted,
92 llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
93 createBinaryOp) {
94 // Min-heap priority queue ordered by arrival time.
95 // Values with earlier arrival times are processed first.
96 llvm::PriorityQueue<ValueWithArrivalTime, std::vector<ValueWithArrivalTime>,
97 std::greater<ValueWithArrivalTime>>
98 queue;
99
100 // Counter for deterministic ordering when arrival times are equal.
101 size_t valueNumber = 0;
102
103 auto push = [&](Value value, bool invert) {
104 int64_t delay = 0;
105 // If analysis is available, use it to compute the delay.
106 // If not available, use zero delay and `valueNumber` will be used instead.
107 if (analysis) {
108 auto result = analysis->getMaxDelay(value);
109 if (failed(result))
110 return failure();
111 delay = *result;
112 }
113 ValueWithArrivalTime entry(value, delay, invert, valueNumber++);
114 queue.push(entry);
115 return success();
116 };
117
118 // Enqueue all operands with their arrival times and inversion flags.
119 for (size_t i = 0, e = op->getNumOperands(); i < e; ++i)
120 if (failed(push(op->getOperand(i), isInverted(op->getOpOperand(i)))))
121 return failure();
122
123 // Build balanced tree by repeatedly combining the two earliest values.
124 // This greedy approach minimizes the maximum depth of late-arriving signals.
125 while (queue.size() >= 2) {
126 auto lhs = queue.top();
127 queue.pop();
128 auto rhs = queue.top();
129 queue.pop();
130 // Create and enqueue the combined value.
131 if (failed(push(createBinaryOp(lhs, rhs), /*inverted=*/false)))
132 return failure();
133 }
134
135 // Get the final result and replace the original operation.
136 auto result = queue.top().getValue();
137 rewriter.replaceOp(op, result);
138 return success();
139}
140
141void LowerVariadicPass::runOnOperation() {
142 // Topologically sort operations in graph regions to ensure operands are
143 // defined before uses.
144 if (!mlir::sortTopologically(
145 getOperation().getBodyBlock(), [](Value val, Operation *op) -> bool {
146 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
147 return isa<hw::InstanceOp>(op);
148 return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
149 op->getDialect());
150 })) {
151 mlir::emitError(getOperation().getLoc())
152 << "Failed to topologically sort graph region blocks";
153 return signalPassFailure();
154 }
155
156 // Get longest path analysis if timing-aware lowering is enabled.
157 synth::IncrementalLongestPathAnalysis *analysis = nullptr;
158 if (timingAware.getValue())
159 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
160
161 auto moduleOp = getOperation();
162
163 // Build set of operation names to lower if specified.
164 SmallVector<OperationName> names;
165 for (const auto &name : opNames)
166 names.push_back(OperationName(name, &getContext()));
167
168 // Return true if the operation should be lowered.
169 auto shouldLower = [&](Operation *op) {
170 // If no names specified, lower all variadic ops.
171 if (names.empty())
172 return true;
173 return llvm::find(names, op->getName()) != names.end();
174 };
175
176 mlir::IRRewriter rewriter(&getContext());
177 rewriter.setListener(analysis);
178
179 // FIXME: Currently only top-level operations are lowered due to the lack of
180 // topological sorting in across nested regions.
181 for (auto &opRef :
182 llvm::make_early_inc_range(moduleOp.getBodyBlock()->getOperations())) {
183 auto *op = &opRef;
184 // Skip operations that don't need lowering or are already binary.
185 if (!shouldLower(op) || op->getNumOperands() <= 2)
186 continue;
187
188 rewriter.setInsertionPoint(op);
189
190 // Handle AndInverterOp specially to preserve inversion flags.
191 if (auto andInverterOp = dyn_cast<aig::AndInverterOp>(op)) {
192 auto result = replaceWithBalancedTree(
193 analysis, rewriter, op,
194 // Check if each operand is inverted.
195 [&](OpOperand &operand) {
196 return andInverterOp.isInverted(operand.getOperandNumber());
197 },
198 // Create binary AndInverterOp with inversion flags.
199 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
200 return aig::AndInverterOp::create(
201 rewriter, op->getLoc(), lhs.getValue(), rhs.getValue(),
202 lhs.isInverted(), rhs.isInverted());
203 });
204 if (failed(result))
205 return signalPassFailure();
206 continue;
207 }
208
209 // Handle commutative operations (and, or, xor, mul, add, etc.) using
210 // delay-aware lowering to minimize critical path.
211 if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
212 op->hasTrait<OpTrait::IsCommutative>()) {
213 auto result = replaceWithBalancedTree(
214 analysis, rewriter, op,
215 // No inversion flags for standard commutative operations.
216 [](OpOperand &) { return false; },
217 // Create binary operation with the same operation type.
218 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
219 OperationState state(op->getLoc(), op->getName());
220 state.addOperands(ValueRange{lhs.getValue(), rhs.getValue()});
221 state.addTypes(op->getResult(0).getType());
222 auto *newOp = Operation::create(state);
223 rewriter.insert(newOp);
224 return newOp->getResult(0);
225 });
226 if (failed(result))
227 return signalPassFailure();
228 }
229 }
230}
static LogicalResult replaceWithBalancedTree(IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter, Operation *op, llvm::function_ref< bool(OpOperand &)> isInverted, llvm::function_ref< Value(ValueWithArrivalTime, ValueWithArrivalTime)> createBinaryOp)
Construct a balanced binary tree from a variadic operation using a delay-aware algorithm.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1