Skip to content

Commit

Permalink
feature(Views): views for a single dataframe (#1594)
Browse files Browse the repository at this point in the history
* feature(Views): views for a single dataframe

* fix: regression on sql dataframe, added tests

* fix(VirtualDataframe): postpone head loading for dependency needs
  • Loading branch information
scaliseraoul authored Feb 7, 2025
1 parent be3e158 commit 2455592
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 57 deletions.
12 changes: 8 additions & 4 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,30 @@ def create(
"Please provide either a DataFrame, a Source or a View"
)

parsed_columns = [Column(**column) for column in columns] if columns else None

if df is not None:
schema = df.schema
schema.name = dataset_name
if (
parsed_columns
): # if no columns are passed it automatically parse the columns from the df
schema.columns = parsed_columns
parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path)
df.to_parquet(parquet_file_path_abs_path, index=False)
elif view:
_relation = [Relation(**relation) for relation in relations or ()]
schema: SemanticLayerSchema = SemanticLayerSchema(
name=dataset_name, relations=_relation, view=True
name=dataset_name, relations=_relation, view=True, columns=parsed_columns
)
elif source.get("table"):
schema: SemanticLayerSchema = SemanticLayerSchema(
name=dataset_name, source=Source(**source)
name=dataset_name, source=Source(**source), columns=parsed_columns
)
else:
raise InvalidConfigError("Unable to create schema with the provided params")

schema.description = description or schema.description
if columns:
schema.columns = [Column(**column) for column in columns]

file_manager.write(schema_path, schema.to_yaml())

Expand Down
13 changes: 7 additions & 6 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def check_columns_relations(self):

# unpack columns info
_columns = self.columns

_column_names = [col.name for col in _columns or ()]
_tables_names_in_columns = {
column_name.split(".")[0] for column_name in _column_names or ()
Expand All @@ -309,8 +310,10 @@ def check_columns_relations(self):
for column_name in _column_names_in_relations or ()
}

if not self.relations:
raise ValueError("At least one relation must be defined for view.")
if not self.relations and not self.columns:
raise ValueError(
"At least a relation or a column must be defined for view."
)

if not all(
is_view_column_name(column_name) for column_name in _column_names
Expand All @@ -327,10 +330,8 @@ def check_columns_relations(self):
"All params 'from' and 'to' in the relations must be in the format '[dataset].[column]'."
)

if (
uncovered_tables := _tables_names_in_columns
- _tables_names_in_relations
):
uncovered_tables = _tables_names_in_columns - _tables_names_in_relations
if uncovered_tables and len(_tables_names_in_columns) > 1:
raise ValueError(
f"No relations provided for the following tables {uncovered_tables}."
)
Expand Down
2 changes: 1 addition & 1 deletion pandasai/data_loader/view_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_dependencies_datasets(self) -> set[str]:
table.split(".")[0]
for relation in self.schema.relations
for table in (relation.from_, relation.to)
}
} or {self.schema.columns[0].name.split(".")[0]}

def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]:
dependency_dict = {
Expand Down
1 change: 0 additions & 1 deletion pandasai/dataframe/virtual_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self, *args, **kwargs):
self._head = None

