15 #include "mlir/Pass/Pass.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/Threading.h"
25 #include "llvm/ADT/APSInt.h"
26 #include "llvm/ADT/DenseSet.h"
27 #include "llvm/ADT/Hashing.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
32 #define DEBUG_TYPE "infer-widths"
36 #define GEN_PASS_DEF_INFERWIDTHS
37 #include "circt/Dialect/FIRRTL/Passes.h.inc"
41 using mlir::InferTypeOpInterface;
42 using mlir::WalkOrder;
44 using namespace circt;
45 using namespace firrtl;
53 auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
56 if (!basetype.hasUninferredWidth())
59 if (basetype.isGround())
60 diag.attachNote() <<
"Field: \"" << str <<
"\"";
61 else if (
auto vecType = type_dyn_cast<FVectorType>(basetype))
63 else if (
auto bundleType = type_dyn_cast<BundleType>(basetype))
64 for (
auto &elem : bundleType.getElements())
70 uint64_t convertedFieldID = 0;
72 auto curFID = fieldID;
77 if (isa<FVectorType>(curFType))
80 convertedFieldID += curFID - subID;
85 return convertedFieldID;
97 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
99 inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const T &e) {
106 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
109 return e.hash_value();
114 #define EXPR_NAMES(x) \
115 Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
116 #define EXPR_KINDS EXPR_NAMES()
117 #define EXPR_CLASSES EXPR_NAMES(Expr)
124 void print(llvm::raw_ostream &os)
const;
126 std::optional<int32_t> getSolution()
const {
132 void setSolution(int32_t solution) {
134 this->solution = solution;
137 Kind getKind()
const {
return kind; }
140 Expr(Kind kind) : kind(kind) {}
146 bool hasSolution =
false;
150 template <
class DerivedT, Expr::Kind DerivedKind>
151 struct ExprBase :
public Expr {
152 ExprBase() : Expr(DerivedKind) {}
153 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
155 if (
auto otherSame = dyn_cast<DerivedT>(other))
156 return *
static_cast<DerivedT *
>(
this) == otherSame;
162 struct VarExpr :
public ExprBase<VarExpr, Expr::Kind::Var> {
163 void print(llvm::raw_ostream &os)
const {
166 os <<
"var" << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFFF);
171 Expr *constraint =
nullptr;
174 Expr *upperBound =
nullptr;
175 std::optional<int32_t> upperBoundSolution;
182 struct DerivedExpr :
public ExprBase<DerivedExpr, Expr::Kind::Derived> {
183 void print(llvm::raw_ostream &os)
const {
186 << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFF);
190 Expr *assigned =
nullptr;
207 struct IdExpr :
public ExprBase<IdExpr, Expr::Kind::Id> {
208 IdExpr(Expr *arg) : arg(arg) {
assert(arg); }
209 void print(llvm::raw_ostream &os)
const { os <<
"*" << *arg; }
211 return getKind() == other.getKind() && arg == other.arg;
222 struct KnownExpr :
public ExprBase<KnownExpr, Expr::Kind::Known> {
223 KnownExpr(int32_t value) : ExprBase() { setSolution(value); }
224 void print(llvm::raw_ostream &os)
const { os << *getSolution(); }
225 bool operator==(
const KnownExpr &other)
const {
226 return *getSolution() == *other.getSolution();
231 int32_t getValue()
const {
return *getSolution(); }
236 struct UnaryExpr :
public Expr {
237 bool operator==(
const UnaryExpr &other)
const {
238 return getKind() == other.getKind() && arg == other.arg;
248 UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) {
assert(arg); }
252 template <
class DerivedT, Expr::Kind DerivedKind>
253 struct UnaryExprBase :
public UnaryExpr {
254 template <
typename... Args>
255 UnaryExprBase(Args &&...args)
256 : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
257 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
261 struct PowExpr :
public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
262 using UnaryExprBase::UnaryExprBase;
263 void print(llvm::raw_ostream &os)
const { os <<
"2^" << arg; }
268 struct BinaryExpr :
public Expr {
269 bool operator==(
const BinaryExpr &other)
const {
270 return getKind() == other.getKind() && lhs() == other.lhs() &&
271 rhs() == other.rhs();
276 Expr *lhs()
const {
return args[0]; }
277 Expr *rhs()
const {
return args[1]; }
283 BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
290 template <
class DerivedT, Expr::Kind DerivedKind>
291 struct BinaryExprBase :
public BinaryExpr {
292 template <
typename... Args>
293 BinaryExprBase(Args &&...args)
294 : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
295 static bool classof(
const Expr *e) {
return e->getKind() == DerivedKind; }
299 struct AddExpr :
public BinaryExprBase<AddExpr, Expr::Kind::Add> {
300 using BinaryExprBase::BinaryExprBase;
301 void print(llvm::raw_ostream &os)
const {
302 os <<
"(" << *lhs() <<
" + " << *rhs() <<
")";
307 struct MaxExpr :
public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
308 using BinaryExprBase::BinaryExprBase;
309 void print(llvm::raw_ostream &os)
const {
310 os <<
"max(" << *lhs() <<
", " << *rhs() <<
")";
315 struct MinExpr :
public BinaryExprBase<MinExpr, Expr::Kind::Min> {
316 using BinaryExprBase::BinaryExprBase;
317 void print(llvm::raw_ostream &os)
const {
318 os <<
"min(" << *lhs() <<
", " << *rhs() <<
")";
322 void Expr::print(llvm::raw_ostream &os)
const {
324 [&](
auto *e) { e->print(os); });
337 template <
typename T>
338 struct InternedSlotInfo : DenseMapInfo<T *> {
339 static T *getEmptyKey() {
341 return static_cast<T *
>(pointer);
343 static T *getTombstoneKey() {
345 return static_cast<T *
>(pointer);
347 static unsigned getHashValue(
const T *val) {
return mlir::hash_value(*val); }
348 static bool isEqual(
const T *lhs,
const T *rhs) {
349 auto empty = getEmptyKey();
350 auto tombstone = getTombstoneKey();
351 if (lhs ==
empty || rhs ==
empty || lhs == tombstone || rhs == tombstone)
359 template <
typename T,
typename std::enable_if_t<
360 std::is_trivially_destructible<T>::value,
int> = 0>
361 class InternedAllocator {
362 llvm::DenseSet<T *, InternedSlotInfo<T>> interned;
363 llvm::BumpPtrAllocator &allocator;
366 InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
371 template <
typename R = T,
typename... Args>
372 std::pair<R *, bool> alloc(Args &&...args) {
373 auto stackValue = R(std::forward<Args>(args)...);
374 auto *stackSlot = &stackValue;
375 auto it = interned.find(stackSlot);
376 if (it != interned.end())
377 return std::make_pair(
static_cast<R *
>(*it),
false);
378 auto heapValue =
new (allocator) R(std::move(stackValue));
379 interned.insert(heapValue);
380 return std::make_pair(heapValue,
true);
386 template <
typename T,
typename std::enable_if_t<
387 std::is_trivially_destructible<T>::value,
int> = 0>
389 llvm::BumpPtrAllocator &allocator;
392 Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
396 template <
typename R = T,
typename... Args>
397 R *alloc(Args &&...args) {
398 return new (allocator) R(std::forward<Args>(args)...);
429 int32_t recScale = 0;
431 int32_t nonrecBias = 0;
435 static LinIneq unsat() {
return LinIneq(
true); }
438 explicit LinIneq(
bool failed =
false) : failed(failed) {}
441 explicit LinIneq(int32_t bias) : nonrecBias(bias) {}
444 explicit LinIneq(int32_t scale, int32_t bias) {
455 explicit LinIneq(int32_t recScale, int32_t recBias, int32_t nonrecBias,
459 this->recScale = recScale;
460 this->recBias = recBias;
461 this->nonrecBias = nonrecBias;
463 this->nonrecBias = std::max(recBias, nonrecBias);
474 static LinIneq max(
const LinIneq &lhs,
const LinIneq &rhs) {
475 return LinIneq(std::max(lhs.recScale, rhs.recScale),
476 std::max(lhs.recBias, rhs.recBias),
477 std::max(lhs.nonrecBias, rhs.nonrecBias),
478 lhs.failed || rhs.failed);
534 static LinIneq add(
const LinIneq &lhs,
const LinIneq &rhs) {
537 auto enable1 = lhs.recScale > 0 && rhs.recScale > 0;
538 auto enable2 = lhs.recScale > 0;
539 auto enable3 = rhs.recScale > 0;
540 auto scale1 = lhs.recScale + rhs.recScale;
541 auto scale2 = lhs.recScale;
542 auto scale3 = rhs.recScale;
543 auto bias1 = lhs.recBias + rhs.recBias;
544 auto bias2 = lhs.recBias + rhs.nonrecBias;
545 auto bias3 = rhs.recBias + lhs.nonrecBias;
546 auto maxScale = std::max(scale1, std::max(scale2, scale3));
550 std::optional<int32_t> maxBias;
551 if (enable1 && scale1 == maxScale)
553 if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
555 if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
560 auto nonrecBias = lhs.nonrecBias + rhs.nonrecBias;
561 auto failed = lhs.failed || rhs.failed;
562 if (enable1 && scale1 == maxScale && bias1 == *maxBias)
563 return LinIneq(scale1, bias1, nonrecBias, failed);
564 if (enable2 && scale2 == maxScale && bias2 == *maxBias)
565 return LinIneq(scale2, bias2, nonrecBias, failed);
566 if (enable3 && scale3 == maxScale && bias3 == *maxBias)
567 return LinIneq(scale3, bias3, nonrecBias, failed);
568 return LinIneq(0, 0, nonrecBias, failed);
581 if (recScale == 1 && recBias > 0)
587 void print(llvm::raw_ostream &os)
const {
589 bool both = (recScale != 0 || recBias != 0) && nonrecBias != 0;
596 os << recScale <<
"*";
602 os <<
" - " << -recBias;
604 os <<
" + " << recBias;
612 if (nonrecBias != 0) {
629 class ConstraintSolver {
631 ConstraintSolver() =
default;
634 auto *v = vars.alloc();
635 varExprs.push_back(v);
637 info[v].insert(currentInfo);
639 locs[v].insert(*currentLoc);
642 DerivedExpr *derived() {
643 auto *d = derivs.alloc();
644 derivedExprs.push_back(d);
647 KnownExpr *known(int32_t value) {
return alloc<KnownExpr>(knowns, value); }
648 IdExpr *id(Expr *arg) {
return alloc<IdExpr>(ids, arg); }
649 PowExpr *pow(Expr *arg) {
return alloc<PowExpr>(uns, arg); }
650 AddExpr *add(Expr *lhs, Expr *rhs) {
return alloc<AddExpr>(bins, lhs, rhs); }
651 MaxExpr *max(Expr *lhs, Expr *rhs) {
return alloc<MaxExpr>(bins, lhs, rhs); }
652 MinExpr *min(Expr *lhs, Expr *rhs) {
return alloc<MinExpr>(bins, lhs, rhs); }
656 Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
658 lhs->constraint = max(lhs->constraint, rhs);
660 lhs->constraint = id(rhs);
661 return lhs->constraint;
666 Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
668 lhs->upperBound = min(lhs->upperBound, rhs);
670 lhs->upperBound = id(rhs);
671 return lhs->upperBound;
674 void dumpConstraints(llvm::raw_ostream &os);
675 LogicalResult solve();
677 using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
678 const ContextInfo &getContextInfo()
const {
return info; }
679 void setCurrentContextInfo(
FieldRef fieldRef) { currentInfo = fieldRef; }
680 void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
684 llvm::BumpPtrAllocator allocator;
685 Allocator<VarExpr> vars = {allocator};
686 Allocator<DerivedExpr> derivs = {allocator};
687 InternedAllocator<KnownExpr> knowns = {allocator};
688 InternedAllocator<IdExpr> ids = {allocator};
689 InternedAllocator<UnaryExpr> uns = {allocator};
690 InternedAllocator<BinaryExpr> bins = {allocator};
693 std::vector<VarExpr *> varExprs;
694 std::vector<DerivedExpr *> derivedExprs;
697 template <
typename R,
typename T,
typename... Args>
698 R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
699 auto [expr, inserted] =
700 allocator.template alloc<R>(std::forward<Args>(args)...);
702 info[expr].insert(currentInfo);
704 locs[expr].insert(*currentLoc);
712 DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
713 std::optional<Location> currentLoc;
717 ConstraintSolver(ConstraintSolver &&) =
delete;
718 ConstraintSolver(
const ConstraintSolver &) =
delete;
719 ConstraintSolver &operator=(ConstraintSolver &&) =
delete;
720 ConstraintSolver &operator=(
const ConstraintSolver &) =
delete;
722 void emitUninferredWidthError(VarExpr *var);
724 LinIneq checkCycles(VarExpr *var, Expr *expr,
725 SmallPtrSetImpl<Expr *> &seenVars,
726 InFlightDiagnostic *reportInto =
nullptr,
727 unsigned indent = 1);
733 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
734 for (
auto *v : varExprs) {
736 os <<
"- " << *v <<
" >= " << *v->constraint <<
"\n";
738 os <<
"- " << *v <<
" unconstrained\n";
743 inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const LinIneq &l) {
759 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
760 SmallPtrSetImpl<Expr *> &seenVars,
761 InFlightDiagnostic *reportInto,
764 TypeSwitch<Expr *, LinIneq>(expr)
766 [&](
auto *expr) {
return LinIneq(expr->getValue()); })
767 .Case<VarExpr>([&](
auto *expr) {
769 return LinIneq(1, 0);
770 if (!seenVars.insert(expr).second)
777 if (!expr->constraint)
780 auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
782 seenVars.erase(expr);
785 .Case<IdExpr>([&](
auto *expr) {
786 return checkCycles(var, expr->arg, seenVars, reportInto,
789 .Case<PowExpr>([&](
auto *expr) {
794 checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
795 if (arg.recScale != 0 || arg.nonrecBias < 0 || arg.nonrecBias >= 31)
796 return LinIneq::unsat();
797 return LinIneq(1 << arg.nonrecBias);
799 .Case<AddExpr>([&](
auto *expr) {
801 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
802 checkCycles(var, expr->rhs(), seenVars, reportInto,
805 .Case<MaxExpr, MinExpr>([&](
auto *expr) {
810 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
811 checkCycles(var, expr->rhs(), seenVars, reportInto,
814 .Default([](
auto) {
return LinIneq::unsat(); });
819 if (reportInto && !ineq.sat()) {
820 auto report = [&](Location loc) {
821 auto ¬e = reportInto->attachNote(loc);
822 note <<
"constrained width W >= ";
823 if (ineq.recScale == -1)
825 if (ineq.recScale != 1)
826 note << ineq.recScale;
828 if (ineq.recBias < 0)
829 note <<
"-" << -ineq.recBias;
830 if (ineq.recBias > 0)
831 note <<
"+" << ineq.recBias;
834 auto it = locs.find(expr);
835 if (it != locs.end())
836 for (
auto loc : it->second)
840 LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
841 <<
"- Visited " << *expr <<
": " << ineq <<
"\n");
851 arg.first = operation(*arg.first);
857 llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
858 auto result =
ExprSolution{std::nullopt, lhs.second || rhs.second};
859 if (lhs.first && rhs.first)
860 result.first = operation(*lhs.first, *rhs.first);
862 result.first = lhs.first;
864 result.first = rhs.first;
870 Frame(Expr *expr,
unsigned indent) : expr(expr), indent(indent) {}
883 std::vector<Frame> &worklist) {
885 worklist.emplace_back(expr, 1);
886 llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
888 while (!worklist.empty()) {
889 auto &frame = worklist.back();
890 auto indent = frame.indent;
893 if (solution.first && !solution.second)
894 frame.expr->setSolution(*solution.first);
895 solvedExprs[frame.expr] = solution;
899 if (!isa<KnownExpr>(frame.expr)) {
901 llvm::dbgs().indent(indent * 2)
902 <<
"= Solved " << *frame.expr <<
" = " << *solution.first;
904 llvm::dbgs().indent(indent * 2) <<
"= Skipped " << *frame.expr;
905 llvm::dbgs() <<
" (" << (solution.second ?
"cycle broken" :
"unique")
914 if (frame.expr->getSolution()) {
916 if (!isa<KnownExpr>(frame.expr))
917 llvm::dbgs().indent(indent * 2) <<
"- Cached " << *frame.expr <<
" = "
918 << *frame.expr->getSolution() <<
"\n";
920 setSolution(
ExprSolution{*frame.expr->getSolution(),
false});
926 if (!isa<KnownExpr>(frame.expr))
927 llvm::dbgs().indent(indent * 2) <<
"- Solving " << *frame.expr <<
"\n";
930 TypeSwitch<Expr *>(frame.expr)
931 .Case<KnownExpr>([&](
auto *expr) {
934 .Case<VarExpr>([&](
auto *expr) {
935 if (solvedExprs.contains(expr->constraint)) {
936 auto solution = solvedExprs[expr->constraint];
943 if (expr->upperBound && solvedExprs.contains(expr->upperBound))
944 expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
945 seenVars.erase(expr);
947 if (solution.first && *solution.first < 0)
949 return setSolution(solution);
953 if (!expr->constraint)
958 if (!seenVars.insert(expr).second)
961 worklist.emplace_back(expr->constraint, indent + 1);
962 if (expr->upperBound)
963 worklist.emplace_back(expr->upperBound, indent + 1);
965 .Case<IdExpr>([&](
auto *expr) {
966 if (solvedExprs.contains(expr->arg))
967 return setSolution(solvedExprs[expr->arg]);
968 worklist.emplace_back(expr->arg, indent + 1);
970 .Case<PowExpr>([&](
auto *expr) {
971 if (solvedExprs.contains(expr->arg))
973 solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
975 worklist.emplace_back(expr->arg, indent + 1);
977 .Case<AddExpr>([&](
auto *expr) {
978 if (solvedExprs.contains(expr->lhs()) &&
979 solvedExprs.contains(expr->rhs()))
981 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
982 [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
984 worklist.emplace_back(expr->lhs(), indent + 1);
985 worklist.emplace_back(expr->rhs(), indent + 1);
987 .Case<MaxExpr>([&](
auto *expr) {
988 if (solvedExprs.contains(expr->lhs()) &&
989 solvedExprs.contains(expr->rhs()))
991 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
992 [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
994 worklist.emplace_back(expr->lhs(), indent + 1);
995 worklist.emplace_back(expr->rhs(), indent + 1);
997 .Case<MinExpr>([&](
auto *expr) {
998 if (solvedExprs.contains(expr->lhs()) &&
999 solvedExprs.contains(expr->rhs()))
1001 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
1002 [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
1004 worklist.emplace_back(expr->lhs(), indent + 1);
1005 worklist.emplace_back(expr->rhs(), indent + 1);
1007 .Default([&](
auto) {
1012 return solvedExprs[expr];
1018 LogicalResult ConstraintSolver::solve() {
1020 llvm::dbgs() <<
"\n";
1022 dumpConstraints(llvm::dbgs());
1027 llvm::dbgs() <<
"\n";
1028 debugHeader(
"Checking for unbreakable loops") <<
"\n\n";
1030 SmallPtrSet<Expr *, 16> seenVars;
1031 bool anyFailed =
false;
1033 for (
auto *var : varExprs) {
1034 if (!var->constraint)
1036 LLVM_DEBUG(llvm::dbgs()
1037 <<
"- Checking " << *var <<
" >= " << *var->constraint <<
"\n");
1042 seenVars.insert(var);
1043 auto ineq = checkCycles(var, var->constraint, seenVars);
1052 LLVM_DEBUG(llvm::dbgs()
1053 <<
" = Breakable since " << ineq <<
" satisfiable\n");
1061 LLVM_DEBUG(llvm::dbgs()
1062 <<
" = UNBREAKABLE since " << ineq <<
" unsatisfiable\n");
1064 for (
auto fieldRef : info.find(var)->second) {
1067 auto *op = fieldRef.getDefiningOp();
1068 auto diag = op ? op->emitOpError()
1069 : mlir::emitError(fieldRef.getValue().getLoc())
1071 diag <<
"is constrained to be wider than itself";
1074 seenVars.insert(var);
1075 checkCycles(var, var->constraint, seenVars, &diag);
1087 llvm::dbgs() <<
"\n";
1090 std::vector<Frame> worklist;
1091 for (
auto *var : varExprs) {
1093 if (!var->constraint) {
1094 LLVM_DEBUG(llvm::dbgs() <<
"- Unconstrained " << *var <<
"\n");
1095 emitUninferredWidthError(var);
1101 LLVM_DEBUG(llvm::dbgs()
1102 <<
"- Solving " << *var <<
" >= " << *var->constraint <<
"\n");
1103 seenVars.insert(var);
1104 auto solution =
solveExpr(var->constraint, seenVars, worklist);
1106 if (var->upperBound && !var->upperBoundSolution)
1107 var->upperBoundSolution =
1108 solveExpr(var->upperBound, seenVars, worklist).first;
1112 if (solution.first) {
1113 if (*solution.first < 0)
1115 var->setSolution(*solution.first);
1120 if (!solution.first) {
1121 LLVM_DEBUG(llvm::dbgs() <<
" - UNSOLVED " << *var <<
"\n");
1122 emitUninferredWidthError(var);
1126 LLVM_DEBUG(llvm::dbgs()
1127 <<
" = Solved " << *var <<
" = " << solution.first <<
" ("
1128 << (solution.second ?
"cycle broken" :
"unique") <<
")\n");
1131 if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1132 LLVM_DEBUG(llvm::dbgs() <<
" ! Unsatisfiable " << *var
1133 <<
" <= " << var->upperBoundSolution <<
"\n");
1134 emitUninferredWidthError(var);
1140 for (
auto *derived : derivedExprs) {
1141 auto *assigned = derived->assigned;
1142 if (!assigned || !assigned->getSolution()) {
1143 LLVM_DEBUG(llvm::dbgs() <<
"- Unused " << *derived <<
" set to 0\n");
1144 derived->setSolution(0);
1146 LLVM_DEBUG(llvm::dbgs() <<
"- Deriving " << *derived <<
" = "
1147 << assigned->getSolution() <<
"\n");
1148 derived->setSolution(*assigned->getSolution());
1152 return failure(anyFailed);
1157 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1158 FieldRef fieldRef = info.find(var)->second.back();
1161 auto diag = mlir::emitError(value.getLoc(),
"uninferred width:");
1164 if (isa<BlockArgument>(value)) {
1166 }
else if (
auto *op = value.getDefiningOp()) {
1167 TypeSwitch<Operation *>(op)
1168 .Case<WireOp>([&](
auto) { diag <<
" wire"; })
1169 .Case<RegOp, RegResetOp>([&](
auto) { diag <<
" reg"; })
1170 .Case<NodeOp>([&](
auto) { diag <<
" node"; })
1171 .Default([&](
auto) { diag <<
" value"; });
1178 if (!fieldName.empty()) {
1181 diag <<
" \"" << fieldName <<
"\"";
1184 if (!var->constraint) {
1185 diag <<
" is unconstrained";
1186 }
else if (var->getSolution() && var->upperBoundSolution &&
1187 var->getSolution() > var->upperBoundSolution) {
1188 diag <<
" cannot satisfy all width requirements";
1189 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1190 LLVM_DEBUG(llvm::dbgs() << *var->upperBound <<
"\n");
1191 auto loc = locs.find(var->constraint)->second.back();
1192 diag.attachNote(loc) <<
"width is constrained to be at least "
1193 << *var->getSolution() <<
" here:";
1194 loc = locs.find(var->upperBound)->second.back();
1195 diag.attachNote(loc) <<
"width is constrained to be at most "
1196 << *var->upperBoundSolution <<
" here:";
1198 diag <<
" width cannot be determined";
1199 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1200 auto loc = locs.find(var->constraint)->second.back();
1201 diag.attachNote(loc) <<
"width is constrained by an uninferred width here:";
1213 class InferenceMapping {
1215 InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1216 hw::InnerSymbolTableCollection &istc)
1217 : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1219 LogicalResult map(CircuitOp op);
1220 bool allWidthsKnown(Operation *op);
1221 LogicalResult mapOperation(Operation *op);
1226 void declareVars(Value value,
bool isDerived =
false);
1231 void maximumOfTypes(Value result, Value rhs, Value lhs);
1235 void constrainTypes(Value larger, Value smaller,
bool equal =
false);
1239 void constrainTypes(Expr *larger, Expr *smaller,
1240 bool imposeUpperBounds =
false,
bool equal =
false);
1249 Expr *getExpr(Value value)
const;
1252 Expr *getExpr(
FieldRef fieldRef)
const;
1256 Expr *getExprOrNull(
FieldRef fieldRef)
const;
1260 void setExpr(Value value, Expr *expr);
1263 void setExpr(
FieldRef fieldRef, Expr *expr);
1266 bool isModuleSkipped(FModuleOp module)
const {
1267 return skippedModules.count(module);
1271 bool areAllModulesSkipped()
const {
return allModulesSkipped; }
1275 ConstraintSolver &solver;
1278 DenseMap<FieldRef, Expr *> opExprs;
1281 SmallPtrSet<Operation *, 16> skippedModules;
1282 bool allModulesSkipped =
true;
1285 SymbolTable &symtbl;
1288 hw::InnerRefNamespace irn;
1295 return TypeSwitch<Type, bool>(type)
1296 .Case<
FIRRTLBaseType>([](
auto base) {
return base.hasUninferredWidth(); })
1298 [](
auto ref) {
return ref.getType().hasUninferredWidth(); })
1299 .Default([](
auto) {
return false; });
1302 LogicalResult InferenceMapping::map(CircuitOp op) {
1303 LLVM_DEBUG(llvm::dbgs()
1304 <<
"\n===----- Mapping ops to constraint exprs -----===\n\n");
1307 for (
auto module : op.getOps<FModuleOp>())
1308 for (
auto arg : module.getArguments()) {
1309 solver.setCurrentContextInfo(
FieldRef(arg, 0));
1313 for (
auto module : op.getOps<FModuleOp>()) {
1316 bool anyUninferred =
false;
1317 for (
auto arg : module.getArguments()) {
1322 module.walk([&](Operation *op) {
1323 for (
auto type : op->getResultTypes())
1326 return WalkResult::interrupt();
1327 return WalkResult::advance();
1330 if (!anyUninferred) {
1331 LLVM_DEBUG(llvm::dbgs() <<
"Skipping fully-inferred module '"
1332 << module.getName() <<
"'\n");
1333 skippedModules.insert(module);
1337 allModulesSkipped =
false;
1341 auto result = module.getBodyBlock()->walk(
1342 [&](Operation *op) {
return WalkResult(mapOperation(op)); });
1343 if (result.wasInterrupted())
1350 bool InferenceMapping::allWidthsKnown(Operation *op) {
1352 if (isa<PropAssignOp>(op))
1357 if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1362 if (isa<FConnectLike, AttachOp>(op))
1366 return llvm::all_of(op->getResults(), [&](
auto result) {
1369 if (auto type = type_dyn_cast<FIRRTLType>(result.getType()))
1370 if (hasUninferredWidth(type))
1376 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1377 if (allWidthsKnown(op))
1381 bool mappingFailed =
false;
1382 solver.setCurrentContextInfo(
1384 solver.setCurrentLocation(op->getLoc());
1385 TypeSwitch<Operation *>(op)
1386 .Case<ConstantOp>([&](
auto op) {
1389 auto v = op.getValue();
1390 auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1391 : v.countLeadingZeros());
1394 setExpr(op.getResult(), solver.known(std::max(w, 1u)));
1396 .Case<SpecialConstantOp>([&](
auto op) {
1399 .Case<InvalidValueOp>([&](
auto op) {
1402 declareVars(op.getResult(),
true);
1404 .Case<WireOp, RegOp>([&](
auto op) { declareVars(op.getResult()); })
1405 .Case<RegResetOp>([&](
auto op) {
1410 declareVars(op.getResult());
1413 constrainTypes(op.getResult(), op.getResetValue());
1415 .Case<NodeOp>([&](
auto op) {
1418 op.getResult().getType());
1422 .Case<SubfieldOp>([&](
auto op) {
1423 BundleType bundleType = op.getInput().getType();
1424 auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1425 unifyTypes(
FieldRef(op.getResult(), 0),
1426 FieldRef(op.getInput(), fieldID), op.getType());
1428 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
1434 .Case<SubtagOp>([&](
auto op) {
1435 FEnumType enumType = op.getInput().getType();
1436 auto fieldID = enumType.getFieldID(op.getFieldIndex());
1437 unifyTypes(
FieldRef(op.getResult(), 0),
1438 FieldRef(op.getInput(), fieldID), op.getType());
1441 .Case<RefSubOp>([&](RefSubOp op) {
1442 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1443 op.getInput().getType().getType())
1444 .Case<FVectorType>([](
auto _) {
return 1; })
1445 .Case<BundleType>([&](
auto type) {
1446 return type.getFieldID(op.getIndex());
1448 unifyTypes(
FieldRef(op.getResult(), 0),
1449 FieldRef(op.getInput(), fieldID), op.getType());
1453 .Case<AddPrimOp, SubPrimOp>([&](
auto op) {
1454 auto lhs = getExpr(op.getLhs());
1455 auto rhs = getExpr(op.getRhs());
1456 auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1457 setExpr(op.getResult(), e);
1459 .Case<MulPrimOp>([&](
auto op) {
1460 auto lhs = getExpr(op.getLhs());
1461 auto rhs = getExpr(op.getRhs());
1462 auto e = solver.add(lhs, rhs);
1463 setExpr(op.getResult(), e);
1465 .Case<DivPrimOp>([&](
auto op) {
1466 auto lhs = getExpr(op.getLhs());
1468 if (op.getType().base().isSigned()) {
1469 e = solver.add(lhs, solver.known(1));
1473 setExpr(op.getResult(), e);
1475 .Case<RemPrimOp>([&](
auto op) {
1476 auto lhs = getExpr(op.getLhs());
1477 auto rhs = getExpr(op.getRhs());
1478 auto e = solver.min(lhs, rhs);
1479 setExpr(op.getResult(), e);
1481 .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](
auto op) {
1482 auto lhs = getExpr(op.getLhs());
1483 auto rhs = getExpr(op.getRhs());
1484 auto e = solver.max(lhs, rhs);
1485 setExpr(op.getResult(), e);
1489 .Case<CatPrimOp>([&](
auto op) {
1490 auto lhs = getExpr(op.getLhs());
1491 auto rhs = getExpr(op.getRhs());
1492 auto e = solver.add(lhs, rhs);
1493 setExpr(op.getResult(), e);
1495 .Case<DShlPrimOp>([&](
auto op) {
1496 auto lhs = getExpr(op.getLhs());
1497 auto rhs = getExpr(op.getRhs());
1498 auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1499 setExpr(op.getResult(), e);
1501 .Case<DShlwPrimOp, DShrPrimOp>([&](
auto op) {
1502 auto e = getExpr(op.getLhs());
1503 setExpr(op.getResult(), e);
1507 .Case<NegPrimOp>([&](
auto op) {
1508 auto input = getExpr(op.getInput());
1509 auto e = solver.add(input, solver.known(1));
1510 setExpr(op.getResult(), e);
1512 .Case<CvtPrimOp>([&](
auto op) {
1513 auto input = getExpr(op.getInput());
1514 auto e = op.getInput().getType().base().isSigned()
1516 : solver.add(input, solver.known(1));
1517 setExpr(op.getResult(), e);
1521 .Case<BitsPrimOp>([&](
auto op) {
1522 setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1524 .Case<HeadPrimOp>([&](
auto op) {
1525 setExpr(op.getResult(), solver.known(op.getAmount()));
1527 .Case<TailPrimOp>([&](
auto op) {
1528 auto input = getExpr(op.getInput());
1529 auto e = solver.add(input, solver.known(-op.getAmount()));
1530 setExpr(op.getResult(), e);
1532 .Case<PadPrimOp>([&](
auto op) {
1533 auto input = getExpr(op.getInput());
1534 auto e = solver.max(input, solver.known(op.getAmount()));
1535 setExpr(op.getResult(), e);
1537 .Case<ShlPrimOp>([&](
auto op) {
1538 auto input = getExpr(op.getInput());
1539 auto e = solver.add(input, solver.known(op.getAmount()));
1540 setExpr(op.getResult(), e);
1542 .Case<ShrPrimOp>([&](
auto op) {
1543 auto input = getExpr(op.getInput());
1545 auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1546 auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1547 solver.known(minWidth));
1548 setExpr(op.getResult(), e);
1552 .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1553 [&](
auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1554 .Case<mlir::UnrealizedConversionCastOp>(
1555 [&](
auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1559 .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1560 AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1561 XorRPrimOp>([&](
auto op) {
1562 auto width = op.getType().getBitWidthOrSentinel();
1563 assert(
width > 0 &&
"width should have been checked by verifier");
1564 setExpr(op.getResult(), solver.known(
width));
1566 .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](
auto op) {
1567 auto *sel = getExpr(op.getSel());
1568 constrainTypes(solver.known(1), sel,
true);
1569 maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1571 .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1572 auto *sel = getExpr(op.getSel());
1573 constrainTypes(solver.known(2), sel,
true);
1574 maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1575 maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1576 maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1579 .Case<ConnectOp, MatchingConnectOp>(
1580 [&](
auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1581 .Case<RefDefineOp>([&](
auto op) {
1584 constrainTypes(op.getDest(), op.getSrc(),
true);
1586 .Case<AttachOp>([&](
auto op) {
1590 if (op.getAttached().empty())
1592 auto prev = op.getAttached()[0];
1593 for (
auto operand : op.getAttached().drop_front()) {
1594 auto e1 = getExpr(prev);
1595 auto e2 = getExpr(operand);
1596 constrainTypes(e1, e2,
true);
1597 constrainTypes(e2, e1,
true);
1603 .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1604 UnclockedAssumeIntrinsicOp, CoverOp>([&](
auto) {})
1607 .Case<InstanceOp>([&](
auto op) {
1608 auto refdModule = op.getReferencedOperation(symtbl);
1609 auto module = dyn_cast<FModuleOp>(&*refdModule);
1611 auto diag = mlir::emitError(op.getLoc());
1612 diag <<
"extern module `" << op.getModuleName()
1613 <<
"` has ports of uninferred width";
1615 auto fml = cast<FModuleLike>(&*refdModule);
1616 auto ports = fml.getPorts();
1617 for (
auto &port : ports) {
1619 if (baseType && baseType.hasUninferredWidth()) {
1620 diag.attachNote(op.getLoc()) <<
"Port: " << port.name;
1621 if (!baseType.isGround())
1626 diag.attachNote(op.getLoc())
1627 <<
"Only non-extern FIRRTL modules may contain unspecified "
1628 "widths to be inferred automatically.";
1629 diag.attachNote(refdModule->getLoc())
1630 <<
"Module `" << op.getModuleName() <<
"` defined here:";
1631 mappingFailed =
true;
1638 for (
auto [result, arg] :
1639 llvm::zip(op->getResults(), module.getArguments()))
1640 unifyTypes({result, 0}, {arg, 0},
1641 type_cast<FIRRTLType>(result.getType()));
1645 .Case<MemOp>([&](MemOp op) {
1647 unsigned nonDebugPort = 0;
1648 for (
const auto &result : llvm::enumerate(op.getResults())) {
1649 declareVars(result.value());
1650 if (!type_isa<RefType>(result.value().getType()))
1651 nonDebugPort = result.index();
1656 auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1657 static const unsigned indices[] = {3, 5};
1658 static const unsigned debug[] = {0};
1660 case MemOp::PortKind::Read:
1661 case MemOp::PortKind::Write:
1662 return ArrayRef<unsigned>(indices, 1);
1663 case MemOp::PortKind::ReadWrite:
1664 return ArrayRef<unsigned>(indices);
1665 case MemOp::PortKind::Debug:
1666 return ArrayRef<unsigned>(
debug);
1668 llvm_unreachable(
"Imposible PortKind");
1675 unsigned firstFieldIndex =
1676 dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1678 op.getResult(nonDebugPort),
1679 type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1680 .getFieldID(firstFieldIndex));
1681 LLVM_DEBUG(llvm::dbgs() <<
"Adjusting memory port variables:\n");
1684 auto dataType = op.getDataType();
1685 for (
unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1686 auto result = op.getResult(i);
1687 if (type_isa<RefType>(result.getType())) {
1691 unifyTypes(firstData,
FieldRef(result, 1), dataType);
1696 type_cast<BundleType>(op.getPortType(i).getPassiveType());
1697 for (
auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1698 unifyTypes(
FieldRef(result, portType.getFieldID(fieldIndex)),
1699 firstData, dataType);
1703 .Case<RefSendOp>([&](
auto op) {
1704 declareVars(op.getResult());
1705 constrainTypes(op.getResult(), op.getBase(),
true);
1707 .Case<RefResolveOp>([&](
auto op) {
1708 declareVars(op.getResult());
1709 constrainTypes(op.getResult(), op.getRef(),
true);
1711 .Case<RefCastOp>([&](
auto op) {
1712 declareVars(op.getResult());
1713 constrainTypes(op.getResult(), op.getInput(),
true);
1715 .Case<RWProbeOp>([&](
auto op) {
1716 auto ist = irn.lookup(op.getTarget());
1718 op->emitError(
"target of rwprobe could not be resolved");
1719 mappingFailed =
true;
1724 op->emitError(
"target of rwprobe resolved to unsupported target");
1725 mappingFailed =
true;
1729 ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1730 unifyTypes(
FieldRef(op.getResult(), 0),
1731 FieldRef(ref.getValue(), newFID), op.getType());
1733 .Case<mlir::UnrealizedConversionCastOp>([&](
auto op) {
1734 for (Value result : op.getResults()) {
1735 auto ty = result.getType();
1736 if (type_isa<FIRRTLType>(ty))
1737 declareVars(result);
1740 .Default([&](
auto op) {
1741 op->emitOpError(
"not supported in width inference");
1742 mappingFailed =
true;
1746 if (
auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1750 return failure(mappingFailed);
1755 void InferenceMapping::declareVars(Value value,
bool isDerived) {
1758 unsigned fieldID = 0;
1760 auto width = type.getBitWidthOrSentinel();
1763 }
else if (
width == -1) {
1766 solver.setCurrentContextInfo(field);
1768 setExpr(field, solver.derived());
1770 setExpr(field, solver.var());
1772 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1775 for (
auto &element : bundleType)
1776 declare(element.type);
1777 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1779 auto save = fieldID;
1780 declare(vecType.getElementType());
1782 fieldID = save + vecType.getMaxFieldID();
1783 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1785 for (
auto &element : enumType.getElements())
1786 declare(element.type);
1788 llvm_unreachable(
"Unknown type inside a bundle!");
1798 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1802 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1804 for (
auto &element : bundleType.getElements())
1805 maximize(element.type);
1806 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1808 auto save = fieldID;
1810 if (vecType.getNumElements() > 0)
1811 maximize(vecType.getElementType());
1812 fieldID = save + vecType.getMaxFieldID();
1813 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1815 for (
auto &element : enumType.getElements())
1816 maximize(element.type);
1817 }
else if (type.isGround()) {
1818 auto *e = solver.max(getExpr(
FieldRef(rhs, fieldID)),
1820 setExpr(
FieldRef(result, fieldID), e);
1823 llvm_unreachable(
"Unknown type inside a bundle!");
1838 void InferenceMapping::constrainTypes(Value larger, Value smaller,
bool equal) {
1845 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1847 for (
auto &element : bundleType.getElements()) {
1849 constrain(element.type, smaller, larger);
1851 constrain(element.type, larger, smaller);
1853 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1855 auto save = fieldID;
1857 if (vecType.getNumElements() > 0) {
1858 constrain(vecType.getElementType(), larger, smaller);
1860 fieldID = save + vecType.getMaxFieldID();
1861 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1863 for (
auto &element : enumType.getElements())
1864 constrain(element.type, larger, smaller);
1865 }
else if (type.isGround()) {
1867 constrainTypes(getExpr(
FieldRef(larger, fieldID)),
1868 getExpr(
FieldRef(smaller, fieldID)),
false, equal);
1871 llvm_unreachable(
"Unknown type inside a bundle!");
1876 constrain(type, larger, smaller);
1881 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1882 bool imposeUpperBounds,
bool equal) {
1883 assert(larger &&
"Larger expression should be specified");
1884 assert(smaller &&
"Smaller expression should be specified");
1890 if (
auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1891 largerDerived->assigned = smaller;
1892 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *largerDerived <<
" from "
1893 << *smaller <<
"\n");
1896 if (
auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1897 smallerDerived->assigned = larger;
1898 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *smallerDerived <<
" from "
1899 << *larger <<
"\n");
1905 if (
auto *largerVar = dyn_cast<VarExpr>(larger)) {
1906 [[maybe_unused]]
auto *c = solver.addGeqConstraint(largerVar, smaller);
1907 LLVM_DEBUG(llvm::dbgs()
1908 <<
"Constrained " << *largerVar <<
" >= " << *c <<
"\n");
1915 [[maybe_unused]]
auto *leq = solver.addLeqConstraint(largerVar, smaller);
1916 LLVM_DEBUG(llvm::dbgs()
1917 <<
"Constrained " << *largerVar <<
" <= " << *leq <<
"\n");
1927 if (
auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1928 if (imposeUpperBounds || equal) {
1929 [[maybe_unused]]
auto *c = solver.addLeqConstraint(smallerVar, larger);
1930 LLVM_DEBUG(llvm::dbgs()
1931 <<
"Constrained " << *smallerVar <<
" <= " << *c <<
"\n");
1951 LLVM_DEBUG(llvm::dbgs()
1952 <<
"Unify " <<
getFieldName(lhsFieldRef).first <<
" = "
1955 if (
auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1956 solver.addGeqConstraint(var, solver.known(0));
1957 setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1959 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1961 for (
auto &element : bundleType) {
1962 unify(element.type);
1964 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1966 auto save = fieldID;
1968 if (vecType.getNumElements() > 0) {
1969 unify(vecType.getElementType());
1971 fieldID = save + vecType.getMaxFieldID();
1972 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1974 for (
auto &element : enumType.getElements())
1975 unify(element.type);
1977 llvm_unreachable(
"Unknown type inside a bundle!");
1985 Expr *InferenceMapping::getExpr(Value value)
const {
1988 return getExpr(
FieldRef(value, 0));
1992 Expr *InferenceMapping::getExpr(
FieldRef fieldRef)
const {
1993 auto *expr = getExprOrNull(fieldRef);
1994 assert(expr &&
"constraint expr should have been constructed for value");
1998 Expr *InferenceMapping::getExprOrNull(
FieldRef fieldRef)
const {
1999 auto it = opExprs.find(fieldRef);
2000 if (it != opExprs.end())
2007 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
2010 return solver.known(
width);
2014 void InferenceMapping::setExpr(Value value, Expr *expr) {
2021 void InferenceMapping::setExpr(
FieldRef fieldRef, Expr *expr) {
2023 llvm::dbgs() <<
"Expr " << *expr <<
" for " << fieldRef.
getValue();
2025 llvm::dbgs() <<
" '" <<
getFieldName(fieldRef).first <<
"'";
2027 if (fieldName.second)
2028 llvm::dbgs() <<
" (\"" << fieldName.first <<
"\")";
2029 llvm::dbgs() <<
"\n";
2031 opExprs[fieldRef] = expr;
2041 class InferenceTypeUpdate {
2043 InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2045 LogicalResult update(CircuitOp op);
2046 FailureOr<bool> updateOperation(Operation *op);
2047 FailureOr<bool> updateValue(Value value);
2051 const InferenceMapping &mapping;
2057 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2059 llvm::dbgs() <<
"\n";
2062 return mlir::failableParallelForEach(
2063 op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2066 if (mapping.isModuleSkipped(op))
2068 auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2069 if (failed(updateOperation(op)))
2070 return WalkResult::interrupt();
2071 return WalkResult::advance();
2072 }).wasInterrupted();
2073 return failure(isFailed);
2078 FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2079 bool anyChanged =
false;
2081 for (Value v : op->getResults()) {
2082 auto result = updateValue(v);
2085 anyChanged |= *result;
2091 if (
auto con = dyn_cast<ConnectOp>(op)) {
2092 auto lhs = con.getDest();
2093 auto rhs = con.getSrc();
2094 auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2095 auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2098 if (!lhsType || !rhsType)
2101 auto lhsWidth = lhsType.getBitWidthOrSentinel();
2102 auto rhsWidth = rhsType.getBitWidthOrSentinel();
2103 if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2104 OpBuilder builder(op);
2105 auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2106 rhsWidth - lhsWidth);
2107 if (type_isa<SIntType>(rhsType))
2109 builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2111 LLVM_DEBUG(llvm::dbgs()
2112 <<
"Truncating RHS to " << lhsType <<
" in " << con <<
"\n");
2113 con->replaceUsesOfWith(con.getSrc(), trunc);
2119 if (
auto module = dyn_cast<FModuleOp>(op)) {
2121 bool argsChanged =
false;
2122 SmallVector<Attribute> argTypes;
2123 argTypes.reserve(module.getNumPorts());
2124 for (
auto arg : module.getArguments()) {
2125 auto result = updateValue(arg);
2128 argsChanged |= *result;
2134 module.setPortTypesAttr(
ArrayAttr::get(module.getContext(), argTypes));
2143 auto *context = type.getContext();
2145 .
Case<UIntType>([&](
auto type) {
2148 .Case<SIntType>([&](
auto type) {
2151 .Case<AnalogType>([&](
auto type) {
2154 .Default([&](
auto type) {
return type; });
2158 FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2160 auto type = type_dyn_cast<FIRRTLType>(value.getType());
2171 if (
auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2172 SmallVector<Type, 2> types;
2174 op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2175 op->getAttrDictionary(), op->getPropertiesStorage(),
2176 op->getRegions(), types);
2180 assert(types.size() == op->getNumResults());
2181 for (
auto [result, type] : llvm::zip(op->getResults(), types)) {
2182 LLVM_DEBUG(llvm::dbgs()
2183 <<
"Inferring " << result <<
" as " << type <<
"\n");
2184 result.setType(type);
2190 auto *context = type.getContext();
2191 unsigned fieldID = 0;
2194 auto width = type.getBitWidthOrSentinel();
2206 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
2209 llvm::SmallVector<BundleType::BundleElement, 3> elements;
2210 for (
auto &element : bundleType) {
2211 auto updatedBase = updateBase(element.type);
2214 elements.emplace_back(element.name, element.isFlip, updatedBase);
2218 if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2220 auto save = fieldID;
2223 if (vecType.getNumElements() > 0) {
2224 auto updatedBase = updateBase(vecType.getElementType());
2229 fieldID = save + vecType.getMaxFieldID();
2235 if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
2237 llvm::SmallVector<FEnumType::EnumElement> elements;
2238 for (
auto &element : enumType.getElements()) {
2239 auto updatedBase = updateBase(element.type);
2242 elements.emplace_back(element.name, updatedBase);
2246 llvm_unreachable(
"Unknown type inside a bundle!");
2253 LLVM_DEBUG(llvm::dbgs() <<
"Update " << value <<
" to " << newType <<
"\n");
2254 value.setType(newType);
2259 if (
auto op = value.getDefiningOp<ConstantOp>()) {
2260 auto k = op.getValue();
2261 auto bitwidth = op.getType().getBitWidthOrSentinel();
2262 if (k.getBitWidth() >
unsigned(bitwidth))
2263 k = k.trunc(bitwidth);
2267 return newType != type;
2273 assert(type.isGround() &&
"Can only pass in ground types.");
2276 Expr *expr = mapping.getExprOrNull(fieldRef);
2277 if (!expr || !expr->getSolution()) {
2282 mlir::emitError(value.getLoc(),
"width should have been inferred");
2285 int32_t solution = *expr->getSolution();
2295 class InferWidthsPass
2296 :
public circt::firrtl::impl::InferWidthsBase<InferWidthsPass> {
2297 void runOnOperation()
override;
2301 void InferWidthsPass::runOnOperation() {
2303 ConstraintSolver solver;
2304 InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2305 getAnalysis<hw::InnerSymbolTableCollection>());
2306 if (failed(mapping.map(getOperation())))
2307 return signalPassFailure();
2310 if (mapping.areAllModulesSkipped())
2311 return markAllAnalysesPreserved();
2314 if (failed(solver.solve()))
2315 return signalPassFailure();
2318 if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2319 return signalPassFailure();
2323 return std::make_unique<InferWidthsPass>();
assert(baseType &&"element must be base type")
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, std::vector< Frame > &worklist)
Compute the value of a constraint expr.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
bool operator==(uint64_t a, const FVInt &b)
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
llvm::hash_code hash_value(const T &e)