23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Threading.h"
25#include "mlir/Support/WalkResult.h"
26#include "llvm/ADT/APInt.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Debug.h"
33#define GEN_PASS_DEF_TECHMAPPER
34#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
41#define DEBUG_TYPE "synth-tech-mapper"
49 auto inputTypes =
module.getInputTypes();
50 auto outputTypes =
module.getOutputTypes();
52 unsigned numInputs = inputTypes.size();
53 unsigned numOutputs = outputTypes.size();
55 return module->emitError(
56 "Modules with multiple outputs are not supported yet");
59 for (
auto type : inputTypes) {
60 if (!type.isInteger(1))
61 return module->emitError(
"All input ports must be single bit");
63 for (
auto type : outputTypes) {
64 if (!type.isInteger(1))
65 return module->emitError(
"All output ports must be single bit");
69 return module->emitError("Too many inputs for truth table generation");
71 SmallVector<Value> results;
72 results.reserve(numOutputs);
74 auto *bodyBlock =
module.getBodyBlock();
75 assert(bodyBlock &&
"Module must have a body block");
77 for (
auto result : bodyBlock->getTerminator()->getOperands())
78 results.push_back(result);
81 FailureOr<BinaryTruthTable> truthTable =
getTruthTable(results, bodyBlock);
82 if (failed(truthTable))
96 llvm::dbgs() <<
"Created Tech Library Pattern for module: "
97 <<
module.getModuleName() << "\n"
98 << "NPN Class: " << this->npnClass.truthTable.table << "\n"
99 << "Inputs: " << this->npnClass.inputPermutation.size()
101 << "Input Negation: " << this->npnClass.inputNegation << "\n"
102 << "Output Negation: " << this->npnClass.outputNegation
108 auto moduleCp =
module;
109 return moduleCp.getModuleName();
114 const Cut &cut)
const override {
124 SmallVectorImpl<NPNClass> &matchingNPNClasses)
const override {
125 matchingNPNClasses.push_back(
npnClass);
130 llvm::FailureOr<Operation *>
rewrite(mlir::OpBuilder &builder,
132 const Cut &cut)
const override {
135 SmallVector<unsigned> permutedInputIndices;
137 permutedInputIndices);
139 SmallVector<Value> inputs;
140 inputs.reserve(permutedInputIndices.size());
141 for (
unsigned idx : permutedInputIndices) {
142 assert(idx < cut.
inputs.size() &&
"input permutation index out of range");
143 inputs.push_back(network.getValue(cut.
inputs[idx]));
146 auto *rootOp = network.getGate(cut.
getRootIndex()).getOperation();
147 assert(rootOp &&
"cut root must be a valid operation");
150 auto instanceOp = hw::InstanceOp::create(builder, rootOp->getLoc(), module,
151 "mapped", ArrayRef<Value>(inputs));
152 return instanceOp.getOperation();
164 auto module = this->module;
165 return module.getLoc();
176struct TechMapperPass :
public impl::TechMapperBase<TechMapperPass> {
177 using TechMapperBase<TechMapperPass>::TechMapperBase;
179 LogicalResult initialize(MLIRContext *
context)
override {
181 npnTable = std::make_shared<const NPNTable>();
185 void runOnOperation()
override {
186 auto module = getOperation();
188 SmallVector<std::unique_ptr<CutRewritePattern>> libraryPatterns;
190 unsigned maxInputSize = 0;
196 SmallVector<hw::HWModuleOp> nonLibraryModules;
197 for (
auto hwModule :
module.getOps<hw::HWModuleOp>()) {
199 hwModule->getAttrOfType<DictionaryAttr>("hw.techlib.info");
204 nonLibraryModules.push_back(hwModule);
209 auto areaAttr = techInfo.getAs<FloatAttr>(
"area");
210 auto delayAttr = techInfo.getAs<ArrayAttr>(
"delay");
211 if (!areaAttr || !delayAttr) {
212 mlir::emitError(hwModule.getLoc())
213 <<
"Library module " << hwModule.getModuleName()
214 <<
" must have 'area'(float) and 'delay' (2d array to represent "
215 "input-output pair delay) attributes";
220 double area = areaAttr.getValue().convertToDouble();
222 SmallVector<DelayType> delay;
223 for (
auto delayValue : delayAttr) {
224 auto delayArray = cast<ArrayAttr>(delayValue);
225 for (
auto delayElement : delayArray) {
230 cast<mlir::IntegerAttr>(delayElement).getValue().getZExtValue());
235 if (failed(npnClass)) {
241 std::unique_ptr<TechLibraryPattern>
pattern =
242 std::make_unique<TechLibraryPattern>(hwModule, area, std::move(delay),
243 std::move(*npnClass));
246 maxInputSize = std::max(maxInputSize,
pattern->getNumInputs());
249 libraryPatterns.push_back(std::move(
pattern));
252 if (libraryPatterns.empty())
253 return markAllAnalysesPreserved();
262 std::atomic<uint64_t> numCutsCreatedCount = 0;
263 std::atomic<uint64_t> numCutSetsCreatedCount = 0;
264 std::atomic<uint64_t> numCutsRewrittenCount = 0;
265 auto result = mlir::failableParallelForEach(
266 module.getContext(), nonLibraryModules, [&](
hw::HWModuleOp hwModule) {
267 LLVM_DEBUG(llvm::dbgs() <<
"Processing non-library module: "
268 << hwModule.getName() <<
"\n");
269 CutRewriter rewriter(options, patternSet);
270 if (failed(rewriter.run(hwModule)))
272 const auto &stats = rewriter.getStats();
273 numCutsCreatedCount.fetch_add(stats.numCutsCreated,
274 std::memory_order_relaxed);
275 numCutSetsCreatedCount.fetch_add(stats.numCutSetsCreated,
276 std::memory_order_relaxed);
277 numCutsRewrittenCount.fetch_add(stats.numCutsRewritten,
278 std::memory_order_relaxed);
283 numCutsCreated += numCutsCreatedCount;
284 numCutSetsCreated += numCutSetsCreatedCount;
285 numCutsRewritten += numCutsRewrittenCount;
289 std::shared_ptr<const NPNTable> npnTable;
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
RewritePatternSet pattern
static llvm::FailureOr< NPNClass > getNPNClassFromModule(hw::HWModuleOp module)
Cut enumeration engine for combinational logic networks.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const CutRewriterOptions & getOptions() const
Get the cut rewriter options used for this enumeration.
Manages a collection of rewriting patterns for combinational logic optimization.
Represents a cut in the combinational logic network.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Simple technology library encoded as a HWModuleOp.
TechLibraryPattern(hw::HWModuleOp module, double area, SmallVector< DelayType > delay, NPNClass npnClass)
StringRef getPatternName() const override
Get the name of this pattern. Used for debugging.
hw::HWModuleOp NPNClass npnClass
std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const override
Match the cut set against this library primitive.
llvm::FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const override
Rewrite the cut set using this library primitive.
const SmallVector< DelayType > delay
bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const override
Enable truth table matching for this pattern.
unsigned getNumOutputs() const override
Get the number of outputs this pattern produces.
unsigned getNumInputs() const
LocationAttr getLoc() const override
Get location for this pattern(optional).
Represents the canonical form of a boolean function under NPN equivalence.
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
bool equivalentOtherThanPermutation(const NPNClass &other) const
Equality comparison for NPN classes.
Base class for cut rewriting patterns used in combinational logic optimization.
mlir::MLIRContext * getContext() const
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
Result of matching a cut against a pattern.