diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 0599b8b16..bca62101e 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -121,7 +121,7 @@ def clear_cache(filename: str = None): cache.clear() -def chat(query: str, *dataframes: List[DataFrame]): +def chat(query: str, *dataframes: DataFrame): """ Start a new chat interaction with the assistant on Dataframe(s). diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index 5350810c9..04a18d148 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -2,6 +2,9 @@ import warnings from typing import Any, List, Optional, Union +import duckdb +import pandas as pd + from pandasai.core.cache import Cache from pandasai.core.code_execution.code_executor import CodeExecutor from pandasai.core.code_generation.base import CodeGenerator @@ -23,6 +26,7 @@ from pandasai.vectorstores.vectorstore import VectorStore from ..config import Config +from ..constants import LOCAL_SOURCE_TYPES from .state import AgentState @@ -102,12 +106,43 @@ def execute_code(self, code: str) -> dict: """Execute the generated code.""" self._state.logger.log(f"Executing code: {code}") code_executor = CodeExecutor(self._state.config) - code_executor.add_to_env( - "execute_sql_query", self._state.dfs[0].execute_sql_query - ) + code_executor.add_to_env("execute_sql_query", self.execute_sql_query) return code_executor.execute_and_return_result(code) + def _execute_local_sql_query(self, query: str) -> pd.DataFrame: + try: + # Use a context manager to ensure the connection is closed + with duckdb.connect() as con: + # Register all DataFrames in the state + for df in self._state.dfs: + con.register(df.name, df) + + # Execute the query and fetch the result as a pandas DataFrame + result = con.sql(query).df() + + return result + except duckdb.Error as e: + raise RuntimeError(f"SQL execution failed: {e}") from e + + def execute_sql_query(self, query: str) -> pd.DataFrame: + """ + Executes an SQL query on registered DataFrames. + + Args: + query (str): The SQL query to execute. + + Returns: + pd.DataFrame: The result of the SQL query as a pandas DataFrame. + """ + if not self._state.dfs: + raise ValueError("No DataFrames available to register for query execution.") + + if self._state.dfs[0].schema.source.type in LOCAL_SOURCE_TYPES: + return self._execute_local_sql_query(query) + else: + return self._state.dfs[0].execute_sql_query(query) + def execute_with_retries(self, code: str) -> Any: """Execute the code with retry logic.""" max_retries = self._state.config.max_retries diff --git a/pandasai/agent/state.py b/pandasai/agent/state.py index b653c01c4..58049db30 100644 --- a/pandasai/agent/state.py +++ b/pandasai/agent/state.py @@ -9,7 +9,7 @@ from pandasai.config import Config, ConfigManager from pandasai.constants import DEFAULT_CACHE_DIRECTORY, DEFAULT_CHART_DIRECTORY from pandasai.core.cache import Cache -from pandasai.data_loader.schema_validator import is_schema_source_same +from pandasai.data_loader.semantic_layer_schema import is_schema_source_same from pandasai.exceptions import InvalidConfigError from pandasai.helpers.folder import Folder from pandasai.helpers.logger import Logger diff --git a/pandasai/data_loader/schema_validator.py b/pandasai/data_loader/schema_validator.py deleted file mode 100644 index 9cb3ac2f9..000000000 --- a/pandasai/data_loader/schema_validator.py +++ /dev/null @@ -1,9 +0,0 @@ -import json - - -def is_schema_source_same(schema1: dict, schema2: dict) -> bool: - return schema1.get("source").get("type") == schema2.get("source").get( - "type" - ) and json.dumps( - schema1.get("source").get("connection"), sort_keys=True - ) == json.dumps(schema2.get("source").get("connection"), sort_keys=True) diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index 3492ed371..345225419 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -229,13 +229,6 @@ def pull(self): print(f"Dataset pulled successfully from path: {self.path}") - def execute_sql_query(self, query: str) -> pd.DataFrame: - import duckdb - - db = duckdb.connect(":memory:") - db.register(self.name, self) - return db.query(query).df() - @staticmethod def get_column_type(column_dtype) -> Optional[str]: """ diff --git a/tests/unit_tests/helpers/test_session.py b/tests/unit_tests/helpers/test_session.py index ba99f9420..2baeae5c0 100644 --- a/tests/unit_tests/helpers/test_session.py +++ b/tests/unit_tests/helpers/test_session.py @@ -8,6 +8,7 @@ from pandasai.helpers.session import Session, get_pandaai_session +@patch("pandasai.os.environ", {}) def test_session_init_without_api_key(): """Test that Session initialization raises PandaAIApiKeyError when no API key is provided""" with pytest.raises(PandaAIApiKeyError) as exc_info: @@ -18,6 +19,7 @@ def test_session_init_without_api_key(): ) +@patch("pandasai.os.environ", {}) def test_session_init_with_none_api_key(): """Test that Session initialization raises PandaAIApiKeyError when API key is None""" with pytest.raises(PandaAIApiKeyError) as exc_info: @@ -28,18 +30,21 @@ def test_session_init_with_none_api_key(): ) +@patch("pandasai.os.environ", {}) def test_session_init_with_api_key(): """Test that Session initialization works with a valid API key""" session = Session(api_key="test-key") assert session._api_key == "test-key" +@patch("pandasai.os.environ", {}) def test_session_init_with_default_api_url(): """Test that Session initialization uses DEFAULT_API_URL when no URL is provided""" session = Session(api_key="test-key") assert session._endpoint_url == DEFAULT_API_URL +@patch("pandasai.os.environ", {}) def test_session_init_with_custom_api_url(): """Test that Session initialization uses provided URL""" custom_url = "https://custom.api.url" @@ -64,6 +69,7 @@ def test_session_init_with_env_api_url(): assert session._endpoint_url == "https://env.api.url" +@patch("pandasai.os.environ", {}) def test_get_pandaai_session_without_credentials(): """Test that get_pandaai_session raises PandaAIApiKeyError when no credentials are provided""" with pytest.raises(PandaAIApiKeyError) as exc_info: diff --git a/tests/unit_tests/test_pandasai_init.py b/tests/unit_tests/test_pandasai_init.py index 9854eb077..5ab1128b8 100644 --- a/tests/unit_tests/test_pandasai_init.py +++ b/tests/unit_tests/test_pandasai_init.py @@ -190,7 +190,10 @@ def test_load_successful_zip_extraction( mock_zip_file.return_value.__enter__.return_value.extractall.assert_called_once() assert isinstance(result, MagicMock) - def test_load_without_api_credentials(self): + @patch("pandasai.os.environ", {}) + def test_load_without_api_credentials( + self, + ): """Test that load raises PandaAIApiKeyError when no API credentials are provided""" with pytest.raises(PandaAIApiKeyError) as exc_info: pandasai.load("test/dataset")