Skip to content

Commit

Permalink
implement code review comments
Browse files Browse the repository at this point in the history
return the output as a list instead of printing

Signed-off-by: tarilabs <[email protected]>
  • Loading branch information
tarilabs committed Feb 13, 2025
1 parent 052661c commit 278de46
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
9 changes: 7 additions & 2 deletions olot/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,19 @@ def crawl_ocilayout_indexes(ocilayout: Path, ocilayout_root_index: OCIImageIndex

def crawl_ocilayout_blobs_to_extract(ocilayout: Path,
output_path: Path,
tar_filter_dir: str = "/models"):
tar_filter_dir: str = "/models") -> List[str]:
"""
Extract from OCI Image/ModelCar only the contents from a specific directory.
Args:
ocilayout: The directory containing the oci-layout of the OCI Image/ModelCar.
output_path: The directory where to extract the ML model assets from the ModelCar to.
tar_filter_dir: The subdirectory in the ModelCar to extract, defaults to `"/models"`.
Returns:
The list of extracted ML contents from the OCI Image/ModelCar.
"""
extracted: List[str] = []
tar_filter_dir= tar_filter_dir.lstrip("/")
blobs_path = ocilayout / "blobs" / "sha256"
if not os.path.exists(output_path):
Expand All @@ -174,7 +178,8 @@ def crawl_ocilayout_blobs_to_extract(ocilayout: Path,
for member in tar.getmembers():
if member.isfile() and member.name.startswith(tar_filter_dir):
tar.extract(member, path=output_path)
print(f"Extracted: {member.name}")
extracted.append(member.name)
return extracted


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion tests/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def test_crawl_ocilayout_blobs_to_extract(tmp_path: Path):
Verify extraction from ModelCar produces those 2 assets.
"""
ocilayout4_path = Path(__file__).parent / "data" / "ocilayout4"
crawl_ocilayout_blobs_to_extract(ocilayout4_path, tmp_path)
mut = crawl_ocilayout_blobs_to_extract(ocilayout4_path, tmp_path)

assert len(mut) == 2
assert "models/README.md" in mut
assert "models/model.joblib" in mut

assert len([x for x in tmp_path.rglob("*") if x.is_file()]) == 2
modelcard = tmp_path / "models" / "README.md"
Expand Down

0 comments on commit 278de46

Please sign in to comment.