Skip to content

Commit

Permalink
fixed issue with binning bool cols and with constant datetime cols
Browse files Browse the repository at this point in the history
  • Loading branch information
mplatzer authored Dec 7, 2024
1 parent fa6dfc9 commit ed6e41c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mostlyai/qa/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ def _clip(col, bins):

if col.nunique() == 1:
# ensure 2 breaks for single-valued columns
val = col.iloc[0]
val = col.dropna().iloc[0]
upper_limit = [val + np.timedelta64(1, "D")] if not pd.isna(val) else []
breaks = [val] + upper_limit
else:
Expand Down Expand Up @@ -1115,6 +1115,7 @@ def bin_non_categorical(


def bin_categorical(col: pd.Series, bins: int | list[str]) -> tuple[pd.Categorical, list[str]]:
col = col.astype("string[pyarrow]")
col = col.fillna(NA_BIN)
col = col.replace("", EMPTY_BIN)
# determine top values, if not provided
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
trim_labels,
calculate_correlations,
plot_store_correlation_matrices,
bin_categorical,
)
from mostlyai.qa.sampling import pull_data_for_accuracy, sample_two_consecutive_rows
from mostlyai.qa.common import (
Expand Down Expand Up @@ -496,6 +497,14 @@ def test_num_col_nans_only(self):
df_counts = df["nans"].value_counts().to_dict()
assert df_counts["(n/a)"] == 10

def test_bin_categorical(self):
x = pd.Series(["a", "b"] * 50 + ["x"])
col, _ = bin_categorical(x, 5)
assert len(col) == 101
x = pd.Series([True, False] * 50 + [np.nan] * 100, dtype="object")
col, _ = bin_categorical(x, 5)
assert len(col) == 200

def test_bin_numeric(self):
# test several edge cases
cases = [
Expand Down Expand Up @@ -534,6 +543,10 @@ def test_bin_datetime(self):
),
["⪰ 2023-01-30 13:00:00.333000"] * 20,
), # two values
(
pd.Series([pd.NaT, "2024-11-20"], dtype="datetime64[ns]"),
["(n/a)", "⪰ 2024-Nov-20"],
), # single value with leading N/A
]

for col, expected in cases:
Expand Down

0 comments on commit ed6e41c

Please sign in to comment.