Use maps to track workload progress instead of recalculating them at every step of the search, calculate total days/slots once, add integration tests, add log folder

This commit is contained in:
2026-01-17 18:41:43 +02:00
parent 908f114e54
commit 5bad63e8a7
10 changed files with 819 additions and 574 deletions

407
src-tauri/src/workload.rs Normal file
View File

@@ -0,0 +1,407 @@
use std::collections::HashMap;
use crate::{
config::UserConfig,
resident::ResidentId,
schedule::ShiftType,
slot::{Day, Slot},
};
#[derive(Default)]
pub struct WorkloadBounds {
pub max_workloads: HashMap<ResidentId, u8>,
pub max_holiday_shifts: HashMap<ResidentId, u8>,
pub max_by_shift_type: HashMap<(ResidentId, ShiftType), u8>,
pub min_by_shift_type: HashMap<(ResidentId, ShiftType), u8>,
}
impl WorkloadBounds {
pub fn new_with_config(config: &UserConfig) -> Self {
let mut bounds = Self::default();
bounds.calculate_max_workloads(config);
bounds.calculate_max_holiday_shifts(config);
bounds.calculate_max_by_shift_type(config);
bounds
}
/// get map with total amount of slots in a month for each type of shift
pub fn get_initial_supply(&self, config: &UserConfig) -> HashMap<ShiftType, u8> {
let mut supply = HashMap::new();
let total_days = config.total_days;
for d in 1..=total_days {
if Day(d).is_open_shift() {
*supply.entry(ShiftType::OpenFirst).or_insert(0) += 1;
*supply.entry(ShiftType::OpenSecond).or_insert(0) += 1;
} else {
*supply.entry(ShiftType::Closed).or_insert(0) += 1;
}
}
supply
}
/// this is called after the user config params have been initialized, can be done with the builder (lite) pattern
/// initialize a hashmap for O(1) search calls for the residents' max workload
pub fn calculate_max_workloads(&mut self, config: &UserConfig) {
let auto_computed_residents: Vec<_> = config
.residents
.iter()
.filter(|r| r.max_shifts.is_none())
.collect();
// if all residents have a manually set max shifts size, just use those values for the max workload
if auto_computed_residents.is_empty() {
for r in &config.residents {
self.max_workloads
.insert(r.id.clone(), r.max_shifts.unwrap_or(0) as u8);
}
return;
}
// Untested scenario: Resident has manual max_shifts and also reduced workload flag
// Probably should forbid using both options from GUI
let manual_max_shifts_sum: usize = config
.residents
.iter()
.map(|r| r.max_shifts.unwrap_or(0))
.sum();
let max_shifts_ceiling = ((config.total_slots as usize - manual_max_shifts_sum) as f32
/ auto_computed_residents.len() as f32)
.ceil() as u8;
for r in &config.residents {
let max_shifts = match r.max_shifts {
Some(shifts) => shifts as u8,
None if r.reduced_load => max_shifts_ceiling - 1,
None => max_shifts_ceiling,
};
self.max_workloads.insert(r.id.clone(), max_shifts);
}
}
pub fn calculate_max_holiday_shifts(&mut self, config: &UserConfig) {
let total_slots = config.total_slots;
let total_holiday_slots = config.total_holiday_slots;
for r in &config.residents {
let workload_limit = *self.max_workloads.get(&r.id).unwrap_or(&0);
let share = (workload_limit as f32 / total_slots as f32) * total_holiday_slots as f32;
let holiday_limit = share.ceil() as u8;
self.max_holiday_shifts.insert(r.id.clone(), holiday_limit);
}
}
pub fn calculate_max_by_shift_type(&mut self, config: &UserConfig) {
let mut supply_by_shift_type = self.get_initial_supply(config);
let mut local_limits = HashMap::new();
let mut local_thresholds = HashMap::new();
let all_shift_types = [
ShiftType::OpenFirst,
ShiftType::OpenSecond,
ShiftType::Closed,
];
// residents with 1 available shift types
for res in config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 1)
{
let shift_type = &res.allowed_types[0];
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0);
local_limits.insert((res.id.clone(), shift_type.clone()), total_limit);
local_thresholds.insert(
(res.id.clone(), shift_type.clone()),
total_limit.saturating_sub(2),
);
for other_type in &all_shift_types {
if other_type != shift_type {
local_limits.insert((res.id.clone(), other_type.clone()), 0);
local_thresholds.insert((res.id.clone(), other_type.clone()), 0);
}
}
if let Some(s) = supply_by_shift_type.get_mut(shift_type) {
*s = s.saturating_sub(total_limit)
}
}
// residents with 2 available shift types
for res in config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 2)
{
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0);
let per_type = ((total_limit as f32) / 2.0).ceil() as u8;
let deduct_amount = (total_limit as f32 / 2.0) as u8;
for shift_type in &all_shift_types {
if res.allowed_types.contains(shift_type) {
local_limits.insert((res.id.clone(), shift_type.clone()), per_type);
local_thresholds.insert(
(res.id.clone(), shift_type.clone()),
per_type.saturating_sub(2),
);
if let Some(s) = supply_by_shift_type.get_mut(shift_type) {
*s = s.saturating_sub(deduct_amount);
}
} else {
local_limits.insert((res.id.clone(), shift_type.clone()), 0);
local_thresholds.insert((res.id.clone(), shift_type.clone()), 0);
}
}
}
// residents with 3 available shift types
for res in config
.residents
.iter()
.filter(|r| r.allowed_types.len() == 3)
{
let total_limit = *self.max_workloads.get(&res.id).unwrap_or(&0);
let per_type = ((total_limit as f32) / 3.0).ceil() as u8;
let deduct_amount = (total_limit as f32 / 3.0) as u8;
for shift_type in &all_shift_types {
if res.allowed_types.contains(shift_type) {
local_limits.insert((res.id.clone(), shift_type.clone()), per_type);
local_thresholds.insert(
(res.id.clone(), shift_type.clone()),
per_type.saturating_sub(2),
);
if let Some(s) = supply_by_shift_type.get_mut(shift_type) {
*s = s.saturating_sub(deduct_amount);
}
} else {
local_limits.insert((res.id.clone(), shift_type.clone()), 0);
local_thresholds.insert((res.id.clone(), shift_type.clone()), 0);
}
}
}
self.max_by_shift_type = local_limits;
self.min_by_shift_type = local_thresholds;
}
}
#[derive(Default, Clone, Debug)]
pub struct WorkloadTracker {
total_counts: HashMap<ResidentId, u8>,
type_counts: HashMap<(ResidentId, ShiftType), u8>,
holidays: HashMap<ResidentId, u8>,
}
impl WorkloadTracker {
pub fn insert(&mut self, res_id: &ResidentId, config: &UserConfig, slot: Slot) {
*self.total_counts.entry(res_id.clone()).or_insert(0) += 1;
*self
.type_counts
.entry((res_id.clone(), slot.shift_type()))
.or_insert(0) += 1;
if config.is_holiday_or_weekend_slot(slot.day.0) {
*self.holidays.entry(res_id.clone()).or_insert(0) += 1;
}
}
pub fn remove(&mut self, resident_id: &ResidentId, config: &UserConfig, slot: Slot) {
if let Some(count) = self.total_counts.get_mut(resident_id) {
*count = count.saturating_sub(1);
}
if let Some(count) = self
.type_counts
.get_mut(&(resident_id.clone(), slot.shift_type()))
{
*count = count.saturating_sub(1);
}
if config.is_holiday_or_weekend_slot(slot.day.0) {
if let Some(count) = self.holidays.get_mut(resident_id) {
*count = count.saturating_sub(1);
}
}
}
pub fn current_workload(&self, res_id: &ResidentId) -> u8 {
*self.total_counts.get(res_id).unwrap_or(&0)
}
pub fn current_holiday_workload(&self, resident_id: &ResidentId) -> u8 {
*self.holidays.get(resident_id).unwrap_or(&0)
}
pub fn are_all_thresholds_met(&self, config: &UserConfig, bounds: &WorkloadBounds) -> bool {
const SHIFT_TYPES: [ShiftType; 3] = [
ShiftType::OpenFirst,
ShiftType::OpenSecond,
ShiftType::Closed,
];
for r in &config.residents {
for shift_type in SHIFT_TYPES {
let current_load = self
.type_counts
.get(&(r.id.clone(), shift_type.clone()))
.unwrap_or(&0);
if let Some(&min) = bounds
.min_by_shift_type
.get(&(r.id.clone(), shift_type.clone()))
{
if *current_load < min {
return false;
}
}
}
}
true
}
pub fn is_total_workload_exceeded(
&self,
bounds: &WorkloadBounds,
resident_id: &ResidentId,
) -> bool {
let current_load = self.current_workload(resident_id);
if let Some(&max) = bounds.max_workloads.get(resident_id) {
if current_load > max {
return true;
}
}
false
}
pub fn is_holiday_workload_exceeded(
&self,
bounds: &WorkloadBounds,
resident_id: &ResidentId,
) -> bool {
let current_load = self.current_holiday_workload(resident_id);
if let Some(&max) = bounds.max_holiday_shifts.get(resident_id) {
if current_load > max {
return true;
}
}
false
}
pub fn is_max_shift_type_exceeded(
&self,
bounds: &WorkloadBounds,
resident_id: &ResidentId,
slot: &Slot,
) -> bool {
let shift_type = slot.shift_type();
let current_load = self
.type_counts
.get(&(resident_id.clone(), shift_type.clone()))
.unwrap_or(&0);
if let Some(&max) = bounds
.max_by_shift_type
.get(&(resident_id.clone(), shift_type.clone()))
{
return *current_load > max;
}
false
}
pub fn get_type_count(&self, res_id: &ResidentId, stype: ShiftType) -> u8 {
*self.type_counts.get(&(res_id.clone(), stype)).unwrap_or(&0)
}
}
#[cfg(test)]
mod tests {
use crate::{
config::UserConfig,
resident::{Resident, ResidentId},
slot::{Day, ShiftPosition, Slot},
workload::{WorkloadBounds, WorkloadTracker},
};
use rstest::{fixture, rstest};
#[fixture]
fn config() -> UserConfig {
UserConfig::default().with_residents(vec![
Resident::new("1", "Stefanos").with_max_shifts(2),
Resident::new("2", "Iordanis").with_max_shifts(2),
Resident::new("3", "Maria").with_reduced_load(),
Resident::new("4", "Veatriki"),
Resident::new("5", "Takis"),
])
}
#[fixture]
fn tracker() -> WorkloadTracker {
WorkloadTracker::default()
}
#[rstest]
fn test_max_workloads(config: UserConfig) {
let bounds = WorkloadBounds::new_with_config(&config);
assert_eq!(bounds.max_workloads[&ResidentId("1".to_string())], 2);
assert_eq!(bounds.max_workloads[&ResidentId("2".to_string())], 2);
assert!(bounds.max_workloads[&ResidentId("3".to_string())] > 0);
}
#[rstest]
fn test_is_total_workload_exceeded(mut tracker: WorkloadTracker, config: UserConfig) {
let res_id = ResidentId("1".to_string());
let mut bounds = WorkloadBounds::default();
bounds.max_workloads.insert(res_id.clone(), 1);
let slot_1 = Slot::new(Day(1), ShiftPosition::First);
let slot_2 = Slot::new(Day(2), ShiftPosition::First);
tracker.insert(&res_id, &config, slot_1);
assert!(!tracker.is_total_workload_exceeded(&bounds, &res_id,));
tracker.insert(&res_id, &config, slot_2);
assert!(tracker.is_total_workload_exceeded(&bounds, &res_id,));
}
#[rstest]
fn test_is_holiday_workload_exceeded(mut tracker: WorkloadTracker, config: UserConfig) {
let res_id = ResidentId("1".to_string());
let mut bounds = WorkloadBounds::default();
bounds.max_holiday_shifts.insert(res_id.clone(), 1);
let sat = Slot::new(Day(7), ShiftPosition::First);
let sun = Slot::new(Day(8), ShiftPosition::First);
tracker.insert(&res_id, &config, sat);
assert!(!tracker.is_holiday_workload_exceeded(&bounds, &res_id));
tracker.insert(&res_id, &config, sun);
assert!(tracker.is_holiday_workload_exceeded(&bounds, &res_id));
}
#[rstest]
fn test_backtracking_accuracy(mut tracker: WorkloadTracker, config: UserConfig) {
let res_id = ResidentId("1".to_string());
let slot = Slot::new(Day(1), ShiftPosition::First);
tracker.insert(&res_id, &config, slot);
assert_eq!(tracker.current_workload(&res_id), 1);
tracker.remove(&res_id, &config, slot);
assert_eq!(tracker.current_workload(&res_id), 0);
}
}