Use maps to track workload progress instead of recalculating them at every step of the search, calculate total days/slots once, add integration tests, add log folder

This commit is contained in:
2026-01-17 18:41:43 +02:00
parent 908f114e54
commit 5bad63e8a7
10 changed files with 819 additions and 574 deletions

View File

@@ -14,6 +14,12 @@ lint:
cd {{tauri_path}} && cargo clippy
test:
cd {{tauri_path}} && cargo test --lib --release
test-integration:
cd {{tauri_path}} && cargo test --test integration --release
test-all:
cd {{tauri_path}} && cargo test --release -- --nocapture
# profile:

View File

@@ -1,245 +0,0 @@
use std::collections::HashMap;
use crate::{
config::UserConfig,
resident::{Resident, ResidentId},
schedule::ShiftType,
slot::Day,
};
pub struct WorkloadBounds {
pub max_workloads: HashMap<ResidentId, u8>,
pub max_holiday_shifts: HashMap<ResidentId, u8>,
pub max_by_shift_type: HashMap<(ResidentId, ShiftType), u8>,
pub min_by_shift_type: HashMap<(ResidentId, ShiftType), u8>,
}
impl WorkloadBounds {
pub fn new() -> Self {
Self {
max_workloads: HashMap::new(),
max_holiday_shifts: HashMap::new(),
max_by_shift_type: HashMap::new(),
min_by_shift_type: HashMap::new(),
}
}
pub fn new_with_config(config: &UserConfig) -> Self {
let mut bounds = Self::new();
bounds.calculate_max_workloads(config);
bounds.calculate_max_holiday_shifts(config);
bounds.calculate_max_by_shift_type(config);
bounds
}
/// get map with total amount of slots in a month for each type of shift
pub fn get_initial_supply(&self, config: &UserConfig) -> HashMap<ShiftType, u8> {
let mut supply = HashMap::new();
let total_days = config.total_days();
for d in 1..=total_days {
if Day(d).is_open_shift() {
*supply.entry(ShiftType::OpenFirst).or_insert(0) += 1;
*supply.entry(ShiftType::OpenSecond).or_insert(0) += 1;
} else {
*supply.entry(ShiftType::Closed).or_insert(0) += 1;
}
}
supply
}
/// this is called after the user config params have been initialized, can be done with the builder (lite) pattern
/// initialize a hashmap for O(1) search calls for the residents' max workload
pub fn calculate_max_workloads(&mut self, config: &UserConfig) {
let total_slots = config.total_slots();
let max_shifts_sum: usize = config
.residents
.iter()
.map(|r| r.max_shifts.unwrap_or(0))
.sum();
let residents_without_max_shifts: Vec<_> = config
.residents
.iter()
.filter(|r| r.max_shifts.is_none())
.collect();
let residents_without_max_shifts_size = residents_without_max_shifts.len();
if residents_without_max_shifts_size == 0 {
for r in &config.residents {
self.max_workloads
.insert(r.id.clone(), r.max_shifts.unwrap_or(0) as u8);
}
return;
}
// Untested scenario: Resident has manual max_shifts and also reduced workload flag
let total_reduced_loads: usize = residents_without_max_shifts
.iter()
.filter(|r| r.reduced_load)
.count();
let max_shifts_ceiling = (total_slots - max_shifts_sum as u8 + total_reduced_loads as u8)
.div_ceil(residents_without_max_shifts_size as u8);
for r in &config.residents {
let max_shifts = if let Some(manual_max_shifts) = r.max_shifts {
manual_max_shifts as u8
} else if r.reduced_load {
max_shifts_ceiling - 1
} else {
max_shifts_ceiling
};
self.max_workloads.insert(r.id.clone(), max_shifts);
}
}
///
pub fn calculate_max_holiday_shifts(&mut self, config: &UserConfig) {
let total_slots = config.total_slots();
let total_holiday_slots = config.total_holiday_slots();
for r in &config.residents {
let workload_limit = *self.max_workloads.get(&r.id).unwrap_or(&0);
let share = (workload_limit as f32 / total_slots as f32) * total_holiday_slots as f32;
let holiday_limit = share.ceil() as u8;
self.max_holiday_shifts.insert(r.id.clone(), holiday_limit);
}
}
///
pub fn calculate_max_by_shift_type(&mut self, config: &UserConfig) {
let mut global_supply = self.get_initial_supply(config);
let mut local_limits = HashMap::new();
let mut local_thresholds = HashMap::new();
let all_shift_types = [
ShiftType::OpenFirst,
ShiftType::OpenSecond,
ShiftType::Closed,
];
// residents with 1 available shift types
for res in config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 1)
{
let shift_type = &res.allowed_types[0];
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0);
local_limits.insert((res.id.clone(), shift_type.clone()), total_limit);
local_thresholds.insert((res.id.clone(), shift_type.clone()), total_limit - 1);
for other_type in &all_shift_types {
if other_type != shift_type {
local_limits.insert((res.id.clone(), other_type.clone()), 0);
local_thresholds.insert((res.id.clone(), other_type.clone()), 0);
}
}
if let Some(s) = global_supply.get_mut(shift_type) {
*s = s.saturating_sub(total_limit)
}
}
// residents with 2 available shift types
for res in config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 2)
{
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0) as f32;
let per_type = (total_limit / 2.0).ceil() as u8;
for shift_type in &all_shift_types {
if res.allowed_types.contains(shift_type) {
local_limits.insert((res.id.clone(), shift_type.clone()), per_type);
local_thresholds.insert((res.id.clone(), shift_type.clone()), per_type - 1);
if let Some(s) = global_supply.get_mut(shift_type) {
*s = s.saturating_sub(per_type)
}
} else {
local_limits.insert((res.id.clone(), shift_type.clone()), 0);
local_thresholds.insert((res.id.clone(), shift_type.clone()), 0);
}
}
}
// residents with 3 available shift types
let res: Vec<&Resident> = config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 3)
.collect();
if !res.is_empty() {
for shift_type in &all_shift_types {
let remaining = *global_supply.get(shift_type).unwrap_or(&0);
let fair_slice = (remaining as f32 / res.len() as f32).ceil().max(0.0) as u8;
for res in &res {
local_limits.insert((res.id.clone(), shift_type.clone()), fair_slice);
local_thresholds.insert((res.id.clone(), shift_type.clone()), fair_slice - 1);
}
}
}
self.max_by_shift_type = local_limits;
self.min_by_shift_type = local_thresholds;
}
}
#[cfg(test)]
mod tests {
use rstest::{fixture, rstest};
use crate::{
bounds::WorkloadBounds,
config::UserConfig,
resident::{Resident, ResidentId},
};
#[fixture]
fn config() -> UserConfig {
UserConfig::default().with_residents(vec![
Resident::new("1", "Stefanos").with_max_shifts(2),
Resident::new("2", "Iordanis").with_max_shifts(2),
Resident::new("3", "Maria").with_reduced_load(),
Resident::new("4", "Veatriki"),
Resident::new("5", "Takis"),
])
}
#[rstest]
fn test_max_workloads(config: UserConfig) {
let bounds = WorkloadBounds::new_with_config(&config);
assert_eq!(bounds.max_workloads[&ResidentId("1".to_string())], 2);
assert_eq!(bounds.max_workloads[&ResidentId("2".to_string())], 2);
assert_eq!(bounds.max_workloads[&ResidentId("3".to_string())], 12);
assert_eq!(bounds.max_workloads[&ResidentId("4".to_string())], 13);
assert_eq!(bounds.max_workloads[&ResidentId("5".to_string())], 13);
}
#[rstest]
fn test_calculate_max_holiday_shifts(config: UserConfig) {
let bounds = WorkloadBounds::new_with_config(&config);
let stefanos_limit = *bounds
.max_holiday_shifts
.get(&ResidentId("1".to_string()))
.unwrap();
let iordanis_limit = *bounds
.max_holiday_shifts
.get(&ResidentId("2".to_string()))
.unwrap();
assert_eq!(stefanos_limit, 1);
assert_eq!(iordanis_limit, 1);
}
}

