Skip to content

Commit

Permalink
trim down ProgressCallbackWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszkolodziejczyk committed Jan 20, 2025
1 parent 7b9fdcd commit a72c192
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 27 deletions.
27 changes: 2 additions & 25 deletions mostlyai/qa/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# limitations under the License.

import logging
from functools import partial
from typing import Protocol
from collections.abc import Callable

import pandas as pd
from rich.progress import Progress

from mostlyai.qa._filesystem import Statistics

Expand Down Expand Up @@ -79,30 +76,11 @@ def __call__(self, total: float | None = None, completed: float | None = None, *


class ProgressCallbackWrapper:
@staticmethod
def _wrap_progress_callback(
update_progress: ProgressCallback | None = None, **kwargs
) -> tuple[ProgressCallback, Callable]:
if not update_progress:
rich_progress = Progress()
rich_progress.start()
task_id = rich_progress.add_task(**kwargs)
update_progress = partial(rich_progress.update, task_id=task_id)
else:
rich_progress = None

def teardown_progress():
if rich_progress:
rich_progress.refresh()
rich_progress.stop()

return update_progress, teardown_progress

def update(self, total: float | None = None, completed: float | None = None, **kwargs) -> None:
self._update_progress(total=total, completed=completed, **kwargs)

def __init__(self, update_progress: ProgressCallback | None = None, **kwargs):
self._update_progress, self._teardown_progress = self._wrap_progress_callback(update_progress, **kwargs)
def __init__(self, update_progress: ProgressCallback | None = None):
self._update_progress = update_progress if update_progress is not None else (lambda *args, **kwargs: None)

def __enter__(self):
self._update_progress(completed=0, total=1)
Expand All @@ -111,7 +89,6 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self._update_progress(completed=1, total=1)
self._teardown_progress()


def check_min_sample_size(size: int, min: int, type: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion mostlyai/qa/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def report(

with (
TemporaryWorkspace() as workspace,
ProgressCallbackWrapper(update_progress, description="Create report 🚀") as progress,
ProgressCallbackWrapper(update_progress) as progress,
):
# ensure all columns are present and in the same order as training data
syn_tgt_data = syn_tgt_data[trn_tgt_data.columns]
Expand Down
2 changes: 1 addition & 1 deletion mostlyai/qa/reporting_from_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def report_from_statistics(

with (
TemporaryWorkspace() as workspace,
ProgressCallbackWrapper(update_progress, description="Create report 🚀") as progress,
ProgressCallbackWrapper(update_progress) as progress,
):
# prepare report_path
if report_path is None:
Expand Down

0 comments on commit a72c192

Please sign in to comment.