23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Threading.h"
25#include "mlir/Support/WalkResult.h"
26#include "llvm/ADT/APInt.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Debug.h"
32#define GEN_PASS_DEF_TECHMAPPER
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
40#define DEBUG_TYPE "synth-tech-mapper"
48 auto inputTypes =
module.getInputTypes();
49 auto outputTypes =
module.getOutputTypes();
51 unsigned numInputs = inputTypes.size();
52 unsigned numOutputs = outputTypes.size();
54 return module->emitError(
55 "Modules with multiple outputs are not supported yet");
58 for (
auto type : inputTypes) {
59 if (!type.isInteger(1))
60 return module->emitError(
"All input ports must be single bit");
62 for (
auto type : outputTypes) {
63 if (!type.isInteger(1))
64 return module->emitError(
"All output ports must be single bit");
68 return module->emitError("Too many inputs for truth table generation");
70 SmallVector<Value> results;
71 results.reserve(numOutputs);
73 auto *bodyBlock =
module.getBodyBlock();
74 assert(bodyBlock &&
"Module must have a body block");
76 for (
auto result : bodyBlock->getTerminator()->getOperands())
77 results.push_back(result);
80 FailureOr<BinaryTruthTable> truthTable =
getTruthTable(results, bodyBlock);
81 if (failed(truthTable))
90 SmallVector<SmallVector<DelayType, 2>, 4>
delay,
96 llvm::dbgs() <<
"Created Tech Library Pattern for module: "
97 <<
module.getModuleName() << "\n"
98 << "NPN Class: " << npnClass.truthTable.table << "\n"
99 << "Inputs: " << npnClass.inputPermutation.size() << "\n"
100 << "Input Negation: " << npnClass.inputNegation << "\n"
101 << "Output Negation: " << npnClass.outputNegation << "\n";
106 auto moduleCp =
module;
107 return moduleCp.getModuleName();
117 SmallVectorImpl<NPNClass> &matchingNPNClasses)
const override {
118 matchingNPNClasses.push_back(
npnClass);
123 llvm::FailureOr<Operation *>
rewrite(mlir::OpBuilder &builder,
124 Cut &cut)
const override {
126 SmallVector<Value> inputs;
130 auto instanceOp = builder.create<hw::InstanceOp>(
131 cut.
getRoot()->getLoc(), module,
"mapped", ArrayRef<Value>(inputs));
132 return instanceOp.getOperation();
138 return delay[inputIndex][outputIndex];
150 auto module = this->module;
151 return module.getLoc();
156 const SmallVector<SmallVector<DelayType, 2>, 4>
delay;
162struct TechMapperPass :
public impl::TechMapperBase<TechMapperPass> {
163 using TechMapperBase<TechMapperPass>::TechMapperBase;
165 void runOnOperation()
override {
166 auto module = getOperation();
168 SmallVector<std::unique_ptr<CutRewritePattern>> libraryPatterns;
170 unsigned maxInputSize = 0;
176 SmallVector<hw::HWModuleOp> nonLibraryModules;
177 for (
auto hwModule :
module.getOps<hw::HWModuleOp>()) {
179 hwModule->getAttrOfType<DictionaryAttr>("hw.techlib.info");
184 nonLibraryModules.push_back(hwModule);
189 auto areaAttr = techInfo.getAs<FloatAttr>(
"area");
190 auto delayAttr = techInfo.getAs<ArrayAttr>(
"delay");
191 if (!areaAttr || !delayAttr) {
192 mlir::emitError(hwModule.getLoc())
193 <<
"Library module " << hwModule.getModuleName()
194 <<
" must have 'area'(float) and 'delay' (2d array to represent "
195 "input-output pair delay) attributes";
200 double area = areaAttr.getValue().convertToDouble();
202 SmallVector<SmallVector<DelayType, 2>, 4> delay;
203 for (
auto delayValue : delayAttr) {
204 auto delayArray = cast<ArrayAttr>(delayValue);
205 SmallVector<DelayType, 2> delayRow;
206 for (
auto delayElement : delayArray) {
211 cast<mlir::IntegerAttr>(delayElement).getValue().getZExtValue());
213 delay.push_back(std::move(delayRow));
217 if (failed(npnClass)) {
223 std::unique_ptr<TechLibraryPattern>
pattern =
224 std::make_unique<TechLibraryPattern>(hwModule, area, std::move(delay),
225 std::move(*npnClass));
228 maxInputSize = std::max(maxInputSize,
pattern->getNumInputs());
231 libraryPatterns.push_back(std::move(
pattern));
234 if (libraryPatterns.empty())
235 return markAllAnalysesPreserved();
243 auto result = mlir::failableParallelForEach(
244 module.getContext(), nonLibraryModules, [&](
hw::HWModuleOp hwModule) {
245 LLVM_DEBUG(llvm::dbgs() <<
"Processing non-library module: "
246 << hwModule.getName() <<
"\n");
247 CutRewriter rewriter(options, patternSet);
248 return rewriter.run(hwModule);
assert(baseType &&"element must be base type")
RewritePatternSet pattern
static llvm::FailureOr< NPNClass > getNPNClassFromModule(hw::HWModuleOp module)
Manages a collection of rewriting patterns for combinational logic optimization.
Represents a cut in the combinational logic network.
void getPermutatedInputs(const NPNClass &patternNPN, SmallVectorImpl< Value > &permutedInputs) const
Get the permutated inputs for this cut based on the given pattern NPN.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
mlir::Operation * getRoot() const
Get the root operation of this cut.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Simple technology library encoded as a HWModuleOp.
StringRef getPatternName() const override
Get the name of this pattern. Used for debugging.
hw::HWModuleOp NPNClass npnClass
TechLibraryPattern(hw::HWModuleOp module, double area, SmallVector< SmallVector< DelayType, 2 >, 4 > delay, NPNClass npnClass)
double getArea() const override
Get the area cost of this pattern.
llvm::FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, Cut &cut) const override
Rewrite the cut set using this library primitive.
bool match(const Cut &cut) const override
Match the cut set against this library primitive.
bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const override
Enable truth table matching for this pattern.
unsigned getNumOutputs() const override
Get the number of outputs this pattern produces.
DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const override
Get the delay between specific input and output.
const SmallVector< SmallVector< DelayType, 2 >, 4 > delay
unsigned getNumInputs() const
LocationAttr getLoc() const override
Get location for this pattern(optional).
Represents the canonical form of a boolean function under NPN equivalence.
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
bool equivalentOtherThanPermutation(const NPNClass &other) const
Equality comparison for NPN classes.
Base class for cut rewriting patterns used in combinational logic optimization.
mlir::MLIRContext * getContext() const
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).