From 1aaecfbdad530a3ee03c4e4115af8dbe5be1696c Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sun, 10 Nov 2024 17:48:48 +0100 Subject: [PATCH 1/5] first draft --- src/banks/env.py | 6 +++-- src/banks/extensions/chat.py | 6 ++--- src/banks/filters/__init__.py | 2 ++ src/banks/filters/cache_control.py | 6 ++--- src/banks/filters/image.py | 40 ++++++++++++++++++++++++++++++ src/banks/types.py | 9 ++++++- tests/test_cache_control.py | 4 +-- tests/test_chat.py | 8 +++--- 8 files changed, 66 insertions(+), 15 deletions(-) create mode 100644 src/banks/filters/image.py diff --git a/src/banks/env.py b/src/banks/env.py index 247dd34..7d495a8 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -4,7 +4,7 @@ from jinja2 import Environment, PackageLoader, select_autoescape from .config import config -from .filters import cache_control, lemmatize, tool +from .filters import cache_control, image, lemmatize, tool def _add_extensions(_env): @@ -38,7 +38,9 @@ def _add_extensions(_env): # Setup custom filters and defaults -env.filters["lemmatize"] = lemmatize env.filters["cache_control"] = cache_control +env.filters["image"] = image +env.filters["lemmatize"] = lemmatize env.filters["tool"] = tool + _add_extensions(env) diff --git a/src/banks/extensions/chat.py b/src/banks/extensions/chat.py index ee75fa2..db6cb15 100644 --- a/src/banks/extensions/chat.py +++ b/src/banks/extensions/chat.py @@ -33,12 +33,12 @@ def content(self) -> ChatMessageContent: return self._content_blocks - def handle_starttag(self, tag, _): - if tag == "content_block_txt": + def handle_starttag(self, tag, attrs): # noqa + if tag == "content_block": self._parse_block_content = True def handle_endtag(self, tag): - if tag == "content_block_txt": + if tag == "content_block": self._parse_block_content = False def handle_data(self, data): diff --git a/src/banks/filters/__init__.py b/src/banks/filters/__init__.py index 27645ad..60ad899 100644 --- a/src/banks/filters/__init__.py +++ b/src/banks/filters/__init__.py @@ -2,11 +2,13 @@ # # SPDX-License-Identifier: MIT from .cache_control import cache_control +from .image import image from .lemmatize import lemmatize from .tool import tool __all__ = ( "cache_control", + "image", "lemmatize", "tool", ) diff --git a/src/banks/filters/cache_control.py b/src/banks/filters/cache_control.py index 65f4640..0e0284c 100644 --- a/src/banks/filters/cache_control.py +++ b/src/banks/filters/cache_control.py @@ -17,8 +17,8 @@ def cache_control(value: str, cache_type: str = "ephemeral") -> str: ``` Important: - this filter marks the content to cache by surrounding it with `` and - ``, so it's only useful when used within a `{% chat %}` block. + this filter marks the content to cache by surrounding it with `` and + ``, so it's only useful when used within a `{% chat %}` block. """ block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}}) - return f"{block.model_dump_json()}" + return f"{block.model_dump_json()}" diff --git a/src/banks/filters/image.py b/src/banks/filters/image.py new file mode 100644 index 0000000..81ea7cd --- /dev/null +++ b/src/banks/filters/image.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +from pathlib import Path +from urllib.parse import urlparse + +from banks.types import ContentBlock, ImageUrl + + +def _is_url(string: str) -> bool: + try: + result = urlparse(string) + return all([result.scheme, result.netloc]) + except ValueError: + return False + + +def image(value: str) -> str: + """Wrap the filtered value into a ContentBlock with the proper cache_control field set. + + The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects. + + Example: + ```jinja + Describe what you see + + {{ "path/to/image/file" | image }} + ``` + + Important: + this filter marks the content to cache by surrounding it with `` and + ``, so it's only useful when used within a `{% chat %}` block. + """ + if _is_url(value): + image_url = ImageUrl(url=value) + else: + image_url = ImageUrl.from_path(Path(value)) + + block = ContentBlock.model_validate({"type": "image_url", "image_url": image_url}) + return f"{block.model_dump_json()}" diff --git a/src/banks/types.py b/src/banks/types.py index fb011b9..7650034 100644 --- a/src/banks/types.py +++ b/src/banks/types.py @@ -1,8 +1,10 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +import base64 from enum import Enum from inspect import Parameter, getdoc, signature +from pathlib import Path from typing import Callable from pydantic import BaseModel @@ -26,9 +28,14 @@ class ImageUrl(BaseModel): url: str @classmethod - def from_base64(cls, media_type: str, base64_str: str): + def from_base64(cls, media_type: str, base64_str: str) -> Self: return cls(url=f"data:{media_type};base64,{base64_str}") + @classmethod + def from_path(cls, file_path: Path) -> Self: + with open(file_path, "rb") as image_file: + return cls.from_base64("image/jpeg", base64.b64encode(image_file.read()).decode("utf-8")) + class ContentBlock(BaseModel): type: ContentBlockType diff --git a/tests/test_cache_control.py b/tests/test_cache_control.py index 100153b..7be8b3b 100644 --- a/tests/test_cache_control.py +++ b/tests/test_cache_control.py @@ -3,6 +3,6 @@ def test_cache_control(): res = cache_control("foo", "ephemeral") - res = res.replace("", "") - res = res.replace("", "") + res = res.replace("", "") + res = res.replace("", "") assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","image_url":null}' diff --git a/tests/test_chat.py b/tests/test_chat.py index 71e3be1..38a90c0 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -30,7 +30,7 @@ def test_content_block_parser_init(): def test_content_block_parser_single_with_cache_control(): p = _ContentBlockParser() p.feed( - '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}' + '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}' ) assert p.content == [ ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None) @@ -39,15 +39,15 @@ def test_content_block_parser_single_with_cache_control(): def test_content_block_parser_single_no_cache_control(): p = _ContentBlockParser() - p.feed('{"type":"text","cache_control":null,"text":"foo","source":null}') + p.feed('{"type":"text","cache_control":null,"text":"foo","source":null}') assert p.content == "foo" def test_content_block_parser_multiple(): p = _ContentBlockParser() p.feed( - '{"type":"text","cache_control":null,"text":"foo","source":null}' - '{"type":"text","cache_control":null,"text":"bar","source":null}' + '{"type":"text","cache_control":null,"text":"foo","source":null}' + '{"type":"text","cache_control":null,"text":"bar","source":null}' ) assert p.content == [ ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None), From bff235f08d6e82ffeb4636f316be4411d251da9c Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sun, 10 Nov 2024 18:15:48 +0100 Subject: [PATCH 2/5] add unit tests --- src/banks/filters/image.py | 7 +-- tests/test_image.py | 92 ++++++++++++++++++++++++++++++++++++++ tests/test_types.py | 41 +++++++++++++++++ 3 files changed, 135 insertions(+), 5 deletions(-) create mode 100644 tests/test_image.py create mode 100644 tests/test_types.py diff --git a/src/banks/filters/image.py b/src/banks/filters/image.py index 81ea7cd..6482bdb 100644 --- a/src/banks/filters/image.py +++ b/src/banks/filters/image.py @@ -8,11 +8,8 @@ def _is_url(string: str) -> bool: - try: - result = urlparse(string) - return all([result.scheme, result.netloc]) - except ValueError: - return False + result = urlparse(string) + return all([result.scheme, result.netloc]) def image(value: str) -> str: diff --git a/tests/test_image.py b/tests/test_image.py new file mode 100644 index 0000000..29a8151 --- /dev/null +++ b/tests/test_image.py @@ -0,0 +1,92 @@ +import json +from pathlib import Path + +import pytest + +from banks.filters.image import _is_url, image + + +def test_is_url(): + """Test the internal URL validation function""" + assert _is_url("https://example.com/image.jpg") is True + assert _is_url("http://example.com/image.jpg") is True + assert _is_url("ftp://example.com/image.jpg") is True + assert _is_url("not_a_url.jpg") is False + assert _is_url("/path/to/image.jpg") is False + assert _is_url("relative/path/image.jpg") is False + assert _is_url("") is False + assert _is_url("https:\\example.com/image.jpg") is False + + +def test_image_with_url(): + """Test image filter with a URL input""" + url = "https://example.com/image.jpg" + result = image(url) + + # Verify the content block wrapper + assert result.startswith("") + assert result.endswith("") + + # Parse the JSON content + json_content = result[15:-16] # Remove wrapper tags + content_block = json.loads(json_content) + + assert content_block["type"] == "image_url" + assert content_block["image_url"]["url"] == url + + +def test_image_with_file_path(tmp_path): + """Test image filter with a file path input""" + # Create a temporary test image file + test_image = tmp_path / "test_image.jpg" + test_content = b"fake image content" + test_image.write_bytes(test_content) + + result = image(str(test_image)) + + # Verify the content block wrapper + assert result.startswith("") + assert result.endswith("") + + # Parse the JSON content + json_content = result[15:-16] # Remove wrapper tags + content_block = json.loads(json_content) + + assert content_block["type"] == "image_url" + assert content_block["image_url"]["url"].startswith("data:image/jpeg;base64,") + + +def test_image_with_nonexistent_file(): + """Test image filter with a nonexistent file path""" + with pytest.raises(FileNotFoundError): + image("nonexistent/image.jpg") + + +def test_image_content_block_structure(): + """Test the structure of the generated content block""" + url = "https://example.com/image.jpg" + result = image(url) + + json_content = result[15:-16] # Remove wrapper tags + content_block = json.loads(json_content) + + # Verify the content block has all expected fields + assert set(content_block.keys()) >= {"type", "image_url"} + assert content_block["type"] == "image_url" + assert isinstance(content_block["image_url"], dict) + assert "url" in content_block["image_url"] + + +@pytest.mark.parametrize( + "invalid_input", + [ + "", # empty string + None, # None value + 123, # non-string number + True, # boolean + ], +) +def test_image_with_invalid_input(invalid_input): + """Test image filter with various invalid inputs""" + with pytest.raises((IsADirectoryError, ValueError, AttributeError, TypeError)): + image(invalid_input) diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..71a9a4e --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,41 @@ +import base64 +from pathlib import Path + +import pytest + +from banks.types import ImageUrl + + +def test_image_url_from_base64(): + """Test creating ImageUrl from base64 encoded data""" + test_data = "Hello, World!" + base64_data = base64.b64encode(test_data.encode()).decode("utf-8") + media_type = "image/jpeg" + + image_url = ImageUrl.from_base64(media_type, base64_data) + expected_url = f"data:{media_type};base64,{base64_data}" + assert image_url.url == expected_url + + +def test_image_url_from_path(tmp_path): + """Test creating ImageUrl from a file path""" + # Create a temporary test image file + test_image = tmp_path / "test_image.jpg" + test_content = b"fake image content" + test_image.write_bytes(test_content) + + image_url = ImageUrl.from_path(test_image) + + # Verify the URL starts with the expected data URI prefix + assert image_url.url.startswith("data:image/jpeg;base64,") + + # Decode the base64 part and verify the content matches + base64_part = image_url.url.split(",")[1] + decoded_content = base64.b64decode(base64_part) + assert decoded_content == test_content + + +def test_image_url_from_path_nonexistent(): + """Test creating ImageUrl from a nonexistent file path""" + with pytest.raises(FileNotFoundError): + ImageUrl.from_path(Path("nonexistent.jpg")) From f9f4bf6c9ebb6275ba5b499c88a2f73963879f2e Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sun, 10 Nov 2024 18:21:34 +0100 Subject: [PATCH 3/5] removed unused import --- tests/test_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_image.py b/tests/test_image.py index 29a8151..bbb0917 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -1,5 +1,4 @@ import json -from pathlib import Path import pytest From 64b657035ad14bc8d40a0c9bfb6097a8b7ff6e03 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sun, 10 Nov 2024 18:26:50 +0100 Subject: [PATCH 4/5] remove useless test --- tests/test_image.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index bbb0917..010c422 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -74,18 +74,3 @@ def test_image_content_block_structure(): assert content_block["type"] == "image_url" assert isinstance(content_block["image_url"], dict) assert "url" in content_block["image_url"] - - -@pytest.mark.parametrize( - "invalid_input", - [ - "", # empty string - None, # None value - 123, # non-string number - True, # boolean - ], -) -def test_image_with_invalid_input(invalid_input): - """Test image filter with various invalid inputs""" - with pytest.raises((IsADirectoryError, ValueError, AttributeError, TypeError)): - image(invalid_input) From 484824f0a7a1ae6b3d2722fede115a54d31b7878 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sun, 10 Nov 2024 18:29:50 +0100 Subject: [PATCH 5/5] docs --- docs/prompt.md | 7 +++++++ src/banks/filters/image.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/prompt.md b/docs/prompt.md index a628693..52899eb 100644 --- a/docs/prompt.md +++ b/docs/prompt.md @@ -34,6 +34,13 @@ provided by Jinja, Banks supports the following ones, specific for prompt engine show_signature_annotations: false heading_level: 3 +::: banks.filters.image.image + options: + show_root_full_path: false + show_symbol_type_heading: false + show_signature_annotations: false + heading_level: 3 + ::: banks.filters.lemmatize.lemmatize options: show_root_full_path: false diff --git a/src/banks/filters/image.py b/src/banks/filters/image.py index 6482bdb..29d15d1 100644 --- a/src/banks/filters/image.py +++ b/src/banks/filters/image.py @@ -13,7 +13,7 @@ def _is_url(string: str) -> bool: def image(value: str) -> str: - """Wrap the filtered value into a ContentBlock with the proper cache_control field set. + """Wrap the filtered value into a ContentBlock of type image. The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects.