4343get_prob_dist_using_survival_curve : function
4444 Calculate probability distributions for each snapshot date based on given model predictions, using a survival curve to predict the probability of each patient being admitted within a given prediction window.
4545
46- get_prob_dist_by_service : function
47- Evaluate composed service-level predictions across the test set for one or more
48- services, producing probability distributions and observed values in the standard
49- evaluation format.
50-
5146"""
5247
5348import pandas as pd
@@ -453,14 +448,28 @@ def get_prob_dist(
453448 )
454449
455450 prob_dist_dict = {}
451+ if verbose :
452+ print (
453+ f"Calculating probability distributions for { len (snapshots_dict )} snapshot dates"
454+ )
455+
456+ if len (snapshots_dict ) > 10 :
457+ print (
458+ "Using efficient generating function approach - much faster than before!"
459+ )
460+
461+ # Initialize a counter for notifying the user every 10 snapshot dates processed
462+ count = 0
456463
457464 for dt , snapshots_to_include in snapshots_dict .items ():
458465 if len (snapshots_to_include ) == 0 :
466+ # Create an empty dictionary for the current snapshot date
459467 prob_dist_dict [dt ] = {
460468 "agg_predicted" : pd .DataFrame ({"agg_proba" : [1 ]}, index = [0 ]),
461469 "agg_observed" : 0 ,
462470 }
463471 else :
472+ # Ensure the lengths of test features and outcomes are equal
464473 assert len (X_test .loc [snapshots_to_include ]) == len (
465474 y_test .loc [snapshots_to_include ]
466475 ), "Mismatch in lengths of X_test and y_test snapshots."
@@ -470,13 +479,15 @@ def get_prob_dist(
470479 else :
471480 prediction_moment_weights = weights .loc [snapshots_to_include ].values
472481
482+ # Apply category filter
473483 if category_filter is None :
474484 prediction_moment_category_filter = None
475485 else :
476486 prediction_moment_category_filter = category_filter .loc [
477487 snapshots_to_include
478488 ]
479489
490+ # Use the refactored generating function approach
480491 prob_dist_dict [dt ] = get_prob_dist_for_prediction_moment (
481492 X_test = X_test .loc [snapshots_to_include ],
482493 y_test = y_test .loc [snapshots_to_include ],
@@ -486,6 +497,11 @@ def get_prob_dist(
486497 normal_approx_threshold = normal_approx_threshold ,
487498 )
488499
500+ # Increment the counter and notify the user every 10 snapshot dates processed
501+ count += 1
502+ if verbose and count % 10 == 0 and count != len (snapshots_dict ):
503+ print (f"Processed { count } snapshot dates" )
504+
489505 if verbose :
490506 print (f"Processed { len (snapshots_dict )} snapshot dates" )
491507
@@ -562,7 +578,12 @@ def get_prob_dist_using_survival_curve(
562578 )
563579
564580 prob_dist_dict = {}
581+ if verbose :
582+ print (
583+ f"Calculating probability distributions for { len (snapshot_dates )} snapshot dates"
584+ )
565585
586+ # Create prediction context that will be the same for all dates
566587 prediction_context = {category : {"prediction_time" : prediction_time }}
567588
568589 for dt in snapshot_dates :
@@ -600,276 +621,3 @@ def get_prob_dist_using_survival_curve(
600621 print (f"Processed { len (snapshot_dates )} snapshot dates" )
601622
602623 return prob_dist_dict
603-
604-
605- def prediction_to_eval_dict (
606- probabilities : "np.ndarray" ,
607- observed : int ,
608- ) -> Dict [str , Any ]:
609- """Convert a production-format probability array and observed count to the
610- evaluation dictionary format used by the visualisation functions.
611-
612- Parameters
613- ----------
614- probabilities : np.ndarray
615- Probability mass function array where ``probabilities[k] = P(count = k)``.
616- Typically obtained from ``DemandPrediction.probabilities``.
617- observed : int
618- The observed count for this prediction moment.
619-
620- Returns
621- -------
622- Dict[str, Any]
623- Dictionary with keys ``'agg_predicted'`` (a DataFrame with column
624- ``'agg_proba'``) and ``'agg_observed'`` (int), compatible with
625- ``plot_epudd``, ``plot_randomised_pit``, ``qq_plot``, etc.
626- """
627- agg_predicted = pd .DataFrame (
628- {"agg_proba" : probabilities },
629- index = range (len (probabilities )),
630- )
631- return {"agg_predicted" : agg_predicted , "agg_observed" : observed }
632-
633-
634- def _count_observed_admissions (
635- ed_visits : pd .DataFrame ,
636- snapshot_date : date ,
637- prediction_time : Tuple [int , int ],
638- prediction_window : timedelta ,
639- specialty : Optional [str ] = None ,
640- ) -> int :
641- """Count actual admissions for a given snapshot date, prediction time and
642- (optionally) specialty.
643-
644- Parameters
645- ----------
646- ed_visits : pd.DataFrame
647- Full ED visits dataframe. Must contain columns ``snapshot_date``,
648- ``prediction_time``, ``is_admitted``, and (when *specialty* is given)
649- ``specialty``.
650- snapshot_date : date
651- The date of the snapshot.
652- prediction_time : Tuple[int, int]
653- ``(hour, minute)`` of the prediction moment.
654- prediction_window : timedelta
655- Not used for counting (admissions are identified by the ``is_admitted``
656- flag on the snapshot), but reserved for future use with time-windowed
657- counting.
658- specialty : str, optional
659- If provided, count only admissions to this specialty.
660-
661- Returns
662- -------
663- int
664- Number of admitted patients matching the criteria.
665- """
666- mask = (
667- (ed_visits ["snapshot_date" ] == snapshot_date )
668- & (ed_visits ["prediction_time" ] == prediction_time )
669- & (ed_visits ["is_admitted" ].astype (bool ))
670- )
671- if specialty is not None :
672- mask = mask & (ed_visits ["specialty" ] == specialty )
673- return int (mask .sum ())
674-
675-
676- def get_prob_dist_by_service (
677- ed_visits : pd .DataFrame ,
678- snapshot_dates : List [date ],
679- prediction_time : Tuple [int , int ],
680- models : tuple ,
681- specialties : List [str ],
682- prediction_window : timedelta ,
683- x1 : float ,
684- y1 : float ,
685- x2 : float ,
686- y2 : float ,
687- services : Optional [List [str ]] = None ,
688- inpatient_visits : Optional [pd .DataFrame ] = None ,
689- flow_selection : Optional [Any ] = None ,
690- prediction_component : str = "arrivals" ,
691- verbose : bool = False ,
692- ) -> Dict [str , Dict [date , Dict [str , Any ]]]:
693- """Evaluate composed service-level predictions across a set of test dates.
694-
695- Unlike ``get_prob_dist`` and ``get_prob_dist_using_survival_curve``, which
696- evaluate a single model component at a time, this function evaluates the
697- composed prediction for one or more services — multiple flows (ED current,
698- yet-to-arrive, transfers, departures) convolved together via
699- ``DemandPredictor``.
700-
701- ``build_service_data`` is called once per snapshot date and produces
702- ``ServicePredictionInputs`` for *all* specialties simultaneously, so
703- requesting multiple services adds negligible cost.
704-
705- For each snapshot date, this function:
706-
707- 1. Extracts the ED and (optionally) inpatient snapshots for the given
708- ``prediction_time``.
709- 2. Calls ``build_service_data`` to produce ``ServicePredictionInputs``
710- for all specialties.
711- 3. Runs ``DemandPredictor.predict_service`` with the given
712- ``flow_selection`` for each requested service.
713- 4. Counts the observed admissions from the data for each service.
714- 5. Packages the predicted PMF and observed count into the standard
715- evaluation dictionary format consumed by ``plot_epudd``,
716- ``plot_randomised_pit``, ``qq_plot``, etc.
717-
718- Parameters
719- ----------
720- ed_visits : pd.DataFrame
721- Full ED visits dataframe (all dates, all prediction times). Must
722- contain columns ``snapshot_date``, ``prediction_time``,
723- ``is_admitted``, ``specialty``, and ``elapsed_los`` (as timedelta).
724- snapshot_dates : List[date]
725- Dates in the test set to evaluate.
726- prediction_time : Tuple[int, int]
727- ``(hour, minute)`` of the prediction moment.
728- models : tuple
729- Seven-element tuple of trained models (or ``None``), as expected by
730- ``build_service_data``: ``(ed_classifier, inpatient_classifier,
731- spec_model, yta_model, non_ed_yta_model, elective_yta_model,
732- transfer_model)``.
733- specialties : List[str]
734- All specialties to pass to ``build_service_data``. This determines
735- the full set of ``ServicePredictionInputs`` that are prepared.
736- prediction_window : timedelta
737- Prediction horizon.
738- x1, y1, x2, y2 : float
739- Parameters for the admission-in-window probability curve.
740- services : List[str], optional
741- Which services to evaluate and return results for. Each must be
742- present in *specialties*. If ``None``, all *specialties* are
743- evaluated.
744- inpatient_visits : pd.DataFrame, optional
745- Full inpatient visits dataframe (all dates, all prediction times).
746- If provided, must contain columns ``snapshot_date``,
747- ``prediction_time``, and ``elapsed_los`` (as timedelta). Used
748- to supply inpatient snapshots for departure predictions. If
749- ``None``, departures are predicted as zero.
750- flow_selection : FlowSelection, optional
751- Which flows to include in the prediction. If ``None``,
752- ``FlowSelection.default()`` is used.
753- prediction_component : str, default ``"arrivals"``
754- Which component of the ``PredictionBundle`` to extract for
755- evaluation. One of ``"arrivals"``, ``"departures"``, or
756- ``"net_flow"``.
757- verbose : bool, default ``False``
758- If ``True``, print a one-line summary on completion.
759-
760- Returns
761- -------
762- Dict[str, Dict[date, Dict[str, Any]]]
763- Dictionary mapping each service name to a dict mapping each
764- snapshot date to a dict with keys ``'agg_predicted'`` (DataFrame
765- with ``'agg_proba'`` column) and ``'agg_observed'`` (int). The
766- inner dict is the standard format expected by the evaluation
767- visualisation functions.
768-
769- Raises
770- ------
771- ValueError
772- If ``prediction_component`` is not one of the recognised values, or
773- if any entry in *services* is not found in *specialties*.
774- """
775- from patientflow .predict .service import build_service_data
776- from patientflow .predict .demand import DemandPredictor , FlowSelection
777-
778- valid_components = ("arrivals" , "departures" , "net_flow" )
779- if prediction_component not in valid_components :
780- raise ValueError (
781- f"prediction_component must be one of { valid_components } , "
782- f"got '{ prediction_component } '"
783- )
784-
785- if services is None :
786- services = list (specialties )
787- else :
788- unknown = [s for s in services if s not in specialties ]
789- if unknown :
790- raise ValueError (
791- f"services { unknown } not found in specialties { specialties } "
792- )
793-
794- if flow_selection is None :
795- flow_selection = FlowSelection .default ()
796-
797- predictor = DemandPredictor (k_sigma = 8.0 )
798- result : Dict [str , Dict [date , Dict [str , Any ]]] = {
799- svc : {} for svc in services
800- }
801-
802- for dt in snapshot_dates :
803- ed_snapshot = ed_visits [
804- (ed_visits ["snapshot_date" ] == dt )
805- & (ed_visits ["prediction_time" ] == prediction_time )
806- ]
807-
808- if ed_snapshot .empty :
809- for svc in services :
810- result [svc ][dt ] = prediction_to_eval_dict (
811- np .array ([1.0 ]), observed = 0
812- )
813- continue
814-
815- ed_snapshot_processed = ed_snapshot .copy (deep = True )
816- if not pd .api .types .is_timedelta64_dtype (
817- ed_snapshot_processed ["elapsed_los" ]
818- ):
819- ed_snapshot_processed ["elapsed_los" ] = pd .to_timedelta (
820- ed_snapshot_processed ["elapsed_los" ], unit = "s"
821- )
822-
823- inpatient_snapshot = None
824- if inpatient_visits is not None :
825- ip_mask = (
826- (inpatient_visits ["snapshot_date" ] == dt )
827- & (inpatient_visits ["prediction_time" ] == prediction_time )
828- )
829- ip_filtered = inpatient_visits [ip_mask ]
830- if not ip_filtered .empty :
831- inpatient_snapshot = ip_filtered .copy (deep = True )
832- if not pd .api .types .is_timedelta64_dtype (
833- inpatient_snapshot ["elapsed_los" ]
834- ):
835- inpatient_snapshot ["elapsed_los" ] = pd .to_timedelta (
836- inpatient_snapshot ["elapsed_los" ], unit = "s"
837- )
838-
839- service_data = build_service_data (
840- models = models ,
841- prediction_time = prediction_time ,
842- ed_snapshots = ed_snapshot_processed ,
843- inpatient_snapshots = inpatient_snapshot ,
844- specialties = specialties ,
845- prediction_window = prediction_window ,
846- x1 = x1 ,
847- y1 = y1 ,
848- x2 = x2 ,
849- y2 = y2 ,
850- )
851-
852- for svc in services :
853- bundle = predictor .predict_service (
854- inputs = service_data [svc ],
855- flow_selection = flow_selection ,
856- )
857-
858- demand_prediction = getattr (bundle , prediction_component )
859-
860- observed = _count_observed_admissions (
861- ed_visits , dt , prediction_time , prediction_window ,
862- specialty = svc ,
863- )
864-
865- result [svc ][dt ] = prediction_to_eval_dict (
866- demand_prediction .probabilities , observed
867- )
868-
869- if verbose :
870- print (
871- f"prediction_time={ prediction_time } : "
872- f"{ len (services )} services × { len (snapshot_dates )} dates"
873- )
874-
875- return result
0 commit comments