CIRCT 23.0.0git
Loading...
Searching...
No Matches
TechMapper.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the TechMapper pass, which performs technology mapping
10// by converting logic network representations (AIG operations) into
11// technology-specific gate implementations using cut-based rewriting.
12//
13// The pass uses a cut-based algorithm with priority cuts and NPN canonical
14// forms for efficient pattern matching. It processes HWModuleOp instances with
15// "hw.techlib.info" attributes as technology library patterns and maps
16// non-library modules to optimal gate implementations based on area and timing
17// optimization strategies.
18//
19//===----------------------------------------------------------------------===//
20
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Threading.h"
25#include "mlir/Support/WalkResult.h"
26#include "llvm/ADT/APInt.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Debug.h"
29#include <atomic>
30
31namespace circt {
32namespace synth {
33#define GEN_PASS_DEF_TECHMAPPER
34#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
35} // namespace synth
36} // namespace circt
37
38using namespace circt;
39using namespace circt::synth;
40
41#define DEBUG_TYPE "synth-tech-mapper"
42
43//===----------------------------------------------------------------------===//
44// Tech Mapper Pass
45//===----------------------------------------------------------------------===//
46
47static llvm::FailureOr<NPNClass> getNPNClassFromModule(hw::HWModuleOp module) {
48 // Get input and output ports
49 auto inputTypes = module.getInputTypes();
50 auto outputTypes = module.getOutputTypes();
51
52 unsigned numInputs = inputTypes.size();
53 unsigned numOutputs = outputTypes.size();
54 if (numOutputs != 1)
55 return module->emitError(
56 "Modules with multiple outputs are not supported yet");
57
58 // Verify all ports are single bit
59 for (auto type : inputTypes) {
60 if (!type.isInteger(1))
61 return module->emitError("All input ports must be single bit");
62 }
63 for (auto type : outputTypes) {
64 if (!type.isInteger(1))
65 return module->emitError("All output ports must be single bit");
66 }
67
68 if (numInputs > maxTruthTableInputs)
69 return module->emitError("Too many inputs for truth table generation");
70
71 SmallVector<Value> results;
72 results.reserve(numOutputs);
73 // Get the body block of the module
74 auto *bodyBlock = module.getBodyBlock();
75 assert(bodyBlock && "Module must have a body block");
76 // Collect output values from the body block
77 for (auto result : bodyBlock->getTerminator()->getOperands())
78 results.push_back(result);
79
80 // Create a truth table for the module
81 FailureOr<BinaryTruthTable> truthTable = getTruthTable(results, bodyBlock);
82 if (failed(truthTable))
83 return failure();
84
85 return NPNClass::computeNPNCanonicalForm(*truthTable);
86}
87
88/// Simple technology library encoded as a HWModuleOp.
91 SmallVector<DelayType> delay, NPNClass npnClass)
93 delay(std::move(delay)), module(module), npnClass(std::move(npnClass)) {
94
95 LLVM_DEBUG({
96 llvm::dbgs() << "Created Tech Library Pattern for module: "
97 << module.getModuleName() << "\n"
98 << "NPN Class: " << this->npnClass.truthTable.table << "\n"
99 << "Inputs: " << this->npnClass.inputPermutation.size()
100 << "\n"
101 << "Input Negation: " << this->npnClass.inputNegation << "\n"
102 << "Output Negation: " << this->npnClass.outputNegation
103 << "\n";
104 });
105 }
106
107 StringRef getPatternName() const override {
108 auto moduleCp = module;
109 return moduleCp.getModuleName();
110 }
111
112 /// Match the cut set against this library primitive
113 std::optional<MatchResult> match(CutEnumerator &enumerator,
114 const Cut &cut) const override {
115 if (!cut.getNPNClass(enumerator.getOptions().npnTable)
117 return std::nullopt;
118
119 return MatchResult(area, delay);
120 }
121
122 /// Enable truth table matching for this pattern
124 SmallVectorImpl<NPNClass> &matchingNPNClasses) const override {
125 matchingNPNClasses.push_back(npnClass);
126 return true;
127 }
128
129 /// Rewrite the cut set using this library primitive
130 llvm::FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
131 CutEnumerator &enumerator,
132 const Cut &cut) const override {
133 const auto &network = enumerator.getLogicNetwork();
134 // Create a new instance of the module
135 SmallVector<unsigned> permutedInputIndices;
137 permutedInputIndices);
138
139 SmallVector<Value> inputs;
140 inputs.reserve(permutedInputIndices.size());
141 for (unsigned idx : permutedInputIndices) {
142 assert(idx < cut.inputs.size() && "input permutation index out of range");
143 inputs.push_back(network.getValue(cut.inputs[idx]));
144 }
145
146 auto *rootOp = network.getGate(cut.getRootIndex()).getOperation();
147 assert(rootOp && "cut root must be a valid operation");
148
149 // TODO: Give a better name to the instance
150 auto instanceOp = hw::InstanceOp::create(builder, rootOp->getLoc(), module,
151 "mapped", ArrayRef<Value>(inputs));
152 return instanceOp.getOperation();
153 }
154
155 unsigned getNumInputs() const {
156 return static_cast<hw::HWModuleOp>(module).getNumInputPorts();
157 }
158
159 unsigned getNumOutputs() const override {
160 return static_cast<hw::HWModuleOp>(module).getNumOutputPorts();
161 }
162
163 LocationAttr getLoc() const override {
164 auto module = this->module;
165 return module.getLoc();
166 }
167
168private:
169 const double area;
170 const SmallVector<DelayType> delay;
171 hw::HWModuleOp module;
173};
174
175namespace {
176struct TechMapperPass : public impl::TechMapperBase<TechMapperPass> {
177 using TechMapperBase<TechMapperPass>::TechMapperBase;
178
179 LogicalResult initialize(MLIRContext *context) override {
180 (void)context;
181 npnTable = std::make_shared<const NPNTable>();
182 return success();
183 }
184
185 void runOnOperation() override {
186 auto module = getOperation();
187
188 SmallVector<std::unique_ptr<CutRewritePattern>> libraryPatterns;
189
190 unsigned maxInputSize = 0;
191 // Consider modules with the "hw.techlib.info" attribute as library
192 // modules.
193 // TODO: This attribute should be replaced with a more structured
194 // representation of technology library information. Specifically, we should
195 // have a dedicated operation for technology library.
196 SmallVector<hw::HWModuleOp> nonLibraryModules;
197 for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
198 auto techInfo =
199 hwModule->getAttrOfType<DictionaryAttr>("hw.techlib.info");
200 if (!techInfo) {
201 // If the module does not have the techlib info, it is not a library
202 // TODO: Run mapping only when the module is under the specific
203 // hierarchy.
204 nonLibraryModules.push_back(hwModule);
205 continue;
206 }
207
208 // Get area and delay attributes
209 auto areaAttr = techInfo.getAs<FloatAttr>("area");
210 auto delayAttr = techInfo.getAs<ArrayAttr>("delay");
211 if (!areaAttr || !delayAttr) {
212 mlir::emitError(hwModule.getLoc())
213 << "Library module " << hwModule.getModuleName()
214 << " must have 'area'(float) and 'delay' (2d array to represent "
215 "input-output pair delay) attributes";
216 signalPassFailure();
217 return;
218 }
219
220 double area = areaAttr.getValue().convertToDouble();
221
222 SmallVector<DelayType> delay;
223 for (auto delayValue : delayAttr) {
224 auto delayArray = cast<ArrayAttr>(delayValue);
225 for (auto delayElement : delayArray) {
226 // FIXME: Currently we assume delay is given as integer attributes,
227 // this should be replaced once we have a proper cell op with
228 // dedicated timing attributes with units.
229 delay.push_back(
230 cast<mlir::IntegerAttr>(delayElement).getValue().getZExtValue());
231 }
232 }
233 // Compute NPN Class for the module.
234 auto npnClass = getNPNClassFromModule(hwModule);
235 if (failed(npnClass)) {
236 signalPassFailure();
237 return;
238 }
239
240 // Create a CutRewritePattern for the library module
241 std::unique_ptr<TechLibraryPattern> pattern =
242 std::make_unique<TechLibraryPattern>(hwModule, area, std::move(delay),
243 std::move(*npnClass));
244
245 // Update the maximum input size
246 maxInputSize = std::max(maxInputSize, pattern->getNumInputs());
247
248 // Add the pattern to the library
249 libraryPatterns.push_back(std::move(pattern));
250 }
251
252 if (libraryPatterns.empty())
253 return markAllAnalysesPreserved();
254
255 CutRewritePatternSet patternSet(std::move(libraryPatterns));
256 CutRewriterOptions options;
257 options.strategy = strategy;
258 options.maxCutInputSize = maxInputSize;
259 options.maxCutSizePerRoot = maxCutsPerRoot;
260 options.attachDebugTiming = test;
261 options.npnTable = npnTable.get();
262 std::atomic<uint64_t> numCutsCreatedCount = 0;
263 std::atomic<uint64_t> numCutSetsCreatedCount = 0;
264 std::atomic<uint64_t> numCutsRewrittenCount = 0;
265 auto result = mlir::failableParallelForEach(
266 module.getContext(), nonLibraryModules, [&](hw::HWModuleOp hwModule) {
267 LLVM_DEBUG(llvm::dbgs() << "Processing non-library module: "
268 << hwModule.getName() << "\n");
269 CutRewriter rewriter(options, patternSet);
270 if (failed(rewriter.run(hwModule)))
271 return failure();
272 const auto &stats = rewriter.getStats();
273 numCutsCreatedCount.fetch_add(stats.numCutsCreated,
274 std::memory_order_relaxed);
275 numCutSetsCreatedCount.fetch_add(stats.numCutSetsCreated,
276 std::memory_order_relaxed);
277 numCutsRewrittenCount.fetch_add(stats.numCutsRewritten,
278 std::memory_order_relaxed);
279 return success();
280 });
281 if (failed(result))
282 signalPassFailure();
283 numCutsCreated += numCutsCreatedCount;
284 numCutSetsCreated += numCutSetsCreatedCount;
285 numCutsRewritten += numCutsRewrittenCount;
286 }
287
288private:
289 std::shared_ptr<const NPNTable> npnTable;
290};
291
292} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
RewritePatternSet pattern
Strategy strategy
static llvm::FailureOr< NPNClass > getNPNClassFromModule(hw::HWModuleOp module)
Cut enumeration engine for combinational logic networks.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const CutRewriterOptions & getOptions() const
Get the cut rewriter options used for this enumeration.
Manages a collection of rewriting patterns for combinational logic optimization.
Represents a cut in the combinational logic network.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Simple technology library encoded as a HWModuleOp.
TechLibraryPattern(hw::HWModuleOp module, double area, SmallVector< DelayType > delay, NPNClass npnClass)
StringRef getPatternName() const override
Get the name of this pattern. Used for debugging.
hw::HWModuleOp NPNClass npnClass
std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const override
Match the cut set against this library primitive.
llvm::FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const override
Rewrite the cut set using this library primitive.
const SmallVector< DelayType > delay
bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const override
Enable truth table matching for this pattern.
unsigned getNumOutputs() const override
Get the number of outputs this pattern produces.
unsigned getNumInputs() const
const double area
LocationAttr getLoc() const override
Get location for this pattern(optional).
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:106
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
bool equivalentOtherThanPermutation(const NPNClass &other) const
Equality comparison for NPN classes.
Definition TruthTable.h:149
Base class for cut rewriting patterns used in combinational logic optimization.
mlir::MLIRContext * getContext() const
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
Result of matching a cut against a pattern.