Skip to content

Commit

Permalink
feat: flesh out the upload_and_register method
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Dobroveanu <[email protected]>
  • Loading branch information
Crazyglue committed Feb 26, 2025
1 parent 545419d commit 6cfbeba
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 89 deletions.
54 changes: 21 additions & 33 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import logging
import os
from collections.abc import Mapping
from dataclasses import asdict
from pathlib import Path
from typing import Any, TypeVar, Union, get_args
from warnings import warn

from .utils import is_oci_uri, is_s3_uri

from .core import ModelRegistryAPIClient
from .exceptions import StoreError
from .types import (
Expand All @@ -21,6 +20,7 @@
RegisteredModel,
SupportedTypes,
)
from .utils import OCIParams, S3Params, save_to_oci_registry

ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = TypeVar("TModel", bound=ModelTypes)
Expand Down Expand Up @@ -166,73 +166,61 @@ async def _register_model_artifact(
return await self._api.upsert_model_version_artifact(
ModelArtifact(name=name, uri=uri, **kwargs), mv.id
)

def _upload_to_s3(
self,
artifact_local_path: str,
destination: str,
region_name: str | None = None,
) -> str:
"""
Uploads a file to an S3 bucket.
"""
pass
"""Uploads a file to an S3 bucket."""
msg = "S3 Upload is not supported"
raise StoreError(msg)

def _upload_to_oci(
self,
artifact_local_path: str,
destination: str,
) -> str:
"""
Uploads an artifact to an OCI registry.
"""
pass

def upload_artifact_and_register_model(
self,
name: str,
artifact_local_path: str,
destination_uri: str,
model_files: list[os.PathLike],
*,
# Artifact/Model Params
version: str,
model_format_name: str,
model_format_version: str,
storage_path: str | None = None,
storage_key: str | None = None,
service_account_name: str | None = None,
author: str | None = None,
owner: str | None = None,
description: str | None = None,
metadata: Mapping[str, SupportedTypes] | None = None,
upload_client_params: Mapping[str, str] | None = None,
# Upload/client Params
upload_params: OCIParams | S3Params,
) -> RegisteredModel:
if is_s3_uri(destination_uri):
self._upload_to_s3(artifact_local_path, destination_uri, upload_client_params['region_name'])
elif is_oci_uri(destination_uri):
self._upload_to_oci(artifact_local_path, destination_uri)
if isinstance(upload_params, S3Params):
# TODO: Dont use a mock function here
destination_uri = self._upload_to_s3(model_files, destination_uri)
elif isinstance(upload_params, OCIParams):
print(asdict(upload_params))
destination_uri = save_to_oci_registry(**asdict(upload_params), model_files=model_files)
else:
msg = "Invalid destination URI. Must start with 's3://' or 'oci://'"
raise StoreError(msg)

# TODO: Perform the upload(s)
msg = 'Param "upload_params" is required to perform an upload. Please ensure the value provided is valid'
raise ValueError(msg)

registered_model = self.register_model(
return self.register_model(
name,
destination_uri,
model_format_name=model_format_name,
model_format_version=model_format_version,
version=version,
storage_key=storage_key,
storage_path=artifact_local_path,
storage_path=storage_path,
service_account_name=service_account_name,
author=author,
owner=owner,
description=description,
metadata=metadata,
)

# TODO: Do something with the model? Metdata?

return registered_model

def register_model(
self,
Expand Down
62 changes: 40 additions & 22 deletions clients/python/src/model_registry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from __future__ import annotations

import os
<<<<<<< HEAD
import re
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, TypedDict
=======
import re
>>>>>>> 6ebe337 (chore: add oci and s3 helper methods)

from typing_extensions import overload

Expand Down Expand Up @@ -138,6 +136,25 @@ def _get_oras_backend() -> BackendDefinition:
"push": oras_push,
}

@dataclass
class OCIParams:
"""Parameters for the OCI client to perform the upload.
Allows for some customization of how to perform the upload step when uploading via OCI
"""
base_image: str
oci_ref: str
dest_dir: str | os.PathLike = None
backend: str = "skopeo"
modelcard: os.PathLike | None = None
custom_oci_backend: BackendDefinition = None

@dataclass
class S3Params:
# TODO: Implement s3 params dataclass
pass

# A dict mapping backend names to their definitions
BackendDict = dict[str, Callable[[], BackendDefinition]]

Expand All @@ -153,7 +170,7 @@ def save_to_oci_registry(
dest_dir: str | os.PathLike = None,
backend: str = "skopeo",
modelcard: os.PathLike | None = None,
backend_registry: BackendDict | None = DEFAULT_BACKENDS,
custom_oci_backend: BackendDefinition | None = None,
) -> str:
"""Appends a list of files to an OCI-based image.
Expand All @@ -164,7 +181,6 @@ def save_to_oci_registry(
model_files: List of files to add to the base_image as layers
backend: The CLI tool to use to perform the oci image pull/push. One of: "skopeo", "oras"
modelcard: Optional, path to the modelcard to additionally include as a layer
backend_registry: Optional, a dict of backends available to be used to perform the OCI image download/upload
Raises:
ValueError: If the chosen backend is not installed on the host
ValueError: If the chosen backend is an invalid option
Expand All @@ -189,12 +205,16 @@ def save_to_oci_registry(
raise StoreError(msg) from e


if backend not in backend_registry:
msg = f"'{backend}' is not an available backend to use. Available backends: {backend_registry.keys()}"
# If a custom backend is provided, use it, else fetch the backend out of the registry
if custom_oci_backend:
backend_def = custom_oci_backend
elif backend in DEFAULT_BACKENDS:
# Fetching the backend definition can throw an error, but it should bubble up as it has the appropriate messaging
backend_def = DEFAULT_BACKENDS[backend]()
else:
msg = f"'{backend}' is not an available backend to use. Available backends: {DEFAULT_BACKENDS.keys()}"
raise ValueError(msg)

# Fetching the backend definition can throw an error, but it should bubble up as it has the appropriate messaging
backend_def = backend_registry[backend]()

if not backend_def["is_available"]():
msg = f"Backend '{backend}' is selected, but not available on the system. Ensure the dependencies for '{backend}' are installed in your environment."
Expand All @@ -214,40 +234,38 @@ def save_to_oci_registry(
s3_prefix = "s3://"

def is_s3_uri(uri: str):
"""Checks whether a string is a valid S3 URI
"""Checks whether a string is a valid S3 URI.
This helper function checks whether the string starts with the correct s3 prefix (s3://) and
whether the string contains both a bucket and a key.
Args:
uri: The URI to check
Returns:
Boolean indicating whether it is a valid S3 URI
"""
if not uri.startswith(s3_prefix):
return False
# Slice the uri from prefix onward, then check if there are 2 components when splitting on "/"
path = uri[len(s3_prefix) :]
if len(path.split("/", 1)) != 2:
return False
return True
return len(path.split("/", 1)) == 2

oci_pattern = r'^oci://(?P<host>[^/]+)/(?P<repository>[A-Za-z0-9_\-/]+)(:(?P<tag>[A-Za-z0-9_.-]+))?$'
oci_pattern = r"^oci://(?P<host>[^/]+)/(?P<repository>[A-Za-z0-9_\-/]+)(:(?P<tag>[A-Za-z0-9_.-]+))?$"

def is_oci_uri(uri: str):
"""Checks whether a string is a valid OCI URI
"""Checks whether a string is a valid OCI URI.
The expected format is:
oci://<host>/<repository>[:<tag>]
Examples of valid URIs:
oci://registry.example.com/my-namespace/my-repo:latest
oci://localhost:5000/my-repo
Args:
uri: The URI to check
Returns:
Boolean indicating whether it is a valid OCI URI
"""
Expand Down
1 change: 1 addition & 0 deletions clients/python/tests/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
data/
tmp/
38 changes: 38 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pathlib
from itertools import islice

import pytest
Expand Down Expand Up @@ -649,3 +650,40 @@ def test_hf_import_default_env(client: ModelRegistry):

for k in env_values:
os.environ.pop(k)

@pytest.mark.e2e
def test_upload_artifact_and_register_model_with_default_oci(client: ModelRegistry) -> None:
# olot is required to run this test
pytest.importorskip("olot")
name = "oci-test/defaults"
version = "0.0.1"
author = "Tester McTesterson"
oci_ref = "localhost:5001/foo/bar:latest"
local_path = pathlib.Path("tests/data")

upload_params = utils.OCIParams(
"quay.io/mmortari/hello-world-wait:latest",
oci_ref,
)

print(upload_params)

# Create a sample file named README.md to be added to the registry
pathlib.Path(local_path).mkdir(parents=True, exist_ok=True)
readme_file_path = os.path.join(local_path, "README.md")
with open(readme_file_path, "w") as f:
f.write("")


assert client.upload_artifact_and_register_model(
name,
model_files=[pathlib.Path(readme_file_path)],
author=author,
version=version,
model_format_name="test format",
model_format_version="test version",
upload_params=upload_params,
)

assert (ma := client.get_model_artifact(name, version))
assert ma.uri == f"oci://{oci_ref}"
46 changes: 12 additions & 34 deletions clients/python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import pytest

from model_registry.exceptions import MissingMetadata
<<<<<<< HEAD
from model_registry.utils import s3_uri_from, save_to_oci_registry
=======
from model_registry.utils import s3_uri_from, is_s3_uri, is_oci_uri
>>>>>>> 6ebe337 (chore: add oci and s3 helper methods)
from model_registry.utils import (
is_oci_uri,
is_s3_uri,
s3_uri_from,
save_to_oci_registry,
)


def test_s3_uri_builder():
Expand Down Expand Up @@ -121,12 +122,10 @@ def pull_mock_imple(base_image, dest_dir):


backend = "something_custom"
backend_registry = {
"something_custom": lambda: {
custom_oci_backend = {
"is_available": is_available_mock,
"pull": pull_mock,
"push": push_mock,
}
}

# similar to other test
Expand All @@ -142,34 +141,13 @@ def pull_mock_imple(base_image, dest_dir):

model_files = [readme_file_path]

uri = save_to_oci_registry(base_image, oci_ref, model_files, dest_dir, backend, None, backend_registry)
uri = save_to_oci_registry(base_image, oci_ref, model_files, dest_dir, backend, None, custom_oci_backend)
# Ensure our mocked backend was called
is_available_mock.assert_called_once()
pull_mock.assert_called_once()
push_mock.assert_called_once()
assert uri == f"oci://{oci_ref}"

def test_save_to_oci_registry_with_custom_backend_unavailable():
is_available_mock = Mock()
is_available_mock.return_value = False # Backend is unavailable, expect an error
pull_mock = Mock()
push_mock = Mock()

backend = "something_custom"
backend_registry = {
"something_custom": lambda: {
"is_available": is_available_mock,
"pull": pull_mock,
"push": push_mock,
}
}


with pytest.raises(ValueError, match=f"Backend '{backend}' is selected, but not available on the system. Ensure the dependencies for '{backend}' are installed in your environment.") as e:
save_to_oci_registry("", "", [], "", backend, backend_registry=backend_registry)

assert f"Backend '{backend}' is selected, but not available on the system." in str(e.value)

def test_save_to_oci_registry_backend_not_found():
backend = "non-existent"
with pytest.raises(ValueError, match=f"'{backend}' is not an available backend to use.") as e:
Expand All @@ -184,7 +162,7 @@ def test_is_s3_uri_with_valid_uris():
"s3://my-bucket/my-folder/my-sub-folder/my-file.sh",
]
for test in test_cases:
assert is_s3_uri(test) == True
assert is_s3_uri(test) is True

def test_is_s3_uri_with_invalid_uris():
test_cases = [
Expand All @@ -194,7 +172,7 @@ def test_is_s3_uri_with_invalid_uris():
"my-bucket/my-file.sh",
]
for test in test_cases:
assert is_s3_uri(test) == False
assert is_s3_uri(test) is False

def test_is_oci_uri_with_valid_uris():
test_cases = [
Expand All @@ -205,7 +183,7 @@ def test_is_oci_uri_with_valid_uris():
]

for test in test_cases:
assert is_oci_uri(test) == True
assert is_oci_uri(test) is True

def test_is_oci_uri_with_invalid_uris():
test_cases = [
Expand All @@ -219,5 +197,5 @@ def test_is_oci_uri_with_invalid_uris():
]

for test in test_cases:
assert is_oci_uri(test) == False
assert is_oci_uri(test) is False

0 comments on commit 6cfbeba

Please sign in to comment.