CIRCT 23.0.0git
Loading...
Searching...
No Matches
TechMapper.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the TechMapper pass, which performs technology mapping
10// by converting logic network representations (AIG operations) into
11// technology-specific gate implementations using cut-based rewriting.
12//
13// The pass uses a cut-based algorithm with priority cuts and NPN canonical
14// forms for efficient pattern matching. It processes HWModuleOp instances with
15// "hw.techlib.info" attributes as technology library patterns and maps
16// non-library modules to optimal gate implementations based on area and timing
17// optimization strategies.
18//
19//===----------------------------------------------------------------------===//
20
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Threading.h"
25#include "mlir/Support/WalkResult.h"
26#include "llvm/ADT/APInt.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Debug.h"
29#include <atomic>
30
31namespace circt {
32namespace synth {
33#define GEN_PASS_DEF_TECHMAPPER
34#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
35} // namespace synth
36} // namespace circt
37
38using namespace circt;
39using namespace circt::synth;
40
41#define DEBUG_TYPE "synth-tech-mapper"
42
43//===----------------------------------------------------------------------===//
44// Tech Mapper Pass
45//===----------------------------------------------------------------------===//
46
47static llvm::FailureOr<NPNClass> getNPNClassFromModule(hw::HWModuleOp module) {
48 // Get input and output ports
49 auto inputTypes = module.getInputTypes();
50 auto outputTypes = module.getOutputTypes();
51
52 unsigned numInputs = inputTypes.size();
53 unsigned numOutputs = outputTypes.size();
54 if (numOutputs != 1)
55 return module->emitError(
56 "Modules with multiple outputs are not supported yet");
57
58 // Verify all ports are single bit
59 for (auto type : inputTypes) {
60 if (!type.isInteger(1))
61 return module->emitError("All input ports must be single bit");
62 }
63 for (auto type : outputTypes) {
64 if (!type.isInteger(1))
65 return module->emitError("All output ports must be single bit");
66 }
67
68 if (numInputs > maxTruthTableInputs)
69 return module->emitError("Too many inputs for truth table generation");
70
71 SmallVector<Value> results;
72 results.reserve(numOutputs);
73 // Get the body block of the module
74 auto *bodyBlock = module.getBodyBlock();
75 assert(bodyBlock && "Module must have a body block");
76 // Collect output values from the body block
77 for (auto result : bodyBlock->getTerminator()->getOperands())
78 results.push_back(result);
79
80 // Create a truth table for the module
81 FailureOr<BinaryTruthTable> truthTable = getTruthTable(results, bodyBlock);
82 if (failed(truthTable))
83 return failure();
84
85 return NPNClass::computeNPNCanonicalForm(*truthTable);
86}
87
88/// Simple technology library encoded as a HWModuleOp.
91 SmallVector<DelayType> delay, NPNClass npnClass)
93 delay(std::move(delay)), module(module), npnClass(std::move(npnClass)) {
94
95 LLVM_DEBUG({
96 llvm::dbgs() << "Created Tech Library Pattern for module: "
97 << module.getModuleName() << "\n"
98 << "NPN Class: " << this->npnClass.truthTable.table << "\n"
99 << "Inputs: " << this->npnClass.inputPermutation.size()
100 << "\n"
101 << "Input Negation: " << this->npnClass.inputNegation << "\n"
102 << "Output Negation: " << this->npnClass.outputNegation
103 << "\n";
104 });
105 }
106
107 StringRef getPatternName() const override {
108 auto moduleCp = module;
109 return moduleCp.getModuleName();
110 }
111
112 /// Match the cut set against this library primitive
113 std::optional<MatchResult> match(CutEnumerator &enumerator,
114 const Cut &cut) const override {
116 return std::nullopt;
117
118 return MatchResult(area, delay);
119 }
120
121 /// Enable truth table matching for this pattern
123 SmallVectorImpl<NPNClass> &matchingNPNClasses) const override {
124 matchingNPNClasses.push_back(npnClass);
125 return true;
126 }
127
128 /// Rewrite the cut set using this library primitive
129 llvm::FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
130 CutEnumerator &enumerator,
131 const Cut &cut) const override {
132 const auto &network = enumerator.getLogicNetwork();
133 // Create a new instance of the module
134 SmallVector<unsigned> permutedInputIndices;
135 cut.getPermutatedInputIndices(npnClass, permutedInputIndices);
136
137 SmallVector<Value> inputs;
138 inputs.reserve(permutedInputIndices.size());
139 for (unsigned idx : permutedInputIndices) {
140 assert(idx < cut.inputs.size() && "input permutation index out of range");
141 inputs.push_back(network.getValue(cut.inputs[idx]));
142 }
143
144 auto *rootOp = network.getGate(cut.getRootIndex()).getOperation();
145 assert(rootOp && "cut root must be a valid operation");
146
147 // TODO: Give a better name to the instance
148 auto instanceOp = hw::InstanceOp::create(builder, rootOp->getLoc(), module,
149 "mapped", ArrayRef<Value>(inputs));
150 return instanceOp.getOperation();
151 }
152
153 unsigned getNumInputs() const {
154 return static_cast<hw::HWModuleOp>(module).getNumInputPorts();
155 }
156
157 unsigned getNumOutputs() const override {
158 return static_cast<hw::HWModuleOp>(module).getNumOutputPorts();
159 }
160
161 LocationAttr getLoc() const override {
162 auto module = this->module;
163 return module.getLoc();
164 }
165
166private:
167 const double area;
168 const SmallVector<DelayType> delay;
169 hw::HWModuleOp module;
171};
172
173namespace {
174struct TechMapperPass : public impl::TechMapperBase<TechMapperPass> {
175 using TechMapperBase<TechMapperPass>::TechMapperBase;
176
177 void runOnOperation() override {
178 auto module = getOperation();
179
180 SmallVector<std::unique_ptr<CutRewritePattern>> libraryPatterns;
181
182 unsigned maxInputSize = 0;
183 // Consider modules with the "hw.techlib.info" attribute as library
184 // modules.
185 // TODO: This attribute should be replaced with a more structured
186 // representation of technology library information. Specifically, we should
187 // have a dedicated operation for technology library.
188 SmallVector<hw::HWModuleOp> nonLibraryModules;
189 for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
190 auto techInfo =
191 hwModule->getAttrOfType<DictionaryAttr>("hw.techlib.info");
192 if (!techInfo) {
193 // If the module does not have the techlib info, it is not a library
194 // TODO: Run mapping only when the module is under the specific
195 // hierarchy.
196 nonLibraryModules.push_back(hwModule);
197 continue;
198 }
199
200 // Get area and delay attributes
201 auto areaAttr = techInfo.getAs<FloatAttr>("area");
202 auto delayAttr = techInfo.getAs<ArrayAttr>("delay");
203 if (!areaAttr || !delayAttr) {
204 mlir::emitError(hwModule.getLoc())
205 << "Library module " << hwModule.getModuleName()
206 << " must have 'area'(float) and 'delay' (2d array to represent "
207 "input-output pair delay) attributes";
208 signalPassFailure();
209 return;
210 }
211
212 double area = areaAttr.getValue().convertToDouble();
213
214 SmallVector<DelayType> delay;
215 for (auto delayValue : delayAttr) {
216 auto delayArray = cast<ArrayAttr>(delayValue);
217 for (auto delayElement : delayArray) {
218 // FIXME: Currently we assume delay is given as integer attributes,
219 // this should be replaced once we have a proper cell op with
220 // dedicated timing attributes with units.
221 delay.push_back(
222 cast<mlir::IntegerAttr>(delayElement).getValue().getZExtValue());
223 }
224 }
225 // Compute NPN Class for the module.
226 auto npnClass = getNPNClassFromModule(hwModule);
227 if (failed(npnClass)) {
228 signalPassFailure();
229 return;
230 }
231
232 // Create a CutRewritePattern for the library module
233 std::unique_ptr<TechLibraryPattern> pattern =
234 std::make_unique<TechLibraryPattern>(hwModule, area, std::move(delay),
235 std::move(*npnClass));
236
237 // Update the maximum input size
238 maxInputSize = std::max(maxInputSize, pattern->getNumInputs());
239
240 // Add the pattern to the library
241 libraryPatterns.push_back(std::move(pattern));
242 }
243
244 if (libraryPatterns.empty())
245 return markAllAnalysesPreserved();
246
247 CutRewritePatternSet patternSet(std::move(libraryPatterns));
248 CutRewriterOptions options;
249 options.strategy = strategy;
250 options.maxCutInputSize = maxInputSize;
251 options.maxCutSizePerRoot = maxCutsPerRoot;
252 options.attachDebugTiming = test;
253 std::atomic<uint64_t> numCutsCreatedCount = 0;
254 std::atomic<uint64_t> numCutSetsCreatedCount = 0;
255 std::atomic<uint64_t> numCutsRewrittenCount = 0;
256 auto result = mlir::failableParallelForEach(
257 module.getContext(), nonLibraryModules, [&](hw::HWModuleOp hwModule) {
258 LLVM_DEBUG(llvm::dbgs() << "Processing non-library module: "
259 << hwModule.getName() << "\n");
260 CutRewriter rewriter(options, patternSet);
261 if (failed(rewriter.run(hwModule)))
262 return failure();
263 const auto &stats = rewriter.getStats();
264 numCutsCreatedCount.fetch_add(stats.numCutsCreated,
265 std::memory_order_relaxed);
266 numCutSetsCreatedCount.fetch_add(stats.numCutSetsCreated,
267 std::memory_order_relaxed);
268 numCutsRewrittenCount.fetch_add(stats.numCutsRewritten,
269 std::memory_order_relaxed);
270 return success();
271 });
272 if (failed(result))
273 signalPassFailure();
274 numCutsCreated += numCutsCreatedCount;
275 numCutSetsCreated += numCutSetsCreatedCount;
276 numCutsRewritten += numCutsRewrittenCount;
277 }
278};
279
280} // namespace
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Strategy strategy
static llvm::FailureOr< NPNClass > getNPNClassFromModule(hw::HWModuleOp module)
Cut enumeration engine for combinational logic networks.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
Manages a collection of rewriting patterns for combinational logic optimization.
Represents a cut in the combinational logic network.
void getPermutatedInputIndices(const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Simple technology library encoded as a HWModuleOp.
TechLibraryPattern(hw::HWModuleOp module, double area, SmallVector< DelayType > delay, NPNClass npnClass)
StringRef getPatternName() const override
Get the name of this pattern. Used for debugging.
hw::HWModuleOp NPNClass npnClass
std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const override
Match the cut set against this library primitive.
llvm::FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const override
Rewrite the cut set using this library primitive.
const SmallVector< DelayType > delay
bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const override
Enable truth table matching for this pattern.
unsigned getNumOutputs() const override
Get the number of outputs this pattern produces.
unsigned getNumInputs() const
const double area
LocationAttr getLoc() const override
Get location for this pattern(optional).
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:104
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
bool equivalentOtherThanPermutation(const NPNClass &other) const
Equality comparison for NPN classes.
Definition TruthTable.h:147
Base class for cut rewriting patterns used in combinational logic optimization.
mlir::MLIRContext * getContext() const
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
Result of matching a cut against a pattern.