Skip to content

Commit

Permalink
removing nested-asyncio from Model Registry client
Browse files Browse the repository at this point in the history
Signed-off-by: blublinsky <[email protected]>
  • Loading branch information
blublinsky committed Feb 14, 2025
1 parent 118d0f6 commit ec997c6
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 30 deletions.
102 changes: 102 additions & 0 deletions clients/python/src/model_registry/_async_task_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# from https://gist.github.com/blink1073/969aeba85f32c285235750626f2eadd8

"""
Copyright (c) 2022 Steven Silvester
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

import asyncio
from typing import Coroutine, Optional, Any
from threading import Thread, Lock
import atexit


class AsyncTaskRunner:
"""
A singleton task runner that runs an asyncio event loop on a background thread.
"""
__instance = None

@staticmethod
def get_instance():
"""
Get an AsyncTaskRunner (singleton)
"""
if AsyncTaskRunner.__instance is None:
AsyncTaskRunner()
assert AsyncTaskRunner.__instance is not None
return AsyncTaskRunner.__instance

def __init__(self):
"""
Initialize
"""
# make sure it is a singleton
if AsyncTaskRunner.__instance is not None:
raise Exception("This class is a singleton!")
else:
AsyncTaskRunner.__instance = self
# initialize variables
self.__io_loop: Optional[asyncio.AbstractEventLoop] = None
self.__runner_thread: Optional[Thread] = None
self.__lock = Lock()
# register exit handler
atexit.register(self._close)

def _close(self):
"""
Clean up. Stop the loop if running
"""
if self.__io_loop:
self.__io_loop.stop()

def _runner(self) -> None:
"""
Function to run in a thread
"""
loop = self.__io_loop
assert loop is not None
try:
loop.run_forever()
finally:
loop.close()

def run(self, coro: Coroutine) -> Any:
"""
Synchronously run a coroutine on a background thread.
"""
with self.__lock:
if self.__io_loop is None:
# If the asyncio loop does not exist
self.__io_loop = asyncio.new_event_loop()
self.__runner_thread = Thread(target=self._runner, daemon=True)
self.__runner_thread.start()
# run coroutine thread safe inside a thread. This return concurrent future
fut = asyncio.run_coroutine_threadsafe(coro, self.__io_loop)
# get concurrent future result
return fut.result()
52 changes: 22 additions & 30 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from collections.abc import Mapping
from pathlib import Path
from typing import Any, TypeVar, Union, get_args
from typing import TypeVar, Union, get_args
from warnings import warn

from .core import ModelRegistryAPIClient
Expand All @@ -19,6 +19,8 @@
RegisteredModel,
SupportedTypes,
)
from ._async_task_runner import AsyncTaskRunner


ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = TypeVar("TModel", bound=ModelTypes)
Expand Down Expand Up @@ -74,17 +76,16 @@ def __init__(
author: Name of the author.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a string.
user_token_envvar: Environment variable to read the user token from if it's not passed as an arg. Defaults to KF_PIPELINES_SA_TOKEN_PATH.
user_token_envvar: Environment variable to read the user token from if it's not passed as an arg.
Defaults to KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a string.
custom_ca_envvar: Environment variable to read the custom CA from if it's not passed as an arg.
log_level: Log level. Defaults to logging.WARNING.
"""
logger.setLevel(log_level)

import nest_asyncio

logger.debug("Setting up reentrant async event loop")
nest_asyncio.apply()

self.runner = AsyncTaskRunner.get_instance()

# TODO: get remaining args from env
self._author = author
Expand Down Expand Up @@ -127,16 +128,6 @@ def __init__(
)
self.get_registered_models().page_size(1)._next_page()

def async_runner(self, coro: Any) -> Any:
import asyncio

try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)

