From 7713318d010875d19f82ca17742b3d07f660ebb0 Mon Sep 17 00:00:00 2001 From: stefiosif Date: Wed, 14 Jan 2026 23:37:58 +0200 Subject: [PATCH] Improve type system with ResidentId and ToxicPair structs --- src-tauri/src/bounds.rs | 83 ++++++++++++++++++++++---------------- src-tauri/src/config.rs | 37 +++++++++++++++-- src-tauri/src/export.rs | 2 +- src-tauri/src/lib.rs | 2 - src-tauri/src/resident.rs | 9 +++-- src-tauri/src/schedule.rs | 65 ++++++++++++++++------------- src-tauri/src/scheduler.rs | 8 ++-- 7 files changed, 127 insertions(+), 79 deletions(-) diff --git a/src-tauri/src/bounds.rs b/src-tauri/src/bounds.rs index 38de34d..ba7a577 100644 --- a/src-tauri/src/bounds.rs +++ b/src-tauri/src/bounds.rs @@ -1,12 +1,17 @@ use std::collections::HashMap; -use crate::{config::UserConfig, resident::Resident, schedule::ShiftType, slot::Day}; +use crate::{ + config::UserConfig, + resident::{Resident, ResidentId}, + schedule::ShiftType, + slot::Day, +}; pub struct WorkloadBounds { - pub max_workloads: HashMap, - pub max_holiday_shifts: HashMap, - pub max_by_shift_type: HashMap<(String, ShiftType), u8>, - pub min_by_shift_type: HashMap<(String, ShiftType), u8>, + pub max_workloads: HashMap, + pub max_holiday_shifts: HashMap, + pub max_by_shift_type: HashMap<(ResidentId, ShiftType), u8>, + pub min_by_shift_type: HashMap<(ResidentId, ShiftType), u8>, } impl WorkloadBounds { @@ -124,20 +129,20 @@ impl WorkloadBounds { .iter() .filter(|r| r.allowed_types.len() == 1) { - let stype = &res.allowed_types[0]; + let shift_type = &res.allowed_types[0]; let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0); - local_limits.insert((res.id.clone(), stype.clone()), total_limit); - local_thresholds.insert((res.id.clone(), stype.clone()), total_limit - 1); + local_limits.insert((res.id.clone(), shift_type.clone()), total_limit); + local_thresholds.insert((res.id.clone(), shift_type.clone()), total_limit - 1); for other_type in &all_shift_types { - if other_type != stype { + if other_type != shift_type { local_limits.insert((res.id.clone(), other_type.clone()), 0); local_thresholds.insert((res.id.clone(), other_type.clone()), 0); } } - if let Some(s) = global_supply.get_mut(stype) { + if let Some(s) = global_supply.get_mut(shift_type) { *s = s.saturating_sub(total_limit) } } @@ -151,37 +156,35 @@ impl WorkloadBounds { let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0) as f32; let per_type = (total_limit / 2.0).ceil() as u8; - for stype in &all_shift_types { - if res.allowed_types.contains(stype) { - local_limits.insert((res.id.clone(), stype.clone()), per_type); - local_thresholds.insert((res.id.clone(), stype.clone()), per_type - 1); - if let Some(s) = global_supply.get_mut(stype) { + for shift_type in &all_shift_types { + if res.allowed_types.contains(shift_type) { + local_limits.insert((res.id.clone(), shift_type.clone()), per_type); + local_thresholds.insert((res.id.clone(), shift_type.clone()), per_type - 1); + if let Some(s) = global_supply.get_mut(shift_type) { *s = s.saturating_sub(per_type) } } else { - local_limits.insert((res.id.clone(), stype.clone()), 0); - local_thresholds.insert((res.id.clone(), stype.clone()), 0); + local_limits.insert((res.id.clone(), shift_type.clone()), 0); + local_thresholds.insert((res.id.clone(), shift_type.clone()), 0); } } } // residents with 3 available shift types - let generalists: Vec<&Resident> = config + let res: Vec<&Resident> = config .residents .iter() .filter(|r| r.allowed_types.len() == 3) .collect(); - if !generalists.is_empty() { - for stype in &all_shift_types { - let remaining = *global_supply.get(stype).unwrap_or(&0); - let fair_slice = (remaining as f32 / generalists.len() as f32) - .ceil() - .max(0.0) as u8; + if !res.is_empty() { + for shift_type in &all_shift_types { + let remaining = *global_supply.get(shift_type).unwrap_or(&0); + let fair_slice = (remaining as f32 / res.len() as f32).ceil().max(0.0) as u8; - for res in &generalists { - local_limits.insert((res.id.clone(), stype.clone()), fair_slice); - local_thresholds.insert((res.id.clone(), stype.clone()), fair_slice - 1); + for res in &res { + local_limits.insert((res.id.clone(), shift_type.clone()), fair_slice); + local_thresholds.insert((res.id.clone(), shift_type.clone()), fair_slice - 1); } } } @@ -195,7 +198,11 @@ impl WorkloadBounds { mod tests { use rstest::{fixture, rstest}; - use crate::{bounds::WorkloadBounds, config::UserConfig, resident::Resident}; + use crate::{ + bounds::WorkloadBounds, + config::UserConfig, + resident::{Resident, ResidentId}, + }; #[fixture] fn config() -> UserConfig { @@ -212,19 +219,25 @@ mod tests { fn test_max_workloads(config: UserConfig) { let bounds = WorkloadBounds::new_with_config(&config); - assert_eq!(bounds.max_workloads["1"], 2); - assert_eq!(bounds.max_workloads["2"], 2); - assert_eq!(bounds.max_workloads["3"], 12); - assert_eq!(bounds.max_workloads["4"], 13); - assert_eq!(bounds.max_workloads["5"], 13); + assert_eq!(bounds.max_workloads[&ResidentId("1".to_string())], 2); + assert_eq!(bounds.max_workloads[&ResidentId("2".to_string())], 2); + assert_eq!(bounds.max_workloads[&ResidentId("3".to_string())], 12); + assert_eq!(bounds.max_workloads[&ResidentId("4".to_string())], 13); + assert_eq!(bounds.max_workloads[&ResidentId("5".to_string())], 13); } #[rstest] fn test_calculate_max_holiday_shifts(config: UserConfig) { let bounds = WorkloadBounds::new_with_config(&config); - let stefanos_limit = *bounds.max_holiday_shifts.get("1").unwrap(); - let iordanis_limit = *bounds.max_holiday_shifts.get("2").unwrap(); + let stefanos_limit = *bounds + .max_holiday_shifts + .get(&ResidentId("1".to_string())) + .unwrap(); + let iordanis_limit = *bounds + .max_holiday_shifts + .get(&ResidentId("2".to_string())) + .unwrap(); assert_eq!(stefanos_limit, 1); assert_eq!(iordanis_limit, 1); diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index 0185d18..d84ed93 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -2,12 +2,37 @@ use chrono::Month; use serde::{Deserialize, Serialize}; use crate::{ - resident::{Resident, ResidentDTO}, + resident::{Resident, ResidentDTO, ResidentId}, slot::Day, }; const YEAR: i32 = 2026; +#[derive(Debug)] +pub struct ToxicPair((ResidentId, ResidentId)); + +impl ToxicPair { + pub fn new(res_id_1: &str, res_id_2: &str) -> Self { + Self(( + ResidentId(res_id_1.to_string()), + ResidentId(res_id_2.to_string()), + )) + } + + pub fn matches(&self, other: &ToxicPair) -> bool { + let p1 = &self.0; + let p2 = &other.0; + + (p1.0 == p2.0 && p1.1 == p2.1) || (p1.0 == p2.1 && p1.1 == p2.0) + } +} + +impl From<(ResidentId, ResidentId)> for ToxicPair { + fn from(value: (ResidentId, ResidentId)) -> Self { + Self((value.0, value.1)) + } +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct UserConfigDTO { month: usize, @@ -23,7 +48,7 @@ pub struct UserConfig { pub year: i32, pub holidays: Vec, pub residents: Vec, - pub toxic_pairs: Vec<(String, String)>, + pub toxic_pairs: Vec, } impl UserConfig { @@ -51,7 +76,7 @@ impl UserConfig { self.residents.push(resident); } - pub fn with_toxic_pairs(mut self, toxic_pairs: Vec<(String, String)>) -> Self { + pub fn with_toxic_pairs(mut self, toxic_pairs: Vec) -> Self { self.toxic_pairs = toxic_pairs; self } @@ -99,7 +124,11 @@ impl From for UserConfig { year: value.year, holidays: value.holidays, residents: value.residents.into_iter().map(Resident::from).collect(), - toxic_pairs: value.toxic_pairs, + toxic_pairs: value + .toxic_pairs + .into_iter() + .map(|p| ToxicPair::new(&p.0, &p.1)) + .collect(), } } } diff --git a/src-tauri/src/export.rs b/src-tauri/src/export.rs index df25afc..7137ea9 100644 --- a/src-tauri/src/export.rs +++ b/src-tauri/src/export.rs @@ -197,7 +197,7 @@ mod tests { #[rstest] pub fn test_export_as_doc(mut schedule: MonthlySchedule, scheduler: Scheduler) { - scheduler.search(&mut schedule, Slot::default()); + scheduler.run(&mut schedule); schedule.export_as_doc(&scheduler.config); } } diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 2c61722..cc285f5 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -27,8 +27,6 @@ struct AppState { fn generate(config: UserConfigDTO, state: tauri::State<'_, AppState>) -> MonthlySchedule { let config = UserConfig::from(config); let mut schedule = MonthlySchedule::new(); - schedule.prefill(&config); - let bounds = WorkloadBounds::new_with_config(&config); let scheduler = Scheduler::new(config, bounds); scheduler.run(&mut schedule); diff --git a/src-tauri/src/resident.rs b/src-tauri/src/resident.rs index c8cd73e..526e66f 100644 --- a/src-tauri/src/resident.rs +++ b/src-tauri/src/resident.rs @@ -5,10 +5,13 @@ use crate::{ slot::{Day, ShiftPosition, Slot}, }; +#[derive(Serialize, Deserialize, Debug, Clone, Hash, Eq, PartialEq)] +pub struct ResidentId(pub String); + #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct Resident { - pub id: String, + pub id: ResidentId, pub name: String, pub negative_shifts: Vec, pub manual_shifts: Vec, @@ -32,7 +35,7 @@ pub struct ResidentDTO { impl Resident { pub fn new(id: &str, name: &str) -> Self { Self { - id: id.to_string(), + id: ResidentId(id.to_string()), name: name.to_string(), negative_shifts: Vec::new(), manual_shifts: Vec::new(), @@ -75,7 +78,7 @@ impl Resident { impl From for Resident { fn from(value: ResidentDTO) -> Self { Self { - id: value.id, + id: ResidentId(value.id), name: value.name, negative_shifts: value .negative_shifts diff --git a/src-tauri/src/schedule.rs b/src-tauri/src/schedule.rs index 37db799..bcefcd7 100644 --- a/src-tauri/src/schedule.rs +++ b/src-tauri/src/schedule.rs @@ -3,8 +3,8 @@ use std::collections::HashMap; use crate::{ bounds::WorkloadBounds, - config::UserConfig, - resident::Resident, + config::{ToxicPair, UserConfig}, + resident::ResidentId, slot::{weekday_to_greek, Day, ShiftPosition, Slot}, }; @@ -13,7 +13,7 @@ use serde::Serializer; /// each slot has one resident /// a day can span between 1 or 2 slots depending on if it is open(odd) or closed(even) #[derive(Deserialize, Debug, Clone)] -pub struct MonthlySchedule(pub HashMap); +pub struct MonthlySchedule(pub HashMap); impl MonthlySchedule { pub fn new() -> Self { @@ -28,31 +28,31 @@ impl MonthlySchedule { } } - pub fn get_resident_id(&self, slot: &Slot) -> Option<&String> { + pub fn get_resident_id(&self, slot: &Slot) -> Option<&ResidentId> { self.0.get(slot) } - pub fn current_workload(&self, resident_id: &str) -> usize { + pub fn current_workload(&self, resident_id: &ResidentId) -> usize { self.0 .values() .filter(|res_id| res_id == &resident_id) .count() } - pub fn current_holiday_workload(&self, resident: &Resident, config: &UserConfig) -> usize { + pub fn current_holiday_workload(&self, resident_id: &ResidentId, config: &UserConfig) -> usize { self.0 .iter() .filter(|(slot, res_id)| { - res_id == &&resident.id && config.is_holiday_or_weekend_slot(slot.day.0) + res_id == &resident_id && config.is_holiday_or_weekend_slot(slot.day.0) }) .count() } - pub fn count_shifts(&self, resident_id: &str, shift_type: Option) -> usize { + pub fn count_shifts(&self, resident_id: &ResidentId, shift_type: Option) -> usize { self.0 .iter() - .filter(|(slot, id)| { - if *id != resident_id { + .filter(|(&slot, id)| { + if id != &resident_id { return false; } @@ -92,8 +92,8 @@ impl MonthlySchedule { true } - pub fn insert(&mut self, slot: Slot, resident_id: &str) { - self.0.insert(slot, resident_id.to_string()); + pub fn insert(&mut self, slot: Slot, resident_id: &ResidentId) { + self.0.insert(slot, resident_id.clone()); } pub fn remove(&mut self, slot: Slot) { @@ -154,11 +154,11 @@ impl MonthlySchedule { let first_id = self.get_resident_id(&slot.previous()); let second_id = self.get_resident_id(slot); - if let (Some(f), Some(s)) = (first_id, second_id) { + if let (Some(r1), Some(r2)) = (first_id, second_id) { return config .toxic_pairs .iter() - .any(|(r1, r2)| (r1 == f && r2 == s) || (r1 == s && r2 == f)); + .any(|pair| pair.matches(&ToxicPair::from((r1.clone(), r2.clone())))); } false @@ -210,15 +210,15 @@ impl MonthlySchedule { None => return false, }; - if let Some(resident) = config.residents.iter().find(|r| &r.id == res_id) { - let current_holiday_workload = self.current_holiday_workload(resident, config); + // if let Some(resident) = config.residents.iter().find(|r| &r.id == res_id) { + let current_holiday_workload = self.current_holiday_workload(res_id, config); - if let Some(&holiday_limit) = bounds.max_holiday_shifts.get(res_id) { - if current_holiday_workload > holiday_limit as usize { - return true; - } + if let Some(&holiday_limit) = bounds.max_holiday_shifts.get(res_id) { + if current_holiday_workload > holiday_limit as usize { + return true; } } + // } false } @@ -297,7 +297,7 @@ impl MonthlySchedule { let o1 = self.count_shifts(&res.id, Some(ShiftType::OpenFirst)); let o2 = self.count_shifts(&res.id, Some(ShiftType::OpenSecond)); let cl = self.count_shifts(&res.id, Some(ShiftType::Closed)); - let sun = self.current_holiday_workload(res, config); + let sun = self.current_holiday_workload(&res.id, config); output.push_str(&format!( "{:<15} | {:<6} | {:<10} | {:<10} | {:<7} | {:<10}\n", @@ -341,8 +341,8 @@ mod tests { use crate::{ bounds::WorkloadBounds, - config::UserConfig, - resident::Resident, + config::{ToxicPair, UserConfig}, + resident::{Resident, ResidentId}, schedule::{Day, MonthlySchedule, Slot}, slot::ShiftPosition, }; @@ -364,7 +364,7 @@ mod tests { Resident::new("1", "Stefanos"), Resident::new("2", "Iordanis"), ]) - .with_toxic_pairs(vec![(("1".to_string(), "2".to_string()))]) + .with_toxic_pairs(vec![ToxicPair::new("1", "2")]) } #[fixture] @@ -382,7 +382,10 @@ mod tests { schedule.insert(slot_1, &resident.id); - assert_eq!(schedule.get_resident_id(&slot_1), Some(&"1".to_string())); + assert_eq!( + schedule.get_resident_id(&slot_1), + Some(&ResidentId("1".to_string())) + ); assert_eq!(schedule.current_workload(&resident.id), 1); assert_eq!(schedule.get_resident_id(&slot_2), None); } @@ -439,8 +442,8 @@ mod tests { let mut bounds = WorkloadBounds::new(); - bounds.max_workloads.insert("1".to_string(), 1); - bounds.max_workloads.insert("2".to_string(), 2); + bounds.max_workloads.insert(ResidentId("1".to_string()), 1); + bounds.max_workloads.insert(ResidentId("2".to_string()), 2); schedule.insert(slot_1, &stefanos.id); assert!(!schedule.is_workload_unbalanced(&slot_1, &config, &bounds)); @@ -463,8 +466,12 @@ mod tests { let mut bounds = WorkloadBounds::new(); - bounds.max_holiday_shifts.insert("1".to_string(), 1); - bounds.max_holiday_shifts.insert("2".to_string(), 1); + bounds + .max_holiday_shifts + .insert(ResidentId("1".to_string()), 1); + bounds + .max_holiday_shifts + .insert(ResidentId("2".to_string()), 1); schedule.insert(slot_1, &stefanos.id); assert!(!schedule.is_holiday_workload_imbalanced(&slot_1, &config, &bounds)); diff --git a/src-tauri/src/scheduler.rs b/src-tauri/src/scheduler.rs index 7c0e00e..6445ac6 100644 --- a/src-tauri/src/scheduler.rs +++ b/src-tauri/src/scheduler.rs @@ -1,5 +1,5 @@ use crate::{ - bounds::WorkloadBounds, config::UserConfig, schedule::MonthlySchedule, + bounds::WorkloadBounds, config::UserConfig, resident::ResidentId, schedule::MonthlySchedule, slot::Slot, }; @@ -16,6 +16,7 @@ impl Scheduler { } pub fn run(&self, schedule: &mut MonthlySchedule) -> bool { + schedule.prefill(&self.config); self.search(schedule, Slot::default()) } @@ -66,7 +67,7 @@ impl Scheduler { } /// Return all valid residents for the current slot - pub fn valid_residents(&self, slot: Slot, schedule: &MonthlySchedule) -> Vec { + pub fn valid_residents(&self, slot: Slot, schedule: &MonthlySchedule) -> Vec { let other_slot_resident_id = schedule.get_resident_id(&slot.other_position()); self.config @@ -144,7 +145,4 @@ mod tests { println!("{}", schedule.pretty_print(&scheduler.config)); } - - #[rstest] - fn test_valid_residents(mut schedule: MonthlySchedule, scheduler: Scheduler) {} }