diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 840fc30c1..c8ca82208 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -114,26 +114,30 @@ def create( "Please provide either a DataFrame, a Source or a View" ) + parsed_columns = [Column(**column) for column in columns] if columns else None + if df is not None: schema = df.schema schema.name = dataset_name + if ( + parsed_columns + ): # if no columns are passed it automatically parse the columns from the df + schema.columns = parsed_columns parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path) df.to_parquet(parquet_file_path_abs_path, index=False) elif view: _relation = [Relation(**relation) for relation in relations or ()] schema: SemanticLayerSchema = SemanticLayerSchema( - name=dataset_name, relations=_relation, view=True + name=dataset_name, relations=_relation, view=True, columns=parsed_columns ) elif source.get("table"): schema: SemanticLayerSchema = SemanticLayerSchema( - name=dataset_name, source=Source(**source) + name=dataset_name, source=Source(**source), columns=parsed_columns ) else: raise InvalidConfigError("Unable to create schema with the provided params") schema.description = description or schema.description - if columns: - schema.columns = [Column(**column) for column in columns] file_manager.write(schema_path, schema.to_yaml()) diff --git a/pandasai/data_loader/semantic_layer_schema.py b/pandasai/data_loader/semantic_layer_schema.py index d0ccdfe6e..5d33d796f 100644 --- a/pandasai/data_loader/semantic_layer_schema.py +++ b/pandasai/data_loader/semantic_layer_schema.py @@ -283,6 +283,7 @@ def check_columns_relations(self): # unpack columns info _columns = self.columns + _column_names = [col.name for col in _columns or ()] _tables_names_in_columns = { column_name.split(".")[0] for column_name in _column_names or () @@ -309,8 +310,10 @@ def check_columns_relations(self): for column_name in _column_names_in_relations or () } - if not self.relations: - raise ValueError("At least one relation must be defined for view.") + if not self.relations and not self.columns: + raise ValueError( + "At least a relation or a column must be defined for view." + ) if not all( is_view_column_name(column_name) for column_name in _column_names @@ -327,10 +330,8 @@ def check_columns_relations(self): "All params 'from' and 'to' in the relations must be in the format '[dataset].[column]'." ) - if ( - uncovered_tables := _tables_names_in_columns - - _tables_names_in_relations - ): + uncovered_tables = _tables_names_in_columns - _tables_names_in_relations + if uncovered_tables and len(_tables_names_in_columns) > 1: raise ValueError( f"No relations provided for the following tables {uncovered_tables}." ) diff --git a/pandasai/data_loader/view_loader.py b/pandasai/data_loader/view_loader.py index ae04bfe5b..50788e0f7 100644 --- a/pandasai/data_loader/view_loader.py +++ b/pandasai/data_loader/view_loader.py @@ -44,7 +44,7 @@ def _get_dependencies_datasets(self) -> set[str]: table.split(".")[0] for relation in self.schema.relations for table in (relation.from_, relation.to) - } + } or {self.schema.columns[0].name.split(".")[0]} def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]: dependency_dict = { diff --git a/pandasai/helpers/sql_sanitizer.py b/pandasai/helpers/sql_sanitizer.py index 94d695174..eb0bb70fa 100644 --- a/pandasai/helpers/sql_sanitizer.py +++ b/pandasai/helpers/sql_sanitizer.py @@ -4,7 +4,7 @@ import sqlglot -def sanitize_relation_name(relation_name: str) -> str: +def sanitize_view_column_name(relation_name: str) -> str: return ".".join(list(map(sanitize_sql_table_name, relation_name.split(".")))) diff --git a/pandasai/query_builders/view_query_builder.py b/pandasai/query_builders/view_query_builder.py index 86fa70d19..3e66037d4 100644 --- a/pandasai/query_builders/view_query_builder.py +++ b/pandasai/query_builders/view_query_builder.py @@ -6,7 +6,7 @@ from ..data_loader.loader import DatasetLoader from ..data_loader.semantic_layer_schema import SemanticLayerSchema -from ..helpers.sql_sanitizer import sanitize_relation_name +from ..helpers.sql_sanitizer import sanitize_view_column_name from .base_query_builder import BaseQueryBuilder @@ -19,10 +19,20 @@ def __init__( super().__init__(schema) self.schema_dependencies_dict = schema_dependencies_dict + @staticmethod + def normalize_view_column_name(name: str) -> str: + return normalize_identifiers(parse_one(sanitize_view_column_name(name))).sql() + + @staticmethod + def normalize_view_column_alias(name: str) -> str: + return normalize_identifiers( + sanitize_view_column_name(name).replace(".", "_") + ).sql() + def _get_columns(self) -> list[str]: if self.schema.columns: return [ - normalize_identifiers(col.name.replace(".", "_")).sql() + self.normalize_view_column_alias(col.name) for col in self.schema.columns ] else: @@ -34,13 +44,18 @@ def _get_sub_query_from_loader(self, loader: DatasetLoader) -> Subquery: def _get_table_expression(self) -> str: relations = self.schema.relations - first_dataset = relations[0].from_.split(".")[0] + columns = self.schema.columns + first_dataset = ( + relations[0].from_.split(".")[0] + if relations + else columns[0].name.split(".")[0] + ) first_loader = self.schema_dependencies_dict[first_dataset] first_query = self._get_sub_query_from_loader(first_loader) if self.schema.columns: columns = [ - f"{normalize_identifiers(col.name).sql()} AS {normalize_identifiers(col.name.replace('.', '_'))}" + f"{self.normalize_view_column_name(col.name)} AS {self.normalize_view_column_alias(col.name)}" for col in self.schema.columns ] else: @@ -54,7 +69,7 @@ def _get_table_expression(self) -> str: subquery = self._get_sub_query_from_loader(loader) query = query.join( subquery, - on=f"{sanitize_relation_name(relation.from_)} = {sanitize_relation_name(relation.to)}", + on=f"{sanitize_view_column_name(relation.from_)} = {sanitize_view_column_name(relation.to)}", append=True, ) alias = normalize_identifiers(self.schema.name).sql() diff --git a/tests/unit_tests/helpers/test_sql_sanitizer.py b/tests/unit_tests/helpers/test_sql_sanitizer.py index f52166723..d35684235 100644 --- a/tests/unit_tests/helpers/test_sql_sanitizer.py +++ b/tests/unit_tests/helpers/test_sql_sanitizer.py @@ -1,7 +1,7 @@ from pandasai.helpers.sql_sanitizer import ( is_sql_query_safe, sanitize_file_name, - sanitize_relation_name, + sanitize_view_column_name, ) @@ -25,7 +25,7 @@ def test_sanitize_file_name_long_name(self): def test_sanitize_relation_name_valid(self): relation = "dataset-name.column" expected = "dataset_name.column" - assert sanitize_relation_name(relation) == expected + assert sanitize_view_column_name(relation) == expected def test_safe_select_query(self): query = "SELECT * FROM users WHERE username = 'admin';" diff --git a/tests/unit_tests/query_builders/test_view_query_builder.py b/tests/unit_tests/query_builders/test_view_query_builder.py index b2d767c88..7a36ff5e3 100644 --- a/tests/unit_tests/query_builders/test_view_query_builder.py +++ b/tests/unit_tests/query_builders/test_view_query_builder.py @@ -24,9 +24,9 @@ def test_build_query(self, view_query_builder): children_name FROM ( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -57,9 +57,9 @@ def test_get_table_expression(self, view_query_builder): view_query_builder._get_table_expression() == """( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -85,9 +85,9 @@ def test_table_name_injection(self, view_query_builder): children_name FROM ( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -108,14 +108,14 @@ def test_column_name_injection(self, view_query_builder): assert ( query == """SELECT - "column; DROP TABLE users;", + column__drop_table_users_, parents_name, children_name FROM ( SELECT - "column; DROP TABLE users;" AS "column; DROP TABLE users;", - "parents.name" AS parents_name, - "children.name" AS children_name + column__drop_table_users_ AS column__drop_table_users_, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -141,9 +141,9 @@ def test_table_name_union_injection(self, view_query_builder): children_name FROM ( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -166,14 +166,14 @@ def test_column_name_union_injection(self, view_query_builder): assert ( query == """SELECT - "column UNION SELECT username, password FROM users;", + column_union_select_username__password_from_users_, parents_name, children_name FROM ( SELECT - "column UNION SELECT username, password FROM users;" AS "column UNION SELECT username, password FROM users;", - "parents.name" AS parents_name, - "children.name" AS children_name + column_union_select_username__password_from_users_ AS column_union_select_username__password_from_users_, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -199,9 +199,9 @@ def test_table_name_comment_injection(self, view_query_builder): children_name FROM ( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -222,14 +222,14 @@ def test_column_name_comment_injection(self, view_query_builder): assert ( query == """SELECT - column, + column___, parents_name, children_name FROM ( SELECT - column AS column, - "parents.name" AS parents_name, - "children.name" AS children_name + column___ AS column___, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT *