Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/patientflow/train/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@
)


def _is_string_like_column(series: Series) -> bool:
"""Check whether a Series holds string-like values that need encoding.

Uses an exclusion-based approach so that any current or future
string-like dtype (``object``, ``StringDtype``, ``ArrowDtype("string")``,
``CategoricalDtype`` with string categories, etc.) is detected without
needing to enumerate each one explicitly.
"""
if isinstance(series.dtype, pd.CategoricalDtype):
series = series.cat.categories.to_series()
return not (
pd.api.types.is_numeric_dtype(series)
or pd.api.types.is_bool_dtype(series)
or pd.api.types.is_datetime64_any_dtype(series)
)


class FeatureColumnTransformer(BaseEstimator, TransformerMixin):
"""
Ensure that input data has exactly the columns seen during training.
Expand Down Expand Up @@ -118,7 +135,7 @@ def fit(
default = 0.0
elif pd.api.types.is_datetime64_any_dtype(series):
default = pd.NaT
elif series.dtype == "object":
elif _is_string_like_column(series):
mode = series.mode(dropna=True)
default = mode.iloc[0] if not mode.empty else "Unknown"
else:
Expand Down Expand Up @@ -328,8 +345,7 @@ def create_column_transformer(
# Keep boolean columns as single 0/1 features rather than one-hot encoding
elif df[col].dtype == "bool":
transformers.append((col, "passthrough", [col]))
# One-hot encode string/object categoricals
elif df[col].dtype == "object":
elif _is_string_like_column(df[col]):
transformers.append((col, OneHotEncoder(handle_unknown="ignore"), [col]))
# Keep non-boolean binary (0/1) columns as single features
elif df[col].nunique() == 2:
Expand Down
36 changes: 36 additions & 0 deletions tests/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,42 @@ def test_invalid_parameters(self):
single_snapshot_per_visit=True,
)

def test_string_dtype_columns(self):
"""Test training when categorical columns use pandas StringDtype."""
visits = self.train_visits.copy()
visits["sex"] = visits["sex"].astype("string")
visits["arrival_method"] = visits["arrival_method"].astype("string")

model = train_classifier(
train_visits=visits,
valid_visits=visits.copy(),
prediction_time=self.prediction_time,
exclude_from_training_data=self.exclude_from_training_data,
grid=self.grid,
ordinal_mappings=self.ordinal_mappings,
visit_col="visit_number",
)
self.assertIsInstance(model, TrainedClassifier)
self.assertIsNotNone(model.pipeline)

def test_categorical_dtype_columns(self):
"""Test training when categorical columns use CategoricalDtype."""
visits = self.train_visits.copy()
visits["sex"] = visits["sex"].astype("category")
visits["arrival_method"] = visits["arrival_method"].astype("category")

model = train_classifier(
train_visits=visits,
valid_visits=visits.copy(),
prediction_time=self.prediction_time,
exclude_from_training_data=self.exclude_from_training_data,
grid=self.grid,
ordinal_mappings=self.ordinal_mappings,
visit_col="visit_number",
)
self.assertIsInstance(model, TrainedClassifier)
self.assertIsNotNone(model.pipeline)

def test_feature_importance(self):
"""Test that feature importance is captured when available."""
model = train_classifier(
Expand Down
Loading