Skip to content

Commit

Permalink
Add required objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
j3soon committed Jul 26, 2024
1 parent cbbd4c2 commit 017d081
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 20 deletions.
18 changes: 18 additions & 0 deletions core/nurse_scheduling/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
class Context:
def __init__(self) -> None:
self.startdate = None
self.enddate = None
self.requirements = None
self.people = None
self.objectives = None
self.dates = None
self.n_days = None
self.n_requirements = None
self.n_people = None
self.model = None
self.shifts = None
self.map_dr_p = None
self.map_dp_r = None
self.map_d_rp = None
self.map_r_dp = None
self.map_p_dr = None
26 changes: 26 additions & 0 deletions core/nurse_scheduling/objective_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from . import utils

def all_requirements_fulfilled(ctx, args):
# Hard constraint
# For all shifts, the requirements (# of people) must be fulfilled.
# Note that a shift is represented as (d, r)
# i.e., sum_{p}(shifts[(d, r, p)]) == required_n_people, for all (d, r)
for (d, r), ps in ctx.map_dr_p.items():
actual_n_people = sum(ctx.shifts[(d, r, p)] for p in ps)
required_n_people = utils.required_n_people(ctx.requirements[r])
ctx.model.Add(actual_n_people == required_n_people)

def all_people_work_at_most_one_shift_per_day(ctx, args):
# Hard constraint
# For all people, for all days, only work at most one shift.
# Note that a shift in day `d` can be represented as `r` instead of (d, r).
# i.e., sum_{r}(shifts[(d, r, p)]) <= 1, for all (d, p)
for (d, p), rs in ctx.map_dp_r.items():
actual_n_shifts = sum(ctx.shifts[(d, r, p)] for r in rs)
maximum_n_shifts = 1
ctx.model.Add(actual_n_shifts <= maximum_n_shifts)

OBJECTIVE_TYPES_TO_FUNC = {
"all requirements fulfilled": all_requirements_fulfilled,
"all people work at most one shift per day": all_people_work_at_most_one_shift_per_day,
}
38 changes: 18 additions & 20 deletions core/nurse_scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from ortools.sat.python import cp_model

from . import export, utils
from . import export, objective_types
from .context import Context
from .dataloader import load_data


Expand All @@ -13,14 +14,17 @@ def schedule(filepath: str, validate=True, deterministic=False):
scenario = load_data(filepath, validate)

logging.info("Extracting scenario data...")
if scenario.apiVersion != "alpha":
raise NotImplementedError(f"Unsupported API version: {scenario.apiVersion}")
startdate = scenario.startdate
enddate = scenario.enddate
requirements = scenario.requirements
people = scenario.people
objectives = scenario.objectives
del scenario
n_days = (enddate - startdate).days + 1
n_people = len(people)
n_requirements = len(requirements)
n_people = len(people)
dates = [startdate + timedelta(days=d) for d in range(n_days)]

logging.info("Initializing solver model...")
Expand All @@ -33,7 +37,9 @@ def schedule(filepath: str, validate=True, deterministic=False):
# Ref: https://developers.google.com/optimization/scheduling/employee_scheduling
for d in range(n_days):
for r in range(n_requirements):
# TODO(Optimize): Skip if no people is required in that day
for p in range(n_people):
# TODO(Optimize): Skip if the person does not qualify for the requirement
shifts[(d, r, p)] = model.NewBoolVar(f"shift_d{d}_r{r}_p{p}")

logging.info("Creating maps for faster lookup...")
Expand All @@ -58,24 +64,16 @@ def schedule(filepath: str, validate=True, deterministic=False):
for p in range(n_people)
}

logging.info("Adding preferences and constraints...")
# Hard constraint
# For all shifts, the requirements (# of people) must be fulfilled.
# Note that a shift is represented as (d, r)
# i.e., sum_{p}(shifts[(d, r, p)]) == required_n_people, for all (d, r)
for (d, r), ps in map_dr_p.items():
actual_n_people = sum(shifts[(d, r, p)] for p in ps)
required_n_people = utils.required_n_people(requirements[r])
model.Add(actual_n_people == required_n_people)

# Hard constraint
# For all people, for all days, only work at most one shift.
# Note that a shift in day `d` can be represented as `r` instead of (d, r).
# i.e., sum_{r}(shifts[(d, r, p)]) <= 1, for all (d, p)
for (d, p), rs in map_dp_r.items():
actual_n_shifts = sum(shifts[(d, r, p)] for r in rs)
maximum_n_shifts = 1
model.Add(actual_n_shifts <= maximum_n_shifts)
ctx = Context()
for k in vars(ctx):
setattr(ctx, k, locals()[k])

logging.info("Adding objectives (i.e., preferences and constraints)...")
# TODO: Check no duplicated objectives
# TODO: Check no overlapping objectives
# TODO: Check all required objectives are present
for objective in objectives:
objective_types.OBJECTIVE_TYPES_TO_FUNC[objective.type](ctx, objective.args)

logging.info("Initializing solver...")
solver = cp_model.CpSolver()
Expand Down
4 changes: 4 additions & 0 deletions core/tests/testcases/example_1.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
apiVersion: alpha
description: Simple Example 1
startdate: 2023-08-18
enddate: 2023-08-20
Expand All @@ -20,3 +21,6 @@ requirements:
- id: N
description: Night shift requirement
required_people: 1
objectives:
- type: all requirements fulfilled
- type: all people work at most one shift per day

0 comments on commit 017d081

Please sign in to comment.