Skip to content

Commit

Permalink
Merge branch 'main' into msd-xxx-no-rich
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszkolodziejczyk authored Jan 20, 2025
2 parents 3adc0ff + b4e9959 commit d314119
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 26 deletions.
52 changes: 36 additions & 16 deletions mostlyai/qa/_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@

import hashlib
import json
import pickle
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Literal
import skops.io as sio

import numpy as np
import pandas as pd
from plotly import graph_objs as go
from sklearn.decomposition import PCA


_OLD_COL_PREFIX = r"^(tgt|ctx|nxt)(\.|⁝)"
_NEW_COL_PREFIX = r"\1::"

Expand Down Expand Up @@ -96,15 +95,15 @@ def __init__(self, path: str | Path):
self.path = Path(path)
self.early_exit_path = self.path / "_EARLY_EXIT"
self.meta_path = self.path / "meta.json"
self.bins_path = self.path / "bins.pickle"
self.bins_dir = self.path / "bins"
self.correlations_path = self.path / "correlations.parquet"
self.univariate_accuracies_path = self.path / "univariate_accuracies.parquet"
self.bivariate_accuracies_path = self.path / "bivariate_accuracies.parquet"
self.numeric_kdes_uni_path = self.path / "numeric_kdes_uni.pickle"
self.categorical_counts_uni_path = self.path / "categorical_counts_uni.pickle"
self.numeric_kdes_uni_dir = self.path / "numeric_kdes_uni"
self.categorical_counts_uni_dir = self.path / "categorical_counts_uni"
self.bin_counts_uni_path = self.path / "bin_counts_uni.parquet"
self.bin_counts_biv_path = self.path / "bin_counts_biv.parquet"
self.pca_model_path = self.path / "pca_model.pickle"
self.pca_model_path = self.path / "pca_model.skops"
self.trn_pca_path = self.path / "trn_pca.npy"
self.hol_pca_path = self.path / "hol_pca.npy"

Expand All @@ -124,10 +123,16 @@ def load_meta(self) -> dict:

def store_bins(self, bins: dict[str, list]) -> None:
df = pd.Series(bins).to_frame("bins").reset_index().rename(columns={"index": "column"})
df.to_pickle(self.bins_path)
self.bins_dir.mkdir(exist_ok=True, parents=True)
empty_df = pd.DataFrame(columns=["column", "bins"])
empty_df.to_parquet(self.bins_dir / "empty.parquet")
for i, row in df.iterrows():
row_df = pd.DataFrame([row]).explode("bins")
row_df.to_parquet(self.bins_dir / f"{i:05}.parquet")

def load_bins(self) -> dict[str, list]:
df = pd.read_pickle(self.bins_path)
df = pd.concat([pd.read_parquet(p) for p in sorted(self.bins_dir.glob("*.parquet"))])
df = df.groupby("column", sort=False).agg(list).reset_index()
# harmonise older prefix formats to <prefix>:: for compatibility with older versions
df["column"] = df["column"].str.replace(_OLD_COL_PREFIX, _NEW_COL_PREFIX, regex=True)
return df.set_index("column")["bins"].to_dict()
Expand Down Expand Up @@ -166,10 +171,16 @@ def store_numeric_uni_kdes(self, trn_kdes: dict[str, pd.Series]) -> None:
[(column, list(xy.index), list(xy.values)) for column, xy in trn_kdes.items()],
columns=["column", "x", "y"],
)
trn_kdes.to_pickle(self.numeric_kdes_uni_path)
self.numeric_kdes_uni_dir.mkdir(exist_ok=True, parents=True)
empty_df = pd.DataFrame(columns=["column", "x", "y"])
empty_df.to_parquet(self.numeric_kdes_uni_dir / "empty.parquet")
for i, row in trn_kdes.iterrows():
row_df = pd.DataFrame([row]).explode(["x", "y"])
row_df.to_parquet(self.numeric_kdes_uni_dir / f"{i:05}.parquet")

def load_numeric_uni_kdes(self) -> dict[str, pd.Series]:
trn_kdes = pd.read_pickle(self.numeric_kdes_uni_path)
trn_kdes = pd.concat([pd.read_parquet(p) for p in sorted(self.numeric_kdes_uni_dir.glob("*.parquet"))])
trn_kdes = trn_kdes.groupby("column", sort=False).agg(list).reset_index()
# harmonise older prefix formats to <prefix>:: for compatibility with older versions
trn_kdes["column"] = trn_kdes["column"].str.replace(_OLD_COL_PREFIX, _NEW_COL_PREFIX, regex=True)
trn_kdes = {
Expand All @@ -187,10 +198,18 @@ def store_categorical_uni_counts(self, trn_cnts_uni: dict[str, pd.Series]) -> No
[(column, list(cat_counts.index), list(cat_counts.values)) for column, cat_counts in trn_cnts_uni.items()],
columns=["column", "cat", "count"],
)
trn_cnts_uni.to_pickle(self.categorical_counts_uni_path)
self.categorical_counts_uni_dir.mkdir(exist_ok=True, parents=True)
empty_df = pd.DataFrame(columns=["column", "cat", "count"])
empty_df.to_parquet(self.categorical_counts_uni_dir / "empty.parquet")
for i, row in trn_cnts_uni.iterrows():
row_df = pd.DataFrame([row]).explode(["cat", "count"])
row_df.to_parquet(self.categorical_counts_uni_dir / f"{i:05}.parquet")

def load_categorical_uni_counts(self) -> dict[str, pd.Series]:
trn_cnts_uni = pd.read_pickle(self.categorical_counts_uni_path)
trn_cnts_uni = pd.concat(
[pd.read_parquet(p) for p in sorted(self.categorical_counts_uni_dir.glob("*.parquet"))]
)
trn_cnts_uni = trn_cnts_uni.groupby("column", sort=False).agg(list).reset_index()
# harmonise older prefix formats to <prefix>:: for compatibility with older versions
trn_cnts_uni["column"] = trn_cnts_uni["column"].str.replace(_OLD_COL_PREFIX, _NEW_COL_PREFIX, regex=True)
trn_cnts_uni = {
Expand Down Expand Up @@ -266,14 +285,15 @@ def biv_multi_index(bin, col1, col2):
return trn_cnts_uni, trn_cnts_biv

def store_pca_model(self, pca_model: PCA):
with self.pca_model_path.open("wb") as file:
pickle.dump(pca_model, file)
sio.dump(pca_model, self.pca_model_path)

def load_pca_model(self) -> PCA | None:
if not self.pca_model_path.exists():
return None
with self.pca_model_path.open("rb") as file:
return pickle.load(file)
unknown_types = sio.get_untrusted_types(file=self.pca_model_path)
if unknown_types:
raise ValueError(f"Unknown types found in file {self.pca_model_path}: {unknown_types}")
return sio.load(self.pca_model_path)

def store_trn_hol_pcas(self, trn_pca: np.ndarray, hol_pca: np.ndarray | None):
np.save(self.trn_pca_path, trn_pca)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"scikit-learn>=1.4.0",
"sentence-transformers>=3.1.0",
"rich>=13.9.4,<14",
"skops>=0.11.0",
]

[project.urls]
Expand Down
36 changes: 26 additions & 10 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d314119

Please sign in to comment.