From 3fd2a7705641aa362d2b008f344d494c0d3a5f89 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sat, 5 Oct 2024 21:09:27 +0200 Subject: [PATCH] more unit tests --- src/banks/extensions/chat.py | 14 +++++++--- src/banks/filters/cache_control.py | 8 ++++-- src/banks/filters/lemmatize.py | 2 +- tests/test_cache_control.py | 8 ++++++ tests/test_chat.py | 42 ++++++++++++++++++++++++++++++ 5 files changed, 67 insertions(+), 7 deletions(-) create mode 100644 tests/test_cache_control.py diff --git a/src/banks/extensions/chat.py b/src/banks/extensions/chat.py index 32a0fae..c530d68 100644 --- a/src/banks/extensions/chat.py +++ b/src/banks/extensions/chat.py @@ -31,6 +31,8 @@ def chat(role: str): # pylint: disable=W0613 class _ContentBlockParser(HTMLParser): + """A parser used to extract text surrounded by `` and `` tags.""" + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._parse_block_content = False @@ -38,15 +40,19 @@ def __init__(self, *args, **kwargs) -> None: @property def content(self) -> ChatMessageContent: + """Returns ChatMessageContent data that can be directly assigned to ChatMessage.content. + + If only one block is present, this block is of type text and has no cache control set, we just + return it as plain text for simplicity. + """ if len(self._content_blocks) == 1: block = self._content_blocks[0] - if type(block) is ContentBlock: - if block.type == "text" and block.cache_control is None: - return block.text or "" + if block.type == "text" and block.cache_control is None: + return block.text or "" return self._content_blocks - def handle_starttag(self, tag, attrs): + def handle_starttag(self, tag, _): if tag == "content_block_txt": self._parse_block_content = True diff --git a/src/banks/filters/cache_control.py b/src/banks/filters/cache_control.py index 16c148a..b5d2fb9 100644 --- a/src/banks/filters/cache_control.py +++ b/src/banks/filters/cache_control.py @@ -5,15 +5,19 @@ def cache_control(value: str, cache_type: str = "ephemeral") -> str: - """ + """Wrap the filtered value into a ContentBlock with the proper cache_control field set. + + The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects. Example: ``` - {{ "This is a long, long text" | cache_control "ephemeral" }} + {{ "This is a long, long text" | cache_control("ephemeral") }} This is short and won't be cached. ``` + Important: this filter marks the content to cache by surrounding it with `` and + ``, so it's only useful when used within a `{% chat %}` block. """ block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}}) return f"{block.model_dump_json()}" diff --git a/src/banks/filters/lemmatize.py b/src/banks/filters/lemmatize.py index 185b86e..44e0e5c 100644 --- a/src/banks/filters/lemmatize.py +++ b/src/banks/filters/lemmatize.py @@ -7,7 +7,7 @@ from simplemma import text_lemmatizer # type: ignore SIMPLEMMA_AVAIL = True -except ImportError: +except ImportError: # pragma: no cover SIMPLEMMA_AVAIL = False diff --git a/tests/test_cache_control.py b/tests/test_cache_control.py new file mode 100644 index 0000000..0a6e378 --- /dev/null +++ b/tests/test_cache_control.py @@ -0,0 +1,8 @@ +from banks.filters.cache_control import cache_control + + +def test_cache_control(): + res = cache_control("foo", "ephemeral") + res = res.replace("", "") + res = res.replace("", "") + assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}' diff --git a/tests/test_chat.py b/tests/test_chat.py index 2941d75..71e3be1 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,6 +2,8 @@ from jinja2 import TemplateSyntaxError from banks import Prompt +from banks.extensions.chat import _ContentBlockParser +from banks.types import CacheControl, ContentBlock, ContentBlockType def test_wrong_tag(): @@ -17,3 +19,43 @@ def test_wrong_tag_params(): def test_wrong_role_type(): with pytest.raises(TemplateSyntaxError): Prompt('{% chat role="does not exist" %}{% endchat %}') + + +def test_content_block_parser_init(): + p = _ContentBlockParser() + assert p._parse_block_content is False + assert p._content_blocks == [] + + +def test_content_block_parser_single_with_cache_control(): + p = _ContentBlockParser() + p.feed( + '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}' + ) + assert p.content == [ + ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None) + ] + + +def test_content_block_parser_single_no_cache_control(): + p = _ContentBlockParser() + p.feed('{"type":"text","cache_control":null,"text":"foo","source":null}') + assert p.content == "foo" + + +def test_content_block_parser_multiple(): + p = _ContentBlockParser() + p.feed( + '{"type":"text","cache_control":null,"text":"foo","source":null}' + '{"type":"text","cache_control":null,"text":"bar","source":null}' + ) + assert p.content == [ + ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None), + ContentBlock(type=ContentBlockType.text, cache_control=None, text="bar", source=None), + ] + + +def test_content_block_parser_other_tags(): + p = _ContentBlockParser() + p.feed("FOO") + assert p.content == "FOO"