use std::sync::atomic::{AtomicBool, Ordering}; use crate::{ config::{ToxicPair, UserConfig}, errors::SearchError, resident::ResidentId, schedule::MonthlySchedule, slot::Slot, timer::Timer, workload::{WorkloadBounds, WorkloadTracker}, }; use rand::{rngs::SmallRng, seq::SliceRandom, SeedableRng}; use rayon::{ current_thread_index, iter::{IntoParallelRefIterator, ParallelIterator}, }; pub struct Scheduler { pub config: UserConfig, pub bounds: WorkloadBounds, pub timer: Timer, } impl Scheduler { pub fn new(config: UserConfig, bounds: WorkloadBounds) -> Self { Self { config, bounds, timer: Timer::default(), } } pub fn new_with_config(config: UserConfig) -> Self { let bounds = WorkloadBounds::new_with_config(&config); Self { config, bounds, timer: Timer::default(), } } pub fn run( &self, schedule: &mut MonthlySchedule, tracker: &mut WorkloadTracker, ) -> Result { schedule.prefill(&self.config); for (slot, res_id) in schedule.0.iter() { tracker.insert(*res_id, &self.config, *slot); } //TODO: add validation // find first non-manually-filled slot let slot = (0..=self.config.total_slots) .find(|&slot_idx| !schedule.0.contains_key(&Slot::from(slot_idx))) .map(Slot::from) .ok_or(SearchError::ScheduleFull)?; let resident_ids = self.valid_residents(slot, schedule, tracker); let solved_in_thread = AtomicBool::new(false); let sovled_state = resident_ids.par_iter().find_map_any(|&id| { let mut local_schedule = schedule.clone(); let mut local_tracker = tracker.clone(); local_schedule.insert(slot, id); local_tracker.insert(id, &self.config, slot); let solved = self.search( &mut local_schedule, &mut local_tracker, slot.next(), &solved_in_thread, ); match solved { Ok(true) => Some((local_schedule, local_tracker)), Ok(false) => None, Err(e) => { let thread_id = current_thread_index().unwrap(); log::log!(e.log_level(), "Thread Id: [{}] {}", thread_id, e); None } } }); if let Some((solved_schedule, solved_tracker)) = sovled_state { *schedule = solved_schedule; *tracker = solved_tracker; return Ok(true); } Ok(false) } /// DFS where maximum depth is calculated by total_days_of_month + odd_days_of_month each node is called a slot /// Starts with schedule partially completed from the user interface /// Ends with a full schedule following restrictions and fairness pub fn search( &self, schedule: &mut MonthlySchedule, tracker: &mut WorkloadTracker, slot: Slot, solved_in_thread: &AtomicBool, ) -> Result { if solved_in_thread.load(Ordering::Relaxed) { return Err(SearchError::SolutionFound); } if self.timer.limit_exceeded() { return Err(SearchError::Timeout); } if schedule.has_resident_in_consecutive_days(&slot.previous()) { return Ok(false); } if self.found_solution(slot) { solved_in_thread.store(true, Ordering::Relaxed); return Ok(true); } if schedule.is_slot_manually_assigned(&slot) { return self.search(schedule, tracker, slot.next(), solved_in_thread); } let mut rng = SmallRng::from_rng(&mut rand::rng()); let mut valid_resident_ids = self.valid_residents(slot, schedule, tracker); valid_resident_ids.shuffle(&mut rng); valid_resident_ids.sort_by_key(|res_id| { let type_count = tracker.current_shift_type_workload(res_id, slot.shift_type()); let workload = tracker.current_workload(res_id); (type_count, workload) }); for id in valid_resident_ids { schedule.insert(slot, id); tracker.insert(id, &self.config, slot); if self.search(schedule, tracker, slot.next(), solved_in_thread)? { return Ok(true); } schedule.remove(slot); tracker.remove(id, &self.config, slot); } Ok(false) } pub fn found_solution(&self, slot: Slot) -> bool { slot.greater_than(self.config.total_days) } /// Return all valid residents for the current slot pub fn valid_residents( &self, slot: Slot, schedule: &MonthlySchedule, tracker: &WorkloadTracker, ) -> Vec { let is_holiday_slot = self.config.is_holiday_or_weekend_slot(slot); let other_resident_id = slot .other_position() .and_then(|partner_slot| schedule.get_resident_id(&partner_slot)); self.config .residents .iter() .filter(|r| { if let Some(other_id) = other_resident_id { if &r.id == other_id { return false; } if self .config .toxic_pairs .iter() .any(|tp| tp.matches(&ToxicPair::from((r.id, *other_id)))) { return false; } } !r.negative_shifts.contains(&slot.day) && r.allowed_types.contains(&slot.shift_type()) && !tracker.reached_workload_limit(&self.bounds, &r.id) && (!is_holiday_slot || !tracker.reached_holiday_limit(&self.bounds, &r.id)) && !tracker.reached_shift_type_limit(&self.bounds, &r.id, slot.shift_type()) }) .map(|r| r.id) .collect() } }