Skip to content

Commit

Permalink
test: add more tests in the agent (#1572)
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri authored Jan 31, 2025
1 parent 4ca228f commit d2350a1
Show file tree
Hide file tree
Showing 5 changed files with 468 additions and 3 deletions.
75 changes: 75 additions & 0 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pandasai.helpers.memory import Memory


def test_to_json_empty_memory():
memory = Memory()
assert memory.to_json() == []


def test_to_json_with_messages():
memory = Memory()

# Add test messages
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)
memory.add("How are you?", is_user=True)

expected_json = [
{"role": "user", "message": "Hello"},
{"role": "assistant", "message": "Hi there!"},
{"role": "user", "message": "How are you?"},
]

assert memory.to_json() == expected_json


def test_to_json_message_order():
memory = Memory()

# Add messages in specific order
messages = [("Message 1", True), ("Message 2", False), ("Message 3", True)]

for msg, is_user in messages:
memory.add(msg, is_user=is_user)

result = memory.to_json()

# Verify order is preserved
assert len(result) == 3
assert result[0]["message"] == "Message 1"
assert result[1]["message"] == "Message 2"
assert result[2]["message"] == "Message 3"


def test_to_openai_messages_empty():
memory = Memory()
assert memory.to_openai_messages() == []


def test_to_openai_messages_with_agent_description():
memory = Memory(agent_description="I am a helpful assistant")
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)

expected_messages = [
{"role": "system", "content": "I am a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]

assert memory.to_openai_messages() == expected_messages


def test_to_openai_messages_without_agent_description():
memory = Memory()
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)
memory.add("How are you?", is_user=True)

expected_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]

assert memory.to_openai_messages() == expected_messages
109 changes: 107 additions & 2 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
from typing import Optional
from unittest.mock import MagicMock, Mock, mock_open, patch
from unittest.mock import ANY, MagicMock, Mock, mock_open, patch

import pandas as pd
import pytest

from pandasai import DatasetLoader, VirtualDataFrame
from pandasai.agent.base import Agent
from pandasai.config import Config, ConfigManager
from pandasai.core.response.error import ErrorResponse
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import CodeExecutionError
from pandasai.exceptions import CodeExecutionError, InvalidLLMOutputType
from pandasai.llm.fake import FakeLLM


Expand Down Expand Up @@ -466,3 +467,107 @@ def test_execute_sql_query_error_no_dataframe(self, agent):

with pytest.raises(ValueError, match="No DataFrames available"):
agent._execute_sql_query(query)

def test_process_query(self, agent, config):
"""Test the _process_query method with successful execution"""
query = "What is the average age?"
output_type = "number"

# Mock the necessary methods
agent.generate_code = Mock(return_value="result = df['age'].mean()")
agent.execute_with_retries = Mock(return_value=30.5)
agent._state.config.enable_cache = True
agent._state.cache = Mock()

# Execute the query
result = agent._process_query(query, output_type)

# Verify the result
assert result == 30.5

# Verify method calls
agent.generate_code.assert_called_once()
agent.execute_with_retries.assert_called_once_with("result = df['age'].mean()")
agent._state.cache.set.assert_called_once()

def test_process_query_execution_error(self, agent, config):
"""Test the _process_query method with execution error"""
query = "What is the invalid operation?"

# Mock methods to simulate error
agent.generate_code = Mock(return_value="invalid_code")
agent.execute_with_retries = Mock(
side_effect=CodeExecutionError("Execution failed")
)
agent._handle_exception = Mock(return_value="Error handled")

# Execute the query
result = agent._process_query(query)

# Verify error handling
assert result == "Error handled"
agent._handle_exception.assert_called_once_with("invalid_code")

def test_regenerate_code_after_invalid_llm_output_error(self, agent):
"""Test code regeneration with InvalidLLMOutputType error"""
from pandasai.exceptions import InvalidLLMOutputType

code = "test code"
error = InvalidLLMOutputType("Invalid output type")

with patch(
"pandasai.agent.base.get_correct_output_type_error_prompt"
) as mock_prompt:
mock_prompt.return_value = "corrected prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"corrected prompt"
)
assert result == "new code"

def test_regenerate_code_after_other_error(self, agent):
"""Test code regeneration with non-InvalidLLMOutputType error"""
code = "test code"
error = ValueError("Some other error")

with patch(
"pandasai.agent.base.get_correct_error_prompt_for_sql"
) as mock_prompt:
mock_prompt.return_value = "sql error prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"sql error prompt"
)
assert result == "new code"

