21#include "mlir/Analysis/TopologicalSortUtils.h"
22#include "mlir/IR/Block.h"
23#include "mlir/IR/OpDefinition.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/Error.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "synth-lower-variadic"
39#define GEN_PASS_DEF_LOWERVARIADIC
40#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
53struct LowerVariadicPass :
public impl::LowerVariadicBase<LowerVariadicPass> {
54 using LowerVariadicBase::LowerVariadicBase;
55 void runOnOperation()
override;
65 IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
66 Operation *op, llvm::function_ref<
bool(OpOperand &)> isInverted,
67 llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
70 SmallVector<ValueWithArrivalTime> operands;
71 size_t valueNumber = 0;
73 for (
size_t i = 0, e = op->getNumOperands(); i < e; ++i) {
78 auto result = analysis->getMaxDelay(op->getOperand(i));
83 operands.push_back(ValueWithArrivalTime(op->getOperand(i), delay,
84 isInverted(op->getOpOperand(i)),
89 auto result = buildBalancedTreeWithArrivalTimes<ValueWithArrivalTime>(
92 [&](
const ValueWithArrivalTime &lhs,
const ValueWithArrivalTime &rhs) {
93 Value combined = createBinaryOp(lhs, rhs);
96 auto delayResult = analysis->getMaxDelay(combined);
97 if (succeeded(delayResult))
98 newDelay = *delayResult;
100 return ValueWithArrivalTime(combined, newDelay,
false, valueNumber++);
103 rewriter.replaceOp(op, result.getValue());
107using OperandKey = llvm::SmallVector<std::pair<mlir::Value, bool>>;
113 llvm::hash_code hash = 0;
115 for (
const auto &pair : val) {
116 hash = llvm::hash_combine(
120 return static_cast<unsigned>(hash);
134 const std::pair<mlir::Value, bool> &rhs)
const {
135 if (lhs.first != rhs.first) {
136 auto lhsArg = llvm::dyn_cast<mlir::BlockArgument>(lhs.first);
137 auto rhsArg = llvm::dyn_cast<mlir::BlockArgument>(rhs.first);
138 if (lhsArg && rhsArg)
139 return lhsArg.getArgNumber() < rhsArg.getArgNumber();
145 auto *lhsOp = lhs.first.getDefiningOp();
146 auto *rhsOp = rhs.first.getDefiningOp();
147 return lhsOp->isBeforeInBlock(rhsOp);
149 return lhs.second < rhs.second;
155 for (
size_t i = 0, e = op.getNumOperands(); i < e; ++i)
156 key.emplace_back(op.getOperand(i), op.isInverted(i));
163 aig::AndInverterOp op, mlir::IRRewriter &rewriter,
164 llvm::DenseMap<OperandKey, mlir::Value> &seenExpressions) {
166 if (op.getNumOperands() <= 2)
170 mlir::SmallVector<Value> newValues;
171 mlir::SmallVector<bool> newInversions;
173 for (
auto it = allOperands.begin(); it != allOperands.end(); ++it) {
177 auto match = seenExpressions.find(remaining);
178 if (match != seenExpressions.end() && match->second != op.getResult()) {
179 newValues.push_back(match->second);
180 newInversions.push_back(
false);
188 newValues.push_back(it->first);
189 newInversions.push_back(it->second);
192 if (newValues.size() < allOperands.size()) {
193 rewriter.modifyOpInPlace(op, [&]() {
194 op.getOperation()->setOperands(newValues);
195 op.setInverted(newInversions);
200void LowerVariadicPass::runOnOperation() {
201 if (getOperation()->getNumRegions() != 1 ||
202 getOperation()->getRegion(0).getBlocks().size() != 1)
206 mlir::Block &bodyBlock = getOperation()->getRegion(0).getBlocks().front();
207 auto *moduleOp = getOperation();
209 if (!mlir::sortTopologically(
210 &bodyBlock, [](Value val, Operation *op) ->
bool {
211 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
212 return isa<hw::InstanceOp>(op);
213 return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
216 mlir::emitError(moduleOp->getLoc())
217 <<
"Failed to topologically sort graph region blocks";
218 return signalPassFailure();
223 if (timingAware.getValue()) {
224 if (!dyn_cast<hw::HWModuleOp>(moduleOp)) {
225 moduleOp->emitWarning(
226 "Longest Path Analysis failed: expected 'hw.module', but found '")
227 << moduleOp->getName().getStringRef()
228 <<
"'. Only HWModuleOps are currently supported.";
230 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
235 SmallVector<OperationName> names;
236 for (
const auto &name : opNames)
237 names.push_back(OperationName(name, &getContext()));
240 auto shouldLower = [&](Operation *op) {
244 return llvm::find(names, op->getName()) != names.end();
247 mlir::IRRewriter rewriter(&getContext());
248 rewriter.setListener(analysis);
252 llvm::DenseMap<OperandKey, mlir::Value> seenExpressions;
254 for (
auto &op : bodyBlock.getOperations()) {
255 if (
auto andInverterOp = llvm::dyn_cast<aig::AndInverterOp>(op)) {
257 seenExpressions[key] = andInverterOp.getResult();
261 for (
auto &op : bodyBlock.getOperations()) {
262 if (
auto andInverterOp = llvm::dyn_cast<aig::AndInverterOp>(op)) {
271 for (
auto &opRef :
llvm::make_early_inc_range(bodyBlock.getOperations())) {
274 if (!shouldLower(op) || op->getNumOperands() <= 2)
277 rewriter.setInsertionPoint(op);
281 mlir::TypeSwitch<Operation *, LogicalResult>(op)
282 .Case<aig::AndInverterOp, XorInverterOp>([&](
auto op) {
284 analysis, rewriter, op,
286 [&](OpOperand &operand) {
287 return op.isInverted(operand.getOperandNumber());
290 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
291 return decltype(op)::create(
292 rewriter, op->getLoc(), lhs.getValue(), rhs.getValue(),
293 lhs.isInverted(), rhs.isInverted());
296 .Default([&](Operation *op) {
299 if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
300 op->hasTrait<OpTrait::IsCommutative>())
302 analysis, rewriter, op,
304 [](OpOperand &) {
return false; },
306 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
307 OperationState state(op->getLoc(), op->getName());
309 ValueRange{lhs.getValue(), rhs.getValue()});
310 state.addTypes(op->getResult(0).getType());
311 auto *newOp = Operation::create(state);
312 rewriter.insert(newOp);
313 return newOp->getResult(0);
318 return signalPassFailure();
static LogicalResult replaceWithBalancedTree(IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter, Operation *op, llvm::function_ref< bool(OpOperand &)> isInverted, llvm::function_ref< Value(ValueWithArrivalTime, ValueWithArrivalTime)> createBinaryOp)
Construct a balanced binary tree from a variadic operation using a delay-aware algorithm.
static void simplifyWithExistingOperations(aig::AndInverterOp op, mlir::IRRewriter &rewriter, llvm::DenseMap< OperandKey, mlir::Value > &seenExpressions)
llvm::SmallVector< std::pair< mlir::Value, bool > > OperandKey
static OperandKey getSortedOperandKey(aig::AndInverterOp op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
bool operator()(const std::pair< mlir::Value, bool > &lhs, const std::pair< mlir::Value, bool > &rhs) const
static unsigned getHashValue(const OperandKey &val)
static bool isEqual(const OperandKey &lhs, const OperandKey &rhs)