Skip to content

Commit

Permalink
feature(Views): views for a single dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
scaliseraoul committed Feb 6, 2025
1 parent 8a0123c commit 7480ad9
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 46 deletions.
12 changes: 8 additions & 4 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,30 @@ def create(
"Please provide either a DataFrame, a Source or a View"
)

parsed_columns = [Column(**column) for column in columns] if columns else None

if df is not None:
schema = df.schema
schema.name = dataset_name
if (
parsed_columns
): # if no columns are passed it automatically parse the columns from the df
schema.columns = parsed_columns
parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path)
df.to_parquet(parquet_file_path_abs_path, index=False)
elif view:
_relation = [Relation(**relation) for relation in relations or ()]
schema: SemanticLayerSchema = SemanticLayerSchema(
name=dataset_name, relations=_relation, view=True
name=dataset_name, relations=_relation, view=True, columns=parsed_columns
)
elif source.get("table"):
schema: SemanticLayerSchema = SemanticLayerSchema(
name=dataset_name, source=Source(**source)
name=dataset_name, source=Source(**source), columns=parsed_columns
)
else:
raise InvalidConfigError("Unable to create schema with the provided params")

schema.description = description or schema.description
if columns:
schema.columns = [Column(**column) for column in columns]

file_manager.write(schema_path, schema.to_yaml())

Expand Down
13 changes: 7 additions & 6 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def check_columns_relations(self):

# unpack columns info
_columns = self.columns

_column_names = [col.name for col in _columns or ()]
_tables_names_in_columns = {
column_name.split(".")[0] for column_name in _column_names or ()
Expand All @@ -309,8 +310,10 @@ def check_columns_relations(self):
for column_name in _column_names_in_relations or ()
}

if not self.relations:
raise ValueError("At least one relation must be defined for view.")
if not self.relations and not self.columns:
raise ValueError(
"At least a relation or a column must be defined for view."
)

if not all(
is_view_column_name(column_name) for column_name in _column_names
Expand All @@ -327,10 +330,8 @@ def check_columns_relations(self):
"All params 'from' and 'to' in the relations must be in the format '[dataset].[column]'."
)

if (
uncovered_tables := _tables_names_in_columns
- _tables_names_in_relations
):
uncovered_tables = _tables_names_in_columns - _tables_names_in_relations
if uncovered_tables and len(_tables_names_in_columns) > 1:
raise ValueError(
f"No relations provided for the following tables {uncovered_tables}."
)
Expand Down
2 changes: 1 addition & 1 deletion pandasai/data_loader/view_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_dependencies_datasets(self) -> set[str]:
table.split(".")[0]
for relation in self.schema.relations
for table in (relation.from_, relation.to)
}
} or {self.schema.columns[0].name.split(".")[0]}

