Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add image tag filter to support images within a prompt for vision models #22

Merged
merged 5 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ provided by Jinja, Banks supports the following ones, specific for prompt engine
show_signature_annotations: false
heading_level: 3

::: banks.filters.image.image
options:
show_root_full_path: false
show_symbol_type_heading: false
show_signature_annotations: false
heading_level: 3

::: banks.filters.lemmatize.lemmatize
options:
show_root_full_path: false
Expand Down
6 changes: 4 additions & 2 deletions src/banks/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jinja2 import Environment, PackageLoader, select_autoescape

from .config import config
from .filters import cache_control, lemmatize, tool
from .filters import cache_control, image, lemmatize, tool


def _add_extensions(_env):
Expand Down Expand Up @@ -38,7 +38,9 @@ def _add_extensions(_env):


# Setup custom filters and defaults
env.filters["lemmatize"] = lemmatize
env.filters["cache_control"] = cache_control
env.filters["image"] = image
env.filters["lemmatize"] = lemmatize
env.filters["tool"] = tool

_add_extensions(env)
6 changes: 3 additions & 3 deletions src/banks/extensions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def content(self) -> ChatMessageContent:

return self._content_blocks

def handle_starttag(self, tag, _):
if tag == "content_block_txt":
def handle_starttag(self, tag, attrs): # noqa
if tag == "content_block":
self._parse_block_content = True

def handle_endtag(self, tag):
if tag == "content_block_txt":
if tag == "content_block":
self._parse_block_content = False

def handle_data(self, data):
Expand Down
2 changes: 2 additions & 0 deletions src/banks/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#
# SPDX-License-Identifier: MIT
from .cache_control import cache_control
from .image import image
from .lemmatize import lemmatize
from .tool import tool

__all__ = (
"cache_control",
"image",
"lemmatize",
"tool",
)
6 changes: 3 additions & 3 deletions src/banks/filters/cache_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def cache_control(value: str, cache_type: str = "ephemeral") -> str:
```

Important:
this filter marks the content to cache by surrounding it with `<content_block_txt>` and
`</content_block_txt>`, so it's only useful when used within a `{% chat %}` block.
this filter marks the content to cache by surrounding it with `<content_block>` and
`</content_block>`, so it's only useful when used within a `{% chat %}` block.
"""
block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}})
return f"<content_block_txt>{block.model_dump_json()}</content_block_txt>"
return f"<content_block>{block.model_dump_json()}</content_block>"
37 changes: 37 additions & 0 deletions src/banks/filters/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from pathlib import Path
from urllib.parse import urlparse

from banks.types import ContentBlock, ImageUrl


def _is_url(string: str) -> bool:
result = urlparse(string)
return all([result.scheme, result.netloc])


def image(value: str) -> str:
"""Wrap the filtered value into a ContentBlock of type image.

The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects.

Example:
```jinja
Describe what you see

