CIRCT 23.0.0git
Loading...
Searching...
No Matches
SynthToComb.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main Synth to Comb Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace circt {
22#define GEN_PASS_DEF_CONVERTSYNTHTOCOMB
23#include "circt/Conversion/Passes.h.inc"
24} // namespace circt
25
26using namespace circt;
27using namespace comb;
28
29//===----------------------------------------------------------------------===//
30// Conversion patterns
31//===----------------------------------------------------------------------===//
32
33namespace {
34
35struct SynthChoiceOpConversion : OpConversionPattern<synth::ChoiceOp> {
36 using OpConversionPattern<synth::ChoiceOp>::OpConversionPattern;
37 LogicalResult
38 matchAndRewrite(synth::ChoiceOp op, OpAdaptor adaptor,
39 ConversionPatternRewriter &rewriter) const override {
40 // Use the first input as the output, and ignore the rest.
41 rewriter.replaceOp(op, adaptor.getInputs().front());
42 return success();
43 }
44};
45
46struct SynthAndInverterOpConversion
47 : OpConversionPattern<synth::aig::AndInverterOp> {
48 using OpConversionPattern<synth::aig::AndInverterOp>::OpConversionPattern;
49 LogicalResult
50 matchAndRewrite(synth::aig::AndInverterOp op, OpAdaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override {
52 // Convert to comb.and + comb.xor + hw.constant
53 auto width = op.getResult().getType().getIntOrFloatBitWidth();
54 auto allOnes =
55 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getAllOnes(width));
56 SmallVector<Value> operands;
57 operands.reserve(op.getNumOperands());
58 for (auto [input, inverted] : llvm::zip(op.getOperands(), op.getInverted()))
59 operands.push_back(inverted ? rewriter.createOrFold<comb::XorOp>(
60 op.getLoc(), input, allOnes, true)
61 : input);
62 // NOTE: Use createOrFold to avoid creating a new operation if possible.
63 rewriter.replaceOp(
64 op, rewriter.createOrFold<comb::AndOp>(op.getLoc(), operands, true));
65 return success();
66 }
67};
68
69struct SynthMajorityInverterOpConversion
70 : OpConversionPattern<synth::mig::MajorityInverterOp> {
72 synth::mig::MajorityInverterOp>::OpConversionPattern;
73 LogicalResult
74 matchAndRewrite(synth::mig::MajorityInverterOp op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter) const override {
76 auto getOperand = [&](unsigned idx) {
77 auto input = adaptor.getInputs()[idx];
78 if (!op.getInverted()[idx])
79 return input;
80 auto width = input.getType().getIntOrFloatBitWidth();
81 auto allOnes = hw::ConstantOp::create(rewriter, op.getLoc(),
82 APInt::getAllOnes(width));
83 return rewriter.createOrFold<comb::XorOp>(op.getLoc(), input, allOnes,
84 true);
85 };
86
87 if (op.getNumOperands() == 1) {
88 rewriter.replaceOp(op, getOperand(0));
89 return success();
90 }
91
92 SmallVector<Value> inputs;
93 inputs.reserve(op.getNumOperands());
94 for (size_t i = 0, e = op.getNumOperands(); i < e; ++i)
95 inputs.push_back(getOperand(i));
96
97 // MAJ_n(x_0, ..., x_n) is the OR of all conjunctions over threshold-sized
98 // subsets, where threshold = floor(n / 2) + 1.
99 auto getProduct = [&](ArrayRef<unsigned> indices) {
100 SmallVector<Value> productOperands;
101 productOperands.reserve(indices.size());
102 for (auto idx : indices)
103 productOperands.push_back(inputs[idx]);
104 return rewriter.createOrFold<comb::AndOp>(op.getLoc(), productOperands,
105 true);
106 };
107
108 SmallVector<Value> operands;
109 SmallVector<unsigned> subset;
110 const unsigned threshold = op.getNumOperands() / 2 + 1;
111
112 auto enumerateProducts = [&](auto &&self, unsigned start) -> void {
113 if (subset.size() == threshold) {
114 operands.push_back(getProduct(subset));
115 return;
116 }
117
118 const unsigned remaining = threshold - subset.size();
119 assert(start + remaining <= op.getNumOperands() &&
120 "Not enough operands left to reach threshold");
121 for (unsigned i = start, e = op.getNumOperands() - remaining; i <= e;
122 ++i) {
123 subset.push_back(i);
124 self(self, i + 1);
125 subset.pop_back();
126 }
127 };
128 enumerateProducts(enumerateProducts, 0);
129
130 rewriter.replaceOp(
131 op, rewriter.createOrFold<comb::OrOp>(op.getLoc(), operands, true));
132 return success();
133 }
134};
135
136} // namespace
137
138//===----------------------------------------------------------------------===//
139// Convert Synth to Comb pass
140//===----------------------------------------------------------------------===//
141
142namespace {
143struct ConvertSynthToCombPass
144 : public impl::ConvertSynthToCombBase<ConvertSynthToCombPass> {
145
146 void runOnOperation() override;
147 using ConvertSynthToCombBase<ConvertSynthToCombPass>::ConvertSynthToCombBase;
148};
149} // namespace
150
151static void populateSynthToCombConversionPatterns(RewritePatternSet &patterns) {
152 patterns.add<SynthChoiceOpConversion, SynthAndInverterOpConversion,
153 SynthMajorityInverterOpConversion>(patterns.getContext());
154}
155
156void ConvertSynthToCombPass::runOnOperation() {
157 ConversionTarget target(getContext());
158 target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
159 target.addIllegalDialect<synth::SynthDialect>();
160
161 RewritePatternSet patterns(&getContext());
163
164 if (failed(mlir::applyPartialConversion(getOperation(), target,
165 std::move(patterns))))
166 return signalPassFailure();
167}
assert(baseType &&"element must be base type")
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static void populateSynthToCombConversionPatterns(RewritePatternSet &patterns)
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1