20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/Threading.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/GraphTraits.h"
25 #include "llvm/ADT/Hashing.h"
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
32 #define DEBUG_TYPE "infer-widths"
34 using mlir::InferTypeOpInterface;
35 using mlir::WalkOrder;
37 using namespace circt;
38 using namespace firrtl;
46 auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
49 if (!basetype.hasUninferredWidth())
52 if (basetype.isGround())
53 diag.attachNote() <<
"Field: \"" << str <<
"\"";
54 else if (
auto vecType = type_dyn_cast<FVectorType>(basetype))
56 else if (
auto bundleType = type_dyn_cast<BundleType>(basetype))
57 for (
auto &elem : bundleType.getElements())
66 return TypeSwitch<Operation *, FieldRef>(ist.getOp())
67 .Case<FModuleOp>([&](
auto fmod) {
68 return FieldRef(fmod.getArgument(ist.getPort()), ist.getField());
73 auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(ist.getOp());
74 assert(symOp && symOp.getTargetResultIndex() &&
75 (symOp.supportsPerFieldSymbols() || ist.getField() == 0));
76 return FieldRef(symOp.getTargetResult(), ist.getField());
81 uint64_t convertedFieldID = 0;
83 auto curFID = fieldID;
88 if (isa<FVectorType>(curFType))
91 convertedFieldID += curFID - subID;
96 return convertedFieldID;
110 inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const T &e) {
120 return e.hash_value();
125 #define EXPR_NAMES(x) \
126 Root##x, Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
127 #define EXPR_KINDS EXPR_NAMES()
128 #define EXPR_CLASSES EXPR_NAMES(Expr)
133 std::optional<int32_t> solution;
137 void print(llvm::raw_ostream &os)
const;
140 Expr(Kind kind) : kind(kind) {}
145 template <
class DerivedT, Expr::Kind DerivedKind>
146 struct ExprBase :
public Expr {
147 ExprBase() : Expr(DerivedKind) {}
148 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
150 if (
auto otherSame = dyn_cast<DerivedT>(other))
151 return *
static_cast<DerivedT *
>(
this) == otherSame;
157 struct RootExpr :
public ExprBase<RootExpr, Expr::Kind::Root> {
158 RootExpr(std::vector<Expr *> &exprs) : exprs(exprs) {}
159 void print(llvm::raw_ostream &os)
const { os <<
"root"; }
160 std::vector<Expr *> &exprs;
164 struct VarExpr :
public ExprBase<VarExpr, Expr::Kind::Var> {
165 void print(llvm::raw_ostream &os)
const {
168 os <<
"var" << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFFF);
173 Expr *constraint =
nullptr;
176 Expr *upperBound =
nullptr;
177 std::optional<int32_t> upperBoundSolution;
184 struct DerivedExpr :
public ExprBase<DerivedExpr, Expr::Kind::Derived> {
185 void print(llvm::raw_ostream &os)
const {
188 << ((size_t)
this / llvm::PowerOf2Ceil(
sizeof(*
this)) & 0xFFF);
192 Expr *assigned =
nullptr;
209 struct IdExpr :
public ExprBase<IdExpr, Expr::Kind::Id> {
210 IdExpr(Expr *arg) : arg(arg) {
assert(arg); }
211 void print(llvm::raw_ostream &os)
const { os <<
"*" << *arg; }
213 return kind == other.kind && arg == other.arg;
224 struct KnownExpr :
public ExprBase<KnownExpr, Expr::Kind::Known> {
225 KnownExpr(int32_t
value) : ExprBase() { solution =
value; }
226 void print(llvm::raw_ostream &os)
const { os << *solution; }
227 bool operator==(
const KnownExpr &other)
const {
228 return *solution == *other.solution;
237 struct UnaryExpr :
public Expr {
238 bool operator==(
const UnaryExpr &other)
const {
239 return kind == other.kind && arg == other.arg;
249 UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) {
assert(arg); }
253 template <
class DerivedT, Expr::Kind DerivedKind>
254 struct UnaryExprBase :
public UnaryExpr {
255 template <
typename... Args>
256 UnaryExprBase(Args &&...args)
257 : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
258 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
262 struct PowExpr :
public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
263 using UnaryExprBase::UnaryExprBase;
264 void print(llvm::raw_ostream &os)
const { os <<
"2^" << arg; }
269 struct BinaryExpr :
public Expr {
270 bool operator==(
const BinaryExpr &other)
const {
271 return kind == other.kind && lhs() == other.lhs() && rhs() == other.rhs();
276 Expr *lhs()
const {
return args[0]; }
277 Expr *rhs()
const {
return args[1]; }
283 BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
290 template <
class DerivedT, Expr::Kind DerivedKind>
291 struct BinaryExprBase :
public BinaryExpr {
292 template <
typename... Args>
293 BinaryExprBase(Args &&...args)
294 : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
295 static bool classof(
const Expr *e) {
return e->kind == DerivedKind; }
299 struct AddExpr :
public BinaryExprBase<AddExpr, Expr::Kind::Add> {
300 using BinaryExprBase::BinaryExprBase;
301 void print(llvm::raw_ostream &os)
const {
302 os <<
"(" << *lhs() <<
" + " << *rhs() <<
")";
307 struct MaxExpr :
public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
308 using BinaryExprBase::BinaryExprBase;
309 void print(llvm::raw_ostream &os)
const {
310 os <<
"max(" << *lhs() <<
", " << *rhs() <<
")";
315 struct MinExpr :
public BinaryExprBase<MinExpr, Expr::Kind::Min> {
316 using BinaryExprBase::BinaryExprBase;
317 void print(llvm::raw_ostream &os)
const {
318 os <<
"min(" << *lhs() <<
", " << *rhs() <<
")";
322 void Expr::print(llvm::raw_ostream &os)
const {
324 [&](
auto *e) { e->print(os); });
336 template <
typename T>
337 struct InternedSlot {
339 InternedSlot(T *ptr) : ptr(ptr) {}
344 template <
typename T,
typename std::enable_if_t<
346 class InternedAllocator {
347 using Slot = InternedSlot<T>;
348 llvm::DenseSet<Slot> interned;
349 llvm::BumpPtrAllocator &allocator;
352 InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
357 template <
typename R = T,
typename... Args>
358 std::pair<R *, bool> alloc(Args &&...args) {
359 auto stack_value = R(std::forward<Args>(args)...);
360 auto stack_slot = Slot(&stack_value);
361 auto it = interned.find(stack_slot);
362 if (it != interned.end())
363 return std::make_pair(
static_cast<R *
>(it->ptr),
false);
364 auto heap_value =
new (allocator) R(std::move(stack_value));
365 interned.insert(Slot(heap_value));
366 return std::make_pair(heap_value,
true);
372 template <
typename T,
typename std::enable_if_t<
375 llvm::BumpPtrAllocator &allocator;
378 Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
382 template <
typename R = T,
typename... Args>
383 R *alloc(Args &&...args) {
384 return new (allocator) R(std::forward<Args>(args)...);
415 int32_t rec_scale = 0;
416 int32_t rec_bias = 0;
417 int32_t nonrec_bias = 0;
421 static LinIneq unsat() {
return LinIneq(
true); }
424 explicit LinIneq(
bool failed =
false) : failed(failed) {}
427 explicit LinIneq(int32_t bias) : nonrec_bias(bias) {}
430 explicit LinIneq(int32_t scale, int32_t bias) {
441 explicit LinIneq(int32_t rec_scale, int32_t rec_bias, int32_t nonrec_bias,
444 if (rec_scale != 0) {
445 this->rec_scale = rec_scale;
446 this->rec_bias = rec_bias;
447 this->nonrec_bias = nonrec_bias;
449 this->nonrec_bias = std::max(rec_bias, nonrec_bias);
460 static LinIneq max(
const LinIneq &lhs,
const LinIneq &rhs) {
461 return LinIneq(std::max(lhs.rec_scale, rhs.rec_scale),
462 std::max(lhs.rec_bias, rhs.rec_bias),
463 std::max(lhs.nonrec_bias, rhs.nonrec_bias),
464 lhs.failed || rhs.failed);
520 static LinIneq add(
const LinIneq &lhs,
const LinIneq &rhs) {
523 auto enable1 = lhs.rec_scale > 0 && rhs.rec_scale > 0;
524 auto enable2 = lhs.rec_scale > 0;
525 auto enable3 = rhs.rec_scale > 0;
526 auto scale1 = lhs.rec_scale + rhs.rec_scale;
527 auto scale2 = lhs.rec_scale;
528 auto scale3 = rhs.rec_scale;
529 auto bias1 = lhs.rec_bias + rhs.rec_bias;
530 auto bias2 = lhs.rec_bias + rhs.nonrec_bias;
531 auto bias3 = rhs.rec_bias + lhs.nonrec_bias;
532 auto maxScale = std::max(scale1, std::max(scale2, scale3));
536 std::optional<int32_t> maxBias;
537 if (enable1 && scale1 == maxScale)
539 if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
541 if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
546 auto nonrec_bias = lhs.nonrec_bias + rhs.nonrec_bias;
547 auto failed = lhs.failed || rhs.failed;
548 if (enable1 && scale1 == maxScale && bias1 == *maxBias)
549 return LinIneq(scale1, bias1, nonrec_bias, failed);
550 if (enable2 && scale2 == maxScale && bias2 == *maxBias)
551 return LinIneq(scale2, bias2, nonrec_bias, failed);
552 if (enable3 && scale3 == maxScale && bias3 == *maxBias)
553 return LinIneq(scale3, bias3, nonrec_bias, failed);
554 return LinIneq(0, 0, nonrec_bias, failed);
567 if (rec_scale == 1 && rec_bias > 0)
573 void print(llvm::raw_ostream &os)
const {
575 bool both = (rec_scale != 0 || rec_bias != 0) && nonrec_bias != 0;
579 if (rec_scale != 0) {
582 os << rec_scale <<
"*";
588 os <<
" - " << -rec_bias;
590 os <<
" + " << rec_bias;
598 if (nonrec_bias != 0) {
615 class ConstraintSolver {
617 ConstraintSolver() =
default;
620 auto v = vars.alloc();
623 info[v].insert(currentInfo);
625 locs[v].insert(*currentLoc);
628 DerivedExpr *derived() {
629 auto *d = derivs.alloc();
633 KnownExpr *known(int32_t
value) {
return alloc<KnownExpr>(knowns,
value); }
634 IdExpr *id(Expr *arg) {
return alloc<IdExpr>(ids, arg); }
635 PowExpr *pow(Expr *arg) {
return alloc<PowExpr>(uns, arg); }
636 AddExpr *add(Expr *lhs, Expr *rhs) {
return alloc<AddExpr>(bins, lhs, rhs); }
637 MaxExpr *max(Expr *lhs, Expr *rhs) {
return alloc<MaxExpr>(bins, lhs, rhs); }
638 MinExpr *min(Expr *lhs, Expr *rhs) {
return alloc<MinExpr>(bins, lhs, rhs); }
642 Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
644 lhs->constraint = max(lhs->constraint, rhs);
646 lhs->constraint = id(rhs);
647 return lhs->constraint;
652 Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
654 lhs->upperBound = min(lhs->upperBound, rhs);
656 lhs->upperBound = id(rhs);
657 return lhs->upperBound;
660 void dumpConstraints(llvm::raw_ostream &os);
661 LogicalResult solve();
663 using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
664 const ContextInfo &getContextInfo()
const {
return info; }
665 void setCurrentContextInfo(
FieldRef fieldRef) { currentInfo = fieldRef; }
666 void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
670 llvm::BumpPtrAllocator allocator;
671 Allocator<VarExpr> vars = {allocator};
672 Allocator<DerivedExpr> derivs = {allocator};
673 InternedAllocator<KnownExpr> knowns = {allocator};
674 InternedAllocator<IdExpr> ids = {allocator};
675 InternedAllocator<UnaryExpr> uns = {allocator};
676 InternedAllocator<BinaryExpr> bins = {allocator};
679 std::vector<Expr *> exprs;
680 RootExpr root = {exprs};
683 template <
typename R,
typename T,
typename... Args>
684 R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
685 auto it = allocator.template alloc<R>(std::forward<Args>(args)...);
687 exprs.push_back(it.first);
689 info[it.first].insert(currentInfo);
691 locs[it.first].insert(*currentLoc);
699 DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
700 std::optional<Location> currentLoc;
704 ConstraintSolver(ConstraintSolver &&) =
delete;
705 ConstraintSolver(
const ConstraintSolver &) =
delete;
706 ConstraintSolver &operator=(ConstraintSolver &&) =
delete;
707 ConstraintSolver &operator=(
const ConstraintSolver &) =
delete;
709 void emitUninferredWidthError(VarExpr *var);
711 LinIneq checkCycles(VarExpr *var, Expr *expr,
712 SmallPtrSetImpl<Expr *> &seenVars,
713 InFlightDiagnostic *reportInto =
nullptr,
714 unsigned indent = 1);
720 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
721 for (
auto *e : exprs) {
722 if (
auto *v = dyn_cast<VarExpr>(e)) {
724 os <<
"- " << *v <<
" >= " << *v->constraint <<
"\n";
726 os <<
"- " << *v <<
" unconstrained\n";
732 inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const LinIneq &l) {
748 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
749 SmallPtrSetImpl<Expr *> &seenVars,
750 InFlightDiagnostic *reportInto,
753 TypeSwitch<Expr *, LinIneq>(expr)
754 .Case<KnownExpr>([&](
auto *expr) {
return LinIneq(*expr->solution); })
755 .Case<VarExpr>([&](
auto *expr) {
757 return LinIneq(1, 0);
758 if (!seenVars.insert(expr).second)
765 if (!expr->constraint)
768 auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
770 seenVars.erase(expr);
773 .Case<IdExpr>([&](
auto *expr) {
774 return checkCycles(var, expr->arg, seenVars, reportInto,
777 .Case<PowExpr>([&](
auto *expr) {
782 checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
783 if (arg.rec_scale != 0 || arg.nonrec_bias < 0 ||
784 arg.nonrec_bias >= 31)
785 return LinIneq::unsat();
786 return LinIneq(1 << arg.nonrec_bias);
788 .Case<AddExpr>([&](
auto *expr) {
790 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
791 checkCycles(var, expr->rhs(), seenVars, reportInto,
794 .Case<MaxExpr, MinExpr>([&](
auto *expr) {
799 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
800 checkCycles(var, expr->rhs(), seenVars, reportInto,
803 .Default([](
auto) {
return LinIneq::unsat(); });
808 if (reportInto && !ineq.sat()) {
809 auto report = [&](Location loc) {
810 auto ¬e = reportInto->attachNote(loc);
811 note <<
"constrained width W >= ";
812 if (ineq.rec_scale == -1)
814 if (ineq.rec_scale != 1)
815 note << ineq.rec_scale;
817 if (ineq.rec_bias < 0)
818 note <<
"-" << -ineq.rec_bias;
819 if (ineq.rec_bias > 0)
820 note <<
"+" << ineq.rec_bias;
823 auto it = locs.find(expr);
824 if (it != locs.end())
825 for (
auto loc : it->second)
830 <<
"- Visited " << *expr <<
": " << ineq <<
"\n");
840 arg.first = operation(*arg.first);
846 llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
847 auto result =
ExprSolution{std::nullopt, lhs.second || rhs.second};
848 if (lhs.first && rhs.first)
849 result.first = operation(*lhs.first, *rhs.first);
851 result.first = lhs.first;
853 result.first = rhs.first;
863 unsigned defaultWorklistSize) {
872 std::vector<Frame> worklist({{expr, indent}});
873 llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
876 worklist.reserve(defaultWorklistSize);
878 while (!worklist.empty()) {
879 auto &frame = worklist.back();
882 if (solution.first && !solution.second)
883 frame.expr->solution = *solution.first;
884 solvedExprs[frame.expr] = solution;
888 if (!isa<KnownExpr>(frame.expr)) {
890 llvm::dbgs().indent(frame.indent * 2)
891 <<
"= Solved " << *frame.expr <<
" = " << *solution.first;
893 llvm::dbgs().indent(frame.indent * 2)
894 <<
"= Skipped " << *frame.expr;
895 llvm::dbgs() <<
" (" << (solution.second ?
"cycle broken" :
"unique")
904 if (frame.expr->solution) {
906 if (!isa<KnownExpr>(frame.expr))
907 llvm::dbgs().indent(indent * 2) <<
"- Cached " << *frame.expr <<
" = "
908 << *frame.expr->solution <<
"\n";
910 setSolution(
ExprSolution{*frame.expr->solution,
false});
916 if (!isa<KnownExpr>(frame.expr))
918 <<
"- Solving " << *frame.expr <<
"\n";
921 TypeSwitch<Expr *>(frame.expr)
922 .Case<KnownExpr>([&](
auto *expr) {
925 .Case<VarExpr>([&](
auto *expr) {
926 if (solvedExprs.contains(expr->constraint)) {
927 auto solution = solvedExprs[expr->constraint];
934 if (expr->upperBound && solvedExprs.contains(expr->upperBound))
935 expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
936 seenVars.erase(expr);
938 if (solution.first && *solution.first < 0)
940 return setSolution(solution);
944 if (!expr->constraint)
949 if (!seenVars.insert(expr).second)
952 worklist.push_back({expr->constraint, indent + 1});
953 if (expr->upperBound)
954 worklist.push_back({expr->upperBound, indent + 1});
956 .Case<IdExpr>([&](
auto *expr) {
957 if (solvedExprs.contains(expr->arg))
958 return setSolution(solvedExprs[expr->arg]);
959 worklist.push_back({expr->arg, indent + 1});
961 .Case<PowExpr>([&](
auto *expr) {
962 if (solvedExprs.contains(expr->arg))
964 solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
966 worklist.push_back({expr->arg, indent + 1});
968 .Case<AddExpr>([&](
auto *expr) {
969 if (solvedExprs.contains(expr->lhs()) &&
970 solvedExprs.contains(expr->rhs()))
972 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
973 [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
975 worklist.push_back({expr->lhs(), indent + 1});
976 worklist.push_back({expr->rhs(), indent + 1});
978 .Case<MaxExpr>([&](
auto *expr) {
979 if (solvedExprs.contains(expr->lhs()) &&
980 solvedExprs.contains(expr->rhs()))
982 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
983 [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
985 worklist.push_back({expr->lhs(), indent + 1});
986 worklist.push_back({expr->rhs(), indent + 1});
988 .Case<MinExpr>([&](
auto *expr) {
989 if (solvedExprs.contains(expr->lhs()) &&
990 solvedExprs.contains(expr->rhs()))
992 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
993 [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
995 worklist.push_back({expr->lhs(), indent + 1});
996 worklist.push_back({expr->rhs(), indent + 1});
1003 return solvedExprs[expr];
1009 LogicalResult ConstraintSolver::solve() {
1011 llvm::dbgs() <<
"\n===----- Constraints -----===\n\n";
1017 llvm::dbgs() <<
"\n===----- Checking for unbreakable loops -----===\n\n");
1018 SmallPtrSet<Expr *, 16> seenVars;
1019 bool anyFailed =
false;
1021 for (
auto *expr : exprs) {
1023 auto *var = dyn_cast<VarExpr>(expr);
1024 if (!var || !var->constraint)
1027 <<
"- Checking " << *var <<
" >= " << *var->constraint <<
"\n");
1032 seenVars.insert(var);
1033 auto ineq = checkCycles(var, var->constraint, seenVars);
1043 <<
" = Breakable since " << ineq <<
" satisfiable\n");
1052 <<
" = UNBREAKABLE since " << ineq <<
" unsatisfiable\n");
1054 for (
auto fieldRef : info.find(var)->second) {
1057 auto op = fieldRef.getDefiningOp();
1058 auto diag = op ? op->emitOpError()
1059 : mlir::emitError(fieldRef.getValue().getLoc())
1061 diag <<
"is constrained to be wider than itself";
1064 seenVars.insert(var);
1065 checkCycles(var, var->constraint, seenVars, &diag);
1076 LLVM_DEBUG(
llvm::dbgs() <<
"\n===----- Solving constraints -----===\n\n");
1077 unsigned defaultWorklistSize = exprs.size() / 2;
1078 for (
auto *expr : exprs) {
1080 auto *var = dyn_cast<VarExpr>(expr);
1085 if (!var->constraint) {
1086 LLVM_DEBUG(
llvm::dbgs() <<
"- Unconstrained " << *var <<
"\n");
1087 emitUninferredWidthError(var);
1094 <<
"- Solving " << *var <<
" >= " << *var->constraint <<
"\n");
1095 seenVars.insert(var);
1096 auto solution =
solveExpr(var->constraint, seenVars, defaultWorklistSize);
1098 if (var->upperBound && !var->upperBoundSolution)
1099 var->upperBoundSolution =
1100 solveExpr(var->upperBound, seenVars, defaultWorklistSize).first;
1104 if (solution.first && *solution.first < 0)
1106 var->solution = solution.first;
1110 if (!solution.first) {
1111 LLVM_DEBUG(
llvm::dbgs() <<
" - UNSOLVED " << *var <<
"\n");
1112 emitUninferredWidthError(var);
1117 <<
" = Solved " << *var <<
" = " << solution.first <<
" ("
1118 << (solution.second ?
"cycle broken" :
"unique") <<
")\n");
1121 if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1122 LLVM_DEBUG(
llvm::dbgs() <<
" ! Unsatisfiable " << *var
1123 <<
" <= " << var->upperBoundSolution <<
"\n");
1124 emitUninferredWidthError(var);
1130 for (
auto *expr : exprs) {
1132 auto *derived = dyn_cast<DerivedExpr>(expr);
1136 auto *assigned = derived->assigned;
1137 if (!assigned || !assigned->solution) {
1138 LLVM_DEBUG(
llvm::dbgs() <<
"- Unused " << *derived <<
" set to 0\n");
1139 derived->solution = 0;
1141 LLVM_DEBUG(
llvm::dbgs() <<
"- Deriving " << *derived <<
" = "
1142 << assigned->solution <<
"\n");
1143 derived->solution = *assigned->solution;
1147 return failure(anyFailed);
1152 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1153 FieldRef fieldRef = info.find(var)->second.back();
1156 auto diag = mlir::emitError(
value.getLoc(),
"uninferred width:");
1159 if (isa<BlockArgument>(
value)) {
1161 }
else if (
auto op =
value.getDefiningOp()) {
1162 TypeSwitch<Operation *>(op)
1163 .Case<WireOp>([&](
auto) { diag <<
" wire"; })
1164 .Case<RegOp, RegResetOp>([&](
auto) { diag <<
" reg"; })
1165 .Case<NodeOp>([&](
auto) { diag <<
" node"; })
1166 .Default([&](
auto) { diag <<
" value"; });
1173 if (!fieldName.empty()) {
1176 diag <<
" \"" << fieldName <<
"\"";
1179 if (!var->constraint) {
1180 diag <<
" is unconstrained";
1181 }
else if (var->solution && var->upperBoundSolution &&
1182 var->solution > var->upperBoundSolution) {
1183 diag <<
" cannot satisfy all width requirements";
1184 LLVM_DEBUG(
llvm::dbgs() << *var->constraint <<
"\n");
1185 LLVM_DEBUG(
llvm::dbgs() << *var->upperBound <<
"\n");
1186 auto loc = locs.find(var->constraint)->second.back();
1187 diag.attachNote(loc) <<
"width is constrained to be at least "
1188 << *var->solution <<
" here:";
1189 loc = locs.find(var->upperBound)->second.back();
1190 diag.attachNote(loc) <<
"width is constrained to be at most "
1191 << *var->upperBoundSolution <<
" here:";
1193 diag <<
" width cannot be determined";
1194 LLVM_DEBUG(
llvm::dbgs() << *var->constraint <<
"\n");
1195 auto loc = locs.find(var->constraint)->second.back();
1196 diag.attachNote(loc) <<
"width is constrained by an uninferred width here:";
1208 class InferenceMapping {
1210 InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1211 hw::InnerSymbolTableCollection &istc)
1212 : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1214 LogicalResult map(CircuitOp op);
1215 LogicalResult mapOperation(Operation *op);
1220 void declareVars(Value
value, Location loc,
bool isDerived =
false);
1223 Expr *declareVar(
FieldRef fieldRef, Location loc);
1227 Expr *declareVar(
FIRRTLType type, Location loc);
1232 void maximumOfTypes(Value result, Value rhs, Value lhs);
1236 void constrainTypes(Value larger, Value smaller,
bool equal =
false);
1240 void constrainTypes(Expr *larger, Expr *smaller,
1241 bool imposeUpperBounds =
false,
bool equal =
false);
1250 Expr *getExpr(Value
value)
const;
1253 Expr *getExpr(
FieldRef fieldRef)
const;
1257 Expr *getExprOrNull(
FieldRef fieldRef)
const;
1261 void setExpr(Value
value, Expr *expr);
1264 void setExpr(
FieldRef fieldRef, Expr *expr);
1267 bool isModuleSkipped(FModuleOp module)
const {
1268 return skippedModules.count(module);
1272 bool areAllModulesSkipped()
const {
return allModulesSkipped; }
1276 ConstraintSolver &solver;
1279 DenseMap<FieldRef, Expr *> opExprs;
1282 SmallPtrSet<Operation *, 16> skippedModules;
1283 bool allModulesSkipped =
true;
1286 SymbolTable &symtbl;
1289 hw::InnerRefNamespace irn;
1296 return TypeSwitch<Type, bool>(type)
1297 .Case<
FIRRTLBaseType>([](
auto base) {
return base.hasUninferredWidth(); })
1299 [](
auto ref) {
return ref.getType().hasUninferredWidth(); })
1300 .Default([](
auto) {
return false; });
1303 LogicalResult InferenceMapping::map(CircuitOp op) {
1305 <<
"\n===----- Mapping ops to constraint exprs -----===\n\n");
1308 for (
auto module : op.getOps<FModuleOp>())
1309 for (
auto arg : module.getArguments()) {
1310 solver.setCurrentContextInfo(
FieldRef(arg, 0));
1311 declareVars(arg, module.getLoc());
1314 for (
auto module : op.getOps<FModuleOp>()) {
1317 bool anyUninferred =
false;
1318 for (
auto arg : module.getArguments()) {
1323 module.walk([&](Operation *op) {
1324 for (
auto type : op->getResultTypes())
1327 return WalkResult::interrupt();
1328 return WalkResult::advance();
1331 if (!anyUninferred) {
1332 LLVM_DEBUG(
llvm::dbgs() <<
"Skipping fully-inferred module '"
1333 << module.getName() <<
"'\n");
1334 skippedModules.insert(module);
1338 allModulesSkipped =
false;
1342 auto result = module.getBodyBlock()->walk(
1343 [&](Operation *op) {
return WalkResult(mapOperation(op)); });
1344 if (result.wasInterrupted())
1351 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1357 bool allWidthsKnown =
true;
1358 for (
auto result : op->getResults()) {
1359 if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1361 allWidthsKnown =
false;
1364 auto resultTy = type_dyn_cast<FIRRTLType>(result.getType());
1368 declareVars(result, op->getLoc());
1370 allWidthsKnown =
false;
1372 if (allWidthsKnown && !isa<FConnectLike, AttachOp>(op))
1376 if (isa<PropAssignOp>(op))
1380 bool mappingFailed =
false;
1381 solver.setCurrentContextInfo(
1383 solver.setCurrentLocation(op->getLoc());
1384 TypeSwitch<Operation *>(op)
1385 .Case<ConstantOp>([&](
auto op) {
1389 if (
auto width = op.getType().get().getWidth())
1390 e = solver.known(*
width);
1392 auto v = op.getValue();
1393 auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1394 : v.countLeadingZeros());
1397 e = solver.known(std::max(w, 1u));
1399 setExpr(op.getResult(), e);
1401 .Case<SpecialConstantOp>([&](
auto op) {
1404 .Case<InvalidValueOp>([&](
auto op) {
1409 declareVars(op.getResult(), op.getLoc(),
true);
1413 auto type = op.getType();
1414 ImplicitLocOpBuilder
builder(op->getLoc(), op);
1416 llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
1420 auto clone =
builder.create<InvalidValueOp>(type);
1421 declareVars(clone.getResult(), clone.getLoc(),
1426 .Case<WireOp, RegOp>(
1427 [&](
auto op) { declareVars(op.getResult(), op.getLoc()); })
1428 .Case<RegResetOp>([&](
auto op) {
1433 declareVars(op.getResult(), op.getLoc());
1436 constrainTypes(op.getResult(), op.getResetValue());
1438 .Case<NodeOp>([&](
auto op) {
1441 op.getResult().getType());
1445 .Case<SubfieldOp>([&](
auto op) {
1446 BundleType bundleType = op.getInput().getType();
1447 auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1448 unifyTypes(
FieldRef(op.getResult(), 0),
1449 FieldRef(op.getInput(), fieldID), op.getType());
1451 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
1457 .Case<SubtagOp>([&](
auto op) {
1458 FEnumType enumType = op.getInput().getType();
1459 auto fieldID = enumType.getFieldID(op.getFieldIndex());
1460 unifyTypes(
FieldRef(op.getResult(), 0),
1461 FieldRef(op.getInput(), fieldID), op.getType());
1464 .Case<RefSubOp>([&](RefSubOp op) {
1465 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1466 op.getInput().getType().getType())
1467 .Case<FVectorType>([](
auto _) {
return 1; })
1468 .Case<BundleType>([&](
auto type) {
1469 return type.getFieldID(op.getIndex());
1471 unifyTypes(
FieldRef(op.getResult(), 0),
1472 FieldRef(op.getInput(), fieldID), op.getType());
1476 .Case<AddPrimOp, SubPrimOp>([&](
auto op) {
1477 auto lhs = getExpr(op.getLhs());
1478 auto rhs = getExpr(op.getRhs());
1479 auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1480 setExpr(op.getResult(), e);
1482 .Case<MulPrimOp>([&](
auto op) {
1483 auto lhs = getExpr(op.getLhs());
1484 auto rhs = getExpr(op.getRhs());
1485 auto e = solver.add(lhs, rhs);
1486 setExpr(op.getResult(), e);
1488 .Case<DivPrimOp>([&](
auto op) {
1489 auto lhs = getExpr(op.getLhs());
1491 if (op.getType().get().isSigned()) {
1492 e = solver.add(lhs, solver.known(1));
1496 setExpr(op.getResult(), e);
1498 .Case<RemPrimOp>([&](
auto op) {
1499 auto lhs = getExpr(op.getLhs());
1500 auto rhs = getExpr(op.getRhs());
1501 auto e = solver.min(lhs, rhs);
1502 setExpr(op.getResult(), e);
1504 .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](
auto op) {
1505 auto lhs = getExpr(op.getLhs());
1506 auto rhs = getExpr(op.getRhs());
1507 auto e = solver.max(lhs, rhs);
1508 setExpr(op.getResult(), e);
1512 .Case<CatPrimOp>([&](
auto op) {
1513 auto lhs = getExpr(op.getLhs());
1514 auto rhs = getExpr(op.getRhs());
1515 auto e = solver.add(lhs, rhs);
1516 setExpr(op.getResult(), e);
1518 .Case<DShlPrimOp>([&](
auto op) {
1519 auto lhs = getExpr(op.getLhs());
1520 auto rhs = getExpr(op.getRhs());
1521 auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1522 setExpr(op.getResult(), e);
1524 .Case<DShlwPrimOp, DShrPrimOp>([&](
auto op) {
1525 auto e = getExpr(op.getLhs());
1526 setExpr(op.getResult(), e);
1530 .Case<NegPrimOp>([&](
auto op) {
1531 auto input = getExpr(op.getInput());
1532 auto e = solver.add(input, solver.known(1));
1533 setExpr(op.getResult(), e);
1535 .Case<CvtPrimOp>([&](
auto op) {
1536 auto input = getExpr(op.getInput());
1537 auto e = op.getInput().getType().get().isSigned()
1539 : solver.add(input, solver.known(1));
1540 setExpr(op.getResult(), e);
1544 .Case<BitsPrimOp>([&](
auto op) {
1545 setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1547 .Case<HeadPrimOp>([&](
auto op) {
1548 setExpr(op.getResult(), solver.known(op.getAmount()));
1550 .Case<TailPrimOp>([&](
auto op) {
1551 auto input = getExpr(op.getInput());
1552 auto e = solver.add(input, solver.known(-op.getAmount()));
1553 setExpr(op.getResult(), e);
1555 .Case<PadPrimOp>([&](
auto op) {
1556 auto input = getExpr(op.getInput());
1557 auto e = solver.max(input, solver.known(op.getAmount()));
1558 setExpr(op.getResult(), e);
1560 .Case<ShlPrimOp>([&](
auto op) {
1561 auto input = getExpr(op.getInput());
1562 auto e = solver.add(input, solver.known(op.getAmount()));
1563 setExpr(op.getResult(), e);
1565 .Case<ShrPrimOp>([&](
auto op) {
1566 auto input = getExpr(op.getInput());
1567 auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1569 setExpr(op.getResult(), e);
1573 .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1574 [&](
auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1575 .Case<mlir::UnrealizedConversionCastOp>(
1576 [&](
auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1580 .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1581 AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1582 XorRPrimOp>([&](
auto op) {
1583 auto width = op.getType().getBitWidthOrSentinel();
1584 assert(
width > 0 &&
"width should have been checked by verifier");
1585 setExpr(op.getResult(), solver.known(
width));
1587 .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](
auto op) {
1588 auto *sel = getExpr(op.getSel());
1589 constrainTypes(sel, solver.known(1));
1590 maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1592 .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1593 auto *sel = getExpr(op.getSel());
1594 constrainTypes(sel, solver.known(2));
1595 maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1596 maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1597 maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1600 .Case<ConnectOp, StrictConnectOp>(
1601 [&](
auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1602 .Case<RefDefineOp>([&](
auto op) {
1605 constrainTypes(op.getDest(), op.getSrc(),
true);
1608 .Case<StrictConnectOp>([&](
auto op) {
1613 constrainTypes(op.getDest(), op.getSrc());
1614 constrainTypes(op.getSrc(), op.getDest());
1616 .Case<AttachOp>([&](
auto op) {
1620 if (op.getAttached().empty())
1622 auto prev = op.getAttached()[0];
1623 for (
auto operand : op.getAttached().drop_front()) {
1624 auto e1 = getExpr(prev);
1625 auto e2 = getExpr(operand);
1626 constrainTypes(e1, e2,
true);
1627 constrainTypes(e2, e1,
true);
1633 .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp, CoverOp>(
1637 .Case<InstanceOp>([&](
auto op) {
1638 auto refdModule = op.getReferencedModule(symtbl);
1639 auto module = dyn_cast<FModuleOp>(&*refdModule);
1641 auto diag = mlir::emitError(op.getLoc());
1642 diag <<
"extern module `" << op.getModuleName()
1643 <<
"` has ports of uninferred width";
1645 auto fml = cast<FModuleLike>(&*refdModule);
1646 auto ports = fml.getPorts();
1647 for (
auto &port : ports) {
1649 if (baseType && baseType.hasUninferredWidth()) {
1650 diag.attachNote(op.getLoc()) <<
"Port: " << port.name;
1651 if (!baseType.isGround())
1656 diag.attachNote(op.getLoc())
1657 <<
"Only non-extern FIRRTL modules may contain unspecified "
1658 "widths to be inferred automatically.";
1659 diag.attachNote(refdModule->getLoc())
1660 <<
"Module `" << op.getModuleName() <<
"` defined here:";
1661 mappingFailed =
true;
1668 for (
auto it : llvm::zip(op->getResults(), module.getArguments())) {
1670 type_cast<FIRRTLType>(std::get<0>(it).getType()));
1675 .Case<MemOp>([&](MemOp op) {
1677 unsigned nonDebugPort = 0;
1678 for (
const auto &result : llvm::enumerate(op.getResults())) {
1679 declareVars(result.value(), op.getLoc());
1680 if (!type_isa<RefType>(result.value().getType()))
1681 nonDebugPort = result.index();
1686 auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1687 static const unsigned indices[] = {3, 5};
1688 static const unsigned debug[] = {0};
1690 case MemOp::PortKind::Read:
1691 case MemOp::PortKind::Write:
1692 return ArrayRef<unsigned>(indices, 1);
1693 case MemOp::PortKind::ReadWrite:
1694 return ArrayRef<unsigned>(indices);
1695 case MemOp::PortKind::Debug:
1696 return ArrayRef<unsigned>(debug);
1698 llvm_unreachable(
"Imposible PortKind");
1705 unsigned firstFieldIndex =
1706 dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1708 op.getResult(nonDebugPort),
1709 type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1710 .getFieldID(firstFieldIndex));
1711 LLVM_DEBUG(
llvm::dbgs() <<
"Adjusting memory port variables:\n");
1714 auto dataType = op.getDataType();
1715 for (
unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1716 auto result = op.getResult(i);
1717 if (type_isa<RefType>(result.getType())) {
1721 unifyTypes(firstData,
FieldRef(result, 1), dataType);
1726 type_cast<BundleType>(op.getPortType(i).getPassiveType());
1727 for (
auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1728 unifyTypes(
FieldRef(result, portType.getFieldID(fieldIndex)),
1729 firstData, dataType);
1733 .Case<RefSendOp>([&](
auto op) {
1734 declareVars(op.getResult(), op.getLoc());
1735 constrainTypes(op.getResult(), op.getBase(),
true);
1737 .Case<RefResolveOp>([&](
auto op) {
1738 declareVars(op.getResult(), op.getLoc());
1739 constrainTypes(op.getResult(), op.getRef(),
true);
1741 .Case<RefCastOp>([&](
auto op) {
1742 declareVars(op.getResult(), op.getLoc());
1743 constrainTypes(op.getResult(), op.getInput(),
true);
1745 .Case<RWProbeOp>([&](
auto op) {
1746 declareVars(op.getResult(), op.getLoc());
1747 auto ist = irn.lookup(op.getTarget());
1749 op->emitError(
"target of rwprobe could not be resolved");
1750 mappingFailed =
true;
1755 op->emitError(
"target of rwprobe resolved to unsupported target");
1756 mappingFailed =
true;
1760 ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1761 unifyTypes(
FieldRef(op.getResult(), 0),
1762 FieldRef(ref.getValue(), newFID), op.getType());
1764 .Case<mlir::UnrealizedConversionCastOp>([&](
auto op) {
1765 for (Value result : op.getResults()) {
1766 auto ty = result.getType();
1767 if (type_isa<FIRRTLType>(ty))
1768 declareVars(result, op.getLoc());
1771 .Default([&](
auto op) {
1772 op->emitOpError(
"not supported in width inference");
1773 mappingFailed =
true;
1777 if (
auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1781 return failure(mappingFailed);
1786 void InferenceMapping::declareVars(Value
value, Location loc,
bool isDerived) {
1789 unsigned fieldID = 0;
1791 auto width = type.getBitWidthOrSentinel();
1796 }
else if (
width == -1) {
1799 solver.setCurrentContextInfo(field);
1801 setExpr(field, solver.derived());
1803 setExpr(field, solver.var());
1805 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1808 for (
auto &element : bundleType) {
1809 declare(element.type);
1811 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1813 auto save = fieldID;
1814 declare(vecType.getElementType());
1816 fieldID = save + vecType.getMaxFieldID();
1817 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1819 for (
auto &element : enumType.getElements())
1820 declare(element.type);
1822 llvm_unreachable(
"Unknown type inside a bundle!");
1832 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1836 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1838 for (
auto &element : bundleType.getElements())
1839 maximize(element.type);
1840 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1842 auto save = fieldID;
1844 if (vecType.getNumElements() > 0)
1845 maximize(vecType.getElementType());
1846 fieldID = save + vecType.getMaxFieldID();
1847 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1849 for (
auto &element : enumType.getElements())
1850 maximize(element.type);
1851 }
else if (type.isGround()) {
1852 auto *e = solver.max(getExpr(
FieldRef(rhs, fieldID)),
1854 setExpr(
FieldRef(result, fieldID), e);
1857 llvm_unreachable(
"Unknown type inside a bundle!");
1872 void InferenceMapping::constrainTypes(Value larger, Value smaller,
bool equal) {
1879 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1881 for (
auto &element : bundleType.getElements()) {
1883 constrain(element.type, smaller, larger);
1885 constrain(element.type, larger, smaller);
1887 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
1889 auto save = fieldID;
1891 if (vecType.getNumElements() > 0) {
1892 constrain(vecType.getElementType(), larger, smaller);
1894 fieldID = save + vecType.getMaxFieldID();
1895 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
1897 for (
auto &element : enumType.getElements())
1898 constrain(element.type, larger, smaller);
1899 }
else if (type.isGround()) {
1901 constrainTypes(getExpr(
FieldRef(larger, fieldID)),
1902 getExpr(
FieldRef(smaller, fieldID)),
false, equal);
1905 llvm_unreachable(
"Unknown type inside a bundle!");
1910 constrain(type, larger, smaller);
1915 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1916 bool imposeUpperBounds,
bool equal) {
1917 assert(larger &&
"Larger expression should be specified");
1918 assert(smaller &&
"Smaller expression should be specified");
1924 if (
auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1925 largerDerived->assigned = smaller;
1926 LLVM_DEBUG(
llvm::dbgs() <<
"Deriving " << *largerDerived <<
" from "
1927 << *smaller <<
"\n");
1930 if (
auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1931 smallerDerived->assigned = larger;
1932 LLVM_DEBUG(
llvm::dbgs() <<
"Deriving " << *smallerDerived <<
" from "
1933 << *larger <<
"\n");
1939 if (
auto largerVar = dyn_cast<VarExpr>(larger)) {
1940 [[maybe_unused]]
auto *c = solver.addGeqConstraint(largerVar, smaller);
1942 <<
"Constrained " << *largerVar <<
" >= " << *c <<
"\n");
1949 [[maybe_unused]]
auto *leq = solver.addLeqConstraint(largerVar, smaller);
1951 <<
"Constrained " << *largerVar <<
" <= " << *leq <<
"\n");
1961 if (
auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1962 if (imposeUpperBounds || equal) {
1963 [[maybe_unused]]
auto *c = solver.addLeqConstraint(smallerVar, larger);
1965 <<
"Constrained " << *smallerVar <<
" <= " << *c <<
"\n");
1986 <<
"Unify " <<
getFieldName(lhsFieldRef).first <<
" = "
1989 if (
auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1990 solver.addGeqConstraint(var, solver.known(0));
1991 setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1993 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
1995 for (
auto &element : bundleType) {
1996 unify(element.type);
1998 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2000 auto save = fieldID;
2002 if (vecType.getNumElements() > 0) {
2003 unify(vecType.getElementType());
2005 fieldID = save + vecType.getMaxFieldID();
2006 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
2008 for (
auto &element : enumType.getElements())
2009 unify(element.type);
2011 llvm_unreachable(
"Unknown type inside a bundle!");
2019 Expr *InferenceMapping::getExpr(Value
value)
const {
2026 Expr *InferenceMapping::getExpr(
FieldRef fieldRef)
const {
2027 auto expr = getExprOrNull(fieldRef);
2028 assert(expr &&
"constraint expr should have been constructed for value");
2032 Expr *InferenceMapping::getExprOrNull(
FieldRef fieldRef)
const {
2033 auto it = opExprs.find(fieldRef);
2034 return it != opExprs.end() ? it->second :
nullptr;
2038 void InferenceMapping::setExpr(Value
value, Expr *expr) {
2045 void InferenceMapping::setExpr(
FieldRef fieldRef, Expr *expr) {
2051 if (fieldName.second)
2052 llvm::dbgs() <<
" (\"" << fieldName.first <<
"\")";
2055 opExprs[fieldRef] = expr;
2065 class InferenceTypeUpdate {
2067 InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2069 LogicalResult update(CircuitOp op);
2075 const InferenceMapping &mapping;
2081 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2082 LLVM_DEBUG(
llvm::dbgs() <<
"\n===----- Update types -----===\n\n");
2083 return mlir::failableParallelForEach(
2084 op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2087 if (mapping.isModuleSkipped(op))
2089 auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2090 if (failed(updateOperation(op)))
2091 return WalkResult::interrupt();
2092 return WalkResult::advance();
2093 }).wasInterrupted();
2094 return failure(isFailed);
2100 bool anyChanged =
false;
2102 for (Value v : op->getResults()) {
2103 auto result = updateValue(v);
2106 anyChanged |= *result;
2112 if (
auto con = dyn_cast<ConnectOp>(op)) {
2113 auto lhs = con.getDest();
2114 auto rhs = con.getSrc();
2115 auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2116 auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2119 if (!lhsType || !rhsType)
2122 auto lhsWidth = lhsType.getBitWidthOrSentinel();
2123 auto rhsWidth = rhsType.getBitWidthOrSentinel();
2124 if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2126 auto trunc =
builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2127 rhsWidth - lhsWidth);
2128 if (type_isa<SIntType>(rhsType))
2130 builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2133 <<
"Truncating RHS to " << lhsType <<
" in " << con <<
"\n");
2134 con->replaceUsesOfWith(con.getSrc(), trunc);
2140 if (
auto module = dyn_cast<FModuleOp>(op)) {
2142 bool argsChanged =
false;
2143 SmallVector<Attribute> argTypes;
2144 argTypes.reserve(module.getNumPorts());
2145 for (
auto arg : module.getArguments()) {
2146 auto result = updateValue(arg);
2149 argsChanged |= *result;
2155 module->setAttr(FModuleLike::getPortTypesAttrName(),
2165 auto *context = type.getContext();
2167 .
Case<UIntType>([&](
auto type) {
2170 .Case<SIntType>([&](
auto type) {
2173 .Case<AnalogType>([&](
auto type) {
2176 .Default([&](
auto type) {
return type; });
2182 auto type = type_dyn_cast<FIRRTLType>(
value.getType());
2193 if (
auto op = dyn_cast_or_null<InferTypeOpInterface>(
value.getDefiningOp())) {
2194 SmallVector<Type, 2> types;
2196 op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2197 op->getAttrDictionary(), op->getPropertiesStorage(),
2198 op->getRegions(), types);
2202 assert(types.size() == op->getNumResults());
2203 for (
auto it : llvm::zip(op->getResults(), types)) {
2204 LLVM_DEBUG(
llvm::dbgs() <<
"Inferring " << std::get<0>(it) <<
" as "
2205 << std::get<1>(it) <<
"\n");
2206 std::get<0>(it).setType(std::get<1>(it));
2212 auto context = type.getContext();
2213 unsigned fieldID = 0;
2216 auto width = type.getBitWidthOrSentinel();
2221 }
else if (
width == -1) {
2226 }
else if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
2229 llvm::SmallVector<BundleType::BundleElement, 3> elements;
2230 for (
auto &element : bundleType) {
2231 auto updatedBase = updateBase(element.type);
2234 elements.emplace_back(element.name, element.isFlip, updatedBase);
2237 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
2239 auto save = fieldID;
2242 if (vecType.getNumElements() > 0) {
2243 auto updatedBase = updateBase(vecType.getElementType());
2248 fieldID = save + vecType.getMaxFieldID();
2253 }
else if (
auto enumType = type_dyn_cast<FEnumType>(type)) {
2255 llvm::SmallVector<FEnumType::EnumElement> elements;
2256 for (
auto &element : enumType.getElements()) {
2257 auto updatedBase = updateBase(element.type);
2260 elements.emplace_back(element.name, updatedBase);
2264 llvm_unreachable(
"Unknown type inside a bundle!");
2271 LLVM_DEBUG(
llvm::dbgs() <<
"Update " <<
value <<
" to " << newType <<
"\n");
2272 value.setType(newType);
2277 if (
auto op =
value.getDefiningOp<ConstantOp>()) {
2278 auto k = op.getValue();
2279 auto bitwidth = op.getType().getBitWidthOrSentinel();
2280 if (k.getBitWidth() >
unsigned(bitwidth))
2281 k = k.trunc(bitwidth);
2285 return newType != type;
2291 assert(type.isGround() &&
"Can only pass in ground types.");
2294 Expr *expr = mapping.getExprOrNull(fieldRef);
2295 if (!expr || !expr->solution) {
2300 mlir::emitError(
value.getLoc(),
"width should have been inferred");
2303 int32_t solution = *expr->solution;
2315 template <
typename T>
2316 struct DenseMapInfo<InternedSlot<T>> {
2319 auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
2320 return Slot(
static_cast<T *
>(pointer));
2323 auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
2324 return Slot(
static_cast<T *
>(pointer));
2328 auto empty = getEmptyKey().ptr;
2329 auto tombstone = getTombstoneKey().ptr;
2330 if (LHS.ptr ==
empty || RHS.ptr ==
empty || LHS.ptr == tombstone ||
2331 RHS.ptr == tombstone)
2332 return LHS.ptr == RHS.ptr;
2333 return *LHS.ptr == *RHS.ptr;
2343 class InferWidthsPass :
public InferWidthsBase<InferWidthsPass> {
2344 void runOnOperation()
override;
2348 void InferWidthsPass::runOnOperation() {
2350 ConstraintSolver solver;
2351 InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2352 getAnalysis<hw::InnerSymbolTableCollection>());
2353 if (failed(mapping.map(getOperation()))) {
2354 signalPassFailure();
2357 if (mapping.areAllModulesSkipped()) {
2358 markAllAnalysesPreserved();
2363 if (failed(solver.solve())) {
2364 signalPassFailure();
2369 if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2370 signalPassFailure();
2374 return std::make_unique<InferWidthsPass>();
assert(baseType &&"element must be base type")
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
bool operator==(const ResetDomain &a, const ResetDomain &b)
static FieldRef getRefForIST(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, unsigned defaultWorklistSize)
Compute the value of a constraint expr.
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
T & operator<<(T &os, FIRVersion version)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()
llvm::hash_code hash_value(const T &e)
static Slot getEmptyKey()
static unsigned getHashValue(Slot val)
static Slot getTombstoneKey()
static bool isEqual(Slot LHS, Slot RHS)