{{ "path/to/image/file" | image }}
```

Important:
this filter marks the content to cache by surrounding it with `<content_block>` and
`</content_block>`, so it's only useful when used within a `{% chat %}` block.
"""
if _is_url(value):
image_url = ImageUrl(url=value)
else:
image_url = ImageUrl.from_path(Path(value))

block = ContentBlock.model_validate({"type": "image_url", "image_url": image_url})
return f"<content_block>{block.model_dump_json()}</content_block>"
9 changes: 8 additions & 1 deletion src/banks/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import base64
from enum import Enum
from inspect import Parameter, getdoc, signature
from pathlib import Path
from typing import Callable

from pydantic import BaseModel
Expand All @@ -26,9 +28,14 @@ class ImageUrl(BaseModel):
url: str

@classmethod
def from_base64(cls, media_type: str, base64_str: str):
def from_base64(cls, media_type: str, base64_str: str) -> Self:
return cls(url=f"data:{media_type};base64,{base64_str}")

@classmethod
def from_path(cls, file_path: Path) -> Self:
with open(file_path, "rb") as image_file:
return cls.from_base64("image/jpeg", base64.b64encode(image_file.read()).decode("utf-8"))


class ContentBlock(BaseModel):
type: ContentBlockType
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cache_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

def test_cache_control():
res = cache_control("foo", "ephemeral")
res = res.replace("<content_block_txt>", "")
res = res.replace("</content_block_txt>", "")
res = res.replace("<content_block>", "")
res = res.replace("</content_block>", "")
assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","image_url":null}'
8 changes: 4 additions & 4 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_content_block_parser_init():
def test_content_block_parser_single_with_cache_control():
p = _ContentBlockParser()
p.feed(
'<content_block_txt>{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}</content_block_txt>'
'<content_block>{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}</content_block>'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None)
Expand All @@ -39,15 +39,15 @@ def test_content_block_parser_single_with_cache_control():

def test_content_block_parser_single_no_cache_control():
p = _ContentBlockParser()
p.feed('<content_block_txt>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block_txt>')
p.feed('<content_block>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block>')
assert p.content == "foo"


def test_content_block_parser_multiple():
p = _ContentBlockParser()
p.feed(
'<content_block_txt>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block_txt>'
'<content_block_txt>{"type":"text","cache_control":null,"text":"bar","source":null}</content_block_txt>'
'<content_block>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block>'
'<content_block>{"type":"text","cache_control":null,"text":"bar","source":null}</content_block>'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None),
Expand Down
76 changes: 76 additions & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json

import pytest

from banks.filters.image import _is_url, image


def test_is_url():
"""Test the internal URL validation function"""
assert _is_url("https://example.com/image.jpg") is True
assert _is_url("http://example.com/image.jpg") is True
assert _is_url("ftp://example.com/image.jpg") is True
assert _is_url("not_a_url.jpg") is False
assert _is_url("/path/to/image.jpg") is False
assert _is_url("relative/path/image.jpg") is False
assert _is_url("") is False
assert _is_url("https:\\example.com/image.jpg") is False


def test_image_with_url():
"""Test image filter with a URL input"""
url = "https://example.com/image.jpg"
result = image(url)

# Verify the content block wrapper
assert result.startswith("<content_block>")
assert result.endswith("</content_block>")

# Parse the JSON content
json_content = result[15:-16] # Remove wrapper tags
content_block = json.loads(json_content)

assert content_block["type"] == "image_url"
assert content_block["image_url"]["url"] == url


def test_image_with_file_path(tmp_path):
"""Test image filter with a file path input"""
# Create a temporary test image file
test_image = tmp_path / "test_image.jpg"
test_content = b"fake image content"
test_image.write_bytes(test_content)

result = image(str(test_image))

# Verify the content block wrapper
assert result.startswith("<content_block>")
assert result.endswith("</content_block>")

# Parse the JSON content
json_content = result[15:-16] # Remove wrapper tags
content_block = json.loads(json_content)

assert content_block["type"] == "image_url"
assert content_block["image_url"]["url"].startswith("data:image/jpeg;base64,")


def test_image_with_nonexistent_file():
"""Test image filter with a nonexistent file path"""
with pytest.raises(FileNotFoundError):
image("nonexistent/image.jpg")


def test_image_content_block_structure():
"""Test the structure of the generated content block"""
url = "https://example.com/image.jpg"
result = image(url)

json_content = result[15:-16] # Remove wrapper tags
content_block = json.loads(json_content)

# Verify the content block has all expected fields
assert set(content_block.keys()) >= {"type", "image_url"}
assert content_block["type"] == "image_url"
assert isinstance(content_block["image_url"], dict)
assert "url" in content_block["image_url"]
41 changes: 41 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import base64
from pathlib import Path

import pytest

from banks.types import ImageUrl


def test_image_url_from_base64():
"""Test creating ImageUrl from base64 encoded data"""
test_data = "Hello, World!"
base64_data = base64.b64encode(test_data.encode()).decode("utf-8")
media_type = "image/jpeg"

image_url = ImageUrl.from_base64(media_type, base64_data)
expected_url = f"data:{media_type};base64,{base64_data}"
assert image_url.url == expected_url


def test_image_url_from_path(tmp_path):
"""Test creating ImageUrl from a file path"""
# Create a temporary test image file
test_image = tmp_path / "test_image.jpg"
test_content = b"fake image content"
test_image.write_bytes(test_content)

image_url = ImageUrl.from_path(test_image)

# Verify the URL starts with the expected data URI prefix
assert image_url.url.startswith("data:image/jpeg;base64,")

# Decode the base64 part and verify the content matches
base64_part = image_url.url.split(",")[1]
decoded_content = base64.b64decode(base64_part)
assert decoded_content == test_content


def test_image_url_from_path_nonexistent():
"""Test creating ImageUrl from a nonexistent file path"""
with pytest.raises(FileNotFoundError):
ImageUrl.from_path(Path("nonexistent.jpg"))