async def _register_model(self, name: str, **kwargs) -> RegisteredModel:
if rm := await self._api.get_registered_model_by_params(name):
return rm
Expand Down Expand Up @@ -210,8 +201,8 @@ def register_model(
Returns:
Registered model.
"""
rm = self.async_runner(self._register_model(name, owner=owner or self._author))
mv = self.async_runner(
rm = self.runner.run(self._register_model(name, owner=owner or self._author))
mv = self.runner.run(
self._register_new_version(
rm,
version,
Expand All @@ -220,7 +211,7 @@ def register_model(
custom_properties=metadata or {},
)
)
self.async_runner(
self.runner.run(
self._register_model_artifact(
mv,
name,
Expand All @@ -244,10 +235,10 @@ def update(self, model: TModel) -> TModel:
msg = f"Model must be one of {get_args(ModelTypes)}"
raise StoreError(msg)
if isinstance(model, RegisteredModel):
return self.async_runner(self._api.upsert_registered_model(model))
return self.runner.run(self._api.upsert_registered_model(model))
if isinstance(model, ModelVersion):
return self.async_runner(self._api.upsert_model_version(model, None))
return self.async_runner(self._api.upsert_model_artifact(model))
return self.runner.run(self._api.upsert_model_version(model, None))
return self.runner.run(self._api.upsert_model_artifact(model))

def register_hf_model(
self,
Expand Down Expand Up @@ -289,8 +280,8 @@ def register_hf_model(
from huggingface_hub import HfApi, hf_hub_url, utils
except ImportError as e:
msg = """package `huggingface-hub` is not installed.
To import models from Hugging Face Hub, start by installing the `huggingface-hub` package, either directly or as an
extra (available as `model-registry[hf]`), e.g.:
To import models from Hugging Face Hub, start by installing the `huggingface-hub` package,
either directly or as an extra (available as `model-registry[hf]`), e.g.:
```sh
!pip install --pre model-registry[hf]
```
Expand Down Expand Up @@ -363,7 +354,7 @@ def get_registered_model(self, name: str) -> RegisteredModel | None:
Returns:
Registered model.
"""
return self.async_runner(self._api.get_registered_model_by_params(name))
return self.runner.run(self._api.get_registered_model_by_params(name))

def get_model_version(self, name: str, version: str) -> ModelVersion | None:
"""Get a model version.
Expand All @@ -382,7 +373,7 @@ def get_model_version(self, name: str, version: str) -> ModelVersion | None:
msg = f"Model {name} does not exist"
raise StoreError(msg)
assert rm.id
return self.async_runner(self._api.get_model_version_by_params(rm.id, version))
return self.runner.run(self._api.get_model_version_by_params(rm.id, version))

def get_model_artifact(self, name: str, version: str) -> ModelArtifact | None:
"""Get a model artifact.
Expand All @@ -401,7 +392,7 @@ def get_model_artifact(self, name: str, version: str) -> ModelArtifact | None:
msg = f"Version {version} does not exist"
raise StoreError(msg)
assert mv.id
return self.async_runner(self._api.get_model_artifact_by_params(name, mv.id))
return self.runner.run(self._api.get_model_artifact_by_params(name, mv.id))

def get_registered_models(self) -> Pager[RegisteredModel]:
"""Get a pager for registered models.
Expand All @@ -411,7 +402,7 @@ def get_registered_models(self) -> Pager[RegisteredModel]:
"""

def rm_list(options: ListOptions) -> list[RegisteredModel]:
return self.async_runner(self._api.get_registered_models(options))
return self.runner.run(self._api.get_registered_models(options))

return Pager[RegisteredModel](rm_list)

Expand All @@ -432,8 +423,9 @@ def get_model_versions(self, name: str) -> Pager[ModelVersion]:
raise StoreError(msg)

def rm_versions(options: ListOptions) -> list[ModelVersion]:
# type checkers can't restrict the type inside a nested function: https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
# type checkers can't restrict the type inside a nested function:
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert rm.id
return self.async_runner(self._api.get_model_versions(rm.id, options))
return self.runner.run(self._api.get_model_versions(rm.id, options))

return Pager[ModelVersion](rm_versions)

0 comments on commit ec997c6

Please sign in to comment.