CIRCT 22.0.0git
Loading...
Searching...
No Matches
TechMapper.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the TechMapper pass, which performs technology mapping
10// by converting logic network representations (AIG operations) into
11// technology-specific gate implementations using cut-based rewriting.
12//
13// The pass uses a cut-based algorithm with priority cuts and NPN canonical
14// forms for efficient pattern matching. It processes HWModuleOp instances with
15// "hw.techlib.info" attributes as technology library patterns and maps
16// non-library modules to optimal gate implementations based on area and timing
17// optimization strategies.
18//
19//===----------------------------------------------------------------------===//
20
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Threading.h"
25#include "mlir/Support/WalkResult.h"
26#include "llvm/ADT/APInt.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Debug.h"
29
30namespace circt {
31namespace synth {
32#define GEN_PASS_DEF_TECHMAPPER
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
34} // namespace synth
35} // namespace circt
36
37using namespace circt;
38using namespace circt::synth;
39
40#define DEBUG_TYPE "synth-tech-mapper"
41
42//===----------------------------------------------------------------------===//
43// Tech Mapper Pass
44//===----------------------------------------------------------------------===//
45
46static llvm::FailureOr<NPNClass> getNPNClassFromModule(hw::HWModuleOp module) {
47 // Get input and output ports
48 auto inputTypes = module.getInputTypes();
49 auto outputTypes = module.getOutputTypes();
50
51 unsigned numInputs = inputTypes.size();
52 unsigned numOutputs = outputTypes.size();
53 if (numOutputs != 1)
54 return module->emitError(
55 "Modules with multiple outputs are not supported yet");
56
57 // Verify all ports are single bit
58 for (auto type : inputTypes) {
59 if (!type.isInteger(1))
60 return module->emitError("All input ports must be single bit");
61 }
62 for (auto type : outputTypes) {
63 if (!type.isInteger(1))
64 return module->emitError("All output ports must be single bit");
65 }
66
67 if (numInputs > maxTruthTableInputs)
68 return module->emitError("Too many inputs for truth table generation");
69
70 SmallVector<Value> results;
71 results.reserve(numOutputs);
72 // Get the body block of the module
73 auto *bodyBlock = module.getBodyBlock();
74 assert(bodyBlock && "Module must have a body block");
75 // Collect output values from the body block
76 for (auto result : bodyBlock->getTerminator()->getOperands())
77 results.push_back(result);
78
79 // Create a truth table for the module
80 FailureOr<BinaryTruthTable> truthTable = getTruthTable(results, bodyBlock);
81 if (failed(truthTable))
82 return failure();
83
84 return NPNClass::computeNPNCanonicalForm(*truthTable);
85}
86
87/// Simple technology library encoded as a HWModuleOp.
90 SmallVector<SmallVector<DelayType, 2>, 4> delay,
93 delay(std::move(delay)), module(module), npnClass(std::move(npnClass)) {
94
95 LLVM_DEBUG({
96 llvm::dbgs() << "Created Tech Library Pattern for module: "
97 << module.getModuleName() << "\n"
98 << "NPN Class: " << npnClass.truthTable.table << "\n"
99 << "Inputs: " << npnClass.inputPermutation.size() << "\n"
100 << "Input Negation: " << npnClass.inputNegation << "\n"
101 << "Output Negation: " << npnClass.outputNegation << "\n";
102 });
103 }
104
105 StringRef getPatternName() const override {
106 auto moduleCp = module;
107 return moduleCp.getModuleName();
108 }
109
110 /// Match the cut set against this library primitive
111 bool match(const Cut &cut) const override {
113 }
114
115 /// Enable truth table matching for this pattern
117 SmallVectorImpl<NPNClass> &matchingNPNClasses) const override {
118 matchingNPNClasses.push_back(npnClass);
119 return true;
120 }
121
122 /// Rewrite the cut set using this library primitive
123 llvm::FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
124 Cut &cut) const override {
125 // Create a new instance of the module
126 SmallVector<Value> inputs;
127 cut.getPermutatedInputs(npnClass, inputs);
128
129 // TODO: Give a better name to the instance
130 auto instanceOp = builder.create<hw::InstanceOp>(
131 cut.getRoot()->getLoc(), module, "mapped", ArrayRef<Value>(inputs));
132 return instanceOp.getOperation();
133 }
134
135 double getArea() const override { return area; }
136
137 DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const override {
138 return delay[inputIndex][outputIndex];
139 }
140
141 unsigned getNumInputs() const {
142 return static_cast<hw::HWModuleOp>(module).getNumInputPorts();
143 }
144
145 unsigned getNumOutputs() const override {
146 return static_cast<hw::HWModuleOp>(module).getNumOutputPorts();
147 }
148
149 LocationAttr getLoc() const override {
150 auto module = this->module;
151 return module.getLoc();
152 }
153
154private:
155 const double area;
156 const SmallVector<SmallVector<DelayType, 2>, 4> delay;
157 hw::HWModuleOp module;
159};
160
161namespace {
162struct TechMapperPass : public impl::TechMapperBase<TechMapperPass> {
163 using TechMapperBase<TechMapperPass>::TechMapperBase;
164
165 void runOnOperation() override {
166 auto module = getOperation();
167
168 SmallVector<std::unique_ptr<CutRewritePattern>> libraryPatterns;
169
170 unsigned maxInputSize = 0;
171 // Consider modules with the "hw.techlib.info" attribute as library
172 // modules.
173 // TODO: This attribute should be replaced with a more structured
174 // representation of technology library information. Specifically, we should
175 // have a dedicated operation for technology library.
176 SmallVector<hw::HWModuleOp> nonLibraryModules;
177 for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
178 auto techInfo =
179 hwModule->getAttrOfType<DictionaryAttr>("hw.techlib.info");
180 if (!techInfo) {
181 // If the module does not have the techlib info, it is not a library
182 // TODO: Run mapping only when the module is under the specific
183 // hierarchy.
184 nonLibraryModules.push_back(hwModule);
185 continue;
186 }
187
188 // Get area and delay attributes
189 auto areaAttr = techInfo.getAs<FloatAttr>("area");
190 auto delayAttr = techInfo.getAs<ArrayAttr>("delay");
191 if (!areaAttr || !delayAttr) {
192 mlir::emitError(hwModule.getLoc())
193 << "Library module " << hwModule.getModuleName()
194 << " must have 'area'(float) and 'delay' (2d array to represent "
195 "input-output pair delay) attributes";
196 signalPassFailure();
197 return;
198 }
199
200 double area = areaAttr.getValue().convertToDouble();
201
202 SmallVector<SmallVector<DelayType, 2>, 4> delay;
203 for (auto delayValue : delayAttr) {
204 auto delayArray = cast<ArrayAttr>(delayValue);
205 SmallVector<DelayType, 2> delayRow;
206 for (auto delayElement : delayArray) {
207 // FIXME: Currently we assume delay is given as integer attributes,
208 // this should be replaced once we have a proper cell op with
209 // dedicated timing attributes with units.
210 delayRow.push_back(
211 cast<mlir::IntegerAttr>(delayElement).getValue().getZExtValue());
212 }
213 delay.push_back(std::move(delayRow));
214 }
215 // Compute NPN Class for the module.
216 auto npnClass = getNPNClassFromModule(hwModule);
217 if (failed(npnClass)) {
218 signalPassFailure();
219 return;
220 }
221
222 // Create a CutRewritePattern for the library module
223 std::unique_ptr<TechLibraryPattern> pattern =
224 std::make_unique<TechLibraryPattern>(hwModule, area, std::move(delay),
225 std::move(*npnClass));
226
227 // Update the maximum input size
228 maxInputSize = std::max(maxInputSize, pattern->getNumInputs());
229
230 // Add the pattern to the library
231 libraryPatterns.push_back(std::move(pattern));
232 }
233
234 if (libraryPatterns.empty())
235 return markAllAnalysesPreserved();
236
237 CutRewritePatternSet patternSet(std::move(libraryPatterns));
238 CutRewriterOptions options;
239 options.strategy = strategy;
240 options.maxCutInputSize = maxInputSize;
241 options.maxCutSizePerRoot = maxCutsPerRoot;
242 options.attachDebugTiming = test;
243 auto result = mlir::failableParallelForEach(
244 module.getContext(), nonLibraryModules, [&](hw::HWModuleOp hwModule) {
245 LLVM_DEBUG(llvm::dbgs() << "Processing non-library module: "
246 << hwModule.getName() << "\n");
247 CutRewriter rewriter(options, patternSet);
248 return rewriter.run(hwModule);
249 });
250 if (failed(result))
251 signalPassFailure();
252 }
253};
254
255} // namespace
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Strategy strategy
static llvm::FailureOr< NPNClass > getNPNClassFromModule(hw::HWModuleOp module)
Manages a collection of rewriting patterns for combinational logic optimization.
Represents a cut in the combinational logic network.
void getPermutatedInputs(const NPNClass &patternNPN, SmallVectorImpl< Value > &permutedInputs) const
Get the permutated inputs for this cut based on the given pattern NPN.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
mlir::Operation * getRoot() const
Get the root operation of this cut.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
int64_t DelayType
Definition CutRewriter.h:36
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:41
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Simple technology library encoded as a HWModuleOp.
StringRef getPatternName() const override
Get the name of this pattern. Used for debugging.
hw::HWModuleOp NPNClass npnClass
TechLibraryPattern(hw::HWModuleOp module, double area, SmallVector< SmallVector< DelayType, 2 >, 4 > delay, NPNClass npnClass)
double getArea() const override
Get the area cost of this pattern.
llvm::FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, Cut &cut) const override
Rewrite the cut set using this library primitive.
bool match(const Cut &cut) const override
Match the cut set against this library primitive.
bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const override
Enable truth table matching for this pattern.
unsigned getNumOutputs() const override
Get the number of outputs this pattern produces.
DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const override
Get the delay between specific input and output.
const SmallVector< SmallVector< DelayType, 2 >, 4 > delay
unsigned getNumInputs() const
const double area
LocationAttr getLoc() const override
Get location for this pattern(optional).
Represents the canonical form of a boolean function under NPN equivalence.
Definition NPNClass.h:103
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
Definition NPNClass.cpp:210
bool equivalentOtherThanPermutation(const NPNClass &other) const
Equality comparison for NPN classes.
Definition NPNClass.h:146
Base class for cut rewriting patterns used in combinational logic optimization.
mlir::MLIRContext * getContext() const
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).