Files
rota/src-tauri/src/scheduler.rs

199 lines
6.1 KiB
Rust

use std::sync::atomic::{AtomicBool, Ordering};
use crate::{
config::{ToxicPair, UserConfig},
errors::SearchError,
resident::ResidentId,
schedule::MonthlySchedule,
slot::Slot,
timer::Timer,
workload::{WorkloadBounds, WorkloadTracker},
};
use rand::{rngs::SmallRng, seq::SliceRandom, SeedableRng};
use rayon::{
current_thread_index,
iter::{IntoParallelRefIterator, ParallelIterator},
};
pub struct Scheduler {
pub config: UserConfig,
pub bounds: WorkloadBounds,
pub timer: Timer,
}
impl Scheduler {
pub fn new(config: UserConfig, bounds: WorkloadBounds) -> Self {
Self {
config,
bounds,
timer: Timer::default(),
}
}
pub fn new_with_config(config: UserConfig) -> Self {
let bounds = WorkloadBounds::new_with_config(&config);
Self {
config,
bounds,
timer: Timer::default(),
}
}
pub fn run(
&self,
schedule: &mut MonthlySchedule,
tracker: &mut WorkloadTracker,
) -> Result<bool, SearchError> {
schedule.prefill(&self.config);
for (slot, res_id) in schedule.0.iter() {
tracker.insert(*res_id, &self.config, *slot);
}
//TODO: add validation
// find first non-manually-filled slot
let slot = (0..=self.config.total_slots)
.find(|&slot_idx| !schedule.0.contains_key(&Slot::from(slot_idx)))
.map(Slot::from)
.ok_or(SearchError::ScheduleFull)?;
let resident_ids = self.valid_residents(slot, schedule, tracker);
let solved_in_thread = AtomicBool::new(false);
let sovled_state = resident_ids.par_iter().find_map_any(|&id| {
let mut local_schedule = schedule.clone();
let mut local_tracker = tracker.clone();
local_schedule.insert(slot, id);
local_tracker.insert(id, &self.config, slot);
let solved = self.search(
&mut local_schedule,
&mut local_tracker,
slot.next(),
&solved_in_thread,
);
match solved {
Ok(true) => Some((local_schedule, local_tracker)),
Ok(false) => None,
Err(e) => {
let thread_id = current_thread_index().unwrap();
log::log!(e.log_level(), "Thread Id: [{}] {}", thread_id, e);
None
}
}
});
if let Some((solved_schedule, solved_tracker)) = sovled_state {
*schedule = solved_schedule;
*tracker = solved_tracker;
return Ok(true);
}
Ok(false)
}
/// DFS where maximum depth is calculated by total_days_of_month + odd_days_of_month each node is called a slot
/// Starts with schedule partially completed from the user interface
/// Ends with a full schedule following restrictions and fairness
pub fn search(
&self,
schedule: &mut MonthlySchedule,
tracker: &mut WorkloadTracker,
slot: Slot,
solved_in_thread: &AtomicBool,
) -> Result<bool, SearchError> {
if solved_in_thread.load(Ordering::Relaxed) {
return Err(SearchError::SolutionFound);
}
if self.timer.limit_exceeded() {
return Err(SearchError::Timeout);
}
if schedule.has_resident_in_consecutive_days(&slot.previous()) {
return Ok(false);
}
if self.found_solution(slot) {
solved_in_thread.store(true, Ordering::Relaxed);
return Ok(true);
}
if schedule.is_slot_manually_assigned(&slot) {
return self.search(schedule, tracker, slot.next(), solved_in_thread);
}
let mut rng = SmallRng::from_rng(&mut rand::rng());
let mut valid_resident_ids = self.valid_residents(slot, schedule, tracker);
valid_resident_ids.shuffle(&mut rng);
valid_resident_ids.sort_by_key(|res_id| {
let type_count = tracker.current_shift_type_workload(res_id, slot.shift_type());
let workload = tracker.current_workload(res_id);
(type_count, workload)
});
for id in valid_resident_ids {
schedule.insert(slot, id);
tracker.insert(id, &self.config, slot);
if self.search(schedule, tracker, slot.next(), solved_in_thread)? {
return Ok(true);
}
schedule.remove(slot);
tracker.remove(id, &self.config, slot);
}
Ok(false)
}
pub fn found_solution(&self, slot: Slot) -> bool {
slot.greater_than(self.config.total_days)
}
/// Return all valid residents for the current slot
pub fn valid_residents(
&self,
slot: Slot,
schedule: &MonthlySchedule,
tracker: &WorkloadTracker,
) -> Vec<ResidentId> {
let is_holiday_slot = self.config.is_holiday_or_weekend_slot(slot);
let other_resident_id = slot
.other_position()
.and_then(|partner_slot| schedule.get_resident_id(&partner_slot));
self.config
.residents
.iter()
.filter(|r| {
if let Some(other_id) = other_resident_id {
if &r.id == other_id {
return false;
}
if self
.config
.toxic_pairs
.iter()
.any(|tp| tp.matches(&ToxicPair::from((r.id, *other_id))))
{
return false;
}
}
!r.negative_shifts.contains(&slot.day)
&& r.allowed_types.contains(&slot.shift_type())
&& !tracker.reached_workload_limit(&self.bounds, &r.id)
&& (!is_holiday_slot || !tracker.reached_holiday_limit(&self.bounds, &r.id))
&& !tracker.reached_shift_type_limit(&self.bounds, &r.id, slot.shift_type())
})
.map(|r| r.id)
.collect()
}
}