Skip to content

Commit

Permalink
fix(test): fixing tests for when PANDABI_API_KEY is present in .env (#…
Browse files Browse the repository at this point in the history
…1527)

* fix(test): fixing tests for when PANDABI_API_KEY is present in .env

* fix(Agent): fixing chatting with multiple local dataframes
  • Loading branch information
scaliseraoul authored Jan 17, 2025
1 parent c5a2a13 commit 2d54457
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def clear_cache(filename: str = None):
cache.clear()


def chat(query: str, *dataframes: List[DataFrame]):
def chat(query: str, *dataframes: DataFrame):
"""
Start a new chat interaction with the assistant on Dataframe(s).
Expand Down
41 changes: 38 additions & 3 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import warnings
from typing import Any, List, Optional, Union

import duckdb
import pandas as pd

from pandasai.core.cache import Cache
from pandasai.core.code_execution.code_executor import CodeExecutor
from pandasai.core.code_generation.base import CodeGenerator
Expand All @@ -23,6 +26,7 @@
from pandasai.vectorstores.vectorstore import VectorStore

from ..config import Config
from ..constants import LOCAL_SOURCE_TYPES
from .state import AgentState


Expand Down Expand Up @@ -102,12 +106,43 @@ def execute_code(self, code: str) -> dict:
"""Execute the generated code."""
self._state.logger.log(f"Executing code: {code}")
code_executor = CodeExecutor(self._state.config)
code_executor.add_to_env(
"execute_sql_query", self._state.dfs[0].execute_sql_query
)
code_executor.add_to_env("execute_sql_query", self.execute_sql_query)

return code_executor.execute_and_return_result(code)

def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
try:
# Use a context manager to ensure the connection is closed
with duckdb.connect() as con:
# Register all DataFrames in the state
for df in self._state.dfs:
con.register(df.name, df)

# Execute the query and fetch the result as a pandas DataFrame
result = con.sql(query).df()

return result
except duckdb.Error as e:
raise RuntimeError(f"SQL execution failed: {e}") from e

def execute_sql_query(self, query: str) -> pd.DataFrame:
"""
Executes an SQL query on registered DataFrames.
Args:
query (str): The SQL query to execute.
Returns:
pd.DataFrame: The result of the SQL query as a pandas DataFrame.
"""
if not self._state.dfs:
raise ValueError("No DataFrames available to register for query execution.")

if self._state.dfs[0].schema.source.type in LOCAL_SOURCE_TYPES:
return self._execute_local_sql_query(query)
else:
return self._state.dfs[0].execute_sql_query(query)

def execute_with_retries(self, code: str) -> Any:
"""Execute the code with retry logic."""
max_retries = self._state.config.max_retries
Expand Down
2 changes: 1 addition & 1 deletion pandasai/agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pandasai.config import Config, ConfigManager
from pandasai.constants import DEFAULT_CACHE_DIRECTORY, DEFAULT_CHART_DIRECTORY
from pandasai.core.cache import Cache
from pandasai.data_loader.schema_validator import is_schema_source_same
from pandasai.data_loader.semantic_layer_schema import is_schema_source_same
from pandasai.exceptions import InvalidConfigError
from pandasai.helpers.folder import Folder
from pandasai.helpers.logger import Logger
Expand Down
9 changes: 0 additions & 9 deletions pandasai/data_loader/schema_validator.py

This file was deleted.

7 changes: 0 additions & 7 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,6 @@ def pull(self):

print(f"Dataset pulled successfully from path: {self.path}")

def execute_sql_query(self, query: str) -> pd.DataFrame:
import duckdb

db = duckdb.connect(":memory:")
db.register(self.name, self)
return db.query(query).df()

@staticmethod
def get_column_type(column_dtype) -> Optional[str]:
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/helpers/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandasai.helpers.session import Session, get_pandaai_session


@patch("pandasai.os.environ", {})
def test_session_init_without_api_key():
"""Test that Session initialization raises PandaAIApiKeyError when no API key is provided"""
with pytest.raises(PandaAIApiKeyError) as exc_info:
Expand All @@ -18,6 +19,7 @@ def test_session_init_without_api_key():
)


@patch("pandasai.os.environ", {})
def test_session_init_with_none_api_key():
"""Test that Session initialization raises PandaAIApiKeyError when API key is None"""
with pytest.raises(PandaAIApiKeyError) as exc_info:
Expand All @@ -28,18 +30,21 @@ def test_session_init_with_none_api_key():
)


@patch("pandasai.os.environ", {})
def test_session_init_with_api_key():
"""Test that Session initialization works with a valid API key"""
session = Session(api_key="test-key")
assert session._api_key == "test-key"


@patch("pandasai.os.environ", {})
def test_session_init_with_default_api_url():
"""Test that Session initialization uses DEFAULT_API_URL when no URL is provided"""
session = Session(api_key="test-key")
assert session._endpoint_url == DEFAULT_API_URL


@patch("pandasai.os.environ", {})
def test_session_init_with_custom_api_url():
"""Test that Session initialization uses provided URL"""
custom_url = "https://custom.api.url"
Expand All @@ -64,6 +69,7 @@ def test_session_init_with_env_api_url():
assert session._endpoint_url == "https://env.api.url"


@patch("pandasai.os.environ", {})
def test_get_pandaai_session_without_credentials():
"""Test that get_pandaai_session raises PandaAIApiKeyError when no credentials are provided"""
with pytest.raises(PandaAIApiKeyError) as exc_info:
Expand Down
5 changes: 4 additions & 1 deletion tests/unit_tests/test_pandasai_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def test_load_successful_zip_extraction(
mock_zip_file.return_value.__enter__.return_value.extractall.assert_called_once()
assert isinstance(result, MagicMock)

def test_load_without_api_credentials(self):
@patch("pandasai.os.environ", {})
def test_load_without_api_credentials(
self,
):
"""Test that load raises PandaAIApiKeyError when no API credentials are provided"""
with pytest.raises(PandaAIApiKeyError) as exc_info:
pandasai.load("test/dataset")
Expand Down

0 comments on commit 2d54457

Please sign in to comment.