super().__init__(
self.get_head(),
*args,
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlglot


def sanitize_relation_name(relation_name: str) -> str:
def sanitize_view_column_name(relation_name: str) -> str:
return ".".join(list(map(sanitize_sql_table_name, relation_name.split("."))))


Expand Down
9 changes: 3 additions & 6 deletions pandasai/query_builders/base_query_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import re
from typing import Any, List
from typing import List

import sqlglot
from sqlglot import from_, pretty, select
from sqlglot.expressions import Limit, cast
from sqlglot import select
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema, Source
Expand Down Expand Up @@ -42,7 +39,7 @@ def _get_columns(self) -> list[str]:
return ["*"]

def _get_table_expression(self) -> str:
return normalize_identifiers(self.schema.name).sql()
return normalize_identifiers(self.schema.name).sql(pretty=True)

@staticmethod
def check_compatible_sources(sources: List[Source]) -> bool:
Expand Down
9 changes: 5 additions & 4 deletions pandasai/query_builders/sql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ def transform_node(node):
if isinstance(node, exp.Table):
original_name = node.name
if original_name in table_mapping:
alias = node.alias or original_name
mapped_value = parsed_mapping[original_name]
if isinstance(mapped_value, exp.Alias):
return exp.Subquery(
this=mapped_value.this.this,
alias=node.alias or original_name,
alias=alias,
)
return exp.Subquery(
this=mapped_value, alias=node.alias or original_name
)
elif isinstance(mapped_value, exp.Column):
return exp.Table(this=mapped_value.this, alias=alias)
return exp.Subquery(this=mapped_value, alias=alias)

return node

Expand Down
25 changes: 20 additions & 5 deletions pandasai/query_builders/view_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..data_loader.loader import DatasetLoader
from ..data_loader.semantic_layer_schema import SemanticLayerSchema
from ..helpers.sql_sanitizer import sanitize_relation_name
from ..helpers.sql_sanitizer import sanitize_view_column_name
from .base_query_builder import BaseQueryBuilder


Expand All @@ -19,10 +19,20 @@ def __init__(
super().__init__(schema)
self.schema_dependencies_dict = schema_dependencies_dict

@staticmethod
def normalize_view_column_name(name: str) -> str:
return normalize_identifiers(parse_one(sanitize_view_column_name(name))).sql()

@staticmethod
def normalize_view_column_alias(name: str) -> str:
return normalize_identifiers(
sanitize_view_column_name(name).replace(".", "_")
).sql()

def _get_columns(self) -> list[str]:
if self.schema.columns:
return [
normalize_identifiers(col.name.replace(".", "_")).sql()
self.normalize_view_column_alias(col.name)
for col in self.schema.columns
]
else:
Expand All @@ -34,13 +44,18 @@ def _get_sub_query_from_loader(self, loader: DatasetLoader) -> Subquery:

def _get_table_expression(self) -> str:
relations = self.schema.relations
first_dataset = relations[0].from_.split(".")[0]
columns = self.schema.columns
first_dataset = (
relations[0].from_.split(".")[0]
if relations
else columns[0].name.split(".")[0]
)
first_loader = self.schema_dependencies_dict[first_dataset]
first_query = self._get_sub_query_from_loader(first_loader)

if self.schema.columns:
columns = [
f"{normalize_identifiers(col.name).sql()} AS {normalize_identifiers(col.name.replace('.', '_'))}"
f"{self.normalize_view_column_name(col.name)} AS {self.normalize_view_column_alias(col.name)}"
for col in self.schema.columns
]
else:
Expand All @@ -54,7 +69,7 @@ def _get_table_expression(self) -> str:
subquery = self._get_sub_query_from_loader(loader)
query = query.join(
subquery,
on=f"{sanitize_relation_name(relation.from_)} = {sanitize_relation_name(relation.to)}",
on=f"{sanitize_view_column_name(relation.from_)} = {sanitize_view_column_name(relation.to)}",
append=True,
)
alias = normalize_identifiers(self.schema.name).sql()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pandasai.helpers.sql_sanitizer import (
is_sql_query_safe,
sanitize_file_name,
sanitize_relation_name,
sanitize_view_column_name,
)


Expand All @@ -25,7 +25,7 @@ def test_sanitize_file_name_long_name(self):
def test_sanitize_relation_name_valid(self):
relation = "dataset-name.column"
expected = "dataset_name.column"
assert sanitize_relation_name(relation) == expected
assert sanitize_view_column_name(relation) == expected

def test_safe_select_query(self):
query = "SELECT * FROM users WHERE username = 'admin';"
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/query_builders/test_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,19 @@ def test_table_name_time_based_injection(self, mysql_schema):
created_at DESC
LIMIT 100"""
)

@pytest.mark.parametrize(
"injection",
[
"users; DROP TABLE users;",
"users UNION SELECT 1,2,3;",
'users"; SELECT * FROM sensitive_data; --',
"users; TRUNCATE users; SELECT * FROM users WHERE 't'='t",
"users' AND (SELECT * FROM (SELECT(SLEEP(5)))test); --",
],
)
def test_order_by_injection(self, injection, mysql_schema):
mysql_schema.order_by = [injection]
query_builder = BaseQueryBuilder(mysql_schema)
with pytest.raises((sqlglot.errors.ParseError, sqlglot.errors.TokenError)):
query_builder.build_query()
58 changes: 58 additions & 0 deletions tests/unit_tests/query_builders/test_sql_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

from pandasai.query_builders.sql_parser import SQLParser


class TestSqlParser:
@staticmethod
@pytest.mark.parametrize(
"query, table_mapping, expected",
[
(
"SELECT * FROM customers",
{"customers": "clients"},
"""SELECT
*
FROM "clients" AS customers""",
),
(
"SELECT * FROM orders",
{"orders": "(SELECT * FROM sales)"},
"""SELECT
*
FROM (
(
SELECT
*
FROM "sales"
)
) AS orders""",
),
(
"SELECT * FROM customers c",
{"customers": "clients"},
"""SELECT
*
FROM "clients" AS c""",
),
(
"SELECT c.id, o.amount FROM customers c JOIN orders o ON c.id = o.customer_id",
{"customers": "clients", "orders": "(SELECT * FROM sales)"},
'''SELECT
"c"."id",
"o"."amount"
FROM "clients" AS c
JOIN (
(
SELECT
*
FROM "sales"
)
) AS o
ON "c"."id" = "o"."customer_id"''',
),
],
)
def test_replace_table_names(query, table_mapping, expected):
result = SQLParser.replace_table_and_column_names(query, table_mapping)
assert result.strip() == expected.strip()
Loading

0 comments on commit 2455592

Please sign in to comment.