diff --git a/olot/basics.py b/olot/basics.py index 0894ec4..e452534 100644 --- a/olot/basics.py +++ b/olot/basics.py @@ -1,9 +1,10 @@ +import logging import os from pathlib import Path from pprint import pprint import tarfile -from typing import Dict, List +from typing import Dict, List, Sequence import typing import click @@ -14,12 +15,30 @@ from olot.oci.oci_image_layout import verify_ocilayout from olot.oci.oci_common import MediaTypes -from olot.utils.files import tarball_from_file, targz_from_file +from olot.utils.files import handle_remove, tarball_from_file, targz_from_file from olot.utils.types import compute_hash_of_str -def oci_layers_on_top(ocilayout: typing.Union[str, os.PathLike], model_files: List[os.PathLike], modelcard: typing.Union[os.PathLike, None] = None): +logger = logging.getLogger(__name__) + +def oci_layers_on_top( + ocilayout: typing.Union[str, os.PathLike], + model_files: Sequence[os.PathLike], + modelcard: typing.Union[os.PathLike, None] = None, + *, + remove_originals: bool = False): + """ + Add contents to an oci-layout directory as new blob layers + + Args: + ocilayout: The oci-layout directory of the base image. + model_files: PathLike array to be added as new blob layers. + modelcard: PathLike of the README.md of the ModelCarD, will be added as the last layer with compression and annotations. + remove: whether to remove the original content files after having added the layers, default: False. + """ if not isinstance(ocilayout, Path): ocilayout = Path(ocilayout) + if remove_originals: + logger.info("Invoked with 'remove' to delete original contents after adding as a blob layer.") verify_ocilayout(ocilayout) ocilayout_root_index = read_ocilayout_root_index(ocilayout) @@ -32,9 +51,13 @@ def oci_layers_on_top(ocilayout: typing.Union[str, os.PathLike], model_files: Li model = Path(model) new_layer = tarball_from_file(model, sha256_path) new_layers[new_layer] = new_layer + if remove_originals: + handle_remove(model) if modelcard is not None: modelcard_layer_diffid = targz_from_file(Path(modelcard), sha256_path) new_layers[modelcard_layer_diffid[0]] = modelcard_layer_diffid[1] + if remove_originals: + handle_remove(modelcard) new_ocilayout_manifests: Dict[str, str] = {} for manifest_hash, manifest in ocilayout_manifests.items(): diff --git a/olot/cli.py b/olot/cli.py index 58e751a..f06f852 100644 --- a/olot/cli.py +++ b/olot/cli.py @@ -1,8 +1,6 @@ from os import PathLike -from pathlib import Path import click - - +import logging from .basics import oci_layers_on_top @@ -11,5 +9,7 @@ @click.option("-m", "--modelcard", type=click.Path(exists=True, file_okay=True, dir_okay=False)) @click.argument('ocilayout', type=click.Path(exists=True, file_okay=False, dir_okay=True)) @click.argument('model_files', nargs=-1) -def cli(ocilayout: str, modelcard: PathLike, model_files): - oci_layers_on_top(Path(ocilayout), model_files, modelcard) +@click.option('-r', '--remove-originals', is_flag=True) +def cli(ocilayout: str, modelcard: PathLike, model_files, remove_originals: bool): + logging.basicConfig(level=logging.INFO) + oci_layers_on_top(ocilayout, model_files, modelcard, remove_originals=remove_originals) diff --git a/olot/utils/files.py b/olot/utils/files.py index ac40439..8797f5d 100644 --- a/olot/utils/files.py +++ b/olot/utils/files.py @@ -1,9 +1,13 @@ import hashlib +import logging +import shutil import tarfile from pathlib import Path import gzip import os +logger = logging.getLogger(__name__) + class HashingWriter: def __init__(self, base_writer, hash_func=None): self.base_writer = base_writer @@ -129,3 +133,14 @@ def targz_from_file(file_path: Path, dest: Path) -> tuple[str, str]: raise tarfile.TarError(f"Error creating tarball: {e}") from e except OSError as e: raise OSError(f"File operation failed: {e}") from e + + +def handle_remove(path: os.PathLike): + if not isinstance(path, Path): + path = Path(path) + if path.is_symlink(): + logger.warning("removing %s which is a symlink", path) + if path.is_dir(): + shutil.rmtree(path) + else: + os.remove(path) diff --git a/tests/basic_test.py b/tests/basic_test.py index 619e790..9c9ca64 100644 --- a/tests/basic_test.py +++ b/tests/basic_test.py @@ -1,10 +1,13 @@ +import os from pathlib import Path +import shutil from typing import Dict -from olot.basics import crawl_ocilayout_blobs_to_extract, crawl_ocilayout_indexes, crawl_ocilayout_manifests +from olot.basics import crawl_ocilayout_blobs_to_extract, crawl_ocilayout_indexes, crawl_ocilayout_manifests, oci_layers_on_top from olot.oci.oci_image_index import OCIImageIndex, read_ocilayout_root_index from olot.oci.oci_image_manifest import OCIImageManifest +from tests.common import sample_model_path, test_data_path def test_crawl_ocilayout_indexes(): @@ -63,3 +66,30 @@ def test_crawl_ocilayout_blobs_to_extract(tmp_path: Path): assert modelcard.exists() modelfile = tmp_path / "models" / "model.joblib" assert modelfile.exists() + + +def test_oci_layers_on_top_with_remove(tmp_path: Path): + """put oci_layers_on_top under test with 'remove' option + """ + test_sample_model = sample_model_path() + test_ocilayout2 = test_data_path() / "ocilayout2" + target_ocilayout = tmp_path / "myocilayout" + shutil.copytree(test_ocilayout2, target_ocilayout) + target_model = tmp_path / "models" + shutil.copytree(test_sample_model, target_model) + print(os.listdir(target_model)) + + models = [ + target_model / "model.joblib", + target_model / "hello.md" + ] + for model in models: + assert model.exists() + modelcard = target_model / "README.md" + assert modelcard.exists() + + oci_layers_on_top(target_ocilayout, models, modelcard, remove_originals=True) + + for model in models: + assert not model.exists() + assert not modelcard.exists() diff --git a/tests/conftest.py b/tests/conftest.py index bf74803..343d6b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ +import logging import pytest +logging.basicConfig(level=logging.INFO) def pytest_collection_modifyitems(config, items): for item in items: