15 #include "mlir/Pass/Pass.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/IR/Threading.h"
26 #include "llvm/ADT/APSInt.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/GraphTraits.h"
29 #include "llvm/ADT/Hashing.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/PostOrderIterator.h"
32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/ErrorHandling.h"
36 #define DEBUG_TYPE "infer-widths"
40 #define GEN_PASS_DEF_INFERWIDTHS
41 #include "circt/Dialect/FIRRTL/Passes.h.inc"
45 using mlir::InferTypeOpInterface;
46 using mlir::WalkOrder;
48 using namespace circt;
49 using namespace firrtl;
57 auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
60 if (!basetype.hasUninferredWidth())
63 if (basetype.isGround())
64 diag.attachNote() <<
"Field: \"" << str <<
"\"";
65 else if (
auto vecType = type_dyn_cast<FVectorType>(basetype))
67 else if (
auto bundleType = type_dyn_cast<BundleType>(basetype))
68 for (
auto &elem : bundleType.getElements())
74 uint64_t convertedFieldID = 0;
76 auto curFID = fieldID;
81 if (isa<FVectorType>(curFType))
84 convertedFieldID += curFID - subID;
89 return convertedFieldID;
101 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
103 inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const T &e) {
110 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
113 return e.hash_value();
118 #define EXPR_NAMES(x) \
119 Root##x, Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
120 #define EXPR_KINDS EXPR_NAMES()
121 #define EXPR_CLASSES EXPR_NAMES(Expr)
126 std::optional<int32_t> solution;
130 void print(llvm::raw_ostream &os)
const;
133 Expr(Kind kind) : kind(kind) {}
138 template <
class DerivedT, Expr::Kind DerivedKind>
139 struct ExprBase :
public Expr {
140 ExprBase() : Expr(DerivedKind) {}
141 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
143 if (
auto otherSame = dyn_cast<DerivedT>(other))
144 return *
static_cast<DerivedT *
>(
this) == otherSame;
150 struct RootExpr :
public ExprBase<RootExpr, Expr::Kind::Root> {
151 RootExpr(std::vector<Expr *> &exprs) : exprs(exprs) {}
152 void print(llvm::raw_ostream &os)
const { os <<
"root"; }
153 std::vector<Expr *> &exprs;
157 struct VarExpr :
public ExprBase<VarExpr, Expr::Kind::Var> {
158 void print(llvm::raw_ostream &os)
const {
161 os <<
"var" << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFFF);
166 Expr *constraint =
nullptr;
169 Expr *upperBound =
nullptr;
170 std::optional<int32_t> upperBoundSolution;
177 struct DerivedExpr :
public ExprBase<DerivedExpr, Expr::Kind::Derived> {
178 void print(llvm::raw_ostream &os)
const {
181 << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFF);
185 Expr *assigned =
nullptr;
202 struct IdExpr :
public ExprBase<IdExpr, Expr::Kind::Id> {
203 IdExpr(Expr *arg) : arg(arg) {
assert(arg); }
204 void print(llvm::raw_ostream &os)
const { os <<
"*" << *arg; }
206 return kind == other.kind && arg == other.arg;
217 struct KnownExpr :
public ExprBase<KnownExpr, Expr::Kind::Known> {
218 KnownExpr(int32_t value) : ExprBase() { solution = value; }
219 void print(llvm::raw_ostream &os)
const { os << *solution; }
220 bool operator==(
const KnownExpr &other)
const {
221 return *solution == *other.solution;
230 struct UnaryExpr :
public Expr {
231 bool operator==(
const UnaryExpr &other)
const {
232 return kind == other.kind && arg == other.arg;
242 UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) {
assert(arg); }
246 template <
class DerivedT, Expr::Kind DerivedKind>
247 struct UnaryExprBase :
public UnaryExpr {
248 template <
typename... Args>
249 UnaryExprBase(Args &&...args)
250 : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
251 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
255 struct PowExpr :
public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
256 using UnaryExprBase::UnaryExprBase;
257 void print(llvm::raw_ostream &os)
const { os <<
"2^" << arg; }
262 struct BinaryExpr :
public Expr {
263 bool operator==(
const BinaryExpr &other)
const {
264 return kind == other.kind && lhs() == other.lhs() && rhs() == other.rhs();
269 Expr *lhs()
const {
return args[0]; }
270 Expr *rhs()
const {
return args[1]; }
276 BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
283 template <
class DerivedT, Expr::Kind DerivedKind>
284 struct BinaryExprBase :
public BinaryExpr {
285 template <
typename... Args>
286 BinaryExprBase(Args &&...args)
287 : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
288 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
292 struct AddExpr :
public BinaryExprBase<AddExpr, Expr::Kind::Add> {
293 using BinaryExprBase::BinaryExprBase;
294 void print(llvm::raw_ostream &os)
const {
295 os <<
"(" << *lhs() <<
" + " << *rhs() <<
")";
300 struct MaxExpr :
public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
301 using BinaryExprBase::BinaryExprBase;
302 void print(llvm::raw_ostream &os)
const {
303 os <<
"max(" << *lhs() <<
", " << *rhs() <<
")";
308 struct MinExpr :
public BinaryExprBase<MinExpr, Expr::Kind::Min> {
309 using BinaryExprBase::BinaryExprBase;
310 void print(llvm::raw_ostream &os)
const {
311 os <<
"min(" << *lhs() <<
", " << *rhs() <<
")";
315 void Expr::print(llvm::raw_ostream &os)
const {
317 [&](
auto *e) { e->print(os); });
329 template <
typename T>
330 struct InternedSlot {
332 InternedSlot(T *ptr) : ptr(ptr) {}
337 template <
typename T,
typename std::enable_if_t<
338 std::is_trivially_destructible<T>::value,
int> = 0>
339 class InternedAllocator {
340 using Slot = InternedSlot<T>;
341 llvm::DenseSet<Slot> interned;
342 llvm::BumpPtrAllocator &allocator;
345 InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
350 template <
typename R = T,
typename... Args>
351 std::pair<R *, bool> alloc(Args &&...args) {
352 auto stack_value = R(std::forward<Args>(args)...);
353 auto stack_slot = Slot(&stack_value);
354 auto it = interned.find(stack_slot);
355 if (it != interned.end())
356 return std::make_pair(
static_cast<R *
>(it->ptr),
false);
357 auto heap_value =
new (allocator) R(std::move(stack_value));
358 interned.insert(Slot(heap_value));
359 return std::make_pair(heap_value,
true);
365 template <
typename T,
typename std::enable_if_t<
366 std::is_trivially_destructible<T>::value,
int> = 0>
368 llvm::BumpPtrAllocator &allocator;
371 Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
375 template <
typename R = T,
typename... Args>
376 R *alloc(Args &&...args) {
377 return new (allocator) R(std::forward<Args>(args)...);
408 int32_t rec_scale = 0;
409 int32_t rec_bias = 0;
410 int32_t nonrec_bias = 0;
414 static LinIneq unsat() {
return LinIneq(
true); }
417 explicit LinIneq(
bool failed =
false) : failed(failed) {}
420 explicit LinIneq(int32_t bias) : nonrec_bias(bias) {}
423 explicit LinIneq(int32_t scale, int32_t bias) {
434 explicit LinIneq(int32_t rec_scale, int32_t rec_bias, int32_t nonrec_bias,
437 if (rec_scale != 0) {
438 this->rec_scale = rec_scale;
439 this->rec_bias = rec_bias;
440 this->nonrec_bias = nonrec_bias;
442 this->nonrec_bias = std::max(rec_bias, nonrec_bias);
453 static LinIneq max(
const LinIneq &lhs,
const LinIneq &rhs) {
454 return LinIneq(std::max(lhs.rec_scale, rhs.rec_scale),
455 std::max(lhs.rec_bias, rhs.rec_bias),
456 std::max(lhs.nonrec_bias, rhs.nonrec_bias),
457 lhs.failed || rhs.failed);
513 static LinIneq add(
const LinIneq &lhs,
const LinIneq &rhs) {
516 auto enable1 = lhs.rec_scale > 0 && rhs.rec_scale > 0;
517 auto enable2 = lhs.rec_scale > 0;
518 auto enable3 = rhs.rec_scale > 0;
519 auto scale1 = lhs.rec_scale + rhs.rec_scale;
520 auto scale2 = lhs.rec_scale;
521 auto scale3 = rhs.rec_scale;
522 auto bias1 = lhs.rec_bias + rhs.rec_bias;
523 auto bias2 = lhs.rec_bias + rhs.nonrec_bias;
524 auto bias3 = rhs.rec_bias + lhs.nonrec_bias;
525 auto maxScale = std::max(scale1, std::max(scale2, scale3));
529 std::optional<int32_t> maxBias;
530 if (enable1 && scale1 == maxScale)
532 if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
534 if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
539 auto nonrec_bias = lhs.nonrec_bias + rhs.nonrec_bias;
540 auto failed = lhs.failed || rhs.failed;
541 if (enable1 && scale1 == maxScale && bias1 == *maxBias)
542 return LinIneq(scale1, bias1, nonrec_bias, failed);
543 if (enable2 && scale2 == maxScale && bias2 == *maxBias)
544 return LinIneq(scale2, bias2, nonrec_bias, failed);
545 if (enable3 && scale3 == maxScale && bias3 == *maxBias)
546 return LinIneq(scale3, bias3, nonrec_bias, failed);
547 return LinIneq(0, 0, nonrec_bias, failed);
560 if (rec_scale == 1 && rec_bias > 0)
566 void print(llvm::raw_ostream &os)
const {
568 bool both = (rec_scale != 0 || rec_bias != 0) && nonrec_bias != 0;
572 if (rec_scale != 0) {
575 os << rec_scale <<
"*";
581 os <<
" - " << -rec_bias;
583 os <<
" + " << rec_bias;
591 if (nonrec_bias != 0) {
608 class ConstraintSolver {
610 ConstraintSolver() =
default;
613 auto v = vars.alloc();
616 info[v].insert(currentInfo);
618 locs[v].insert(*currentLoc);
621 DerivedExpr *derived() {
622 auto *d = derivs.alloc();
626 KnownExpr *known(int32_t value) {
return alloc<KnownExpr>(knowns, value); }
627 IdExpr *id(Expr *arg) {
return alloc<IdExpr>(ids, arg); }
628 PowExpr *pow(Expr *arg) {
return alloc<PowExpr>(uns, arg); }
629 AddExpr *add(Expr *lhs, Expr *rhs) {
return alloc<AddExpr>(bins, lhs, rhs); }
630 MaxExpr *max(Expr *lhs, Expr *rhs) {
return alloc<MaxExpr>(bins, lhs, rhs); }
631 MinExpr *min(Expr *lhs, Expr *rhs) {
return alloc<MinExpr>(bins, lhs, rhs); }
635 Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
637 lhs->constraint = max(lhs->constraint, rhs);
639 lhs->constraint = id(rhs);
640 return lhs->constraint;
645 Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
647 lhs->upperBound = min(lhs->upperBound, rhs);
649 lhs->upperBound = id(rhs);
650 return lhs->upperBound;
653 void dumpConstraints(llvm::raw_ostream &os);
654 LogicalResult solve();
656 using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
657 const ContextInfo &getContextInfo()
const {
return info; }
658 void setCurrentContextInfo(
FieldRef fieldRef) { currentInfo = fieldRef; }
659 void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
663 llvm::BumpPtrAllocator allocator;
664 Allocator<VarExpr> vars = {allocator};
665 Allocator<DerivedExpr> derivs = {allocator};
666 InternedAllocator<KnownExpr> knowns = {allocator};
667 InternedAllocator<IdExpr> ids = {allocator};
668 InternedAllocator<UnaryExpr> uns = {allocator};
669 InternedAllocator<BinaryExpr> bins = {allocator};
672 std::vector<Expr *> exprs;
673 RootExpr root = {exprs};
676 template <
typename R,
typename T,
typename... Args>
677 R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
678 auto it = allocator.template alloc<R>(std::forward<Args>(args)...);
680 exprs.push_back(it.first);
682 info[it.first].insert(currentInfo);
684 locs[it.first].insert(*currentLoc);
692 DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
693 std::optional<Location> currentLoc;
697 ConstraintSolver(ConstraintSolver &&) =
delete;
698 ConstraintSolver(
const ConstraintSolver &) =
delete;
699 ConstraintSolver &operator=(ConstraintSolver &&) =
delete;
700 ConstraintSolver &operator=(
const ConstraintSolver &) =
delete;
702 void emitUninferredWidthError(VarExpr *var);
704 LinIneq checkCycles(VarExpr *var, Expr *expr,
705 SmallPtrSetImpl<Expr *> &seenVars,
706 InFlightDiagnostic *reportInto =
nullptr,
707 unsigned indent = 1);
713 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
714 for (
auto *e : exprs) {
715 if (
auto *v = dyn_cast<VarExpr>(e)) {
717 os <<
"- " << *v <<
" >= " << *v->constraint <<
"\n";
719 os <<
"- " << *v <<
" unconstrained\n";
725 inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const LinIneq &l) {
741 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
742 SmallPtrSetImpl<Expr *> &seenVars,
743 InFlightDiagnostic *reportInto,
746 TypeSwitch<Expr *, LinIneq>(expr)
747 .Case<KnownExpr>([&](
auto *expr) {
return LinIneq(*expr->solution); })
748 .Case<VarExpr>([&](
auto *expr) {
750 return LinIneq(1, 0);
751 if (!seenVars.insert(expr).second)
758 if (!expr->constraint)
761 auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
763 seenVars.erase(expr);
766 .Case<IdExpr>([&](
auto *expr) {
767 return checkCycles(var, expr->arg, seenVars, reportInto,
770 .Case<PowExpr>([&](
auto *expr) {
775 checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
776 if (arg.rec_scale != 0 || arg.nonrec_bias < 0 ||
777 arg.nonrec_bias >= 31)
778 return LinIneq::unsat();
779 return LinIneq(1 << arg.nonrec_bias);
781 .Case<AddExpr>([&](
auto *expr) {
783 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
784 checkCycles(var, expr->rhs(), seenVars, reportInto,
787 .Case<MaxExpr, MinExpr>([&](
auto *expr) {
792 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
793 checkCycles(var, expr->rhs(), seenVars, reportInto,
796 .Default([](
auto) {
return LinIneq::unsat(); });
801 if (reportInto && !ineq.sat()) {
802 auto report = [&](Location loc) {
803 auto ¬e = reportInto->attachNote(loc);
804 note <<
"constrained width W >= ";
805 if (ineq.rec_scale == -1)
807 if (ineq.rec_scale != 1)
808 note << ineq.rec_scale;
810 if (ineq.rec_bias < 0)
811 note <<
"-" << -ineq.rec_bias;
812 if (ineq.rec_bias > 0)
813 note <<
"+" << ineq.rec_bias;
816 auto it = locs.find(expr);
817 if (it != locs.end())
818 for (
auto loc : it->second)
822 LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
823 <<
"- Visited " << *expr <<
": " << ineq <<
"\n");
833 arg.first = operation(*arg.first);
839 llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
840 auto result =
ExprSolution{std::nullopt, lhs.second || rhs.second};
841 if (lhs.first && rhs.first)
842 result.first = operation(*lhs.first, *rhs.first);
844 result.first = lhs.first;
846 result.first = rhs.first;
856 unsigned defaultWorklistSize) {
865 std::vector<Frame> worklist({{expr, indent}});
866 llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
869 worklist.reserve(defaultWorklistSize);
871 while (!worklist.empty()) {
872 auto &frame = worklist.back();
875 if (solution.first && !solution.second)
876 frame.expr->solution = *solution.first;
877 solvedExprs[frame.expr] = solution;
881 if (!isa<KnownExpr>(frame.expr)) {
883 llvm::dbgs().indent(frame.indent * 2)
884 <<
"= Solved " << *frame.expr <<
" = " << *solution.first;
886 llvm::dbgs().indent(frame.indent * 2)
887 <<
"= Skipped " << *frame.expr;
888 llvm::dbgs() <<
" (" << (solution.second ?
"cycle broken" :
"unique")
897 if (frame.expr->solution) {
899 if (!isa<KnownExpr>(frame.expr))
900 llvm::dbgs().indent(indent * 2) <<
"- Cached " << *frame.expr <<
" = "
901 << *frame.expr->solution <<
"\n";
903 setSolution(
ExprSolution{*frame.expr->solution,
false});
909 if (!isa<KnownExpr>(frame.expr))
910 llvm::dbgs().indent(frame.indent * 2)
911 <<
"- Solving " << *frame.expr <<
"\n";
914 TypeSwitch<Expr *>(frame.expr)
915 .Case<KnownExpr>([&](
auto *expr) {
918 .Case<VarExpr>([&](
auto *expr) {
919 if (solvedExprs.contains(expr->constraint)) {
920 auto solution = solvedExprs[expr->constraint];
927 if (expr->upperBound && solvedExprs.contains(expr->upperBound))
928 expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
929 seenVars.erase(expr);
931 if (solution.first && *solution.first < 0)
933 return setSolution(solution);
937 if (!expr->constraint)
942 if (!seenVars.insert(expr).second)
945 worklist.push_back({expr->constraint, indent + 1});
946 if (expr->upperBound)
947 worklist.push_back({expr->upperBound, indent + 1});
949 .Case<IdExpr>([&](
auto *expr) {
950 if (solvedExprs.contains(expr->arg))
951 return setSolution(solvedExprs[expr->arg]);
952 worklist.push_back({expr->arg, indent + 1});
954 .Case<PowExpr>([&](
auto *expr) {
955 if (solvedExprs.contains(expr->arg))
957 solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
959 worklist.push_back({expr->arg, indent + 1});
961 .Case<AddExpr>([&](
auto *expr) {
962 if (solvedExprs.contains(expr->lhs()) &&
963 solvedExprs.contains(expr->rhs()))
965 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
966 [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
968 worklist.push_back({expr->lhs(), indent + 1});
969 worklist.push_back({expr->rhs(), indent + 1});
971 .Case<MaxExpr>([&](
auto *expr) {
972 if (solvedExprs.contains(expr->lhs()) &&
973 solvedExprs.contains(expr->rhs()))
975 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
976 [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
978 worklist.push_back({expr->lhs(), indent + 1});
979 worklist.push_back({expr->rhs(), indent + 1});
981 .Case<MinExpr>([&](
auto *expr) {
982 if (solvedExprs.contains(expr->lhs()) &&
983 solvedExprs.contains(expr->rhs()))
985 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
986 [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
988 worklist.push_back({expr->lhs(), indent + 1});
989 worklist.push_back({expr->rhs(), indent + 1});
996 return solvedExprs[expr];
1002 LogicalResult ConstraintSolver::solve() {
1004 llvm::dbgs() <<
"\n";
1006 dumpConstraints(llvm::dbgs());
1011 llvm::dbgs() <<
"\n";
1012 debugHeader(
"Checking for unbreakable loops") <<
"\n\n";
1014 SmallPtrSet<Expr *, 16> seenVars;
1015 bool anyFailed =
false;
1017 for (
auto *expr : exprs) {
1019 auto *var = dyn_cast<VarExpr>(expr);
1020 if (!var || !var->constraint)
1022 LLVM_DEBUG(llvm::dbgs()
1023 <<
"- Checking " << *var <<
" >= " << *var->constraint <<
"\n");
1028 seenVars.insert(var);
1029 auto ineq = checkCycles(var, var->constraint, seenVars);
1038 LLVM_DEBUG(llvm::dbgs()
1039 <<
" = Breakable since " << ineq <<
" satisfiable\n");
1047 LLVM_DEBUG(llvm::dbgs()
1048 <<
" = UNBREAKABLE since " << ineq <<
" unsatisfiable\n");
1050 for (
auto fieldRef : info.find(var)->second) {
1053 auto op = fieldRef.getDefiningOp();
1054 auto diag = op ? op->emitOpError()
1055 : mlir::emitError(fieldRef.getValue().getLoc())
1057 diag <<
"is constrained to be wider than itself";
1060 seenVars.insert(var);
1061 checkCycles(var, var->constraint, seenVars, &diag);
1073 llvm::dbgs() <<
"\n";
1076 unsigned defaultWorklistSize = exprs.size() / 2;
1077 for (
auto *expr : exprs) {
1079 auto *var = dyn_cast<VarExpr>(expr);
1084 if (!var->constraint) {
1085 LLVM_DEBUG(llvm::dbgs() <<
"- Unconstrained " << *var <<
"\n");
1086 emitUninferredWidthError(var);
1092 LLVM_DEBUG(llvm::dbgs()
1093 <<
"- Solving " << *var <<
" >= " << *var->constraint <<
"\n");
1094 seenVars.insert(var);
1095 auto solution =
solveExpr(var->constraint, seenVars, defaultWorklistSize);
1097 if (var->upperBound && !var->upperBoundSolution)
1098 var->upperBoundSolution =
1099 solveExpr(var->upperBound, seenVars, defaultWorklistSize).first;
1103 if (solution.first && *solution.first < 0)
1105 var->solution = solution.first;
1109 if (!solution.first) {
1110 LLVM_DEBUG(llvm::dbgs() <<
" - UNSOLVED " << *var <<
"\n");
1111 emitUninferredWidthError(var);
1115 LLVM_DEBUG(llvm::dbgs()
1116 <<
" = Solved " << *var <<
" = " << solution.first <<
" ("
1117 << (solution.second ?
"cycle broken" :
"unique") <<
")\n");
1120 if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1121 LLVM_DEBUG(llvm::dbgs() <<
" ! Unsatisfiable " << *var
1122 <<
" <= " << var->upperBoundSolution <<
"\n");
1123 emitUninferredWidthError(var);
1129 for (
auto *expr : exprs) {
1131 auto *derived = dyn_cast<DerivedExpr>(expr);
1135 auto *assigned = derived->assigned;
1136 if (!assigned || !assigned->solution) {
1137 LLVM_DEBUG(llvm::dbgs() <<
"- Unused " << *derived <<
" set to 0\n");
1138 derived->solution = 0;
1140 LLVM_DEBUG(llvm::dbgs() <<
"- Deriving " << *derived <<
" = "
1141 << assigned->solution <<
"\n");
1142 derived->solution = *assigned->solution;
1146 return failure(anyFailed);
1151 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1152 FieldRef fieldRef = info.find(var)->second.back();
1155 auto diag = mlir::emitError(value.getLoc(),
"uninferred width:");
1158 if (isa<BlockArgument>(value)) {
1160 }
else if (
auto op = value.getDefiningOp()) {
1161 TypeSwitch<Operation *>(op)
1162 .Case<WireOp>([&](
auto) { diag <<
" wire"; })
1163 .Case<RegOp, RegResetOp>([&](
auto) { diag <<
" reg"; })
1164 .Case<NodeOp>([&](
auto) { diag <<
" node"; })
1165 .Default([&](
auto) { diag <<
" value"; });
1172 if (!fieldName.empty()) {
1175 diag <<
" \"" << fieldName <<
"\"";
1178 if (!var->constraint) {
1179 diag <<
" is unconstrained";
1180 }
else if (var->solution && var->upperBoundSolution &&
1181 var->solution > var->upperBoundSolution) {
1182 diag <<
" cannot satisfy all width requirements";
1183 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1184 LLVM_DEBUG(llvm::dbgs() << *var->upperBound <<
"\n");
1185 auto loc = locs.find(var->constraint)->second.back();
1186 diag.attachNote(loc) <<
"width is constrained to be at least "
1187 << *var->solution <<
" here:";
1188 loc = locs.find(var->upperBound)->second.back();
1189 diag.attachNote(loc) <<
"width is constrained to be at most "
1190 << *var->upperBoundSolution <<
" here:";
1192 diag <<
" width cannot be determined";
1193 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1194 auto loc = locs.find(var->constraint)->second.back();
1195 diag.attachNote(loc) <<
"width is constrained by an uninferred width here:";
1207 class InferenceMapping {
1209 InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1210 hw::InnerSymbolTableCollection &istc)
1211 : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1213 LogicalResult map(CircuitOp op);
1214 LogicalResult mapOperation(Operation *op);
1219 void declareVars(Value value, Location loc,
bool isDerived =
false);
1222 Expr *declareVar(
FieldRef fieldRef, Location loc);
1226 Expr *declareVar(
FIRRTLType type, Location loc);
1231 void maximumOfTypes(Value result, Value rhs, Value lhs);
1235 void constrainTypes(Value larger, Value smaller,
bool equal =
false);
1239 void constrainTypes(Expr *larger, Expr *smaller,
1240 bool imposeUpperBounds =
false,
bool equal =
false);
1249 Expr *getExpr(Value value)
const;
1252 Expr *getExpr(
FieldRef fieldRef)
const;
1256 Expr *getExprOrNull(
FieldRef fieldRef)
const;
1260 void setExpr(Value value, Expr *expr);
1263 void setExpr(
FieldRef fieldRef, Expr *expr);
1266 bool isModuleSkipped(FModuleOp module)
const {
1267 return skippedModules.count(module);
1271 bool areAllModulesSkipped()
const {
return allModulesSkipped; }
1275 ConstraintSolver &solver;
1278 DenseMap<FieldRef, Expr *> opExprs;
1281 SmallPtrSet<Operation *, 16> skippedModules;
1282 bool allModulesSkipped =
true;
1285 SymbolTable &symtbl;
1288 hw::InnerRefNamespace irn;
1295 return TypeSwitch<Type, bool>(type)
1296 .Case<
FIRRTLBaseType>([](
auto base) {
return base.hasUninferredWidth(); })
1298 [](
auto ref) {
return ref.getType().hasUninferredWidth(); })
1299 .Default([](
auto) {
return false; });
1302 LogicalResult InferenceMapping::map(CircuitOp op) {
1303 LLVM_DEBUG(llvm::dbgs()
1304 <<
"\n===----- Mapping ops to constraint exprs -----===\n\n");
1307 for (
auto module : op.getOps<FModuleOp>())
1308 for (
auto arg : module.getArguments()) {
1309 solver.setCurrentContextInfo(
FieldRef(arg, 0));
1310 declareVars(arg, module.getLoc());
1313 for (
auto module : op.getOps<FModuleOp>()) {
1316 bool anyUninferred =
false;
1317 for (
auto arg : module.getArguments()) {
1322 module.walk([&](Operation *op) {
1323 for (
auto type : op->getResultTypes())
1326 return WalkResult::interrupt();
1327 return WalkResult::advance();
1330 if (!anyUninferred) {
1331 LLVM_DEBUG(llvm::dbgs() <<
"Skipping fully-inferred module '"
1332 << module.getName() <<
"'\n");
1333 skippedModules.insert(module);
1337 allModulesSkipped =
false;
1341 auto result = module.getBodyBlock()->walk(
1342 [&](Operation *op) {
return WalkResult(mapOperation(op)); });
1343 if (result.wasInterrupted())
1350 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1356 bool allWidthsKnown =
true;
1357 for (
auto result : op->getResults()) {
1358 if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1360 allWidthsKnown =
false;
1363 auto resultTy = type_dyn_cast<FIRRTLType>(result.getType());
1367 declareVars(result, op->getLoc());
1369 allWidthsKnown =
false;
1371 if (allWidthsKnown && !isa<FConnectLike, AttachOp>(op))
1375 if (isa<PropAssignOp>(op))
1379 bool mappingFailed =
false;
1380 solver.setCurrentContextInfo(
1382 solver.setCurrentLocation(op->getLoc());
1383 TypeSwitch<Operation *>(op)
1384 .Case<ConstantOp>([&](
auto op) {
1388 if (
auto width = op.getType().base().getWidth())
1389 e = solver.known(*
width);
1391 auto v = op.getValue();
1392 auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1393 : v.countLeadingZeros());
1396 e = solver.known(std::max(w, 1u));
1398 setExpr(op.getResult(), e);
1400 .Case<SpecialConstantOp>([&](
auto op) {
1403 .Case<InvalidValueOp>([&](
auto op) {
1408 declareVars(op.getResult(), op.getLoc(),
true);
1412 auto type = op.getType();
1413 ImplicitLocOpBuilder builder(op->getLoc(), op);
1415 llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
1419 auto clone = builder.create<InvalidValueOp>(type);
1420 declareVars(clone.getResult(), clone.getLoc(),
1425 .Case<WireOp, RegOp>(
1426 [&](
auto op) { declareVars(op.getResult(), op.getLoc()); })
1427 .Case<RegResetOp>([&](
auto op) {
1432 declareVars(op.getResult(), op.getLoc());
1435 constrainTypes(op.getResult(), op.getResetValue());
1437 .Case<NodeOp>([&](
auto op) {
1440 op.getResult().getType());
1444 .Case<SubfieldOp>([&](
auto op) {
1445 BundleType bundleType = op.getInput().getType();
1446 auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1447 unifyTypes(
FieldRef(op.getResult(), 0),
1448 FieldRef(op.getInput(), fieldID), op.getType());
1450 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
1456 .Case<SubtagOp>([&](
auto op) {
1457 FEnumType enumType = op.getInput().getType();
1458 auto fieldID = enumType.getFieldID(op.getFieldIndex());
1459 unifyTypes(
FieldRef(op.getResult(), 0),
1460 FieldRef(op.getInput(), fieldID), op.getType());
1463 .Case<RefSubOp>([&](RefSubOp op) {
1464 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1465 op.getInput().getType().getType())
1466 .Case<FVectorType>([](
auto _) {
return 1; })
1467 .Case<BundleType>([&](
auto type) {
1468 return type.getFieldID(op.getIndex());
1470 unifyTypes(
FieldRef(op.getResult(), 0),
1471 FieldRef(op.getInput(), fieldID), op.getType());
1475 .Case<AddPrimOp, SubPrimOp>([&](
auto op) {
1476 auto lhs = getExpr(op.getLhs());
1477 auto rhs = getExpr(op.getRhs());
1478 auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1479 setExpr(op.getResult(), e);
1481 .Case<MulPrimOp>([&](
auto op) {
1482 auto lhs = getExpr(op.getLhs());
1483 auto rhs = getExpr(op.getRhs());
1484 auto e = solver.add(lhs, rhs);
1485 setExpr(op.getResult(), e);
1487 .Case<DivPrimOp>([&](
auto op) {
1488 auto lhs = getExpr(op.getLhs());
1490 if (op.getType().base().isSigned()) {
1491 e = solver.add(lhs, solver.known(1));
1495 setExpr(op.getResult(), e);
1497 .Case<RemPrimOp>([&](
auto op) {
1498 auto lhs = getExpr(op.getLhs());
1499 auto rhs = getExpr(op.getRhs());
1500 auto e = solver.min(lhs, rhs);
1501 setExpr(op.getResult(), e);
1503 .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](
auto op) {
1504 auto lhs = getExpr(op.getLhs());
1505 auto rhs = getExpr(op.getRhs());
1506 auto e = solver.max(lhs, rhs);
1507 setExpr(op.getResult(), e);
1511 .Case<CatPrimOp>([&](
auto op) {
1512 auto lhs = getExpr(op.getLhs());
1513 auto rhs = getExpr(op.getRhs());
1514 auto e = solver.add(lhs, rhs);
1515 setExpr(op.getResult(), e);
1517 .Case<DShlPrimOp>([&](
auto op) {
1518 auto lhs = getExpr(op.getLhs());
1519 auto rhs = getExpr(op.getRhs());
1520 auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1521 setExpr(op.getResult(), e);
1523 .Case<DShlwPrimOp, DShrPrimOp>([&](
auto op) {
1524 auto e = getExpr(op.getLhs());
1525 setExpr(op.getResult(), e);
1529 .Case<NegPrimOp>([&](
auto op) {
1530 auto input = getExpr(op.getInput());
1531 auto e = solver.add(input, solver.known(1));
1532 setExpr(op.getResult(), e);
1534 .Case<CvtPrimOp>([&](
auto op) {
1535 auto input = getExpr(op.getInput());
1536 auto e = op.getInput().getType().base().isSigned()
1538 : solver.add(input, solver.known(1));
1539 setExpr(op.getResult(), e);
1543 .Case<BitsPrimOp>([&](
auto op) {
1544 setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1546 .Case<HeadPrimOp>([&](
auto op) {
1547 setExpr(op.getResult(), solver.known(op.getAmount()));
1549 .Case<TailPrimOp>([&](
auto op) {
1550 auto input = getExpr(op.getInput());
1551 auto e = solver.add(input, solver.known(-op.getAmount()));
1552 setExpr(op.getResult(), e);
1554 .Case<PadPrimOp>([&](
auto op) {
1555 auto input = getExpr(op.getInput());
1556 auto e = solver.max(input, solver.known(op.getAmount()));
1557 setExpr(op.getResult(), e);
1559 .Case<ShlPrimOp>([&](
auto op) {
1560 auto input = getExpr(op.getInput());
1561 auto e = solver.add(input, solver.known(op.getAmount()));
1562 setExpr(op.getResult(), e);
1564 .Case<ShrPrimOp>([&](
auto op) {
1565 auto input = getExpr(op.getInput());
1567 auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1568 auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1569 solver.known(minWidth));
1570 setExpr(op.getResult(), e);
1574 .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1575 [&](
auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1576 .Case<mlir::UnrealizedConversionCastOp>(
1577 [&](
auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1581 .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1582 AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1583 XorRPrimOp>([&](
auto op) {
1584 auto width = op.getType().getBitWidthOrSentinel();
1585 assert(
width > 0 &&
"width should have been checked by verifier");
1586 setExpr(op.getResult(), solver.known(
width));
1588 .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](
auto op) {
1589 auto *sel = getExpr(op.getSel());
1590 constrainTypes(solver.known(1), sel,
true);
1591 maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1593 .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1594 auto *sel = getExpr(op.getSel());
1595 constrainTypes(solver.known(2), sel,
true);
1596 maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1597 maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1598 maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1601 .Case<ConnectOp, MatchingConnectOp>(
1602 [&](
auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1603 .Case<RefDefineOp>([&](
auto op) {
1606 constrainTypes(op.getDest(), op.getSrc(),
true);
1609 .Case<MatchingConnectOp>([&](
auto op) {
1614 constrainTypes(op.getDest(), op.getSrc());
1615 constrainTypes(op.getSrc(), op.getDest());
1617 .Case<AttachOp>([&](
auto op) {
1621 if (op.getAttached().empty())
1623 auto prev = op.getAttached()[0];
1624 for (
auto operand : op.getAttached().drop_front()) {
1625 auto e1 = getExpr(prev);
1626 auto e2 = getExpr(operand);
1627 constrainTypes(e1, e2,
true);
1628 constrainTypes(e2, e1,
true);
1634 .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1635 UnclockedAssumeIntrinsicOp, CoverOp>([&](
auto) {})
1638 .Case<InstanceOp>([&](
auto op) {
1639 auto refdModule = op.getReferencedOperation(symtbl);
1640 auto module = dyn_cast<FModuleOp>(&*refdModule);
1642 auto diag = mlir::emitError(op.getLoc());
1643 diag <<
"extern module `" << op.getModuleName()
1644 <<
"` has ports of uninferred width";
1646 auto fml = cast<FModuleLike>(&*refdModule);
1647 auto ports = fml.getPorts();
1648 for (
auto &port : ports) {
1650 if (baseType && baseType.hasUninferredWidth()) {
1651 diag.attachNote(op.getLoc()) <<
"Port: " << port.name;
1652 if (!baseType.isGround())
1657 diag.attachNote(op.getLoc())
1658 <<
"Only non-extern FIRRTL modules may contain unspecified "
1659 "widths to be inferred automatically.";
1660 diag.attachNote(refdModule->getLoc())
1661 <<
"Module `" << op.getModuleName() <<
"` defined here:";
1662 mappingFailed =
true;
1669 for (
auto it : llvm::zip(op->getResults(), module.getArguments())) {
1671 type_cast<FIRRTLType>(std::get<0>(it).getType()));
1676 .Case<MemOp>([&](MemOp op) {
1678 unsigned nonDebugPort = 0;
1679 for (
const auto &result : llvm::enumerate(op.getResults())) {
1680 declareVars(result.value(), op.getLoc());
1681 if (!type_isa<RefType>(result.value().getType()))
1682 nonDebugPort = result.index();
1687 auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1688 static const unsigned indices[] = {3, 5};
1689 static const unsigned debug[] = {0};
1691 case MemOp::PortKind::Read:
1692 case MemOp::PortKind::Write:
1693 return ArrayRef<unsigned>(indices, 1);
1694 case MemOp::PortKind::ReadWrite:
1695 return ArrayRef<unsigned>(indices);
1696 case MemOp::PortKind::Debug:
1697 return ArrayRef<unsigned>(
debug);
1699 llvm_unreachable(
"Imposible PortKind");
1706 unsigned firstFieldIndex =
1707 dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1709 op.getResult(nonDebugPort),
1710 type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1711 .getFieldID(firstFieldIndex));
1712 LLVM_DEBUG(llvm::dbgs() <<
"Adjusting memory port variables:\n");
1715 auto dataType = op.getDataType();
1716 for (
unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1717 auto result = op.getResult(i);
1718 if (type_isa<RefType>(result.getType())) {
1722 unifyTypes(firstData,
FieldRef(result, 1), dataType);
1727 type_cast<BundleType>(op.getPortType(i).getPassiveType());
1728 for (
auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1729 unifyTypes(
FieldRef(result, portType.getFieldID(fieldIndex)),
1730 firstData, dataType);
1734 .Case<RefSendOp>([&](
auto op) {
1735 declareVars(op.getResult(), op.getLoc());
1736 constrainTypes(op.getResult(), op.getBase(),
true);
1738 .Case<RefResolveOp>([&](
auto op) {
1739 declareVars(op.getResult(), op.getLoc());
1740 constrainTypes(op.getResult(), op.getRef(),
true);
1742 .Case<RefCastOp>([&](
auto op) {
1743 declareVars(op.getResult(), op.getLoc());
1744 constrainTypes(op.getResult(), op.getInput(),
true);
1746 .Case<RWProbeOp>([&](
auto op) {
1747 declareVars(op.getResult(), op.getLoc());
1748 auto ist = irn.lookup(op.getTarget());
1750 op->emitError(
"target of rwprobe could not be resolved");
1751 mappingFailed =
true;
1756 op->emitError(
"target of rwprobe resolved to unsupported target");
1757 mappingFailed =
true;
1761 ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1762 unifyTypes(
FieldRef(op.getResult(), 0),
1763 FieldRef(ref.getValue(), newFID), op.getType());
1765 .Case<mlir::UnrealizedConversionCastOp>([&](
auto op) {
1766 for (Value result : op.getResults()) {
1767 auto ty = result.getType();
1768 if (type_isa<FIRRTLType>(ty))
1769 declareVars(result, op.getLoc());
1772 .Default([&](
auto op) {
1773 op->emitOpError(
"not supported in width inference");
1774 mappingFailed =
true;
1778 if (
auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1782 return failure(mappingFailed);
1787 void InferenceMapping::declareVars(Value value, Location loc,
bool isDerived) {
1790 unsigned fieldID = 0;
1792 auto width = type.getBitWidthOrSentinel();
1797 }
else if (
width == -1) {
1800 solver.setCurrentContextInfo(field);
1802 setExpr(field, solver.derived());
1804 setExpr(field, solver.var());
1806 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1809 for (
auto &element : bundleType) {
1810 declare(element.type);
1812 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1814 auto save = fieldID;
1815 declare(vecType.getElementType());
1817 fieldID = save + vecType.getMaxFieldID();
1818 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1820 for (
auto &element : enumType.getElements())
1821 declare(element.type);
1823 llvm_unreachable(
"Unknown type inside a bundle!");
1833 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1837 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1839 for (
auto &element : bundleType.getElements())
1840 maximize(element.type);
1841 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1843 auto save = fieldID;
1845 if (vecType.getNumElements() > 0)
1846 maximize(vecType.getElementType());
1847 fieldID = save + vecType.getMaxFieldID();
1848 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1850 for (
auto &element : enumType.getElements())
1851 maximize(element.type);
1852 }
else if (type.isGround()) {
1853 auto *e = solver.max(getExpr(
FieldRef(rhs, fieldID)),
1855 setExpr(
FieldRef(result, fieldID), e);
1858 llvm_unreachable(
"Unknown type inside a bundle!");
1873 void InferenceMapping::constrainTypes(Value larger, Value smaller,
bool equal) {
1880 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1882 for (
auto &element : bundleType.getElements()) {
1884 constrain(element.type, smaller, larger);
1886 constrain(element.type, larger, smaller);
1888 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1890 auto save = fieldID;
1892 if (vecType.getNumElements() > 0) {
1893 constrain(vecType.getElementType(), larger, smaller);
1895 fieldID = save + vecType.getMaxFieldID();
1896 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1898 for (
auto &element : enumType.getElements())
1899 constrain(element.type, larger, smaller);
1900 }
else if (type.isGround()) {
1902 constrainTypes(getExpr(
FieldRef(larger, fieldID)),
1903 getExpr(
FieldRef(smaller, fieldID)),
false, equal);
1906 llvm_unreachable(
"Unknown type inside a bundle!");
1911 constrain(type, larger, smaller);
1916 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1917 bool imposeUpperBounds,
bool equal) {
1918 assert(larger &&
"Larger expression should be specified");
1919 assert(smaller &&
"Smaller expression should be specified");
1925 if (
auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1926 largerDerived->assigned = smaller;
1927 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *largerDerived <<
" from "
1928 << *smaller <<
"\n");
1931 if (
auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1932 smallerDerived->assigned = larger;
1933 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *smallerDerived <<
" from "
1934 << *larger <<
"\n");
1940 if (
auto largerVar = dyn_cast<VarExpr>(larger)) {
1941 [[maybe_unused]]
auto *c = solver.addGeqConstraint(largerVar, smaller);
1942 LLVM_DEBUG(llvm::dbgs()
1943 <<
"Constrained " << *largerVar <<
" >= " << *c <<
"\n");
1950 [[maybe_unused]]
auto *leq = solver.addLeqConstraint(largerVar, smaller);
1951 LLVM_DEBUG(llvm::dbgs()
1952 <<
"Constrained " << *largerVar <<
" <= " << *leq <<
"\n");
1962 if (
auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1963 if (imposeUpperBounds || equal) {
1964 [[maybe_unused]]
auto *c = solver.addLeqConstraint(smallerVar, larger);
1965 LLVM_DEBUG(llvm::dbgs()
1966 <<
"Constrained " << *smallerVar <<
" <= " << *c <<
"\n");
1986 LLVM_DEBUG(llvm::dbgs()
1987 <<
"Unify " <<
getFieldName(lhsFieldRef).first <<
" = "
1990 if (
auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1991 solver.addGeqConstraint(var, solver.known(0));
1992 setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1994 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1996 for (
auto &element : bundleType) {
1997 unify(element.type);
1999 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2001 auto save = fieldID;
2003 if (vecType.getNumElements() > 0) {
2004 unify(vecType.getElementType());
2006 fieldID = save + vecType.getMaxFieldID();
2007 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
2009 for (
auto &element : enumType.getElements())
2010 unify(element.type);
2012 llvm_unreachable(
"Unknown type inside a bundle!");
2020 Expr *InferenceMapping::getExpr(Value value)
const {
2023 return getExpr(
FieldRef(value, 0));
2027 Expr *InferenceMapping::getExpr(
FieldRef fieldRef)
const {
2028 auto expr = getExprOrNull(fieldRef);
2029 assert(expr &&
"constraint expr should have been constructed for value");
2033 Expr *InferenceMapping::getExprOrNull(
FieldRef fieldRef)
const {
2034 auto it = opExprs.find(fieldRef);
2035 return it != opExprs.end() ? it->second :
nullptr;
2039 void InferenceMapping::setExpr(Value value, Expr *expr) {
2046 void InferenceMapping::setExpr(
FieldRef fieldRef, Expr *expr) {
2048 llvm::dbgs() <<
"Expr " << *expr <<
" for " << fieldRef.
getValue();
2050 llvm::dbgs() <<
" '" <<
getFieldName(fieldRef).first <<
"'";
2052 if (fieldName.second)
2053 llvm::dbgs() <<
" (\"" << fieldName.first <<
"\")";
2054 llvm::dbgs() <<
"\n";
2056 opExprs[fieldRef] = expr;
2066 class InferenceTypeUpdate {
2068 InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2070 LogicalResult update(CircuitOp op);
2071 FailureOr<bool> updateOperation(Operation *op);
2072 FailureOr<bool> updateValue(Value value);
2076 const InferenceMapping &mapping;
2082 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2084 llvm::dbgs() <<
"\n";
2087 return mlir::failableParallelForEach(
2088 op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2091 if (mapping.isModuleSkipped(op))
2093 auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2094 if (failed(updateOperation(op)))
2095 return WalkResult::interrupt();
2096 return WalkResult::advance();
2097 }).wasInterrupted();
2098 return failure(isFailed);
2103 FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2104 bool anyChanged =
false;
2106 for (Value v : op->getResults()) {
2107 auto result = updateValue(v);
2110 anyChanged |= *result;
2116 if (
auto con = dyn_cast<ConnectOp>(op)) {
2117 auto lhs = con.getDest();
2118 auto rhs = con.getSrc();
2119 auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2120 auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2123 if (!lhsType || !rhsType)
2126 auto lhsWidth = lhsType.getBitWidthOrSentinel();
2127 auto rhsWidth = rhsType.getBitWidthOrSentinel();
2128 if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2129 OpBuilder builder(op);
2130 auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2131 rhsWidth - lhsWidth);
2132 if (type_isa<SIntType>(rhsType))
2134 builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2136 LLVM_DEBUG(llvm::dbgs()
2137 <<
"Truncating RHS to " << lhsType <<
" in " << con <<
"\n");
2138 con->replaceUsesOfWith(con.getSrc(), trunc);
2144 if (
auto module = dyn_cast<FModuleOp>(op)) {
2146 bool argsChanged =
false;
2147 SmallVector<Attribute> argTypes;
2148 argTypes.reserve(module.getNumPorts());
2149 for (
auto arg : module.getArguments()) {
2150 auto result = updateValue(arg);
2153 argsChanged |= *result;
2159 module->setAttr(FModuleLike::getPortTypesAttrName(),
2169 auto *context = type.getContext();
2171 .
Case<UIntType>([&](
auto type) {
2174 .Case<SIntType>([&](
auto type) {
2177 .Case<AnalogType>([&](
auto type) {
2180 .Default([&](
auto type) {
return type; });
2184 FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2186 auto type = type_dyn_cast<FIRRTLType>(value.getType());
2197 if (
auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2198 SmallVector<Type, 2> types;
2200 op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2201 op->getAttrDictionary(), op->getPropertiesStorage(),
2202 op->getRegions(), types);
2206 assert(types.size() == op->getNumResults());
2207 for (
auto it : llvm::zip(op->getResults(), types)) {
2208 LLVM_DEBUG(llvm::dbgs() <<
"Inferring " << std::get<0>(it) <<
" as "
2209 << std::get<1>(it) <<
"\n");
2210 std::get<0>(it).setType(std::get<1>(it));
2216 auto context = type.getContext();
2217 unsigned fieldID = 0;
2220 auto width = type.getBitWidthOrSentinel();
2225 }
else if (
width == -1) {
2230 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
2233 llvm::SmallVector<BundleType::BundleElement, 3> elements;
2234 for (
auto &element : bundleType) {
2235 auto updatedBase = updateBase(element.type);
2238 elements.emplace_back(element.name, element.isFlip, updatedBase);
2241 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2243 auto save = fieldID;
2246 if (vecType.getNumElements() > 0) {
2247 auto updatedBase = updateBase(vecType.getElementType());
2252 fieldID = save + vecType.getMaxFieldID();
2257 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
2259 llvm::SmallVector<FEnumType::EnumElement> elements;
2260 for (
auto &element : enumType.getElements()) {
2261 auto updatedBase = updateBase(element.type);
2264 elements.emplace_back(element.name, updatedBase);
2268 llvm_unreachable(
"Unknown type inside a bundle!");
2275 LLVM_DEBUG(llvm::dbgs() <<
"Update " << value <<
" to " << newType <<
"\n");
2276 value.setType(newType);
2281 if (
auto op = value.getDefiningOp<ConstantOp>()) {
2282 auto k = op.getValue();
2283 auto bitwidth = op.getType().getBitWidthOrSentinel();
2284 if (k.getBitWidth() >
unsigned(bitwidth))
2285 k = k.trunc(bitwidth);
2289 return newType != type;
2295 assert(type.isGround() &&
"Can only pass in ground types.");
2298 Expr *expr = mapping.getExprOrNull(fieldRef);
2299 if (!expr || !expr->solution) {
2304 mlir::emitError(value.getLoc(),
"width should have been inferred");
2307 int32_t solution = *expr->solution;
2319 template <
typename T>
2320 struct DenseMapInfo<InternedSlot<T>> {
2323 auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
2324 return Slot(
static_cast<T *
>(pointer));
2327 auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
2328 return Slot(
static_cast<T *
>(pointer));
2332 auto empty = getEmptyKey().ptr;
2333 auto tombstone = getTombstoneKey().ptr;
2334 if (LHS.ptr ==
empty || RHS.ptr ==
empty || LHS.ptr == tombstone ||
2335 RHS.ptr == tombstone)
2336 return LHS.ptr == RHS.ptr;
2337 return *LHS.ptr == *RHS.ptr;
2347 class InferWidthsPass
2348 :
public circt::firrtl::impl::InferWidthsBase<InferWidthsPass> {
2349 void runOnOperation()
override;
2353 void InferWidthsPass::runOnOperation() {
2355 ConstraintSolver solver;
2356 InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2357 getAnalysis<hw::InnerSymbolTableCollection>());
2358 if (failed(mapping.map(getOperation()))) {
2359 signalPassFailure();
2362 if (mapping.areAllModulesSkipped()) {
2363 markAllAnalysesPreserved();
2368 if (failed(solver.solve())) {
2369 signalPassFailure();
2374 if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2375 signalPassFailure();
2379 return std::make_unique<InferWidthsPass>();
assert(baseType &&"element must be base type")
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
bool operator==(const ResetDomain &a, const ResetDomain &b)
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, unsigned defaultWorklistSize)
Compute the value of a constraint expr.
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
T & operator<<(T &os, FIRVersion version)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
llvm::hash_code hash_value(const T &e)
static Slot getEmptyKey()
static unsigned getHashValue(Slot val)
static Slot getTombstoneKey()
static bool isEqual(Slot LHS, Slot RHS)