Skip to content

Commit 1e90fe1

Browse files
authored
Merge pull request #68 from UCL-CORU/tidy-model-loading
resolve model model load issues
2 parents faf93ff + 97d6663 commit 1e90fe1

5 files changed

Lines changed: 2 additions & 170 deletions

File tree

docs/notebooks/3c_Predict_bed_counts_without_using_patient_snapshots.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,6 @@ The `predict()` method will:
416416

417417
```python
418418
from patientflow.predictors.weighted_poisson_predictor import WeightedPoissonPredictor
419-
from joblib import dump, load
420419

421420
yta_model = WeightedPoissonPredictor(verbose=True)
422421
num_days = (start_validation_set - start_training_set).days

notebooks/3c_Predict_bed_counts_without_using_patient_snapshots.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,6 @@
10751075
],
10761076
"source": [
10771077
"from patientflow.predictors.weighted_poisson_predictor import WeightedPoissonPredictor\n",
1078-
"from joblib import dump, load\n",
10791078
"\n",
10801079
"yta_model = WeightedPoissonPredictor(verbose=True)\n",
10811080
"num_days = (start_validation_set - start_training_set).days\n",

src/patientflow/load.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
import sys
4040

4141
import pandas as pd
42-
from joblib import load
43-
from patientflow.errors import ModelLoadError
4442

4543
import yaml
4644
from typing import Any, Dict, Tuple, Union, Optional
@@ -549,49 +547,6 @@ def get_model_key(model_name, prediction_time):
549547
return model_name
550548

551549

552-
def load_saved_model(model_file_path, model_name, prediction_time=None):
553-
"""
554-
Load a saved model from a file.
555-
556-
Parameters
557-
----------
558-
model_file_path : Path
559-
The path to the directory where the model is saved.
560-
model_name : str
561-
The base name of the model.
562-
prediction_time : tuple of int, optional
563-
The time of day the model was trained for.
564-
565-
Returns
566-
-------
567-
Any
568-
The loaded model.
569-
570-
Raises
571-
------
572-
ModelLoadError
573-
If the model file cannot be found or loaded.
574-
"""
575-
if prediction_time:
576-
# retrieve model based on the time of day it is trained for
577-
model_name = get_model_key(model_name, prediction_time)
578-
579-
full_path = model_file_path / model_name
580-
full_path = full_path.with_suffix(".joblib")
581-
582-
try:
583-
model = load(full_path)
584-
return model
585-
except FileNotFoundError:
586-
# print(f"Model named {model_name} not found at path: {model_file_path}")
587-
raise ModelLoadError(
588-
f"Model named {model_name} not found at path: {model_file_path}"
589-
)
590-
except Exception as e:
591-
# print(f"Error loading model: {e}")
592-
raise ModelLoadError(f"Error loading model called {model_name}: {e}")
593-
594-
595550
def get_dict_cols(df):
596551
"""
597552
Categorize DataFrame columns into predefined groups.

src/patientflow/prepare.py

Lines changed: 2 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,7 @@
88
99
Functions
1010
---------
11-
prepare_for_inference(model_file_path, model_name, prediction_time=None,
12-
model_only=False, df=None, data_path=None,
13-
single_snapshot_per_visit=True, index_column='snapshot_id',
14-
sort_columns=None, eval_columns=None,
15-
exclude_from_training_data=None)
16-
Loads a model and prepares data for inference.
17-
18-
select_one_snapshot_per_visit(df, visit_col, seed=42)
11+
git select_one_snapshot_per_visit(df, visit_col, seed=42)
1912
Selects one snapshot per visit based on a random number and returns the filtered DataFrame.
2013
2114
prepare_patient_snapshots(df, prediction_time, exclude_columns, single_snapshot_per_visit=True)
@@ -31,7 +24,7 @@
3124
import pandas as pd
3225
import numpy as np
3326
import random
34-
from patientflow.load import load_saved_model, get_dict_cols, data_from_csv
27+
from patientflow.load import get_dict_cols
3528
from datetime import datetime, date
3629

3730

@@ -517,118 +510,6 @@ def prepare_patient_snapshots(
517510
return df_tod, y
518511

519512

520-
def prepare_for_inference(
521-
model_file_path,
522-
model_name,
523-
prediction_time=None,
524-
model_only=False,
525-
df=None,
526-
data_path=None,
527-
single_snapshot_per_visit=True,
528-
index_column="snapshot_id",
529-
sort_columns=["visit_number", "snapshot_date", "prediction_time"],
530-
eval_columns=["prediction_time", "consultation_sequence", "final_sequence"],
531-
exclude_from_training_data=["visit_number", "snapshot_date", "prediction_time"],
532-
):
533-
"""
534-
Load a trained model and prepare data for making predictions.
535-
536-
This function retrieves a trained model from a specified file path and,
537-
if requested, prepares the data required for inference. The data can be
538-
provided either as a DataFrame or as a file path to a CSV file. The function
539-
allows filtering and processing of the data to match the model's requirements.
540-
If available, it will use the calibrated pipeline instead of the regular pipeline.
541-
542-
Parameters
543-
----------
544-
model_file_path : str
545-
The file path where the trained model is saved.
546-
model_name : str
547-
The name of the model to be loaded.
548-
prediction_time : str, optional
549-
The time at which predictions are to be made. This is used to filter
550-
the data for the relevant time snapshot.
551-
model_only : bool, optional
552-
If True, only the model is returned. If False, both the prepared data
553-
and the model are returned. Default is False.
554-
df : pandas.DataFrame, optional
555-
The DataFrame containing the data to be used for inference. If not
556-
provided, data_path must be specified.
557-
data_path : str, optional
558-
The file path to a CSV file containing the data to be used for inference.
559-
Ignored if `df` is provided.
560-
single_snapshot_per_visit : bool, optional
561-
If True, only a single snapshot per visit is considered. Default is True.
562-
index_column : str, optional
563-
The name of the index column in the data. Default is 'snapshot_id'.
564-
sort_columns : list of str, optional
565-
The columns to sort the data by. Default is ["visit_number", "snapshot_date", "prediction_time"].
566-
eval_columns : list of str, optional
567-
The columns that require literal evaluation of their content when loading from csv.
568-
Default is ["prediction_time", "consultation_sequence", "final_sequence"].
569-
exclude_from_training_data : list of str, optional
570-
The columns to be excluded from the training data. Default is ["visit_number", "snapshot_date", "prediction_time"].
571-
572-
Returns
573-
-------
574-
model : object
575-
The loaded model (calibrated pipeline if available, otherwise regular pipeline).
576-
X_test : pandas.DataFrame, optional
577-
The features prepared for testing, returned only if model_only is False.
578-
y_test : pandas.Series, optional
579-
The labels corresponding to X_test, returned only if model_only is False.
580-
581-
Raises
582-
------
583-
KeyError
584-
If the 'training_validation_test' column is not found in the provided DataFrame.
585-
586-
Notes
587-
-----
588-
- Either `df` or `data_path` must be provided. If neither is provided or if `df`
589-
is empty, the function will print an error message and return None.
590-
- The function will automatically use a calibrated pipeline if one is available
591-
in the model, otherwise it will fall back to the regular pipeline.
592-
"""
593-
594-
# retrieve model trained for this time of day
595-
model = load_saved_model(model_file_path, model_name, prediction_time)
596-
597-
# Use calibrated pipeline if available, otherwise use regular pipeline
598-
if hasattr(model, "calibrated_pipeline") and model.calibrated_pipeline is not None:
599-
pipeline = model.calibrated_pipeline
600-
else:
601-
pipeline = model.pipeline
602-
603-
if model_only:
604-
return pipeline
605-
606-
if data_path:
607-
df = data_from_csv(data_path, index_column, sort_columns, eval_columns)
608-
elif df is None or df.empty:
609-
print("Please supply a dataset if not passing a data path")
610-
return None
611-
612-
try:
613-
test_df = (
614-
df[df.training_validation_test == "test"]
615-
.drop(columns="training_validation_test")
616-
.copy()
617-
)
618-
except KeyError:
619-
print("Column training_validation_test not found in dataframe")
620-
return None
621-
622-
X_test, y_test = prepare_patient_snapshots(
623-
test_df,
624-
prediction_time,
625-
exclude_from_training_data,
626-
single_snapshot_per_visit,
627-
)
628-
629-
return X_test, y_test, pipeline
630-
631-
632513
def prepare_group_snapshot_dict(df, start_dt=None, end_dt=None):
633514
"""
634515
Prepares a dictionary mapping snapshot dates to their corresponding snapshot indices.

tests/test_create_predictions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,6 @@ def create_spec_model(df, apply_special_category_filtering):
183183
admit_col="is_admitted",
184184
)
185185
model.fit(df)
186-
# full_path = self.model_file_path / str("ed_specialty.joblib")
187-
# joblib.dump(model, full_path)
188186

189187
return model
190188

0 commit comments

Comments
 (0)