diff --git a/pyproject.toml b/pyproject.toml index 832b958..95fd230 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,8 @@ ignore = [ "ISC001", # Magic numbers "PLR2004", + # __all__ sorted + "RUF022", ] unfixable = [ # Don't touch unused imports diff --git a/src/banks/extensions/completion.py b/src/banks/extensions/completion.py index d12b0a6..c39d333 100644 --- a/src/banks/extensions/completion.py +++ b/src/banks/extensions/completion.py @@ -89,8 +89,9 @@ def _do_completion(self, model_name, caller): Helper callback. """ messages, tools = self._body_to_messages(caller()) + messages_as_dict = [m.model_dump() for m in messages] - response = cast(ModelResponse, completion(model=model_name, messages=messages, tools=tools)) + response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict, tools=tools or None)) choices = cast(list[Choices], response.choices) tool_calls = choices[0].message.tool_calls if not tool_calls: @@ -112,7 +113,8 @@ def _do_completion(self, model_name, caller): ) ) - response = cast(ModelResponse, completion(model=model_name, messages=messages)) + messages_as_dict = [m.model_dump() for m in messages] + response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict)) choices = cast(list[Choices], response.choices) return choices[0].message.content diff --git a/tests/test_completion.py b/tests/test_completion.py index 4bb7816..ea5b885 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -115,7 +115,7 @@ def test__do_completion_no_tools(ext, mocked_choices_no_tools): mocked_completion.return_value.choices = mocked_choices_no_tools ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}') mocked_completion.assert_called_with( - model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + model="test-model", messages=[ChatMessage(role="user", content="hello").model_dump()], tools=None ) @@ -131,13 +131,18 @@ async def test__do_completion_async_no_tools(ext, mocked_choices_no_tools): def test__do_completion_with_tools(ext, mocked_choices_with_tools): ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}") - ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"])) + ext._body_to_messages = mock.MagicMock( + return_value=( + [ChatMessage(role="user", content="message1"), ChatMessage(role="user", content="message2")], + [mock.MagicMock(), mock.MagicMock()], + ) + ) with mock.patch("banks.extensions.completion.completion") as mocked_completion: mocked_completion.return_value.choices = mocked_choices_with_tools ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}') calls = mocked_completion.call_args_list assert len(calls) == 2 # complete query, complete with tool results - assert calls[0].kwargs["tools"] == ["tool1", "tool2"] + assert len(calls[0].kwargs["tools"]) == 2 assert "tools" not in calls[1].kwargs for m in calls[1].kwargs["messages"]: if type(m) is ChatMessage: