CIRCT 22.0.0git
Loading...
Searching...
No Matches
TechMapper.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the TechMapper pass, which performs technology mapping
10// by converting logic network representations (AIG operations) into
11// technology-specific gate implementations using cut-based rewriting.
12//
13// The pass uses a cut-based algorithm with priority cuts and NPN canonical
14// forms for efficient pattern matching. It processes HWModuleOp instances with
15// "hw.techlib.info" attributes as technology library patterns and maps
16// non-library modules to optimal gate implementations based on area and timing
17// optimization strategies.
18//
19//===----------------------------------------------------------------------===//
20
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Threading.h"
25#include "mlir/Support/WalkResult.h"
26#include "llvm/ADT/APInt.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Debug.h"
29
30namespace circt {
31namespace synth {
32#define GEN_PASS_DEF_TECHMAPPER
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
34} // namespace synth
35} // namespace circt
36
37using namespace circt;
38using namespace circt::synth;
39
40#define DEBUG_TYPE "synth-tech-mapper"
41
42//===----------------------------------------------------------------------===//
43// Tech Mapper Pass
44//===----------------------------------------------------------------------===//
45
46static llvm::FailureOr<NPNClass> getNPNClassFromModule(hw::HWModuleOp module) {
47 // Get input and output ports
48 auto inputTypes = module.getInputTypes();
49 auto outputTypes = module.getOutputTypes();
50
51 unsigned numInputs = inputTypes.size();
52 unsigned numOutputs = outputTypes.size();
53 if (numOutputs != 1)
54 return module->emitError(
55 "Modules with multiple outputs are not supported yet");
56
57 // Verify all ports are single bit
58 for (auto type : inputTypes) {
59 if (!type.isInteger(1))
60 return module->emitError("All input ports must be single bit");
61 }
62 for (auto type : outputTypes) {
63 if (!type.isInteger(1))
64 return module->emitError("All output ports must be single bit");
65 }
66
67 if (numInputs > maxTruthTableInputs)
68 return module->emitError("Too many inputs for truth table generation");
69
70 SmallVector<Value> results;
71 results.reserve(numOutputs);
72 // Get the body block of the module
73 auto *bodyBlock = module.getBodyBlock();
74 assert(bodyBlock && "Module must have a body block");
75 // Collect output values from the body block
76 for (auto result : bodyBlock->getTerminator()->getOperands())
77 results.push_back(result);
78
79 // Create a truth table for the module
80 FailureOr<BinaryTruthTable> truthTable = getTruthTable(results, bodyBlock);
81 if (failed(truthTable))
82 return failure();
83
84 return NPNClass::computeNPNCanonicalForm(*truthTable);
85}
86
87/// Simple technology library encoded as a HWModuleOp.
90 SmallVector<SmallVector<DelayType, 2>, 4> delay,
93 delay(std::move(delay)), module(module), npnClass(std::move(npnClass)) {
94
95 LLVM_DEBUG({
96 llvm::dbgs() << "Created Tech Library Pattern for module: "
97 << module.getModuleName() << "\n"
98 << "NPN Class: " << npnClass.truthTable.table << "\n"
99 << "Inputs: " << npnClass.inputPermutation.size() << "\n"
100 << "Input Negation: " << npnClass.inputNegation << "\n"
101 << "Output Negation: " << npnClass.outputNegation << "\n";
102 });
103 }
104
105 StringRef getPatternName() const override {
106 auto moduleCp = module;
107 return moduleCp.getModuleName();
108 }
109
110 /// Match the cut set against this library primitive
111 bool match(const Cut &cut) const override {
113 }
114
115 /// Enable truth table matching for this pattern
117 SmallVectorImpl<NPNClass> &matchingNPNClasses) const override {
118 matchingNPNClasses.push_back(npnClass);
119 return true;
120 }
121
122 /// Rewrite the cut set using this library primitive
123 llvm::FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
124 Cut &cut) const override {
125 // Create a new instance of the module
126 SmallVector<Value> inputs;
127 cut.getPermutatedInputs(npnClass, inputs);
128
129 // TODO: Give a better name to the instance
130 auto instanceOp =
131 hw::InstanceOp::create(builder, cut.getRoot()->getLoc(), module,
132 "mapped", ArrayRef<Value>(inputs));
133 return instanceOp.getOperation();
134 }
135
136 double getArea() const override { return area; }
137
138 DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const override {
139 return delay[inputIndex][outputIndex];
140 }
141
142 unsigned getNumInputs() const {
143 return static_cast<hw::HWModuleOp>(module).getNumInputPorts();
144 }
145
146 unsigned getNumOutputs() const override {
147 return static_cast<hw::HWModuleOp>(module).getNumOutputPorts();
148 }
149
150 LocationAttr getLoc() const override {
151 auto module = this->module;
152 return module.getLoc();
153 }
154
155private:
156 const double area;
157 const SmallVector<SmallVector<DelayType, 2>, 4> delay;
158 hw::HWModuleOp module;
160};
161
162namespace {
163struct TechMapperPass : public impl::TechMapperBase<TechMapperPass> {
164 using TechMapperBase<TechMapperPass>::TechMapperBase;
165
166 void runOnOperation() override {
167 auto module = getOperation();
168
169 SmallVector<std::unique_ptr<CutRewritePattern>> libraryPatterns;
170
171 unsigned maxInputSize = 0;
172 // Consider modules with the "hw.techlib.info" attribute as library
173 // modules.
174 // TODO: This attribute should be replaced with a more structured
175 // representation of technology library information. Specifically, we should
176 // have a dedicated operation for technology library.
177 SmallVector<hw::HWModuleOp> nonLibraryModules;
178 for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
179 auto techInfo =
180 hwModule->getAttrOfType<DictionaryAttr>("hw.techlib.info");
181 if (!techInfo) {
182 // If the module does not have the techlib info, it is not a library
183 // TODO: Run mapping only when the module is under the specific
184 // hierarchy.
185 nonLibraryModules.push_back(hwModule);
186 continue;
187 }
188
189 // Get area and delay attributes
190 auto areaAttr = techInfo.getAs<FloatAttr>("area");
191 auto delayAttr = techInfo.getAs<ArrayAttr>("delay");
192 if (!areaAttr || !delayAttr) {
193 mlir::emitError(hwModule.getLoc())
194 << "Library module " << hwModule.getModuleName()
195 << " must have 'area'(float) and 'delay' (2d array to represent "
196 "input-output pair delay) attributes";
197 signalPassFailure();
198 return;
199 }
200
201 double area = areaAttr.getValue().convertToDouble();
202
203 SmallVector<SmallVector<DelayType, 2>, 4> delay;
204 for (auto delayValue : delayAttr) {
205 auto delayArray = cast<ArrayAttr>(delayValue);
206 SmallVector<DelayType, 2> delayRow;
207 for (auto delayElement : delayArray) {
208 // FIXME: Currently we assume delay is given as integer attributes,
209 // this should be replaced once we have a proper cell op with
210 // dedicated timing attributes with units.
211 delayRow.push_back(
212 cast<mlir::IntegerAttr>(delayElement).getValue().getZExtValue());
213 }
214 delay.push_back(std::move(delayRow));
215 }
216 // Compute NPN Class for the module.
217 auto npnClass = getNPNClassFromModule(hwModule);
218 if (failed(npnClass)) {
219 signalPassFailure();
220 return;
221 }
222
223 // Create a CutRewritePattern for the library module
224 std::unique_ptr<TechLibraryPattern> pattern =
225 std::make_unique<TechLibraryPattern>(hwModule, area, std::move(delay),
226 std::move(*npnClass));
227
228 // Update the maximum input size
229 maxInputSize = std::max(maxInputSize, pattern->getNumInputs());
230
231 // Add the pattern to the library
232 libraryPatterns.push_back(std::move(pattern));
233 }
234
235 if (libraryPatterns.empty())
236 return markAllAnalysesPreserved();
237
238 CutRewritePatternSet patternSet(std::move(libraryPatterns));
239 CutRewriterOptions options;
240 options.strategy = strategy;
241 options.maxCutInputSize = maxInputSize;
242 options.maxCutSizePerRoot = maxCutsPerRoot;
243 options.attachDebugTiming = test;
244 auto result = mlir::failableParallelForEach(
245 module.getContext(), nonLibraryModules, [&](hw::HWModuleOp hwModule) {
246 LLVM_DEBUG(llvm::dbgs() << "Processing non-library module: "
247 << hwModule.getName() << "\n");
248 CutRewriter rewriter(options, patternSet);
249 return rewriter.run(hwModule);
250 });
251 if (failed(result))
252 signalPassFailure();
253 }
254};
255
256} // namespace
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Strategy strategy
static llvm::FailureOr< NPNClass > getNPNClassFromModule(hw::HWModuleOp module)
Manages a collection of rewriting patterns for combinational logic optimization.
Represents a cut in the combinational logic network.
void getPermutatedInputs(const NPNClass &patternNPN, SmallVectorImpl< Value > &permutedInputs) const
Get the permutated inputs for this cut based on the given pattern NPN.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
mlir::Operation * getRoot() const
Get the root operation of this cut.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
int64_t DelayType
Definition CutRewriter.h:36
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:41
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Simple technology library encoded as a HWModuleOp.
StringRef getPatternName() const override
Get the name of this pattern. Used for debugging.
hw::HWModuleOp NPNClass npnClass
TechLibraryPattern(hw::HWModuleOp module, double area, SmallVector< SmallVector< DelayType, 2 >, 4 > delay, NPNClass npnClass)
double getArea() const override
Get the area cost of this pattern.
llvm::FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, Cut &cut) const override
Rewrite the cut set using this library primitive.
bool match(const Cut &cut) const override
Match the cut set against this library primitive.
bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const override
Enable truth table matching for this pattern.
unsigned getNumOutputs() const override
Get the number of outputs this pattern produces.
DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const override
Get the delay between specific input and output.
const SmallVector< SmallVector< DelayType, 2 >, 4 > delay
unsigned getNumInputs() const
const double area
LocationAttr getLoc() const override
Get location for this pattern(optional).
Represents the canonical form of a boolean function under NPN equivalence.
Definition NPNClass.h:103
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
Definition NPNClass.cpp:210
bool equivalentOtherThanPermutation(const NPNClass &other) const
Equality comparison for NPN classes.
Definition NPNClass.h:146
Base class for cut rewriting patterns used in combinational logic optimization.
mlir::MLIRContext * getContext() const
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).