View File

@@ -8,7 +8,7 @@ use crate::{
const YEAR: i32 = 2026;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ToxicPair((ResidentId, ResidentId));
impl ToxicPair {
@@ -42,28 +42,47 @@ pub struct UserConfigDTO {
toxic_pairs: Vec<(String, String)>,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UserConfig {
pub month: Month,
pub year: i32,
pub holidays: Vec<usize>,
pub residents: Vec<Resident>,
pub toxic_pairs: Vec<ToxicPair>,
pub total_days: u8,
pub total_slots: u8,
pub total_holiday_slots: u8,
}
impl UserConfig {
pub fn new(month: usize) -> Self {
pub fn new(month: u8) -> Self {
let month = Month::try_from(month).unwrap();
let total_days = month.num_days(YEAR).unwrap();
let total_slots = (1..=total_days)
.map(|d| if Day(d).is_open_shift() { 2 } else { 1 })
.sum();
let total_holiday_slots = (1..=total_days)
.filter(|&d| Day(d).is_weekend(month.number_from_month(), YEAR))
.map(|d| if Day(d).is_open_shift() { 2 } else { 1 })
.sum();
Self {
month: Month::try_from(month as u8).unwrap(),
month,
year: YEAR,
holidays: vec![],
residents: vec![],
toxic_pairs: vec![],
total_days,
total_slots,
total_holiday_slots,
}
}
pub fn with_holidays(mut self, holidays: Vec<usize>) -> Self {
self.holidays = holidays;
self.total_holiday_slots = self.total_holiday_slots();
self
}
@@ -81,18 +100,8 @@ impl UserConfig {
self
}
pub fn total_days(&self) -> u8 {
self.month.num_days(self.year).unwrap()
}
pub fn total_slots(&self) -> u8 {
(1..=self.total_days())
.map(|d| if Day(d).is_open_shift() { 2 } else { 1 })
.sum()
}
pub fn total_holiday_slots(&self) -> u8 {
(1..=self.total_days())
fn total_holiday_slots(&self) -> u8 {
(1..=self.total_days)
.filter(|&d| self.is_holiday_or_weekend_slot(d))
.map(|d| if Day(d).is_open_shift() { 2 } else { 1 })
.sum()
@@ -107,20 +116,52 @@ impl UserConfig {
impl Default for UserConfig {
fn default() -> Self {
let month = Month::try_from(2).unwrap();
let total_days = month.num_days(YEAR).unwrap();
let total_slots = (1..=total_days)
.map(|d| if Day(d).is_open_shift() { 2 } else { 1 })
.sum();
let total_holiday_slots = (1..=total_days)
.filter(|&d| Day(d).is_weekend(month.number_from_month(), YEAR))
.map(|d| if Day(d).is_open_shift() { 2 } else { 1 })
.sum();
Self {
month: Month::try_from(2).unwrap(),
month,
year: YEAR,
holidays: vec![],
residents: vec![],
toxic_pairs: vec![],
total_days,
total_slots,
total_holiday_slots,
}
}
}
impl From<UserConfigDTO> for UserConfig {
fn from(value: UserConfigDTO) -> Self {
let month = Month::try_from(value.month as u8).unwrap();
let total_days = month.num_days(YEAR).unwrap();
let total_slots = (1..=total_days)
.map(|d| if Day(d).is_open_shift() { 2 } else { 1 })
.sum();
let total_holiday_slots = (1..=total_days)
.filter(|&d| {
Day(d).is_weekend(month.number_from_month(), value.year)
|| value.holidays.contains(&(d as usize))
})
.map(|d| if Day(d).is_open_shift() { 2 } else { 1 })
.sum();
Self {
month: Month::try_from(value.month as u8).unwrap(),
month,
year: value.year,
holidays: value.holidays,
residents: value.residents.into_iter().map(Resident::from).collect(),
@@ -129,32 +170,9 @@ impl From<UserConfigDTO> for UserConfig {
.into_iter()
.map(|p| ToxicPair::new(&p.0, &p.1))
.collect(),
total_days,
total_slots,
total_holiday_slots,
}
}
}
#[cfg(test)]
mod tests {
use rstest::{fixture, rstest};
use crate::{config::UserConfig, resident::Resident, schedule::MonthlySchedule};
#[fixture]
fn setup() -> (UserConfig, MonthlySchedule) {
let mut config = UserConfig::default();
let res_a = Resident::new("1", "Stefanos");
let res_b = Resident::new("2", "Iordanis");
config.add(res_a);
config.add(res_b);
let schedule = MonthlySchedule::new();
(config, schedule)
}
#[rstest]
fn test_total_holiday_slots() {
let config = UserConfig::default().with_holidays(vec![2, 3, 4]);
assert_eq!(16, config.total_holiday_slots());
}
}

View File

@@ -1,12 +1,12 @@
use std::{fs::File, io::Write};
use docx_rs::{Docx, Paragraph, Run, RunFonts, Table, TableCell, TableRow};
use log::info;
use crate::{
config::UserConfig,
schedule::MonthlySchedule,
slot::{month_to_greek, weekday_to_greek, Day, ShiftPosition, Slot},
workload::WorkloadTracker,
};
#[derive(Debug)]
@@ -16,13 +16,13 @@ pub enum FileType {
}
pub trait Export {
fn export(&self, file_type: FileType, config: &UserConfig);
fn export(&self, file_type: FileType, config: &UserConfig, tracker: &WorkloadTracker);
}
impl Export for MonthlySchedule {
fn export(&self, file_type: FileType, config: &UserConfig) {
fn export(&self, file_type: FileType, config: &UserConfig, tracker: &WorkloadTracker) {
match file_type {
FileType::Txt => self.export_as_txt(config),
FileType::Txt => self.export_as_txt(config, tracker),
FileType::Docx => self.export_as_doc(config),
};
@@ -36,19 +36,19 @@ impl Export for MonthlySchedule {
}
impl MonthlySchedule {
pub fn export_as_txt(&self, config: &UserConfig) -> String {
pub fn export_as_txt(&self, config: &UserConfig, tracker: &WorkloadTracker) -> String {
let file = File::create("schedule.txt").unwrap();
let mut writer = std::io::BufWriter::new(file);
writer
.write_all(self.pretty_print(config).as_bytes())
.expect("Failed to write to buffer");
.expect("Failed to write schedule");
writer
.write_all(self.report(config).as_bytes())
.expect("Failed to write to buffer");
.write_all(self.report(config, tracker).as_bytes())
.expect("Failed to write report");
writer.flush().expect("Failed to flush buffer");
info!("im here");
"ok".to_string()
}
@@ -91,7 +91,7 @@ impl MonthlySchedule {
let mut residents_table = Table::new(vec![]);
for d in 1..=config.total_days() {
for d in 1..=config.total_days {
let day = Day(d);
let is_weekend = day.is_weekend(config.month.number_from_month(), config.year);
let slot_first = Slot::new(Day(d), ShiftPosition::First);
@@ -164,8 +164,11 @@ mod tests {
use rstest::{fixture, rstest};
use crate::{
bounds::WorkloadBounds, config::UserConfig, resident::Resident, schedule::MonthlySchedule,
config::UserConfig,
resident::Resident,
schedule::MonthlySchedule,
scheduler::Scheduler,
workload::{WorkloadBounds, WorkloadTracker},
};
#[fixture]
@@ -195,9 +198,18 @@ mod tests {
Scheduler::new(config, bounds)
}
#[fixture]
fn tracker() -> WorkloadTracker {
WorkloadTracker::default()
}
#[rstest]
pub fn test_export_as_doc(mut schedule: MonthlySchedule, scheduler: Scheduler) {
scheduler.run(&mut schedule);
pub fn test_export_as_doc(
mut schedule: MonthlySchedule,
mut tracker: WorkloadTracker,
scheduler: Scheduler,
) {
scheduler.run(&mut schedule, &mut tracker);
schedule.export_as_doc(&scheduler.config);
}
}

View File

@@ -1,23 +1,24 @@
use std::sync::Mutex;
use std::{env::home_dir, sync::Mutex};
use crate::{
bounds::WorkloadBounds,
config::{UserConfig, UserConfigDTO},
export::{Export, FileType},
schedule::MonthlySchedule,
scheduler::Scheduler,
workload::{WorkloadBounds, WorkloadTracker},
};
mod bounds;
mod config;
mod export;
mod resident;
mod schedule;
mod scheduler;
mod slot;
pub mod config;
pub mod export;
pub mod resident;
pub mod schedule;
pub mod scheduler;
pub mod slot;
pub mod workload;
struct AppState {
schedule: Mutex<MonthlySchedule>,
tracker: Mutex<WorkloadTracker>,
}
/// argument to this must be the rota state including all
@@ -27,33 +28,52 @@ struct AppState {
fn generate(config: UserConfigDTO, state: tauri::State<'_, AppState>) -> MonthlySchedule {
let config = UserConfig::from(config);
let mut schedule = MonthlySchedule::new();
let mut tracker = WorkloadTracker::default();
let bounds = WorkloadBounds::new_with_config(&config);
let scheduler = Scheduler::new(config, bounds);
scheduler.run(&mut schedule);
scheduler.run(&mut schedule, &mut tracker);
let mut internal_schedule = state.schedule.lock().unwrap();
*internal_schedule = schedule.clone();
let mut internal_tracker = state.tracker.lock().unwrap();
*internal_tracker = tracker.clone();
schedule
}
/// export into docx
#[tauri::command]
fn export(config: UserConfigDTO, state: tauri::State<'_, AppState>) {
let config = UserConfig::from(config);
let schedule = state.schedule.lock().unwrap();
schedule.export(FileType::Docx, &config);
let tracker = state.tracker.lock().unwrap();
schedule.export(FileType::Docx, &config, &tracker);
schedule.export(FileType::Txt, &config, &tracker);
}
#[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() {
let log_dir = home_dir().unwrap().join(".rota_logs");
if let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Cannot create log folder: {}", e);
}
tauri::Builder::default()
.manage(AppState {
schedule: Mutex::new(MonthlySchedule::new()),
tracker: Mutex::new(WorkloadTracker::default()),
})
.plugin(
tauri_plugin_log::Builder::new()
.level(tauri_plugin_log::log::LevelFilter::Info)
.targets([tauri_plugin_log::Target::new(
tauri_plugin_log::TargetKind::Folder {
path: log_dir,
file_name: Some("rota".to_string()), // Note: Plugin adds .log automatically
},
)])
.level(log::LevelFilter::Info)
.build(),
)
.plugin(tauri_plugin_opener::init())

View File

@@ -2,22 +2,22 @@ use serde::{ser::SerializeMap, Deserialize, Serialize};
use std::collections::HashMap;
use crate::{
bounds::WorkloadBounds,
config::{ToxicPair, UserConfig},
resident::ResidentId,
slot::{weekday_to_greek, Day, ShiftPosition, Slot},
workload::{WorkloadBounds, WorkloadTracker},
};
use serde::Serializer;
/// each slot has one resident
/// a day can span between 1 or 2 slots depending on if it is open(odd) or closed(even)
#[derive(Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone, Default)]
pub struct MonthlySchedule(pub HashMap<Slot, ResidentId>);
impl MonthlySchedule {
pub fn new() -> Self {
Self(HashMap::new())
Self::default()
}
pub fn prefill(&mut self, config: &UserConfig) {
@@ -32,66 +32,6 @@ impl MonthlySchedule {
self.0.get(slot)
}
pub fn current_workload(&self, resident_id: &ResidentId) -> usize {
self.0
.values()
.filter(|res_id| res_id == &resident_id)
.count()
}
pub fn current_holiday_workload(&self, resident_id: &ResidentId, config: &UserConfig) -> usize {
self.0
.iter()
.filter(|(slot, res_id)| {
res_id == &resident_id && config.is_holiday_or_weekend_slot(slot.day.0)
})
.count()
}
pub fn count_shifts(&self, resident_id: &ResidentId, shift_type: Option<ShiftType>) -> usize {
self.0
.iter()
.filter(|(&slot, id)| {
if id != &resident_id {
return false;
}
match &shift_type {
None => true,
Some(target) => {
let actual_type = if slot.is_open_shift() {
match slot.position {
ShiftPosition::First => ShiftType::OpenFirst,
ShiftPosition::Second => ShiftType::OpenSecond,
}
} else {
ShiftType::Closed
};
actual_type == *target
}
}
})
.count()
}
pub fn is_per_shift_threshold_met(&self, config: &UserConfig, bounds: &WorkloadBounds) -> bool {
for res in &config.residents {
for stype in [
ShiftType::OpenFirst,
ShiftType::OpenSecond,
ShiftType::Closed,
] {
let count = self.count_shifts(&res.id, Some(stype.clone()));
if let Some(&threshold) = bounds.min_by_shift_type.get(&(res.id.clone(), stype)) {
if count < threshold as usize {
return false;
}
}
}
}
true
}
pub fn insert(&mut self, slot: Slot, resident_id: &ResidentId) {
self.0.insert(slot, resident_id.clone());
}
@@ -116,16 +56,21 @@ impl MonthlySchedule {
slot: &Slot,
config: &UserConfig,
bounds: &WorkloadBounds,
tracker: &WorkloadTracker,
) -> bool {
self.same_resident_in_consecutive_days(slot)
let resident_id = match self.get_resident_id(slot) {
Some(id) => id,
None => return false,
};
self.has_resident_in_consecutive_days(slot)
|| self.has_toxic_pair(slot, config)
|| self.is_workload_unbalanced(slot, config, bounds)
|| self.is_holiday_workload_imbalanced(slot, config, bounds)
|| self.is_shift_type_distribution_unfair(slot, bounds)
|| tracker.is_total_workload_exceeded(bounds, resident_id)
|| tracker.is_holiday_workload_exceeded(bounds, resident_id)
|| tracker.is_max_shift_type_exceeded(bounds, resident_id, slot)
}
/// same_resident_in_consecutive_days
pub fn same_resident_in_consecutive_days(&self, slot: &Slot) -> bool {
pub fn has_resident_in_consecutive_days(&self, slot: &Slot) -> bool {
if slot.day == Day(1) {
return false;
}
@@ -144,7 +89,6 @@ impl MonthlySchedule {
.any(|s| self.get_resident_id(s) == self.get_resident_id(slot))
}
/// has_toxic_pair
pub fn has_toxic_pair(&self, slot: &Slot, config: &UserConfig) -> bool {
// can only have caused a toxic pair violation if we just added a 2nd resident in an open shift
if !slot.is_open_second() {
@@ -164,93 +108,6 @@ impl MonthlySchedule {
false
}
/// is_workload_unbalanced
pub fn is_workload_unbalanced(
&self,
slot: &Slot,
config: &UserConfig,
bounds: &WorkloadBounds,
) -> bool {
let res_id = match self.get_resident_id(slot) {
Some(id) => id,
None => return false,
};
if let Some(resident) = config.residents.iter().find(|r| &r.id == res_id) {
let current_workload = self.current_workload(&resident.id);
if let Some(&limit) = bounds.max_workloads.get(res_id) {
let mut workload_limit = limit;
if resident.reduced_load {
workload_limit -= 1;
}
if current_workload > workload_limit as usize {
return true;
}
}
}
false
}
/// is_holiday_workload_imbalanced
pub fn is_holiday_workload_imbalanced(
&self,
slot: &Slot,
config: &UserConfig,
bounds: &WorkloadBounds,
) -> bool {
if !config.is_holiday_or_weekend_slot(slot.day.0) {
return false;
}
let res_id = match self.get_resident_id(slot) {
Some(id) => id,
None => return false,
};
// if let Some(resident) = config.residents.iter().find(|r| &r.id == res_id) {
let current_holiday_workload = self.current_holiday_workload(res_id, config);
if let Some(&holiday_limit) = bounds.max_holiday_shifts.get(res_id) {
if current_holiday_workload > holiday_limit as usize {
return true;
}
}
// }
false
}
/// is_shift_type_distribution_unfair
pub fn is_shift_type_distribution_unfair(&self, slot: &Slot, bounds: &WorkloadBounds) -> bool {
let resident_id = match self.get_resident_id(slot) {
Some(id) => id,
None => return false,
};
let current_shift_type = if slot.is_open_shift() {
match slot.position {
ShiftPosition::First => ShiftType::OpenFirst,
ShiftPosition::Second => ShiftType::OpenSecond,
}
} else {
ShiftType::Closed
};
let current_count = self.count_shifts(resident_id, Some(current_shift_type.clone()));
if let Some(&limit) = bounds
.max_by_shift_type
.get(&(resident_id.clone(), current_shift_type.clone()))
{
return current_count > limit as usize;
}
false
}
pub fn pretty_print(&self, config: &UserConfig) -> String {
let mut sorted: Vec<_> = self.0.iter().collect();
sorted.sort_by_key(|(slot, _)| (slot.day, slot.position));
@@ -278,33 +135,32 @@ impl MonthlySchedule {
output
}
pub fn report(&self, config: &UserConfig) -> String {
pub fn report(&self, config: &UserConfig, tracker: &WorkloadTracker) -> String {
let mut output = String::new();
output.push_str("\n--- Αναφορά ---\n");
// Using standard widths for Greek characters and alignment
output.push_str(&format!(
"{:<15} | {:<6} | {:<10} | {:<10} | {:<7} | {:<10}\n",
"Ειδικευόμενος", "Σύνολο", "Ανοιχτή(1)", "Ανοιχτή(2)", "Κλειστή", "ΣΚ/Αργίες"
));
output.push_str("-".repeat(75).as_str());
output.push_str("-".repeat(85).as_str());
output.push('\n');
let mut residents: Vec<_> = config.residents.iter().collect();
residents.sort_by_key(|r| &r.name);
for res in residents {
let total = self.current_workload(&res.id);
let o1 = self.count_shifts(&res.id, Some(ShiftType::OpenFirst));
let o2 = self.count_shifts(&res.id, Some(ShiftType::OpenSecond));
let cl = self.count_shifts(&res.id, Some(ShiftType::Closed));
let sun = self.current_holiday_workload(&res.id, config);
let total = tracker.current_workload(&res.id);
let o1 = tracker.get_type_count(&res.id, ShiftType::OpenFirst);
let o2 = tracker.get_type_count(&res.id, ShiftType::OpenSecond);
let cl = tracker.get_type_count(&res.id, ShiftType::Closed);
let holiday = tracker.current_holiday_workload(&res.id);
output.push_str(&format!(
"{:<15} | {:<6} | {:<10} | {:<10} | {:<7} | {:<10}\n",
res.name, total, o1, o2, cl, sun
res.name, total, o1, o2, cl, holiday
));
}
output.push_str("-".repeat(75).as_str());
output.push_str("-".repeat(85).as_str());
output.push('\n');
output
}
@@ -340,7 +196,6 @@ mod tests {
use rstest::{fixture, rstest};
use crate::{
bounds::WorkloadBounds,
config::{ToxicPair, UserConfig},
resident::{Resident, ResidentId},
schedule::{Day, MonthlySchedule, Slot},
@@ -386,7 +241,6 @@ mod tests {
schedule.get_resident_id(&slot_1),
Some(&ResidentId("1".to_string()))
);
assert_eq!(schedule.current_workload(&resident.id), 1);
assert_eq!(schedule.get_resident_id(&slot_2), None);
}
@@ -395,11 +249,9 @@ mod tests {
let slot_1 = Slot::new(Day(1), ShiftPosition::First);
schedule.insert(slot_1, &resident.id);
assert_eq!(schedule.current_workload(&resident.id), 1);
schedule.remove(slot_1);
assert_eq!(schedule.get_resident_id(&slot_1), None);
assert_eq!(schedule.current_workload(&resident.id), 0);
}
#[rstest]
@@ -412,9 +264,9 @@ mod tests {
schedule.insert(slot_2, &resident.id);
schedule.insert(slot_3, &resident.id);
assert!(!schedule.same_resident_in_consecutive_days(&slot_1));
assert!(!schedule.same_resident_in_consecutive_days(&slot_2));
assert!(schedule.same_resident_in_consecutive_days(&slot_3));
assert!(!schedule.has_resident_in_consecutive_days(&slot_1));
assert!(!schedule.has_resident_in_consecutive_days(&slot_2));
assert!(schedule.has_resident_in_consecutive_days(&slot_3));
}
#[rstest]
@@ -430,56 +282,4 @@ mod tests {
assert!(schedule.has_toxic_pair(&slot_2, &toxic_config))
}
#[rstest]
fn test_is_workload_unbalanced(mut schedule: MonthlySchedule, config: UserConfig) {
let slot_1 = Slot::new(Day(1), ShiftPosition::First);
let slot_2 = Slot::new(Day(1), ShiftPosition::Second);
let slot_3 = Slot::new(Day(2), ShiftPosition::First);
let stefanos = &config.residents[0];
let iordanis = &config.residents[1];
let mut bounds = WorkloadBounds::new();
bounds.max_workloads.insert(ResidentId("1".to_string()), 1);
bounds.max_workloads.insert(ResidentId("2".to_string()), 2);
schedule.insert(slot_1, &stefanos.id);
assert!(!schedule.is_workload_unbalanced(&slot_1, &config, &bounds));
schedule.insert(slot_2, &iordanis.id);
assert!(!schedule.is_workload_unbalanced(&slot_2, &config, &bounds));
schedule.insert(slot_3, &stefanos.id);
assert!(schedule.is_workload_unbalanced(&slot_3, &config, &bounds));
}
#[rstest]
fn test_is_holiday_workload_imbalanced(mut schedule: MonthlySchedule, config: UserConfig) {
let slot_1 = Slot::new(Day(1), ShiftPosition::First);
let slot_2 = Slot::new(Day(1), ShiftPosition::Second);
let slot_7 = Slot::new(Day(7), ShiftPosition::First);
let stefanos = &config.residents[0];
let iordanis = &config.residents[1];
let mut bounds = WorkloadBounds::new();
bounds
.max_holiday_shifts
.insert(ResidentId("1".to_string()), 1);
bounds
.max_holiday_shifts
.insert(ResidentId("2".to_string()), 1);
schedule.insert(slot_1, &stefanos.id);
assert!(!schedule.is_holiday_workload_imbalanced(&slot_1, &config, &bounds));
schedule.insert(slot_2, &iordanis.id);
assert!(!schedule.is_holiday_workload_imbalanced(&slot_2, &config, &bounds));
schedule.insert(slot_7, &stefanos.id);
assert!(schedule.is_holiday_workload_imbalanced(&slot_7, &config, &bounds));
}
}

View File

@@ -1,6 +1,9 @@
use crate::{
bounds::WorkloadBounds, config::UserConfig, resident::ResidentId, schedule::MonthlySchedule,
config::UserConfig,
resident::ResidentId,
schedule::MonthlySchedule,
slot::Slot,
workload::{WorkloadBounds, WorkloadTracker},
};
use rand::Rng;
@@ -15,52 +18,62 @@ impl Scheduler {
Self { config, bounds }
}
pub fn run(&self, schedule: &mut MonthlySchedule) -> bool {
pub fn new_with_config(config: UserConfig) -> Self {
let bounds = WorkloadBounds::new_with_config(&config);
Self { config, bounds }
}
pub fn run(&self, schedule: &mut MonthlySchedule, tracker: &mut WorkloadTracker) -> bool {
schedule.prefill(&self.config);
self.search(schedule, Slot::default())
for (slot, res_id) in schedule.0.iter() {
tracker.insert(res_id, &self.config, *slot);
}
self.search(schedule, tracker, Slot::default())
}
/// DFS where maximum depth is calculated by total_days_of_month + odd_days_of_month each node is called a slot
/// Starts with schedule partially completed from the user interface
/// Ends with a full schedule following restrictions and fairness
pub fn search(&self, schedule: &mut MonthlySchedule, slot: Slot) -> bool {
pub fn search(
&self,
schedule: &mut MonthlySchedule,
tracker: &mut WorkloadTracker,
slot: Slot,
) -> bool {
if !slot.is_first()
&& schedule.restrictions_violated(&slot.previous(), &self.config, &self.bounds)
&& schedule.restrictions_violated(&slot.previous(), &self.config, &self.bounds, tracker)
{
log::trace!("Cutting branch due to restriction violation");
return false;
}
if slot.greater_than(self.config.total_days()) {
if !schedule.is_per_shift_threshold_met(&self.config, &self.bounds) {
return false;
}
log::trace!("Solution found, exiting recursive algorithm");
return true;
if slot.greater_than(self.config.total_days) {
return tracker.are_all_thresholds_met(&self.config, &self.bounds);
}
if schedule.is_slot_manually_assigned(&slot) {
return self.search(schedule, slot.next());
return self.search(schedule, tracker, slot.next());
}
// sort candidates by current workload, add rng for tie breakers
let mut valid_resident_ids = self.valid_residents(slot, schedule);
valid_resident_ids.sort_unstable_by_key(|res_id| {
let workload = schedule.current_workload(res_id);
let workload = tracker.current_workload(res_id);
let tie_breaker: f64 = rand::rng().random();
(workload, (tie_breaker * 1000.0) as usize)
});
for id in &valid_resident_ids {
schedule.insert(slot, id);
tracker.insert(id, &self.config, slot);
if self.search(schedule, slot.next()) {
log::trace!("Solution found, exiting recursive algorithm");
if self.search(schedule, tracker, slot.next()) {
return true;
}
schedule.remove(slot);
tracker.remove(id, &self.config, slot);
}
false
@@ -68,15 +81,19 @@ impl Scheduler {
/// Return all valid residents for the current slot
pub fn valid_residents(&self, slot: Slot, schedule: &MonthlySchedule) -> Vec<&ResidentId> {
let required_type = slot.shift_type(); // Calculate once here
let other_resident = schedule.get_resident_id(&slot.other_position());
let required_type = slot.shift_type();
let other_resident = slot
.other_position()
.and_then(|partner_slot| schedule.get_resident_id(&partner_slot));
self.config
.residents
.iter()
.filter(|r| Some(&r.id) != other_resident)
.filter(|r| !r.negative_shifts.contains(&slot.day))
.filter(|r| r.allowed_types.contains(&required_type))
.filter(|r| {
Some(&r.id) != other_resident
&& !r.negative_shifts.contains(&slot.day)
&& r.allowed_types.contains(&required_type)
})
.map(|r| &r.id)
.collect()
}
@@ -87,12 +104,12 @@ mod tests {
use rstest::{fixture, rstest};
use crate::{
bounds::WorkloadBounds,
config::UserConfig,
resident::Resident,
schedule::MonthlySchedule,
scheduler::Scheduler,
slot::{Day, ShiftPosition, Slot},
workload::{WorkloadBounds, WorkloadTracker},
};
#[fixture]
@@ -122,11 +139,20 @@ mod tests {
Scheduler::new(config, bounds)
}
#[rstest]
fn test_search(mut schedule: MonthlySchedule, scheduler: Scheduler) {
assert!(scheduler.search(&mut schedule, Slot::default()));
#[fixture]
fn tracker() -> WorkloadTracker {
WorkloadTracker::default()
}
for d in 1..=scheduler.config.total_days() {
#[rstest]
fn test_search(
mut schedule: MonthlySchedule,
mut tracker: WorkloadTracker,
scheduler: Scheduler,
) {
assert!(scheduler.run(&mut schedule, &mut tracker));
for d in 1..=scheduler.config.total_days {
let day = Day(d);
if day.is_open_shift() {
let slot_first = Slot::new(day, ShiftPosition::First);
@@ -140,9 +166,9 @@ mod tests {
}
for r in &scheduler.config.residents {
let workload = schedule.current_workload(&r.id);
let workload = tracker.current_workload(&r.id);
let limit = *scheduler.bounds.max_workloads.get(&r.id).unwrap();
assert!(workload <= limit as usize);
assert!(workload <= limit);
}
println!("{}", schedule.pretty_print(&scheduler.config));

View File

@@ -79,16 +79,20 @@ impl Slot {
self.day.greater_than(&Day(limit))
}
pub fn other_position(&self) -> Self {
pub fn other_position(&self) -> Option<Self> {
if !self.day.is_open_shift() {
return None;
}
let other_pos = match self.position {
ShiftPosition::First => ShiftPosition::Second,
ShiftPosition::Second => ShiftPosition::First,
};
Self {
Some(Self {
day: self.day,
position: other_pos,
}
})
}
pub fn shift_type(&self) -> ShiftType {
@@ -186,7 +190,10 @@ pub fn month_to_greek(month: u32) -> &'static str {
mod tests {
use rstest::rstest;
use crate::slot::{Day, ShiftPosition, Slot};
use crate::{
schedule::ShiftType,
slot::{Day, ShiftPosition, Slot},
};
#[rstest]
fn test_slot() {
@@ -215,6 +222,14 @@ mod tests {
assert!(!slot_1.greater_than(1));
assert!(!slot_2.greater_than(1));
assert!(slot_3.greater_than(1));
assert_eq!(slot_1.other_position(), Some(slot_1.next()));
assert_eq!(slot_2.other_position(), Some(slot_2.previous()));
assert_eq!(slot_3.other_position(), None);
assert_eq!(slot_1.shift_type(), ShiftType::OpenFirst);
assert_eq!(slot_2.shift_type(), ShiftType::OpenSecond);
assert_eq!(slot_3.shift_type(), ShiftType::Closed);
}
#[rstest]

407
src-tauri/src/workload.rs Normal file
View File

@@ -0,0 +1,407 @@
use std::collections::HashMap;
use crate::{
config::UserConfig,
resident::ResidentId,
schedule::ShiftType,
slot::{Day, Slot},
};
#[derive(Default)]
pub struct WorkloadBounds {
pub max_workloads: HashMap<ResidentId, u8>,
pub max_holiday_shifts: HashMap<ResidentId, u8>,
pub max_by_shift_type: HashMap<(ResidentId, ShiftType), u8>,
pub min_by_shift_type: HashMap<(ResidentId, ShiftType), u8>,
}
impl WorkloadBounds {
pub fn new_with_config(config: &UserConfig) -> Self {
let mut bounds = Self::default();
bounds.calculate_max_workloads(config);
bounds.calculate_max_holiday_shifts(config);
bounds.calculate_max_by_shift_type(config);
bounds
}
/// get map with total amount of slots in a month for each type of shift
pub fn get_initial_supply(&self, config: &UserConfig) -> HashMap<ShiftType, u8> {
let mut supply = HashMap::new();
let total_days = config.total_days;
for d in 1..=total_days {
if Day(d).is_open_shift() {
*supply.entry(ShiftType::OpenFirst).or_insert(0) += 1;
*supply.entry(ShiftType::OpenSecond).or_insert(0) += 1;
} else {
*supply.entry(ShiftType::Closed).or_insert(0) += 1;
}
}
supply
}
/// this is called after the user config params have been initialized, can be done with the builder (lite) pattern
/// initialize a hashmap for O(1) search calls for the residents' max workload
pub fn calculate_max_workloads(&mut self, config: &UserConfig) {
let auto_computed_residents: Vec<_> = config
.residents
.iter()
.filter(|r| r.max_shifts.is_none())
.collect();
// if all residents have a manually set max shifts size, just use those values for the max workload
if auto_computed_residents.is_empty() {
for r in &config.residents {
self.max_workloads
.insert(r.id.clone(), r.max_shifts.unwrap_or(0) as u8);
}
return;
}
// Untested scenario: Resident has manual max_shifts and also reduced workload flag
// Probably should forbid using both options from GUI
let manual_max_shifts_sum: usize = config
.residents
.iter()
.map(|r| r.max_shifts.unwrap_or(0))
.sum();
let max_shifts_ceiling = ((config.total_slots as usize - manual_max_shifts_sum) as f32
/ auto_computed_residents.len() as f32)
.ceil() as u8;
for r in &config.residents {
let max_shifts = match r.max_shifts {
Some(shifts) => shifts as u8,
None if r.reduced_load => max_shifts_ceiling - 1,
None => max_shifts_ceiling,
};
self.max_workloads.insert(r.id.clone(), max_shifts);
}
}
pub fn calculate_max_holiday_shifts(&mut self, config: &UserConfig) {
let total_slots = config.total_slots;
let total_holiday_slots = config.total_holiday_slots;
for r in &config.residents {
let workload_limit = *self.max_workloads.get(&r.id).unwrap_or(&0);
let share = (workload_limit as f32 / total_slots as f32) * total_holiday_slots as f32;
let holiday_limit = share.ceil() as u8;
self.max_holiday_shifts.insert(r.id.clone(), holiday_limit);
}
}
pub fn calculate_max_by_shift_type(&mut self, config: &UserConfig) {
let mut supply_by_shift_type = self.get_initial_supply(config);
let mut local_limits = HashMap::new();
let mut local_thresholds = HashMap::new();
let all_shift_types = [
ShiftType::OpenFirst,
ShiftType::OpenSecond,
ShiftType::Closed,
];
// residents with 1 available shift types
for res in config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 1)
{
let shift_type = &res.allowed_types[0];
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0);
local_limits.insert((res.id.clone(), shift_type.clone()), total_limit);
local_thresholds.insert(
(res.id.clone(), shift_type.clone()),
total_limit.saturating_sub(2),
);
for other_type in &all_shift_types {
if other_type != shift_type {
local_limits.insert((res.id.clone(), other_type.clone()), 0);
local_thresholds.insert((res.id.clone(), other_type.clone()), 0);
}
}
if let Some(s) = supply_by_shift_type.get_mut(shift_type) {
*s = s.saturating_sub(total_limit)
}
}
// residents with 2 available shift types
for res in config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 2)
{
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0);
let per_type = ((total_limit as f32) / 2.0).ceil() as u8;
let deduct_amount = (total_limit as f32 / 2.0) as u8;
for shift_type in &all_shift_types {
if res.allowed_types.contains(shift_type) {
local_limits.insert((res.id.clone(), shift_type.clone()), per_type);
local_thresholds.insert(
(res.id.clone(), shift_type.clone()),
per_type.saturating_sub(2),
);
if let Some(s) = supply_by_shift_type.get_mut(shift_type) {
*s = s.saturating_sub(deduct_amount);
}
} else {
local_limits.insert((res.id.clone(), shift_type.clone()), 0);
local_thresholds.insert((res.id.clone(), shift_type.clone()), 0);
}
}
}
// residents with 3 available shift types
for res in config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 3)
{
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0);
let per_type = ((total_limit as f32) / 3.0).ceil() as u8;
let deduct_amount = (total_limit as f32 / 3.0) as u8;
for shift_type in &all_shift_types {
if res.allowed_types.contains(shift_type) {
local_limits.insert((res.id.clone(), shift_type.clone()), per_type);
local_thresholds.insert(
(res.id.clone(), shift_type.clone()),
per_type.saturating_sub(2),
);
if let Some(s) = supply_by_shift_type.get_mut(shift_type) {
*s = s.saturating_sub(deduct_amount);
}
} else {
local_limits.insert((res.id.clone(), shift_type.clone()), 0);
local_thresholds.insert((res.id.clone(), shift_type.clone()), 0);
}
}
}
self.max_by_shift_type = local_limits;
self.min_by_shift_type = local_thresholds;
}
}
#[derive(Default, Clone, Debug)]
pub struct WorkloadTracker {
total_counts: HashMap<ResidentId, u8>,
type_counts: HashMap<(ResidentId, ShiftType), u8>,
holidays: HashMap<ResidentId, u8>,
}
impl WorkloadTracker {
pub fn insert(&mut self, res_id: &ResidentId, config: &UserConfig, slot: Slot) {
*self.total_counts.entry(res_id.clone()).or_insert(0) += 1;
*self
.type_counts
.entry((res_id.clone(), slot.shift_type()))
.or_insert(0) += 1;
if config.is_holiday_or_weekend_slot(slot.day.0) {
*self.holidays.entry(res_id.clone()).or_insert(0) += 1;
}
}
pub fn remove(&mut self, resident_id: &ResidentId, config: &UserConfig, slot: Slot) {
if let Some(count) = self.total_counts.get_mut(resident_id) {
*count = count.saturating_sub(1);
}
if let Some(count) = self
.type_counts
.get_mut(&(resident_id.clone(), slot.shift_type()))
{
*count = count.saturating_sub(1);
}
if config.is_holiday_or_weekend_slot(slot.day.0) {
if let Some(count) = self.holidays.get_mut(resident_id) {
*count = count.saturating_sub(1);
}
}
}
pub fn current_workload(&self, res_id: &ResidentId) -> u8 {
*self.total_counts.get(res_id).unwrap_or(&0)
}
pub fn current_holiday_workload(&self, resident_id: &ResidentId) -> u8 {
*self.holidays.get(resident_id).unwrap_or(&0)
}
pub fn are_all_thresholds_met(&self, config: &UserConfig, bounds: &WorkloadBounds) -> bool {
const SHIFT_TYPES: [ShiftType; 3] = [
ShiftType::OpenFirst,
ShiftType::OpenSecond,
ShiftType::Closed,
];
for r in &config.residents {
for shift_type in SHIFT_TYPES {
let current_load = self
.type_counts
.get(&(r.id.clone(), shift_type.clone()))
.unwrap_or(&0);
if let Some(&min) = bounds
.min_by_shift_type
.get(&(r.id.clone(), shift_type.clone()))
{
if *current_load < min {
return false;
}
}
}
}
true
}
pub fn is_total_workload_exceeded(
&self,
bounds: &WorkloadBounds,
resident_id: &ResidentId,
) -> bool {
let current_load = self.current_workload(resident_id);
if let Some(&max) = bounds.max_workloads.get(resident_id) {
if current_load > max {
return true;
}
}
false
}
pub fn is_holiday_workload_exceeded(
&self,
bounds: &WorkloadBounds,
resident_id: &ResidentId,
) -> bool {
let current_load = self.current_holiday_workload(resident_id);
if let Some(&max) = bounds.max_holiday_shifts.get(resident_id) {
if current_load > max {
return true;
}
}
false
}
pub fn is_max_shift_type_exceeded(
&self,
bounds: &WorkloadBounds,
resident_id: &ResidentId,
slot: &Slot,
) -> bool {
let shift_type = slot.shift_type();
let current_load = self
.type_counts
.get(&(resident_id.clone(), shift_type.clone()))
.unwrap_or(&0);
if let Some(&max) = bounds
.max_by_shift_type
.get(&(resident_id.clone(), shift_type.clone()))
{
return *current_load > max;
}
false
}
pub fn get_type_count(&self, res_id: &ResidentId, stype: ShiftType) -> u8 {
*self.type_counts.get(&(res_id.clone(), stype)).unwrap_or(&0)
}
}
#[cfg(test)]
mod tests {
use crate::{
config::UserConfig,
resident::{Resident, ResidentId},
slot::{Day, ShiftPosition, Slot},
workload::{WorkloadBounds, WorkloadTracker},
};
use rstest::{fixture, rstest};
#[fixture]
fn config() -> UserConfig {
UserConfig::default().with_residents(vec![
Resident::new("1", "Stefanos").with_max_shifts(2),
Resident::new("2", "Iordanis").with_max_shifts(2),
Resident::new("3", "Maria").with_reduced_load(),
Resident::new("4", "Veatriki"),
Resident::new("5", "Takis"),
])
}
#[fixture]
fn tracker() -> WorkloadTracker {
WorkloadTracker::default()
}
#[rstest]
fn test_max_workloads(config: UserConfig) {
let bounds = WorkloadBounds::new_with_config(&config);
assert_eq!(bounds.max_workloads[&ResidentId("1".to_string())], 2);
assert_eq!(bounds.max_workloads[&ResidentId("2".to_string())], 2);
assert!(bounds.max_workloads[&ResidentId("3".to_string())] > 0);
}
#[rstest]
fn test_is_total_workload_exceeded(mut tracker: WorkloadTracker, config: UserConfig) {
let res_id = ResidentId("1".to_string());
let mut bounds = WorkloadBounds::default();
bounds.max_workloads.insert(res_id.clone(), 1);
let slot_1 = Slot::new(Day(1), ShiftPosition::First);
let slot_2 = Slot::new(Day(2), ShiftPosition::First);
tracker.insert(&res_id, &config, slot_1);
assert!(!tracker.is_total_workload_exceeded(&bounds, &res_id,));
tracker.insert(&res_id, &config, slot_2);
assert!(tracker.is_total_workload_exceeded(&bounds, &res_id,));
}
#[rstest]
fn test_is_holiday_workload_exceeded(mut tracker: WorkloadTracker, config: UserConfig) {
let res_id = ResidentId("1".to_string());
let mut bounds = WorkloadBounds::default();
bounds.max_holiday_shifts.insert(res_id.clone(), 1);
let sat = Slot::new(Day(7), ShiftPosition::First);
let sun = Slot::new(Day(8), ShiftPosition::First);
tracker.insert(&res_id, &config, sat);
assert!(!tracker.is_holiday_workload_exceeded(&bounds, &res_id));
tracker.insert(&res_id, &config, sun);
assert!(tracker.is_holiday_workload_exceeded(&bounds, &res_id));
}
#[rstest]
fn test_backtracking_accuracy(mut tracker: WorkloadTracker, config: UserConfig) {
let res_id = ResidentId("1".to_string());
let slot = Slot::new(Day(1), ShiftPosition::First);
tracker.insert(&res_id, &config, slot);
assert_eq!(tracker.current_workload(&res_id), 1);
tracker.remove(&res_id, &config, slot);
assert_eq!(tracker.current_workload(&res_id), 0);
}
}

View File

@@ -0,0 +1,186 @@
#[cfg(test)]
mod integration_tests {
use rota_lib::{
config::{ToxicPair, UserConfig},
resident::Resident,
schedule::{MonthlySchedule, ShiftType},
scheduler::Scheduler,
slot::{Day, ShiftPosition, Slot},
workload::{WorkloadBounds, WorkloadTracker},
};
use rstest::{fixture, rstest};
#[fixture]
fn minimal_config() -> UserConfig {
UserConfig::new(2).with_residents(vec![
Resident::new("1", "R1"),
Resident::new("2", "R2"),
Resident::new("3", "R3"),
Resident::new("4", "R4"),
])
}
#[fixture]
fn maximal_config() -> UserConfig {
UserConfig::new(2)
.with_holidays(vec![2, 3, 10, 11, 12, 25])
.with_residents(vec![
Resident::new("1", "R1").with_max_shifts(3),
Resident::new("2", "R2").with_max_shifts(4),
Resident::new("3", "R3").with_reduced_load(),
Resident::new("4", "R4").with_allowed_types(vec![ShiftType::Closed]),
Resident::new("5", "R5")
.with_allowed_types(vec![ShiftType::OpenFirst, ShiftType::OpenSecond]),
Resident::new("6", "R6").with_negative_shifts(vec![Day(5), Day(15), Day(25)]),
Resident::new("7", "R7"),
Resident::new("8", "R8"),
Resident::new("9", "R9"),
Resident::new("10", "R10"),
])
.with_toxic_pairs(vec![
ToxicPair::new("1", "2"),
ToxicPair::new("3", "4"),
ToxicPair::new("7", "8"),
])
}
#[fixture]
fn manual_shifts_heavy_config() -> UserConfig {
UserConfig::new(2).with_residents(vec![
Resident::new("1", "R1").with_manual_shifts(vec![
Slot::new(Day(1), ShiftPosition::First),
Slot::new(Day(3), ShiftPosition::First),
Slot::new(Day(5), ShiftPosition::Second),
]),
Resident::new("2", "R2").with_manual_shifts(vec![
Slot::new(Day(2), ShiftPosition::First),
Slot::new(Day(4), ShiftPosition::First),
]),
Resident::new("3", "R3"),
Resident::new("4", "R4"),
Resident::new("5", "R5"),
Resident::new("6", "R6"),
])
}
#[fixture]
fn complex_config() -> UserConfig {
UserConfig::new(2)
.with_holidays(vec![5, 12, 19, 26])
.with_residents(vec![
Resident::new("1", "R1")
.with_max_shifts(3)
.with_negative_shifts(vec![Day(1), Day(2), Day(3)]),
Resident::new("2", "R2")
.with_max_shifts(3)
.with_negative_shifts(vec![Day(4), Day(5), Day(6)]),
Resident::new("3", "R3")
.with_max_shifts(3)
.with_negative_shifts(vec![Day(7), Day(8), Day(9)]),
Resident::new("4", "R4").with_allowed_types(vec![ShiftType::Closed]),
Resident::new("5", "R5")
.with_allowed_types(vec![ShiftType::OpenFirst, ShiftType::OpenSecond]),
Resident::new("6", "R6"),
Resident::new("7", "R7"),
Resident::new("8", "R8"),
])
.with_toxic_pairs(vec![
ToxicPair::new("1", "2"),
ToxicPair::new("2", "3"),
ToxicPair::new("5", "6"),
ToxicPair::new("6", "7"),
])
}
#[rstest]
fn test_minimal_config(minimal_config: UserConfig) {
let mut schedule = MonthlySchedule::new();
let mut tracker = WorkloadTracker::default();
let scheduler = Scheduler::new_with_config(minimal_config.clone());
assert!(scheduler.run(&mut schedule, &mut tracker));
validate_all_constraints(&schedule, &tracker, &minimal_config);
}
#[rstest]
fn test_maximal_config(maximal_config: UserConfig) {
let mut schedule = MonthlySchedule::new();
let mut tracker = WorkloadTracker::default();
let scheduler = Scheduler::new_with_config(maximal_config.clone());
assert!(scheduler.run(&mut schedule, &mut tracker));
validate_all_constraints(&schedule, &tracker, &maximal_config);
}
#[rstest]
fn test_manual_shifts_heavy_config(manual_shifts_heavy_config: UserConfig) {
let mut schedule = MonthlySchedule::new();
let mut tracker = WorkloadTracker::default();
let scheduler = Scheduler::new_with_config(manual_shifts_heavy_config.clone());
assert!(scheduler.run(&mut schedule, &mut tracker));
validate_all_constraints(&schedule, &tracker, &manual_shifts_heavy_config);
}
#[rstest]
fn test_complex_config(complex_config: UserConfig) {
let mut schedule = MonthlySchedule::new();
let mut tracker = WorkloadTracker::default();
let scheduler = Scheduler::new_with_config(complex_config.clone());
assert!(scheduler.run(&mut schedule, &mut tracker));
validate_all_constraints(&schedule, &tracker, &complex_config);
}
fn validate_all_constraints(
schedule: &MonthlySchedule,
tracker: &WorkloadTracker,
config: &UserConfig,
) {
assert_eq!(schedule.0.len() as u8, config.total_slots);
for d in 2..=config.total_days {
let current: Vec<_> = [ShiftPosition::First, ShiftPosition::Second]
.iter()
.filter_map(|&p| schedule.get_resident_id(&Slot::new(Day(d), p)))
.collect();
let previous: Vec<_> = [ShiftPosition::First, ShiftPosition::Second]
.iter()
.filter_map(|&p| schedule.get_resident_id(&Slot::new(Day(d - 1), p)))
.collect();
for res in current {
assert!(!previous.contains(&res));
}
}
for d in 1..=config.total_days {
let day = Day(d);
if day.is_open_shift() {
let r1 = schedule.get_resident_id(&Slot::new(day, ShiftPosition::First));
let r2 = schedule.get_resident_id(&Slot::new(day, ShiftPosition::Second));
assert_ne!(r1, r2);
if let (Some(id1), Some(id2)) = (r1, r2) {
let pair = ToxicPair::from((id1.clone(), id2.clone()));
assert!(config.toxic_pairs.iter().all(|t| !t.matches(&pair)));
}
}
}
let bounds = WorkloadBounds::new_with_config(config);
for (slot, res_id) in &schedule.0 {
let res = config
.residents
.iter()
.find(|r| &r.id == res_id)
.expect("Resident not found");
assert!(res.allowed_types.contains(&slot.shift_type()));
assert!(!res.negative_shifts.contains(&slot.day));
}
for resident in &config.residents {
let workload = tracker.current_workload(&resident.id);
let max = *bounds.max_workloads.get(&resident.id).unwrap();
assert!(workload <= max);
}
}
}