Skip to content

Decouple prediction_window from training in IncomingAdmissionPredictor #136

@zmek

Description

@zmek

Problem

Both prediction_window and prediction_times are currently required at training time (fit()) for all three IncomingAdmissionPredictor subclasses (DirectAdmissionPredictor, ParametricIncomingAdmissionPredictor, EmpiricalIncomingAdmissionPredictor). However, neither fundamentally affects what is learned from the data — they only control which subset of the full 24-hour arrival-rate dictionary is stored.

time_varying_arrival_rates() already computes rates for every time-of-day interval across the full 24-hour cycle. The two parameters then truncate that complete dictionary along two axes:

  • prediction_times truncates which starting times get stored
  • prediction_window truncates how many intervals are stored per starting time

This coupling has several downsides:

  1. A model must be retrained to use a different prediction window or set of prediction times, even though the underlying arrival rates and survival curves are independent of both.
  2. Prediction-time validation is forced to reject mismatches — both predict/emergency_demand.py and predict/service.py raise errors when the requested prediction_window differs from yet_to_arrive_model.prediction_window, which would be unnecessary if the window were a predict-time concern.
  3. Prediction times must fall back to the nearest trained time_iter_prediction_inputs() uses find_nearest_previous_prediction_time() to snap to a trained time, losing precision. With the full rate dictionary, any time could be served exactly (snapping to the nearest yta_time_interval boundary instead).
  4. Model naming is tied to the window — e.g. yet_to_arrive_{hours}_hours in train/emergency_demand.py — suggesting separate models for separate windows, when one model could serve them all.

Root cause

In _calculate_parameters() (incoming_admission_predictors.py, line 496), the full 24-hour arrival-rate dictionary is computed by time_varying_arrival_rates(), but then only a subset is kept — Ntimes intervals for each time in prediction_times:

Ntimes = int(prediction_window / yta_time_interval)

arrival_rates_dict = time_varying_arrival_rates(
    df, yta_time_interval, num_days, verbose=self.verbose
)

for prediction_time_ in prediction_times:
    arrival_rates = [
        arrival_rates_dict[
            (datetime(1970, 1, 1, prediction_time_hr, prediction_time_min)
             + i * yta_time_interval).time()
        ]
        for i in range(Ntimes)
    ]
    prediction_time_dict[(prediction_time_hr, prediction_time_min)] = {
        "arrival_rates": arrival_rates
    }

Everything downstream — self.weights, self.NTimes, self.prediction_window_hours, self.prediction_times — is then locked to those choices.

Proposed change

Defer both prediction_window and prediction_times to predict() / predict_mean() so that a single trained model can serve any window length (up to 24 hours) and any prediction time, given the same yta_time_interval.

After this change, fit() would only need from the user:

  • train_df — the training data
  • yta_time_interval — the granularity of arrival-rate buckets
  • num_days — the number of days the training data spans
  • epsilon — error tolerance for distribution support
  • filters — (already on __init__) data categorization filters

Training side — simplify fit()

File / location What changes
IncomingAdmissionPredictor.fit() Remove prediction_window and prediction_times from signature.
_calculate_parameters() Store the full arrival_rates_dict (keyed by time-of-day) instead of pre-slicing to Ntimes entries per prediction time. Remove prediction_window and prediction_times parameters.
Cached metadata Remove self.prediction_window, self.prediction_window_hours, self.NTimes, self.prediction_times. Keep self.yta_time_interval and self.yta_time_interval_hours.
EmpiricalIncomingAdmissionPredictor.fit() Drop prediction_window and prediction_times from the super().fit() call. The survival-curve calculation is already independent of both.
train/incoming_admission_predictor.py Remove prediction_window and prediction_times from train_parametric_admission_predictor() and the yta_model.fit() call.
train/emergency_demand.py Remove prediction_window and prediction_times from training paths. Model naming no longer needs the window suffix.

Prediction side — accept both in predict() / predict_mean()

File / location What changes
_iter_prediction_inputs() Accept prediction_window. Derive Ntimes on the fly and slice arrival rates from the stored full-cycle dict using datetime wrap-around arithmetic (currently in _calculate_parameters). The prediction_time is already passed via prediction_context — it now resolves against the full dict rather than the trained subset.
_get_window_and_interval_hours() Accept prediction_window as an argument instead of reading self.prediction_window_hours and self.NTimes.
find_nearest_previous_prediction_time() No longer needed for snapping to trained times. Could be replaced by snapping to the nearest yta_time_interval boundary, which is more precise.
DirectAdmissionPredictor.predict() Accept prediction_window kwarg; pass through to iterator.
ParametricIncomingAdmissionPredictor.predict() Accept prediction_window kwarg; use it to compute time_remaining_before_end_of_window and theta.
EmpiricalIncomingAdmissionPredictor.predict() Accept prediction_window kwarg; pass to _calculate_survival_probabilities().
predict_mean() (base class) Accept and forward prediction_window.
_get_admission_probabilities() (all subclasses) Thread prediction_window through.
predict/emergency_demand.py Pass prediction_window to .predict() / .predict_mean(). Remove the mismatch validation (prediction_window != yet_to_arrive_model.prediction_window).
predict/service.py Same as above — pass window at predict time, remove mismatch checks for yet_to_arrive_model, non_ed_yta_model, and elective_yta_model.

Weights format change

The structure of self.weights changes from:

# Current: pre-sliced per prediction time
weights[filter_key][(hour, minute)] = {
    "arrival_rates": [rate_0, rate_1, ..., rate_{Ntimes-1}]
}

to:

# Proposed: full 24-hour arrival-rate dictionary
weights[filter_key] = {
    "arrival_rates_dict": {
        time(0, 0): rate,
        time(0, 30): rate,
        ...
        time(23, 30): rate,
    }
}

The per-prediction-time slicing moves into _iter_prediction_inputs() at predict time.

Note: This is a breaking change to the serialized model format. Any pickled models trained under the old format will need to be retrained after this change.

Backward compatibility

Consider a deprecation path:

  1. Accept prediction_window and prediction_times in fit() as optional parameters with a deprecation warning.
  2. If provided at training, store them as defaults that predict() uses when not passed explicitly — preserving existing call sites until they're migrated.
  3. Remove the deprecated path in a future release.

Files affected

  • src/patientflow/predictors/incoming_admission_predictors.py — core predictor classes
  • src/patientflow/train/incoming_admission_predictor.py — training utility
  • src/patientflow/train/emergency_demand.py — training orchestration
  • src/patientflow/predict/emergency_demand.py — prediction with validation
  • src/patientflow/predict/service.py — service-level prediction with validation
  • src/patientflow/viz/pipeline_plots.py — if it accesses model.prediction_window
  • Notebooks that call fit() with prediction_window / prediction_times

Benefits

  • One model, any window, any time — a single trained model works for any prediction window and any prediction time without retraining.
  • Simpler training APIfit() reduces to train_df, yta_time_interval, num_days, and epsilon.
  • More precise prediction times — any time-of-day can be served exactly (snapping to the nearest interval boundary) rather than falling back to a trained subset.
  • No mismatch errors — the predict-time validation that currently rejects window mismatches becomes unnecessary.
  • Smaller stored models — the full arrival-rate dict is the same size regardless of how many prediction windows or times you plan to use.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions