Skip to content

Commit

Permalink
more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Oct 5, 2024
1 parent 5e00876 commit 3fd2a77
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 7 deletions.
14 changes: 10 additions & 4 deletions src/banks/extensions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,28 @@ def chat(role: str): # pylint: disable=W0613


class _ContentBlockParser(HTMLParser):
"""A parser used to extract text surrounded by `<content_block_txt>` and `</content_block_txt>` tags."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._parse_block_content = False
self._content_blocks: list[ContentBlock] = []

@property
def content(self) -> ChatMessageContent:
"""Returns ChatMessageContent data that can be directly assigned to ChatMessage.content.
If only one block is present, this block is of type text and has no cache control set, we just
return it as plain text for simplicity.
"""
if len(self._content_blocks) == 1:
block = self._content_blocks[0]
if type(block) is ContentBlock:
if block.type == "text" and block.cache_control is None:
return block.text or ""
if block.type == "text" and block.cache_control is None:
return block.text or ""

return self._content_blocks

def handle_starttag(self, tag, attrs):
def handle_starttag(self, tag, _):
if tag == "content_block_txt":
self._parse_block_content = True

Expand Down
8 changes: 6 additions & 2 deletions src/banks/filters/cache_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@


def cache_control(value: str, cache_type: str = "ephemeral") -> str:
"""
"""Wrap the filtered value into a ContentBlock with the proper cache_control field set.
The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects.
Example:
```
{{ "This is a long, long text" | cache_control "ephemeral" }}
{{ "This is a long, long text" | cache_control("ephemeral") }}
This is short and won't be cached.
```
Important: this filter marks the content to cache by surrounding it with `<content_block_txt>` and
`</content_block_txt>`, so it's only useful when used within a `{% chat %}` block.
"""
block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}})
return f"<content_block_txt>{block.model_dump_json()}</content_block_txt>"
2 changes: 1 addition & 1 deletion src/banks/filters/lemmatize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from simplemma import text_lemmatizer # type: ignore

SIMPLEMMA_AVAIL = True
except ImportError:
except ImportError: # pragma: no cover
SIMPLEMMA_AVAIL = False


Expand Down
8 changes: 8 additions & 0 deletions tests/test_cache_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from banks.filters.cache_control import cache_control


def test_cache_control():
res = cache_control("foo", "ephemeral")
res = res.replace("<content_block_txt>", "")
res = res.replace("</content_block_txt>", "")
assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}'
42 changes: 42 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from jinja2 import TemplateSyntaxError

from banks import Prompt
from banks.extensions.chat import _ContentBlockParser
from banks.types import CacheControl, ContentBlock, ContentBlockType


def test_wrong_tag():
Expand All @@ -17,3 +19,43 @@ def test_wrong_tag_params():
def test_wrong_role_type():
with pytest.raises(TemplateSyntaxError):
Prompt('{% chat role="does not exist" %}{% endchat %}')


def test_content_block_parser_init():
p = _ContentBlockParser()
assert p._parse_block_content is False
assert p._content_blocks == []


def test_content_block_parser_single_with_cache_control():
p = _ContentBlockParser()
p.feed(
'<content_block_txt>{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}</content_block_txt>'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None)
]


def test_content_block_parser_single_no_cache_control():
p = _ContentBlockParser()
p.feed('<content_block_txt>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block_txt>')
assert p.content == "foo"


def test_content_block_parser_multiple():
p = _ContentBlockParser()
p.feed(
'<content_block_txt>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block_txt>'
'<content_block_txt>{"type":"text","cache_control":null,"text":"bar","source":null}</content_block_txt>'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None),
ContentBlock(type=ContentBlockType.text, cache_control=None, text="bar", source=None),
]


def test_content_block_parser_other_tags():
p = _ContentBlockParser()
p.feed("<some_tag>FOO</some_tag>")
assert p.content == "FOO"

0 comments on commit 3fd2a77

Please sign in to comment.