Skip to content

Commit

Permalink
fix: transpiling sql to specific dialect (#1596)
Browse files Browse the repository at this point in the history
* fix: transpiling sql to specific dialect

* fix: remove default mysql dialect
  • Loading branch information
scaliseraoul authored Feb 7, 2025
1 parent 2455592 commit 5e025ed
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 8 deletions.
3 changes: 3 additions & 0 deletions pandasai/data_loader/duck_db_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import duckdb

from pandasai.query_builders.sql_parser import SQLParser


class DuckDBConnectionManager:
_instance = None
Expand Down Expand Up @@ -32,6 +34,7 @@ def register(self, name: str, df):

def sql(self, query: str):
"""Executes an SQL query and returns the result as a Pandas DataFrame."""
query = SQLParser.transpile_sql_dialect(query, to_dialect="duckdb")
return self.connection.sql(query)

def close(self):
Expand Down
2 changes: 2 additions & 0 deletions pandasai/data_loader/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..constants import (
SUPPORTED_SOURCE_CONNECTORS,
)
from ..query_builders.sql_parser import SQLParser
from .loader import DatasetLoader
from .semantic_layer_schema import SemanticLayerSchema

Expand Down Expand Up @@ -40,6 +41,7 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
connection_info = self.schema.source.connection

load_function = self._get_loader_function(source_type)
query = SQLParser.transpile_sql_dialect(query, to_dialect=source_type)

if not is_sql_query_safe(query):
raise MaliciousQueryError(
Expand Down
2 changes: 2 additions & 0 deletions pandasai/data_loader/view_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..exceptions import MaliciousQueryError
from ..helpers.sql_sanitizer import is_sql_query_safe
from ..query_builders.base_query_builder import BaseQueryBuilder
from ..query_builders.sql_parser import SQLParser
from .duck_db_connection_manager import DuckDBConnectionManager
from .loader import DatasetLoader
from .local_loader import LocalDatasetLoader
Expand Down Expand Up @@ -89,6 +90,7 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
if source_type in LOCAL_SOURCE_TYPES:
return self.execute_local_query(query)
load_function = self._get_loader_function(source_type)
query = SQLParser.transpile_sql_dialect(query, to_dialect=source_type)

if not is_sql_query_safe(query):
raise MaliciousQueryError(
Expand Down
14 changes: 7 additions & 7 deletions pandasai/query_builders/sql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@


class SQLParser:
@staticmethod
def extract_table_names(query: str):
parsed = sqlglot.parse_one(query)
return {
table.alias_or_name for table in parsed.find_all(sqlglot.expressions.Table)
}

@staticmethod
def replace_table_and_column_names(query, table_mapping):
"""
Expand Down Expand Up @@ -57,3 +50,10 @@ def transform_node(node):

# Convert back to SQL string
return transformed.sql(pretty=True)

@staticmethod
def transpile_sql_dialect(query, to_dialect, from_dialect=None):
query = (
parse_one(query, read=from_dialect) if from_dialect else parse_one(query)
)
return query.sql(dialect=to_dialect, pretty=True)
2 changes: 1 addition & 1 deletion tests/unit_tests/data_loader/test_sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_mysql_safe_query(self, mysql_schema):
result = loader.execute_query("SELECT * FROM users")

assert isinstance(result, DataFrame)
mock_sql_query.assert_called_once_with("SELECT * FROM users")
mock_sql_query.assert_called_once_with("SELECT\n *\nFROM users")

def test_mysql_malicious_with_no_import(self, mysql_schema):
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/query_builders/test_sql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ class TestSqlParser:
def test_replace_table_names(query, table_mapping, expected):
result = SQLParser.replace_table_and_column_names(query, table_mapping)
assert result.strip() == expected.strip()

def test_mysql_transpilation(self):
query = '''SELECT COUNT(*) AS "total_rows"'''
expected = """SELECT\n COUNT(*) AS `total_rows`"""
result = SQLParser.transpile_sql_dialect(query, to_dialect="mysql")
assert result.strip() == expected.strip()

0 comments on commit 5e025ed

Please sign in to comment.