19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/ADT/APSInt.h"
22#include "llvm/ADT/DenseSet.h"
23#include "llvm/ADT/Hashing.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/ErrorHandling.h"
28#define DEBUG_TYPE "infer-widths"
32#define GEN_PASS_DEF_INFERWIDTHS
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
37using mlir::InferTypeOpInterface;
41using namespace firrtl;
49 auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
52 if (!basetype.hasUninferredWidth())
55 if (basetype.isGround())
56 diag.attachNote() <<
"Field: \"" << str <<
"\"";
57 else if (
auto vecType = type_dyn_cast<FVectorType>(basetype))
59 else if (
auto bundleType = type_dyn_cast<BundleType>(basetype))
60 for (
auto &elem : bundleType.getElements())
66 uint64_t convertedFieldID = 0;
68 auto curFID = fieldID;
73 if (isa<FVectorType>(curFType))
76 convertedFieldID += curFID - subID;
81 return convertedFieldID;
93template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
95inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const T &e) {
102template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
105 return e.hash_value();
110#define EXPR_NAMES(x) \
111 Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
112#define EXPR_KINDS EXPR_NAMES()
113#define EXPR_CLASSES EXPR_NAMES(Expr)
120 void print(llvm::raw_ostream &os)
const;
122 std::optional<int32_t> getSolution()
const {
128 void setSolution(int32_t solution) {
130 this->solution = solution;
133 Kind getKind()
const {
return kind; }
136 Expr(Kind kind) : kind(kind) {}
142 bool hasSolution =
false;
146template <
class DerivedT, Expr::Kind DerivedKind>
147struct ExprBase :
public Expr {
148 ExprBase() : Expr(DerivedKind) {}
149 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
151 if (
auto otherSame = dyn_cast<DerivedT>(other))
152 return *
static_cast<DerivedT *
>(
this) == otherSame;
158struct VarExpr :
public ExprBase<VarExpr, Expr::Kind::Var> {
159 void print(llvm::raw_ostream &os)
const {
162 os <<
"var" << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFFF);
167 Expr *constraint =
nullptr;
170 Expr *upperBound =
nullptr;
171 std::optional<int32_t> upperBoundSolution;
178struct DerivedExpr :
public ExprBase<DerivedExpr, Expr::Kind::Derived> {
179 void print(llvm::raw_ostream &os)
const {
182 << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFF);
186 Expr *assigned =
nullptr;
203struct IdExpr :
public ExprBase<IdExpr, Expr::Kind::Id> {
204 IdExpr(Expr *arg) : arg(arg) {
assert(arg); }
205 void print(llvm::raw_ostream &os)
const { os <<
"*" << *arg; }
207 return getKind() == other.getKind() && arg == other.arg;
210 return llvm::hash_combine(Expr::hash_value(), arg);
218struct KnownExpr :
public ExprBase<KnownExpr, Expr::Kind::Known> {
219 KnownExpr(int32_t value) : ExprBase() { setSolution(value); }
220 void print(llvm::raw_ostream &os)
const { os << *getSolution(); }
221 bool operator==(
const KnownExpr &other)
const {
222 return *getSolution() == *other.getSolution();
225 return llvm::hash_combine(Expr::hash_value(), *getSolution());
227 int32_t getValue()
const {
return *getSolution(); }
232struct UnaryExpr :
public Expr {
233 bool operator==(
const UnaryExpr &other)
const {
234 return getKind() == other.getKind() && arg == other.arg;
237 return llvm::hash_combine(Expr::hash_value(), arg);
244 UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) {
assert(arg); }
248template <
class DerivedT, Expr::Kind DerivedKind>
249struct UnaryExprBase :
public UnaryExpr {
250 template <
typename... Args>
251 UnaryExprBase(Args &&...args)
252 : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
253 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
257struct PowExpr :
public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
258 using UnaryExprBase::UnaryExprBase;
259 void print(llvm::raw_ostream &os)
const { os <<
"2^" << arg; }
264struct BinaryExpr :
public Expr {
265 bool operator==(
const BinaryExpr &other)
const {
266 return getKind() == other.getKind() && lhs() == other.lhs() &&
267 rhs() == other.rhs();
270 return llvm::hash_combine(Expr::hash_value(), *args);
272 Expr *lhs()
const {
return args[0]; }
273 Expr *rhs()
const {
return args[1]; }
279 BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
286template <
class DerivedT, Expr::Kind DerivedKind>
287struct BinaryExprBase :
public BinaryExpr {
288 template <
typename... Args>
289 BinaryExprBase(Args &&...args)
290 : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
291 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
295struct AddExpr :
public BinaryExprBase<AddExpr, Expr::Kind::Add> {
296 using BinaryExprBase::BinaryExprBase;
297 void print(llvm::raw_ostream &os)
const {
298 os <<
"(" << *lhs() <<
" + " << *rhs() <<
")";
303struct MaxExpr :
public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
304 using BinaryExprBase::BinaryExprBase;
305 void print(llvm::raw_ostream &os)
const {
306 os <<
"max(" << *lhs() <<
", " << *rhs() <<
")";
311struct MinExpr :
public BinaryExprBase<MinExpr, Expr::Kind::Min> {
312 using BinaryExprBase::BinaryExprBase;
313 void print(llvm::raw_ostream &os)
const {
314 os <<
"min(" << *lhs() <<
", " << *rhs() <<
")";
318void Expr::print(llvm::raw_ostream &os)
const {
320 [&](
auto *e) { e->print(os); });
334struct InternedSlotInfo : DenseMapInfo<T *> {
335 static unsigned getHashValue(
const T *val) {
return mlir::hash_value(*val); }
336 static bool isEqual(
const T *lhs,
const T *rhs) {
345template <
typename T,
typename std::enable_if_t<
346 std::is_trivially_destructible<T>::value,
int> = 0>
347class InternedAllocator {
348 llvm::DenseSet<T *, InternedSlotInfo<T>> interned;
349 llvm::BumpPtrAllocator &allocator;
352 InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
357 template <
typename R = T,
typename... Args>
358 std::pair<R *, bool> alloc(Args &&...args) {
359 auto stackValue = R(std::forward<Args>(args)...);
360 auto *stackSlot = &stackValue;
361 auto it = interned.find(stackSlot);
362 if (it != interned.end())
363 return std::make_pair(
static_cast<R *
>(*it),
false);
364 auto heapValue =
new (allocator) R(std::move(stackValue));
365 interned.insert(heapValue);
366 return std::make_pair(heapValue,
true);
372template <
typename T,
typename std::enable_if_t<
373 std::is_trivially_destructible<T>::value,
int> = 0>
375 llvm::BumpPtrAllocator &allocator;
378 Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
382 template <
typename R = T,
typename... Args>
383 R *alloc(Args &&...args) {
384 return new (allocator) R(std::forward<Args>(args)...);
415 int32_t recScale = 0;
417 int32_t nonrecBias = 0;
421 static LinIneq unsat() {
return LinIneq(
true); }
424 explicit LinIneq(
bool failed =
false) : failed(failed) {}
427 explicit LinIneq(int32_t bias) : nonrecBias(bias) {}
430 explicit LinIneq(int32_t scale, int32_t bias) {
441 explicit LinIneq(int32_t recScale, int32_t recBias, int32_t nonrecBias,
445 this->recScale = recScale;
446 this->recBias = recBias;
447 this->nonrecBias = nonrecBias;
449 this->nonrecBias = std::max(recBias, nonrecBias);
460 static LinIneq max(
const LinIneq &lhs,
const LinIneq &rhs) {
461 return LinIneq(std::max(lhs.recScale, rhs.recScale),
462 std::max(lhs.recBias, rhs.recBias),
463 std::max(lhs.nonrecBias, rhs.nonrecBias),
464 lhs.failed || rhs.failed);
520 static LinIneq add(
const LinIneq &lhs,
const LinIneq &rhs) {
523 auto enable1 = lhs.recScale > 0 && rhs.recScale > 0;
524 auto enable2 = lhs.recScale > 0;
525 auto enable3 = rhs.recScale > 0;
526 auto scale1 = lhs.recScale + rhs.recScale;
527 auto scale2 = lhs.recScale;
528 auto scale3 = rhs.recScale;
529 auto bias1 = lhs.recBias + rhs.recBias;
530 auto bias2 = lhs.recBias + rhs.nonrecBias;
531 auto bias3 = rhs.recBias + lhs.nonrecBias;
532 auto maxScale = std::max(scale1, std::max(scale2, scale3));
536 std::optional<int32_t> maxBias;
537 if (enable1 && scale1 == maxScale)
539 if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
541 if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
546 auto nonrecBias = lhs.nonrecBias + rhs.nonrecBias;
547 auto failed = lhs.failed || rhs.failed;
548 if (enable1 && scale1 == maxScale && bias1 == *maxBias)
549 return LinIneq(scale1, bias1, nonrecBias, failed);
550 if (enable2 && scale2 == maxScale && bias2 == *maxBias)
551 return LinIneq(scale2, bias2, nonrecBias, failed);
552 if (enable3 && scale3 == maxScale && bias3 == *maxBias)
553 return LinIneq(scale3, bias3, nonrecBias, failed);
554 return LinIneq(0, 0, nonrecBias, failed);
567 if (recScale == 1 && recBias > 0)
573 void print(llvm::raw_ostream &os)
const {
575 bool both = (recScale != 0 || recBias != 0) && nonrecBias != 0;
582 os << recScale <<
"*";
588 os <<
" - " << -recBias;
590 os <<
" + " << recBias;
598 if (nonrecBias != 0) {
615class ConstraintSolver {
617 ConstraintSolver() =
default;
620 auto *v = vars.alloc();
621 varExprs.push_back(v);
623 info[v].insert(currentInfo);
625 locs[v].insert(*currentLoc);
628 DerivedExpr *derived() {
629 auto *
d = derivs.alloc();
630 derivedExprs.push_back(d);
633 KnownExpr *known(int32_t value) {
return alloc<KnownExpr>(knowns, value); }
634 IdExpr *id(Expr *arg) {
return alloc<IdExpr>(ids, arg); }
635 PowExpr *pow(Expr *arg) {
return alloc<PowExpr>(uns, arg); }
636 AddExpr *add(Expr *lhs, Expr *rhs) {
return alloc<AddExpr>(bins, lhs, rhs); }
637 MaxExpr *max(Expr *lhs, Expr *rhs) {
return alloc<MaxExpr>(bins, lhs, rhs); }
638 MinExpr *min(Expr *lhs, Expr *rhs) {
return alloc<MinExpr>(bins, lhs, rhs); }
642 Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
644 lhs->constraint = max(lhs->constraint, rhs);
646 lhs->constraint = id(rhs);
647 return lhs->constraint;
652 Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
654 lhs->upperBound = min(lhs->upperBound, rhs);
656 lhs->upperBound = id(rhs);
657 return lhs->upperBound;
660 void dumpConstraints(llvm::raw_ostream &os);
661 LogicalResult solve();
663 using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
664 const ContextInfo &getContextInfo()
const {
return info; }
665 void setCurrentContextInfo(
FieldRef fieldRef) { currentInfo = fieldRef; }
666 void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
670 llvm::BumpPtrAllocator allocator;
671 Allocator<VarExpr> vars = {allocator};
672 Allocator<DerivedExpr> derivs = {allocator};
673 InternedAllocator<KnownExpr> knowns = {allocator};
674 InternedAllocator<IdExpr> ids = {allocator};
675 InternedAllocator<UnaryExpr> uns = {allocator};
676 InternedAllocator<BinaryExpr> bins = {allocator};
679 std::vector<VarExpr *> varExprs;
680 std::vector<DerivedExpr *> derivedExprs;
683 template <
typename R,
typename T,
typename... Args>
684 R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
685 auto [expr, inserted] =
686 allocator.template alloc<R>(std::forward<Args>(args)...);
688 info[expr].insert(currentInfo);
690 locs[expr].insert(*currentLoc);
698 DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
699 std::optional<Location> currentLoc;
703 ConstraintSolver(ConstraintSolver &&) =
delete;
704 ConstraintSolver(
const ConstraintSolver &) =
delete;
705 ConstraintSolver &operator=(ConstraintSolver &&) =
delete;
706 ConstraintSolver &operator=(
const ConstraintSolver &) =
delete;
708 void emitUninferredWidthError(VarExpr *var);
710 LinIneq checkCycles(VarExpr *var, Expr *expr,
711 SmallPtrSetImpl<Expr *> &seenVars,
712 InFlightDiagnostic *reportInto =
nullptr,
713 unsigned indent = 1);
719void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
720 for (
auto *v : varExprs) {
722 os <<
"- " << *v <<
" >= " << *v->constraint <<
"\n";
724 os <<
"- " << *v <<
" unconstrained\n";
729inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const LinIneq &l) {
745LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
746 SmallPtrSetImpl<Expr *> &seenVars,
747 InFlightDiagnostic *reportInto,
750 TypeSwitch<Expr *, LinIneq>(expr)
752 [&](
auto *expr) {
return LinIneq(expr->getValue()); })
753 .Case<VarExpr>([&](
auto *expr) {
755 return LinIneq(1, 0);
756 if (!seenVars.insert(expr).second)
763 if (!expr->constraint)
766 auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
768 seenVars.erase(expr);
771 .Case<IdExpr>([&](
auto *expr) {
772 return checkCycles(var, expr->arg, seenVars, reportInto,
775 .Case<PowExpr>([&](
auto *expr) {
780 checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
781 if (arg.recScale != 0 || arg.nonrecBias < 0 || arg.nonrecBias >= 31)
782 return LinIneq::unsat();
783 return LinIneq(1 << arg.nonrecBias);
785 .Case<AddExpr>([&](
auto *expr) {
787 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
788 checkCycles(var, expr->rhs(), seenVars, reportInto,
791 .Case<MaxExpr, MinExpr>([&](
auto *expr) {
796 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
797 checkCycles(var, expr->rhs(), seenVars, reportInto,
800 .Default([](
auto) {
return LinIneq::unsat(); });
805 if (reportInto && !ineq.sat()) {
806 auto report = [&](Location loc) {
807 auto ¬e = reportInto->attachNote(loc);
808 note <<
"constrained width W >= ";
809 if (ineq.recScale == -1)
811 if (ineq.recScale != 1)
812 note << ineq.recScale;
814 if (ineq.recBias < 0)
815 note <<
"-" << -ineq.recBias;
816 if (ineq.recBias > 0)
817 note <<
"+" << ineq.recBias;
820 auto it = locs.find(expr);
821 if (it != locs.end())
822 for (
auto loc : it->second)
826 LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
827 <<
"- Visited " << *expr <<
": " << ineq <<
"\n");
837 arg.first = operation(*arg.first);
843 llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
844 auto result =
ExprSolution{std::nullopt, lhs.second || rhs.second};
845 if (lhs.first && rhs.first)
846 result.first = operation(*lhs.first, *rhs.first);
848 result.first = lhs.first;
850 result.first = rhs.first;
856 Frame(Expr *expr,
unsigned indent) : expr(expr), indent(indent) {}
869 std::vector<Frame> &worklist) {
871 worklist.emplace_back(expr, 1);
872 llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
874 while (!worklist.empty()) {
875 auto &frame = worklist.back();
876 auto indent = frame.indent;
879 if (solution.first && !solution.second)
880 frame.expr->setSolution(*solution.first);
881 solvedExprs[frame.expr] = solution;
885 if (!isa<KnownExpr>(frame.expr)) {
887 llvm::dbgs().indent(indent * 2)
888 <<
"= Solved " << *frame.expr <<
" = " << *solution.first;
890 llvm::dbgs().indent(indent * 2) <<
"= Skipped " << *frame.expr;
891 llvm::dbgs() <<
" (" << (solution.second ?
"cycle broken" :
"unique")
900 if (frame.expr->getSolution()) {
902 if (!isa<KnownExpr>(frame.expr))
903 llvm::dbgs().indent(indent * 2) <<
"- Cached " << *frame.expr <<
" = "
904 << *frame.expr->getSolution() <<
"\n";
906 setSolution(
ExprSolution{*frame.expr->getSolution(),
false});
912 if (!isa<KnownExpr>(frame.expr))
913 llvm::dbgs().indent(indent * 2) <<
"- Solving " << *frame.expr <<
"\n";
916 TypeSwitch<Expr *>(frame.expr)
917 .Case<KnownExpr>([&](
auto *expr) {
920 .Case<VarExpr>([&](
auto *expr) {
921 if (solvedExprs.contains(expr->constraint)) {
922 auto solution = solvedExprs[expr->constraint];
929 if (expr->upperBound && solvedExprs.contains(expr->upperBound))
930 expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
931 seenVars.erase(expr);
933 if (solution.first && *solution.first < 0)
935 return setSolution(solution);
939 if (!expr->constraint)
944 if (!seenVars.insert(expr).second)
947 worklist.emplace_back(expr->constraint, indent + 1);
948 if (expr->upperBound)
949 worklist.emplace_back(expr->upperBound, indent + 1);
951 .Case<IdExpr>([&](
auto *expr) {
952 if (solvedExprs.contains(expr->arg))
953 return setSolution(solvedExprs[expr->arg]);
954 worklist.emplace_back(expr->arg, indent + 1);
956 .Case<PowExpr>([&](
auto *expr) {
957 if (solvedExprs.contains(expr->arg))
959 solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
961 worklist.emplace_back(expr->arg, indent + 1);
963 .Case<AddExpr>([&](
auto *expr) {
964 if (solvedExprs.contains(expr->lhs()) &&
965 solvedExprs.contains(expr->rhs()))
967 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
968 [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
970 worklist.emplace_back(expr->lhs(), indent + 1);
971 worklist.emplace_back(expr->rhs(), indent + 1);
973 .Case<MaxExpr>([&](
auto *expr) {
974 if (solvedExprs.contains(expr->lhs()) &&
975 solvedExprs.contains(expr->rhs()))
977 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
978 [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
980 worklist.emplace_back(expr->lhs(), indent + 1);
981 worklist.emplace_back(expr->rhs(), indent + 1);
983 .Case<MinExpr>([&](
auto *expr) {
984 if (solvedExprs.contains(expr->lhs()) &&
985 solvedExprs.contains(expr->rhs()))
987 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
988 [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
990 worklist.emplace_back(expr->lhs(), indent + 1);
991 worklist.emplace_back(expr->rhs(), indent + 1);
998 return solvedExprs[expr];
1004LogicalResult ConstraintSolver::solve() {
1006 llvm::dbgs() <<
"\n";
1008 dumpConstraints(llvm::dbgs());
1013 llvm::dbgs() <<
"\n";
1014 debugHeader(
"Checking for unbreakable loops") <<
"\n\n";
1016 SmallPtrSet<Expr *, 16> seenVars;
1017 bool anyFailed =
false;
1019 for (
auto *var : varExprs) {
1020 if (!var->constraint)
1022 LLVM_DEBUG(llvm::dbgs()
1023 <<
"- Checking " << *var <<
" >= " << *var->constraint <<
"\n");
1028 seenVars.insert(var);
1029 auto ineq = checkCycles(var, var->constraint, seenVars);
1038 LLVM_DEBUG(llvm::dbgs()
1039 <<
" = Breakable since " << ineq <<
" satisfiable\n");
1047 LLVM_DEBUG(llvm::dbgs()
1048 <<
" = UNBREAKABLE since " << ineq <<
" unsatisfiable\n");
1050 for (
auto fieldRef :
info.find(var)->second) {
1053 auto *op = fieldRef.getDefiningOp();
1054 auto diag = op ? op->emitOpError()
1055 : mlir::emitError(fieldRef.getValue().getLoc())
1057 diag <<
"is constrained to be wider than itself";
1060 seenVars.insert(var);
1061 checkCycles(var, var->constraint, seenVars, &diag);
1073 llvm::dbgs() <<
"\n";
1076 std::vector<Frame> worklist;
1077 for (
auto *var : varExprs) {
1079 if (!var->constraint) {
1080 LLVM_DEBUG(llvm::dbgs() <<
"- Unconstrained " << *var <<
"\n");
1081 emitUninferredWidthError(var);
1087 LLVM_DEBUG(llvm::dbgs()
1088 <<
"- Solving " << *var <<
" >= " << *var->constraint <<
"\n");
1089 seenVars.insert(var);
1090 auto solution =
solveExpr(var->constraint, seenVars, worklist);
1092 if (var->upperBound && !var->upperBoundSolution)
1093 var->upperBoundSolution =
1094 solveExpr(var->upperBound, seenVars, worklist).first;
1098 if (solution.first) {
1099 if (*solution.first < 0)
1101 var->setSolution(*solution.first);
1106 if (!solution.first) {
1107 LLVM_DEBUG(llvm::dbgs() <<
" - UNSOLVED " << *var <<
"\n");
1108 emitUninferredWidthError(var);
1112 LLVM_DEBUG(llvm::dbgs()
1113 <<
" = Solved " << *var <<
" = " << solution.first <<
" ("
1114 << (solution.second ?
"cycle broken" :
"unique") <<
")\n");
1117 if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1118 LLVM_DEBUG(llvm::dbgs() <<
" ! Unsatisfiable " << *var
1119 <<
" <= " << var->upperBoundSolution <<
"\n");
1120 emitUninferredWidthError(var);
1126 for (
auto *derived : derivedExprs) {
1127 auto *assigned = derived->assigned;
1128 if (!assigned || !assigned->getSolution()) {
1129 LLVM_DEBUG(llvm::dbgs() <<
"- Unused " << *derived <<
" set to 0\n");
1130 derived->setSolution(0);
1132 LLVM_DEBUG(llvm::dbgs() <<
"- Deriving " << *derived <<
" = "
1133 << assigned->getSolution() <<
"\n");
1134 derived->setSolution(*assigned->getSolution());
1138 return failure(anyFailed);
1143void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1147 auto diag = mlir::emitError(value.getLoc(),
"uninferred width:");
1150 if (isa<BlockArgument>(value)) {
1152 }
else if (
auto *op = value.getDefiningOp()) {
1153 TypeSwitch<Operation *>(op)
1154 .Case<WireOp>([&](
auto) { diag <<
" wire"; })
1155 .Case<RegOp, RegResetOp>([&](
auto) { diag <<
" reg"; })
1156 .Case<NodeOp>([&](
auto) { diag <<
" node"; })
1157 .Default([&](
auto) { diag <<
" value"; });
1164 if (!fieldName.empty()) {
1167 diag <<
" \"" << fieldName <<
"\"";
1170 if (!var->constraint) {
1171 diag <<
" is unconstrained";
1172 }
else if (var->getSolution() && var->upperBoundSolution &&
1173 var->getSolution() > var->upperBoundSolution) {
1174 diag <<
" cannot satisfy all width requirements";
1175 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1176 LLVM_DEBUG(llvm::dbgs() << *var->upperBound <<
"\n");
1177 auto loc = locs.find(var->constraint)->second.back();
1178 diag.attachNote(loc) <<
"width is constrained to be at least "
1179 << *var->getSolution() <<
" here:";
1180 loc = locs.find(var->upperBound)->second.back();
1181 diag.attachNote(loc) <<
"width is constrained to be at most "
1182 << *var->upperBoundSolution <<
" here:";
1184 diag <<
" width cannot be determined";
1185 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1186 auto loc = locs.find(var->constraint)->second.back();
1187 diag.attachNote(loc) <<
"width is constrained by an uninferred width here:";
1199class InferenceMapping {
1201 InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1203 : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1205 LogicalResult map(CircuitOp op);
1206 bool allWidthsKnown(Operation *op);
1207 LogicalResult mapOperation(Operation *op);
1212 void declareVars(Value value,
bool isDerived =
false);
1217 void maximumOfTypes(Value result, Value rhs, Value lhs);
1221 void constrainTypes(Value larger, Value smaller,
bool equal =
false);
1225 void constrainTypes(Expr *larger, Expr *smaller,
1226 bool imposeUpperBounds =
false,
bool equal =
false);
1235 Expr *getExpr(Value value)
const;
1238 Expr *getExpr(
FieldRef fieldRef)
const;
1242 Expr *getExprOrNull(
FieldRef fieldRef)
const;
1246 void setExpr(Value value, Expr *expr);
1249 void setExpr(
FieldRef fieldRef, Expr *expr);
1252 bool isModuleSkipped(FModuleOp module)
const {
1253 return skippedModules.count(module);
1257 bool areAllModulesSkipped()
const {
return allModulesSkipped; }
1261 ConstraintSolver &solver;
1264 DenseMap<FieldRef, Expr *> opExprs;
1267 SmallPtrSet<Operation *, 16> skippedModules;
1268 bool allModulesSkipped =
true;
1271 SymbolTable &symtbl;
1281 return TypeSwitch<Type, bool>(type)
1282 .Case<
FIRRTLBaseType>([](
auto base) {
return base.hasUninferredWidth(); })
1284 [](
auto ref) {
return ref.getType().hasUninferredWidth(); })
1285 .Default([](
auto) {
return false; });
1288LogicalResult InferenceMapping::map(CircuitOp op) {
1289 LLVM_DEBUG(llvm::dbgs()
1290 <<
"\n===----- Mapping ops to constraint exprs -----===\n\n");
1293 for (
auto module : op.getOps<FModuleOp>())
1294 for (auto arg : module.getArguments()) {
1295 solver.setCurrentContextInfo(
FieldRef(arg, 0));
1299 for (
auto module : op.getOps<FModuleOp>()) {
1302 bool anyUninferred =
false;
1303 for (
auto arg : module.getArguments()) {
1308 module.walk([&](Operation *op) {
1309 for (auto type : op->getResultTypes())
1310 anyUninferred |= hasUninferredWidth(type);
1312 return WalkResult::interrupt();
1313 return WalkResult::advance();
1316 if (!anyUninferred) {
1317 LLVM_DEBUG(llvm::dbgs() <<
"Skipping fully-inferred module '"
1318 << module.getName() <<
"'\n");
1319 skippedModules.insert(module);
1323 allModulesSkipped =
false;
1327 auto result =
module.getBodyBlock()->walk(
1328 [&](Operation *op) { return WalkResult(mapOperation(op)); });
1329 if (result.wasInterrupted())
1336bool InferenceMapping::allWidthsKnown(Operation *op) {
1338 if (isa<PropAssignOp>(op))
1343 if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1348 if (isa<FConnectLike, AttachOp>(op))
1352 return llvm::all_of(op->getResults(), [&](
auto result) {
1355 if (auto type = type_dyn_cast<FIRRTLType>(result.getType()))
1356 if (hasUninferredWidth(type))
1362LogicalResult InferenceMapping::mapOperation(Operation *op) {
1363 if (allWidthsKnown(op))
1367 bool mappingFailed =
false;
1368 solver.setCurrentContextInfo(
1370 solver.setCurrentLocation(op->getLoc());
1371 TypeSwitch<Operation *>(op)
1372 .Case<ConstantOp>([&](
auto op) {
1375 auto v = op.getValue();
1376 auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1377 : v.countLeadingZeros());
1380 setExpr(op.getResult(), solver.known(std::max(w, 1u)));
1382 .Case<SpecialConstantOp>([&](
auto op) {
1385 .Case<InvalidValueOp>([&](
auto op) {
1388 declareVars(op.getResult(),
true);
1390 .Case<WireOp, RegOp>([&](
auto op) { declareVars(op.getResult()); })
1391 .Case<RegResetOp>([&](
auto op) {
1396 declareVars(op.getResult());
1399 constrainTypes(op.getResult(), op.getResetValue());
1401 .Case<NodeOp>([&](
auto op) {
1404 op.getResult().getType());
1408 .Case<SubfieldOp>([&](
auto op) {
1409 BundleType bundleType = op.getInput().getType();
1410 auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1411 unifyTypes(
FieldRef(op.getResult(), 0),
1412 FieldRef(op.getInput(), fieldID), op.getType());
1414 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
1420 .Case<RefSubOp>([&](RefSubOp op) {
1421 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1422 op.getInput().getType().getType())
1423 .Case<FVectorType>([](
auto _) {
return 1; })
1424 .Case<BundleType>([&](
auto type) {
1425 return type.getFieldID(op.getIndex());
1427 unifyTypes(
FieldRef(op.getResult(), 0),
1428 FieldRef(op.getInput(), fieldID), op.getType());
1432 .Case<AddPrimOp, SubPrimOp>([&](
auto op) {
1433 auto lhs = getExpr(op.getLhs());
1434 auto rhs = getExpr(op.getRhs());
1435 auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1436 setExpr(op.getResult(), e);
1438 .Case<MulPrimOp>([&](
auto op) {
1439 auto lhs = getExpr(op.getLhs());
1440 auto rhs = getExpr(op.getRhs());
1441 auto e = solver.add(lhs, rhs);
1442 setExpr(op.getResult(), e);
1444 .Case<DivPrimOp>([&](
auto op) {
1445 auto lhs = getExpr(op.getLhs());
1447 if (op.getType().base().isSigned()) {
1448 e = solver.add(lhs, solver.known(1));
1452 setExpr(op.getResult(), e);
1454 .Case<RemPrimOp>([&](
auto op) {
1455 auto lhs = getExpr(op.getLhs());
1456 auto rhs = getExpr(op.getRhs());
1457 auto e = solver.min(lhs, rhs);
1458 setExpr(op.getResult(), e);
1460 .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](
auto op) {
1461 auto lhs = getExpr(op.getLhs());
1462 auto rhs = getExpr(op.getRhs());
1463 auto e = solver.max(lhs, rhs);
1464 setExpr(op.getResult(), e);
1467 .Case<CatPrimOp>([&](
auto op) {
1468 if (op.getInputs().empty()) {
1469 setExpr(op.getResult(), solver.known(0));
1472 auto result = getExpr(op.getInputs().front());
1473 for (
auto operand : op.getInputs().drop_front()) {
1474 auto operandExpr = getExpr(operand);
1475 result = solver.add(result, operandExpr);
1477 setExpr(op.getResult(), result);
1480 .Case<DShlPrimOp>([&](
auto op) {
1481 auto lhs = getExpr(op.getLhs());
1482 auto rhs = getExpr(op.getRhs());
1483 auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1484 setExpr(op.getResult(), e);
1486 .Case<DShlwPrimOp, DShrPrimOp>([&](
auto op) {
1487 auto e = getExpr(op.getLhs());
1488 setExpr(op.getResult(), e);
1492 .Case<NegPrimOp>([&](
auto op) {
1493 auto input = getExpr(op.getInput());
1494 auto e = solver.add(input, solver.known(1));
1495 setExpr(op.getResult(), e);
1497 .Case<CvtPrimOp>([&](
auto op) {
1498 auto input = getExpr(op.getInput());
1499 auto e = op.getInput().getType().base().isSigned()
1501 : solver.add(input, solver.known(1));
1502 setExpr(op.getResult(), e);
1506 .Case<BitsPrimOp>([&](
auto op) {
1507 setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1509 .Case<HeadPrimOp>([&](
auto op) {
1510 setExpr(op.getResult(), solver.known(op.getAmount()));
1512 .Case<TailPrimOp>([&](
auto op) {
1513 auto input = getExpr(op.getInput());
1514 auto e = solver.add(input, solver.known(-op.getAmount()));
1515 setExpr(op.getResult(), e);
1517 .Case<PadPrimOp>([&](
auto op) {
1518 auto input = getExpr(op.getInput());
1519 auto e = solver.max(input, solver.known(op.getAmount()));
1520 setExpr(op.getResult(), e);
1522 .Case<ShlPrimOp>([&](
auto op) {
1523 auto input = getExpr(op.getInput());
1524 auto e = solver.add(input, solver.known(op.getAmount()));
1525 setExpr(op.getResult(), e);
1527 .Case<ShrPrimOp>([&](
auto op) {
1528 auto input = getExpr(op.getInput());
1530 auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1531 auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1532 solver.known(minWidth));
1533 setExpr(op.getResult(), e);
1537 .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1538 [&](
auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1539 .Case<mlir::UnrealizedConversionCastOp>(
1540 [&](
auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1544 .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1545 AsClockPrimOp, AsAsyncResetPrimOp, AsResetPrimOp, AndRPrimOp,
1546 OrRPrimOp, XorRPrimOp>([&](
auto op) {
1547 auto width = op.getType().getBitWidthOrSentinel();
1548 assert(width > 0 &&
"width should have been checked by verifier");
1549 setExpr(op.getResult(), solver.known(width));
1551 .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](
auto op) {
1552 auto *sel = getExpr(op.getSel());
1553 constrainTypes(solver.known(1), sel,
true);
1554 maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1556 .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1557 auto *sel = getExpr(op.getSel());
1558 constrainTypes(solver.known(2), sel,
true);
1559 maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1560 maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1561 maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1564 .Case<ConnectOp, MatchingConnectOp>(
1565 [&](
auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1566 .Case<RefDefineOp>([&](
auto op) {
1569 constrainTypes(op.getDest(), op.getSrc(),
true);
1571 .Case<AttachOp>([&](
auto op) {
1575 if (op.getAttached().empty())
1577 auto prev = op.getAttached()[0];
1578 for (
auto operand : op.getAttached().drop_front()) {
1579 auto e1 = getExpr(prev);
1580 auto e2 = getExpr(operand);
1581 constrainTypes(e1, e2,
true);
1582 constrainTypes(e2, e1,
true);
1588 .Case<AssertOp, AssumeOp, CoverOp, DomainDefineOp, FFlushOp, PrintFOp,
1589 SkipOp, StopOp, UnclockedAssumeIntrinsicOp, WhenOp>([&](
auto) {})
1592 .Case<InstanceOp>([&](
auto op) {
1593 auto refdModule = op.getReferencedOperation(symtbl);
1594 auto module = dyn_cast<FModuleOp>(&*refdModule);
1596 auto diag = mlir::emitError(op.getLoc());
1597 diag <<
"extern module `" << op.getModuleName()
1598 <<
"` has ports of uninferred width";
1600 auto fml = cast<FModuleLike>(&*refdModule);
1601 auto ports = fml.getPorts();
1602 for (
auto &port : ports) {
1604 if (baseType && baseType.hasUninferredWidth()) {
1605 diag.attachNote(op.getLoc()) <<
"Port: " << port.name;
1606 if (!baseType.isGround())
1611 diag.attachNote(op.getLoc())
1612 <<
"Only non-extern FIRRTL modules may contain unspecified "
1613 "widths to be inferred automatically.";
1614 diag.attachNote(refdModule->getLoc())
1615 <<
"Module `" << op.getModuleName() <<
"` defined here:";
1616 mappingFailed =
true;
1623 for (
auto [result, arg] :
1624 llvm::zip(op->getResults(), module.getArguments()))
1625 unifyTypes({result, 0}, {arg, 0},
1626 type_cast<FIRRTLType>(result.getType()));
1630 .Case<MemOp>([&](MemOp op) {
1632 unsigned nonDebugPort = 0;
1633 for (
const auto &result :
llvm::enumerate(op.getResults())) {
1634 declareVars(result.value());
1635 if (!type_isa<RefType>(result.value().getType()))
1636 nonDebugPort = result.index();
1641 auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1642 static const unsigned indices[] = {3, 5};
1643 static const unsigned debug[] = {0};
1645 case MemOp::PortKind::Read:
1646 case MemOp::PortKind::Write:
1647 return ArrayRef<unsigned>(indices, 1);
1648 case MemOp::PortKind::ReadWrite:
1649 return ArrayRef<unsigned>(indices);
1650 case MemOp::PortKind::Debug:
1651 return ArrayRef<unsigned>(
debug);
1653 llvm_unreachable(
"Imposible PortKind");
1660 unsigned firstFieldIndex =
1661 dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1663 op.getResult(nonDebugPort),
1664 type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1665 .getFieldID(firstFieldIndex));
1666 LLVM_DEBUG(llvm::dbgs() <<
"Adjusting memory port variables:\n");
1669 auto dataType = op.getDataType();
1670 for (
unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1671 auto result = op.getResult(i);
1672 if (type_isa<RefType>(result.getType())) {
1676 unifyTypes(firstData,
FieldRef(result, 1), dataType);
1681 type_cast<BundleType>(op.getPortType(i).getPassiveType());
1682 for (
auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1684 firstData, dataType);
1688 .Case<RefSendOp>([&](
auto op) {
1689 declareVars(op.getResult());
1690 constrainTypes(op.getResult(), op.getBase(),
true);
1692 .Case<RefResolveOp>([&](
auto op) {
1693 declareVars(op.getResult());
1694 constrainTypes(op.getResult(), op.getRef(),
true);
1696 .Case<RefCastOp>([&](
auto op) {
1697 declareVars(op.getResult());
1698 constrainTypes(op.getResult(), op.getInput(),
true);
1700 .Case<RWProbeOp>([&](
auto op) {
1701 auto ist = irn.lookup(op.getTarget());
1703 op->emitError(
"target of rwprobe could not be resolved");
1704 mappingFailed =
true;
1709 op->emitError(
"target of rwprobe resolved to unsupported target");
1710 mappingFailed =
true;
1714 ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1715 unifyTypes(
FieldRef(op.getResult(), 0),
1716 FieldRef(ref.getValue(), newFID), op.getType());
1718 .Case<mlir::UnrealizedConversionCastOp>([&](
auto op) {
1719 for (Value result : op.getResults()) {
1720 auto ty = result.getType();
1721 if (type_isa<FIRRTLType>(ty))
1722 declareVars(result);
1725 .Default([&](
auto op) {
1726 op->emitOpError(
"not supported in width inference");
1727 mappingFailed =
true;
1731 if (
auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1735 return failure(mappingFailed);
1740void InferenceMapping::declareVars(Value value,
bool isDerived) {
1743 unsigned fieldID = 0;
1745 auto width = type.getBitWidthOrSentinel();
1748 }
else if (width == -1) {
1751 solver.setCurrentContextInfo(field);
1753 setExpr(field, solver.derived());
1755 setExpr(field, solver.var());
1757 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1760 for (
auto &element : bundleType)
1761 declare(element.type);
1762 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1764 auto save = fieldID;
1765 declare(vecType.getElementType());
1767 fieldID = save + vecType.getMaxFieldID();
1769 llvm_unreachable(
"Unknown type inside a bundle!");
1779void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1783 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1785 for (
auto &element : bundleType.getElements())
1786 maximize(element.type);
1787 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1789 auto save = fieldID;
1791 if (vecType.getNumElements() > 0)
1792 maximize(vecType.getElementType());
1793 fieldID = save + vecType.getMaxFieldID();
1794 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1795 auto *e = solver.max(getExpr(
FieldRef(rhs, fieldID)),
1797 setExpr(
FieldRef(result, fieldID), e);
1799 }
else if (type.isGround()) {
1800 auto *e = solver.max(getExpr(
FieldRef(rhs, fieldID)),
1802 setExpr(
FieldRef(result, fieldID), e);
1805 llvm_unreachable(
"Unknown type inside a bundle!");
1820void InferenceMapping::constrainTypes(Value larger, Value smaller,
bool equal) {
1827 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1829 for (
auto &element : bundleType.getElements()) {
1831 constrain(element.type, smaller, larger);
1833 constrain(element.type, larger, smaller);
1835 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1837 auto save = fieldID;
1839 if (vecType.getNumElements() > 0) {
1840 constrain(vecType.getElementType(), larger, smaller);
1842 fieldID = save + vecType.getMaxFieldID();
1843 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1844 constrainTypes(getExpr(
FieldRef(larger, fieldID)),
1845 getExpr(
FieldRef(smaller, fieldID)),
false, equal);
1847 }
else if (type.isGround()) {
1849 constrainTypes(getExpr(
FieldRef(larger, fieldID)),
1850 getExpr(
FieldRef(smaller, fieldID)),
false, equal);
1853 llvm_unreachable(
"Unknown type inside a bundle!");
1858 constrain(type, larger, smaller);
1863void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1864 bool imposeUpperBounds,
bool equal) {
1865 assert(larger &&
"Larger expression should be specified");
1866 assert(smaller &&
"Smaller expression should be specified");
1872 if (
auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1873 largerDerived->assigned = smaller;
1874 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *largerDerived <<
" from "
1875 << *smaller <<
"\n");
1878 if (
auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1879 smallerDerived->assigned = larger;
1880 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *smallerDerived <<
" from "
1881 << *larger <<
"\n");
1887 if (
auto *largerVar = dyn_cast<VarExpr>(larger)) {
1888 [[maybe_unused]]
auto *c = solver.addGeqConstraint(largerVar, smaller);
1889 LLVM_DEBUG(llvm::dbgs()
1890 <<
"Constrained " << *largerVar <<
" >= " << *c <<
"\n");
1897 [[maybe_unused]]
auto *leq = solver.addLeqConstraint(largerVar, smaller);
1898 LLVM_DEBUG(llvm::dbgs()
1899 <<
"Constrained " << *largerVar <<
" <= " << *leq <<
"\n");
1909 if (
auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1910 if (imposeUpperBounds || equal) {
1911 [[maybe_unused]]
auto *c = solver.addLeqConstraint(smallerVar, larger);
1912 LLVM_DEBUG(llvm::dbgs()
1913 <<
"Constrained " << *smallerVar <<
" <= " << *c <<
"\n");
1933 LLVM_DEBUG(llvm::dbgs()
1934 <<
"Unify " <<
getFieldName(lhsFieldRef).first <<
" = "
1937 if (
auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1938 solver.addGeqConstraint(var, solver.known(0));
1939 setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1941 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1943 for (
auto &element : bundleType) {
1944 unify(element.type);
1946 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1948 auto save = fieldID;
1950 if (vecType.getNumElements() > 0) {
1951 unify(vecType.getElementType());
1953 fieldID = save + vecType.getMaxFieldID();
1954 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1957 LLVM_DEBUG(llvm::dbgs()
1958 <<
"Unify " <<
getFieldName(lhsFieldRef).first <<
" = "
1960 setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1963 llvm_unreachable(
"Unknown type inside a bundle!");
1971Expr *InferenceMapping::getExpr(Value value)
const {
1974 return getExpr(
FieldRef(value, 0));
1978Expr *InferenceMapping::getExpr(
FieldRef fieldRef)
const {
1979 auto *expr = getExprOrNull(fieldRef);
1980 assert(expr &&
"constraint expr should have been constructed for value");
1984Expr *InferenceMapping::getExprOrNull(
FieldRef fieldRef)
const {
1985 auto it = opExprs.find(fieldRef);
1986 if (it != opExprs.end())
1993 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
1996 return solver.known(width);
2000void InferenceMapping::setExpr(Value value, Expr *expr) {
2007void InferenceMapping::setExpr(
FieldRef fieldRef, Expr *expr) {
2009 llvm::dbgs() <<
"Expr " << *expr <<
" for " << fieldRef.
getValue();
2011 llvm::dbgs() <<
" '" <<
getFieldName(fieldRef).first <<
"'";
2013 if (fieldName.second)
2014 llvm::dbgs() <<
" (\"" << fieldName.first <<
"\")";
2015 llvm::dbgs() <<
"\n";
2017 opExprs[fieldRef] = expr;
2027class InferenceTypeUpdate {
2029 InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2031 LogicalResult update(CircuitOp op);
2032 FailureOr<bool> updateOperation(Operation *op);
2033 FailureOr<bool> updateValue(Value value);
2037 const InferenceMapping &mapping;
2043LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2045 llvm::dbgs() <<
"\n";
2048 return mlir::failableParallelForEach(
2049 op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2052 if (mapping.isModuleSkipped(op))
2054 auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2055 if (failed(updateOperation(op)))
2056 return WalkResult::interrupt();
2057 return WalkResult::advance();
2058 }).wasInterrupted();
2059 return failure(isFailed);
2064FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2065 bool anyChanged =
false;
2067 for (Value v : op->getResults()) {
2068 auto result = updateValue(v);
2071 anyChanged |= *result;
2077 if (
auto con = dyn_cast<ConnectOp>(op)) {
2078 auto lhs = con.getDest();
2079 auto rhs = con.getSrc();
2080 auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2081 auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2084 if (!lhsType || !rhsType)
2087 auto lhsWidth = lhsType.getBitWidthOrSentinel();
2088 auto rhsWidth = rhsType.getBitWidthOrSentinel();
2089 if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2090 OpBuilder builder(op);
2091 auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2092 rhsWidth - lhsWidth);
2093 if (type_isa<SIntType>(rhsType))
2095 builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2097 LLVM_DEBUG(llvm::dbgs()
2098 <<
"Truncating RHS to " << lhsType <<
" in " << con <<
"\n");
2099 con->replaceUsesOfWith(con.getSrc(), trunc);
2105 if (
auto module = dyn_cast<FModuleOp>(op)) {
2107 bool argsChanged =
false;
2108 SmallVector<Attribute> argTypes;
2109 argTypes.reserve(module.getNumPorts());
2110 for (
auto arg : module.getArguments()) {
2111 auto result = updateValue(arg);
2114 argsChanged |= *result;
2115 argTypes.push_back(TypeAttr::get(arg.getType()));
2120 module.setPortTypesAttr(ArrayAttr::get(module.getContext(), argTypes));
2129 auto *
context = type.getContext();
2131 .
Case<UIntType>([&](
auto type) {
2134 .Case<SIntType>([&](
auto type) {
2137 .Case<AnalogType>([&](
auto type) {
2140 .Default([&](
auto type) {
return type; });
2144FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2146 auto type = type_dyn_cast<FIRRTLType>(value.getType());
2157 if (
auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2158 SmallVector<Type, 2> types;
2160 op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2161 op->getAttrDictionary(), op->getPropertiesStorage(),
2162 op->getRegions(), types);
2166 assert(types.size() == op->getNumResults());
2167 for (
auto [result, type] :
llvm::zip(op->getResults(), types)) {
2168 LLVM_DEBUG(llvm::dbgs()
2169 <<
"Inferring " << result <<
" as " << type <<
"\n");
2170 result.setType(type);
2176 auto *
context = type.getContext();
2177 unsigned fieldID = 0;
2180 auto width = type.getBitWidthOrSentinel();
2192 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
2195 llvm::SmallVector<BundleType::BundleElement, 3> elements;
2196 for (
auto &element : bundleType) {
2197 auto updatedBase = updateBase(element.type);
2200 elements.emplace_back(element.name, element.isFlip, updatedBase);
2202 return BundleType::get(
context, elements, bundleType.isConst());
2204 if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2206 auto save = fieldID;
2209 if (vecType.getNumElements() > 0) {
2210 auto updatedBase = updateBase(vecType.getElementType());
2213 auto newType = FVectorType::get(updatedBase, vecType.getNumElements(),
2215 fieldID = save + vecType.getMaxFieldID();
2221 llvm_unreachable(
"Unknown type inside a bundle!");
2228 LLVM_DEBUG(llvm::dbgs() <<
"Update " << value <<
" to " << newType <<
"\n");
2229 value.setType(newType);
2234 if (
auto op = value.getDefiningOp<ConstantOp>()) {
2235 auto k = op.getValue();
2236 auto bitwidth = op.getType().getBitWidthOrSentinel();
2237 if (k.getBitWidth() >
unsigned(bitwidth))
2238 k = k.trunc(bitwidth);
2239 op->setAttr(
"value", IntegerAttr::get(op.getContext(), k));
2242 return newType != type;
2248 assert(type.isGround() &&
"Can only pass in ground types.");
2251 Expr *expr = mapping.getExprOrNull(fieldRef);
2252 if (!expr || !expr->getSolution()) {
2257 mlir::emitError(value.getLoc(),
"width should have been inferred");
2260 int32_t solution = *expr->getSolution();
2270class InferWidthsPass
2271 :
public circt::firrtl::impl::InferWidthsBase<InferWidthsPass> {
2272 void runOnOperation()
override;
2276void InferWidthsPass::runOnOperation() {
2278 ConstraintSolver solver;
2279 InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2280 getAnalysis<hw::InnerSymbolTableCollection>());
2281 if (failed(mapping.map(getOperation())))
2282 return signalPassFailure();
2285 if (mapping.areAllModulesSkipped())
2286 return markAllAnalysesPreserved();
2289 if (failed(solver.solve()))
2290 return signalPassFailure();
2293 if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2294 return signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static void print(TypedAttr val, llvm::raw_ostream &os)
static unsigned getFieldID(BundleType type, unsigned index)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, std::vector< Frame > &worklist)
Compute the value of a constraint expr.
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
bool isConst() const
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
This class represents a collection of InnerSymbolTable's.
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
llvm::hash_code hash_value(const ClassElement &element)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
static bool operator==(const ModulePort &a, const ModulePort &b)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(const llvm::Twine &str, unsigned width=80)
Write a "header"-like string to the debug stream with a certain width.
llvm::hash_code hash_value(const DenseSet< T > &set)
llvm::hash_code hash_value(const T &e)
This class represents the namespace in which InnerRef's can be resolved.