Skip to content

Commit

Permalink
feature(Views): adding view for local sources (#1586)
Browse files Browse the repository at this point in the history
* fix(FileManager): updated path in create and pull methods

* fix(Views): updated views to reflect refactors

* fix(Views): removes commented code and minor changes

* feature(Views): adding view for local sources

* feature(Views): ruff resolution

* fix: update load in cli and pull method

* feature(Views): add SQLParser method to parse sql generate to validated one, using, refactor query builders to use sqlglot

* feature(QueryBuilder): add checks against SQL injection

* test: enforce further sql injection tests

---------

Co-authored-by: scaliseraoul-sinaptik <[email protected]>
Co-authored-by: Gabriele Venturi <[email protected]>
  • Loading branch information
3 people authored Feb 6, 2025
1 parent 3c612d2 commit 8a0123c
Show file tree
Hide file tree
Showing 35 changed files with 923 additions and 422 deletions.
14 changes: 7 additions & 7 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
from pandasai.exceptions import DatasetNotFound, InvalidConfigError, PandaAIApiKeyError
from pandasai.helpers.path import find_project_root, get_validated_dataset_path
from pandasai.helpers.session import get_pandaai_session
from pandasai.query_builders import SqlQueryBuilder

from .agent import Agent
from .constants import LOCAL_SOURCE_TYPES, SQL_SOURCE_TYPES
from .core.cache import Cache
from .data_loader.loader import DatasetLoader
from .data_loader.query_builder import QueryBuilder
from .data_loader.semantic_layer_schema import (
Column,
)
from .dataframe import DataFrame, VirtualDataFrame
from .helpers.sql_sanitizer import sanitize_sql_table_name
from .helpers.sql_sanitizer import sanitize_file_name, sanitize_sql_table_name
from .smart_dataframe import SmartDataframe
from .smart_datalake import SmartDatalake

Expand Down Expand Up @@ -97,7 +97,6 @@ def create(
raise ValueError("df must be a PandaAI DataFrame")

org_name, dataset_name = get_validated_dataset_path(path)

dataset_directory = str(os.path.join(org_name, dataset_name))

schema_path = os.path.join(dataset_directory, "schema.yaml")
Expand All @@ -117,17 +116,17 @@ def create(

if df is not None:
schema = df.schema
schema.name = sanitize_sql_table_name(dataset_name)
schema.name = dataset_name
parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path)
df.to_parquet(parquet_file_path_abs_path, index=False)
elif view:
_relation = [Relation(**relation) for relation in relations or ()]
schema: SemanticLayerSchema = SemanticLayerSchema(
name=sanitize_sql_table_name(dataset_name), relations=_relation, view=True
name=dataset_name, relations=_relation, view=True
)
elif source.get("table"):
schema: SemanticLayerSchema = SemanticLayerSchema(
name=sanitize_sql_table_name(dataset_name), source=Source(**source)
name=dataset_name, source=Source(**source)
)
else:
raise InvalidConfigError("Unable to create schema with the provided params")
Expand All @@ -140,6 +139,7 @@ def create(

print(f"Dataset saved successfully to path: {dataset_directory}")

schema.name = sanitize_sql_table_name(schema.name)
loader = DatasetLoader.create_loader_from_schema(schema, path)
return loader.load()

Expand Down Expand Up @@ -252,7 +252,7 @@ def load(dataset_path: str) -> DataFrame:

def read_csv(filepath: str) -> DataFrame:
data = pd.read_csv(filepath)
table = f"table_{sanitize_sql_table_name(filepath)}"
table = f"table_{sanitize_file_name(filepath)}"
return DataFrame(data, _table_name=table)


Expand Down
34 changes: 24 additions & 10 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
from pandasai.sandbox import Sandbox
from pandasai.vectorstores.vectorstore import VectorStore

from .. import SqlQueryBuilder
from ..config import Config
from ..constants import LOCAL_SOURCE_TYPES
from ..data_loader.duck_db_connection_manager import DuckDBConnectionManager
from ..query_builders.base_query_builder import BaseQueryBuilder
from ..query_builders.sql_parser import SQLParser
from .state import AgentState


Expand Down Expand Up @@ -65,6 +69,13 @@ def __init__(
stacklevel=2,
)

if isinstance(dfs, list):
sources = [df.schema.source for df in dfs]
if not BaseQueryBuilder.check_compatible_sources(sources):
raise ValueError(
f"The sources of these datasets: {dfs} are not compatibles"
)

self.description = description
self._state = AgentState()
self._state.initialize(dfs, config, memory_size, vectorstore, description)
Expand Down Expand Up @@ -117,18 +128,20 @@ def execute_code(self, code: str) -> dict:

return code_executor.execute_and_return_result(code)

def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
try:
# Use a context manager to ensure the connection is closed
with duckdb.connect() as con:
# Register all DataFrames in the state
for df in self._state.dfs:
con.register(df.schema.name, df)
@staticmethod
def _parse_correct_table_name(query: str, dfs: List[VirtualDataFrame]) -> str:
table_mapping = {
df.schema.name: df.query_builder._get_table_expression() for df in dfs
}

# Execute the query and fetch the result as a pandas DataFrame
result = con.sql(query).df()
return SQLParser.replace_table_and_column_names(query, table_mapping)

return result
def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
try:
db_manager = DuckDBConnectionManager()
for df in self._state.dfs:
db_manager.register(df.schema.name, df)
return db_manager.sql(query).df()
except duckdb.Error as e:
raise RuntimeError(f"SQL execution failed: {e}") from e

Expand All @@ -151,6 +164,7 @@ def _execute_sql_query(self, query: str) -> pd.DataFrame:
if source and source.type in LOCAL_SOURCE_TYPES:
return self._execute_local_sql_query(query)
else:
query = self._parse_correct_table_name(query, self._state.dfs)
return df0.execute_sql_query(query)

def execute_with_retries(self, code: str) -> Any:
Expand Down
8 changes: 4 additions & 4 deletions pandasai/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def pull(dataset_path):
"""📥 Pull a dataset from a remote source"""
try:
click.echo(f"🔄 Pulling dataset from: {dataset_path}")
dataset_loader = DatasetLoader()
df = dataset_loader.load(dataset_path)
dataset_loader = DatasetLoader.create_loader_from_path(dataset_path)
df = dataset_loader.load()
df.pull()
click.echo(f"\n✨ Dataset successfully pulled from path: {dataset_path}")
except Exception as e:
Expand All @@ -150,8 +150,8 @@ def push(dataset_path):
"""📤 Push a dataset to a remote source"""
try:
click.echo(f"🔄 Pushing dataset to: {dataset_path}")
dataset_loader = DatasetLoader()
df = dataset_loader.load(dataset_path)
dataset_loader = DatasetLoader.create_loader_from_path(dataset_path)
df = dataset_loader.load()
df.push()
click.echo(f"\n✨ Dataset successfully pushed to path: {dataset_path}")
except Exception as e:
Expand Down
39 changes: 39 additions & 0 deletions pandasai/data_loader/duck_db_connection_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import weakref

import duckdb


class DuckDBConnectionManager:
_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super(DuckDBConnectionManager, cls).__new__(cls)
cls._instance._init_connection()
weakref.finalize(cls._instance, cls._close_connection)
return cls._instance

def _init_connection(self):
"""Initialize a DuckDB connection."""
self.connection = duckdb.connect()
self._registered_tables = set()

@classmethod
def _close_connection(cls):
"""Closes the DuckDB connection when the instance is deleted."""
if cls._instance and hasattr(cls._instance, "connection"):
cls._instance.connection.close()
cls._instance = None

def register(self, name: str, df):
"""Registers a DataFrame as a DuckDB table."""
self.connection.register(name, df)
self._registered_tables.add(name)

def sql(self, query: str):
"""Executes an SQL query and returns the result as a Pandas DataFrame."""
return self.connection.sql(query)

def close(self):
"""Manually close the connection if needed."""
self._close_connection()
18 changes: 15 additions & 3 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from abc import ABC, abstractmethod

import pandas as pd
import yaml
Expand All @@ -12,15 +13,26 @@
from ..constants import (
LOCAL_SOURCE_TYPES,
)
from ..query_builders.base_query_builder import BaseQueryBuilder
from .semantic_layer_schema import SemanticLayerSchema
from .transformation_manager import TransformationManager


class DatasetLoader:
class DatasetLoader(ABC):
def __init__(self, schema: SemanticLayerSchema, dataset_path: str):
self.schema = schema
self.dataset_path = dataset_path
self.org_name, self.dataset_name = get_validated_dataset_path(self.dataset_path)
self.org_name, self.dataset_name = get_validated_dataset_path(dataset_path)
self.dataset_path = f"{self.org_name}/{self.dataset_name}"

@property
@abstractmethod
def query_builder(self) -> BaseQueryBuilder:
"""Abstract property that must be implemented by subclasses."""
pass

@abstractmethod
def execute_query(self, query: str):
pass

@classmethod
def create_loader_from_schema(
Expand Down
33 changes: 32 additions & 1 deletion pandasai/data_loader/local_loader.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,40 @@
import os

import duckdb
import pandas as pd

from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError
from pandasai.query_builders import LocalQueryBuilder

from ..config import ConfigManager
from ..constants import (
LOCAL_SOURCE_TYPES,
)
from ..helpers.sql_sanitizer import is_sql_query_safe
from .duck_db_connection_manager import DuckDBConnectionManager
from .loader import DatasetLoader
from .semantic_layer_schema import SemanticLayerSchema


class LocalDatasetLoader(DatasetLoader):
"""
Loader for local datasets (CSV, Parquet).
"""

def __init__(self, schema: SemanticLayerSchema, dataset_path: str):
super().__init__(schema, dataset_path)
self._query_builder: LocalQueryBuilder = LocalQueryBuilder(schema)

@property
def query_builder(self) -> LocalQueryBuilder:
return self._query_builder

def register_table(self):
df = self.load()
db_manager = DuckDBConnectionManager()
db_manager.register(self.schema.name, df)

def load(self) -> DataFrame:
df: pd.DataFrame = self._load_from_local_source()
df = self._filter_columns(df)
Expand Down Expand Up @@ -70,3 +88,16 @@ def _filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
df_columns = df.columns.tolist()
columns_to_keep = [col for col in df_columns if col in schema_columns]
return df[columns_to_keep]

def execute_query(self, query: str) -> pd.DataFrame:
try:
db_manager = DuckDBConnectionManager()

if not is_sql_query_safe(query):
raise MaliciousQueryError(
"The SQL query is deemed unsafe and will not be executed."
)

return db_manager.sql(query).df()
except duckdb.Error as e:
raise RuntimeError(f"SQL execution failed: {e}") from e
61 changes: 0 additions & 61 deletions pandasai/data_loader/query_builder.py

This file was deleted.

6 changes: 1 addition & 5 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,6 @@ def validate_type_and_fields(cls, values):
raise ValueError(
f"For local source type '{_type}', 'path' must be defined."
)
if not table:
raise ValueError(
f"For local source type '{_type}', 'table' must be defined."
)

elif _type in REMOTE_SOURCE_TYPES:
if not connection:
Expand Down Expand Up @@ -282,7 +278,7 @@ class SemanticLayerSchema(BaseModel):

@model_validator(mode="after")
def check_columns_relations(self):
column_re_check = r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$"
column_re_check = r"^[a-zA-Z0-9-]+\.[a-zA-Z0-9_]+$"
is_view_column_name = partial(re.match, column_re_check)

# unpack columns info
Expand Down
Loading

0 comments on commit 8a0123c

Please sign in to comment.