Skip to content

Commit

Permalink
feat(SemanticLayerSchema): making column type optional and avoid dump…
Browse files Browse the repository at this point in the history
…ing and parsing null values in the SemanticLayerSchema
  • Loading branch information
scaliseraoul committed Jan 15, 2025
1 parent d370d17 commit f889c86
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 34 deletions.
2 changes: 1 addition & 1 deletion docs/v3/contributing.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ make spell_fix
We use `pytest` to test our code. You can run the tests by running the following command:

```bash
make tests
make test_all
```

Make sure that all tests pass before submitting a pull request.
Expand Down
35 changes: 22 additions & 13 deletions docs/v3/semantic-layer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,22 @@ df.save(
path="company/sales-data", # Format: "organization/dataset"
name="sales-data", # Human-readable name
description="Sales data from our retail stores", # Optional description
columns={
"transaction_id": {
columns=[
{
"name": "transaction_id",
"type": "string",
"description": "Unique identifier for each sale"
},
"sale_date": {
{
"name": "sale_date"
"type": "datetime",
"description": "Date and time of the sale"
}
}
]
)
```

#### name
#### name

The name field identifies your dataset in the save method.

Expand Down Expand Up @@ -84,28 +86,33 @@ df.save(
path="company/sales-data",
name="sales-data",
description="Daily sales transactions from all retail stores",
columns={
"transaction_id": {
columns=[
{
"name": "transaction_id",
"type": "string",
"description": "Unique identifier for each sale"
},
"sale_date": {
{
"name": "sale_date"
"type": "datetime",
"description": "Date and time of the sale"
},
"quantity": {
{
"name": "quantity",
"type": "integer",
"description": "Number of units sold"
},
"price": {
{
"name": "price",
"type": "float",
"description": "Price per unit in USD"
},
"is_online": {
{
"name": "is_online",
"type": "boolean",
"description": "Whether the sale was made online"
}
}
]
)
```

Expand Down Expand Up @@ -238,7 +245,7 @@ source:
- `connection_string` (str): Connection string for the data source
- `query` (str): Query to retrieve data from the data source


{/* commented as destination and update frequency will be only in the materialized case
#### destination (mandatory)
Specify the destination for your dataset.

Expand All @@ -256,6 +263,7 @@ destination:
path: /path/to/data
```


#### update_frequency
Specify the frequency of updates for your dataset.

Expand All @@ -268,6 +276,7 @@ Specify the frequency of updates for your dataset.
```yaml
update_frequency: daily
```
*/}

#### order_by
Specify the columns to order by.
Expand Down
2 changes: 1 addition & 1 deletion pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _get_abs_dataset_path(self):
return os.path.join(find_project_root(), "datasets", self.dataset_path)

def _load_schema(self):
schema_path = os.path.join(self._get_abs_dataset_path(), "schema.yaml")
schema_path = os.path.join(str(self._get_abs_dataset_path()), "schema.yaml")
if not os.path.exists(schema_path):
raise FileNotFoundError(f"Schema file not found: {schema_path}")

Expand Down
19 changes: 12 additions & 7 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import yaml
from pydantic import (
Expand All @@ -19,21 +19,23 @@

class Column(BaseModel):
name: str = Field(..., description="Name of the column.")
type: str = Field(..., description="Data type of the column.")
type: Optional[str] = Field(None, description="Data type of the column.")
description: Optional[str] = Field(None, description="Description of the column")

@field_validator("type")
@classmethod
def is_column_type_supported(cls, type: str) -> str:
if type not in VALID_COLUMN_TYPES:
raise ValueError(f"Unsupported column type: {type}")
raise ValueError(
f"Unsupported column type: {type}. Supported types are: {VALID_COLUMN_TYPES}"
)
return type


class Transformation(BaseModel):
type: str = Field(..., description="Type of transformation to be applied.")
params: Dict[str, str] = Field(
..., description="Parameters for the transformation."
params: Optional[Dict[str, str]] = Field(
None, description="Parameters for the transformation."
)

@field_validator("type")
Expand Down Expand Up @@ -95,13 +97,13 @@ def is_format_supported(cls, format: str) -> str:

class SemanticLayerSchema(BaseModel):
name: str = Field(..., description="Dataset name.")
source: Source = Field(..., description="Data source for your dataset.")
description: Optional[str] = Field(
None, description="Dataset’s contents and purpose description."
)
columns: Optional[List[Column]] = Field(
None, description="Structure and metadata of your dataset’s columns"
)
source: Source = Field(..., description="Data source for your dataset.")
order_by: Optional[List[str]] = Field(
None, description="Ordering criteria for the dataset."
)
Expand All @@ -118,8 +120,11 @@ class SemanticLayerSchema(BaseModel):
None, description="Frequency of dataset updates."
)

def to_dict(self) -> dict[str, Any]:
return self.model_dump(exclude_none=True)

def to_yaml(self) -> str:
return yaml.dump(self.model_dump(), sort_keys=False)
return yaml.dump(self.to_dict(), sort_keys=False)


def is_schema_source_same(
Expand Down
9 changes: 5 additions & 4 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,22 +176,23 @@ def _create_yml_template(
columns_dict: dictionary with info about columns of the dataframe
"""

columns = list(map(lambda column: Column(**column), columns_dict))
if columns_dict:
columns_dict = list(map(lambda column: Column(**column), columns_dict))

schema = SemanticLayerSchema(
name=name,
description=description,
columns=columns,
columns=columns_dict,
source=Source(type="parquet", path="data.parquet"),
destination=Destination(
type="local", format="parquet", path="data.parquet"
),
)

return schema.model_dump()
return schema.to_dict()

def save(
self, path: str, name: str, description: str = None, columns: List[dict] = []
self, path: str, name: str, description: str = None, columns: List[dict] = None
):
self.name = name
self.description = description
Expand Down
10 changes: 2 additions & 8 deletions tests/unit_tests/dataframe/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,18 @@ def test_save_creates_correct_schema(self, sample_df):
"name": name,
"description": description,
"columns": [
{"name": "Name", "type": "string", "description": None},
{"name": "Age", "type": "integer", "description": None},
{"name": "Name", "type": "string"},
{"name": "Age", "type": "integer"},
],
"destination": {
"format": "parquet",
"path": "data.parquet",
"type": "local",
},
"limit": None,
"order_by": None,
"source": {
"connection": None,
"path": "data.parquet",
"table": None,
"type": "parquet",
},
"transformations": None,
"update_frequency": None,
}

mock_yaml_dump.assert_called_once_with(
Expand Down

0 comments on commit f889c86

Please sign in to comment.