diff --git a/pandasai/__init__.py b/pandasai/__init__.py index d3265fa5d..731401257 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -115,26 +115,30 @@ def create( "Please provide either a DataFrame, a Source or a View" ) + parsed_columns = [Column(**column) for column in columns] if columns else None + if df is not None: schema = df.schema schema.name = dataset_name + if ( + parsed_columns + ): # if no columns are passed it automatically parse the columns from the df + schema.columns = parsed_columns parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path) df.to_parquet(parquet_file_path_abs_path, index=False) elif view: _relation = [Relation(**relation) for relation in relations or ()] schema: SemanticLayerSchema = SemanticLayerSchema( - name=dataset_name, relations=_relation, view=True + name=dataset_name, relations=_relation, view=True, columns=parsed_columns ) elif source.get("table"): schema: SemanticLayerSchema = SemanticLayerSchema( - name=dataset_name, source=Source(**source) + name=dataset_name, source=Source(**source), columns=parsed_columns ) else: raise InvalidConfigError("Unable to create schema with the provided params") schema.description = description or schema.description - if columns: - schema.columns = [Column(**column) for column in columns] file_manager.write(schema_path, schema.to_yaml()) diff --git a/pandasai/data_loader/semantic_layer_schema.py b/pandasai/data_loader/semantic_layer_schema.py index d0ccdfe6e..5d33d796f 100644 --- a/pandasai/data_loader/semantic_layer_schema.py +++ b/pandasai/data_loader/semantic_layer_schema.py @@ -283,6 +283,7 @@ def check_columns_relations(self): # unpack columns info _columns = self.columns + _column_names = [col.name for col in _columns or ()] _tables_names_in_columns = { column_name.split(".")[0] for column_name in _column_names or () @@ -309,8 +310,10 @@ def check_columns_relations(self): for column_name in _column_names_in_relations or () } - if not self.relations: - raise ValueError("At least one relation must be defined for view.") + if not self.relations and not self.columns: + raise ValueError( + "At least a relation or a column must be defined for view." + ) if not all( is_view_column_name(column_name) for column_name in _column_names @@ -327,10 +330,8 @@ def check_columns_relations(self): "All params 'from' and 'to' in the relations must be in the format '[dataset].[column]'." ) - if ( - uncovered_tables := _tables_names_in_columns - - _tables_names_in_relations - ): + uncovered_tables = _tables_names_in_columns - _tables_names_in_relations + if uncovered_tables and len(_tables_names_in_columns) > 1: raise ValueError( f"No relations provided for the following tables {uncovered_tables}." ) diff --git a/pandasai/data_loader/view_loader.py b/pandasai/data_loader/view_loader.py index ae04bfe5b..50788e0f7 100644 --- a/pandasai/data_loader/view_loader.py +++ b/pandasai/data_loader/view_loader.py @@ -44,7 +44,7 @@ def _get_dependencies_datasets(self) -> set[str]: table.split(".")[0] for relation in self.schema.relations for table in (relation.from_, relation.to) - } + } or {self.schema.columns[0].name.split(".")[0]} def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]: dependency_dict = { diff --git a/pandasai/dataframe/virtual_dataframe.py b/pandasai/dataframe/virtual_dataframe.py index 948b66832..48f2dc3d7 100644 --- a/pandasai/dataframe/virtual_dataframe.py +++ b/pandasai/dataframe/virtual_dataframe.py @@ -30,7 +30,6 @@ def __init__(self, *args, **kwargs): self._head = None super().__init__( - self.get_head(), *args, **kwargs, ) diff --git a/pandasai/helpers/sql_sanitizer.py b/pandasai/helpers/sql_sanitizer.py index 94d695174..eb0bb70fa 100644 --- a/pandasai/helpers/sql_sanitizer.py +++ b/pandasai/helpers/sql_sanitizer.py @@ -4,7 +4,7 @@ import sqlglot -def sanitize_relation_name(relation_name: str) -> str: +def sanitize_view_column_name(relation_name: str) -> str: return ".".join(list(map(sanitize_sql_table_name, relation_name.split(".")))) diff --git a/pandasai/query_builders/base_query_builder.py b/pandasai/query_builders/base_query_builder.py index 36ee7f02e..b8909bb40 100644 --- a/pandasai/query_builders/base_query_builder.py +++ b/pandasai/query_builders/base_query_builder.py @@ -1,9 +1,6 @@ -import re -from typing import Any, List +from typing import List -import sqlglot -from sqlglot import from_, pretty, select -from sqlglot.expressions import Limit, cast +from sqlglot import select from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema, Source @@ -42,7 +39,7 @@ def _get_columns(self) -> list[str]: return ["*"] def _get_table_expression(self) -> str: - return normalize_identifiers(self.schema.name).sql() + return normalize_identifiers(self.schema.name).sql(pretty=True) @staticmethod def check_compatible_sources(sources: List[Source]) -> bool: diff --git a/pandasai/query_builders/sql_parser.py b/pandasai/query_builders/sql_parser.py index e1604fc83..86873be54 100644 --- a/pandasai/query_builders/sql_parser.py +++ b/pandasai/query_builders/sql_parser.py @@ -35,15 +35,16 @@ def transform_node(node): if isinstance(node, exp.Table): original_name = node.name if original_name in table_mapping: + alias = node.alias or original_name mapped_value = parsed_mapping[original_name] if isinstance(mapped_value, exp.Alias): return exp.Subquery( this=mapped_value.this.this, - alias=node.alias or original_name, + alias=alias, ) - return exp.Subquery( - this=mapped_value, alias=node.alias or original_name - ) + elif isinstance(mapped_value, exp.Column): + return exp.Table(this=mapped_value.this, alias=alias) + return exp.Subquery(this=mapped_value, alias=alias) return node diff --git a/pandasai/query_builders/view_query_builder.py b/pandasai/query_builders/view_query_builder.py index 86fa70d19..3e66037d4 100644 --- a/pandasai/query_builders/view_query_builder.py +++ b/pandasai/query_builders/view_query_builder.py @@ -6,7 +6,7 @@ from ..data_loader.loader import DatasetLoader from ..data_loader.semantic_layer_schema import SemanticLayerSchema -from ..helpers.sql_sanitizer import sanitize_relation_name +from ..helpers.sql_sanitizer import sanitize_view_column_name from .base_query_builder import BaseQueryBuilder @@ -19,10 +19,20 @@ def __init__( super().__init__(schema) self.schema_dependencies_dict = schema_dependencies_dict + @staticmethod + def normalize_view_column_name(name: str) -> str: + return normalize_identifiers(parse_one(sanitize_view_column_name(name))).sql() + + @staticmethod + def normalize_view_column_alias(name: str) -> str: + return normalize_identifiers( + sanitize_view_column_name(name).replace(".", "_") + ).sql() + def _get_columns(self) -> list[str]: if self.schema.columns: return [ - normalize_identifiers(col.name.replace(".", "_")).sql() + self.normalize_view_column_alias(col.name) for col in self.schema.columns ] else: @@ -34,13 +44,18 @@ def _get_sub_query_from_loader(self, loader: DatasetLoader) -> Subquery: def _get_table_expression(self) -> str: relations = self.schema.relations - first_dataset = relations[0].from_.split(".")[0] + columns = self.schema.columns + first_dataset = ( + relations[0].from_.split(".")[0] + if relations + else columns[0].name.split(".")[0] + ) first_loader = self.schema_dependencies_dict[first_dataset] first_query = self._get_sub_query_from_loader(first_loader) if self.schema.columns: columns = [ - f"{normalize_identifiers(col.name).sql()} AS {normalize_identifiers(col.name.replace('.', '_'))}" + f"{self.normalize_view_column_name(col.name)} AS {self.normalize_view_column_alias(col.name)}" for col in self.schema.columns ] else: @@ -54,7 +69,7 @@ def _get_table_expression(self) -> str: subquery = self._get_sub_query_from_loader(loader) query = query.join( subquery, - on=f"{sanitize_relation_name(relation.from_)} = {sanitize_relation_name(relation.to)}", + on=f"{sanitize_view_column_name(relation.from_)} = {sanitize_view_column_name(relation.to)}", append=True, ) alias = normalize_identifiers(self.schema.name).sql() diff --git a/tests/unit_tests/helpers/test_sql_sanitizer.py b/tests/unit_tests/helpers/test_sql_sanitizer.py index f52166723..d35684235 100644 --- a/tests/unit_tests/helpers/test_sql_sanitizer.py +++ b/tests/unit_tests/helpers/test_sql_sanitizer.py @@ -1,7 +1,7 @@ from pandasai.helpers.sql_sanitizer import ( is_sql_query_safe, sanitize_file_name, - sanitize_relation_name, + sanitize_view_column_name, ) @@ -25,7 +25,7 @@ def test_sanitize_file_name_long_name(self): def test_sanitize_relation_name_valid(self): relation = "dataset-name.column" expected = "dataset_name.column" - assert sanitize_relation_name(relation) == expected + assert sanitize_view_column_name(relation) == expected def test_safe_select_query(self): query = "SELECT * FROM users WHERE username = 'admin';" diff --git a/tests/unit_tests/query_builders/test_query_builder.py b/tests/unit_tests/query_builders/test_query_builder.py index 46aac973a..ae5becb44 100644 --- a/tests/unit_tests/query_builders/test_query_builder.py +++ b/tests/unit_tests/query_builders/test_query_builder.py @@ -250,3 +250,19 @@ def test_table_name_time_based_injection(self, mysql_schema): created_at DESC LIMIT 100""" ) + + @pytest.mark.parametrize( + "injection", + [ + "users; DROP TABLE users;", + "users UNION SELECT 1,2,3;", + 'users"; SELECT * FROM sensitive_data; --', + "users; TRUNCATE users; SELECT * FROM users WHERE 't'='t", + "users' AND (SELECT * FROM (SELECT(SLEEP(5)))test); --", + ], + ) + def test_order_by_injection(self, injection, mysql_schema): + mysql_schema.order_by = [injection] + query_builder = BaseQueryBuilder(mysql_schema) + with pytest.raises((sqlglot.errors.ParseError, sqlglot.errors.TokenError)): + query_builder.build_query() diff --git a/tests/unit_tests/query_builders/test_sql_parser.py b/tests/unit_tests/query_builders/test_sql_parser.py new file mode 100644 index 000000000..5a3d55ac8 --- /dev/null +++ b/tests/unit_tests/query_builders/test_sql_parser.py @@ -0,0 +1,58 @@ +import pytest + +from pandasai.query_builders.sql_parser import SQLParser + + +class TestSqlParser: + @staticmethod + @pytest.mark.parametrize( + "query, table_mapping, expected", + [ + ( + "SELECT * FROM customers", + {"customers": "clients"}, + """SELECT + * +FROM "clients" AS customers""", + ), + ( + "SELECT * FROM orders", + {"orders": "(SELECT * FROM sales)"}, + """SELECT + * +FROM ( + ( + SELECT + * + FROM "sales" + ) +) AS orders""", + ), + ( + "SELECT * FROM customers c", + {"customers": "clients"}, + """SELECT + * +FROM "clients" AS c""", + ), + ( + "SELECT c.id, o.amount FROM customers c JOIN orders o ON c.id = o.customer_id", + {"customers": "clients", "orders": "(SELECT * FROM sales)"}, + '''SELECT + "c"."id", + "o"."amount" +FROM "clients" AS c +JOIN ( + ( + SELECT + * + FROM "sales" + ) +) AS o + ON "c"."id" = "o"."customer_id"''', + ), + ], + ) + def test_replace_table_names(query, table_mapping, expected): + result = SQLParser.replace_table_and_column_names(query, table_mapping) + assert result.strip() == expected.strip() diff --git a/tests/unit_tests/query_builders/test_view_query_builder.py b/tests/unit_tests/query_builders/test_view_query_builder.py index b2d767c88..7a36ff5e3 100644 --- a/tests/unit_tests/query_builders/test_view_query_builder.py +++ b/tests/unit_tests/query_builders/test_view_query_builder.py @@ -24,9 +24,9 @@ def test_build_query(self, view_query_builder): children_name FROM ( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -57,9 +57,9 @@ def test_get_table_expression(self, view_query_builder): view_query_builder._get_table_expression() == """( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -85,9 +85,9 @@ def test_table_name_injection(self, view_query_builder): children_name FROM ( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -108,14 +108,14 @@ def test_column_name_injection(self, view_query_builder): assert ( query == """SELECT - "column; DROP TABLE users;", + column__drop_table_users_, parents_name, children_name FROM ( SELECT - "column; DROP TABLE users;" AS "column; DROP TABLE users;", - "parents.name" AS parents_name, - "children.name" AS children_name + column__drop_table_users_ AS column__drop_table_users_, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -141,9 +141,9 @@ def test_table_name_union_injection(self, view_query_builder): children_name FROM ( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -166,14 +166,14 @@ def test_column_name_union_injection(self, view_query_builder): assert ( query == """SELECT - "column UNION SELECT username, password FROM users;", + column_union_select_username__password_from_users_, parents_name, children_name FROM ( SELECT - "column UNION SELECT username, password FROM users;" AS "column UNION SELECT username, password FROM users;", - "parents.name" AS parents_name, - "children.name" AS children_name + column_union_select_username__password_from_users_ AS column_union_select_username__password_from_users_, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -199,9 +199,9 @@ def test_table_name_comment_injection(self, view_query_builder): children_name FROM ( SELECT - "parents.id" AS parents_id, - "parents.name" AS parents_name, - "children.name" AS children_name + parents.id AS parents_id, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT * @@ -222,14 +222,14 @@ def test_column_name_comment_injection(self, view_query_builder): assert ( query == """SELECT - column, + column___, parents_name, children_name FROM ( SELECT - column AS column, - "parents.name" AS parents_name, - "children.name" AS children_name + column___ AS column___, + parents.name AS parents_name, + children.name AS children_name FROM ( SELECT *