Skip to content

Commit

Permalink
feat: flesh out the upload_and_register method
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Dobroveanu <[email protected]>
  • Loading branch information
Crazyglue committed Feb 26, 2025
1 parent 545419d commit 5b44e51
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 65 deletions.
44 changes: 18 additions & 26 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from __future__ import annotations

from dataclasses import asdict
import logging
import os
from collections.abc import Mapping
from pathlib import Path
from typing import Any, TypeVar, Union, get_args
from warnings import warn

from .utils import is_oci_uri, is_s3_uri
from .utils import OCIParams, S3Params, is_oci_uri, is_s3_uri, save_to_oci_registry

from .core import ModelRegistryAPIClient
from .exceptions import StoreError
Expand Down Expand Up @@ -176,44 +177,37 @@ def _upload_to_s3(
"""
Uploads a file to an S3 bucket.
"""
pass

def _upload_to_oci(
self,
artifact_local_path: str,
destination: str,
) -> str:
"""
Uploads an artifact to an OCI registry.
"""
pass
msg = f"S3 Upload is not supported"
raise StoreError(msg)

def upload_artifact_and_register_model(
self,
name: str,
artifact_local_path: str,
destination_uri: str,
model_files: list[os.PathLike],
*,
# Artifact/Model Params
version: str,
model_format_name: str,
model_format_version: str,
storage_path: str | None = None,
storage_key: str | None = None,
service_account_name: str | None = None,
author: str | None = None,
owner: str | None = None,
description: str | None = None,
metadata: Mapping[str, SupportedTypes] | None = None,
upload_client_params: Mapping[str, str] | None = None,
# Upload/client Params
upload_params: OCIParams | S3Params,
) -> RegisteredModel:
if is_s3_uri(destination_uri):
self._upload_to_s3(artifact_local_path, destination_uri, upload_client_params['region_name'])
elif is_oci_uri(destination_uri):
self._upload_to_oci(artifact_local_path, destination_uri)
if isinstance(upload_params, S3Params):
# TODO: Dont use a mock function here
destination_uri = self._upload_to_s3(model_files, destination_uri)
elif isinstance(upload_params, OCIParams):
print(asdict(upload_params))
destination_uri = save_to_oci_registry(**asdict(upload_params), model_files=model_files)
else:
msg = "Invalid destination URI. Must start with 's3://' or 'oci://'"
raise StoreError(msg)

# TODO: Perform the upload(s)
msg = f'Param "upload_params" is required to perform an upload. Please ensure the value provided is valid'
raise ValueError(msg)

registered_model = self.register_model(
name,
Expand All @@ -222,16 +216,14 @@ def upload_artifact_and_register_model(
model_format_version=model_format_version,
version=version,
storage_key=storage_key,
storage_path=artifact_local_path,
storage_path=storage_path,
service_account_name=service_account_name,
author=author,
owner=owner,
description=description,
metadata=metadata,
)

# TODO: Do something with the model? Metdata?

return registered_model

def register_model(
Expand Down
38 changes: 29 additions & 9 deletions clients/python/src/model_registry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from __future__ import annotations

from dataclasses import dataclass
import os
<<<<<<< HEAD
import tempfile
from pathlib import Path
from typing import Callable, TypedDict
=======
import re
>>>>>>> 6ebe337 (chore: add oci and s3 helper methods)

from typing_extensions import overload

Expand Down Expand Up @@ -138,6 +136,25 @@ def _get_oras_backend() -> BackendDefinition:
"push": oras_push,
}

@dataclass
class OCIParams:
'''Parameters for the OCI client to perform the upload
Allows for some customization of how to perform the upload step when uploading via OCI
'''
base_image: str
oci_ref: str
dest_dir: str | os.PathLike = None
backend: str = "skopeo"
modelcard: os.PathLike | None = None
custom_oci_backend: BackendDefinition = None

@dataclass
class S3Params:
# TODO: Implement s3 params dataclass
pass

# A dict mapping backend names to their definitions
BackendDict = dict[str, Callable[[], BackendDefinition]]

Expand All @@ -153,7 +170,7 @@ def save_to_oci_registry(
dest_dir: str | os.PathLike = None,
backend: str = "skopeo",
modelcard: os.PathLike | None = None,
backend_registry: BackendDict | None = DEFAULT_BACKENDS,
custom_oci_backend: BackendDefinition | None = None,
) -> str:
"""Appends a list of files to an OCI-based image.
Expand All @@ -164,7 +181,6 @@ def save_to_oci_registry(
model_files: List of files to add to the base_image as layers
backend: The CLI tool to use to perform the oci image pull/push. One of: "skopeo", "oras"
modelcard: Optional, path to the modelcard to additionally include as a layer
backend_registry: Optional, a dict of backends available to be used to perform the OCI image download/upload
Raises:
ValueError: If the chosen backend is not installed on the host
ValueError: If the chosen backend is an invalid option
Expand All @@ -189,12 +205,16 @@ def save_to_oci_registry(
raise StoreError(msg) from e


if backend not in backend_registry:
msg = f"'{backend}' is not an available backend to use. Available backends: {backend_registry.keys()}"
# If a custom backend is provided, use it, else fetch the backend out of the registry
if custom_oci_backend:
backend_def = custom_oci_backend
elif backend in DEFAULT_BACKENDS:
# Fetching the backend definition can throw an error, but it should bubble up as it has the appropriate messaging
backend_def = DEFAULT_BACKENDS[backend]()
else:
msg = f"'{backend}' is not an available backend to use. Available backends: {DEFAULT_BACKENDS.keys()}"
raise ValueError(msg)

# Fetching the backend definition can throw an error, but it should bubble up as it has the appropriate messaging
backend_def = backend_registry[backend]()

if not backend_def["is_available"]():
msg = f"Backend '{backend}' is selected, but not available on the system. Ensure the dependencies for '{backend}' are installed in your environment."
Expand Down
1 change: 1 addition & 0 deletions clients/python/tests/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
data/
tmp/
38 changes: 38 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from itertools import islice
import pathlib

import pytest
import requests
Expand Down Expand Up @@ -649,3 +650,40 @@ def test_hf_import_default_env(client: ModelRegistry):

for k in env_values:
os.environ.pop(k)

@pytest.mark.e2e
def test_upload_artifact_and_register_model_with_default_oci(client: ModelRegistry) -> None:
# olot is required to run this test
pytest.importorskip("olot")
name = "oci-test/defaults"
version = "0.0.1"
author = "Tester McTesterson"
oci_ref = "localhost:5001/foo/bar:latest"
local_path = pathlib.Path("tests/data")

upload_params = utils.OCIParams(
"quay.io/mmortari/hello-world-wait:latest",
oci_ref,
)

print(upload_params)

# Create a sample file named README.md to be added to the registry
pathlib.Path(local_path).mkdir(parents=True, exist_ok=True)
readme_file_path = os.path.join(local_path, "README.md")
with open(readme_file_path, "w") as f:
f.write("")


assert client.upload_artifact_and_register_model(
name,
model_files=[pathlib.Path(readme_file_path)],
author=author,
version=version,
model_format_name="test format",
model_format_version="test version",
upload_params=upload_params,
)

assert (ma := client.get_model_artifact(name, version))
assert ma.uri == f"oci://{oci_ref}"
33 changes: 3 additions & 30 deletions clients/python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
import pytest

from model_registry.exceptions import MissingMetadata
<<<<<<< HEAD
from model_registry.utils import s3_uri_from, save_to_oci_registry
=======
from model_registry.utils import s3_uri_from, is_s3_uri, is_oci_uri
>>>>>>> 6ebe337 (chore: add oci and s3 helper methods)
from model_registry.utils import s3_uri_from, save_to_oci_registry, is_s3_uri, is_oci_uri


def test_s3_uri_builder():
Expand Down Expand Up @@ -121,12 +117,10 @@ def pull_mock_imple(base_image, dest_dir):


backend = "something_custom"
backend_registry = {
"something_custom": lambda: {
custom_oci_backend = {
"is_available": is_available_mock,
"pull": pull_mock,
"push": push_mock,
}
}

# similar to other test
Expand All @@ -142,34 +136,13 @@ def pull_mock_imple(base_image, dest_dir):

model_files = [readme_file_path]

uri = save_to_oci_registry(base_image, oci_ref, model_files, dest_dir, backend, None, backend_registry)
uri = save_to_oci_registry(base_image, oci_ref, model_files, dest_dir, backend, None, custom_oci_backend)
# Ensure our mocked backend was called
is_available_mock.assert_called_once()
pull_mock.assert_called_once()
push_mock.assert_called_once()
assert uri == f"oci://{oci_ref}"

def test_save_to_oci_registry_with_custom_backend_unavailable():
is_available_mock = Mock()
is_available_mock.return_value = False # Backend is unavailable, expect an error
pull_mock = Mock()
push_mock = Mock()

backend = "something_custom"
backend_registry = {
"something_custom": lambda: {
"is_available": is_available_mock,
"pull": pull_mock,
"push": push_mock,
}
}


with pytest.raises(ValueError, match=f"Backend '{backend}' is selected, but not available on the system. Ensure the dependencies for '{backend}' are installed in your environment.") as e:
save_to_oci_registry("", "", [], "", backend, backend_registry=backend_registry)

assert f"Backend '{backend}' is selected, but not available on the system." in str(e.value)

def test_save_to_oci_registry_backend_not_found():
backend = "non-existent"
with pytest.raises(ValueError, match=f"'{backend}' is not an available backend to use.") as e:
Expand Down

0 comments on commit 5b44e51

Please sign in to comment.