Skip to content

Commit 4b3fc8a

Browse files
committed
Revert "implement get_prob_dist_by_service to support evaluation"
This reverts commit 5079976.
1 parent 5079976 commit 4b3fc8a

2 files changed

Lines changed: 28 additions & 769 deletions

File tree

notebooks/4d_Evaluate_demand_predictions.ipynb

Lines changed: 2 additions & 491 deletions
Large diffs are not rendered by default.

src/patientflow/aggregate.py

Lines changed: 26 additions & 278 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@
4343
get_prob_dist_using_survival_curve : function
4444
Calculate probability distributions for each snapshot date based on given model predictions, using a survival curve to predict the probability of each patient being admitted within a given prediction window.
4545
46-
get_prob_dist_by_service : function
47-
Evaluate composed service-level predictions across the test set for one or more
48-
services, producing probability distributions and observed values in the standard
49-
evaluation format.
50-
5146
"""
5247

5348
import pandas as pd
@@ -453,14 +448,28 @@ def get_prob_dist(
453448
)
454449

455450
prob_dist_dict = {}
451+
if verbose:
452+
print(
453+
f"Calculating probability distributions for {len(snapshots_dict)} snapshot dates"
454+
)
455+
456+
if len(snapshots_dict) > 10:
457+
print(
458+
"Using efficient generating function approach - much faster than before!"
459+
)
460+
461+
# Initialize a counter for notifying the user every 10 snapshot dates processed
462+
count = 0
456463

457464
for dt, snapshots_to_include in snapshots_dict.items():
458465
if len(snapshots_to_include) == 0:
466+
# Create an empty dictionary for the current snapshot date
459467
prob_dist_dict[dt] = {
460468
"agg_predicted": pd.DataFrame({"agg_proba": [1]}, index=[0]),
461469
"agg_observed": 0,
462470
}
463471
else:
472+
# Ensure the lengths of test features and outcomes are equal
464473
assert len(X_test.loc[snapshots_to_include]) == len(
465474
y_test.loc[snapshots_to_include]
466475
), "Mismatch in lengths of X_test and y_test snapshots."
@@ -470,13 +479,15 @@ def get_prob_dist(
470479
else:
471480
prediction_moment_weights = weights.loc[snapshots_to_include].values
472481

482+
# Apply category filter
473483
if category_filter is None:
474484
prediction_moment_category_filter = None
475485
else:
476486
prediction_moment_category_filter = category_filter.loc[
477487
snapshots_to_include
478488
]
479489

490+
# Use the refactored generating function approach
480491
prob_dist_dict[dt] = get_prob_dist_for_prediction_moment(
481492
X_test=X_test.loc[snapshots_to_include],
482493
y_test=y_test.loc[snapshots_to_include],
@@ -486,6 +497,11 @@ def get_prob_dist(
486497
normal_approx_threshold=normal_approx_threshold,
487498
)
488499

500+
# Increment the counter and notify the user every 10 snapshot dates processed
501+
count += 1
502+
if verbose and count % 10 == 0 and count != len(snapshots_dict):
503+
print(f"Processed {count} snapshot dates")
504+
489505
if verbose:
490506
print(f"Processed {len(snapshots_dict)} snapshot dates")
491507

@@ -562,7 +578,12 @@ def get_prob_dist_using_survival_curve(
562578
)
563579

564580
prob_dist_dict = {}
581+
if verbose:
582+
print(
583+
f"Calculating probability distributions for {len(snapshot_dates)} snapshot dates"
584+
)
565585

586+
# Create prediction context that will be the same for all dates
566587
prediction_context = {category: {"prediction_time": prediction_time}}
567588

568589
for dt in snapshot_dates:
@@ -600,276 +621,3 @@ def get_prob_dist_using_survival_curve(
600621
print(f"Processed {len(snapshot_dates)} snapshot dates")
601622

602623
return prob_dist_dict
603-
604-
605-
def prediction_to_eval_dict(
606-
probabilities: "np.ndarray",
607-
observed: int,
608-
) -> Dict[str, Any]:
609-
"""Convert a production-format probability array and observed count to the
610-
evaluation dictionary format used by the visualisation functions.
611-
612-
Parameters
613-
----------
614-
probabilities : np.ndarray
615-
Probability mass function array where ``probabilities[k] = P(count = k)``.
616-
Typically obtained from ``DemandPrediction.probabilities``.
617-
observed : int
618-
The observed count for this prediction moment.
619-
620-
Returns
621-
-------
622-
Dict[str, Any]
623-
Dictionary with keys ``'agg_predicted'`` (a DataFrame with column
624-
``'agg_proba'``) and ``'agg_observed'`` (int), compatible with
625-
``plot_epudd``, ``plot_randomised_pit``, ``qq_plot``, etc.
626-
"""
627-
agg_predicted = pd.DataFrame(
628-
{"agg_proba": probabilities},
629-
index=range(len(probabilities)),
630-
)
631-
return {"agg_predicted": agg_predicted, "agg_observed": observed}
632-
633-
634-
def _count_observed_admissions(
635-
ed_visits: pd.DataFrame,
636-
snapshot_date: date,
637-
prediction_time: Tuple[int, int],
638-
prediction_window: timedelta,
639-
specialty: Optional[str] = None,
640-
) -> int:
641-
"""Count actual admissions for a given snapshot date, prediction time and
642-
(optionally) specialty.
643-
644-
Parameters
645-
----------
646-
ed_visits : pd.DataFrame
647-
Full ED visits dataframe. Must contain columns ``snapshot_date``,
648-
``prediction_time``, ``is_admitted``, and (when *specialty* is given)
649-
``specialty``.
650-
snapshot_date : date
651-
The date of the snapshot.
652-
prediction_time : Tuple[int, int]
653-
``(hour, minute)`` of the prediction moment.
654-
prediction_window : timedelta
655-
Not used for counting (admissions are identified by the ``is_admitted``
656-
flag on the snapshot), but reserved for future use with time-windowed
657-
counting.
658-
specialty : str, optional
659-
If provided, count only admissions to this specialty.
660-
661-
Returns
662-
-------
663-
int
664-
Number of admitted patients matching the criteria.
665-
"""
666-
mask = (
667-
(ed_visits["snapshot_date"] == snapshot_date)
668-
& (ed_visits["prediction_time"] == prediction_time)
669-
& (ed_visits["is_admitted"].astype(bool))
670-
)
671-
if specialty is not None:
672-
mask = mask & (ed_visits["specialty"] == specialty)
673-
return int(mask.sum())
674-
675-
676-
def get_prob_dist_by_service(
677-
ed_visits: pd.DataFrame,
678-
snapshot_dates: List[date],
679-
prediction_time: Tuple[int, int],
680-
models: tuple,
681-
specialties: List[str],
682-
prediction_window: timedelta,
683-
x1: float,
684-
y1: float,
685-
x2: float,
686-
y2: float,
687-
services: Optional[List[str]] = None,
688-
inpatient_visits: Optional[pd.DataFrame] = None,
689-
flow_selection: Optional[Any] = None,
690-
prediction_component: str = "arrivals",
691-
verbose: bool = False,
692-
) -> Dict[str, Dict[date, Dict[str, Any]]]:
693-
"""Evaluate composed service-level predictions across a set of test dates.
694-
695-
Unlike ``get_prob_dist`` and ``get_prob_dist_using_survival_curve``, which
696-
evaluate a single model component at a time, this function evaluates the
697-
composed prediction for one or more services — multiple flows (ED current,
698-
yet-to-arrive, transfers, departures) convolved together via
699-
``DemandPredictor``.
700-
701-
``build_service_data`` is called once per snapshot date and produces
702-
``ServicePredictionInputs`` for *all* specialties simultaneously, so
703-
requesting multiple services adds negligible cost.
704-
705-
For each snapshot date, this function:
706-
707-
1. Extracts the ED and (optionally) inpatient snapshots for the given
708-
``prediction_time``.
709-
2. Calls ``build_service_data`` to produce ``ServicePredictionInputs``
710-
for all specialties.
711-
3. Runs ``DemandPredictor.predict_service`` with the given
712-
``flow_selection`` for each requested service.
713-
4. Counts the observed admissions from the data for each service.
714-
5. Packages the predicted PMF and observed count into the standard
715-
evaluation dictionary format consumed by ``plot_epudd``,
716-
``plot_randomised_pit``, ``qq_plot``, etc.
717-
718-
Parameters
719-
----------
720-
ed_visits : pd.DataFrame
721-
Full ED visits dataframe (all dates, all prediction times). Must
722-
contain columns ``snapshot_date``, ``prediction_time``,
723-
``is_admitted``, ``specialty``, and ``elapsed_los`` (as timedelta).
724-
snapshot_dates : List[date]
725-
Dates in the test set to evaluate.
726-
prediction_time : Tuple[int, int]
727-
``(hour, minute)`` of the prediction moment.
728-
models : tuple
729-
Seven-element tuple of trained models (or ``None``), as expected by
730-
``build_service_data``: ``(ed_classifier, inpatient_classifier,
731-
spec_model, yta_model, non_ed_yta_model, elective_yta_model,
732-
transfer_model)``.
733-
specialties : List[str]
734-
All specialties to pass to ``build_service_data``. This determines
735-
the full set of ``ServicePredictionInputs`` that are prepared.
736-
prediction_window : timedelta
737-
Prediction horizon.
738-
x1, y1, x2, y2 : float
739-
Parameters for the admission-in-window probability curve.
740-
services : List[str], optional
741-
Which services to evaluate and return results for. Each must be
742-
present in *specialties*. If ``None``, all *specialties* are
743-
evaluated.
744-
inpatient_visits : pd.DataFrame, optional
745-
Full inpatient visits dataframe (all dates, all prediction times).
746-
If provided, must contain columns ``snapshot_date``,
747-
``prediction_time``, and ``elapsed_los`` (as timedelta). Used
748-
to supply inpatient snapshots for departure predictions. If
749-
``None``, departures are predicted as zero.
750-
flow_selection : FlowSelection, optional
751-
Which flows to include in the prediction. If ``None``,
752-
``FlowSelection.default()`` is used.
753-
prediction_component : str, default ``"arrivals"``
754-
Which component of the ``PredictionBundle`` to extract for
755-
evaluation. One of ``"arrivals"``, ``"departures"``, or
756-
``"net_flow"``.
757-
verbose : bool, default ``False``
758-
If ``True``, print a one-line summary on completion.
759-
760-
Returns
761-
-------
762-
Dict[str, Dict[date, Dict[str, Any]]]
763-
Dictionary mapping each service name to a dict mapping each
764-
snapshot date to a dict with keys ``'agg_predicted'`` (DataFrame
765-
with ``'agg_proba'`` column) and ``'agg_observed'`` (int). The
766-
inner dict is the standard format expected by the evaluation
767-
visualisation functions.
768-
769-
Raises
770-
------
771-
ValueError
772-
If ``prediction_component`` is not one of the recognised values, or
773-
if any entry in *services* is not found in *specialties*.
774-
"""
775-
from patientflow.predict.service import build_service_data
776-
from patientflow.predict.demand import DemandPredictor, FlowSelection
777-
778-
valid_components = ("arrivals", "departures", "net_flow")
779-
if prediction_component not in valid_components:
780-
raise ValueError(
781-
f"prediction_component must be one of {valid_components}, "
782-
f"got '{prediction_component}'"
783-
)
784-
785-
if services is None:
786-
services = list(specialties)
787-
else:
788-
unknown = [s for s in services if s not in specialties]
789-
if unknown:
790-
raise ValueError(
791-
f"services {unknown} not found in specialties {specialties}"
792-
)
793-
794-
if flow_selection is None:
795-
flow_selection = FlowSelection.default()
796-
797-
predictor = DemandPredictor(k_sigma=8.0)
798-
result: Dict[str, Dict[date, Dict[str, Any]]] = {
799-
svc: {} for svc in services
800-
}
801-
802-
for dt in snapshot_dates:
803-
ed_snapshot = ed_visits[
804-
(ed_visits["snapshot_date"] == dt)
805-
& (ed_visits["prediction_time"] == prediction_time)
806-
]
807-
808-
if ed_snapshot.empty:
809-
for svc in services:
810-
result[svc][dt] = prediction_to_eval_dict(
811-
np.array([1.0]), observed=0
812-
)
813-
continue
814-
815-
ed_snapshot_processed = ed_snapshot.copy(deep=True)
816-
if not pd.api.types.is_timedelta64_dtype(
817-
ed_snapshot_processed["elapsed_los"]
818-
):
819-
ed_snapshot_processed["elapsed_los"] = pd.to_timedelta(
820-
ed_snapshot_processed["elapsed_los"], unit="s"
821-
)
822-
823-
inpatient_snapshot = None
824-
if inpatient_visits is not None:
825-
ip_mask = (
826-
(inpatient_visits["snapshot_date"] == dt)
827-
& (inpatient_visits["prediction_time"] == prediction_time)
828-
)
829-
ip_filtered = inpatient_visits[ip_mask]
830-
if not ip_filtered.empty:
831-
inpatient_snapshot = ip_filtered.copy(deep=True)
832-
if not pd.api.types.is_timedelta64_dtype(
833-
inpatient_snapshot["elapsed_los"]
834-
):
835-
inpatient_snapshot["elapsed_los"] = pd.to_timedelta(
836-
inpatient_snapshot["elapsed_los"], unit="s"
837-
)
838-
839-
service_data = build_service_data(
840-
models=models,
841-
prediction_time=prediction_time,
842-
ed_snapshots=ed_snapshot_processed,
843-
inpatient_snapshots=inpatient_snapshot,
844-
specialties=specialties,
845-
prediction_window=prediction_window,
846-
x1=x1,
847-
y1=y1,
848-
x2=x2,
849-
y2=y2,
850-
)
851-
852-
for svc in services:
853-
bundle = predictor.predict_service(
854-
inputs=service_data[svc],
855-
flow_selection=flow_selection,
856-
)
857-
858-
demand_prediction = getattr(bundle, prediction_component)
859-
860-
observed = _count_observed_admissions(
861-
ed_visits, dt, prediction_time, prediction_window,
862-
specialty=svc,
863-
)
864-
865-
result[svc][dt] = prediction_to_eval_dict(
866-
demand_prediction.probabilities, observed
867-
)
868-
869-
if verbose:
870-
print(
871-
f"prediction_time={prediction_time}: "
872-
f"{len(services)} services × {len(snapshot_dates)} dates"
873-
)
874-
875-
return result

0 commit comments

Comments
 (0)