Skip to content

Commit

Permalink
no pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszkolodziejczyk committed Jan 20, 2025
1 parent 7b9fdcd commit 471e1c6
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions mostlyai/qa/_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@

import hashlib
import json
import pickle
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Literal
import skops.io as sio

import numpy as np
import pandas as pd
from plotly import graph_objs as go
from sklearn.decomposition import PCA


_OLD_COL_PREFIX = r"^(tgt|ctx|nxt)(\.|⁝)"
_NEW_COL_PREFIX = r"\1::"

Expand Down Expand Up @@ -96,15 +95,15 @@ def __init__(self, path: str | Path):
self.path = Path(path)
self.early_exit_path = self.path / "_EARLY_EXIT"
self.meta_path = self.path / "meta.json"
self.bins_path = self.path / "bins.pickle"
self.bins_dir = self.path / "bins"
self.correlations_path = self.path / "correlations.parquet"
self.univariate_accuracies_path = self.path / "univariate_accuracies.parquet"
self.bivariate_accuracies_path = self.path / "bivariate_accuracies.parquet"
self.numeric_kdes_uni_path = self.path / "numeric_kdes_uni.pickle"
self.categorical_counts_uni_path = self.path / "categorical_counts_uni.pickle"
self.numeric_kdes_uni_dir = self.path / "numeric_kdes_uni"
self.categorical_counts_uni_dir = self.path / "categorical_counts_uni"
self.bin_counts_uni_path = self.path / "bin_counts_uni.parquet"
self.bin_counts_biv_path = self.path / "bin_counts_biv.parquet"
self.pca_model_path = self.path / "pca_model.pickle"
self.pca_model_path = self.path / "pca_model.skops"
self.trn_pca_path = self.path / "trn_pca.npy"
self.hol_pca_path = self.path / "hol_pca.npy"

Expand All @@ -124,10 +123,14 @@ def load_meta(self) -> dict:

def store_bins(self, bins: dict[str, list]) -> None:
df = pd.Series(bins).to_frame("bins").reset_index().rename(columns={"index": "column"})
df.to_pickle(self.bins_path)
self.bins_dir.mkdir(exist_ok=True, parents=True)
for i, row in df.iterrows():
row_df = pd.DataFrame([row]).explode("bins")
row_df.to_parquet(self.bins_dir / f"{i:05}.parquet")

def load_bins(self) -> dict[str, list]:
df = pd.read_pickle(self.bins_path)
df = pd.concat([pd.read_parquet(p) for p in sorted(self.bins_dir.glob("*.parquet"))])
df = df.groupby("column", sort=False).agg(list).reset_index()
# harmonise older prefix formats to <prefix>:: for compatibility with older versions
df["column"] = df["column"].str.replace(_OLD_COL_PREFIX, _NEW_COL_PREFIX, regex=True)
return df.set_index("column")["bins"].to_dict()
Expand Down Expand Up @@ -166,10 +169,14 @@ def store_numeric_uni_kdes(self, trn_kdes: dict[str, pd.Series]) -> None:
[(column, list(xy.index), list(xy.values)) for column, xy in trn_kdes.items()],
columns=["column", "x", "y"],
)
trn_kdes.to_pickle(self.numeric_kdes_uni_path)
self.numeric_kdes_uni_dir.mkdir(exist_ok=True, parents=True)
for i, row in trn_kdes.iterrows():
row_df = pd.DataFrame([row]).explode(["x", "y"])
row_df.to_parquet(self.numeric_kdes_uni_dir / f"{i:05}.parquet")

def load_numeric_uni_kdes(self) -> dict[str, pd.Series]:
trn_kdes = pd.read_pickle(self.numeric_kdes_uni_path)
trn_kdes = pd.concat([pd.read_parquet(p) for p in sorted(self.numeric_kdes_uni_dir.glob("*.parquet"))])
trn_kdes = trn_kdes.groupby("column", sort=False).agg(list).reset_index()
# harmonise older prefix formats to <prefix>:: for compatibility with older versions
trn_kdes["column"] = trn_kdes["column"].str.replace(_OLD_COL_PREFIX, _NEW_COL_PREFIX, regex=True)
trn_kdes = {
Expand All @@ -187,10 +194,16 @@ def store_categorical_uni_counts(self, trn_cnts_uni: dict[str, pd.Series]) -> No
[(column, list(cat_counts.index), list(cat_counts.values)) for column, cat_counts in trn_cnts_uni.items()],
columns=["column", "cat", "count"],
)
trn_cnts_uni.to_pickle(self.categorical_counts_uni_path)
self.categorical_counts_uni_dir.mkdir(exist_ok=True, parents=True)
for i, row in trn_cnts_uni.iterrows():
row_df = pd.DataFrame([row]).explode(["cat", "count"])
row_df.to_parquet(self.categorical_counts_uni_dir / f"{i:05}.parquet")

def load_categorical_uni_counts(self) -> dict[str, pd.Series]:
trn_cnts_uni = pd.read_pickle(self.categorical_counts_uni_path)
trn_cnts_uni = pd.concat(
[pd.read_parquet(p) for p in sorted(self.categorical_counts_uni_dir.glob("*.parquet"))]
)
trn_cnts_uni = trn_cnts_uni.groupby("column", sort=False).agg(list).reset_index()
# harmonise older prefix formats to <prefix>:: for compatibility with older versions
trn_cnts_uni["column"] = trn_cnts_uni["column"].str.replace(_OLD_COL_PREFIX, _NEW_COL_PREFIX, regex=True)
trn_cnts_uni = {
Expand Down Expand Up @@ -266,14 +279,13 @@ def biv_multi_index(bin, col1, col2):
return trn_cnts_uni, trn_cnts_biv

def store_pca_model(self, pca_model: PCA):
with self.pca_model_path.open("wb") as file:
pickle.dump(pca_model, file)
sio.dump(pca_model, self.pca_model_path)

def load_pca_model(self) -> PCA | None:
if not self.pca_model_path.exists():
return None
with self.pca_model_path.open("rb") as file:
return pickle.load(file)
unknown_types = sio.get_untrusted_types(file=self.pca_model_path)
if unknown_types:
raise ValueError(f"Unknown types found in file {self.pca_model_path}: {unknown_types}")
return sio.load(self.pca_model_path)

def store_trn_hol_pcas(self, trn_pca: np.ndarray, hol_pca: np.ndarray | None):
np.save(self.trn_pca_path, trn_pca)
Expand Down

0 comments on commit 471e1c6

Please sign in to comment.