diff --git a/docs/prompt.md b/docs/prompt.md
index a628693..52899eb 100644
--- a/docs/prompt.md
+++ b/docs/prompt.md
@@ -34,6 +34,13 @@ provided by Jinja, Banks supports the following ones, specific for prompt engine
show_signature_annotations: false
heading_level: 3
+::: banks.filters.image.image
+ options:
+ show_root_full_path: false
+ show_symbol_type_heading: false
+ show_signature_annotations: false
+ heading_level: 3
+
::: banks.filters.lemmatize.lemmatize
options:
show_root_full_path: false
diff --git a/src/banks/env.py b/src/banks/env.py
index 247dd34..7d495a8 100644
--- a/src/banks/env.py
+++ b/src/banks/env.py
@@ -4,7 +4,7 @@
from jinja2 import Environment, PackageLoader, select_autoescape
from .config import config
-from .filters import cache_control, lemmatize, tool
+from .filters import cache_control, image, lemmatize, tool
def _add_extensions(_env):
@@ -38,7 +38,9 @@ def _add_extensions(_env):
# Setup custom filters and defaults
-env.filters["lemmatize"] = lemmatize
env.filters["cache_control"] = cache_control
+env.filters["image"] = image
+env.filters["lemmatize"] = lemmatize
env.filters["tool"] = tool
+
_add_extensions(env)
diff --git a/src/banks/extensions/chat.py b/src/banks/extensions/chat.py
index ee75fa2..db6cb15 100644
--- a/src/banks/extensions/chat.py
+++ b/src/banks/extensions/chat.py
@@ -33,12 +33,12 @@ def content(self) -> ChatMessageContent:
return self._content_blocks
- def handle_starttag(self, tag, _):
- if tag == "content_block_txt":
+ def handle_starttag(self, tag, attrs): # noqa
+ if tag == "content_block":
self._parse_block_content = True
def handle_endtag(self, tag):
- if tag == "content_block_txt":
+ if tag == "content_block":
self._parse_block_content = False
def handle_data(self, data):
diff --git a/src/banks/filters/__init__.py b/src/banks/filters/__init__.py
index 27645ad..60ad899 100644
--- a/src/banks/filters/__init__.py
+++ b/src/banks/filters/__init__.py
@@ -2,11 +2,13 @@
#
# SPDX-License-Identifier: MIT
from .cache_control import cache_control
+from .image import image
from .lemmatize import lemmatize
from .tool import tool
__all__ = (
"cache_control",
+ "image",
"lemmatize",
"tool",
)
diff --git a/src/banks/filters/cache_control.py b/src/banks/filters/cache_control.py
index 65f4640..0e0284c 100644
--- a/src/banks/filters/cache_control.py
+++ b/src/banks/filters/cache_control.py
@@ -17,8 +17,8 @@ def cache_control(value: str, cache_type: str = "ephemeral") -> str:
```
Important:
- this filter marks the content to cache by surrounding it with `` and
- ``, so it's only useful when used within a `{% chat %}` block.
+ this filter marks the content to cache by surrounding it with `` and
+ ``, so it's only useful when used within a `{% chat %}` block.
"""
block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}})
- return f"{block.model_dump_json()}"
+ return f"{block.model_dump_json()}"
diff --git a/src/banks/filters/image.py b/src/banks/filters/image.py
new file mode 100644
index 0000000..29d15d1
--- /dev/null
+++ b/src/banks/filters/image.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi
+#
+# SPDX-License-Identifier: MIT
+from pathlib import Path
+from urllib.parse import urlparse
+
+from banks.types import ContentBlock, ImageUrl
+
+
+def _is_url(string: str) -> bool:
+ result = urlparse(string)
+ return all([result.scheme, result.netloc])
+
+
+def image(value: str) -> str:
+ """Wrap the filtered value into a ContentBlock of type image.
+
+ The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects.
+
+ Example:
+ ```jinja
+ Describe what you see
+
+ {{ "path/to/image/file" | image }}
+ ```
+
+ Important:
+ this filter marks the content to cache by surrounding it with `` and
+ ``, so it's only useful when used within a `{% chat %}` block.
+ """
+ if _is_url(value):
+ image_url = ImageUrl(url=value)
+ else:
+ image_url = ImageUrl.from_path(Path(value))
+
+ block = ContentBlock.model_validate({"type": "image_url", "image_url": image_url})
+ return f"{block.model_dump_json()}"
diff --git a/src/banks/types.py b/src/banks/types.py
index fb011b9..7650034 100644
--- a/src/banks/types.py
+++ b/src/banks/types.py
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi
#
# SPDX-License-Identifier: MIT
+import base64
from enum import Enum
from inspect import Parameter, getdoc, signature
+from pathlib import Path
from typing import Callable
from pydantic import BaseModel
@@ -26,9 +28,14 @@ class ImageUrl(BaseModel):
url: str
@classmethod
- def from_base64(cls, media_type: str, base64_str: str):
+ def from_base64(cls, media_type: str, base64_str: str) -> Self:
return cls(url=f"data:{media_type};base64,{base64_str}")
+ @classmethod
+ def from_path(cls, file_path: Path) -> Self:
+ with open(file_path, "rb") as image_file:
+ return cls.from_base64("image/jpeg", base64.b64encode(image_file.read()).decode("utf-8"))
+
class ContentBlock(BaseModel):
type: ContentBlockType
diff --git a/tests/test_cache_control.py b/tests/test_cache_control.py
index 100153b..7be8b3b 100644
--- a/tests/test_cache_control.py
+++ b/tests/test_cache_control.py
@@ -3,6 +3,6 @@
def test_cache_control():
res = cache_control("foo", "ephemeral")
- res = res.replace("", "")
- res = res.replace("", "")
+ res = res.replace("", "")
+ res = res.replace("", "")
assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","image_url":null}'
diff --git a/tests/test_chat.py b/tests/test_chat.py
index 71e3be1..38a90c0 100644
--- a/tests/test_chat.py
+++ b/tests/test_chat.py
@@ -30,7 +30,7 @@ def test_content_block_parser_init():
def test_content_block_parser_single_with_cache_control():
p = _ContentBlockParser()
p.feed(
- '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}'
+ '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None)
@@ -39,15 +39,15 @@ def test_content_block_parser_single_with_cache_control():
def test_content_block_parser_single_no_cache_control():
p = _ContentBlockParser()
- p.feed('{"type":"text","cache_control":null,"text":"foo","source":null}')
+ p.feed('{"type":"text","cache_control":null,"text":"foo","source":null}')
assert p.content == "foo"
def test_content_block_parser_multiple():
p = _ContentBlockParser()
p.feed(
- '{"type":"text","cache_control":null,"text":"foo","source":null}'
- '{"type":"text","cache_control":null,"text":"bar","source":null}'
+ '{"type":"text","cache_control":null,"text":"foo","source":null}'
+ '{"type":"text","cache_control":null,"text":"bar","source":null}'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None),
diff --git a/tests/test_image.py b/tests/test_image.py
new file mode 100644
index 0000000..010c422
--- /dev/null
+++ b/tests/test_image.py
@@ -0,0 +1,76 @@
+import json
+
+import pytest
+
+from banks.filters.image import _is_url, image
+
+
+def test_is_url():
+ """Test the internal URL validation function"""
+ assert _is_url("https://example.com/image.jpg") is True
+ assert _is_url("http://example.com/image.jpg") is True
+ assert _is_url("ftp://example.com/image.jpg") is True
+ assert _is_url("not_a_url.jpg") is False
+ assert _is_url("/path/to/image.jpg") is False
+ assert _is_url("relative/path/image.jpg") is False
+ assert _is_url("") is False
+ assert _is_url("https:\\example.com/image.jpg") is False
+
+
+def test_image_with_url():
+ """Test image filter with a URL input"""
+ url = "https://example.com/image.jpg"
+ result = image(url)
+
+ # Verify the content block wrapper
+ assert result.startswith("")
+ assert result.endswith("")
+
+ # Parse the JSON content
+ json_content = result[15:-16] # Remove wrapper tags
+ content_block = json.loads(json_content)
+
+ assert content_block["type"] == "image_url"
+ assert content_block["image_url"]["url"] == url
+
+
+def test_image_with_file_path(tmp_path):
+ """Test image filter with a file path input"""
+ # Create a temporary test image file
+ test_image = tmp_path / "test_image.jpg"
+ test_content = b"fake image content"
+ test_image.write_bytes(test_content)
+
+ result = image(str(test_image))
+
+ # Verify the content block wrapper
+ assert result.startswith("")
+ assert result.endswith("")
+
+ # Parse the JSON content
+ json_content = result[15:-16] # Remove wrapper tags
+ content_block = json.loads(json_content)
+
+ assert content_block["type"] == "image_url"
+ assert content_block["image_url"]["url"].startswith("data:image/jpeg;base64,")
+
+
+def test_image_with_nonexistent_file():
+ """Test image filter with a nonexistent file path"""
+ with pytest.raises(FileNotFoundError):
+ image("nonexistent/image.jpg")
+
+
+def test_image_content_block_structure():
+ """Test the structure of the generated content block"""
+ url = "https://example.com/image.jpg"
+ result = image(url)
+
+ json_content = result[15:-16] # Remove wrapper tags
+ content_block = json.loads(json_content)
+
+ # Verify the content block has all expected fields
+ assert set(content_block.keys()) >= {"type", "image_url"}
+ assert content_block["type"] == "image_url"
+ assert isinstance(content_block["image_url"], dict)
+ assert "url" in content_block["image_url"]
diff --git a/tests/test_types.py b/tests/test_types.py
new file mode 100644
index 0000000..71a9a4e
--- /dev/null
+++ b/tests/test_types.py
@@ -0,0 +1,41 @@
+import base64
+from pathlib import Path
+
+import pytest
+
+from banks.types import ImageUrl
+
+
+def test_image_url_from_base64():
+ """Test creating ImageUrl from base64 encoded data"""
+ test_data = "Hello, World!"
+ base64_data = base64.b64encode(test_data.encode()).decode("utf-8")
+ media_type = "image/jpeg"
+
+ image_url = ImageUrl.from_base64(media_type, base64_data)
+ expected_url = f"data:{media_type};base64,{base64_data}"
+ assert image_url.url == expected_url
+
+
+def test_image_url_from_path(tmp_path):
+ """Test creating ImageUrl from a file path"""
+ # Create a temporary test image file
+ test_image = tmp_path / "test_image.jpg"
+ test_content = b"fake image content"
+ test_image.write_bytes(test_content)
+
+ image_url = ImageUrl.from_path(test_image)
+
+ # Verify the URL starts with the expected data URI prefix
+ assert image_url.url.startswith("data:image/jpeg;base64,")
+
+ # Decode the base64 part and verify the content matches
+ base64_part = image_url.url.split(",")[1]
+ decoded_content = base64.b64decode(base64_part)
+ assert decoded_content == test_content
+
+
+def test_image_url_from_path_nonexistent():
+ """Test creating ImageUrl from a nonexistent file path"""
+ with pytest.raises(FileNotFoundError):
+ ImageUrl.from_path(Path("nonexistent.jpg"))