diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index aecac62..d6d78ea 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -17,8 +17,8 @@ const YEAR: i32 = 2026; pub struct ToxicPair((ResidentId, ResidentId)); impl ToxicPair { - pub fn new(res_id_1: u8, res_id_2: u8) -> Self { - Self((ResidentId(res_id_1), ResidentId(res_id_2))) + pub fn new(r_id_1: u8, r_id_2: u8) -> Self { + Self((ResidentId(r_id_1), ResidentId(r_id_2))) } pub fn matches(&self, other: &ToxicPair) -> bool { @@ -134,6 +134,14 @@ impl UserConfig { } supply } + + pub fn flexibility_map(&self) -> HashMap { + let mut map = HashMap::new(); + for r in &self.residents { + map.insert(r.id, r.allowed_types.len() as u8); + } + map + } } impl Default for UserConfig { diff --git a/src-tauri/src/scheduler.rs b/src-tauri/src/scheduler.rs index b107942..2591237 100644 --- a/src-tauri/src/scheduler.rs +++ b/src-tauri/src/scheduler.rs @@ -47,8 +47,8 @@ impl Scheduler { tracker: &mut WorkloadTracker, ) -> Result { schedule.prefill(&self.config); - for (slot, res_id) in schedule.0.iter() { - tracker.insert(*res_id, &self.config, *slot); + for (slot, r_id) in schedule.0.iter() { + tracker.insert(*r_id, &self.config, *slot); } //TODO: add validation @@ -59,10 +59,13 @@ impl Scheduler { .map(Slot::from) .ok_or(SearchError::ScheduleFull)?; - let resident_ids = self.valid_residents(slot, schedule, tracker); + let mut valid_resident_ids = self.valid_residents(slot, schedule, tracker); + + self.sort_residents(&mut valid_resident_ids, tracker, slot); + let solved_in_thread = AtomicBool::new(false); - let sovled_state = resident_ids.par_iter().find_map_any(|&id| { + let sovled_state = valid_resident_ids.par_iter().find_map_any(|&id| { let mut local_schedule = schedule.clone(); let mut local_tracker = tracker.clone(); @@ -127,14 +130,9 @@ impl Scheduler { return self.search(schedule, tracker, slot.next(), solved_in_thread); } - let mut rng = SmallRng::from_rng(&mut rand::rng()); let mut valid_resident_ids = self.valid_residents(slot, schedule, tracker); - valid_resident_ids.shuffle(&mut rng); - valid_resident_ids.sort_by_key(|res_id| { - let type_count = tracker.current_shift_type_workload(res_id, slot.shift_type()); - let workload = tracker.current_workload(res_id); - (type_count, workload) - }); + + self.sort_residents(&mut valid_resident_ids, tracker, slot); for id in valid_resident_ids { schedule.insert(slot, id); @@ -151,12 +149,12 @@ impl Scheduler { Ok(false) } - pub fn found_solution(&self, slot: Slot) -> bool { + fn found_solution(&self, slot: Slot) -> bool { slot.greater_than(self.config.total_days) } /// Return all valid residents for the current slot - pub fn valid_residents( + fn valid_residents( &self, slot: Slot, schedule: &MonthlySchedule, @@ -195,4 +193,22 @@ impl Scheduler { .map(|r| r.id) .collect() } + + fn sort_residents( + &self, + resident_ids: &mut Vec, + tracker: &WorkloadTracker, + slot: Slot, + ) { + let flex_map = self.config.flexibility_map(); + let mut rng = SmallRng::from_rng(&mut rand::rng()); + resident_ids.shuffle(&mut rng); + resident_ids.sort_by_key(|r_id| { + let type_workload = tracker.current_shift_type_workload(r_id, slot.shift_type()); + let holiday_workload = tracker.current_holiday_workload(r_id); + let workload = tracker.current_workload(r_id); + let flex = flex_map.get(r_id).unwrap(); + (flex, type_workload, workload, holiday_workload) + }); + } }