24#include "mlir/IR/Builders.h"
25#include "mlir/IR/Threading.h"
26#include "mlir/Support/WalkResult.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/SmallVector.h"
30#include "llvm/Support/Debug.h"
35#define GEN_PASS_DEF_TECHMAPPER
36#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
43#define DEBUG_TYPE "synth-tech-mapper"
51 auto inputTypes =
module.getInputTypes();
52 auto outputTypes =
module.getOutputTypes();
54 unsigned numInputs = inputTypes.size();
55 unsigned numOutputs = outputTypes.size();
57 return module->emitError(
58 "Modules with multiple outputs are not supported yet");
61 for (
auto type : inputTypes) {
62 if (!type.isInteger(1))
63 return module->emitError(
"All input ports must be single bit");
65 for (
auto type : outputTypes) {
66 if (!type.isInteger(1))
67 return module->emitError(
"All output ports must be single bit");
71 return module->emitError("Too many inputs for truth table generation");
73 SmallVector<Value> results;
74 results.reserve(numOutputs);
76 auto *bodyBlock =
module.getBodyBlock();
77 assert(bodyBlock &&
"Module must have a body block");
79 for (
auto result : bodyBlock->getTerminator()->getOperands())
80 results.push_back(result);
83 FailureOr<BinaryTruthTable> truthTable =
getTruthTable(results, bodyBlock);
84 if (failed(truthTable))
98 llvm::dbgs() <<
"Created Tech Library Pattern for module: "
99 <<
module.getModuleName() << "\n"
100 << "NPN Class: " << this->npnClass.truthTable.table << "\n"
101 << "Inputs: " << this->npnClass.inputPermutation.size()
103 << "Input Negation: " << this->npnClass.inputNegation << "\n"
104 << "Output Negation: " << this->npnClass.outputNegation
110 auto moduleCp =
module;
111 return moduleCp.getModuleName();
116 const Cut &cut)
const override {
126 SmallVectorImpl<NPNClass> &matchingNPNClasses)
const override {
127 matchingNPNClasses.push_back(
npnClass);
132 llvm::FailureOr<Operation *>
rewrite(mlir::OpBuilder &builder,
134 const Cut &cut)
const override {
137 SmallVector<unsigned> permutedInputIndices;
139 permutedInputIndices);
141 SmallVector<Value> inputs;
142 inputs.reserve(permutedInputIndices.size());
143 for (
unsigned idx : permutedInputIndices) {
144 assert(idx < cut.
inputs.size() &&
"input permutation index out of range");
145 inputs.push_back(network.getValue(cut.
inputs[idx]));
148 auto *rootOp = network.getGate(cut.
getRootIndex()).getOperation();
149 assert(rootOp &&
"cut root must be a valid operation");
152 auto instanceOp = hw::InstanceOp::create(builder, rootOp->getLoc(), module,
153 "mapped", ArrayRef<Value>(inputs));
154 return instanceOp.getOperation();
166 auto module = this->module;
167 return module.getLoc();
178struct TechMapperPass :
public impl::TechMapperBase<TechMapperPass> {
179 using TechMapperBase<TechMapperPass>::TechMapperBase;
181 LogicalResult initialize(MLIRContext *
context)
override {
183 npnTable = std::make_shared<const NPNTable>();
187 void runOnOperation()
override {
188 auto module = getOperation();
190 SmallVector<std::unique_ptr<CutRewritePattern>> libraryPatterns;
192 unsigned maxInputSize = 0;
195 SmallVector<hw::HWModuleOp> nonLibraryModules;
196 for (
auto hwModule :
module.getOps<hw::HWModuleOp>()) {
199 hwModule->getAttrOfType<MappingCostAttr>("synth.mapping_cost");
201 nonLibraryModules.push_back(hwModule);
205 double area = mappingCost.getArea().getValue().convertToDouble();
207 StringAttr outputName;
212 "Modules with multiple outputs are not supported yet");
216 outputName = port.name;
219 hwModule.emitError(
"expected library module to have an output");
224 llvm::DenseMap<StringAttr, DelayType> delayByInput;
225 for (
auto attr : mappingCost.getArcs()) {
226 auto arc = cast<LinearTimingArcAttr>(attr);
229 "expected synth.linear_timing_arc in synth.mapping_cost arcs");
234 if (
arc.getPin() != outputName) {
235 hwModule.emitError(
"mapping cost arc output '")
236 <<
arc.getPin().getValue() <<
"' does not match module output '"
237 << outputName.getValue() <<
"'";
242 int64_t intrinsicDelay =
arc.getIntrinsic();
248 .try_emplace(
arc.getRelatedPin(),
251 hwModule.emitError(
"duplicate mapping cost arc for input '")
252 <<
arc.getRelatedPin().getValue() <<
"'";
258 SmallVector<DelayType> delay;
259 for (
const auto &port : hwModule.getPortList()) {
263 auto it = delayByInput.find(port.name);
264 if (it == delayByInput.end()) {
265 hwModule.emitError(
"missing mapping cost arc for input '")
266 << port.name.getValue() <<
"'";
271 delay.push_back(it->second);
274 if (delay.size() != delayByInput.size()) {
276 "synth.mapping_cost arcs do not match module inputs");
283 if (failed(npnClass)) {
289 std::unique_ptr<TechLibraryPattern>
pattern =
290 std::make_unique<TechLibraryPattern>(hwModule, area, std::move(delay),
291 std::move(*npnClass));
294 maxInputSize = std::max(maxInputSize,
pattern->getNumInputs());
297 libraryPatterns.push_back(std::move(
pattern));
300 if (libraryPatterns.empty())
301 return markAllAnalysesPreserved();
310 std::atomic<uint64_t> numCutsCreatedCount = 0;
311 std::atomic<uint64_t> numCutSetsCreatedCount = 0;
312 std::atomic<uint64_t> numCutsRewrittenCount = 0;
313 auto result = mlir::failableParallelForEach(
314 module.getContext(), nonLibraryModules, [&](
hw::HWModuleOp hwModule) {
315 LLVM_DEBUG(llvm::dbgs() <<
"Processing non-library module: "
316 << hwModule.getName() <<
"\n");
317 CutRewriter rewriter(options, patternSet);
318 if (failed(rewriter.run(hwModule)))
320 const auto &stats = rewriter.getStats();
321 numCutsCreatedCount.fetch_add(stats.numCutsCreated,
322 std::memory_order_relaxed);
323 numCutSetsCreatedCount.fetch_add(stats.numCutSetsCreated,
324 std::memory_order_relaxed);
325 numCutsRewrittenCount.fetch_add(stats.numCutsRewritten,
326 std::memory_order_relaxed);
331 numCutsCreated += numCutsCreatedCount;
332 numCutSetsCreated += numCutSetsCreatedCount;
333 numCutsRewritten += numCutsRewrittenCount;
337 std::shared_ptr<const NPNTable> npnTable;
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
RewritePatternSet pattern
static llvm::FailureOr< NPNClass > getNPNClassFromModule(hw::HWModuleOp module)
Cut enumeration engine for combinational logic networks.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const CutRewriterOptions & getOptions() const
Get the cut rewriter options used for this enumeration.
Manages a collection of rewriting patterns for combinational logic optimization.
Represents a cut in the combinational logic network.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Simple technology library encoded as a HWModuleOp.
TechLibraryPattern(hw::HWModuleOp module, double area, SmallVector< DelayType > delay, NPNClass npnClass)
StringRef getPatternName() const override
Get the name of this pattern. Used for debugging.
hw::HWModuleOp NPNClass npnClass
std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const override
Match the cut set against this library primitive.
llvm::FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const override
Rewrite the cut set using this library primitive.
const SmallVector< DelayType > delay
bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const override
Enable truth table matching for this pattern.
unsigned getNumOutputs() const override
Get the number of outputs this pattern produces.
unsigned getNumInputs() const
LocationAttr getLoc() const override
Get location for this pattern(optional).
Represents the canonical form of a boolean function under NPN equivalence.
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
bool equivalentOtherThanPermutation(const NPNClass &other) const
Equality comparison for NPN classes.
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getOutputs()
Base class for cut rewriting patterns used in combinational logic optimization.
mlir::MLIRContext * getContext() const
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
Result of matching a cut against a pattern.