16 #include "mlir/IR/Operation.h"
18 #include "ortools/sat/cp_model.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/Format.h"
23 #define DEBUG_TYPE "cpsat-schedulers"
25 using namespace circt;
27 using namespace operations_research;
28 using namespace operations_research::sat;
45 return containingOp->emitError(
"problem does not include last operation");
47 CpModelBuilder cpModel;
50 DenseMap<Operation *, IntVar> taskStarts;
51 DenseMap<Operation *, IntVar> taskEnds;
52 DenseMap<Problem::OperatorType, SmallVector<IntervalVar, 4>>
53 resourcesToTaskIntervals;
58 for (
auto *task : tasks) {
69 for (
auto item : llvm::enumerate(tasks)) {
70 auto i = item.index();
71 auto *task = item.value();
72 IntVar startVar = cpModel.NewIntVar(
Domain(0, horizon))
73 .WithName((Twine(
"start_of_task_") + Twine(i)).str());
74 IntVar endVar = cpModel.NewIntVar(
Domain(0, horizon))
75 .WithName((Twine(
"end_of_task_") + Twine(i)).str());
76 taskStarts[task] = startVar;
77 taskEnds[task] = endVar;
79 unsigned duration = *prob.
getLatency(*resource);
80 IntervalVar taskInterval =
81 cpModel.NewIntervalVar(startVar, duration, endVar)
82 .WithName((Twine(
"task_interval_") + Twine(i)).str());
85 resourcesToTaskIntervals[resource.value()].emplace_back(taskInterval);
90 for (Operation *task : tasks) {
92 Operation *src = dep.getSource();
93 Operation *dst = dep.getDestination();
95 return containingOp->emitError() <<
"dependence cycle detected";
96 cpModel.AddLessOrEqual(taskEnds[src], taskStarts[dst]);
102 for (
auto resourceToTaskIntervals : resourcesToTaskIntervals) {
104 auto capacity = prob.
getLimit(resource);
105 SmallVector<IntervalVar, 4> &taskIntervals =
106 resourceToTaskIntervals.getSecond();
113 CumulativeConstraint cumu = cpModel.AddCumulative(capacity.value());
114 for (
const auto &item : llvm::enumerate(taskIntervals)) {
115 auto i = item.index();
116 auto taskInterval = item.value();
117 IntVar demandVar = cpModel.NewIntVar(
Domain(1)).WithName(
118 (Twine(
"demand_") + Twine(i) + Twine(
"_") + Twine(resource.strref()))
123 cpModel.NewFixedSizeIntervalVar(taskInterval.StartExpr(), 1);
124 cumu.AddDemand(start, demandVar);
128 cpModel.Minimize(taskEnds[lastOp]);
132 int numSolutions = 0;
133 model.Add(NewFeasibleSolutionObserver([&](
const CpSolverResponse &r) {
134 LLVM_DEBUG(dbgs() <<
"Solution " << numSolutions <<
'\n');
135 LLVM_DEBUG(dbgs() <<
"Solution status" << r.status() <<
'\n');
139 LLVM_DEBUG(dbgs() <<
"Starting solver\n");
140 const CpSolverResponse response = SolveCpModel(cpModel.Build(), &model);
142 if (response.status() == CpSolverStatus::OPTIMAL ||
143 response.status() == CpSolverStatus::FEASIBLE) {
144 for (
auto *task : tasks)
145 prob.
setStartTime(task, SolutionIntegerValue(response, taskStarts[task]));
149 return containingOp->emitError() <<
"infeasible";
std::optional< OperatorType > getLinkedOperatorType(Operation *op)
The linked operator type provides the runtime characteristics for op.
bool hasOperation(Operation *op)
Return true if op is part of this problem.
void setStartTime(Operation *op, unsigned val)
DependenceRange getDependences(Operation *op)
Return a range object to transparently iterate over op's incoming 1) implicit def-use dependences (ba...
const OperationSet & getOperations()
Return the set of operations.
mlir::StringAttr OperatorType
Operator types are distinguished by name (chosen by the client).
std::optional< unsigned > getLatency(OperatorType opr)
The latency is the number of cycles opr needs to compute its result.
Operation * getContainingOp()
Return the operation containing this problem, e.g. to emit diagnostics.
This class models a resource-constrained scheduling problem.
std::optional< unsigned > getLimit(OperatorType opr)
The limit is the maximum number of operations using opr that are allowed to start in the same time st...
Domain
The number of values each bit of a type can assume.
LogicalResult scheduleCPSAT(SharedOperatorsProblem &prob, Operation *lastOp)
Solve the acyclic problem with shared operators using constraint programming and an external SAT solv...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.