use std::collections::HashMap; use crate::{ config::UserConfig, resident::{Resident, ResidentId}, schedule::ShiftType, slot::Slot, }; #[derive(Default)] pub struct WorkloadBounds { pub max_workloads: HashMap, pub max_holiday_shifts: HashMap, pub max_by_shift_type: HashMap<(ResidentId, ShiftType), u8>, } impl WorkloadBounds { pub fn new_with_config(config: &UserConfig) -> Self { let residents = &config.residents; let total_slots = config.total_slots; let total_holiday_slots = config.total_holiday_slots; let mut bounds = Self::default(); bounds.calculate_max_workloads(residents, total_slots); debug_assert!(bounds.max_workloads.values().sum::() >= total_slots); bounds.calculate_max_holiday_shifts(residents, total_holiday_slots); debug_assert!(bounds.max_holiday_shifts.values().sum::() >= total_holiday_slots); bounds.calculate_max_by_shift_type(residents); debug_assert!(bounds.max_by_shift_type.values().sum::() >= total_slots); bounds } pub fn calculate_max_workloads(&mut self, residents: &Vec, total_slots: u8) { let non_manual_residents: Vec<_> = residents .iter() .filter(|r| r.max_shifts.is_none()) .collect(); // all residents' max workload were manually inserted if non_manual_residents.is_empty() { for r in residents { self.max_workloads.insert(r.id, r.max_shifts.unwrap_or(0)); } return; } let total_manual_workload: u8 = residents.iter().map(|r| r.max_shifts.unwrap_or(0)).sum(); let total_reduced_workload: u8 = residents .iter() .map(|r| if r.reduced_load { 1 } else { 0 }) .sum(); let remaining_slots = total_slots - total_manual_workload + total_reduced_workload; let workload_share = remaining_slots.div_ceil(non_manual_residents.len() as u8); for r in residents { let max_workload = match r.max_shifts { Some(max_shifts) => max_shifts, None if r.reduced_load => workload_share - 1, None => workload_share, }; self.max_workloads.insert(r.id, max_workload); } } pub fn calculate_max_holiday_shifts( &mut self, residents: &Vec, total_holiday_slots: u8, ) { let total_residents = residents.len(); let holiday_share = total_holiday_slots.div_ceil(total_residents as u8); for r in residents { self.max_holiday_shifts.insert(r.id, holiday_share); } } pub fn calculate_max_by_shift_type(&mut self, residents: &Vec) { let mut upper_limits = HashMap::new(); let shift_types = [ ShiftType::OpenFirst, ShiftType::OpenSecond, ShiftType::Closed, ]; for r in residents { let total_limit = *self.max_workloads.get(&r.id).unwrap_or(&0); let n_allowed = r.allowed_types.len(); for shift_type in shift_types { let limit = if r.allowed_types.contains(&shift_type) { if n_allowed == 1 { total_limit } else { (total_limit as f32 / n_allowed as f32).floor() as u8 + 1 } } else { 0 }; upper_limits.insert((r.id, shift_type), limit); } } self.max_by_shift_type = upper_limits; } } #[derive(Default, Clone, Debug)] pub struct WorkloadTracker { total_counts: HashMap, type_counts: HashMap<(ResidentId, ShiftType), u8>, holiday_counts: HashMap, } impl WorkloadTracker { pub fn insert(&mut self, r_id: ResidentId, config: &UserConfig, slot: Slot) { *self.total_counts.entry(r_id).or_insert(0) += 1; *self .type_counts .entry((r_id, slot.shift_type())) .or_insert(0) += 1; if config.is_holiday_or_weekend_slot(slot) { *self.holiday_counts.entry(r_id).or_insert(0) += 1; } } pub fn remove(&mut self, r_id: ResidentId, config: &UserConfig, slot: Slot) { if let Some(count) = self.total_counts.get_mut(&r_id) { *count = count.saturating_sub(1); } if let Some(count) = self.type_counts.get_mut(&(r_id, slot.shift_type())) { *count = count.saturating_sub(1); } if config.is_holiday_or_weekend_slot(slot) { if let Some(count) = self.holiday_counts.get_mut(&r_id) { *count = count.saturating_sub(1); } } } pub fn current_workload(&self, r_id: &ResidentId) -> u8 { *self.total_counts.get(r_id).unwrap_or(&0) } pub fn current_holiday_workload(&self, r_id: &ResidentId) -> u8 { *self.holiday_counts.get(r_id).unwrap_or(&0) } pub fn current_shift_type_workload(&self, r_id: &ResidentId, shift_type: ShiftType) -> u8 { *self.type_counts.get(&(*r_id, shift_type)).unwrap_or(&0) } pub fn reached_workload_limit(&self, bounds: &WorkloadBounds, r_id: &ResidentId) -> bool { let current_load = self.current_workload(r_id); if let Some(&max) = bounds.max_workloads.get(r_id) { return current_load >= max; } false } pub fn reached_holiday_limit(&self, bounds: &WorkloadBounds, r_id: &ResidentId) -> bool { let current_load = self.current_holiday_workload(r_id); if let Some(&max) = bounds.max_holiday_shifts.get(r_id) { return current_load >= max; } false } pub fn reached_shift_type_limit( &self, bounds: &WorkloadBounds, r_id: &ResidentId, shift_type: ShiftType, ) -> bool { let current_load = self.current_shift_type_workload(r_id, shift_type); if let Some(&max) = bounds.max_by_shift_type.get(&(*r_id, shift_type)) { return current_load >= max; } false } } #[cfg(test)] mod tests { use crate::{ config::UserConfig, fixtures::{complex_config, hard_config, minimal_config}, resident::ResidentId, schedule::ShiftType, slot::{Day, ShiftPosition, Slot}, workload::{WorkloadBounds, WorkloadTracker}, }; use rstest::rstest; // Testing WorkloadBounds #[rstest] fn test_max_workloads(mut minimal_config: UserConfig) { minimal_config.update_month(4); let bounds = WorkloadBounds::new_with_config(&minimal_config); assert_eq!(9, bounds.max_workloads[&ResidentId(1)]); assert_eq!(9, bounds.max_workloads[&ResidentId(2)]); assert_eq!(9, bounds.max_workloads[&ResidentId(3)]); assert_eq!(9, bounds.max_workloads[&ResidentId(4)]); assert_eq!(9, bounds.max_workloads[&ResidentId(5)]); } #[rstest] fn test_calculate_max_workloads_minimal(minimal_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds.calculate_max_workloads(&minimal_config.residents, minimal_config.total_slots); assert_eq!(9, *bounds.max_workloads.get(&ResidentId(1)).unwrap()); assert_eq!(9, *bounds.max_workloads.get(&ResidentId(2)).unwrap()); assert_eq!(9, *bounds.max_workloads.get(&ResidentId(3)).unwrap()); assert_eq!(9, *bounds.max_workloads.get(&ResidentId(4)).unwrap()); assert_eq!(9, *bounds.max_workloads.get(&ResidentId(5)).unwrap()); } #[rstest] fn test_calculate_max_workloads_complex(complex_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds.calculate_max_workloads(&complex_config.residents, complex_config.total_slots); assert_eq!(3, *bounds.max_workloads.get(&ResidentId(1)).unwrap()); assert_eq!(3, *bounds.max_workloads.get(&ResidentId(2)).unwrap()); assert_eq!(3, *bounds.max_workloads.get(&ResidentId(3)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(4)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(5)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(6)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(7)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(8)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(9)).unwrap()); } #[rstest] fn test_calculate_max_workloads_hard(hard_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds.calculate_max_workloads(&hard_config.residents, hard_config.total_slots); assert_eq!(5, *bounds.max_workloads.get(&ResidentId(6)).unwrap()); assert_eq!(5, *bounds.max_workloads.get(&ResidentId(7)).unwrap()); assert_eq!(5, *bounds.max_workloads.get(&ResidentId(8)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(1)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(2)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(3)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(4)).unwrap()); assert_eq!(6, *bounds.max_workloads.get(&ResidentId(5)).unwrap()); } #[rstest] fn test_calculate_max_holiday_shifts_complex(complex_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds.calculate_max_holiday_shifts( &complex_config.residents, complex_config.total_holiday_slots, ); for i in 1..=9 { assert_eq!(2, *bounds.max_holiday_shifts.get(&ResidentId(i)).unwrap()); } } #[rstest] fn test_calculate_max_holiday_shifts_minimal(minimal_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds.calculate_max_holiday_shifts( &minimal_config.residents, minimal_config.total_holiday_slots, ); for i in 1..=5 { assert_eq!(3, *bounds.max_holiday_shifts.get(&ResidentId(i)).unwrap()); } } #[rstest] fn test_calculate_max_holiday_shifts_hard(hard_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds .calculate_max_holiday_shifts(&hard_config.residents, hard_config.total_holiday_slots); for i in 1..=8 { assert_eq!(2, *bounds.max_holiday_shifts.get(&ResidentId(i)).unwrap()); } } #[rstest] fn test_calculate_max_by_shift_type_minimal(minimal_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds.calculate_max_workloads(&minimal_config.residents, minimal_config.total_slots); bounds.calculate_max_by_shift_type(&minimal_config.residents); let m = bounds.max_by_shift_type; assert_eq!(4, *m.get(&(ResidentId(1), ShiftType::OpenFirst)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(1), ShiftType::OpenSecond)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(1), ShiftType::Closed)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(2), ShiftType::OpenFirst)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(2), ShiftType::OpenSecond)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(2), ShiftType::Closed)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(3), ShiftType::OpenFirst)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(3), ShiftType::OpenSecond)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(3), ShiftType::Closed)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(4), ShiftType::OpenFirst)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(4), ShiftType::OpenSecond)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(4), ShiftType::Closed)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(5), ShiftType::OpenFirst)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(5), ShiftType::OpenSecond)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(5), ShiftType::Closed)).unwrap()); } #[rstest] fn test_calculate_max_by_shift_type_complex(complex_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds.calculate_max_workloads(&complex_config.residents, complex_config.total_slots); bounds.calculate_max_by_shift_type(&complex_config.residents); let m = bounds.max_by_shift_type; assert_eq!(2, *m.get(&(ResidentId(1), ShiftType::OpenFirst)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(1), ShiftType::OpenSecond)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(1), ShiftType::Closed)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(2), ShiftType::OpenFirst)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(2), ShiftType::OpenSecond)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(2), ShiftType::Closed)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(3), ShiftType::OpenFirst)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(3), ShiftType::OpenSecond)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(3), ShiftType::Closed)).unwrap()); assert_eq!(0, *m.get(&(ResidentId(4), ShiftType::OpenFirst)).unwrap()); assert_eq!(0, *m.get(&(ResidentId(4), ShiftType::OpenSecond)).unwrap()); assert_eq!(6, *m.get(&(ResidentId(4), ShiftType::Closed)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(5), ShiftType::OpenFirst)).unwrap()); assert_eq!(4, *m.get(&(ResidentId(5), ShiftType::OpenSecond)).unwrap()); assert_eq!(0, *m.get(&(ResidentId(5), ShiftType::Closed)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(6), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(6), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(6), ShiftType::Closed)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(7), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(7), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(7), ShiftType::Closed)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(8), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(8), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(8), ShiftType::Closed)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(9), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(9), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(9), ShiftType::Closed)).unwrap()); } #[rstest] fn test_calculate_max_by_shift_type_hard(hard_config: UserConfig) { let mut bounds = WorkloadBounds::default(); bounds.calculate_max_workloads(&hard_config.residents, hard_config.total_slots); bounds.calculate_max_by_shift_type(&hard_config.residents); let m = bounds.max_by_shift_type; assert_eq!(3, *m.get(&(ResidentId(1), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(1), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(1), ShiftType::Closed)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(2), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(2), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(2), ShiftType::Closed)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(3), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(3), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(3), ShiftType::Closed)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(4), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(4), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(4), ShiftType::Closed)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(5), ShiftType::OpenFirst)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(5), ShiftType::OpenSecond)).unwrap()); assert_eq!(3, *m.get(&(ResidentId(5), ShiftType::Closed)).unwrap()); assert_eq!(0, *m.get(&(ResidentId(6), ShiftType::OpenFirst)).unwrap()); assert_eq!(5, *m.get(&(ResidentId(6), ShiftType::OpenSecond)).unwrap()); assert_eq!(0, *m.get(&(ResidentId(6), ShiftType::Closed)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(7), ShiftType::OpenFirst)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(7), ShiftType::OpenSecond)).unwrap()); assert_eq!(2, *m.get(&(ResidentId(7), ShiftType::Closed)).unwrap()); assert_eq!(0, *m.get(&(ResidentId(8), ShiftType::OpenFirst)).unwrap()); assert_eq!(5, *m.get(&(ResidentId(8), ShiftType::OpenSecond)).unwrap()); assert_eq!(0, *m.get(&(ResidentId(8), ShiftType::Closed)).unwrap()); } // Testing WorkloadTracker #[rstest] fn test_reached_workload_limit(minimal_config: UserConfig) { let mut tracker = WorkloadTracker::default(); let r_id = ResidentId(1); let mut bounds = WorkloadBounds::default(); bounds.max_workloads.insert(r_id, 1); let slot_1 = Slot::new(Day(1), ShiftPosition::First); let slot_2 = Slot::new(Day(2), ShiftPosition::First); assert!(!tracker.reached_workload_limit(&bounds, &r_id)); tracker.insert(r_id, &minimal_config, slot_1); assert!(tracker.reached_workload_limit(&bounds, &r_id)); tracker.insert(r_id, &minimal_config, slot_2); assert!(tracker.reached_workload_limit(&bounds, &r_id)); } #[rstest] fn test_reached_holiday_limit(minimal_config: UserConfig) { let mut tracker = WorkloadTracker::default(); let r_id = ResidentId(1); let mut bounds = WorkloadBounds::default(); bounds.max_holiday_shifts.insert(r_id, 1); let sat = Slot::new(Day(11), ShiftPosition::First); let sun = Slot::new(Day(12), ShiftPosition::First); assert!(!tracker.reached_holiday_limit(&bounds, &r_id)); tracker.insert(r_id, &minimal_config, sat); assert!(tracker.reached_holiday_limit(&bounds, &r_id)); tracker.insert(r_id, &minimal_config, sun); assert!(tracker.reached_holiday_limit(&bounds, &r_id)); } #[rstest] fn test_reached_shift_type_limit(minimal_config: UserConfig) { let mut tracker = WorkloadTracker::default(); let r_id = ResidentId(1); let mut bounds = WorkloadBounds::default(); bounds .max_by_shift_type .insert((r_id, ShiftType::OpenFirst), 1); let slot_1 = Slot::new(Day(1), ShiftPosition::First); let slot_2 = Slot::new(Day(3), ShiftPosition::First); let open_first = ShiftType::OpenFirst; assert!(!tracker.reached_shift_type_limit(&bounds, &r_id, open_first)); tracker.insert(r_id, &minimal_config, slot_1); assert!(tracker.reached_shift_type_limit(&bounds, &r_id, open_first)); tracker.insert(r_id, &minimal_config, slot_2); assert!(tracker.reached_shift_type_limit(&bounds, &r_id, open_first)); } #[rstest] fn test_backtracking_state(minimal_config: UserConfig) { let mut tracker = WorkloadTracker::default(); let r_id = ResidentId(1); let sat = Slot::new(Day(11), ShiftPosition::First); let sun = Slot::new(Day(12), ShiftPosition::First); let open_first = ShiftType::OpenFirst; let closed = ShiftType::Closed; tracker.insert(r_id, &minimal_config, sat); assert_eq!(1, tracker.current_workload(&r_id)); assert_eq!(1, tracker.current_holiday_workload(&r_id)); assert_eq!(1, tracker.current_shift_type_workload(&r_id, open_first)); tracker.insert(r_id, &minimal_config, sun); assert_eq!(2, tracker.current_workload(&r_id)); assert_eq!(2, tracker.current_holiday_workload(&r_id)); assert_eq!(1, tracker.current_shift_type_workload(&r_id, open_first)); assert_eq!(1, tracker.current_shift_type_workload(&r_id, closed)); tracker.remove(r_id, &minimal_config, sun); assert_eq!(1, tracker.current_workload(&r_id)); assert_eq!(1, tracker.current_holiday_workload(&r_id)); assert_eq!(1, tracker.current_shift_type_workload(&r_id, open_first)); assert_eq!(0, tracker.current_shift_type_workload(&r_id, closed)); tracker.remove(r_id, &minimal_config, sat); assert_eq!(0, tracker.current_workload(&r_id)); assert_eq!(0, tracker.current_holiday_workload(&r_id)); assert_eq!(0, tracker.current_shift_type_workload(&r_id, open_first)); assert_eq!(0, tracker.current_shift_type_workload(&r_id, closed)); } }