def test_handle_exception(self, agent):
"""Test that _handle_exception properly formats and logs exceptions"""
test_code = "print(1/0)" # Code that will raise a ZeroDivisionError

# Mock the logger to verify it's called
mock_logger = MagicMock()
agent._state.logger = mock_logger

# Create an actual exception to handle
try:
exec(test_code)
except:
# Call the method
result = agent._handle_exception(test_code)

# Verify the result is an ErrorResponse
assert isinstance(result, ErrorResponse)
assert result.last_code_executed == test_code
assert "ZeroDivisionError" in result.error

# Verify the error was logged
mock_logger.log.assert_called_once()
assert "Processing failed with error" in mock_logger.log.call_args[0][0]
136 changes: 136 additions & 0 deletions tests/unit_tests/dataframe/test_pull.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os
from io import BytesIO
from unittest.mock import Mock, mock_open, patch
from zipfile import ZipFile

import pandas as pd
import pytest

from pandasai.data_loader.semantic_layer_schema import (
Column,
SemanticLayerSchema,
Source,
)
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError


@pytest.fixture
def mock_env(monkeypatch):
monkeypatch.setenv("PANDABI_API_KEY", "test_api_key")


@pytest.fixture
def sample_df():
return pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})


@pytest.fixture
def mock_zip_content():
zip_buffer = BytesIO()
with ZipFile(zip_buffer, "w") as zip_file:
zip_file.writestr("test.csv", "col1,col2\n1,a\n2,b\n3,c")
return zip_buffer.getvalue()


@pytest.fixture
def mock_schema():
return SemanticLayerSchema(
name="test_schema",
source=Source(type="parquet", path="data.parquet", table="test_table"),
columns=[
Column(name="col1", type="integer"),
Column(name="col2", type="string"),
],
)


def test_pull_success(mock_env, sample_df, mock_zip_content, mock_schema, tmp_path):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session, patch(
"pandasai.dataframe.base.find_project_root"
) as mock_root, patch(
"pandasai.DatasetLoader.create_loader_from_path"
) as mock_loader, patch("builtins.open", mock_open()) as mock_file:
# Setup mocks
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = mock_zip_content
mock_session.return_value.get.return_value = mock_response
mock_root.return_value = str(tmp_path)

mock_loader_instance = Mock()
mock_loader_instance.load.return_value = DataFrame(
sample_df, schema=mock_schema
)
mock_loader.return_value = mock_loader_instance

# Create DataFrame instance and call pull
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()

# Verify API call
mock_session.return_value.get.assert_called_once_with(
"/datasets/pull",
headers={
"accept": "application/json",
"x-authorization": "Bearer test_api_key",
},
params={"path": "test/path"},
)

# Verify file operations
assert mock_file.call_count > 0


def test_pull_missing_api_key(sample_df, mock_schema):
with patch("os.environ.get") as mock_env_get:
mock_env_get.return_value = None
with pytest.raises(PandaAIApiKeyError):
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()


def test_pull_api_error(mock_env, sample_df, mock_schema):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session:
mock_response = Mock()
mock_response.status_code = 404
mock_session.return_value.get.return_value = mock_response

df = DataFrame(sample_df, path="test/path", schema=mock_schema)
with pytest.raises(DatasetNotFound, match="Remote dataset not found to pull!"):
df.pull()


def test_pull_file_exists(mock_env, sample_df, mock_zip_content, mock_schema, tmp_path):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session, patch(
"pandasai.dataframe.base.find_project_root"
) as mock_root, patch(
"pandasai.DatasetLoader.create_loader_from_path"
) as mock_loader, patch("builtins.open", mock_open()) as mock_file, patch(
"os.path.exists"
) as mock_exists, patch("os.makedirs") as mock_makedirs:
# Setup mocks
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = mock_zip_content
mock_session.return_value.get.return_value = mock_response
mock_root.return_value = str(tmp_path)
mock_exists.return_value = True

mock_loader_instance = Mock()
mock_loader_instance.load.return_value = DataFrame(
sample_df, schema=mock_schema
)
mock_loader.return_value = mock_loader_instance

# Create DataFrame instance and call pull
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()

# Verify directory creation
mock_makedirs.assert_called_with(
os.path.dirname(
os.path.join(str(tmp_path), "datasets", "test/path", "test.csv")
),
exist_ok=True,
)
Loading

0 comments on commit d2350a1

Please sign in to comment.