Skip to content

Commit

Permalink
[Fix] Temporarily fix the gemini structured output error (#986)
Browse files Browse the repository at this point in the history
Co-authored-by: Maximilian Schulz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 7, 2025
1 parent 07a12da commit 5e30406
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->

### Fixed

- Added temporary compatibility fix for Gemini models in `_pydantic_output.py`. ([#986](https://github.com/mckinsey/vizro/pull/986))

<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
18 changes: 17 additions & 1 deletion vizro-ai/examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@
"# max_retries=2,\n",
"# endpoint= os.environ.get(\"MISTRAL_BASE_URL\"),\n",
"# mistral_api_key = os.environ.get(\"MISTRAL_API_KEY\")\n",
"# )\n",
"\n",
"# import os\n",
"# from langchain_google_genai import ChatGoogleGenerativeAI\n",
"# llm = ChatGoogleGenerativeAI(\n",
"# model=\"gemini-1.5-flash-latest\",\n",
"# # model=\"gemini-1.5-pro-latest\",\n",
"# google_api_key=os.environ.get(\"GOOGLE_API_KEY\"),\n",
"# temperature=0,\n",
"# )"
]
},
Expand Down Expand Up @@ -109,6 +118,13 @@
"source": [
"vizro_ai.plot(df, \"show me the geo distribution of life expectancy and set year as animation \")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -127,7 +143,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions vizro-ai/hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"langchain_mistralai",
"langchain-anthropic",
"langchain-aws",
"langchain-google-genai",
"pre-commit"
]
installer = "uv"
Expand Down
34 changes: 30 additions & 4 deletions vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@ def _create_message_content(
return message_content


def _handle_google_llm_response(
llm_model: BaseChatModel, response_model: BaseModel, prompt: ChatPromptTemplate, message_content: dict
) -> BaseModel:
"""Handle the LLM response specifically for Google models."""
from langchain_core.utils.function_calling import convert_to_openai_function

schema = convert_to_openai_function(response_model)

pydantic_llm = prompt | llm_model.with_structured_output(schema)
res = pydantic_llm.invoke(message_content)

# Handle case where response is a list, which is the case for Gemini models
if isinstance(res, list):
res = res[0].get("args", res[0]) if isinstance(res[0], dict) else res[0]

return response_model.parse_obj(res)


def _get_pydantic_model(
query: str,
llm_model: BaseChatModel,
Expand All @@ -79,13 +97,21 @@ def _get_pydantic_model(
message_content = _create_message_content(
query, df_info, str(last_validation_error) if attempt_is_retry else None, retry=attempt_is_retry
)
pydantic_llm = prompt | llm_model.with_structured_output(response_model)

try:
res = pydantic_llm.invoke(message_content)
# Apply the fix for nested structures, following langchain-google-genai implementation
# referred to https://github.com/langchain-ai/langchain/issues/24225
# and https://github.com/langchain-ai/langchain-google/pull/658/files
# TODO: revisit this temporary fix once pydantic v2 is implemented in vizro-ai
if "google" in llm_model.__class__.__module__.lower():
return _handle_google_llm_response(llm_model, response_model, prompt, message_content)

# For other models, use standard structured output
pydantic_llm = prompt | llm_model.with_structured_output(response_model)
return pydantic_llm.invoke(message_content)

except ValidationError as validation_error:
last_validation_error = validation_error
else:
return res # TODO: problem is response is None, then it returns without raising an error. Wrong typing!
# TODO: should this be shifted to logging so that that one can control what output gets shown (e.g. in public demos)
raise last_validation_error

Expand Down

0 comments on commit 5e30406

Please sign in to comment.