Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ The `predict()` method will:

```python
from patientflow.predictors.weighted_poisson_predictor import WeightedPoissonPredictor
from joblib import dump, load

yta_model = WeightedPoissonPredictor(verbose=True)
num_days = (start_validation_set - start_training_set).days
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,6 @@
],
"source": [
"from patientflow.predictors.weighted_poisson_predictor import WeightedPoissonPredictor\n",
"from joblib import dump, load\n",
"\n",
"yta_model = WeightedPoissonPredictor(verbose=True)\n",
"num_days = (start_validation_set - start_training_set).days\n",
Expand Down
45 changes: 0 additions & 45 deletions src/patientflow/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
import sys

import pandas as pd
from joblib import load
from patientflow.errors import ModelLoadError

import yaml
from typing import Any, Dict, Tuple, Union, Optional
Expand Down Expand Up @@ -549,49 +547,6 @@ def get_model_key(model_name, prediction_time):
return model_name


def load_saved_model(model_file_path, model_name, prediction_time=None):
"""
Load a saved model from a file.

Parameters
----------
model_file_path : Path
The path to the directory where the model is saved.
model_name : str
The base name of the model.
prediction_time : tuple of int, optional
The time of day the model was trained for.

Returns
-------
Any
The loaded model.

Raises
------
ModelLoadError
If the model file cannot be found or loaded.
"""
if prediction_time:
# retrieve model based on the time of day it is trained for
model_name = get_model_key(model_name, prediction_time)

full_path = model_file_path / model_name
full_path = full_path.with_suffix(".joblib")

try:
model = load(full_path)
return model
except FileNotFoundError:
# print(f"Model named {model_name} not found at path: {model_file_path}")
raise ModelLoadError(
f"Model named {model_name} not found at path: {model_file_path}"
)
except Exception as e:
# print(f"Error loading model: {e}")
raise ModelLoadError(f"Error loading model called {model_name}: {e}")


def get_dict_cols(df):
"""
Categorize DataFrame columns into predefined groups.
Expand Down
123 changes: 2 additions & 121 deletions src/patientflow/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@

Functions
---------
prepare_for_inference(model_file_path, model_name, prediction_time=None,
model_only=False, df=None, data_path=None,
single_snapshot_per_visit=True, index_column='snapshot_id',
sort_columns=None, eval_columns=None,
exclude_from_training_data=None)
Loads a model and prepares data for inference.

select_one_snapshot_per_visit(df, visit_col, seed=42)
git select_one_snapshot_per_visit(df, visit_col, seed=42)
Selects one snapshot per visit based on a random number and returns the filtered DataFrame.

prepare_patient_snapshots(df, prediction_time, exclude_columns, single_snapshot_per_visit=True)
Expand All @@ -31,7 +24,7 @@
import pandas as pd
import numpy as np
import random
from patientflow.load import load_saved_model, get_dict_cols, data_from_csv
from patientflow.load import get_dict_cols
from datetime import datetime, date


Expand Down Expand Up @@ -517,118 +510,6 @@ def prepare_patient_snapshots(
return df_tod, y


def prepare_for_inference(
model_file_path,
model_name,
prediction_time=None,
model_only=False,
df=None,
data_path=None,
single_snapshot_per_visit=True,
index_column="snapshot_id",
sort_columns=["visit_number", "snapshot_date", "prediction_time"],
eval_columns=["prediction_time", "consultation_sequence", "final_sequence"],
exclude_from_training_data=["visit_number", "snapshot_date", "prediction_time"],
):
"""
Load a trained model and prepare data for making predictions.

This function retrieves a trained model from a specified file path and,
if requested, prepares the data required for inference. The data can be
provided either as a DataFrame or as a file path to a CSV file. The function
allows filtering and processing of the data to match the model's requirements.
If available, it will use the calibrated pipeline instead of the regular pipeline.

Parameters
----------
model_file_path : str
The file path where the trained model is saved.
model_name : str
The name of the model to be loaded.
prediction_time : str, optional
The time at which predictions are to be made. This is used to filter
the data for the relevant time snapshot.
model_only : bool, optional
If True, only the model is returned. If False, both the prepared data
and the model are returned. Default is False.
df : pandas.DataFrame, optional
The DataFrame containing the data to be used for inference. If not
provided, data_path must be specified.
data_path : str, optional
The file path to a CSV file containing the data to be used for inference.
Ignored if `df` is provided.
single_snapshot_per_visit : bool, optional
If True, only a single snapshot per visit is considered. Default is True.
index_column : str, optional
The name of the index column in the data. Default is 'snapshot_id'.
sort_columns : list of str, optional
The columns to sort the data by. Default is ["visit_number", "snapshot_date", "prediction_time"].
eval_columns : list of str, optional
The columns that require literal evaluation of their content when loading from csv.
Default is ["prediction_time", "consultation_sequence", "final_sequence"].
exclude_from_training_data : list of str, optional
The columns to be excluded from the training data. Default is ["visit_number", "snapshot_date", "prediction_time"].

Returns
-------
model : object
The loaded model (calibrated pipeline if available, otherwise regular pipeline).
X_test : pandas.DataFrame, optional
The features prepared for testing, returned only if model_only is False.
y_test : pandas.Series, optional
The labels corresponding to X_test, returned only if model_only is False.

Raises
------
KeyError
If the 'training_validation_test' column is not found in the provided DataFrame.

Notes
-----
- Either `df` or `data_path` must be provided. If neither is provided or if `df`
is empty, the function will print an error message and return None.
- The function will automatically use a calibrated pipeline if one is available
in the model, otherwise it will fall back to the regular pipeline.
"""

# retrieve model trained for this time of day
model = load_saved_model(model_file_path, model_name, prediction_time)

# Use calibrated pipeline if available, otherwise use regular pipeline
if hasattr(model, "calibrated_pipeline") and model.calibrated_pipeline is not None:
pipeline = model.calibrated_pipeline
else:
pipeline = model.pipeline

if model_only:
return pipeline

if data_path:
df = data_from_csv(data_path, index_column, sort_columns, eval_columns)
elif df is None or df.empty:
print("Please supply a dataset if not passing a data path")
return None

try:
test_df = (
df[df.training_validation_test == "test"]
.drop(columns="training_validation_test")
.copy()
)
except KeyError:
print("Column training_validation_test not found in dataframe")
return None

X_test, y_test = prepare_patient_snapshots(
test_df,
prediction_time,
exclude_from_training_data,
single_snapshot_per_visit,
)

return X_test, y_test, pipeline


def prepare_group_snapshot_dict(df, start_dt=None, end_dt=None):
"""
Prepares a dictionary mapping snapshot dates to their corresponding snapshot indices.
Expand Down
2 changes: 0 additions & 2 deletions tests/test_create_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def create_spec_model(df, apply_special_category_filtering):
admit_col="is_admitted",
)
model.fit(df)
# full_path = self.model_file_path / str("ed_specialty.joblib")
# joblib.dump(model, full_path)

return model

Expand Down