-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from jaideepr97/create-manifest
Create OCI image manifest
- Loading branch information
Showing
5 changed files
with
181 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from pathlib import Path | ||
import os | ||
import datetime | ||
import json | ||
import argparse | ||
from typing import List | ||
|
||
from olot.oci.oci_image_manifest import create_oci_image_manifest, create_manifest_layers | ||
from olot.oci.oci_common import Keys | ||
from olot.utils.files import MIMETypes, tarball_from_file, targz_from_file | ||
from olot.utils.types import compute_hash_of_str | ||
|
||
def create_oci_artifact_from_model(source_dir: Path, dest_dir: Path): | ||
if not source_dir.exists(): | ||
raise NotADirectoryError(f"Input directory '{source_dir}' does not exist.") | ||
|
||
if dest_dir is None: | ||
dest_dir = source_dir / "oci" | ||
os.makedirs(dest_dir, exist_ok=True) | ||
|
||
sha256_path = dest_dir / "blobs" / "sha256" | ||
os.makedirs(sha256_path, exist_ok=True) | ||
|
||
# assume flat structure for source_dir for now | ||
# TODO: handle subdirectories appropriately | ||
model_files = [source_dir / Path(f) for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))] | ||
|
||
# Populate blobs directory | ||
layers = create_blobs(model_files, dest_dir) | ||
|
||
# Create the OCI image manifest | ||
manifest_layers = create_manifest_layers(model_files, layers) | ||
annotations = { | ||
Keys.image_created_annotation: datetime.datetime.now().isoformat() | ||
} | ||
artifactType = MIMETypes.mlmodel | ||
manifest = create_oci_image_manifest( | ||
artifactType=artifactType, | ||
layers=manifest_layers, | ||
annotations=annotations | ||
) | ||
manifest_json = json.dumps(manifest.dict(), indent=4, sort_keys=True) | ||
manifest_SHA = compute_hash_of_str(manifest_json) | ||
with open(sha256_path / manifest_SHA, "w") as f: | ||
f.write(manifest_json) | ||
|
||
|
||
def create_blobs(model_files: List[Path], dest_dir: Path): | ||
""" | ||
Create the blobs directory for an OCI artifact. | ||
""" | ||
layers = {} # layer digest : (precomp, postcomp) | ||
sha256_path = dest_dir / "blobs" / "sha256" | ||
|
||
for model_file in model_files: | ||
file_name = os.path.basename(os.path.normpath(model_file)) | ||
# handle model card file if encountered - assume README.md is the modelcard | ||
if file_name.endswith("README.md"): | ||
postcomp_chksum, precomp_chksum = targz_from_file(model_file, sha256_path) | ||
layers[file_name] = (precomp_chksum, postcomp_chksum) | ||
else: | ||
checksum = tarball_from_file(model_file, sha256_path) | ||
layers[file_name] = (checksum, "") | ||
return layers | ||
|
||
# create a main function to test the function | ||
def main(): | ||
parser = argparse.ArgumentParser(description="Create OCI artifact from model") | ||
parser.add_argument('source_dir', type=str, help='Path to the source directory') | ||
args = parser.parse_args() | ||
|
||
source_dir = Path(args.source_dir) | ||
create_oci_artifact_from_model(source_dir, None) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 9 additions & 5 deletions
14
tests/oci/oci_common_test.py → tests/oci/oci_artifact_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,26 @@ | ||
from tests.common import sample_model_path, file_checksums_with_compression, file_checksums_without_compression | ||
from olot.oci.oci_common import create_blobs | ||
from olot.oci_artifact import create_blobs | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
|
||
def test_create_blobs(tmp_path): | ||
source_dir = sample_model_path() | ||
dest_dir = tmp_path | ||
|
||
layers = create_blobs(source_dir, dest_dir) | ||
model_files = [source_dir / Path(f) for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))] | ||
|
||
layers = create_blobs(model_files, dest_dir) | ||
|
||
expected_layers = {} | ||
result = file_checksums_with_compression(source_dir / "README.md", dest_dir) | ||
expected_layers[result[0]] = result[1] | ||
expected_layers["README.md"] = result[1] | ||
|
||
result = file_checksums_without_compression(source_dir / "hello.md", dest_dir) | ||
expected_layers[result] = result | ||
expected_layers["hello.md"] = result | ||
|
||
result = file_checksums_without_compression(source_dir / "model.joblib", dest_dir) | ||
expected_layers[result] = result | ||
expected_layers["model.joblib"] = result | ||
|
||
assert sorted(layers) == sorted(expected_layers) |