CIRCT 23.0.0git
Loading...
Searching...
No Matches
TechMapper.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the TechMapper pass, which performs technology mapping
10// by converting logic network representations (AIG operations) into
11// technology-specific gate implementations using cut-based rewriting.
12//
13// The pass uses a cut-based algorithm with priority cuts and NPN canonical
14// forms for efficient pattern matching. It processes HWModuleOp instances with
15// "synth.mapping_cost" attributes as technology library patterns and maps
16// non-library modules to optimal gate implementations based on area and timing
17// optimization strategies.
18//
19//===----------------------------------------------------------------------===//
20
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/Threading.h"
26#include "mlir/Support/WalkResult.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/SmallVector.h"
30#include "llvm/Support/Debug.h"
31#include <atomic>
32
33namespace circt {
34namespace synth {
35#define GEN_PASS_DEF_TECHMAPPER
36#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
37} // namespace synth
38} // namespace circt
39
40using namespace circt;
41using namespace circt::synth;
42
43#define DEBUG_TYPE "synth-tech-mapper"
44
45//===----------------------------------------------------------------------===//
46// Tech Mapper Pass
47//===----------------------------------------------------------------------===//
48
49static llvm::FailureOr<NPNClass> getNPNClassFromModule(hw::HWModuleOp module) {
50 // Get input and output ports
51 auto inputTypes = module.getInputTypes();
52 auto outputTypes = module.getOutputTypes();
53
54 unsigned numInputs = inputTypes.size();
55 unsigned numOutputs = outputTypes.size();
56 if (numOutputs != 1)
57 return module->emitError(
58 "Modules with multiple outputs are not supported yet");
59
60 // Verify all ports are single bit
61 for (auto type : inputTypes) {
62 if (!type.isInteger(1))
63 return module->emitError("All input ports must be single bit");
64 }
65 for (auto type : outputTypes) {
66 if (!type.isInteger(1))
67 return module->emitError("All output ports must be single bit");
68 }
69
70 if (numInputs > maxTruthTableInputs)
71 return module->emitError("Too many inputs for truth table generation");
72
73 SmallVector<Value> results;
74 results.reserve(numOutputs);
75 // Get the body block of the module
76 auto *bodyBlock = module.getBodyBlock();
77 assert(bodyBlock && "Module must have a body block");
78 // Collect output values from the body block
79 for (auto result : bodyBlock->getTerminator()->getOperands())
80 results.push_back(result);
81
82 // Create a truth table for the module
83 FailureOr<BinaryTruthTable> truthTable = getTruthTable(results, bodyBlock);
84 if (failed(truthTable))
85 return failure();
86
87 return NPNClass::computeNPNCanonicalForm(*truthTable);
88}
89
90/// Simple technology library encoded as a HWModuleOp.
93 SmallVector<DelayType> delay, NPNClass npnClass)
95 delay(std::move(delay)), module(module), npnClass(std::move(npnClass)) {
96
97 LLVM_DEBUG({
98 llvm::dbgs() << "Created Tech Library Pattern for module: "
99 << module.getModuleName() << "\n"
100 << "NPN Class: " << this->npnClass.truthTable.table << "\n"
101 << "Inputs: " << this->npnClass.inputPermutation.size()
102 << "\n"
103 << "Input Negation: " << this->npnClass.inputNegation << "\n"
104 << "Output Negation: " << this->npnClass.outputNegation
105 << "\n";
106 });
107 }
108
109 StringRef getPatternName() const override {
110 auto moduleCp = module;
111 return moduleCp.getModuleName();
112 }
113
114 /// Match the cut set against this library primitive
115 std::optional<MatchResult> match(CutEnumerator &enumerator,
116 const Cut &cut) const override {
117 if (!cut.getNPNClass(enumerator.getOptions().npnTable)
119 return std::nullopt;
120
121 return MatchResult(area, delay);
122 }
123
124 /// Enable truth table matching for this pattern
126 SmallVectorImpl<NPNClass> &matchingNPNClasses) const override {
127 matchingNPNClasses.push_back(npnClass);
128 return true;
129 }
130
131 /// Rewrite the cut set using this library primitive
132 llvm::FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
133 CutEnumerator &enumerator,
134 const Cut &cut) const override {
135 const auto &network = enumerator.getLogicNetwork();
136 // Create a new instance of the module
137 SmallVector<unsigned> permutedInputIndices;
139 permutedInputIndices);
140
141 SmallVector<Value> inputs;
142 inputs.reserve(permutedInputIndices.size());
143 for (unsigned idx : permutedInputIndices) {
144 assert(idx < cut.inputs.size() && "input permutation index out of range");
145 inputs.push_back(network.getValue(cut.inputs[idx]));
146 }
147
148 auto *rootOp = network.getGate(cut.getRootIndex()).getOperation();
149 assert(rootOp && "cut root must be a valid operation");
150
151 // TODO: Give a better name to the instance
152 auto instanceOp = hw::InstanceOp::create(builder, rootOp->getLoc(), module,
153 "mapped", ArrayRef<Value>(inputs));
154 return instanceOp.getOperation();
155 }
156
157 unsigned getNumInputs() const {
158 return static_cast<hw::HWModuleOp>(module).getNumInputPorts();
159 }
160
161 unsigned getNumOutputs() const override {
162 return static_cast<hw::HWModuleOp>(module).getNumOutputPorts();
163 }
164
165 LocationAttr getLoc() const override {
166 auto module = this->module;
167 return module.getLoc();
168 }
169
170private:
171 const double area;
172 const SmallVector<DelayType> delay;
173 hw::HWModuleOp module;
175};
176
177namespace {
178struct TechMapperPass : public impl::TechMapperBase<TechMapperPass> {
179 using TechMapperBase<TechMapperPass>::TechMapperBase;
180
181 LogicalResult initialize(MLIRContext *context) override {
182 (void)context;
183 npnTable = std::make_shared<const NPNTable>();
184 return success();
185 }
186
187 void runOnOperation() override {
188 auto module = getOperation();
189
190 SmallVector<std::unique_ptr<CutRewritePattern>> libraryPatterns;
191
192 unsigned maxInputSize = 0;
193 // Consider modules with the "synth.mapping_cost" attribute as library
194 // modules.
195 SmallVector<hw::HWModuleOp> nonLibraryModules;
196 for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
197
198 auto mappingCost =
199 hwModule->getAttrOfType<MappingCostAttr>("synth.mapping_cost");
200 if (!mappingCost) {
201 nonLibraryModules.push_back(hwModule);
202 continue;
203 }
204
205 double area = mappingCost.getArea().getValue().convertToDouble();
206
207 StringAttr outputName;
208 hw::ModulePortInfo ports(hwModule.getPortList());
209 for (const auto &port : ports.getOutputs()) {
210 if (outputName) {
211 hwModule.emitError(
212 "Modules with multiple outputs are not supported yet");
213 signalPassFailure();
214 return;
215 }
216 outputName = port.name;
217 }
218 if (!outputName) {
219 hwModule.emitError("expected library module to have an output");
220 signalPassFailure();
221 return;
222 }
223
224 llvm::DenseMap<StringAttr, DelayType> delayByInput;
225 for (auto attr : mappingCost.getArcs()) {
226 auto arc = cast<LinearTimingArcAttr>(attr);
227 if (!arc) {
228 hwModule.emitError(
229 "expected synth.linear_timing_arc in synth.mapping_cost arcs");
230 signalPassFailure();
231 return;
232 }
233
234 if (arc.getPin() != outputName) {
235 hwModule.emitError("mapping cost arc output '")
236 << arc.getPin().getValue() << "' does not match module output '"
237 << outputName.getValue() << "'";
238 signalPassFailure();
239 return;
240 }
241
242 int64_t intrinsicDelay = arc.getIntrinsic();
243
244 // TechMapper currently preserves the old integer per-pin delay model.
245 // The sensitivity, polarity, and input capacitance fields are carried
246 // in the attribute for future load-aware mapping.
247 if (!delayByInput
248 .try_emplace(arc.getRelatedPin(),
249 static_cast<DelayType>(intrinsicDelay))
250 .second) {
251 hwModule.emitError("duplicate mapping cost arc for input '")
252 << arc.getRelatedPin().getValue() << "'";
253 signalPassFailure();
254 return;
255 }
256 }
257
258 SmallVector<DelayType> delay;
259 for (const auto &port : hwModule.getPortList()) {
260 if (!port.isInput())
261 continue;
262
263 auto it = delayByInput.find(port.name);
264 if (it == delayByInput.end()) {
265 hwModule.emitError("missing mapping cost arc for input '")
266 << port.name.getValue() << "'";
267 signalPassFailure();
268 return;
269 }
270
271 delay.push_back(it->second);
272 }
273
274 if (delay.size() != delayByInput.size()) {
275 hwModule.emitError(
276 "synth.mapping_cost arcs do not match module inputs");
277 signalPassFailure();
278 return;
279 }
280
281 // Compute NPN Class for the module.
282 auto npnClass = getNPNClassFromModule(hwModule);
283 if (failed(npnClass)) {
284 signalPassFailure();
285 return;
286 }
287
288 // Create a CutRewritePattern for the library module
289 std::unique_ptr<TechLibraryPattern> pattern =
290 std::make_unique<TechLibraryPattern>(hwModule, area, std::move(delay),
291 std::move(*npnClass));
292
293 // Update the maximum input size
294 maxInputSize = std::max(maxInputSize, pattern->getNumInputs());
295
296 // Add the pattern to the library
297 libraryPatterns.push_back(std::move(pattern));
298 }
299
300 if (libraryPatterns.empty())
301 return markAllAnalysesPreserved();
302
303 CutRewritePatternSet patternSet(std::move(libraryPatterns));
304 CutRewriterOptions options;
305 options.strategy = strategy;
306 options.maxCutInputSize = maxInputSize;
307 options.maxCutSizePerRoot = maxCutsPerRoot;
308 options.attachDebugTiming = test;
309 options.npnTable = npnTable.get();
310 std::atomic<uint64_t> numCutsCreatedCount = 0;
311 std::atomic<uint64_t> numCutSetsCreatedCount = 0;
312 std::atomic<uint64_t> numCutsRewrittenCount = 0;
313 auto result = mlir::failableParallelForEach(
314 module.getContext(), nonLibraryModules, [&](hw::HWModuleOp hwModule) {
315 LLVM_DEBUG(llvm::dbgs() << "Processing non-library module: "
316 << hwModule.getName() << "\n");
317 CutRewriter rewriter(options, patternSet);
318 if (failed(rewriter.run(hwModule)))
319 return failure();
320 const auto &stats = rewriter.getStats();
321 numCutsCreatedCount.fetch_add(stats.numCutsCreated,
322 std::memory_order_relaxed);
323 numCutSetsCreatedCount.fetch_add(stats.numCutSetsCreated,
324 std::memory_order_relaxed);
325 numCutsRewrittenCount.fetch_add(stats.numCutsRewritten,
326 std::memory_order_relaxed);
327 return success();
328 });
329 if (failed(result))
330 signalPassFailure();
331 numCutsCreated += numCutsCreatedCount;
332 numCutSetsCreated += numCutSetsCreatedCount;
333 numCutsRewritten += numCutsRewrittenCount;
334 }
335
336private:
337 std::shared_ptr<const NPNTable> npnTable;
338};
339
340} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
RewritePatternSet pattern
Strategy strategy
static llvm::FailureOr< NPNClass > getNPNClassFromModule(hw::HWModuleOp module)
Cut enumeration engine for combinational logic networks.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const CutRewriterOptions & getOptions() const
Get the cut rewriter options used for this enumeration.
Manages a collection of rewriting patterns for combinational logic optimization.
Represents a cut in the combinational logic network.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
Definition arc.py:1
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:40
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Simple technology library encoded as a HWModuleOp.
TechLibraryPattern(hw::HWModuleOp module, double area, SmallVector< DelayType > delay, NPNClass npnClass)
StringRef getPatternName() const override
Get the name of this pattern. Used for debugging.
hw::HWModuleOp NPNClass npnClass
std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const override
Match the cut set against this library primitive.
llvm::FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const override
Rewrite the cut set using this library primitive.
const SmallVector< DelayType > delay
bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const override
Enable truth table matching for this pattern.
unsigned getNumOutputs() const override
Get the number of outputs this pattern produces.
unsigned getNumInputs() const
const double area
LocationAttr getLoc() const override
Get location for this pattern(optional).
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:106
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
bool equivalentOtherThanPermutation(const NPNClass &other) const
Equality comparison for NPN classes.
Definition TruthTable.h:149
This holds a decoded list of input/inout and output ports for a module or instance.
PortDirectionRange getOutputs()
Base class for cut rewriting patterns used in combinational logic optimization.
mlir::MLIRContext * getContext() const
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
Result of matching a cut against a pattern.