21#include "mlir/Analysis/TopologicalSortUtils.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/Visitors.h"
25#include "mlir/Support/LLVM.h"
26#include "mlir/Transforms/RegionUtils.h"
27#include "llvm/ADT/DenseMap.h"
28#include "llvm/ADT/DenseMapInfo.h"
29#include "llvm/ADT/PointerIntPair.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/DebugLog.h"
32#include "llvm/Support/LogicalResult.h"
34#define DEBUG_TYPE "synth-structural-hash"
38#define GEN_PASS_DEF_STRUCTURALHASH
39#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
56 llvm::SmallVector<llvm::PointerIntPair<Value, 1>, 3> inps)
66 hash = llvm::hash_combine(hash, operand.getOpaqueValue());
67 return static_cast<unsigned>(hash);
79struct StructuralHashPass
80 :
public impl::StructuralHashBase<StructuralHashPass> {
81 void runOnOperation()
override;
91class StructuralHashDriver {
93 StructuralHashDriver() =
default;
94 void visitOp(BooleanLogicOpInterface op);
95 void visitUnaryOp(BooleanLogicOpInterface op);
96 void visitVariadicOp(BooleanLogicOpInterface op);
97 uint64_t getNumber(Value v);
102 llvm::LogicalResult
run(Operation *op);
106 llvm::LogicalResult runOnBlock(Block &block);
109 DenseMap<Value, uint64_t> valueNumber;
110 uint64_t constantCounter = 0;
113 DenseMap<StructuralHashKey, BooleanLogicOpInterface> hashTable;
123 DenseMap<Value, Value> inversion;
127void StructuralHashDriver::visitOp(BooleanLogicOpInterface op) {
131 if (op.getInputs().size() == 1) {
142void StructuralHashDriver::visitUnaryOp(BooleanLogicOpInterface logicOp) {
143 Operation *op = logicOp.getOperation();
144 auto [input, inverted] = logicOp.getInputPair(0);
146 op->replaceAllUsesWith(ArrayRef<Value>{input});
150 auto it = inversion.find(input);
151 if (it != inversion.end()) {
153 op->replaceAllUsesWith(ArrayRef<Value>{it->second});
157 inversion[logicOp.getResult()] = input;
163void StructuralHashDriver::visitVariadicOp(BooleanLogicOpInterface logicOp) {
164 Operation *op = logicOp.getOperation();
165 auto inversions = logicOp.getInverted();
169 for (
auto [input, inverted] :
llvm::zip(op->getOperands(), inversions)) {
170 bool isInverted = inverted;
172 auto it = inversion.find(input);
173 if (it != inversion.end()) {
176 isInverted = !isInverted;
179 key.operandPairs.push_back(
180 llvm::PointerIntPair<Value, 1>(input, isInverted));
183 (void)getNumber(input);
188 if (logicOp.areInputsPermutationInvariant()) {
189 llvm::sort(key.operandPairs, [&](
auto a,
auto b) {
190 size_t aNum = getNumber(a.getPointer());
191 size_t bNum = getNumber(b.getPointer());
194 return a.getInt() < b.getInt();
199 auto [it, inserted] = hashTable.try_emplace(key, logicOp);
202 op->setOperands(llvm::to_vector<3>(llvm::map_range(
203 key.operandPairs, [](
auto p) { return p.getPointer(); })));
204 SmallVector<bool, 3> newInversion(
205 llvm::map_range(key.operandPairs, [](
auto p) { return p.getInt(); }));
206 logicOp.setInverted(newInversion);
208 (void)getNumber(logicOp.getResult());
210 LDBG() <<
"Structural Hash: Replacing " << *op <<
" with " << *(it->second)
215 if (name && !it->second->hasAttr(
"sv.namehint"))
216 it->second->setAttr(
"sv.namehint", name);
217 op->replaceAllUsesWith(it->second);
224uint64_t StructuralHashDriver::getNumber(Value v) {
225 auto it = valueNumber.find(v);
226 if (it != valueNumber.end())
231 if (
auto *op = v.getDefiningOp();
232 op && op->hasTrait<mlir::OpTrait::ConstantLike>()) {
233 auto [it, inserted] = valueNumber.try_emplace(
234 v, std::numeric_limits<uint64_t>::max() - constantCounter++);
238 return valueNumber.try_emplace(v, valueNumber.size() - constantCounter)
242llvm::LogicalResult StructuralHashDriver::runOnBlock(Block &block) {
245 auto isOperationReady = [&](Value value, Operation *op) ->
bool {
247 return !isa<BooleanLogicOpInterface>(op);
250 if (!mlir::sortTopologically(&block, isOperationReady))
253 for (
auto arg : block.getArguments())
254 (void)getNumber(arg);
258 llvm::make_early_inc_range(block.getOps<BooleanLogicOpInterface>())) {
262 return mlir::success();
265llvm::LogicalResult StructuralHashDriver::run(Operation *moduleOp) {
266 auto result = moduleOp->walk([&](Block *block) {
267 return failed(runOnBlock(*block)) ? WalkResult::interrupt()
268 : WalkResult::advance();
270 if (result.wasInterrupted())
274 mlir::PatternRewriter rewriter(moduleOp->getContext());
275 (void)mlir::runRegionDCE(rewriter, moduleOp->getRegions());
276 return mlir::success();
279void StructuralHashPass::runOnOperation() {
280 auto *topOp = getOperation();
281 StructuralHashDriver driver;
282 if (failed(driver.run(topOp)))
283 return signalPassFailure();
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
llvm::hash_code hash_value(const DenseSet< T > &set)
A struct that represents the key used for structural hashing.
StructuralHashKey(OperationName name, llvm::SmallVector< llvm::PointerIntPair< Value, 1 >, 3 > inps)
Constructor.
llvm::SmallVector< llvm::PointerIntPair< Value, 1 >, 3 > operandPairs
static bool isEqual(const StructuralHashKey &lhs, const StructuralHashKey &rhs)
static unsigned getHashValue(const StructuralHashKey &key)