From bb46e9710d54996853d6ad453ab6c60ced1930b1 Mon Sep 17 00:00:00 2001 From: Johnson Sun Date: Sun, 28 Jul 2024 03:34:11 +0800 Subject: [PATCH] feat(core): Add basic report mechanism --- core/nurse_scheduling/preference_types.py | 8 ++++++-- core/nurse_scheduling/report.py | 7 +++++++ core/nurse_scheduling/scheduler.py | 10 +++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 core/nurse_scheduling/report.py diff --git a/core/nurse_scheduling/preference_types.py b/core/nurse_scheduling/preference_types.py index f9a2f88..48d3817 100644 --- a/core/nurse_scheduling/preference_types.py +++ b/core/nurse_scheduling/preference_types.py @@ -1,5 +1,7 @@ from . import utils from .context import Context +from .report import Report + def all_requirements_fulfilled(ctx: Context, preference, preference_idx): # Hard constraint @@ -34,16 +36,17 @@ def assign_shifts_evenly(ctx: Context, preference, preference_idx): # Construct: L2 = actual_n_shifts - target_n_shifts) ** 2 MAX = max(ctx.n_days - target_n_shifts, target_n_shifts) - diff_var_name = f"{unique_var_prefix}_diff" + diff_var_name = f"{unique_var_prefix}diff" ctx.model_vars[diff_var_name] = diff = ctx.model.NewIntVar(0, MAX, diff_var_name) ctx.model.add_abs_equality(diff, actual_n_shifts - target_n_shifts) - L2_var_name = f"{unique_var_prefix}_L2" + L2_var_name = f"{unique_var_prefix}L2" ctx.model_vars[L2_var_name] = L2 = ctx.model.NewIntVar(0, MAX**2, L2_var_name) ctx.model.AddMultiplicationEquality(L2, diff, diff) # Add the objective weight = -1000000 ctx.objective += weight * L2 + ctx.reports.append(Report(f"assign_shifts_evenly_L2_p_{p}", L2, lambda x: x == 0)) def shift_request(ctx: Context, preference, preference_idx): # Soft constraint @@ -60,6 +63,7 @@ def shift_request(ctx: Context, preference, preference_idx): # Add the objective weight = 1 ctx.objective += weight * ctx.shifts[(d, r, p)] + ctx.reports.append(Report(f"shift_request_p_{p}_d_{d}_r_{r}", ctx.shifts[(d, r, p)], lambda x: x == 1)) PREFERENCE_TYPES_TO_FUNC = { "all requirements fulfilled": all_requirements_fulfilled, diff --git a/core/nurse_scheduling/report.py b/core/nurse_scheduling/report.py new file mode 100644 index 0000000..3c37997 --- /dev/null +++ b/core/nurse_scheduling/report.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + +@dataclass +class Report: + description: str + variable: 'typing.Any' + skip_condition: 'typing.Any' = lambda x: False diff --git a/core/nurse_scheduling/scheduler.py b/core/nurse_scheduling/scheduler.py index 3e07c42..cf3338e 100644 --- a/core/nurse_scheduling/scheduler.py +++ b/core/nurse_scheduling/scheduler.py @@ -1,12 +1,14 @@ import itertools import logging from datetime import timedelta +from typing import List from ortools.sat.python import cp_model from . import export, preference_types from .context import Context from .dataloader import load_data +from .report import Report def schedule(filepath: str, validate=True, deterministic=False): @@ -31,6 +33,7 @@ def schedule(filepath: str, validate=True, deterministic=False): logging.info("Initializing solver model...") ctx.model = cp_model.CpModel() ctx.model_vars = {} + ctx.reports: List[Report] = [] ctx.shifts = {} """A set of indicator variables that are 1 if and only if a person (p) is assigned to a shift (d, r).""" @@ -125,10 +128,15 @@ def on_solution_callback(self): logging.info(f" - conflicts: {solver.NumConflicts()}") logging.info(f" - branches : {solver.NumBranches()}") logging.info(f" - wall time: {solver.WallTime()}s") - logging.info("Variables:") for k, v in ctx.model_vars.items(): logging.info(f" - {k}: {solver.Value(v)}") + logging.info("Reports:") + for report in ctx.reports: + val = solver.Value(report.variable) + if report.skip_condition(val): + continue + logging.info(f" - {report.description}: {val}") logging.info(f"Done.")