Improve type system with ResidentId and ToxicPair structs
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{config::UserConfig, resident::Resident, schedule::ShiftType, slot::Day};
|
||||
use crate::{
|
||||
config::UserConfig,
|
||||
resident::{Resident, ResidentId},
|
||||
schedule::ShiftType,
|
||||
slot::Day,
|
||||
};
|
||||
|
||||
pub struct WorkloadBounds {
|
||||
pub max_workloads: HashMap<String, u8>,
|
||||
pub max_holiday_shifts: HashMap<String, u8>,
|
||||
pub max_by_shift_type: HashMap<(String, ShiftType), u8>,
|
||||
pub min_by_shift_type: HashMap<(String, ShiftType), u8>,
|
||||
pub max_workloads: HashMap<ResidentId, u8>,
|
||||
pub max_holiday_shifts: HashMap<ResidentId, u8>,
|
||||
pub max_by_shift_type: HashMap<(ResidentId, ShiftType), u8>,
|
||||
pub min_by_shift_type: HashMap<(ResidentId, ShiftType), u8>,
|
||||
}
|
||||
|
||||
impl WorkloadBounds {
|
||||
@@ -124,20 +129,20 @@ impl WorkloadBounds {
|
||||
.iter()
|
||||
.filter(|r| r.allowed_types.len() == 1)
|
||||
{
|
||||
let stype = &res.allowed_types[0];
|
||||
let shift_type = &res.allowed_types[0];
|
||||
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0);
|
||||
|
||||
local_limits.insert((res.id.clone(), stype.clone()), total_limit);
|
||||
local_thresholds.insert((res.id.clone(), stype.clone()), total_limit - 1);
|
||||
local_limits.insert((res.id.clone(), shift_type.clone()), total_limit);
|
||||
local_thresholds.insert((res.id.clone(), shift_type.clone()), total_limit - 1);
|
||||
|
||||
for other_type in &all_shift_types {
|
||||
if other_type != stype {
|
||||
if other_type != shift_type {
|
||||
local_limits.insert((res.id.clone(), other_type.clone()), 0);
|
||||
local_thresholds.insert((res.id.clone(), other_type.clone()), 0);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(s) = global_supply.get_mut(stype) {
|
||||
if let Some(s) = global_supply.get_mut(shift_type) {
|
||||
*s = s.saturating_sub(total_limit)
|
||||
}
|
||||
}
|
||||
@@ -151,37 +156,35 @@ impl WorkloadBounds {
|
||||
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0) as f32;
|
||||
let per_type = (total_limit / 2.0).ceil() as u8;
|
||||
|
||||
for stype in &all_shift_types {
|
||||
if res.allowed_types.contains(stype) {
|
||||
local_limits.insert((res.id.clone(), stype.clone()), per_type);
|
||||
local_thresholds.insert((res.id.clone(), stype.clone()), per_type - 1);
|
||||
if let Some(s) = global_supply.get_mut(stype) {
|
||||
for shift_type in &all_shift_types {
|
||||
if res.allowed_types.contains(shift_type) {
|
||||
local_limits.insert((res.id.clone(), shift_type.clone()), per_type);
|
||||
local_thresholds.insert((res.id.clone(), shift_type.clone()), per_type - 1);
|
||||
if let Some(s) = global_supply.get_mut(shift_type) {
|
||||
*s = s.saturating_sub(per_type)
|
||||
}
|
||||
} else {
|
||||
local_limits.insert((res.id.clone(), stype.clone()), 0);
|
||||
local_thresholds.insert((res.id.clone(), stype.clone()), 0);
|
||||
local_limits.insert((res.id.clone(), shift_type.clone()), 0);
|
||||
local_thresholds.insert((res.id.clone(), shift_type.clone()), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// residents with 3 available shift types
|
||||
let generalists: Vec<&Resident> = config
|
||||
let res: Vec<&Resident> = config
|
||||
.residents
|
||||
.iter()
|
||||
.filter(|r| r.allowed_types.len() == 3)
|
||||
.collect();
|
||||
|
||||
if !generalists.is_empty() {
|
||||
for stype in &all_shift_types {
|
||||
let remaining = *global_supply.get(stype).unwrap_or(&0);
|
||||
let fair_slice = (remaining as f32 / generalists.len() as f32)
|
||||
.ceil()
|
||||
.max(0.0) as u8;
|
||||
if !res.is_empty() {
|
||||
for shift_type in &all_shift_types {
|
||||
let remaining = *global_supply.get(shift_type).unwrap_or(&0);
|
||||
let fair_slice = (remaining as f32 / res.len() as f32).ceil().max(0.0) as u8;
|
||||
|
||||
for res in &generalists {
|
||||
local_limits.insert((res.id.clone(), stype.clone()), fair_slice);
|
||||
local_thresholds.insert((res.id.clone(), stype.clone()), fair_slice - 1);
|
||||
for res in &res {
|
||||
local_limits.insert((res.id.clone(), shift_type.clone()), fair_slice);
|
||||
local_thresholds.insert((res.id.clone(), shift_type.clone()), fair_slice - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,7 +198,11 @@ impl WorkloadBounds {
|
||||
mod tests {
|
||||
use rstest::{fixture, rstest};
|
||||
|
||||
use crate::{bounds::WorkloadBounds, config::UserConfig, resident::Resident};
|
||||
use crate::{
|
||||
bounds::WorkloadBounds,
|
||||
config::UserConfig,
|
||||
resident::{Resident, ResidentId},
|
||||
};
|
||||
|
||||
#[fixture]
|
||||
fn config() -> UserConfig {
|
||||
@@ -212,19 +219,25 @@ mod tests {
|
||||
fn test_max_workloads(config: UserConfig) {
|
||||
let bounds = WorkloadBounds::new_with_config(&config);
|
||||
|
||||
assert_eq!(bounds.max_workloads["1"], 2);
|
||||
assert_eq!(bounds.max_workloads["2"], 2);
|
||||
assert_eq!(bounds.max_workloads["3"], 12);
|
||||
assert_eq!(bounds.max_workloads["4"], 13);
|
||||
assert_eq!(bounds.max_workloads["5"], 13);
|
||||
assert_eq!(bounds.max_workloads[&ResidentId("1".to_string())], 2);
|
||||
assert_eq!(bounds.max_workloads[&ResidentId("2".to_string())], 2);
|
||||
assert_eq!(bounds.max_workloads[&ResidentId("3".to_string())], 12);
|
||||
assert_eq!(bounds.max_workloads[&ResidentId("4".to_string())], 13);
|
||||
assert_eq!(bounds.max_workloads[&ResidentId("5".to_string())], 13);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_calculate_max_holiday_shifts(config: UserConfig) {
|
||||
let bounds = WorkloadBounds::new_with_config(&config);
|
||||
|
||||
let stefanos_limit = *bounds.max_holiday_shifts.get("1").unwrap();
|
||||
let iordanis_limit = *bounds.max_holiday_shifts.get("2").unwrap();
|
||||
let stefanos_limit = *bounds
|
||||
.max_holiday_shifts
|
||||
.get(&ResidentId("1".to_string()))
|
||||
.unwrap();
|
||||
let iordanis_limit = *bounds
|
||||
.max_holiday_shifts
|
||||
.get(&ResidentId("2".to_string()))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stefanos_limit, 1);
|
||||
assert_eq!(iordanis_limit, 1);
|
||||
|
||||
@@ -2,12 +2,37 @@ use chrono::Month;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
resident::{Resident, ResidentDTO},
|
||||
resident::{Resident, ResidentDTO, ResidentId},
|
||||
slot::Day,
|
||||
};
|
||||
|
||||
const YEAR: i32 = 2026;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToxicPair((ResidentId, ResidentId));
|
||||
|
||||
impl ToxicPair {
|
||||
pub fn new(res_id_1: &str, res_id_2: &str) -> Self {
|
||||
Self((
|
||||
ResidentId(res_id_1.to_string()),
|
||||
ResidentId(res_id_2.to_string()),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn matches(&self, other: &ToxicPair) -> bool {
|
||||
let p1 = &self.0;
|
||||
let p2 = &other.0;
|
||||
|
||||
(p1.0 == p2.0 && p1.1 == p2.1) || (p1.0 == p2.1 && p1.1 == p2.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(ResidentId, ResidentId)> for ToxicPair {
|
||||
fn from(value: (ResidentId, ResidentId)) -> Self {
|
||||
Self((value.0, value.1))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct UserConfigDTO {
|
||||
month: usize,
|
||||
@@ -23,7 +48,7 @@ pub struct UserConfig {
|
||||
pub year: i32,
|
||||
pub holidays: Vec<usize>,
|
||||
pub residents: Vec<Resident>,
|
||||
pub toxic_pairs: Vec<(String, String)>,
|
||||
pub toxic_pairs: Vec<ToxicPair>,
|
||||
}
|
||||
|
||||
impl UserConfig {
|
||||
@@ -51,7 +76,7 @@ impl UserConfig {
|
||||
self.residents.push(resident);
|
||||
}
|
||||
|
||||
pub fn with_toxic_pairs(mut self, toxic_pairs: Vec<(String, String)>) -> Self {
|
||||
pub fn with_toxic_pairs(mut self, toxic_pairs: Vec<ToxicPair>) -> Self {
|
||||
self.toxic_pairs = toxic_pairs;
|
||||
self
|
||||
}
|
||||
@@ -99,7 +124,11 @@ impl From<UserConfigDTO> for UserConfig {
|
||||
year: value.year,
|
||||
holidays: value.holidays,
|
||||
residents: value.residents.into_iter().map(Resident::from).collect(),
|
||||
toxic_pairs: value.toxic_pairs,
|
||||
toxic_pairs: value
|
||||
.toxic_pairs
|
||||
.into_iter()
|
||||
.map(|p| ToxicPair::new(&p.0, &p.1))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ mod tests {
|
||||
|
||||
#[rstest]
|
||||
pub fn test_export_as_doc(mut schedule: MonthlySchedule, scheduler: Scheduler) {
|
||||
scheduler.search(&mut schedule, Slot::default());
|
||||
scheduler.run(&mut schedule);
|
||||
schedule.export_as_doc(&scheduler.config);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,8 +27,6 @@ struct AppState {
|
||||
fn generate(config: UserConfigDTO, state: tauri::State<'_, AppState>) -> MonthlySchedule {
|
||||
let config = UserConfig::from(config);
|
||||
let mut schedule = MonthlySchedule::new();
|
||||
schedule.prefill(&config);
|
||||
|
||||
let bounds = WorkloadBounds::new_with_config(&config);
|
||||
let scheduler = Scheduler::new(config, bounds);
|
||||
scheduler.run(&mut schedule);
|
||||
|
||||
@@ -5,10 +5,13 @@ use crate::{
|
||||
slot::{Day, ShiftPosition, Slot},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct ResidentId(pub String);
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Resident {
|
||||
pub id: String,
|
||||
pub id: ResidentId,
|
||||
pub name: String,
|
||||
pub negative_shifts: Vec<Day>,
|
||||
pub manual_shifts: Vec<Slot>,
|
||||
@@ -32,7 +35,7 @@ pub struct ResidentDTO {
|
||||
impl Resident {
|
||||
pub fn new(id: &str, name: &str) -> Self {
|
||||
Self {
|
||||
id: id.to_string(),
|
||||
id: ResidentId(id.to_string()),
|
||||
name: name.to_string(),
|
||||
negative_shifts: Vec::new(),
|
||||
manual_shifts: Vec::new(),
|
||||
@@ -75,7 +78,7 @@ impl Resident {
|
||||
impl From<ResidentDTO> for Resident {
|
||||
fn from(value: ResidentDTO) -> Self {
|
||||
Self {
|
||||
id: value.id,
|
||||
id: ResidentId(value.id),
|
||||
name: value.name,
|
||||
negative_shifts: value
|
||||
.negative_shifts
|
||||
|
||||
@@ -3,8 +3,8 @@ use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
bounds::WorkloadBounds,
|
||||
config::UserConfig,
|
||||
resident::Resident,
|
||||
config::{ToxicPair, UserConfig},
|
||||
resident::ResidentId,
|
||||
slot::{weekday_to_greek, Day, ShiftPosition, Slot},
|
||||
};
|
||||
|
||||
@@ -13,7 +13,7 @@ use serde::Serializer;
|
||||
/// each slot has one resident
|
||||
/// a day can span between 1 or 2 slots depending on if it is open(odd) or closed(even)
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct MonthlySchedule(pub HashMap<Slot, String>);
|
||||
pub struct MonthlySchedule(pub HashMap<Slot, ResidentId>);
|
||||
|
||||
impl MonthlySchedule {
|
||||
pub fn new() -> Self {
|
||||
@@ -28,31 +28,31 @@ impl MonthlySchedule {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_resident_id(&self, slot: &Slot) -> Option<&String> {
|
||||
pub fn get_resident_id(&self, slot: &Slot) -> Option<&ResidentId> {
|
||||
self.0.get(slot)
|
||||
}
|
||||
|
||||
pub fn current_workload(&self, resident_id: &str) -> usize {
|
||||
pub fn current_workload(&self, resident_id: &ResidentId) -> usize {
|
||||
self.0
|
||||
.values()
|
||||
.filter(|res_id| res_id == &resident_id)
|
||||
.count()
|
||||
}
|
||||
|
||||
pub fn current_holiday_workload(&self, resident: &Resident, config: &UserConfig) -> usize {
|
||||
pub fn current_holiday_workload(&self, resident_id: &ResidentId, config: &UserConfig) -> usize {
|
||||
self.0
|
||||
.iter()
|
||||
.filter(|(slot, res_id)| {
|
||||
res_id == &&resident.id && config.is_holiday_or_weekend_slot(slot.day.0)
|
||||
res_id == &resident_id && config.is_holiday_or_weekend_slot(slot.day.0)
|
||||
})
|
||||
.count()
|
||||
}
|
||||
|
||||
pub fn count_shifts(&self, resident_id: &str, shift_type: Option<ShiftType>) -> usize {
|
||||
pub fn count_shifts(&self, resident_id: &ResidentId, shift_type: Option<ShiftType>) -> usize {
|
||||
self.0
|
||||
.iter()
|
||||
.filter(|(slot, id)| {
|
||||
if *id != resident_id {
|
||||
.filter(|(&slot, id)| {
|
||||
if id != &resident_id {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -92,8 +92,8 @@ impl MonthlySchedule {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, slot: Slot, resident_id: &str) {
|
||||
self.0.insert(slot, resident_id.to_string());
|
||||
pub fn insert(&mut self, slot: Slot, resident_id: &ResidentId) {
|
||||
self.0.insert(slot, resident_id.clone());
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, slot: Slot) {
|
||||
@@ -154,11 +154,11 @@ impl MonthlySchedule {
|
||||
let first_id = self.get_resident_id(&slot.previous());
|
||||
let second_id = self.get_resident_id(slot);
|
||||
|
||||
if let (Some(f), Some(s)) = (first_id, second_id) {
|
||||
if let (Some(r1), Some(r2)) = (first_id, second_id) {
|
||||
return config
|
||||
.toxic_pairs
|
||||
.iter()
|
||||
.any(|(r1, r2)| (r1 == f && r2 == s) || (r1 == s && r2 == f));
|
||||
.any(|pair| pair.matches(&ToxicPair::from((r1.clone(), r2.clone()))));
|
||||
}
|
||||
|
||||
false
|
||||
@@ -210,15 +210,15 @@ impl MonthlySchedule {
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if let Some(resident) = config.residents.iter().find(|r| &r.id == res_id) {
|
||||
let current_holiday_workload = self.current_holiday_workload(resident, config);
|
||||
// if let Some(resident) = config.residents.iter().find(|r| &r.id == res_id) {
|
||||
let current_holiday_workload = self.current_holiday_workload(res_id, config);
|
||||
|
||||
if let Some(&holiday_limit) = bounds.max_holiday_shifts.get(res_id) {
|
||||
if current_holiday_workload > holiday_limit as usize {
|
||||
return true;
|
||||
}
|
||||
if let Some(&holiday_limit) = bounds.max_holiday_shifts.get(res_id) {
|
||||
if current_holiday_workload > holiday_limit as usize {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// }
|
||||
|
||||
false
|
||||
}
|
||||
@@ -297,7 +297,7 @@ impl MonthlySchedule {
|
||||
let o1 = self.count_shifts(&res.id, Some(ShiftType::OpenFirst));
|
||||
let o2 = self.count_shifts(&res.id, Some(ShiftType::OpenSecond));
|
||||
let cl = self.count_shifts(&res.id, Some(ShiftType::Closed));
|
||||
let sun = self.current_holiday_workload(res, config);
|
||||
let sun = self.current_holiday_workload(&res.id, config);
|
||||
|
||||
output.push_str(&format!(
|
||||
"{:<15} | {:<6} | {:<10} | {:<10} | {:<7} | {:<10}\n",
|
||||
@@ -341,8 +341,8 @@ mod tests {
|
||||
|
||||
use crate::{
|
||||
bounds::WorkloadBounds,
|
||||
config::UserConfig,
|
||||
resident::Resident,
|
||||
config::{ToxicPair, UserConfig},
|
||||
resident::{Resident, ResidentId},
|
||||
schedule::{Day, MonthlySchedule, Slot},
|
||||
slot::ShiftPosition,
|
||||
};
|
||||
@@ -364,7 +364,7 @@ mod tests {
|
||||
Resident::new("1", "Stefanos"),
|
||||
Resident::new("2", "Iordanis"),
|
||||
])
|
||||
.with_toxic_pairs(vec![(("1".to_string(), "2".to_string()))])
|
||||
.with_toxic_pairs(vec![ToxicPair::new("1", "2")])
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
@@ -382,7 +382,10 @@ mod tests {
|
||||
|
||||
schedule.insert(slot_1, &resident.id);
|
||||
|
||||
assert_eq!(schedule.get_resident_id(&slot_1), Some(&"1".to_string()));
|
||||
assert_eq!(
|
||||
schedule.get_resident_id(&slot_1),
|
||||
Some(&ResidentId("1".to_string()))
|
||||
);
|
||||
assert_eq!(schedule.current_workload(&resident.id), 1);
|
||||
assert_eq!(schedule.get_resident_id(&slot_2), None);
|
||||
}
|
||||
@@ -439,8 +442,8 @@ mod tests {
|
||||
|
||||
let mut bounds = WorkloadBounds::new();
|
||||
|
||||
bounds.max_workloads.insert("1".to_string(), 1);
|
||||
bounds.max_workloads.insert("2".to_string(), 2);
|
||||
bounds.max_workloads.insert(ResidentId("1".to_string()), 1);
|
||||
bounds.max_workloads.insert(ResidentId("2".to_string()), 2);
|
||||
|
||||
schedule.insert(slot_1, &stefanos.id);
|
||||
assert!(!schedule.is_workload_unbalanced(&slot_1, &config, &bounds));
|
||||
@@ -463,8 +466,12 @@ mod tests {
|
||||
|
||||
let mut bounds = WorkloadBounds::new();
|
||||
|
||||
bounds.max_holiday_shifts.insert("1".to_string(), 1);
|
||||
bounds.max_holiday_shifts.insert("2".to_string(), 1);
|
||||
bounds
|
||||
.max_holiday_shifts
|
||||
.insert(ResidentId("1".to_string()), 1);
|
||||
bounds
|
||||
.max_holiday_shifts
|
||||
.insert(ResidentId("2".to_string()), 1);
|
||||
|
||||
schedule.insert(slot_1, &stefanos.id);
|
||||
assert!(!schedule.is_holiday_workload_imbalanced(&slot_1, &config, &bounds));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
bounds::WorkloadBounds, config::UserConfig, schedule::MonthlySchedule,
|
||||
bounds::WorkloadBounds, config::UserConfig, resident::ResidentId, schedule::MonthlySchedule,
|
||||
slot::Slot,
|
||||
};
|
||||
|
||||
@@ -16,6 +16,7 @@ impl Scheduler {
|
||||
}
|
||||
|
||||
pub fn run(&self, schedule: &mut MonthlySchedule) -> bool {
|
||||
schedule.prefill(&self.config);
|
||||
self.search(schedule, Slot::default())
|
||||
}
|
||||
|
||||
@@ -66,7 +67,7 @@ impl Scheduler {
|
||||
}
|
||||
|
||||
/// Return all valid residents for the current slot
|
||||
pub fn valid_residents(&self, slot: Slot, schedule: &MonthlySchedule) -> Vec<String> {
|
||||
pub fn valid_residents(&self, slot: Slot, schedule: &MonthlySchedule) -> Vec<ResidentId> {
|
||||
let other_slot_resident_id = schedule.get_resident_id(&slot.other_position());
|
||||
|
||||
self.config
|
||||
@@ -144,7 +145,4 @@ mod tests {
|
||||
|
||||
println!("{}", schedule.pretty_print(&scheduler.config));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_valid_residents(mut schedule: MonthlySchedule, scheduler: Scheduler) {}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user