def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]:
dependency_dict = {
Expand Down
2 changes: 1 addition & 1 deletion pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlglot


def sanitize_relation_name(relation_name: str) -> str:
def sanitize_view_column_name(relation_name: str) -> str:
return ".".join(list(map(sanitize_sql_table_name, relation_name.split("."))))


Expand Down
25 changes: 20 additions & 5 deletions pandasai/query_builders/view_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..data_loader.loader import DatasetLoader
from ..data_loader.semantic_layer_schema import SemanticLayerSchema
from ..helpers.sql_sanitizer import sanitize_relation_name
from ..helpers.sql_sanitizer import sanitize_view_column_name
from .base_query_builder import BaseQueryBuilder


Expand All @@ -19,10 +19,20 @@ def __init__(
super().__init__(schema)
self.schema_dependencies_dict = schema_dependencies_dict

@staticmethod
def normalize_view_column_name(name: str) -> str:
return normalize_identifiers(parse_one(sanitize_view_column_name(name))).sql()

@staticmethod
def normalize_view_column_alias(name: str) -> str:
return normalize_identifiers(
sanitize_view_column_name(name).replace(".", "_")
).sql()

def _get_columns(self) -> list[str]:
if self.schema.columns:
return [
normalize_identifiers(col.name.replace(".", "_")).sql()
self.normalize_view_column_alias(col.name)
for col in self.schema.columns
]
else:
Expand All @@ -34,13 +44,18 @@ def _get_sub_query_from_loader(self, loader: DatasetLoader) -> Subquery:

def _get_table_expression(self) -> str:
relations = self.schema.relations
first_dataset = relations[0].from_.split(".")[0]
columns = self.schema.columns
first_dataset = (
relations[0].from_.split(".")[0]
if relations
else columns[0].name.split(".")[0]
)
first_loader = self.schema_dependencies_dict[first_dataset]
first_query = self._get_sub_query_from_loader(first_loader)

if self.schema.columns:
columns = [
f"{normalize_identifiers(col.name).sql()} AS {normalize_identifiers(col.name.replace('.', '_'))}"
f"{self.normalize_view_column_name(col.name)} AS {self.normalize_view_column_alias(col.name)}"
for col in self.schema.columns
]
else:
Expand All @@ -54,7 +69,7 @@ def _get_table_expression(self) -> str:
subquery = self._get_sub_query_from_loader(loader)
query = query.join(
subquery,
on=f"{sanitize_relation_name(relation.from_)} = {sanitize_relation_name(relation.to)}",
on=f"{sanitize_view_column_name(relation.from_)} = {sanitize_view_column_name(relation.to)}",
append=True,
)
alias = normalize_identifiers(self.schema.name).sql()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pandasai.helpers.sql_sanitizer import (
is_sql_query_safe,
sanitize_file_name,
sanitize_relation_name,
sanitize_view_column_name,
)


Expand All @@ -25,7 +25,7 @@ def test_sanitize_file_name_long_name(self):
def test_sanitize_relation_name_valid(self):
relation = "dataset-name.column"
expected = "dataset_name.column"
assert sanitize_relation_name(relation) == expected
assert sanitize_view_column_name(relation) == expected

def test_safe_select_query(self):
query = "SELECT * FROM users WHERE username = 'admin';"
Expand Down
54 changes: 27 additions & 27 deletions tests/unit_tests/query_builders/test_view_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def test_build_query(self, view_query_builder):
children_name
FROM (
SELECT
"parents.id" AS parents_id,
"parents.name" AS parents_name,
"children.name" AS children_name
parents.id AS parents_id,
parents.name AS parents_name,
children.name AS children_name
FROM (
SELECT
*
Expand Down Expand Up @@ -57,9 +57,9 @@ def test_get_table_expression(self, view_query_builder):
view_query_builder._get_table_expression()
== """(
SELECT
"parents.id" AS parents_id,
"parents.name" AS parents_name,
"children.name" AS children_name
parents.id AS parents_id,
parents.name AS parents_name,
children.name AS children_name
FROM (
SELECT
*
Expand All @@ -85,9 +85,9 @@ def test_table_name_injection(self, view_query_builder):
children_name
FROM (
SELECT
"parents.id" AS parents_id,
"parents.name" AS parents_name,
"children.name" AS children_name
parents.id AS parents_id,
parents.name AS parents_name,
children.name AS children_name
FROM (
SELECT
*
Expand All @@ -108,14 +108,14 @@ def test_column_name_injection(self, view_query_builder):
assert (
query
== """SELECT
"column; DROP TABLE users;",
column__drop_table_users_,
parents_name,
children_name
FROM (
SELECT
"column; DROP TABLE users;" AS "column; DROP TABLE users;",
"parents.name" AS parents_name,
"children.name" AS children_name
column__drop_table_users_ AS column__drop_table_users_,
parents.name AS parents_name,
children.name AS children_name
FROM (
SELECT
*
Expand All @@ -141,9 +141,9 @@ def test_table_name_union_injection(self, view_query_builder):
children_name
FROM (
SELECT
"parents.id" AS parents_id,
"parents.name" AS parents_name,
"children.name" AS children_name
parents.id AS parents_id,
parents.name AS parents_name,
children.name AS children_name
FROM (
SELECT
*
Expand All @@ -166,14 +166,14 @@ def test_column_name_union_injection(self, view_query_builder):
assert (
query
== """SELECT
"column UNION SELECT username, password FROM users;",
column_union_select_username__password_from_users_,
parents_name,
children_name
FROM (
SELECT
"column UNION SELECT username, password FROM users;" AS "column UNION SELECT username, password FROM users;",
"parents.name" AS parents_name,
"children.name" AS children_name
column_union_select_username__password_from_users_ AS column_union_select_username__password_from_users_,
parents.name AS parents_name,
children.name AS children_name
FROM (
SELECT
*
Expand All @@ -199,9 +199,9 @@ def test_table_name_comment_injection(self, view_query_builder):
children_name
FROM (
SELECT
"parents.id" AS parents_id,
"parents.name" AS parents_name,
"children.name" AS children_name
parents.id AS parents_id,
parents.name AS parents_name,
children.name AS children_name
FROM (
SELECT
*
Expand All @@ -222,14 +222,14 @@ def test_column_name_comment_injection(self, view_query_builder):
assert (
query
== """SELECT
column,
column___,
parents_name,
children_name
FROM (
SELECT
column AS column,
"parents.name" AS parents_name,
"children.name" AS children_name
column___ AS column___,
parents.name AS parents_name,
children.name AS children_name
FROM (
SELECT
*
Expand Down

0 comments on commit 7480ad9

Please sign in to comment.