19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/ADT/APSInt.h"
22#include "llvm/ADT/DenseSet.h"
23#include "llvm/ADT/Hashing.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/ErrorHandling.h"
28#define DEBUG_TYPE "infer-widths"
32#define GEN_PASS_DEF_INFERWIDTHS
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
37using mlir::InferTypeOpInterface;
41using namespace firrtl;
49 auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
52 if (!basetype.hasUninferredWidth())
55 if (basetype.isGround())
56 diag.attachNote() <<
"Field: \"" << str <<
"\"";
57 else if (
auto vecType = type_dyn_cast<FVectorType>(basetype))
59 else if (
auto bundleType = type_dyn_cast<BundleType>(basetype))
60 for (
auto &elem : bundleType.getElements())
66 uint64_t convertedFieldID = 0;
68 auto curFID = fieldID;
73 if (isa<FVectorType>(curFType))
76 convertedFieldID += curFID - subID;
81 return convertedFieldID;
93template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
95inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const T &e) {
102template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
105 return e.hash_value();
110#define EXPR_NAMES(x) \
111 Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
112#define EXPR_KINDS EXPR_NAMES()
113#define EXPR_CLASSES EXPR_NAMES(Expr)
120 void print(llvm::raw_ostream &os)
const;
122 std::optional<int32_t> getSolution()
const {
128 void setSolution(int32_t solution) {
130 this->solution = solution;
133 Kind getKind()
const {
return kind; }
136 Expr(Kind kind) : kind(kind) {}
137 llvm::hash_code
hash_value()
const {
return llvm::hash_value(kind); }
142 bool hasSolution =
false;
146template <
class DerivedT, Expr::Kind DerivedKind>
147struct ExprBase :
public Expr {
148 ExprBase() : Expr(DerivedKind) {}
149 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
151 if (
auto otherSame = dyn_cast<DerivedT>(other))
152 return *
static_cast<DerivedT *
>(
this) == otherSame;
158struct VarExpr :
public ExprBase<VarExpr, Expr::Kind::Var> {
159 void print(llvm::raw_ostream &os)
const {
162 os <<
"var" << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFFF);
167 Expr *constraint =
nullptr;
170 Expr *upperBound =
nullptr;
171 std::optional<int32_t> upperBoundSolution;
178struct DerivedExpr :
public ExprBase<DerivedExpr, Expr::Kind::Derived> {
179 void print(llvm::raw_ostream &os)
const {
182 << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFF);
186 Expr *assigned =
nullptr;
203struct IdExpr :
public ExprBase<IdExpr, Expr::Kind::Id> {
204 IdExpr(Expr *arg) : arg(arg) {
assert(arg); }
205 void print(llvm::raw_ostream &os)
const { os <<
"*" << *arg; }
207 return getKind() == other.getKind() && arg == other.arg;
210 return llvm::hash_combine(Expr::hash_value(), arg);
218struct KnownExpr :
public ExprBase<KnownExpr, Expr::Kind::Known> {
219 KnownExpr(int32_t value) : ExprBase() { setSolution(value); }
220 void print(llvm::raw_ostream &os)
const { os << *getSolution(); }
221 bool operator==(
const KnownExpr &other)
const {
222 return *getSolution() == *other.getSolution();
225 return llvm::hash_combine(Expr::hash_value(), *getSolution());
227 int32_t getValue()
const {
return *getSolution(); }
232struct UnaryExpr :
public Expr {
233 bool operator==(
const UnaryExpr &other)
const {
234 return getKind() == other.getKind() && arg == other.arg;
237 return llvm::hash_combine(Expr::hash_value(), arg);
244 UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) {
assert(arg); }
248template <
class DerivedT, Expr::Kind DerivedKind>
249struct UnaryExprBase :
public UnaryExpr {
250 template <
typename... Args>
251 UnaryExprBase(Args &&...args)
252 : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
253 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
257struct PowExpr :
public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
258 using UnaryExprBase::UnaryExprBase;
259 void print(llvm::raw_ostream &os)
const { os <<
"2^" << arg; }
264struct BinaryExpr :
public Expr {
265 bool operator==(
const BinaryExpr &other)
const {
266 return getKind() == other.getKind() && lhs() == other.lhs() &&
267 rhs() == other.rhs();
270 return llvm::hash_combine(Expr::hash_value(), *args);
272 Expr *lhs()
const {
return args[0]; }
273 Expr *rhs()
const {
return args[1]; }
279 BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
286template <
class DerivedT, Expr::Kind DerivedKind>
287struct BinaryExprBase :
public BinaryExpr {
288 template <
typename... Args>
289 BinaryExprBase(Args &&...args)
290 : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
291 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
295struct AddExpr :
public BinaryExprBase<AddExpr, Expr::Kind::Add> {
296 using BinaryExprBase::BinaryExprBase;
297 void print(llvm::raw_ostream &os)
const {
298 os <<
"(" << *lhs() <<
" + " << *rhs() <<
")";
303struct MaxExpr :
public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
304 using BinaryExprBase::BinaryExprBase;
305 void print(llvm::raw_ostream &os)
const {
306 os <<
"max(" << *lhs() <<
", " << *rhs() <<
")";
311struct MinExpr :
public BinaryExprBase<MinExpr, Expr::Kind::Min> {
312 using BinaryExprBase::BinaryExprBase;
313 void print(llvm::raw_ostream &os)
const {
314 os <<
"min(" << *lhs() <<
", " << *rhs() <<
")";
318void Expr::print(llvm::raw_ostream &os)
const {
320 [&](
auto *e) { e->print(os); });
334struct InternedSlotInfo : DenseMapInfo<T *> {
335 static T *getEmptyKey() {
337 return static_cast<T *
>(pointer);
339 static T *getTombstoneKey() {
341 return static_cast<T *
>(pointer);
343 static unsigned getHashValue(
const T *val) {
return mlir::hash_value(*val); }
344 static bool isEqual(
const T *lhs,
const T *rhs) {
345 auto empty = getEmptyKey();
346 auto tombstone = getTombstoneKey();
347 if (lhs ==
empty || rhs ==
empty || lhs == tombstone || rhs == tombstone)
355template <
typename T,
typename std::enable_if_t<
356 std::is_trivially_destructible<T>::value,
int> = 0>
357class InternedAllocator {
358 llvm::DenseSet<T *, InternedSlotInfo<T>> interned;
359 llvm::BumpPtrAllocator &allocator;
362 InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
367 template <
typename R = T,
typename... Args>
368 std::pair<R *, bool> alloc(Args &&...args) {
369 auto stackValue = R(std::forward<Args>(args)...);
370 auto *stackSlot = &stackValue;
371 auto it = interned.find(stackSlot);
372 if (it != interned.end())
373 return std::make_pair(
static_cast<R *
>(*it),
false);
374 auto heapValue =
new (allocator) R(std::move(stackValue));
375 interned.insert(heapValue);
376 return std::make_pair(heapValue,
true);
382template <
typename T,
typename std::enable_if_t<
383 std::is_trivially_destructible<T>::value,
int> = 0>
385 llvm::BumpPtrAllocator &allocator;
388 Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
392 template <
typename R = T,
typename... Args>
393 R *alloc(Args &&...args) {
394 return new (allocator) R(std::forward<Args>(args)...);
425 int32_t recScale = 0;
427 int32_t nonrecBias = 0;
431 static LinIneq unsat() {
return LinIneq(
true); }
434 explicit LinIneq(
bool failed =
false) : failed(failed) {}
437 explicit LinIneq(int32_t bias) : nonrecBias(bias) {}
440 explicit LinIneq(int32_t scale, int32_t bias) {
451 explicit LinIneq(int32_t recScale, int32_t recBias, int32_t nonrecBias,
455 this->recScale = recScale;
456 this->recBias = recBias;
457 this->nonrecBias = nonrecBias;
459 this->nonrecBias = std::max(recBias, nonrecBias);
470 static LinIneq max(
const LinIneq &lhs,
const LinIneq &rhs) {
471 return LinIneq(std::max(lhs.recScale, rhs.recScale),
472 std::max(lhs.recBias, rhs.recBias),
473 std::max(lhs.nonrecBias, rhs.nonrecBias),
474 lhs.failed || rhs.failed);
530 static LinIneq add(
const LinIneq &lhs,
const LinIneq &rhs) {
533 auto enable1 = lhs.recScale > 0 && rhs.recScale > 0;
534 auto enable2 = lhs.recScale > 0;
535 auto enable3 = rhs.recScale > 0;
536 auto scale1 = lhs.recScale + rhs.recScale;
537 auto scale2 = lhs.recScale;
538 auto scale3 = rhs.recScale;
539 auto bias1 = lhs.recBias + rhs.recBias;
540 auto bias2 = lhs.recBias + rhs.nonrecBias;
541 auto bias3 = rhs.recBias + lhs.nonrecBias;
542 auto maxScale = std::max(scale1, std::max(scale2, scale3));
546 std::optional<int32_t> maxBias;
547 if (enable1 && scale1 == maxScale)
549 if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
551 if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
556 auto nonrecBias = lhs.nonrecBias + rhs.nonrecBias;
557 auto failed = lhs.failed || rhs.failed;
558 if (enable1 && scale1 == maxScale && bias1 == *maxBias)
559 return LinIneq(scale1, bias1, nonrecBias, failed);
560 if (enable2 && scale2 == maxScale && bias2 == *maxBias)
561 return LinIneq(scale2, bias2, nonrecBias, failed);
562 if (enable3 && scale3 == maxScale && bias3 == *maxBias)
563 return LinIneq(scale3, bias3, nonrecBias, failed);
564 return LinIneq(0, 0, nonrecBias, failed);
577 if (recScale == 1 && recBias > 0)
583 void print(llvm::raw_ostream &os)
const {
585 bool both = (recScale != 0 || recBias != 0) && nonrecBias != 0;
592 os << recScale <<
"*";
598 os <<
" - " << -recBias;
600 os <<
" + " << recBias;
608 if (nonrecBias != 0) {
625class ConstraintSolver {
627 ConstraintSolver() =
default;
630 auto *v = vars.alloc();
631 varExprs.push_back(v);
633 info[v].insert(currentInfo);
635 locs[v].insert(*currentLoc);
638 DerivedExpr *derived() {
639 auto *d = derivs.alloc();
640 derivedExprs.push_back(d);
643 KnownExpr *known(int32_t value) {
return alloc<KnownExpr>(knowns, value); }
644 IdExpr *id(Expr *arg) {
return alloc<IdExpr>(ids, arg); }
645 PowExpr *pow(Expr *arg) {
return alloc<PowExpr>(uns, arg); }
646 AddExpr *add(Expr *lhs, Expr *rhs) {
return alloc<AddExpr>(bins, lhs, rhs); }
647 MaxExpr *max(Expr *lhs, Expr *rhs) {
return alloc<MaxExpr>(bins, lhs, rhs); }
648 MinExpr *min(Expr *lhs, Expr *rhs) {
return alloc<MinExpr>(bins, lhs, rhs); }
652 Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
654 lhs->constraint = max(lhs->constraint, rhs);
656 lhs->constraint = id(rhs);
657 return lhs->constraint;
662 Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
664 lhs->upperBound = min(lhs->upperBound, rhs);
666 lhs->upperBound = id(rhs);
667 return lhs->upperBound;
670 void dumpConstraints(llvm::raw_ostream &os);
671 LogicalResult solve();
673 using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
674 const ContextInfo &getContextInfo()
const {
return info; }
675 void setCurrentContextInfo(
FieldRef fieldRef) { currentInfo = fieldRef; }
676 void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
680 llvm::BumpPtrAllocator allocator;
681 Allocator<VarExpr> vars = {allocator};
682 Allocator<DerivedExpr> derivs = {allocator};
683 InternedAllocator<KnownExpr> knowns = {allocator};
684 InternedAllocator<IdExpr> ids = {allocator};
685 InternedAllocator<UnaryExpr> uns = {allocator};
686 InternedAllocator<BinaryExpr> bins = {allocator};
689 std::vector<VarExpr *> varExprs;
690 std::vector<DerivedExpr *> derivedExprs;
693 template <
typename R,
typename T,
typename... Args>
694 R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
695 auto [expr, inserted] =
696 allocator.template alloc<R>(std::forward<Args>(args)...);
698 info[expr].insert(currentInfo);
700 locs[expr].insert(*currentLoc);
708 DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
709 std::optional<Location> currentLoc;
713 ConstraintSolver(ConstraintSolver &&) =
delete;
714 ConstraintSolver(
const ConstraintSolver &) =
delete;
715 ConstraintSolver &operator=(ConstraintSolver &&) =
delete;
716 ConstraintSolver &operator=(
const ConstraintSolver &) =
delete;
718 void emitUninferredWidthError(VarExpr *var);
720 LinIneq checkCycles(VarExpr *var, Expr *expr,
721 SmallPtrSetImpl<Expr *> &seenVars,
722 InFlightDiagnostic *reportInto =
nullptr,
723 unsigned indent = 1);
729void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
730 for (
auto *v : varExprs) {
732 os <<
"- " << *v <<
" >= " << *v->constraint <<
"\n";
734 os <<
"- " << *v <<
" unconstrained\n";
739inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const LinIneq &l) {
755LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
756 SmallPtrSetImpl<Expr *> &seenVars,
757 InFlightDiagnostic *reportInto,
760 TypeSwitch<Expr *, LinIneq>(expr)
762 [&](
auto *expr) {
return LinIneq(expr->getValue()); })
763 .Case<VarExpr>([&](
auto *expr) {
765 return LinIneq(1, 0);
766 if (!seenVars.insert(expr).second)
773 if (!expr->constraint)
776 auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
778 seenVars.erase(expr);
781 .Case<IdExpr>([&](
auto *expr) {
782 return checkCycles(var, expr->arg, seenVars, reportInto,
785 .Case<PowExpr>([&](
auto *expr) {
790 checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
791 if (arg.recScale != 0 || arg.nonrecBias < 0 || arg.nonrecBias >= 31)
792 return LinIneq::unsat();
793 return LinIneq(1 << arg.nonrecBias);
795 .Case<AddExpr>([&](
auto *expr) {
797 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
798 checkCycles(var, expr->rhs(), seenVars, reportInto,
801 .Case<MaxExpr, MinExpr>([&](
auto *expr) {
806 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
807 checkCycles(var, expr->rhs(), seenVars, reportInto,
810 .Default([](
auto) {
return LinIneq::unsat(); });
815 if (reportInto && !ineq.sat()) {
816 auto report = [&](Location loc) {
817 auto ¬e = reportInto->attachNote(loc);
818 note <<
"constrained width W >= ";
819 if (ineq.recScale == -1)
821 if (ineq.recScale != 1)
822 note << ineq.recScale;
824 if (ineq.recBias < 0)
825 note <<
"-" << -ineq.recBias;
826 if (ineq.recBias > 0)
827 note <<
"+" << ineq.recBias;
830 auto it = locs.find(expr);
831 if (it != locs.end())
832 for (
auto loc : it->second)
836 LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
837 <<
"- Visited " << *expr <<
": " << ineq <<
"\n");
847 arg.first = operation(*arg.first);
853 llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
854 auto result =
ExprSolution{std::nullopt, lhs.second || rhs.second};
855 if (lhs.first && rhs.first)
856 result.first = operation(*lhs.first, *rhs.first);
858 result.first = lhs.first;
860 result.first = rhs.first;
866 Frame(Expr *expr,
unsigned indent) : expr(expr), indent(indent) {}
879 std::vector<Frame> &worklist) {
881 worklist.emplace_back(expr, 1);
882 llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
884 while (!worklist.empty()) {
885 auto &frame = worklist.back();
886 auto indent = frame.indent;
889 if (solution.first && !solution.second)
890 frame.expr->setSolution(*solution.first);
891 solvedExprs[frame.expr] = solution;
895 if (!isa<KnownExpr>(frame.expr)) {
897 llvm::dbgs().indent(indent * 2)
898 <<
"= Solved " << *frame.expr <<
" = " << *solution.first;
900 llvm::dbgs().indent(indent * 2) <<
"= Skipped " << *frame.expr;
901 llvm::dbgs() <<
" (" << (solution.second ?
"cycle broken" :
"unique")
910 if (frame.expr->getSolution()) {
912 if (!isa<KnownExpr>(frame.expr))
913 llvm::dbgs().indent(indent * 2) <<
"- Cached " << *frame.expr <<
" = "
914 << *frame.expr->getSolution() <<
"\n";
916 setSolution(
ExprSolution{*frame.expr->getSolution(),
false});
922 if (!isa<KnownExpr>(frame.expr))
923 llvm::dbgs().indent(indent * 2) <<
"- Solving " << *frame.expr <<
"\n";
926 TypeSwitch<Expr *>(frame.expr)
927 .Case<KnownExpr>([&](
auto *expr) {
930 .Case<VarExpr>([&](
auto *expr) {
931 if (solvedExprs.contains(expr->constraint)) {
932 auto solution = solvedExprs[expr->constraint];
939 if (expr->upperBound && solvedExprs.contains(expr->upperBound))
940 expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
941 seenVars.erase(expr);
943 if (solution.first && *solution.first < 0)
945 return setSolution(solution);
949 if (!expr->constraint)
954 if (!seenVars.insert(expr).second)
957 worklist.emplace_back(expr->constraint, indent + 1);
958 if (expr->upperBound)
959 worklist.emplace_back(expr->upperBound, indent + 1);
961 .Case<IdExpr>([&](
auto *expr) {
962 if (solvedExprs.contains(expr->arg))
963 return setSolution(solvedExprs[expr->arg]);
964 worklist.emplace_back(expr->arg, indent + 1);
966 .Case<PowExpr>([&](
auto *expr) {
967 if (solvedExprs.contains(expr->arg))
969 solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
971 worklist.emplace_back(expr->arg, indent + 1);
973 .Case<AddExpr>([&](
auto *expr) {
974 if (solvedExprs.contains(expr->lhs()) &&
975 solvedExprs.contains(expr->rhs()))
977 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
978 [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
980 worklist.emplace_back(expr->lhs(), indent + 1);
981 worklist.emplace_back(expr->rhs(), indent + 1);
983 .Case<MaxExpr>([&](
auto *expr) {
984 if (solvedExprs.contains(expr->lhs()) &&
985 solvedExprs.contains(expr->rhs()))
987 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
988 [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
990 worklist.emplace_back(expr->lhs(), indent + 1);
991 worklist.emplace_back(expr->rhs(), indent + 1);
993 .Case<MinExpr>([&](
auto *expr) {
994 if (solvedExprs.contains(expr->lhs()) &&
995 solvedExprs.contains(expr->rhs()))
997 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
998 [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
1000 worklist.emplace_back(expr->lhs(), indent + 1);
1001 worklist.emplace_back(expr->rhs(), indent + 1);
1003 .Default([&](
auto) {
1008 return solvedExprs[expr];
1014LogicalResult ConstraintSolver::solve() {
1016 llvm::dbgs() <<
"\n";
1018 dumpConstraints(llvm::dbgs());
1023 llvm::dbgs() <<
"\n";
1024 debugHeader(
"Checking for unbreakable loops") <<
"\n\n";
1026 SmallPtrSet<Expr *, 16> seenVars;
1027 bool anyFailed =
false;
1029 for (
auto *var : varExprs) {
1030 if (!var->constraint)
1032 LLVM_DEBUG(llvm::dbgs()
1033 <<
"- Checking " << *var <<
" >= " << *var->constraint <<
"\n");
1038 seenVars.insert(var);
1039 auto ineq = checkCycles(var, var->constraint, seenVars);
1048 LLVM_DEBUG(llvm::dbgs()
1049 <<
" = Breakable since " << ineq <<
" satisfiable\n");
1057 LLVM_DEBUG(llvm::dbgs()
1058 <<
" = UNBREAKABLE since " << ineq <<
" unsatisfiable\n");
1060 for (
auto fieldRef : info.find(var)->second) {
1063 auto *op = fieldRef.getDefiningOp();
1064 auto diag = op ? op->emitOpError()
1065 : mlir::emitError(fieldRef.getValue().getLoc())
1067 diag <<
"is constrained to be wider than itself";
1070 seenVars.insert(var);
1071 checkCycles(var, var->constraint, seenVars, &diag);
1083 llvm::dbgs() <<
"\n";
1086 std::vector<Frame> worklist;
1087 for (
auto *var : varExprs) {
1089 if (!var->constraint) {
1090 LLVM_DEBUG(llvm::dbgs() <<
"- Unconstrained " << *var <<
"\n");
1091 emitUninferredWidthError(var);
1097 LLVM_DEBUG(llvm::dbgs()
1098 <<
"- Solving " << *var <<
" >= " << *var->constraint <<
"\n");
1099 seenVars.insert(var);
1100 auto solution =
solveExpr(var->constraint, seenVars, worklist);
1102 if (var->upperBound && !var->upperBoundSolution)
1103 var->upperBoundSolution =
1104 solveExpr(var->upperBound, seenVars, worklist).first;
1108 if (solution.first) {
1109 if (*solution.first < 0)
1111 var->setSolution(*solution.first);
1116 if (!solution.first) {
1117 LLVM_DEBUG(llvm::dbgs() <<
" - UNSOLVED " << *var <<
"\n");
1118 emitUninferredWidthError(var);
1122 LLVM_DEBUG(llvm::dbgs()
1123 <<
" = Solved " << *var <<
" = " << solution.first <<
" ("
1124 << (solution.second ?
"cycle broken" :
"unique") <<
")\n");
1127 if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1128 LLVM_DEBUG(llvm::dbgs() <<
" ! Unsatisfiable " << *var
1129 <<
" <= " << var->upperBoundSolution <<
"\n");
1130 emitUninferredWidthError(var);
1136 for (
auto *derived : derivedExprs) {
1137 auto *assigned = derived->assigned;
1138 if (!assigned || !assigned->getSolution()) {
1139 LLVM_DEBUG(llvm::dbgs() <<
"- Unused " << *derived <<
" set to 0\n");
1140 derived->setSolution(0);
1142 LLVM_DEBUG(llvm::dbgs() <<
"- Deriving " << *derived <<
" = "
1143 << assigned->getSolution() <<
"\n");
1144 derived->setSolution(*assigned->getSolution());
1148 return failure(anyFailed);
1153void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1154 FieldRef fieldRef = info.find(var)->second.back();
1157 auto diag = mlir::emitError(value.getLoc(),
"uninferred width:");
1160 if (isa<BlockArgument>(value)) {
1162 }
else if (
auto *op = value.getDefiningOp()) {
1163 TypeSwitch<Operation *>(op)
1164 .Case<WireOp>([&](
auto) { diag <<
" wire"; })
1165 .Case<RegOp, RegResetOp>([&](
auto) { diag <<
" reg"; })
1166 .Case<NodeOp>([&](
auto) { diag <<
" node"; })
1167 .Default([&](
auto) { diag <<
" value"; });
1174 if (!fieldName.empty()) {
1177 diag <<
" \"" << fieldName <<
"\"";
1180 if (!var->constraint) {
1181 diag <<
" is unconstrained";
1182 }
else if (var->getSolution() && var->upperBoundSolution &&
1183 var->getSolution() > var->upperBoundSolution) {
1184 diag <<
" cannot satisfy all width requirements";
1185 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1186 LLVM_DEBUG(llvm::dbgs() << *var->upperBound <<
"\n");
1187 auto loc = locs.find(var->constraint)->second.back();
1188 diag.attachNote(loc) <<
"width is constrained to be at least "
1189 << *var->getSolution() <<
" here:";
1190 loc = locs.find(var->upperBound)->second.back();
1191 diag.attachNote(loc) <<
"width is constrained to be at most "
1192 << *var->upperBoundSolution <<
" here:";
1194 diag <<
" width cannot be determined";
1195 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1196 auto loc = locs.find(var->constraint)->second.back();
1197 diag.attachNote(loc) <<
"width is constrained by an uninferred width here:";
1209class InferenceMapping {
1211 InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1213 : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1215 LogicalResult map(CircuitOp op);
1216 bool allWidthsKnown(Operation *op);
1217 LogicalResult mapOperation(Operation *op);
1222 void declareVars(Value value,
bool isDerived =
false);
1227 void maximumOfTypes(Value result, Value rhs, Value lhs);
1231 void constrainTypes(Value larger, Value smaller,
bool equal =
false);
1235 void constrainTypes(Expr *larger, Expr *smaller,
1236 bool imposeUpperBounds =
false,
bool equal =
false);
1245 Expr *getExpr(Value value)
const;
1248 Expr *getExpr(
FieldRef fieldRef)
const;
1252 Expr *getExprOrNull(
FieldRef fieldRef)
const;
1256 void setExpr(Value value, Expr *expr);
1259 void setExpr(
FieldRef fieldRef, Expr *expr);
1262 bool isModuleSkipped(FModuleOp module)
const {
1263 return skippedModules.count(module);
1267 bool areAllModulesSkipped()
const {
return allModulesSkipped; }
1271 ConstraintSolver &solver;
1274 DenseMap<FieldRef, Expr *> opExprs;
1277 SmallPtrSet<Operation *, 16> skippedModules;
1278 bool allModulesSkipped =
true;
1281 SymbolTable &symtbl;
1291 return TypeSwitch<Type, bool>(type)
1292 .Case<
FIRRTLBaseType>([](
auto base) {
return base.hasUninferredWidth(); })
1294 [](
auto ref) {
return ref.getType().hasUninferredWidth(); })
1295 .Default([](
auto) {
return false; });
1298LogicalResult InferenceMapping::map(CircuitOp op) {
1299 LLVM_DEBUG(llvm::dbgs()
1300 <<
"\n===----- Mapping ops to constraint exprs -----===\n\n");
1303 for (
auto module : op.getOps<FModuleOp>())
1304 for (auto arg : module.getArguments()) {
1305 solver.setCurrentContextInfo(
FieldRef(arg, 0));
1309 for (
auto module : op.getOps<FModuleOp>()) {
1312 bool anyUninferred =
false;
1313 for (
auto arg : module.getArguments()) {
1318 module.walk([&](Operation *op) {
1319 for (auto type : op->getResultTypes())
1320 anyUninferred |= hasUninferredWidth(type);
1322 return WalkResult::interrupt();
1323 return WalkResult::advance();
1326 if (!anyUninferred) {
1327 LLVM_DEBUG(llvm::dbgs() <<
"Skipping fully-inferred module '"
1328 << module.getName() <<
"'\n");
1329 skippedModules.insert(module);
1333 allModulesSkipped =
false;
1337 auto result =
module.getBodyBlock()->walk(
1338 [&](Operation *op) { return WalkResult(mapOperation(op)); });
1339 if (result.wasInterrupted())
1346bool InferenceMapping::allWidthsKnown(Operation *op) {
1348 if (isa<PropAssignOp>(op))
1353 if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1358 if (isa<FConnectLike, AttachOp>(op))
1362 return llvm::all_of(op->getResults(), [&](
auto result) {
1365 if (auto type = type_dyn_cast<FIRRTLType>(result.getType()))
1366 if (hasUninferredWidth(type))
1372LogicalResult InferenceMapping::mapOperation(Operation *op) {
1373 if (allWidthsKnown(op))
1377 bool mappingFailed =
false;
1378 solver.setCurrentContextInfo(
1380 solver.setCurrentLocation(op->getLoc());
1381 TypeSwitch<Operation *>(op)
1382 .Case<ConstantOp>([&](
auto op) {
1385 auto v = op.getValue();
1386 auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1387 : v.countLeadingZeros());
1390 setExpr(op.getResult(), solver.known(std::max(w, 1u)));
1392 .Case<SpecialConstantOp>([&](
auto op) {
1395 .Case<InvalidValueOp>([&](
auto op) {
1398 declareVars(op.getResult(),
true);
1400 .Case<WireOp, RegOp>([&](
auto op) { declareVars(op.getResult()); })
1401 .Case<RegResetOp>([&](
auto op) {
1406 declareVars(op.getResult());
1409 constrainTypes(op.getResult(), op.getResetValue());
1411 .Case<NodeOp>([&](
auto op) {
1414 op.getResult().getType());
1418 .Case<SubfieldOp>([&](
auto op) {
1419 BundleType bundleType = op.getInput().getType();
1420 auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1421 unifyTypes(
FieldRef(op.getResult(), 0),
1422 FieldRef(op.getInput(), fieldID), op.getType());
1424 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
1430 .Case<SubtagOp>([&](
auto op) {
1431 FEnumType enumType = op.getInput().getType();
1432 auto fieldID = enumType.getFieldID(op.getFieldIndex());
1433 unifyTypes(
FieldRef(op.getResult(), 0),
1434 FieldRef(op.getInput(), fieldID), op.getType());
1437 .Case<RefSubOp>([&](RefSubOp op) {
1438 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1439 op.getInput().getType().getType())
1440 .Case<FVectorType>([](
auto _) {
return 1; })
1441 .Case<BundleType>([&](
auto type) {
1442 return type.getFieldID(op.getIndex());
1444 unifyTypes(
FieldRef(op.getResult(), 0),
1445 FieldRef(op.getInput(), fieldID), op.getType());
1449 .Case<AddPrimOp, SubPrimOp>([&](
auto op) {
1450 auto lhs = getExpr(op.getLhs());
1451 auto rhs = getExpr(op.getRhs());
1452 auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1453 setExpr(op.getResult(), e);
1455 .Case<MulPrimOp>([&](
auto op) {
1456 auto lhs = getExpr(op.getLhs());
1457 auto rhs = getExpr(op.getRhs());
1458 auto e = solver.add(lhs, rhs);
1459 setExpr(op.getResult(), e);
1461 .Case<DivPrimOp>([&](
auto op) {
1462 auto lhs = getExpr(op.getLhs());
1464 if (op.getType().base().isSigned()) {
1465 e = solver.add(lhs, solver.known(1));
1469 setExpr(op.getResult(), e);
1471 .Case<RemPrimOp>([&](
auto op) {
1472 auto lhs = getExpr(op.getLhs());
1473 auto rhs = getExpr(op.getRhs());
1474 auto e = solver.min(lhs, rhs);
1475 setExpr(op.getResult(), e);
1477 .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](
auto op) {
1478 auto lhs = getExpr(op.getLhs());
1479 auto rhs = getExpr(op.getRhs());
1480 auto e = solver.max(lhs, rhs);
1481 setExpr(op.getResult(), e);
1485 .Case<CatPrimOp>([&](
auto op) {
1486 auto lhs = getExpr(op.getLhs());
1487 auto rhs = getExpr(op.getRhs());
1488 auto e = solver.add(lhs, rhs);
1489 setExpr(op.getResult(), e);
1491 .Case<DShlPrimOp>([&](
auto op) {
1492 auto lhs = getExpr(op.getLhs());
1493 auto rhs = getExpr(op.getRhs());
1494 auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1495 setExpr(op.getResult(), e);
1497 .Case<DShlwPrimOp, DShrPrimOp>([&](
auto op) {
1498 auto e = getExpr(op.getLhs());
1499 setExpr(op.getResult(), e);
1503 .Case<NegPrimOp>([&](
auto op) {
1504 auto input = getExpr(op.getInput());
1505 auto e = solver.add(input, solver.known(1));
1506 setExpr(op.getResult(), e);
1508 .Case<CvtPrimOp>([&](
auto op) {
1509 auto input = getExpr(op.getInput());
1510 auto e = op.getInput().getType().base().isSigned()
1512 : solver.add(input, solver.known(1));
1513 setExpr(op.getResult(), e);
1517 .Case<BitsPrimOp>([&](
auto op) {
1518 setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1520 .Case<HeadPrimOp>([&](
auto op) {
1521 setExpr(op.getResult(), solver.known(op.getAmount()));
1523 .Case<TailPrimOp>([&](
auto op) {
1524 auto input = getExpr(op.getInput());
1525 auto e = solver.add(input, solver.known(-op.getAmount()));
1526 setExpr(op.getResult(), e);
1528 .Case<PadPrimOp>([&](
auto op) {
1529 auto input = getExpr(op.getInput());
1530 auto e = solver.max(input, solver.known(op.getAmount()));
1531 setExpr(op.getResult(), e);
1533 .Case<ShlPrimOp>([&](
auto op) {
1534 auto input = getExpr(op.getInput());
1535 auto e = solver.add(input, solver.known(op.getAmount()));
1536 setExpr(op.getResult(), e);
1538 .Case<ShrPrimOp>([&](
auto op) {
1539 auto input = getExpr(op.getInput());
1541 auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1542 auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1543 solver.known(minWidth));
1544 setExpr(op.getResult(), e);
1548 .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1549 [&](
auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1550 .Case<mlir::UnrealizedConversionCastOp>(
1551 [&](
auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1555 .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1556 AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1557 XorRPrimOp>([&](
auto op) {
1558 auto width = op.getType().getBitWidthOrSentinel();
1559 assert(width > 0 &&
"width should have been checked by verifier");
1560 setExpr(op.getResult(), solver.known(width));
1562 .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](
auto op) {
1563 auto *sel = getExpr(op.getSel());
1564 constrainTypes(solver.known(1), sel,
true);
1565 maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1567 .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1568 auto *sel = getExpr(op.getSel());
1569 constrainTypes(solver.known(2), sel,
true);
1570 maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1571 maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1572 maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1575 .Case<ConnectOp, MatchingConnectOp>(
1576 [&](
auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1577 .Case<RefDefineOp>([&](
auto op) {
1580 constrainTypes(op.getDest(), op.getSrc(),
true);
1582 .Case<AttachOp>([&](
auto op) {
1586 if (op.getAttached().empty())
1588 auto prev = op.getAttached()[0];
1589 for (
auto operand : op.getAttached().drop_front()) {
1590 auto e1 = getExpr(prev);
1591 auto e2 = getExpr(operand);
1592 constrainTypes(e1, e2,
true);
1593 constrainTypes(e2, e1,
true);
1599 .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1600 UnclockedAssumeIntrinsicOp, CoverOp>([&](
auto) {})
1603 .Case<InstanceOp>([&](
auto op) {
1604 auto refdModule = op.getReferencedOperation(symtbl);
1605 auto module = dyn_cast<FModuleOp>(&*refdModule);
1607 auto diag = mlir::emitError(op.getLoc());
1608 diag <<
"extern module `" << op.getModuleName()
1609 <<
"` has ports of uninferred width";
1611 auto fml = cast<FModuleLike>(&*refdModule);
1612 auto ports = fml.getPorts();
1613 for (
auto &port : ports) {
1615 if (baseType && baseType.hasUninferredWidth()) {
1616 diag.attachNote(op.getLoc()) <<
"Port: " << port.name;
1617 if (!baseType.isGround())
1622 diag.attachNote(op.getLoc())
1623 <<
"Only non-extern FIRRTL modules may contain unspecified "
1624 "widths to be inferred automatically.";
1625 diag.attachNote(refdModule->getLoc())
1626 <<
"Module `" << op.getModuleName() <<
"` defined here:";
1627 mappingFailed =
true;
1634 for (
auto [result, arg] :
1635 llvm::zip(op->getResults(), module.getArguments()))
1636 unifyTypes({result, 0}, {arg, 0},
1637 type_cast<FIRRTLType>(result.getType()));
1641 .Case<MemOp>([&](MemOp op) {
1643 unsigned nonDebugPort = 0;
1644 for (
const auto &result :
llvm::enumerate(op.getResults())) {
1645 declareVars(result.value());
1646 if (!type_isa<RefType>(result.value().getType()))
1647 nonDebugPort = result.index();
1652 auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1653 static const unsigned indices[] = {3, 5};
1654 static const unsigned debug[] = {0};
1656 case MemOp::PortKind::Read:
1657 case MemOp::PortKind::Write:
1658 return ArrayRef<unsigned>(indices, 1);
1659 case MemOp::PortKind::ReadWrite:
1660 return ArrayRef<unsigned>(indices);
1661 case MemOp::PortKind::Debug:
1662 return ArrayRef<unsigned>(
debug);
1664 llvm_unreachable(
"Imposible PortKind");
1671 unsigned firstFieldIndex =
1672 dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1674 op.getResult(nonDebugPort),
1675 type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1676 .getFieldID(firstFieldIndex));
1677 LLVM_DEBUG(llvm::dbgs() <<
"Adjusting memory port variables:\n");
1680 auto dataType = op.getDataType();
1681 for (
unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1682 auto result = op.getResult(i);
1683 if (type_isa<RefType>(result.getType())) {
1687 unifyTypes(firstData,
FieldRef(result, 1), dataType);
1692 type_cast<BundleType>(op.getPortType(i).getPassiveType());
1693 for (
auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1695 firstData, dataType);
1699 .Case<RefSendOp>([&](
auto op) {
1700 declareVars(op.getResult());
1701 constrainTypes(op.getResult(), op.getBase(),
true);
1703 .Case<RefResolveOp>([&](
auto op) {
1704 declareVars(op.getResult());
1705 constrainTypes(op.getResult(), op.getRef(),
true);
1707 .Case<RefCastOp>([&](
auto op) {
1708 declareVars(op.getResult());
1709 constrainTypes(op.getResult(), op.getInput(),
true);
1711 .Case<RWProbeOp>([&](
auto op) {
1712 auto ist = irn.lookup(op.getTarget());
1714 op->emitError(
"target of rwprobe could not be resolved");
1715 mappingFailed =
true;
1720 op->emitError(
"target of rwprobe resolved to unsupported target");
1721 mappingFailed =
true;
1725 ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1726 unifyTypes(
FieldRef(op.getResult(), 0),
1727 FieldRef(ref.getValue(), newFID), op.getType());
1729 .Case<mlir::UnrealizedConversionCastOp>([&](
auto op) {
1730 for (Value result : op.getResults()) {
1731 auto ty = result.getType();
1732 if (type_isa<FIRRTLType>(ty))
1733 declareVars(result);
1736 .Default([&](
auto op) {
1737 op->emitOpError(
"not supported in width inference");
1738 mappingFailed =
true;
1742 if (
auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1746 return failure(mappingFailed);
1751void InferenceMapping::declareVars(Value value,
bool isDerived) {
1754 unsigned fieldID = 0;
1756 auto width = type.getBitWidthOrSentinel();
1759 }
else if (width == -1) {
1762 solver.setCurrentContextInfo(field);
1764 setExpr(field, solver.derived());
1766 setExpr(field, solver.var());
1768 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1771 for (
auto &element : bundleType)
1772 declare(element.type);
1773 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1775 auto save = fieldID;
1776 declare(vecType.getElementType());
1778 fieldID = save + vecType.getMaxFieldID();
1779 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1781 for (
auto &element : enumType.getElements())
1782 declare(element.type);
1784 llvm_unreachable(
"Unknown type inside a bundle!");
1794void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1798 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1800 for (
auto &element : bundleType.getElements())
1801 maximize(element.type);
1802 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1804 auto save = fieldID;
1806 if (vecType.getNumElements() > 0)
1807 maximize(vecType.getElementType());
1808 fieldID = save + vecType.getMaxFieldID();
1809 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1811 for (
auto &element : enumType.getElements())
1812 maximize(element.type);
1813 }
else if (type.isGround()) {
1814 auto *e = solver.max(getExpr(
FieldRef(rhs, fieldID)),
1816 setExpr(
FieldRef(result, fieldID), e);
1819 llvm_unreachable(
"Unknown type inside a bundle!");
1834void InferenceMapping::constrainTypes(Value larger, Value smaller,
bool equal) {
1841 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1843 for (
auto &element : bundleType.getElements()) {
1845 constrain(element.type, smaller, larger);
1847 constrain(element.type, larger, smaller);
1849 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1851 auto save = fieldID;
1853 if (vecType.getNumElements() > 0) {
1854 constrain(vecType.getElementType(), larger, smaller);
1856 fieldID = save + vecType.getMaxFieldID();
1857 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1859 for (
auto &element : enumType.getElements())
1860 constrain(element.type, larger, smaller);
1861 }
else if (type.isGround()) {
1863 constrainTypes(getExpr(
FieldRef(larger, fieldID)),
1864 getExpr(
FieldRef(smaller, fieldID)),
false, equal);
1867 llvm_unreachable(
"Unknown type inside a bundle!");
1872 constrain(type, larger, smaller);
1877void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1878 bool imposeUpperBounds,
bool equal) {
1879 assert(larger &&
"Larger expression should be specified");
1880 assert(smaller &&
"Smaller expression should be specified");
1886 if (
auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1887 largerDerived->assigned = smaller;
1888 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *largerDerived <<
" from "
1889 << *smaller <<
"\n");
1892 if (
auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1893 smallerDerived->assigned = larger;
1894 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *smallerDerived <<
" from "
1895 << *larger <<
"\n");
1901 if (
auto *largerVar = dyn_cast<VarExpr>(larger)) {
1902 [[maybe_unused]]
auto *c = solver.addGeqConstraint(largerVar, smaller);
1903 LLVM_DEBUG(llvm::dbgs()
1904 <<
"Constrained " << *largerVar <<
" >= " << *c <<
"\n");
1911 [[maybe_unused]]
auto *leq = solver.addLeqConstraint(largerVar, smaller);
1912 LLVM_DEBUG(llvm::dbgs()
1913 <<
"Constrained " << *largerVar <<
" <= " << *leq <<
"\n");
1923 if (
auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1924 if (imposeUpperBounds || equal) {
1925 [[maybe_unused]]
auto *c = solver.addLeqConstraint(smallerVar, larger);
1926 LLVM_DEBUG(llvm::dbgs()
1927 <<
"Constrained " << *smallerVar <<
" <= " << *c <<
"\n");
1947 LLVM_DEBUG(llvm::dbgs()
1948 <<
"Unify " <<
getFieldName(lhsFieldRef).first <<
" = "
1951 if (
auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1952 solver.addGeqConstraint(var, solver.known(0));
1953 setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1955 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1957 for (
auto &element : bundleType) {
1958 unify(element.type);
1960 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1962 auto save = fieldID;
1964 if (vecType.getNumElements() > 0) {
1965 unify(vecType.getElementType());
1967 fieldID = save + vecType.getMaxFieldID();
1968 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1970 for (
auto &element : enumType.getElements())
1971 unify(element.type);
1973 llvm_unreachable(
"Unknown type inside a bundle!");
1981Expr *InferenceMapping::getExpr(Value value)
const {
1984 return getExpr(
FieldRef(value, 0));
1988Expr *InferenceMapping::getExpr(
FieldRef fieldRef)
const {
1989 auto *expr = getExprOrNull(fieldRef);
1990 assert(expr &&
"constraint expr should have been constructed for value");
1994Expr *InferenceMapping::getExprOrNull(
FieldRef fieldRef)
const {
1995 auto it = opExprs.find(fieldRef);
1996 if (it != opExprs.end())
2003 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
2006 return solver.known(width);
2010void InferenceMapping::setExpr(Value value, Expr *expr) {
2017void InferenceMapping::setExpr(
FieldRef fieldRef, Expr *expr) {
2019 llvm::dbgs() <<
"Expr " << *expr <<
" for " << fieldRef.
getValue();
2021 llvm::dbgs() <<
" '" <<
getFieldName(fieldRef).first <<
"'";
2023 if (fieldName.second)
2024 llvm::dbgs() <<
" (\"" << fieldName.first <<
"\")";
2025 llvm::dbgs() <<
"\n";
2027 opExprs[fieldRef] = expr;
2037class InferenceTypeUpdate {
2039 InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2041 LogicalResult update(CircuitOp op);
2042 FailureOr<bool> updateOperation(Operation *op);
2043 FailureOr<bool> updateValue(Value value);
2047 const InferenceMapping &mapping;
2053LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2055 llvm::dbgs() <<
"\n";
2058 return mlir::failableParallelForEach(
2059 op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2062 if (mapping.isModuleSkipped(op))
2064 auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2065 if (failed(updateOperation(op)))
2066 return WalkResult::interrupt();
2067 return WalkResult::advance();
2068 }).wasInterrupted();
2069 return failure(isFailed);
2074FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2075 bool anyChanged =
false;
2077 for (Value v : op->getResults()) {
2078 auto result = updateValue(v);
2081 anyChanged |= *result;
2087 if (
auto con = dyn_cast<ConnectOp>(op)) {
2088 auto lhs = con.getDest();
2089 auto rhs = con.getSrc();
2090 auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2091 auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2094 if (!lhsType || !rhsType)
2097 auto lhsWidth = lhsType.getBitWidthOrSentinel();
2098 auto rhsWidth = rhsType.getBitWidthOrSentinel();
2099 if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2100 OpBuilder builder(op);
2101 auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2102 rhsWidth - lhsWidth);
2103 if (type_isa<SIntType>(rhsType))
2105 builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2107 LLVM_DEBUG(llvm::dbgs()
2108 <<
"Truncating RHS to " << lhsType <<
" in " << con <<
"\n");
2109 con->replaceUsesOfWith(con.getSrc(), trunc);
2115 if (
auto module = dyn_cast<FModuleOp>(op)) {
2117 bool argsChanged =
false;
2118 SmallVector<Attribute> argTypes;
2119 argTypes.reserve(module.getNumPorts());
2120 for (
auto arg : module.getArguments()) {
2121 auto result = updateValue(arg);
2124 argsChanged |= *result;
2125 argTypes.push_back(TypeAttr::get(arg.getType()));
2130 module.setPortTypesAttr(ArrayAttr::get(module.getContext(), argTypes));
2139 auto *context = type.getContext();
2141 .
Case<UIntType>([&](
auto type) {
2142 return UIntType::get(context, newWidth, type.
isConst());
2144 .Case<SIntType>([&](
auto type) {
2145 return SIntType::get(context, newWidth, type.
isConst());
2147 .Case<AnalogType>([&](
auto type) {
2148 return AnalogType::get(context, newWidth, type.
isConst());
2150 .Default([&](
auto type) {
return type; });
2154FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2156 auto type = type_dyn_cast<FIRRTLType>(value.getType());
2167 if (
auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2168 SmallVector<Type, 2> types;
2170 op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2171 op->getAttrDictionary(), op->getPropertiesStorage(),
2172 op->getRegions(), types);
2176 assert(types.size() == op->getNumResults());
2177 for (
auto [result, type] :
llvm::zip(op->getResults(), types)) {
2178 LLVM_DEBUG(llvm::dbgs()
2179 <<
"Inferring " << result <<
" as " << type <<
"\n");
2180 result.setType(type);
2186 auto *context = type.getContext();
2187 unsigned fieldID = 0;
2190 auto width = type.getBitWidthOrSentinel();
2202 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
2205 llvm::SmallVector<BundleType::BundleElement, 3> elements;
2206 for (
auto &element : bundleType) {
2207 auto updatedBase = updateBase(element.type);
2210 elements.emplace_back(element.name, element.isFlip, updatedBase);
2212 return BundleType::get(context, elements, bundleType.isConst());
2214 if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2216 auto save = fieldID;
2219 if (vecType.getNumElements() > 0) {
2220 auto updatedBase = updateBase(vecType.getElementType());
2223 auto newType = FVectorType::get(updatedBase, vecType.getNumElements(),
2225 fieldID = save + vecType.getMaxFieldID();
2231 if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
2233 llvm::SmallVector<FEnumType::EnumElement> elements;
2234 for (
auto &element : enumType.getElements()) {
2235 auto updatedBase = updateBase(element.type);
2238 elements.emplace_back(element.name, updatedBase);
2240 return FEnumType::get(context, elements, enumType.isConst());
2242 llvm_unreachable(
"Unknown type inside a bundle!");
2249 LLVM_DEBUG(llvm::dbgs() <<
"Update " << value <<
" to " << newType <<
"\n");
2250 value.setType(newType);
2255 if (
auto op = value.getDefiningOp<ConstantOp>()) {
2256 auto k = op.getValue();
2257 auto bitwidth = op.getType().getBitWidthOrSentinel();
2258 if (k.getBitWidth() >
unsigned(bitwidth))
2259 k = k.trunc(bitwidth);
2260 op->setAttr(
"value", IntegerAttr::get(op.getContext(), k));
2263 return newType != type;
2269 assert(type.isGround() &&
"Can only pass in ground types.");
2272 Expr *expr = mapping.getExprOrNull(fieldRef);
2273 if (!expr || !expr->getSolution()) {
2278 mlir::emitError(value.getLoc(),
"width should have been inferred");
2281 int32_t solution = *expr->getSolution();
2291class InferWidthsPass
2292 :
public circt::firrtl::impl::InferWidthsBase<InferWidthsPass> {
2293 void runOnOperation()
override;
2297void InferWidthsPass::runOnOperation() {
2299 ConstraintSolver solver;
2300 InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2301 getAnalysis<hw::InnerSymbolTableCollection>());
2302 if (failed(mapping.map(getOperation())))
2303 return signalPassFailure();
2306 if (mapping.areAllModulesSkipped())
2307 return markAllAnalysesPreserved();
2310 if (failed(solver.solve()))
2311 return signalPassFailure();
2314 if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2315 return signalPassFailure();
2319 return std::make_unique<InferWidthsPass>();
assert(baseType &&"element must be base type")
static unsigned getFieldID(BundleType type, unsigned index)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, std::vector< Frame > &worklist)
Compute the value of a constraint expr.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
This class represents a collection of InnerSymbolTable's.
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
static bool operator==(const ModulePort &a, const ModulePort &b)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
llvm::hash_code hash_value(const T &e)
This class represents the namespace in which InnerRef's can be resolved.