diff --git a/src/patientflow/train/classifiers.py b/src/patientflow/train/classifiers.py index 2d0120d..ef50387 100644 --- a/src/patientflow/train/classifiers.py +++ b/src/patientflow/train/classifiers.py @@ -60,6 +60,23 @@ ) +def _is_string_like_column(series: Series) -> bool: + """Check whether a Series holds string-like values that need encoding. + + Uses an exclusion-based approach so that any current or future + string-like dtype (``object``, ``StringDtype``, ``ArrowDtype("string")``, + ``CategoricalDtype`` with string categories, etc.) is detected without + needing to enumerate each one explicitly. + """ + if isinstance(series.dtype, pd.CategoricalDtype): + series = series.cat.categories.to_series() + return not ( + pd.api.types.is_numeric_dtype(series) + or pd.api.types.is_bool_dtype(series) + or pd.api.types.is_datetime64_any_dtype(series) + ) + + class FeatureColumnTransformer(BaseEstimator, TransformerMixin): """ Ensure that input data has exactly the columns seen during training. @@ -118,7 +135,7 @@ def fit( default = 0.0 elif pd.api.types.is_datetime64_any_dtype(series): default = pd.NaT - elif series.dtype == "object": + elif _is_string_like_column(series): mode = series.mode(dropna=True) default = mode.iloc[0] if not mode.empty else "Unknown" else: @@ -328,8 +345,7 @@ def create_column_transformer( # Keep boolean columns as single 0/1 features rather than one-hot encoding elif df[col].dtype == "bool": transformers.append((col, "passthrough", [col])) - # One-hot encode string/object categoricals - elif df[col].dtype == "object": + elif _is_string_like_column(df[col]): transformers.append((col, OneHotEncoder(handle_unknown="ignore"), [col])) # Keep non-boolean binary (0/1) columns as single features elif df[col].nunique() == 2: diff --git a/tests/test_classifiers.py b/tests/test_classifiers.py index 881bb50..36083f6 100644 --- a/tests/test_classifiers.py +++ b/tests/test_classifiers.py @@ -214,6 +214,42 @@ def test_invalid_parameters(self): single_snapshot_per_visit=True, ) + def test_string_dtype_columns(self): + """Test training when categorical columns use pandas StringDtype.""" + visits = self.train_visits.copy() + visits["sex"] = visits["sex"].astype("string") + visits["arrival_method"] = visits["arrival_method"].astype("string") + + model = train_classifier( + train_visits=visits, + valid_visits=visits.copy(), + prediction_time=self.prediction_time, + exclude_from_training_data=self.exclude_from_training_data, + grid=self.grid, + ordinal_mappings=self.ordinal_mappings, + visit_col="visit_number", + ) + self.assertIsInstance(model, TrainedClassifier) + self.assertIsNotNone(model.pipeline) + + def test_categorical_dtype_columns(self): + """Test training when categorical columns use CategoricalDtype.""" + visits = self.train_visits.copy() + visits["sex"] = visits["sex"].astype("category") + visits["arrival_method"] = visits["arrival_method"].astype("category") + + model = train_classifier( + train_visits=visits, + valid_visits=visits.copy(), + prediction_time=self.prediction_time, + exclude_from_training_data=self.exclude_from_training_data, + grid=self.grid, + ordinal_mappings=self.ordinal_mappings, + visit_col="visit_number", + ) + self.assertIsInstance(model, TrainedClassifier) + self.assertIsNotNone(model.pipeline) + def test_feature_importance(self): """Test that feature importance is captured when available.""" model = train_classifier(