|
8 | 8 |
|
9 | 9 | Functions |
10 | 10 | --------- |
11 | | -prepare_for_inference(model_file_path, model_name, prediction_time=None, |
12 | | - model_only=False, df=None, data_path=None, |
13 | | - single_snapshot_per_visit=True, index_column='snapshot_id', |
14 | | - sort_columns=None, eval_columns=None, |
15 | | - exclude_from_training_data=None) |
16 | | - Loads a model and prepares data for inference. |
17 | | -
|
18 | | -select_one_snapshot_per_visit(df, visit_col, seed=42) |
| 11 | +git select_one_snapshot_per_visit(df, visit_col, seed=42) |
19 | 12 | Selects one snapshot per visit based on a random number and returns the filtered DataFrame. |
20 | 13 |
|
21 | 14 | prepare_patient_snapshots(df, prediction_time, exclude_columns, single_snapshot_per_visit=True) |
|
31 | 24 | import pandas as pd |
32 | 25 | import numpy as np |
33 | 26 | import random |
34 | | -from patientflow.load import load_saved_model, get_dict_cols, data_from_csv |
| 27 | +from patientflow.load import get_dict_cols |
35 | 28 | from datetime import datetime, date |
36 | 29 |
|
37 | 30 |
|
@@ -517,118 +510,6 @@ def prepare_patient_snapshots( |
517 | 510 | return df_tod, y |
518 | 511 |
|
519 | 512 |
|
520 | | -def prepare_for_inference( |
521 | | - model_file_path, |
522 | | - model_name, |
523 | | - prediction_time=None, |
524 | | - model_only=False, |
525 | | - df=None, |
526 | | - data_path=None, |
527 | | - single_snapshot_per_visit=True, |
528 | | - index_column="snapshot_id", |
529 | | - sort_columns=["visit_number", "snapshot_date", "prediction_time"], |
530 | | - eval_columns=["prediction_time", "consultation_sequence", "final_sequence"], |
531 | | - exclude_from_training_data=["visit_number", "snapshot_date", "prediction_time"], |
532 | | -): |
533 | | - """ |
534 | | - Load a trained model and prepare data for making predictions. |
535 | | -
|
536 | | - This function retrieves a trained model from a specified file path and, |
537 | | - if requested, prepares the data required for inference. The data can be |
538 | | - provided either as a DataFrame or as a file path to a CSV file. The function |
539 | | - allows filtering and processing of the data to match the model's requirements. |
540 | | - If available, it will use the calibrated pipeline instead of the regular pipeline. |
541 | | -
|
542 | | - Parameters |
543 | | - ---------- |
544 | | - model_file_path : str |
545 | | - The file path where the trained model is saved. |
546 | | - model_name : str |
547 | | - The name of the model to be loaded. |
548 | | - prediction_time : str, optional |
549 | | - The time at which predictions are to be made. This is used to filter |
550 | | - the data for the relevant time snapshot. |
551 | | - model_only : bool, optional |
552 | | - If True, only the model is returned. If False, both the prepared data |
553 | | - and the model are returned. Default is False. |
554 | | - df : pandas.DataFrame, optional |
555 | | - The DataFrame containing the data to be used for inference. If not |
556 | | - provided, data_path must be specified. |
557 | | - data_path : str, optional |
558 | | - The file path to a CSV file containing the data to be used for inference. |
559 | | - Ignored if `df` is provided. |
560 | | - single_snapshot_per_visit : bool, optional |
561 | | - If True, only a single snapshot per visit is considered. Default is True. |
562 | | - index_column : str, optional |
563 | | - The name of the index column in the data. Default is 'snapshot_id'. |
564 | | - sort_columns : list of str, optional |
565 | | - The columns to sort the data by. Default is ["visit_number", "snapshot_date", "prediction_time"]. |
566 | | - eval_columns : list of str, optional |
567 | | - The columns that require literal evaluation of their content when loading from csv. |
568 | | - Default is ["prediction_time", "consultation_sequence", "final_sequence"]. |
569 | | - exclude_from_training_data : list of str, optional |
570 | | - The columns to be excluded from the training data. Default is ["visit_number", "snapshot_date", "prediction_time"]. |
571 | | -
|
572 | | - Returns |
573 | | - ------- |
574 | | - model : object |
575 | | - The loaded model (calibrated pipeline if available, otherwise regular pipeline). |
576 | | - X_test : pandas.DataFrame, optional |
577 | | - The features prepared for testing, returned only if model_only is False. |
578 | | - y_test : pandas.Series, optional |
579 | | - The labels corresponding to X_test, returned only if model_only is False. |
580 | | -
|
581 | | - Raises |
582 | | - ------ |
583 | | - KeyError |
584 | | - If the 'training_validation_test' column is not found in the provided DataFrame. |
585 | | -
|
586 | | - Notes |
587 | | - ----- |
588 | | - - Either `df` or `data_path` must be provided. If neither is provided or if `df` |
589 | | - is empty, the function will print an error message and return None. |
590 | | - - The function will automatically use a calibrated pipeline if one is available |
591 | | - in the model, otherwise it will fall back to the regular pipeline. |
592 | | - """ |
593 | | - |
594 | | - # retrieve model trained for this time of day |
595 | | - model = load_saved_model(model_file_path, model_name, prediction_time) |
596 | | - |
597 | | - # Use calibrated pipeline if available, otherwise use regular pipeline |
598 | | - if hasattr(model, "calibrated_pipeline") and model.calibrated_pipeline is not None: |
599 | | - pipeline = model.calibrated_pipeline |
600 | | - else: |
601 | | - pipeline = model.pipeline |
602 | | - |
603 | | - if model_only: |
604 | | - return pipeline |
605 | | - |
606 | | - if data_path: |
607 | | - df = data_from_csv(data_path, index_column, sort_columns, eval_columns) |
608 | | - elif df is None or df.empty: |
609 | | - print("Please supply a dataset if not passing a data path") |
610 | | - return None |
611 | | - |
612 | | - try: |
613 | | - test_df = ( |
614 | | - df[df.training_validation_test == "test"] |
615 | | - .drop(columns="training_validation_test") |
616 | | - .copy() |
617 | | - ) |
618 | | - except KeyError: |
619 | | - print("Column training_validation_test not found in dataframe") |
620 | | - return None |
621 | | - |
622 | | - X_test, y_test = prepare_patient_snapshots( |
623 | | - test_df, |
624 | | - prediction_time, |
625 | | - exclude_from_training_data, |
626 | | - single_snapshot_per_visit, |
627 | | - ) |
628 | | - |
629 | | - return X_test, y_test, pipeline |
630 | | - |
631 | | - |
632 | 513 | def prepare_group_snapshot_dict(df, start_dt=None, end_dt=None): |
633 | 514 | """ |
634 | 515 | Prepares a dictionary mapping snapshot dates to their corresponding snapshot indices. |
|
0 commit comments