22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/Threading.h"
24 #include "llvm/ADT/APSInt.h"
25 #include "llvm/ADT/DenseSet.h"
26 #include "llvm/ADT/GraphTraits.h"
27 #include "llvm/ADT/Hashing.h"
28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/ADT/PostOrderIterator.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/ErrorHandling.h"
34 #define DEBUG_TYPE "infer-widths"
36 using mlir::InferTypeOpInterface;
37 using mlir::WalkOrder;
39 using namespace circt;
40 using namespace firrtl;
48 auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
51 if (!basetype.hasUninferredWidth())
54 if (basetype.isGround())
55 diag.attachNote() <<
"Field: \"" << str <<
"\"";
56 else if (
auto vecType = type_dyn_cast<FVectorType>(basetype))
58 else if (
auto bundleType = type_dyn_cast<BundleType>(basetype))
59 for (
auto &elem : bundleType.getElements())
68 return TypeSwitch<Operation *, FieldRef>(ist.getOp())
69 .Case<FModuleOp>([&](
auto fmod) {
70 return FieldRef(fmod.getArgument(ist.getPort()), ist.getField());
75 auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(ist.getOp());
76 assert(symOp && symOp.getTargetResultIndex() &&
77 (symOp.supportsPerFieldSymbols() || ist.getField() == 0));
78 return FieldRef(symOp.getTargetResult(), ist.getField());
83 uint64_t convertedFieldID = 0;
85 auto curFID = fieldID;
90 if (isa<FVectorType>(curFType))
93 convertedFieldID += curFID - subID;
98 return convertedFieldID;
110 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
112 inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const T &e) {
119 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
122 return e.hash_value();
127 #define EXPR_NAMES(x) \
128 Root##x, Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
129 #define EXPR_KINDS EXPR_NAMES()
130 #define EXPR_CLASSES EXPR_NAMES(Expr)
135 std::optional<int32_t> solution;
139 void print(llvm::raw_ostream &os)
const;
142 Expr(Kind kind) : kind(kind) {}
147 template <
class DerivedT, Expr::Kind DerivedKind>
148 struct ExprBase :
public Expr {
149 ExprBase() : Expr(DerivedKind) {}
150 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
152 if (
auto otherSame = dyn_cast<DerivedT>(other))
153 return *
static_cast<DerivedT *
>(
this) == otherSame;
159 struct RootExpr :
public ExprBase<RootExpr, Expr::Kind::Root> {
160 RootExpr(std::vector<Expr *> &exprs) : exprs(exprs) {}
161 void print(llvm::raw_ostream &os)
const { os <<
"root"; }
162 std::vector<Expr *> &exprs;
166 struct VarExpr :
public ExprBase<VarExpr, Expr::Kind::Var> {
167 void print(llvm::raw_ostream &os)
const {
170 os <<
"var" << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFFF);
175 Expr *constraint =
nullptr;
178 Expr *upperBound =
nullptr;
179 std::optional<int32_t> upperBoundSolution;
186 struct DerivedExpr :
public ExprBase<DerivedExpr, Expr::Kind::Derived> {
187 void print(llvm::raw_ostream &os)
const {
190 << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFF);
194 Expr *assigned =
nullptr;
211 struct IdExpr :
public ExprBase<IdExpr, Expr::Kind::Id> {
212 IdExpr(Expr *arg) : arg(arg) {
assert(arg); }
213 void print(llvm::raw_ostream &os)
const { os <<
"*" << *arg; }
215 return kind == other.kind && arg == other.arg;
226 struct KnownExpr :
public ExprBase<KnownExpr, Expr::Kind::Known> {
227 KnownExpr(int32_t value) : ExprBase() { solution = value; }
228 void print(llvm::raw_ostream &os)
const { os << *solution; }
229 bool operator==(
const KnownExpr &other)
const {
230 return *solution == *other.solution;
239 struct UnaryExpr :
public Expr {
240 bool operator==(
const UnaryExpr &other)
const {
241 return kind == other.kind && arg == other.arg;
251 UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) {
assert(arg); }
255 template <
class DerivedT, Expr::Kind DerivedKind>
256 struct UnaryExprBase :
public UnaryExpr {
257 template <
typename... Args>
258 UnaryExprBase(Args &&...args)
259 : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
260 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
264 struct PowExpr :
public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
265 using UnaryExprBase::UnaryExprBase;
266 void print(llvm::raw_ostream &os)
const { os <<
"2^" << arg; }
271 struct BinaryExpr :
public Expr {
272 bool operator==(
const BinaryExpr &other)
const {
273 return kind == other.kind && lhs() == other.lhs() && rhs() == other.rhs();
278 Expr *lhs()
const {
return args[0]; }
279 Expr *rhs()
const {
return args[1]; }
285 BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
292 template <
class DerivedT, Expr::Kind DerivedKind>
293 struct BinaryExprBase :
public BinaryExpr {
294 template <
typename... Args>
295 BinaryExprBase(Args &&...args)
296 : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
297 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
301 struct AddExpr :
public BinaryExprBase<AddExpr, Expr::Kind::Add> {
302 using BinaryExprBase::BinaryExprBase;
303 void print(llvm::raw_ostream &os)
const {
304 os <<
"(" << *lhs() <<
" + " << *rhs() <<
")";
309 struct MaxExpr :
public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
310 using BinaryExprBase::BinaryExprBase;
311 void print(llvm::raw_ostream &os)
const {
312 os <<
"max(" << *lhs() <<
", " << *rhs() <<
")";
317 struct MinExpr :
public BinaryExprBase<MinExpr, Expr::Kind::Min> {
318 using BinaryExprBase::BinaryExprBase;
319 void print(llvm::raw_ostream &os)
const {
320 os <<
"min(" << *lhs() <<
", " << *rhs() <<
")";
324 void Expr::print(llvm::raw_ostream &os)
const {
326 [&](
auto *e) { e->print(os); });
338 template <
typename T>
339 struct InternedSlot {
341 InternedSlot(T *ptr) : ptr(ptr) {}
346 template <
typename T,
typename std::enable_if_t<
347 std::is_trivially_destructible<T>::value,
int> = 0>
348 class InternedAllocator {
349 using Slot = InternedSlot<T>;
350 llvm::DenseSet<Slot> interned;
351 llvm::BumpPtrAllocator &allocator;
354 InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
359 template <
typename R = T,
typename... Args>
360 std::pair<R *, bool> alloc(Args &&...args) {
361 auto stack_value = R(std::forward<Args>(args)...);
362 auto stack_slot = Slot(&stack_value);
363 auto it = interned.find(stack_slot);
364 if (it != interned.end())
365 return std::make_pair(
static_cast<R *
>(it->ptr),
false);
366 auto heap_value =
new (allocator) R(std::move(stack_value));
367 interned.insert(Slot(heap_value));
368 return std::make_pair(heap_value,
true);
374 template <
typename T,
typename std::enable_if_t<
375 std::is_trivially_destructible<T>::value,
int> = 0>
377 llvm::BumpPtrAllocator &allocator;
380 Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
384 template <
typename R = T,
typename... Args>
385 R *alloc(Args &&...args) {
386 return new (allocator) R(std::forward<Args>(args)...);
417 int32_t rec_scale = 0;
418 int32_t rec_bias = 0;
419 int32_t nonrec_bias = 0;
423 static LinIneq unsat() {
return LinIneq(
true); }
426 explicit LinIneq(
bool failed =
false) : failed(failed) {}
429 explicit LinIneq(int32_t bias) : nonrec_bias(bias) {}
432 explicit LinIneq(int32_t scale, int32_t bias) {
443 explicit LinIneq(int32_t rec_scale, int32_t rec_bias, int32_t nonrec_bias,
446 if (rec_scale != 0) {
447 this->rec_scale = rec_scale;
448 this->rec_bias = rec_bias;
449 this->nonrec_bias = nonrec_bias;
451 this->nonrec_bias = std::max(rec_bias, nonrec_bias);
462 static LinIneq max(
const LinIneq &lhs,
const LinIneq &rhs) {
463 return LinIneq(std::max(lhs.rec_scale, rhs.rec_scale),
464 std::max(lhs.rec_bias, rhs.rec_bias),
465 std::max(lhs.nonrec_bias, rhs.nonrec_bias),
466 lhs.failed || rhs.failed);
522 static LinIneq add(
const LinIneq &lhs,
const LinIneq &rhs) {
525 auto enable1 = lhs.rec_scale > 0 && rhs.rec_scale > 0;
526 auto enable2 = lhs.rec_scale > 0;
527 auto enable3 = rhs.rec_scale > 0;
528 auto scale1 = lhs.rec_scale + rhs.rec_scale;
529 auto scale2 = lhs.rec_scale;
530 auto scale3 = rhs.rec_scale;
531 auto bias1 = lhs.rec_bias + rhs.rec_bias;
532 auto bias2 = lhs.rec_bias + rhs.nonrec_bias;
533 auto bias3 = rhs.rec_bias + lhs.nonrec_bias;
534 auto maxScale = std::max(scale1, std::max(scale2, scale3));
538 std::optional<int32_t> maxBias;
539 if (enable1 && scale1 == maxScale)
541 if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
543 if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
548 auto nonrec_bias = lhs.nonrec_bias + rhs.nonrec_bias;
549 auto failed = lhs.failed || rhs.failed;
550 if (enable1 && scale1 == maxScale && bias1 == *maxBias)
551 return LinIneq(scale1, bias1, nonrec_bias, failed);
552 if (enable2 && scale2 == maxScale && bias2 == *maxBias)
553 return LinIneq(scale2, bias2, nonrec_bias, failed);
554 if (enable3 && scale3 == maxScale && bias3 == *maxBias)
555 return LinIneq(scale3, bias3, nonrec_bias, failed);
556 return LinIneq(0, 0, nonrec_bias, failed);
569 if (rec_scale == 1 && rec_bias > 0)
575 void print(llvm::raw_ostream &os)
const {
577 bool both = (rec_scale != 0 || rec_bias != 0) && nonrec_bias != 0;
581 if (rec_scale != 0) {
584 os << rec_scale <<
"*";
590 os <<
" - " << -rec_bias;
592 os <<
" + " << rec_bias;
600 if (nonrec_bias != 0) {
617 class ConstraintSolver {
619 ConstraintSolver() =
default;
622 auto v = vars.alloc();
625 info[v].insert(currentInfo);
627 locs[v].insert(*currentLoc);
630 DerivedExpr *derived() {
631 auto *d = derivs.alloc();
635 KnownExpr *known(int32_t value) {
return alloc<KnownExpr>(knowns, value); }
636 IdExpr *id(Expr *arg) {
return alloc<IdExpr>(ids, arg); }
637 PowExpr *pow(Expr *arg) {
return alloc<PowExpr>(uns, arg); }
638 AddExpr *add(Expr *lhs, Expr *rhs) {
return alloc<AddExpr>(bins, lhs, rhs); }
639 MaxExpr *max(Expr *lhs, Expr *rhs) {
return alloc<MaxExpr>(bins, lhs, rhs); }
640 MinExpr *min(Expr *lhs, Expr *rhs) {
return alloc<MinExpr>(bins, lhs, rhs); }
644 Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
646 lhs->constraint = max(lhs->constraint, rhs);
648 lhs->constraint = id(rhs);
649 return lhs->constraint;
654 Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
656 lhs->upperBound = min(lhs->upperBound, rhs);
658 lhs->upperBound = id(rhs);
659 return lhs->upperBound;
662 void dumpConstraints(llvm::raw_ostream &os);
663 LogicalResult solve();
665 using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
666 const ContextInfo &getContextInfo()
const {
return info; }
667 void setCurrentContextInfo(
FieldRef fieldRef) { currentInfo = fieldRef; }
668 void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
672 llvm::BumpPtrAllocator allocator;
673 Allocator<VarExpr> vars = {allocator};
674 Allocator<DerivedExpr> derivs = {allocator};
675 InternedAllocator<KnownExpr> knowns = {allocator};
676 InternedAllocator<IdExpr> ids = {allocator};
677 InternedAllocator<UnaryExpr> uns = {allocator};
678 InternedAllocator<BinaryExpr> bins = {allocator};
681 std::vector<Expr *> exprs;
682 RootExpr root = {exprs};
685 template <
typename R,
typename T,
typename... Args>
686 R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
687 auto it = allocator.template alloc<R>(std::forward<Args>(args)...);
689 exprs.push_back(it.first);
691 info[it.first].insert(currentInfo);
693 locs[it.first].insert(*currentLoc);
701 DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
702 std::optional<Location> currentLoc;
706 ConstraintSolver(ConstraintSolver &&) =
delete;
707 ConstraintSolver(
const ConstraintSolver &) =
delete;
708 ConstraintSolver &operator=(ConstraintSolver &&) =
delete;
709 ConstraintSolver &operator=(
const ConstraintSolver &) =
delete;
711 void emitUninferredWidthError(VarExpr *var);
713 LinIneq checkCycles(VarExpr *var, Expr *expr,
714 SmallPtrSetImpl<Expr *> &seenVars,
715 InFlightDiagnostic *reportInto =
nullptr,
716 unsigned indent = 1);
722 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
723 for (
auto *e : exprs) {
724 if (
auto *v = dyn_cast<VarExpr>(e)) {
726 os <<
"- " << *v <<
" >= " << *v->constraint <<
"\n";
728 os <<
"- " << *v <<
" unconstrained\n";
734 inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const LinIneq &l) {
750 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
751 SmallPtrSetImpl<Expr *> &seenVars,
752 InFlightDiagnostic *reportInto,
755 TypeSwitch<Expr *, LinIneq>(expr)
756 .Case<KnownExpr>([&](
auto *expr) {
return LinIneq(*expr->solution); })
757 .Case<VarExpr>([&](
auto *expr) {
759 return LinIneq(1, 0);
760 if (!seenVars.insert(expr).second)
767 if (!expr->constraint)
770 auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
772 seenVars.erase(expr);
775 .Case<IdExpr>([&](
auto *expr) {
776 return checkCycles(var, expr->arg, seenVars, reportInto,
779 .Case<PowExpr>([&](
auto *expr) {
784 checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
785 if (arg.rec_scale != 0 || arg.nonrec_bias < 0 ||
786 arg.nonrec_bias >= 31)
787 return LinIneq::unsat();
788 return LinIneq(1 << arg.nonrec_bias);
790 .Case<AddExpr>([&](
auto *expr) {
792 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
793 checkCycles(var, expr->rhs(), seenVars, reportInto,
796 .Case<MaxExpr, MinExpr>([&](
auto *expr) {
801 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
802 checkCycles(var, expr->rhs(), seenVars, reportInto,
805 .Default([](
auto) {
return LinIneq::unsat(); });
810 if (reportInto && !ineq.sat()) {
811 auto report = [&](Location loc) {
812 auto ¬e = reportInto->attachNote(loc);
813 note <<
"constrained width W >= ";
814 if (ineq.rec_scale == -1)
816 if (ineq.rec_scale != 1)
817 note << ineq.rec_scale;
819 if (ineq.rec_bias < 0)
820 note <<
"-" << -ineq.rec_bias;
821 if (ineq.rec_bias > 0)
822 note <<
"+" << ineq.rec_bias;
825 auto it = locs.find(expr);
826 if (it != locs.end())
827 for (
auto loc : it->second)
831 LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
832 <<
"- Visited " << *expr <<
": " << ineq <<
"\n");
842 arg.first = operation(*arg.first);
848 llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
849 auto result =
ExprSolution{std::nullopt, lhs.second || rhs.second};
850 if (lhs.first && rhs.first)
851 result.first = operation(*lhs.first, *rhs.first);
853 result.first = lhs.first;
855 result.first = rhs.first;
865 unsigned defaultWorklistSize) {
874 std::vector<Frame> worklist({{expr, indent}});
875 llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
878 worklist.reserve(defaultWorklistSize);
880 while (!worklist.empty()) {
881 auto &frame = worklist.back();
884 if (solution.first && !solution.second)
885 frame.expr->solution = *solution.first;
886 solvedExprs[frame.expr] = solution;
890 if (!isa<KnownExpr>(frame.expr)) {
892 llvm::dbgs().indent(frame.indent * 2)
893 <<
"= Solved " << *frame.expr <<
" = " << *solution.first;
895 llvm::dbgs().indent(frame.indent * 2)
896 <<
"= Skipped " << *frame.expr;
897 llvm::dbgs() <<
" (" << (solution.second ?
"cycle broken" :
"unique")
906 if (frame.expr->solution) {
908 if (!isa<KnownExpr>(frame.expr))
909 llvm::dbgs().indent(indent * 2) <<
"- Cached " << *frame.expr <<
" = "
910 << *frame.expr->solution <<
"\n";
912 setSolution(
ExprSolution{*frame.expr->solution,
false});
918 if (!isa<KnownExpr>(frame.expr))
919 llvm::dbgs().indent(frame.indent * 2)
920 <<
"- Solving " << *frame.expr <<
"\n";
923 TypeSwitch<Expr *>(frame.expr)
924 .Case<KnownExpr>([&](
auto *expr) {
927 .Case<VarExpr>([&](
auto *expr) {
928 if (solvedExprs.contains(expr->constraint)) {
929 auto solution = solvedExprs[expr->constraint];
936 if (expr->upperBound && solvedExprs.contains(expr->upperBound))
937 expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
938 seenVars.erase(expr);
940 if (solution.first && *solution.first < 0)
942 return setSolution(solution);
946 if (!expr->constraint)
951 if (!seenVars.insert(expr).second)
954 worklist.push_back({expr->constraint, indent + 1});
955 if (expr->upperBound)
956 worklist.push_back({expr->upperBound, indent + 1});
958 .Case<IdExpr>([&](
auto *expr) {
959 if (solvedExprs.contains(expr->arg))
960 return setSolution(solvedExprs[expr->arg]);
961 worklist.push_back({expr->arg, indent + 1});
963 .Case<PowExpr>([&](
auto *expr) {
964 if (solvedExprs.contains(expr->arg))
966 solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
968 worklist.push_back({expr->arg, indent + 1});
970 .Case<AddExpr>([&](
auto *expr) {
971 if (solvedExprs.contains(expr->lhs()) &&
972 solvedExprs.contains(expr->rhs()))
974 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
975 [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
977 worklist.push_back({expr->lhs(), indent + 1});
978 worklist.push_back({expr->rhs(), indent + 1});
980 .Case<MaxExpr>([&](
auto *expr) {
981 if (solvedExprs.contains(expr->lhs()) &&
982 solvedExprs.contains(expr->rhs()))
984 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
985 [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
987 worklist.push_back({expr->lhs(), indent + 1});
988 worklist.push_back({expr->rhs(), indent + 1});
990 .Case<MinExpr>([&](
auto *expr) {
991 if (solvedExprs.contains(expr->lhs()) &&
992 solvedExprs.contains(expr->rhs()))
994 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
995 [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
997 worklist.push_back({expr->lhs(), indent + 1});
998 worklist.push_back({expr->rhs(), indent + 1});
1000 .Default([&](
auto) {
1005 return solvedExprs[expr];
1011 LogicalResult ConstraintSolver::solve() {
1013 llvm::dbgs() <<
"\n";
1015 dumpConstraints(llvm::dbgs());
1020 llvm::dbgs() <<
"\n";
1021 debugHeader(
"Checking for unbreakable loops") <<
"\n\n";
1023 SmallPtrSet<Expr *, 16> seenVars;
1024 bool anyFailed =
false;
1026 for (
auto *expr : exprs) {
1028 auto *var = dyn_cast<VarExpr>(expr);
1029 if (!var || !var->constraint)
1031 LLVM_DEBUG(llvm::dbgs()
1032 <<
"- Checking " << *var <<
" >= " << *var->constraint <<
"\n");
1037 seenVars.insert(var);
1038 auto ineq = checkCycles(var, var->constraint, seenVars);
1047 LLVM_DEBUG(llvm::dbgs()
1048 <<
" = Breakable since " << ineq <<
" satisfiable\n");
1056 LLVM_DEBUG(llvm::dbgs()
1057 <<
" = UNBREAKABLE since " << ineq <<
" unsatisfiable\n");
1059 for (
auto fieldRef : info.find(var)->second) {
1062 auto op = fieldRef.getDefiningOp();
1063 auto diag = op ? op->emitOpError()
1064 : mlir::emitError(fieldRef.getValue().getLoc())
1066 diag <<
"is constrained to be wider than itself";
1069 seenVars.insert(var);
1070 checkCycles(var, var->constraint, seenVars, &diag);
1082 llvm::dbgs() <<
"\n";
1085 unsigned defaultWorklistSize = exprs.size() / 2;
1086 for (
auto *expr : exprs) {
1088 auto *var = dyn_cast<VarExpr>(expr);
1093 if (!var->constraint) {
1094 LLVM_DEBUG(llvm::dbgs() <<
"- Unconstrained " << *var <<
"\n");
1095 emitUninferredWidthError(var);
1101 LLVM_DEBUG(llvm::dbgs()
1102 <<
"- Solving " << *var <<
" >= " << *var->constraint <<
"\n");
1103 seenVars.insert(var);
1104 auto solution =
solveExpr(var->constraint, seenVars, defaultWorklistSize);
1106 if (var->upperBound && !var->upperBoundSolution)
1107 var->upperBoundSolution =
1108 solveExpr(var->upperBound, seenVars, defaultWorklistSize).first;
1112 if (solution.first && *solution.first < 0)
1114 var->solution = solution.first;
1118 if (!solution.first) {
1119 LLVM_DEBUG(llvm::dbgs() <<
" - UNSOLVED " << *var <<
"\n");
1120 emitUninferredWidthError(var);
1124 LLVM_DEBUG(llvm::dbgs()
1125 <<
" = Solved " << *var <<
" = " << solution.first <<
" ("
1126 << (solution.second ?
"cycle broken" :
"unique") <<
")\n");
1129 if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1130 LLVM_DEBUG(llvm::dbgs() <<
" ! Unsatisfiable " << *var
1131 <<
" <= " << var->upperBoundSolution <<
"\n");
1132 emitUninferredWidthError(var);
1138 for (
auto *expr : exprs) {
1140 auto *derived = dyn_cast<DerivedExpr>(expr);
1144 auto *assigned = derived->assigned;
1145 if (!assigned || !assigned->solution) {
1146 LLVM_DEBUG(llvm::dbgs() <<
"- Unused " << *derived <<
" set to 0\n");
1147 derived->solution = 0;
1149 LLVM_DEBUG(llvm::dbgs() <<
"- Deriving " << *derived <<
" = "
1150 << assigned->solution <<
"\n");
1151 derived->solution = *assigned->solution;
1155 return failure(anyFailed);
1160 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1161 FieldRef fieldRef = info.find(var)->second.back();
1164 auto diag = mlir::emitError(value.getLoc(),
"uninferred width:");
1167 if (isa<BlockArgument>(value)) {
1169 }
else if (
auto op = value.getDefiningOp()) {
1170 TypeSwitch<Operation *>(op)
1171 .Case<WireOp>([&](
auto) { diag <<
" wire"; })
1172 .Case<RegOp, RegResetOp>([&](
auto) { diag <<
" reg"; })
1173 .Case<NodeOp>([&](
auto) { diag <<
" node"; })
1174 .Default([&](
auto) { diag <<
" value"; });
1181 if (!fieldName.empty()) {
1184 diag <<
" \"" << fieldName <<
"\"";
1187 if (!var->constraint) {
1188 diag <<
" is unconstrained";
1189 }
else if (var->solution && var->upperBoundSolution &&
1190 var->solution > var->upperBoundSolution) {
1191 diag <<
" cannot satisfy all width requirements";
1192 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1193 LLVM_DEBUG(llvm::dbgs() << *var->upperBound <<
"\n");
1194 auto loc = locs.find(var->constraint)->second.back();
1195 diag.attachNote(loc) <<
"width is constrained to be at least "
1196 << *var->solution <<
" here:";
1197 loc = locs.find(var->upperBound)->second.back();
1198 diag.attachNote(loc) <<
"width is constrained to be at most "
1199 << *var->upperBoundSolution <<
" here:";
1201 diag <<
" width cannot be determined";
1202 LLVM_DEBUG(llvm::dbgs() << *var->constraint <<
"\n");
1203 auto loc = locs.find(var->constraint)->second.back();
1204 diag.attachNote(loc) <<
"width is constrained by an uninferred width here:";
1216 class InferenceMapping {
1218 InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1219 hw::InnerSymbolTableCollection &istc)
1220 : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1222 LogicalResult map(CircuitOp op);
1223 LogicalResult mapOperation(Operation *op);
1228 void declareVars(Value value, Location loc,
bool isDerived =
false);
1231 Expr *declareVar(
FieldRef fieldRef, Location loc);
1235 Expr *declareVar(
FIRRTLType type, Location loc);
1240 void maximumOfTypes(Value result, Value rhs, Value lhs);
1244 void constrainTypes(Value larger, Value smaller,
bool equal =
false);
1248 void constrainTypes(Expr *larger, Expr *smaller,
1249 bool imposeUpperBounds =
false,
bool equal =
false);
1258 Expr *getExpr(Value value)
const;
1261 Expr *getExpr(
FieldRef fieldRef)
const;
1265 Expr *getExprOrNull(
FieldRef fieldRef)
const;
1269 void setExpr(Value value, Expr *expr);
1272 void setExpr(
FieldRef fieldRef, Expr *expr);
1275 bool isModuleSkipped(FModuleOp module)
const {
1276 return skippedModules.count(module);
1280 bool areAllModulesSkipped()
const {
return allModulesSkipped; }
1284 ConstraintSolver &solver;
1287 DenseMap<FieldRef, Expr *> opExprs;
1290 SmallPtrSet<Operation *, 16> skippedModules;
1291 bool allModulesSkipped =
true;
1294 SymbolTable &symtbl;
1297 hw::InnerRefNamespace irn;
1304 return TypeSwitch<Type, bool>(type)
1305 .Case<
FIRRTLBaseType>([](
auto base) {
return base.hasUninferredWidth(); })
1307 [](
auto ref) {
return ref.getType().hasUninferredWidth(); })
1308 .Default([](
auto) {
return false; });
1311 LogicalResult InferenceMapping::map(CircuitOp op) {
1312 LLVM_DEBUG(llvm::dbgs()
1313 <<
"\n===----- Mapping ops to constraint exprs -----===\n\n");
1316 for (
auto module : op.getOps<FModuleOp>())
1317 for (
auto arg : module.getArguments()) {
1318 solver.setCurrentContextInfo(
FieldRef(arg, 0));
1319 declareVars(arg, module.getLoc());
1322 for (
auto module : op.getOps<FModuleOp>()) {
1325 bool anyUninferred =
false;
1326 for (
auto arg : module.getArguments()) {
1331 module.walk([&](Operation *op) {
1332 for (
auto type : op->getResultTypes())
1335 return WalkResult::interrupt();
1336 return WalkResult::advance();
1339 if (!anyUninferred) {
1340 LLVM_DEBUG(llvm::dbgs() <<
"Skipping fully-inferred module '"
1341 << module.getName() <<
"'\n");
1342 skippedModules.insert(module);
1346 allModulesSkipped =
false;
1350 auto result = module.getBodyBlock()->walk(
1351 [&](Operation *op) {
return WalkResult(mapOperation(op)); });
1352 if (result.wasInterrupted())
1359 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1365 bool allWidthsKnown =
true;
1366 for (
auto result : op->getResults()) {
1367 if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1369 allWidthsKnown =
false;
1372 auto resultTy = type_dyn_cast<FIRRTLType>(result.getType());
1376 declareVars(result, op->getLoc());
1378 allWidthsKnown =
false;
1380 if (allWidthsKnown && !isa<FConnectLike, AttachOp>(op))
1384 if (isa<PropAssignOp>(op))
1388 bool mappingFailed =
false;
1389 solver.setCurrentContextInfo(
1391 solver.setCurrentLocation(op->getLoc());
1392 TypeSwitch<Operation *>(op)
1393 .Case<ConstantOp>([&](
auto op) {
1397 if (
auto width = op.getType().base().getWidth())
1398 e = solver.known(*
width);
1400 auto v = op.getValue();
1401 auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1402 : v.countLeadingZeros());
1405 e = solver.known(std::max(w, 1u));
1407 setExpr(op.getResult(), e);
1409 .Case<SpecialConstantOp>([&](
auto op) {
1412 .Case<InvalidValueOp>([&](
auto op) {
1417 declareVars(op.getResult(), op.getLoc(),
true);
1421 auto type = op.getType();
1422 ImplicitLocOpBuilder
builder(op->getLoc(), op);
1424 llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
1428 auto clone =
builder.create<InvalidValueOp>(type);
1429 declareVars(clone.getResult(), clone.getLoc(),
1434 .Case<WireOp, RegOp>(
1435 [&](
auto op) { declareVars(op.getResult(), op.getLoc()); })
1436 .Case<RegResetOp>([&](
auto op) {
1441 declareVars(op.getResult(), op.getLoc());
1444 constrainTypes(op.getResult(), op.getResetValue());
1446 .Case<NodeOp>([&](
auto op) {
1449 op.getResult().getType());
1453 .Case<SubfieldOp>([&](
auto op) {
1454 BundleType bundleType = op.getInput().getType();
1455 auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1456 unifyTypes(
FieldRef(op.getResult(), 0),
1457 FieldRef(op.getInput(), fieldID), op.getType());
1459 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
1465 .Case<SubtagOp>([&](
auto op) {
1466 FEnumType enumType = op.getInput().getType();
1467 auto fieldID = enumType.getFieldID(op.getFieldIndex());
1468 unifyTypes(
FieldRef(op.getResult(), 0),
1469 FieldRef(op.getInput(), fieldID), op.getType());
1472 .Case<RefSubOp>([&](RefSubOp op) {
1473 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1474 op.getInput().getType().getType())
1475 .Case<FVectorType>([](
auto _) {
return 1; })
1476 .Case<BundleType>([&](
auto type) {
1477 return type.getFieldID(op.getIndex());
1479 unifyTypes(
FieldRef(op.getResult(), 0),
1480 FieldRef(op.getInput(), fieldID), op.getType());
1484 .Case<AddPrimOp, SubPrimOp>([&](
auto op) {
1485 auto lhs = getExpr(op.getLhs());
1486 auto rhs = getExpr(op.getRhs());
1487 auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1488 setExpr(op.getResult(), e);
1490 .Case<MulPrimOp>([&](
auto op) {
1491 auto lhs = getExpr(op.getLhs());
1492 auto rhs = getExpr(op.getRhs());
1493 auto e = solver.add(lhs, rhs);
1494 setExpr(op.getResult(), e);
1496 .Case<DivPrimOp>([&](
auto op) {
1497 auto lhs = getExpr(op.getLhs());
1499 if (op.getType().base().isSigned()) {
1500 e = solver.add(lhs, solver.known(1));
1504 setExpr(op.getResult(), e);
1506 .Case<RemPrimOp>([&](
auto op) {
1507 auto lhs = getExpr(op.getLhs());
1508 auto rhs = getExpr(op.getRhs());
1509 auto e = solver.min(lhs, rhs);
1510 setExpr(op.getResult(), e);
1512 .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](
auto op) {
1513 auto lhs = getExpr(op.getLhs());
1514 auto rhs = getExpr(op.getRhs());
1515 auto e = solver.max(lhs, rhs);
1516 setExpr(op.getResult(), e);
1520 .Case<CatPrimOp>([&](
auto op) {
1521 auto lhs = getExpr(op.getLhs());
1522 auto rhs = getExpr(op.getRhs());
1523 auto e = solver.add(lhs, rhs);
1524 setExpr(op.getResult(), e);
1526 .Case<DShlPrimOp>([&](
auto op) {
1527 auto lhs = getExpr(op.getLhs());
1528 auto rhs = getExpr(op.getRhs());
1529 auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1530 setExpr(op.getResult(), e);
1532 .Case<DShlwPrimOp, DShrPrimOp>([&](
auto op) {
1533 auto e = getExpr(op.getLhs());
1534 setExpr(op.getResult(), e);
1538 .Case<NegPrimOp>([&](
auto op) {
1539 auto input = getExpr(op.getInput());
1540 auto e = solver.add(input, solver.known(1));
1541 setExpr(op.getResult(), e);
1543 .Case<CvtPrimOp>([&](
auto op) {
1544 auto input = getExpr(op.getInput());
1545 auto e = op.getInput().getType().base().isSigned()
1547 : solver.add(input, solver.known(1));
1548 setExpr(op.getResult(), e);
1552 .Case<BitsPrimOp>([&](
auto op) {
1553 setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1555 .Case<HeadPrimOp>([&](
auto op) {
1556 setExpr(op.getResult(), solver.known(op.getAmount()));
1558 .Case<TailPrimOp>([&](
auto op) {
1559 auto input = getExpr(op.getInput());
1560 auto e = solver.add(input, solver.known(-op.getAmount()));
1561 setExpr(op.getResult(), e);
1563 .Case<PadPrimOp>([&](
auto op) {
1564 auto input = getExpr(op.getInput());
1565 auto e = solver.max(input, solver.known(op.getAmount()));
1566 setExpr(op.getResult(), e);
1568 .Case<ShlPrimOp>([&](
auto op) {
1569 auto input = getExpr(op.getInput());
1570 auto e = solver.add(input, solver.known(op.getAmount()));
1571 setExpr(op.getResult(), e);
1573 .Case<ShrPrimOp>([&](
auto op) {
1574 auto input = getExpr(op.getInput());
1576 auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1577 auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1578 solver.known(minWidth));
1579 setExpr(op.getResult(), e);
1583 .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1584 [&](
auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1585 .Case<mlir::UnrealizedConversionCastOp>(
1586 [&](
auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1590 .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1591 AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1592 XorRPrimOp>([&](
auto op) {
1593 auto width = op.getType().getBitWidthOrSentinel();
1594 assert(
width > 0 &&
"width should have been checked by verifier");
1595 setExpr(op.getResult(), solver.known(
width));
1597 .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](
auto op) {
1598 auto *sel = getExpr(op.getSel());
1599 constrainTypes(solver.known(1), sel,
true);
1600 maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1602 .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1603 auto *sel = getExpr(op.getSel());
1604 constrainTypes(solver.known(2), sel,
true);
1605 maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1606 maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1607 maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1610 .Case<ConnectOp, StrictConnectOp>(
1611 [&](
auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1612 .Case<RefDefineOp>([&](
auto op) {
1615 constrainTypes(op.getDest(), op.getSrc(),
true);
1618 .Case<StrictConnectOp>([&](
auto op) {
1623 constrainTypes(op.getDest(), op.getSrc());
1624 constrainTypes(op.getSrc(), op.getDest());
1626 .Case<AttachOp>([&](
auto op) {
1630 if (op.getAttached().empty())
1632 auto prev = op.getAttached()[0];
1633 for (
auto operand : op.getAttached().drop_front()) {
1634 auto e1 = getExpr(prev);
1635 auto e2 = getExpr(operand);
1636 constrainTypes(e1, e2,
true);
1637 constrainTypes(e2, e1,
true);
1643 .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1644 UnclockedAssumeIntrinsicOp, CoverOp>([&](
auto) {})
1647 .Case<InstanceOp>([&](
auto op) {
1648 auto refdModule = op.getReferencedOperation(symtbl);
1649 auto module = dyn_cast<FModuleOp>(&*refdModule);
1651 auto diag = mlir::emitError(op.getLoc());
1652 diag <<
"extern module `" << op.getModuleName()
1653 <<
"` has ports of uninferred width";
1655 auto fml = cast<FModuleLike>(&*refdModule);
1656 auto ports = fml.getPorts();
1657 for (
auto &port : ports) {
1659 if (baseType && baseType.hasUninferredWidth()) {
1660 diag.attachNote(op.getLoc()) <<
"Port: " << port.name;
1661 if (!baseType.isGround())
1666 diag.attachNote(op.getLoc())
1667 <<
"Only non-extern FIRRTL modules may contain unspecified "
1668 "widths to be inferred automatically.";
1669 diag.attachNote(refdModule->getLoc())
1670 <<
"Module `" << op.getModuleName() <<
"` defined here:";
1671 mappingFailed =
true;
1678 for (
auto it : llvm::zip(op->getResults(), module.getArguments())) {
1680 type_cast<FIRRTLType>(std::get<0>(it).getType()));
1685 .Case<MemOp>([&](MemOp op) {
1687 unsigned nonDebugPort = 0;
1688 for (
const auto &result : llvm::enumerate(op.getResults())) {
1689 declareVars(result.value(), op.getLoc());
1690 if (!type_isa<RefType>(result.value().getType()))
1691 nonDebugPort = result.index();
1696 auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1697 static const unsigned indices[] = {3, 5};
1698 static const unsigned debug[] = {0};
1700 case MemOp::PortKind::Read:
1701 case MemOp::PortKind::Write:
1702 return ArrayRef<unsigned>(indices, 1);
1703 case MemOp::PortKind::ReadWrite:
1704 return ArrayRef<unsigned>(indices);
1705 case MemOp::PortKind::Debug:
1706 return ArrayRef<unsigned>(
debug);
1708 llvm_unreachable(
"Imposible PortKind");
1715 unsigned firstFieldIndex =
1716 dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1718 op.getResult(nonDebugPort),
1719 type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1720 .getFieldID(firstFieldIndex));
1721 LLVM_DEBUG(llvm::dbgs() <<
"Adjusting memory port variables:\n");
1724 auto dataType = op.getDataType();
1725 for (
unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1726 auto result = op.getResult(i);
1727 if (type_isa<RefType>(result.getType())) {
1731 unifyTypes(firstData,
FieldRef(result, 1), dataType);
1736 type_cast<BundleType>(op.getPortType(i).getPassiveType());
1737 for (
auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1738 unifyTypes(
FieldRef(result, portType.getFieldID(fieldIndex)),
1739 firstData, dataType);
1743 .Case<RefSendOp>([&](
auto op) {
1744 declareVars(op.getResult(), op.getLoc());
1745 constrainTypes(op.getResult(), op.getBase(),
true);
1747 .Case<RefResolveOp>([&](
auto op) {
1748 declareVars(op.getResult(), op.getLoc());
1749 constrainTypes(op.getResult(), op.getRef(),
true);
1751 .Case<RefCastOp>([&](
auto op) {
1752 declareVars(op.getResult(), op.getLoc());
1753 constrainTypes(op.getResult(), op.getInput(),
true);
1755 .Case<RWProbeOp>([&](
auto op) {
1756 declareVars(op.getResult(), op.getLoc());
1757 auto ist = irn.lookup(op.getTarget());
1759 op->emitError(
"target of rwprobe could not be resolved");
1760 mappingFailed =
true;
1765 op->emitError(
"target of rwprobe resolved to unsupported target");
1766 mappingFailed =
true;
1770 ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1771 unifyTypes(
FieldRef(op.getResult(), 0),
1772 FieldRef(ref.getValue(), newFID), op.getType());
1774 .Case<mlir::UnrealizedConversionCastOp>([&](
auto op) {
1775 for (Value result : op.getResults()) {
1776 auto ty = result.getType();
1777 if (type_isa<FIRRTLType>(ty))
1778 declareVars(result, op.getLoc());
1781 .Default([&](
auto op) {
1782 op->emitOpError(
"not supported in width inference");
1783 mappingFailed =
true;
1787 if (
auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1791 return failure(mappingFailed);
1796 void InferenceMapping::declareVars(Value value, Location loc,
bool isDerived) {
1799 unsigned fieldID = 0;
1801 auto width = type.getBitWidthOrSentinel();
1806 }
else if (
width == -1) {
1809 solver.setCurrentContextInfo(field);
1811 setExpr(field, solver.derived());
1813 setExpr(field, solver.var());
1815 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1818 for (
auto &element : bundleType) {
1819 declare(element.type);
1821 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1823 auto save = fieldID;
1824 declare(vecType.getElementType());
1826 fieldID = save + vecType.getMaxFieldID();
1827 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1829 for (
auto &element : enumType.getElements())
1830 declare(element.type);
1832 llvm_unreachable(
"Unknown type inside a bundle!");
1842 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1846 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1848 for (
auto &element : bundleType.getElements())
1849 maximize(element.type);
1850 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1852 auto save = fieldID;
1854 if (vecType.getNumElements() > 0)
1855 maximize(vecType.getElementType());
1856 fieldID = save + vecType.getMaxFieldID();
1857 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1859 for (
auto &element : enumType.getElements())
1860 maximize(element.type);
1861 }
else if (type.isGround()) {
1862 auto *e = solver.max(getExpr(
FieldRef(rhs, fieldID)),
1864 setExpr(
FieldRef(result, fieldID), e);
1867 llvm_unreachable(
"Unknown type inside a bundle!");
1882 void InferenceMapping::constrainTypes(Value larger, Value smaller,
bool equal) {
1889 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1891 for (
auto &element : bundleType.getElements()) {
1893 constrain(element.type, smaller, larger);
1895 constrain(element.type, larger, smaller);
1897 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1899 auto save = fieldID;
1901 if (vecType.getNumElements() > 0) {
1902 constrain(vecType.getElementType(), larger, smaller);
1904 fieldID = save + vecType.getMaxFieldID();
1905 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1907 for (
auto &element : enumType.getElements())
1908 constrain(element.type, larger, smaller);
1909 }
else if (type.isGround()) {
1911 constrainTypes(getExpr(
FieldRef(larger, fieldID)),
1912 getExpr(
FieldRef(smaller, fieldID)),
false, equal);
1915 llvm_unreachable(
"Unknown type inside a bundle!");
1920 constrain(type, larger, smaller);
1925 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1926 bool imposeUpperBounds,
bool equal) {
1927 assert(larger &&
"Larger expression should be specified");
1928 assert(smaller &&
"Smaller expression should be specified");
1934 if (
auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1935 largerDerived->assigned = smaller;
1936 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *largerDerived <<
" from "
1937 << *smaller <<
"\n");
1940 if (
auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1941 smallerDerived->assigned = larger;
1942 LLVM_DEBUG(llvm::dbgs() <<
"Deriving " << *smallerDerived <<
" from "
1943 << *larger <<
"\n");
1949 if (
auto largerVar = dyn_cast<VarExpr>(larger)) {
1950 [[maybe_unused]]
auto *c = solver.addGeqConstraint(largerVar, smaller);
1951 LLVM_DEBUG(llvm::dbgs()
1952 <<
"Constrained " << *largerVar <<
" >= " << *c <<
"\n");
1959 [[maybe_unused]]
auto *leq = solver.addLeqConstraint(largerVar, smaller);
1960 LLVM_DEBUG(llvm::dbgs()
1961 <<
"Constrained " << *largerVar <<
" <= " << *leq <<
"\n");
1971 if (
auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1972 if (imposeUpperBounds || equal) {
1973 [[maybe_unused]]
auto *c = solver.addLeqConstraint(smallerVar, larger);
1974 LLVM_DEBUG(llvm::dbgs()
1975 <<
"Constrained " << *smallerVar <<
" <= " << *c <<
"\n");
1995 LLVM_DEBUG(llvm::dbgs()
1996 <<
"Unify " <<
getFieldName(lhsFieldRef).first <<
" = "
1999 if (
auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
2000 solver.addGeqConstraint(var, solver.known(0));
2001 setExpr(lhsFieldRef, getExpr(rhsFieldRef));
2003 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
2005 for (
auto &element : bundleType) {
2006 unify(element.type);
2008 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2010 auto save = fieldID;
2012 if (vecType.getNumElements() > 0) {
2013 unify(vecType.getElementType());
2015 fieldID = save + vecType.getMaxFieldID();
2016 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
2018 for (
auto &element : enumType.getElements())
2019 unify(element.type);
2021 llvm_unreachable(
"Unknown type inside a bundle!");
2029 Expr *InferenceMapping::getExpr(Value value)
const {
2032 return getExpr(
FieldRef(value, 0));
2036 Expr *InferenceMapping::getExpr(
FieldRef fieldRef)
const {
2037 auto expr = getExprOrNull(fieldRef);
2038 assert(expr &&
"constraint expr should have been constructed for value");
2042 Expr *InferenceMapping::getExprOrNull(
FieldRef fieldRef)
const {
2043 auto it = opExprs.find(fieldRef);
2044 return it != opExprs.end() ? it->second :
nullptr;
2048 void InferenceMapping::setExpr(Value value, Expr *expr) {
2055 void InferenceMapping::setExpr(
FieldRef fieldRef, Expr *expr) {
2057 llvm::dbgs() <<
"Expr " << *expr <<
" for " << fieldRef.
getValue();
2059 llvm::dbgs() <<
" '" <<
getFieldName(fieldRef).first <<
"'";
2061 if (fieldName.second)
2062 llvm::dbgs() <<
" (\"" << fieldName.first <<
"\")";
2063 llvm::dbgs() <<
"\n";
2065 opExprs[fieldRef] = expr;
2075 class InferenceTypeUpdate {
2077 InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2079 LogicalResult update(CircuitOp op);
2085 const InferenceMapping &mapping;
2091 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2093 llvm::dbgs() <<
"\n";
2096 return mlir::failableParallelForEach(
2097 op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2100 if (mapping.isModuleSkipped(op))
2102 auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2103 if (failed(updateOperation(op)))
2104 return WalkResult::interrupt();
2105 return WalkResult::advance();
2106 }).wasInterrupted();
2107 return failure(isFailed);
2113 bool anyChanged =
false;
2115 for (Value v : op->getResults()) {
2116 auto result = updateValue(v);
2119 anyChanged |= *result;
2125 if (
auto con = dyn_cast<ConnectOp>(op)) {
2126 auto lhs = con.getDest();
2127 auto rhs = con.getSrc();
2128 auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2129 auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2132 if (!lhsType || !rhsType)
2135 auto lhsWidth = lhsType.getBitWidthOrSentinel();
2136 auto rhsWidth = rhsType.getBitWidthOrSentinel();
2137 if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2139 auto trunc =
builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2140 rhsWidth - lhsWidth);
2141 if (type_isa<SIntType>(rhsType))
2143 builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2145 LLVM_DEBUG(llvm::dbgs()
2146 <<
"Truncating RHS to " << lhsType <<
" in " << con <<
"\n");
2147 con->replaceUsesOfWith(con.getSrc(), trunc);
2153 if (
auto module = dyn_cast<FModuleOp>(op)) {
2155 bool argsChanged =
false;
2156 SmallVector<Attribute> argTypes;
2157 argTypes.reserve(module.getNumPorts());
2158 for (
auto arg : module.getArguments()) {
2159 auto result = updateValue(arg);
2162 argsChanged |= *result;
2168 module->setAttr(FModuleLike::getPortTypesAttrName(),
2178 auto *context = type.getContext();
2180 .
Case<UIntType>([&](
auto type) {
2183 .Case<SIntType>([&](
auto type) {
2186 .Case<AnalogType>([&](
auto type) {
2189 .Default([&](
auto type) {
return type; });
2195 auto type = type_dyn_cast<FIRRTLType>(value.getType());
2206 if (
auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2207 SmallVector<Type, 2> types;
2209 op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2210 op->getAttrDictionary(), op->getPropertiesStorage(),
2211 op->getRegions(), types);
2215 assert(types.size() == op->getNumResults());
2216 for (
auto it : llvm::zip(op->getResults(), types)) {
2217 LLVM_DEBUG(llvm::dbgs() <<
"Inferring " << std::get<0>(it) <<
" as "
2218 << std::get<1>(it) <<
"\n");
2219 std::get<0>(it).setType(std::get<1>(it));
2225 auto context = type.getContext();
2226 unsigned fieldID = 0;
2229 auto width = type.getBitWidthOrSentinel();
2234 }
else if (
width == -1) {
2239 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
2242 llvm::SmallVector<BundleType::BundleElement, 3> elements;
2243 for (
auto &element : bundleType) {
2244 auto updatedBase = updateBase(element.type);
2247 elements.emplace_back(element.name, element.isFlip, updatedBase);
2250 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2252 auto save = fieldID;
2255 if (vecType.getNumElements() > 0) {
2256 auto updatedBase = updateBase(vecType.getElementType());
2261 fieldID = save + vecType.getMaxFieldID();
2266 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
2268 llvm::SmallVector<FEnumType::EnumElement> elements;
2269 for (
auto &element : enumType.getElements()) {
2270 auto updatedBase = updateBase(element.type);
2273 elements.emplace_back(element.name, updatedBase);
2277 llvm_unreachable(
"Unknown type inside a bundle!");
2284 LLVM_DEBUG(llvm::dbgs() <<
"Update " << value <<
" to " << newType <<
"\n");
2285 value.setType(newType);
2290 if (
auto op = value.getDefiningOp<ConstantOp>()) {
2291 auto k = op.getValue();
2292 auto bitwidth = op.getType().getBitWidthOrSentinel();
2293 if (k.getBitWidth() >
unsigned(bitwidth))
2294 k = k.trunc(bitwidth);
2298 return newType != type;
2304 assert(type.isGround() &&
"Can only pass in ground types.");
2307 Expr *expr = mapping.getExprOrNull(fieldRef);
2308 if (!expr || !expr->solution) {
2313 mlir::emitError(value.getLoc(),
"width should have been inferred");
2316 int32_t solution = *expr->solution;
2328 template <
typename T>
2329 struct DenseMapInfo<InternedSlot<T>> {
2332 auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
2333 return Slot(
static_cast<T *
>(pointer));
2336 auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
2337 return Slot(
static_cast<T *
>(pointer));
2341 auto empty = getEmptyKey().ptr;
2342 auto tombstone = getTombstoneKey().ptr;
2343 if (LHS.ptr ==
empty || RHS.ptr ==
empty || LHS.ptr == tombstone ||
2344 RHS.ptr == tombstone)
2345 return LHS.ptr == RHS.ptr;
2346 return *LHS.ptr == *RHS.ptr;
2356 class InferWidthsPass :
public InferWidthsBase<InferWidthsPass> {
2357 void runOnOperation()
override;
2361 void InferWidthsPass::runOnOperation() {
2363 ConstraintSolver solver;
2364 InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2365 getAnalysis<hw::InnerSymbolTableCollection>());
2366 if (failed(mapping.map(getOperation()))) {
2367 signalPassFailure();
2370 if (mapping.areAllModulesSkipped()) {
2371 markAllAnalysesPreserved();
2376 if (failed(solver.solve())) {
2377 signalPassFailure();
2382 if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2383 signalPassFailure();
2387 return std::make_unique<InferWidthsPass>();
assert(baseType &&"element must be base type")
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
bool operator==(const ResetDomain &a, const ResetDomain &b)
static FieldRef getRefForIST(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, unsigned defaultWorklistSize)
Compute the value of a constraint expr.
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
T & operator<<(T &os, FIRVersion version)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
llvm::hash_code hash_value(const T &e)
static Slot getEmptyKey()
static unsigned getHashValue(Slot val)
static Slot getTombstoneKey()
static bool isEqual(Slot LHS, Slot RHS)