From 5f1d24d3b5ea36052fce2b594717808d6ed2bc73 Mon Sep 17 00:00:00 2001 From: Sergey Kulik <104143901+zxqfd555-pw@users.noreply.github.com> Date: Wed, 22 Jan 2025 22:01:30 +0100 Subject: [PATCH] serialize all Pathway types to Delta Lake (#7988) Co-authored-by: Mateusz Lewandowski Co-authored-by: Kamil Piechowiak <32928185+KamilPiechowiak@users.noreply.github.com> GitOrigin-RevId: 74fcfa5c09a09085a50cca169eaabfc67013690c --- CHANGELOG.md | 2 + integration_tests/iceberg/test_iceberg.py | 106 ++++- python/pathway/io/_utils.py | 15 +- python/pathway/io/python/__init__.py | 6 +- python/pathway/tests/test_io.py | 109 ++++- src/connectors/data_lake/delta.rs | 101 ++++- src/connectors/data_lake/iceberg.rs | 60 ++- src/connectors/data_lake/mod.rs | 492 ++++++++++++++++------ src/connectors/data_lake/writer.rs | 255 +++++++++-- src/engine/time.rs | 103 +++-- src/engine/value.rs | 5 + tests/integration/test_arrow.rs | 79 +++- tests/integration/test_deltalake.rs | 367 +++++++++++++--- 13 files changed, 1399 insertions(+), 301 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70372077..7ea53dbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Added - `pw.io.iceberg.read` method for reading Apache Iceberg tables into Pathway. - methods `pw.io.postgres.write` and `pw.io.postgres.write_snapshot` now accept an additional argument `init_mode`, which allows initializing the table before writing. +- `pw.io.deltalake.read` now supports serialization and deserialization for all Pathway data types. ### Changed - **BREAKING**: `pw.io.deltalake.read` now requires explicit specification of primary key fields. @@ -15,6 +16,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - **BREAKING**: `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer` now returns a dictionary from `pw_ai_answer` endpoint. - `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer` allows optionally returning context documents from `pw_ai_answer` endpoint. - **BREAKING**: When using delay in temporal behavior, current time is updated immediately, not in the next batch. +- **BREAKING**: The `Pointer` type is now serialized to Delta Tables as raw bytes. - `pw.io.kafka.write` now allows to specify `key` and `headers` for JSON and CSV data formats. ### Fixed diff --git a/integration_tests/iceberg/test_iceberg.py b/integration_tests/iceberg/test_iceberg.py index fc5fd5c4..e2741ce7 100644 --- a/integration_tests/iceberg/test_iceberg.py +++ b/integration_tests/iceberg/test_iceberg.py @@ -1,9 +1,12 @@ import json +import os +import pickle import threading import time import uuid import pandas as pd +from dateutil import tz from pyiceberg.catalog import load_catalog import pathway as pw @@ -23,7 +26,7 @@ {"user_id": 11, "name": "Steve"} {"user_id": 12, "name": "Sarah"}""" -CATALOG_URI = "http://iceberg:8181" +CATALOG_URI = os.environ.get("ICEBERG_CATALOG_URI", "http://iceberg:8181") INPUT_CONTENTS = { 1: INPUT_CONTENTS_1, 2: INPUT_CONTENTS_2, @@ -205,3 +208,104 @@ class InputSchema(pw.Schema): assert set(pandas_table["name"]) == all_names assert set(pandas_table["diff"]) == {1} assert len(set(pandas_table["time"])) == len(INPUT_CONTENTS) + + +def test_py_object_wrapper_in_iceberg(tmp_path): + input_path = tmp_path / "input.jsonl" + output_path = tmp_path / "output.jsonl" + iceberg_table_name = str(uuid.uuid4()) + input_path.write_text("test") + + table = pw.io.plaintext.read(input_path, mode="static") + table = table.select( + data=pw.this.data, + fun=pw.wrap_py_object(len, serializer=pickle), # type: ignore + ) + pw.io.iceberg.write( + table, + catalog_uri=CATALOG_URI, + namespace=["my_database"], + table_name=iceberg_table_name, + ) + run() + G.clear() + + class InputSchema(pw.Schema): + data: str = pw.column_definition(primary_key=True) + fun: pw.PyObjectWrapper + + @pw.udf + def use_python_object(a: pw.PyObjectWrapper, x: str) -> int: + return a.value(x) + + table = pw.io.iceberg.read( + catalog_uri=CATALOG_URI, + namespace=["my_database"], + table_name=iceberg_table_name, + mode="static", + schema=InputSchema, + ) + table = table.select(len=use_python_object(pw.this.fun, pw.this.data)) + pw.io.jsonlines.write(table, output_path) + run() + + with open(output_path, "r") as f: + data = json.load(f) + assert data["len"] == 4 + + +def test_iceberg_different_types_serialization(tmp_path): + input_path = tmp_path / "input.jsonl" + iceberg_table_name = str(uuid.uuid4()) + input_path.write_text("test") + + column_values = { + "boolean": True, + "integer": 123, + "double": -5.6, + "string": "abcdef", + "binary_data": b"fedcba", + "datetime_naive": pw.DateTimeNaive(year=2025, month=1, day=17), + "datetime_utc_aware": pw.DateTimeUtc(year=2025, month=1, day=17, tz=tz.UTC), + "duration": pw.Duration(days=5), + "json_data": pw.Json.parse('{"a": 15, "b": "hello"}'), + } + table = pw.io.plaintext.read(input_path, mode="static") + table = table.select( + data=pw.this.data, + **column_values, + ) + pw.io.iceberg.write( + table, + catalog_uri=CATALOG_URI, + namespace=["my_database"], + table_name=iceberg_table_name, + ) + run() + G.clear() + + class InputSchema(pw.Schema): + data: str = pw.column_definition(primary_key=True) + boolean: bool + integer: int + double: float + string: str + binary_data: bytes + datetime_naive: pw.DateTimeNaive + datetime_utc_aware: pw.DateTimeUtc + duration: pw.Duration + json_data: pw.Json + + def on_change(key, row, time, is_addition): + for field, expected_value in column_values.items(): + assert row[field] == expected_value + + table = pw.io.iceberg.read( + catalog_uri=CATALOG_URI, + namespace=["my_database"], + table_name=iceberg_table_name, + mode="static", + schema=InputSchema, + ) + pw.io.subscribe(table, on_change=on_change) + run() diff --git a/python/pathway/io/_utils.py b/python/pathway/io/_utils.py index 39eebbce..46bc3e3b 100644 --- a/python/pathway/io/_utils.py +++ b/python/pathway/io/_utils.py @@ -73,7 +73,11 @@ class RawDataSchema(pw.Schema): - data: Any + data: bytes + + +class PlaintextDataSchema(pw.Schema): + data: str class MetadataSchema(Schema): @@ -294,7 +298,12 @@ def construct_schema_and_data_format( if param in kwargs and kwargs[param] is not None: raise ValueError(f"Unexpected argument for plaintext format: {param}") - schema = RawDataSchema + parse_utf8 = format != "binary" + if parse_utf8: + schema = PlaintextDataSchema + else: + schema = RawDataSchema + if with_metadata: schema |= MetadataSchema schema, api_schema = read_schema( @@ -308,7 +317,7 @@ def construct_schema_and_data_format( return schema, api.DataFormat( format_type=data_format_type, **api_schema, - parse_utf8=(format != "binary"), + parse_utf8=parse_utf8, key_generation_policy=( api.KeyGenerationPolicy.ALWAYS_AUTOGENERATE if autogenerate_key diff --git a/python/pathway/io/python/__init__.py b/python/pathway/io/python/__init__.py index b8f9c45a..09f479e8 100644 --- a/python/pathway/io/python/__init__.py +++ b/python/pathway/io/python/__init__.py @@ -25,6 +25,7 @@ from pathway.internals.trace import trace_user_frame from pathway.io._utils import ( MetadataSchema, + PlaintextDataSchema, RawDataSchema, assert_schema_or_value_columns_not_none, get_data_format_type, @@ -418,7 +419,10 @@ def read( raise ValueError("raw format must not be used with primary_key property") if value_columns: raise ValueError("raw format must not be used with value_columns property") - schema = RawDataSchema + if format == "binary": + schema = RawDataSchema + else: + schema = PlaintextDataSchema if subject._with_metadata is True: schema |= MetadataSchema assert_schema_or_value_columns_not_none(schema, value_columns, data_format_type) diff --git a/python/pathway/tests/test_io.py b/python/pathway/tests/test_io.py index 18f3504f..2b8166b9 100644 --- a/python/pathway/tests/test_io.py +++ b/python/pathway/tests/test_io.py @@ -3,6 +3,7 @@ import json import os import pathlib +import pickle import socket import sqlite3 import sys @@ -11,9 +12,11 @@ from typing import Any, Optional from unittest import mock +import numpy as np import pandas as pd import pytest import yaml +from dateutil import tz from deltalake import DeltaTable, write_deltalake from fs import open_fs @@ -30,7 +33,6 @@ T, assert_table_equality, assert_table_equality_wo_index, - assert_table_equality_wo_index_types, deprecated_call_here, needs_multiprocessing_fork, run, @@ -214,7 +216,7 @@ def run(self): table = pw.io.python.read(TestSubject(), format="raw") - assert_table_equality_wo_index_types( + assert_table_equality_wo_index( table, T( """ @@ -2358,7 +2360,7 @@ class OutputSchema(pw.Schema): 3 | baz | 1701283942 """, ).update_types( - data=Any, + data=str, createdAt=Optional[int], ), result, @@ -3418,7 +3420,7 @@ class InputSchema(pw.Schema): pw.io.deltalake.read(tmp_path / "lake", schema=InputSchema) -def test_iceberg_no_primary_key(tmp_path: pathlib.Path): +def test_iceberg_no_primary_key(): class InputSchema(pw.Schema): k: int v: str @@ -3433,3 +3435,102 @@ class InputSchema(pw.Schema): table_name="test", schema=InputSchema, ) + + +def test_py_object_wrapper_in_deltalake(tmp_path: pathlib.Path): + input_path = tmp_path / "input.jsonl" + lake_path = tmp_path / "delta-lake" + output_path = tmp_path / "output.jsonl" + input_path.write_text("test") + + table = pw.io.plaintext.read(input_path, mode="static") + table = table.select( + data=pw.this.data, fun=pw.wrap_py_object(len, serializer=pickle) # type: ignore + ) + pw.io.deltalake.write(table, lake_path) + run_all() + G.clear() + + class InputSchema(pw.Schema): + data: str = pw.column_definition(primary_key=True) + fun: pw.PyObjectWrapper + + @pw.udf + def use_python_object(a: pw.PyObjectWrapper, x: str) -> int: + return a.value(x) + + table = pw.io.deltalake.read(lake_path, schema=InputSchema, mode="static") + table = table.select(len=use_python_object(pw.this.fun, pw.this.data)) + pw.io.jsonlines.write(table, output_path) + run_all() + + with open(output_path, "r") as f: + data = json.load(f) + assert data["len"] == 4 + + +def test_deltalake_different_types_serialization(tmp_path: pathlib.Path): + input_path = tmp_path / "input.jsonl" + lake_path = tmp_path / "delta-lake" + input_path.write_text("test") + + column_values = { + "boolean": True, + "integer": 123, + "double": -5.6, + "string": "abcdef", + "binary_data": b"fedcba", + "datetime_naive": pw.DateTimeNaive(year=2025, month=1, day=17), + "datetime_utc_aware": pw.DateTimeUtc(year=2025, month=1, day=17, tz=tz.UTC), + "duration": pw.Duration(days=5), + "ints": np.array([9, 9, 9], dtype=int), + "floats": np.array([1.1, 2.2, 3.3], dtype=float), + "json_data": pw.Json.parse('{"a": 15, "b": "hello"}'), + "tuple_data": (b"world", True), + "list_data": ["lorem", None, "ipsum"], + "fun": pw.wrap_py_object(len, serializer=pickle), + } + table = pw.io.plaintext.read(input_path, mode="static") + table = table.select( + data=pw.this.data, + **column_values, + ) + table = table.update_types( + ints=np.ndarray[None, int], + floats=np.ndarray[None, float], + tuple_data=tuple[bytes, bool], + list_data=list[str | None], + ) + table = table.select() + pw.io.deltalake.write(table, lake_path) + run_all() + G.clear() + + class InputSchema(pw.Schema): + data: str = pw.column_definition(primary_key=True) + boolean: bool + integer: int + double: float + string: str + binary_data: bytes + datetime_naive: pw.DateTimeNaive + datetime_utc_aware: pw.DateTimeUtc + duration: pw.Duration + ints: np.ndarray[None, int] + floats: np.ndarray[None, float] + json_data: pw.Json + tuple_data: tuple[bytes, bool] + list_data: list[str | None] + fun: pw.PyObjectWrapper + + def on_change(key, row, time, is_addition): + for field, expected_value in column_values.items(): + if isinstance(field, np.ndarray): + assert row[field].shape() == expected_value.shape() + assert (row[field] == expected_value).all() + else: + assert row[field] == expected_value + + table = pw.io.deltalake.read(lake_path, schema=InputSchema, mode="static") + pw.io.subscribe(table, on_change=on_change) + run_all() diff --git a/src/connectors/data_lake/delta.rs b/src/connectors/data_lake/delta.rs index 0c489439..61f8f270 100644 --- a/src/connectors/data_lake/delta.rs +++ b/src/connectors/data_lake/delta.rs @@ -9,9 +9,11 @@ use std::time::Duration; use deltalake::arrow::array::RecordBatch as ArrowRecordBatch; use deltalake::datafusion::parquet::file::reader::SerializedFileReader as DeltaLakeParquetReader; use deltalake::kernel::Action as DeltaLakeAction; +use deltalake::kernel::ArrayType as DeltaTableArrayType; use deltalake::kernel::DataType as DeltaTableKernelType; use deltalake::kernel::PrimitiveType as DeltaTablePrimitiveType; use deltalake::kernel::StructField as DeltaTableStructField; +use deltalake::kernel::StructType as DeltaTableStructType; use deltalake::operations::create::CreateBuilder as DeltaTableCreateBuilder; use deltalake::parquet::file::reader::FileReader as DeltaLakeParquetFileReader; use deltalake::parquet::record::reader::RowIter as ParquetRowIterator; @@ -23,7 +25,9 @@ use deltalake::{open_table_with_storage_options as open_delta_table, DeltaTable, use s3::bucket::Bucket as S3Bucket; use tempfile::tempfile; -use super::{parquet_row_into_values_map, LakeBatchWriter, SPECIAL_OUTPUT_FIELDS}; +use super::{ + parquet_row_into_values_map, LakeBatchWriter, LakeWriterSettings, SPECIAL_OUTPUT_FIELDS, +}; use crate::async_runtime::create_async_tokio_runtime; use crate::connectors::data_storage::ConnectorMode; use crate::connectors::scanner::S3Scanner; @@ -62,14 +66,14 @@ impl DeltaBatchWriter { for field in schema_fields { struct_fields.push(DeltaTableStructField::new( field.name.clone(), - Self::delta_table_primitive_type(&field.type_)?, + Self::delta_table_type(&field.type_)?, field.type_.can_be_none(), )); } for (field, type_) in SPECIAL_OUTPUT_FIELDS { struct_fields.push(DeltaTableStructField::new( field, - Self::delta_table_primitive_type(&type_)?, + Self::delta_table_type(&type_)?, false, )); } @@ -98,23 +102,73 @@ impl DeltaBatchWriter { Ok(table) } - fn delta_table_primitive_type(type_: &Type) -> Result { - Ok(DeltaTableKernelType::Primitive(match type_ { - Type::Bool => DeltaTablePrimitiveType::Boolean, - Type::Float => DeltaTablePrimitiveType::Double, - Type::String | Type::Json => DeltaTablePrimitiveType::String, - Type::Bytes => DeltaTablePrimitiveType::Binary, - Type::DateTimeNaive => DeltaTablePrimitiveType::TimestampNtz, - Type::DateTimeUtc => DeltaTablePrimitiveType::Timestamp, - Type::Int | Type::Duration => DeltaTablePrimitiveType::Long, - Type::Optional(wrapped) => return Self::delta_table_primitive_type(wrapped), - Type::Any - | Type::Array(_, _) - | Type::Tuple(_) - | Type::List(_) - | Type::PyObjectWrapper - | Type::Pointer => return Err(WriteError::UnsupportedType(type_.clone())), - })) + fn delta_table_type(type_: &Type) -> Result { + let delta_type = match type_ { + Type::Bool => DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::Boolean), + Type::Float => DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::Double), + Type::String | Type::Json => { + DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::String) + } + Type::PyObjectWrapper | Type::Pointer | Type::Bytes => { + DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::Binary) + } + Type::DateTimeNaive => { + DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::TimestampNtz) + } + Type::DateTimeUtc => { + DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::Timestamp) + } + Type::Int | Type::Duration => { + DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::Long) + } + Type::List(element_type) => { + let element_type_is_optional = element_type.is_optional(); + let nested_element_type = Self::delta_table_type(element_type.unoptionalize())?; + let array_type = + DeltaTableArrayType::new(nested_element_type, element_type_is_optional); + DeltaTableKernelType::Array(array_type.into()) + } + Type::Array(_, nested_type) => { + let wrapped_type = nested_type.as_ref(); + let elements_kernel_type = match wrapped_type { + Type::Int => DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::Long), + Type::Float => DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::Double), + _ => panic!("Type::Array can't contain elements of the type {wrapped_type:?}"), + }; + let shape_data_type = DeltaTableKernelType::Array( + DeltaTableArrayType::new( + DeltaTableKernelType::Primitive(DeltaTablePrimitiveType::Long), + true, + ) + .into(), + ); + let elements_data_type = DeltaTableKernelType::Array( + DeltaTableArrayType::new(elements_kernel_type, true).into(), + ); + let struct_descriptor = DeltaTableStructType::new(vec![ + DeltaTableStructField::new("shape", shape_data_type, false), + DeltaTableStructField::new("elements", elements_data_type, false), + ]); + DeltaTableKernelType::Struct(struct_descriptor.into()) + } + Type::Tuple(nested_types) => { + let mut struct_fields = Vec::new(); + for (index, nested_type) in nested_types.iter().enumerate() { + let nested_type_is_optional = nested_type.is_optional(); + let nested_delta_type = Self::delta_table_type(nested_type)?; + struct_fields.push(DeltaTableStructField::new( + format!("[{index}]"), + nested_delta_type, + nested_type_is_optional, + )); + } + let struct_descriptor = DeltaTableStructType::new(struct_fields); + DeltaTableKernelType::Struct(struct_descriptor.into()) + } + Type::Optional(wrapped) => return Self::delta_table_type(wrapped), + Type::Any => return Err(WriteError::UnsupportedType(type_.clone())), + }; + Ok(delta_type) } } @@ -126,6 +180,13 @@ impl LakeBatchWriter for DeltaBatchWriter { Ok::<(), WriteError>(()) }) } + + fn settings(&self) -> LakeWriterSettings { + LakeWriterSettings { + use_64bit_size_type: false, + utc_timezone_name: "UTC".into(), + } + } } pub enum ObjectDownloader { diff --git a/src/connectors/data_lake/iceberg.rs b/src/connectors/data_lake/iceberg.rs index ce7895b0..1af7acd9 100644 --- a/src/connectors/data_lake/iceberg.rs +++ b/src/connectors/data_lake/iceberg.rs @@ -9,8 +9,8 @@ use deltalake::parquet::file::properties::WriterProperties; use futures::{stream, StreamExt, TryStreamExt}; use iceberg::scan::{FileScanTask, FileScanTaskStream}; use iceberg::spec::{ - NestedField, PrimitiveType as IcebergPrimitiveType, Schema as IcebergSchema, - Type as IcebergType, + ListType as IcebergListType, NestedField, NestedField as IcebergNestedField, + PrimitiveType as IcebergPrimitiveType, Schema as IcebergSchema, Type as IcebergType, }; use iceberg::table::Table as IcebergTable; use iceberg::transaction::Transaction; @@ -25,7 +25,9 @@ use iceberg::{Catalog, Namespace, NamespaceIdent, TableCreation, TableIdent}; use iceberg_catalog_rest::{RestCatalog, RestCatalogConfig}; use tokio::runtime::Runtime as TokioRuntime; -use super::{columns_into_pathway_values, LakeBatchWriter, SPECIAL_OUTPUT_FIELDS}; +use super::{ + columns_into_pathway_values, LakeBatchWriter, LakeWriterSettings, SPECIAL_OUTPUT_FIELDS, +}; use crate::async_runtime::create_async_tokio_runtime; use crate::connectors::data_storage::ConnectorMode; use crate::connectors::metadata::IcebergMetadata; @@ -159,22 +161,34 @@ impl IcebergTableParams { } fn iceberg_type(type_: &Type) -> Result { - Ok(IcebergType::Primitive(match type_ { - Type::Bool => IcebergPrimitiveType::Boolean, - Type::Float => IcebergPrimitiveType::Double, - Type::String | Type::Json => IcebergPrimitiveType::String, - Type::Bytes => IcebergPrimitiveType::Binary, - Type::DateTimeNaive => IcebergPrimitiveType::Timestamp, - Type::DateTimeUtc => IcebergPrimitiveType::Timestamptz, - Type::Int | Type::Duration => IcebergPrimitiveType::Long, - Type::Optional(wrapped) => return Self::iceberg_type(wrapped), - Type::Any - | Type::Array(_, _) - | Type::Tuple(_) - | Type::List(_) // TODO: it is possible to support lists with the usage of IcebergType::List - | Type::PyObjectWrapper - | Type::Pointer => return Err(WriteError::UnsupportedType(type_.clone())), - })) + let iceberg_type = match type_ { + Type::Bool => IcebergType::Primitive(IcebergPrimitiveType::Boolean), + Type::Float => IcebergType::Primitive(IcebergPrimitiveType::Double), + Type::String | Type::Json => IcebergType::Primitive(IcebergPrimitiveType::String), + Type::Bytes | Type::PyObjectWrapper | Type::Pointer => { + IcebergType::Primitive(IcebergPrimitiveType::Binary) + } + Type::DateTimeNaive => IcebergType::Primitive(IcebergPrimitiveType::Timestamp), + Type::DateTimeUtc => IcebergType::Primitive(IcebergPrimitiveType::Timestamptz), + Type::Int | Type::Duration => IcebergType::Primitive(IcebergPrimitiveType::Long), + Type::Optional(wrapped) => Self::iceberg_type(wrapped)?, + Type::List(element_type) => { + let element_type_is_optional = element_type.is_optional(); + let nested_element_type = Self::iceberg_type(element_type.unoptionalize())?; + let nested_type = IcebergNestedField::new( + 0, + "element", + nested_element_type, + !element_type_is_optional, + ); + let array_type = IcebergListType::new(nested_type.into()); + IcebergType::List(array_type) + } + Type::Any | Type::Array(_, _) | Type::Tuple(_) => { + return Err(WriteError::UnsupportedType(type_.clone())) + } + }; + Ok(iceberg_type) } } @@ -200,6 +214,7 @@ impl IcebergBatchWriter { &namespace, db_params.warehouse.as_ref(), )?; + Ok(Self { runtime, catalog, @@ -254,6 +269,13 @@ impl LakeBatchWriter for IcebergBatchWriter { Ok::<(), WriteError>(()) }) } + + fn settings(&self) -> LakeWriterSettings { + LakeWriterSettings { + use_64bit_size_type: true, + utc_timezone_name: "+00:00".into(), + } + } } /// Wrapper for `FileScanTask` that allows to compare them. diff --git a/src/connectors/data_lake/mod.rs b/src/connectors/data_lake/mod.rs index 3b0cbf0e..871ba6d6 100644 --- a/src/connectors/data_lake/mod.rs +++ b/src/connectors/data_lake/mod.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; +use arcstr::ArcStr; use deltalake::arrow::array::types::{ DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, @@ -8,18 +9,23 @@ use deltalake::arrow::array::types::{ TimestampSecondType, UInt16Type, UInt32Type, UInt8Type, }; use deltalake::arrow::array::{ - Array as ArrowArray, ArrowPrimitiveType, AsArray, RecordBatch as ArrowRecordBatch, + Array as ArrowArray, ArrowPrimitiveType, AsArray, OffsetSizeTrait, + RecordBatch as ArrowRecordBatch, }; use deltalake::arrow::datatypes::{DataType as ArrowDataType, TimeUnit as ArrowTimeUnit}; use deltalake::datafusion::parquet::record::Field as ParquetValue; +use deltalake::datafusion::parquet::record::List as ParquetList; use deltalake::parquet::record::Row as ParquetRow; use half::f16; +use ndarray::ArrayD; use crate::connectors::data_storage::ConversionError; use crate::connectors::data_storage::ValuesMap; use crate::connectors::WriteError; use crate::engine::error::{limit_length, STANDARD_OBJECT_LENGTH_LIMIT}; -use crate::engine::{DateTimeNaive, DateTimeUtc, Duration as EngineDuration, Type, Value}; +use crate::engine::{ + value::Kind, DateTimeNaive, DateTimeUtc, Duration as EngineDuration, Type, Value, +}; pub mod delta; pub mod iceberg; @@ -31,10 +37,19 @@ pub use writer::LakeWriter; const SPECIAL_OUTPUT_FIELDS: [(&str, Type); 2] = [("time", Type::Int), ("diff", Type::Int)]; +pub struct LakeWriterSettings { + pub use_64bit_size_type: bool, + pub utc_timezone_name: ArcStr, +} + pub trait LakeBatchWriter: Send { fn write_batch(&mut self, batch: ArrowRecordBatch) -> Result<(), WriteError>; + + fn settings(&self) -> LakeWriterSettings; } +type ParsedValue = Result>; + // Commonly used routines for converting Parquet and Arrow data into Pathway values. pub fn parquet_row_into_values_map( @@ -47,45 +62,170 @@ pub fn parquet_row_into_values_map( // Column outside of the user-provided schema continue; }; + let value = parquet_value_into_pathway_value(parquet_value, expected_type, name); + row_map.insert(name.clone(), value); + } - let value = match (parquet_value, expected_type) { - (ParquetValue::Null, _) => Some(Value::None), - (ParquetValue::Bool(b), Type::Bool | Type::Any) => Some(Value::from(*b)), - (ParquetValue::Long(i), Type::Int | Type::Any) => Some(Value::from(*i)), - (ParquetValue::Long(i), Type::Duration) => Some(Value::from( - EngineDuration::new_with_unit(*i, "us").unwrap(), - )), - (ParquetValue::Double(f), Type::Float | Type::Any) => Some(Value::Float((*f).into())), - (ParquetValue::Str(s), Type::String | Type::Any) => Some(Value::String(s.into())), - (ParquetValue::Str(s), Type::Json) => serde_json::from_str::(s) - .ok() - .map(Value::from), - (ParquetValue::TimestampMicros(us), Type::DateTimeNaive | Type::Any) => Some( - Value::from(DateTimeNaive::from_timestamp(*us, "us").unwrap()), - ), - (ParquetValue::TimestampMicros(us), Type::DateTimeUtc) => { - Some(Value::from(DateTimeUtc::from_timestamp(*us, "us").unwrap())) + row_map.into() +} + +pub fn parquet_value_into_pathway_value( + parquet_value: &ParquetValue, + expected_type: &Type, + name: &str, +) -> ParsedValue { + let expected_type_unopt = expected_type.unoptionalize(); + let unchecked_value = match (parquet_value, expected_type_unopt) { + (ParquetValue::Null, _) => Some(Value::None), + (ParquetValue::Bool(b), Type::Bool | Type::Any) => Some(Value::from(*b)), + (ParquetValue::Long(i), Type::Int | Type::Any) => Some(Value::from(*i)), + (ParquetValue::Long(i), Type::Duration) => Some(Value::from( + EngineDuration::new_with_unit(*i, "us").unwrap(), + )), + (ParquetValue::Double(f), Type::Float | Type::Any) => Some(Value::Float((*f).into())), + (ParquetValue::Str(s), Type::String | Type::Any) => Some(Value::String(s.into())), + (ParquetValue::Str(s), Type::Json) => serde_json::from_str::(s) + .ok() + .map(Value::from), + (ParquetValue::TimestampMicros(us), Type::DateTimeNaive | Type::Any) => Some(Value::from( + DateTimeNaive::from_timestamp(*us, "us").unwrap(), + )), + (ParquetValue::TimestampMicros(us), Type::DateTimeUtc) => { + Some(Value::from(DateTimeUtc::from_timestamp(*us, "us").unwrap())) + } + (ParquetValue::Bytes(b), Type::Bytes | Type::Any) => Some(Value::Bytes(b.data().into())), + (ParquetValue::Bytes(b), Type::Pointer | Type::PyObjectWrapper) => { + if let Ok(value) = bincode::deserialize::(b.data()) { + match (value.kind(), expected_type_unopt) { + (Kind::Pointer, Type::Pointer) + | (Kind::PyObjectWrapper, Type::PyObjectWrapper) => Some(value), + _ => None, + } + } else { + None } - (ParquetValue::Bytes(b), Type::Bytes | Type::Any) => { - Some(Value::Bytes(b.data().into())) + } + (ParquetValue::ListInternal(parquet_list), Type::List(nested_type)) => { + let mut values = Vec::new(); + for element in parquet_list.elements() { + values.push(parquet_value_into_pathway_value( + element, + nested_type, + name, + )?); } - _ => None, - }; - let value = if let Some(value) = value { - Ok(value) + Some(Value::Tuple(values.into())) + } + (ParquetValue::Group(row), Type::Array(_, array_type)) => { + parse_pathway_array_from_parquet_row(row, array_type) + } + (ParquetValue::Group(row), Type::Tuple(nested_types)) => { + parse_pathway_tuple_from_row(row, nested_types) + } + _ => None, + }; + + let expected_type_is_optional = expected_type.is_optional(); + let unchecked_value_is_none = unchecked_value == Some(Value::None); + let value = if unchecked_value_is_none && !expected_type_is_optional { + None + } else { + unchecked_value + }; + + if let Some(value) = value { + Ok(value) + } else { + let value_repr = limit_length(format!("{parquet_value:?}"), STANDARD_OBJECT_LENGTH_LIMIT); + Err(Box::new(ConversionError { + value_repr, + field_name: name.to_string(), + type_: expected_type.clone(), + })) + } +} + +pub fn parse_pathway_tuple_from_row(row: &ParquetRow, nested_types: &[Type]) -> Option { + let mut tuple_contents: Vec> = vec![None; nested_types.len()]; + for (column_name, parquet_value) in row.get_column_iter() { + // Column name has format [index], so we need to skip the first and the last + // character to obtain the sequential index + let str_index = &column_name[1..(column_name.len() - 1)]; + let index: usize = str_index.parse().ok()?; + if index >= nested_types.len() { + return None; + } + tuple_contents[index] = + parquet_value_into_pathway_value(parquet_value, &nested_types[index], "").ok(); + } + let mut tuple_values = Vec::new(); + for tuple_value in tuple_contents { + tuple_values.push(tuple_value?); + } + Some(Value::Tuple(tuple_values.into())) +} + +pub fn parse_pathway_array_from_parquet_row(row: &ParquetRow, array_type: &Type) -> Option { + let shape_i64 = parse_int_array_from_parquet_row(row, "shape")?; + let mut shape: Vec = Vec::new(); + for element in shape_i64 { + shape.push(element.try_into().ok()?); + } + match array_type { + Type::Int => { + let values = parse_int_array_from_parquet_row(row, "elements")?; + let array_impl = ArrayD::::from_shape_vec(shape, values).ok()?; + Some(Value::from(array_impl)) + } + Type::Float => { + let values = parse_float_array_from_parquet_row(row, "elements")?; + let array_impl = ArrayD::::from_shape_vec(shape, values).ok()?; + Some(Value::from(array_impl)) + } + _ => panic!("this method should not be used for types other than Int or Float"), + } +} + +fn parse_int_array_from_parquet_row(row: &ParquetRow, name: &str) -> Option> { + let mut result = Vec::new(); + let list_field = parse_list_field_from_parquet_row(row, name)?; + for element in list_field.elements() { + if let ParquetValue::Long(v) = element { + result.push(*v); } else { - let value_repr = - limit_length(format!("{parquet_value:?}"), STANDARD_OBJECT_LENGTH_LIMIT); - Err(Box::new(ConversionError { - value_repr, - field_name: name.clone(), - type_: expected_type.clone(), - })) - }; - row_map.insert(name.clone(), value); + return None; + } } + Some(result) +} - row_map.into() +fn parse_float_array_from_parquet_row(row: &ParquetRow, name: &str) -> Option> { + let mut result = Vec::new(); + let list_field = parse_list_field_from_parquet_row(row, name)?; + for element in list_field.elements() { + if let ParquetValue::Double(v) = element { + result.push(*v); + } else { + return None; + } + } + Some(result) +} + +fn parse_list_field_from_parquet_row<'a>( + row: &'a ParquetRow, + name: &str, +) -> Option<&'a ParquetList> { + for (column_name, parquet_value) in row.get_column_iter() { + if column_name != name { + continue; + } + if let ParquetValue::ListInternal(list_field) = parquet_value { + return Some(list_field); + } + break; + } + None } pub fn columns_into_pathway_values( @@ -99,100 +239,169 @@ pub fn columns_into_pathway_values( let Some(column) = entry.column_by_name(column_name) else { continue; }; - let arrow_type = column.data_type(); - let values_vector = match (arrow_type, expected_type.unoptionalize()) { - (ArrowDataType::Null, _) => vec![Ok(Value::None); rows_count], - (ArrowDataType::Int64, Type::Int | Type::Any) => { - convert_arrow_array::(column, |v| Ok(Value::Int(v))) - } - (ArrowDataType::Int32, Type::Int | Type::Any) => { - convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) - } - (ArrowDataType::Int16, Type::Int | Type::Any) => { - convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) - } - (ArrowDataType::Int8, Type::Int | Type::Any) => { - convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) - } - (ArrowDataType::UInt32, Type::Int | Type::Any) => { - convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) - } - (ArrowDataType::UInt16, Type::Int | Type::Any) => { - convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) - } - (ArrowDataType::UInt8, Type::Int | Type::Any) => { - convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) - } - (ArrowDataType::Float64, Type::Float | Type::Any) => { - convert_arrow_array::(column, |v| Ok(Value::Float(v.into()))) - } - (ArrowDataType::Float32, Type::Float | Type::Any) => { - convert_arrow_array::(column, |v| { - Ok(Value::Float(Into::::into(v).into())) - }) - } - (ArrowDataType::Float16, Type::Float | Type::Any) => { - convert_arrow_array::(column, |v| { - Ok(Value::Float(Into::::into(v).into())) - }) - } - (ArrowDataType::Boolean, Type::Bool | Type::Any) => convert_arrow_boolean_array(column), - (ArrowDataType::Utf8, Type::String | Type::Any) => convert_arrow_string_array(column), - (ArrowDataType::Utf8, Type::Json) => { - convert_arrow_json_array(column, column_name, expected_type) - } - (ArrowDataType::Binary, Type::Bytes | Type::Any) => convert_arrow_bytes_array(column), - (ArrowDataType::Duration(time_unit), Type::Duration | Type::Any) => { - convert_arrow_duration_array(column, *time_unit) - } - (ArrowDataType::Int64, Type::Duration) => { - // Compatibility clause: there is no duration type in Delta Lake, - // so int64 is used to store duration. - // Since the timestamp types in DeltaLake are stored in microseconds, - // we need to convert the duration to microseconds. - convert_arrow_array::(column, |v| { - Ok(Value::Duration( - EngineDuration::new_with_unit(v, "us").unwrap(), - )) - }) - } - (ArrowDataType::Timestamp(time_unit, None), Type::DateTimeNaive | Type::Any) => { - convert_arrow_timestamp_array_naive(column, *time_unit) - } - ( - ArrowDataType::Timestamp(time_unit, Some(timezone)), - Type::DateTimeUtc | Type::Any, - ) => convert_arrow_timestamp_array_utc(column, *time_unit, timezone.as_ref()), - (arrow_type, expected_type) => { - vec![ - Err(Box::new(ConversionError { - value_repr: format!("{arrow_type:?}"), - field_name: column_name.clone(), - type_: expected_type.clone(), - })); - rows_count - ] - } - }; - let is_optional = matches!(expected_type, Type::Optional(_)); + let values_vector = + column_into_pathway_values(column, expected_type, column_name, rows_count); for (index, value) in values_vector.into_iter().enumerate() { - let prepared_value = if value == Ok(Value::None) && !is_optional { + result[index].insert(column_name.clone(), value); + } + } + + result.into_iter().map(std::convert::Into::into).collect() +} + +#[allow(clippy::too_many_lines)] +fn column_into_pathway_values( + column: &Arc, + expected_type: &Type, + column_name: &str, + rows_count: usize, +) -> Vec { + let arrow_type = column.data_type(); + let expected_type_unopt = expected_type.unoptionalize(); + let mut values_vector = match (arrow_type, expected_type_unopt) { + (ArrowDataType::Null, _) => vec![Ok(Value::None); rows_count], + (ArrowDataType::Int64, Type::Int | Type::Any) => { + convert_arrow_array::(column, |v| Ok(Value::Int(v))) + } + (ArrowDataType::Int32, Type::Int | Type::Any) => { + convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) + } + (ArrowDataType::Int16, Type::Int | Type::Any) => { + convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) + } + (ArrowDataType::Int8, Type::Int | Type::Any) => { + convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) + } + (ArrowDataType::UInt32, Type::Int | Type::Any) => { + convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) + } + (ArrowDataType::UInt16, Type::Int | Type::Any) => { + convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) + } + (ArrowDataType::UInt8, Type::Int | Type::Any) => { + convert_arrow_array::(column, |v| Ok(Value::Int(v.into()))) + } + (ArrowDataType::Float64, Type::Float | Type::Any) => { + convert_arrow_array::(column, |v| Ok(Value::Float(v.into()))) + } + (ArrowDataType::Float32, Type::Float | Type::Any) => { + convert_arrow_array::(column, |v| { + Ok(Value::Float(Into::::into(v).into())) + }) + } + (ArrowDataType::Float16, Type::Float | Type::Any) => { + convert_arrow_array::(column, |v| { + Ok(Value::Float(Into::::into(v).into())) + }) + } + (ArrowDataType::Boolean, Type::Bool | Type::Any) => convert_arrow_boolean_array(column), + (ArrowDataType::Utf8, Type::String | Type::Any) => { + convert_arrow_string_array::(column) + } + (ArrowDataType::Utf8, Type::Json) => { + convert_arrow_json_array::(column, column_name, expected_type) + } + (ArrowDataType::LargeUtf8, Type::String | Type::Any) => { + convert_arrow_string_array::(column) + } + (ArrowDataType::LargeUtf8, Type::Json) => { + convert_arrow_json_array::(column, column_name, expected_type) + } + ( + ArrowDataType::Binary, + Type::Bytes | Type::Pointer | Type::PyObjectWrapper | Type::Any, + ) => convert_arrow_bytes_array::(column, column_name, expected_type_unopt), + ( + ArrowDataType::LargeBinary, + Type::Bytes | Type::Pointer | Type::PyObjectWrapper | Type::Any, + ) => convert_arrow_bytes_array::(column, column_name, expected_type_unopt), + (ArrowDataType::Duration(time_unit), Type::Duration | Type::Any) => { + convert_arrow_duration_array(column, *time_unit) + } + (ArrowDataType::Int64, Type::Duration) => { + // Compatibility clause: there is no duration type in Delta Lake, + // so int64 is used to store duration. + // Since the timestamp types in DeltaLake are stored in microseconds, + // we need to convert the duration to microseconds. + convert_arrow_array::(column, |v| { + Ok(Value::Duration( + EngineDuration::new_with_unit(v, "us").unwrap(), + )) + }) + } + (ArrowDataType::Timestamp(time_unit, None), Type::DateTimeNaive | Type::Any) => { + convert_arrow_timestamp_array_naive(column, *time_unit) + } + (ArrowDataType::Timestamp(time_unit, Some(timezone)), Type::DateTimeUtc | Type::Any) => { + convert_arrow_timestamp_array_utc(column, *time_unit, timezone.as_ref()) + } + (ArrowDataType::List(_), Type::List(_) | Type::Any) => { + convert_arrow_list_array::(column, expected_type, column_name, column.len()) + } + (ArrowDataType::LargeList(_), Type::List(_) | Type::Any) => { + convert_arrow_list_array::(column, expected_type, column_name, column.len()) + } + (arrow_type, expected_type) => { + vec![ Err(Box::new(ConversionError { + value_repr: format!("{arrow_type:?}"), + field_name: column_name.to_string(), + type_: expected_type.clone(), + })); + rows_count + ] + } + }; + + let is_optional = expected_type.is_optional(); + if !is_optional { + for value in &mut values_vector { + if value == &Ok(Value::None) { + *value = Err(Box::new(ConversionError { value_repr: "null".to_string(), - field_name: column_name.clone(), + field_name: column_name.to_string(), type_: expected_type.clone(), - })) - } else { - value - }; - result[index].insert(column_name.clone(), prepared_value); + })); + } } } - result.into_iter().map(std::convert::Into::into).collect() + values_vector } -type ParsedValue = Result>; +fn pathway_tuple_from_parsed_values(nested_list_contents: Vec) -> ParsedValue { + let mut prepared_values = Vec::new(); + for value in nested_list_contents { + prepared_values.push(value?); + } + Ok(Value::Tuple(prepared_values.into())) +} + +fn convert_arrow_list_array( + column: &Arc, + expected_type: &Type, + column_name: &str, + rows_count: usize, +) -> Vec { + let nested_type = match expected_type.unoptionalize() { + Type::Any => Type::Any, + Type::List(nested_type) => nested_type.as_ref().clone(), + _ => unreachable!(), + }; + let mut result = Vec::new(); + for element in column.as_list::().iter() { + let parsed_value = match element { + Some(element) => { + let nested_list_contents = + column_into_pathway_values(&element, &nested_type, column_name, rows_count); + pathway_tuple_from_parsed_values(nested_list_contents) + } + None => Ok(Value::None), + }; + result.push(parsed_value); + } + result +} fn convert_arrow_array>( column: &Arc, @@ -208,13 +417,13 @@ fn convert_arrow_array>( .collect() } -fn convert_arrow_json_array( +fn convert_arrow_json_array( column: &Arc, name: &str, expected_type: &Type, ) -> Vec { column - .as_string::() + .as_string::() .into_iter() .map(|v| match v { Some(v) => serde_json::from_str::(v) @@ -231,9 +440,11 @@ fn convert_arrow_json_array( .collect() } -fn convert_arrow_string_array(column: &Arc) -> Vec { +fn convert_arrow_string_array( + column: &Arc, +) -> Vec { column - .as_string::() + .as_string::() .into_iter() .map(|v| match v { Some(v) => Ok(Value::String(v.into())), @@ -253,12 +464,39 @@ fn convert_arrow_boolean_array(column: &Arc) -> Vec .collect() } -fn convert_arrow_bytes_array(column: &Arc) -> Vec { +fn convert_arrow_bytes_array( + column: &Arc, + field_name: &str, + expected_type: &Type, +) -> Vec { column - .as_binary::() + .as_binary::() .into_iter() .map(|v| match v { - Some(v) => Ok(Value::Bytes(v.into())), + Some(v) => { + if expected_type == &Type::Bytes { + Ok(Value::Bytes(v.into())) + } else { + let maybe_value = bincode::deserialize::(v); + if let Ok(value) = maybe_value { + match (value.kind(), expected_type) { + (Kind::Pointer, Type::Pointer) + | (Kind::PyObjectWrapper, Type::PyObjectWrapper) => Ok(value), + _ => Err(Box::new(ConversionError { + value_repr: format!("{value}"), + field_name: field_name.to_string(), + type_: expected_type.clone(), + })), + } + } else { + Err(Box::new(ConversionError { + value_repr: format!("{maybe_value:?}"), + field_name: field_name.to_string(), + type_: expected_type.clone(), + })) + } + } + } None => Ok(Value::None), }) .collect() diff --git a/src/connectors/data_lake/writer.rs b/src/connectors/data_lake/writer.rs index 827430e3..3d5f3bc9 100644 --- a/src/connectors/data_lake/writer.rs +++ b/src/connectors/data_lake/writer.rs @@ -4,19 +4,25 @@ use std::time::{Duration, Instant}; use deltalake::arrow::array::Array as ArrowArray; use deltalake::arrow::array::RecordBatch as ArrowRecordBatch; use deltalake::arrow::array::{ - BinaryArray as ArrowBinaryArray, BooleanArray as ArrowBooleanArray, + BinaryArray as ArrowBinaryArray, BooleanArray as ArrowBooleanArray, BooleanBufferBuilder, Float64Array as ArrowFloat64Array, Int64Array as ArrowInt64Array, - StringArray as ArrowStringArray, TimestampMicrosecondArray as ArrowTimestampArray, + LargeBinaryArray as ArrowLargeBinaryArray, LargeListArray as ArrowLargeListArray, + ListArray as ArrowListArray, StringArray as ArrowStringArray, StructArray as ArrowStructArray, + TimestampMicrosecondArray as ArrowTimestampArray, }; +use deltalake::arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use deltalake::arrow::datatypes::{ - DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema, + DataType as ArrowDataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema, TimeUnit as ArrowTimeUnit, }; +use ndarray::ArrayD; +use super::LakeWriterSettings; use crate::connectors::data_format::FormatterContext; use crate::connectors::data_lake::LakeBatchWriter; use crate::connectors::{WriteError, Writer}; use crate::engine::time::DateTime as EngineDateTime; +use crate::engine::value::Handle; use crate::engine::{Type, Value}; use crate::python_api::ValueField; @@ -34,12 +40,15 @@ pub struct LakeWriter { impl LakeWriter { pub fn new( batch_writer: Box, - value_fields: &Vec, + value_fields: &[ValueField], min_commit_frequency: Option, ) -> Result { - let schema = Arc::new(Self::construct_schema(value_fields)?); + let schema = Arc::new(Self::construct_schema( + value_fields, + &batch_writer.settings(), + )?); let mut empty_buffered_columns = Vec::new(); - empty_buffered_columns.resize_with(schema.flattened_fields().len(), Vec::new); + empty_buffered_columns.resize_with(schema.fields().len(), Vec::new); Ok(Self { batch_writer, schema, @@ -53,7 +62,7 @@ impl LakeWriter { } fn array_of_target_type( - values: &Vec, + values: &[Value], mut to_simple_type: impl FnMut(&Value) -> Result, ) -> Result>, WriteError> { let mut values_vec: Vec> = Vec::new(); @@ -69,7 +78,7 @@ impl LakeWriter { fn arrow_array_for_type( type_: &ArrowDataType, - values: &Vec, + values: &[Value], ) -> Result, WriteError> { match type_ { ArrowDataType::Boolean => { @@ -103,16 +112,23 @@ impl LakeWriter { })?; Ok(Arc::new(ArrowStringArray::from(v))) } - ArrowDataType::Binary => { + ArrowDataType::Binary | ArrowDataType::LargeBinary => { let mut vec_owned = Self::array_of_target_type::>(values, |v| match v { Value::Bytes(b) => Ok(b.to_vec()), + Value::PyObjectWrapper(_) | Value::Pointer(_) => { + Ok(bincode::serialize(v).map_err(|e| *e)?) + } _ => Err(WriteError::TypeMismatchWithSchema(v.clone(), type_.clone())), })?; let mut vec_refs = Vec::new(); for item in &mut vec_owned { vec_refs.push(item.as_mut().map(|v| v.as_slice())); } - Ok(Arc::new(ArrowBinaryArray::from(vec_refs))) + if *type_ == ArrowDataType::Binary { + Ok(Arc::new(ArrowBinaryArray::from(vec_refs))) + } else { + Ok(Arc::new(ArrowLargeBinaryArray::from(vec_refs))) + } } ArrowDataType::Timestamp(ArrowTimeUnit::Microsecond, None) => { let v = Self::array_of_target_type::(values, |v| match v { @@ -130,17 +146,142 @@ impl LakeWriter { })?; Ok(Arc::new(ArrowTimestampArray::from(v).with_timezone(&**tz))) } + ArrowDataType::List(nested_type) => { + Self::arrow_array_of_lists(values, nested_type, false) + } + ArrowDataType::LargeList(nested_type) => { + Self::arrow_array_of_lists(values, nested_type, true) + } + ArrowDataType::Struct(nested_struct) => { + Self::arrow_array_of_structs(values, nested_struct.as_ref()) + } _ => panic!("provided type {type_} is unknown to the engine"), } } + fn arrow_array_of_structs( + values: &[Value], + nested_types: &[Arc], + ) -> Result, WriteError> { + // Step 1. Decompose struct into separate columns + let mut struct_columns: Vec> = vec![Vec::new(); nested_types.len()]; + let mut defined_fields_map = BooleanBufferBuilder::new(values.len()); + defined_fields_map.resize(values.len()); + for (index, value) in values.iter().enumerate() { + defined_fields_map.set_bit(index, value != &Value::None); + match value { + Value::None => { + for item in &mut struct_columns { + item.push(Value::None); + } + } + Value::IntArray(a) => { + struct_columns[0].push(Self::convert_shape_to_pathway_tuple(a.shape())); + struct_columns[1].push(Self::convert_contents_to_pathway_tuple(a)); + } + Value::FloatArray(a) => { + struct_columns[0].push(Self::convert_shape_to_pathway_tuple(a.shape())); + struct_columns[1].push(Self::convert_contents_to_pathway_tuple(a)); + } + Value::Tuple(tuple_elements) => { + for (index, field) in tuple_elements.iter().enumerate() { + struct_columns[index].push(field.clone()); + } + } + _ => panic!("Pathway type {value} is not serializable as an arrow tuple"), + } + } + + // Step 2. Create Arrow arrays for the separate columns + let mut arrow_arrays = Vec::new(); + for (struct_column, arrow_field) in struct_columns.iter().zip(nested_types) { + let arrow_array = Self::arrow_array_for_type(arrow_field.data_type(), struct_column)?; + arrow_arrays.push(arrow_array); + } + + // Step 3. Create a struct array + let struct_array: Arc = Arc::new(ArrowStructArray::new( + nested_types.into(), + arrow_arrays, + Some(NullBuffer::new(defined_fields_map.finish())), + )); + Ok(struct_array) + } + + fn convert_shape_to_pathway_tuple(shape: &[usize]) -> Value { + let tuple_contents: Vec<_> = shape + .iter() + .map(|v| Value::Int((*v).try_into().unwrap())) + .collect(); + Value::Tuple(tuple_contents.into()) + } + + fn convert_contents_to_pathway_tuple + Clone>( + contents: &Handle>, + ) -> Value + where + Value: std::convert::From, + { + let tuple_contents: Vec<_> = contents.iter().map(|v| Value::from((*v).clone())).collect(); + Value::Tuple(tuple_contents.into()) + } + + fn arrow_array_of_lists( + values: &[Value], + nested_type: &Arc, + use_64bit_size_type: bool, + ) -> Result, WriteError> { + let mut flat_values = Vec::new(); + let mut offsets = Vec::new(); + + let mut defined_fields_map = BooleanBufferBuilder::new(values.len()); + defined_fields_map.resize(values.len()); + for (index, value) in values.iter().enumerate() { + offsets.push(flat_values.len()); + let Value::Tuple(list) = value else { + defined_fields_map.set_bit(index, false); + continue; + }; + defined_fields_map.set_bit(index, true); + for nested_value in list.as_ref() { + flat_values.push(nested_value.clone()); + } + } + offsets.push(flat_values.len()); + + let flat_values = Self::arrow_array_for_type(nested_type.data_type(), &flat_values)?; + + let list_array: Arc = if use_64bit_size_type { + let offsets: Vec = offsets.into_iter().map(|v| v.try_into().unwrap()).collect(); + let scalar_buffer = ScalarBuffer::from(offsets); + let offset_buffer = OffsetBuffer::new(scalar_buffer); + Arc::new(ArrowLargeListArray::new( + nested_type.clone(), + offset_buffer, + flat_values, + Some(NullBuffer::new(defined_fields_map.finish())), + )) + } else { + let offsets: Vec = offsets.into_iter().map(|v| v.try_into().unwrap()).collect(); + let scalar_buffer = ScalarBuffer::from(offsets); + let offset_buffer = OffsetBuffer::new(scalar_buffer); + Arc::new(ArrowListArray::new( + nested_type.clone(), + offset_buffer, + flat_values, + Some(NullBuffer::new(defined_fields_map.finish())), + )) + }; + + Ok(list_array) + } + fn prepare_arrow_batch(&self) -> Result { let mut data_columns = Vec::new(); for (index, column) in self.buffered_columns.iter().enumerate() { - data_columns.push(Self::arrow_array_for_type( - self.schema.field(index).data_type(), - column, - )?); + let arrow_array = + Self::arrow_array_for_type(self.schema.field(index).data_type(), column)?; + data_columns.push(arrow_array); } Ok(ArrowRecordBatch::try_new( self.schema.clone(), @@ -148,41 +289,97 @@ impl LakeWriter { )?) } - fn arrow_data_type(type_: &Type) -> Result { + fn arrow_data_type( + type_: &Type, + settings: &LakeWriterSettings, + ) -> Result { Ok(match type_ { Type::Bool => ArrowDataType::Boolean, Type::Int | Type::Duration => ArrowDataType::Int64, Type::Float => ArrowDataType::Float64, - Type::Pointer | Type::String | Type::Json => ArrowDataType::Utf8, - Type::Bytes => ArrowDataType::Binary, + Type::String | Type::Json => ArrowDataType::Utf8, + Type::Bytes | Type::Pointer | Type::PyObjectWrapper => { + if settings.use_64bit_size_type { + ArrowDataType::LargeBinary + } else { + ArrowDataType::Binary + } + } // DeltaLake timestamps are stored in microseconds: // https://docs.rs/deltalake/latest/deltalake/kernel/enum.PrimitiveType.html#variant.Timestamp Type::DateTimeNaive => ArrowDataType::Timestamp(ArrowTimeUnit::Microsecond, None), - Type::DateTimeUtc => { - ArrowDataType::Timestamp(ArrowTimeUnit::Microsecond, Some("UTC".into())) - } - Type::Optional(wrapped) => return Self::arrow_data_type(wrapped), - Type::Any - | Type::Array(_, _) - | Type::Tuple(_) - | Type::List(_) - | Type::PyObjectWrapper => return Err(WriteError::UnsupportedType(type_.clone())), + Type::DateTimeUtc => ArrowDataType::Timestamp( + ArrowTimeUnit::Microsecond, + Some(settings.utc_timezone_name.clone().into()), + ), + Type::Optional(wrapped) => return Self::arrow_data_type(wrapped, settings), + Type::List(wrapped_type) => { + let wrapped_type_is_optional = wrapped_type.is_optional(); + let wrapped_arrow_type = Self::arrow_data_type(wrapped_type, settings)?; + let list_field = + ArrowField::new("element", wrapped_arrow_type, wrapped_type_is_optional); + ArrowDataType::List(list_field.into()) + } + Type::Array(_, wrapped_type) => { + let wrapped_type = wrapped_type.as_ref(); + let elements_arrow_type = match wrapped_type { + Type::Int => ArrowDataType::Int64, + Type::Float => ArrowDataType::Float64, + _ => panic!("Type::Array can't contain elements of the type {wrapped_type:?}"), + }; + let struct_fields_vector = vec![ + ArrowField::new( + "shape", + ArrowDataType::List( + ArrowField::new("element", ArrowDataType::Int64, true).into(), + ), + false, + ), + ArrowField::new( + "elements", + ArrowDataType::List( + ArrowField::new("element", elements_arrow_type, true).into(), + ), + false, + ), + ]; + let struct_fields = ArrowFields::from(struct_fields_vector); + ArrowDataType::Struct(struct_fields) + } + Type::Tuple(wrapped_types) => { + let mut struct_fields = Vec::new(); + for (index, wrapped_type) in wrapped_types.iter().enumerate() { + let nested_arrow_type = Self::arrow_data_type(wrapped_type, settings)?; + let nested_type_is_optional = wrapped_type.is_optional(); + struct_fields.push(ArrowField::new( + format!("[{index}]"), + nested_arrow_type, + nested_type_is_optional, + )); + } + let struct_descriptor = ArrowFields::from(struct_fields); + ArrowDataType::Struct(struct_descriptor) + } + Type::Any => return Err(WriteError::UnsupportedType(type_.clone())), }) } - pub fn construct_schema(value_fields: &Vec) -> Result { + pub fn construct_schema( + value_fields: &[ValueField], + settings: &LakeWriterSettings, + ) -> Result { let mut schema_fields: Vec = Vec::new(); for field in value_fields { schema_fields.push(ArrowField::new( field.name.clone(), - Self::arrow_data_type(&field.type_)?, + Self::arrow_data_type(&field.type_, settings)?, field.type_.can_be_none(), )); } for (field, type_) in SPECIAL_OUTPUT_FIELDS { schema_fields.push(ArrowField::new( field, - Self::arrow_data_type(&type_)?, + Self::arrow_data_type(&type_, settings)?, false, )); } diff --git a/src/engine/time.rs b/src/engine/time.rs index 71ad059f..f08c0d65 100644 --- a/src/engine/time.rs +++ b/src/engine/time.rs @@ -1,7 +1,9 @@ // Copyright © 2024 Pathway use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; +use std::str::FromStr; +use chrono::offset::FixedOffset; use chrono::{self, DurationRound, LocalResult, TimeZone}; use chrono::{Datelike, Timelike}; use chrono_tz::Tz; @@ -163,34 +165,9 @@ impl DateTimeNaive { } pub fn to_utc_from_timezone(&self, timezone: &str) -> DataResult { - match timezone.parse::() { - Ok(tz) => { - let naive_local = self.as_chrono_datetime(); - let localized = tz.from_local_datetime(&naive_local); - match localized { - LocalResult::Single(localized) | LocalResult::Ambiguous(_, localized) => { - Ok(localized.into()) - } - LocalResult::None => { - // This NaiveDateTime doesn't exist in a given timezone. - // We try getting a first date after this. - let moved = naive_local + chrono::Duration::try_minutes(30).unwrap(); - let rounded = moved - .duration_round(chrono::Duration::try_hours(1).unwrap()) - .unwrap(); - let localized = tz.from_local_datetime(&rounded); - if let LocalResult::Single(localized) = localized { - Ok(localized.into()) - } else { - Err(DataError::DateTimeConversionError) - } - } - } - } - Err(e) => Err(DataError::ParseError(format!( - "cannot parse time zone {timezone:?}: {e}" - ))), - } + let naive_local = self.as_chrono_datetime(); + to_utc_from_timezone::(naive_local, timezone) + .or_else(|_err| to_utc_from_timezone::(naive_local, timezone)) } #[must_use] @@ -291,17 +268,9 @@ impl DateTimeUtc { } pub fn to_naive_in_timezone(&self, timezone: &str) -> DataResult { - match timezone.parse::() { - Ok(tz) => { - let naive_utc = self.as_chrono_datetime(); - let localized = tz.from_utc_datetime(&naive_utc); - let naive_local = localized.naive_local(); - Ok(naive_local.into()) - } - Err(e) => Err(DataError::ParseError(format!( - "cannot parse time zone {timezone:?}: {e}" - ))), - } + let naive_utc = self.as_chrono_datetime(); + to_naive_in_timezone::(naive_utc, timezone) + .or_else(|_err| to_naive_in_timezone::(naive_utc, timezone)) } #[must_use] @@ -579,3 +548,59 @@ impl Display for Duration { write!(fmt, "{}", output.join(" ")) } } + +fn to_utc_from_timezone( + naive_local: chrono::NaiveDateTime, + timezone: &str, +) -> DataResult +where + Tz: TimeZone + FromStr, + ::Err: std::fmt::Display, +{ + let tz = match timezone.parse::() { + Ok(tz) => Ok(tz), + Err(e) => Err(DataError::ParseError(format!( + "cannot parse time zone {timezone:?}: {e}" + ))), + }?; + let localized = tz.from_local_datetime(&naive_local); + match localized { + LocalResult::Single(localized) | LocalResult::Ambiguous(_, localized) => { + Ok(localized.into()) + } + LocalResult::None => { + // This NaiveDateTime doesn't exist in a given timezone. + // We try getting a first date after this. + let moved = naive_local + chrono::Duration::try_minutes(30).unwrap(); + let rounded = moved + .duration_round(chrono::Duration::try_hours(1).unwrap()) + .unwrap(); + let localized = tz.from_local_datetime(&rounded); + if let LocalResult::Single(localized) = localized { + Ok(localized.into()) + } else { + Err(DataError::DateTimeConversionError) + } + } + } +} + +pub fn to_naive_in_timezone( + naive_utc: chrono::NaiveDateTime, + timezone: &str, +) -> DataResult +where + Tz: TimeZone + FromStr, + ::Err: std::fmt::Display, +{ + match timezone.parse::() { + Ok(tz) => { + let localized = tz.from_utc_datetime(&naive_utc); + let naive_local = localized.naive_local(); + Ok(naive_local.into()) + } + Err(e) => Err(DataError::ParseError(format!( + "cannot parse time zone {timezone:?}: {e}" + ))), + } +} diff --git a/src/engine/value.rs b/src/engine/value.rs index 922d3f5d..22a7b39f 100644 --- a/src/engine/value.rs +++ b/src/engine/value.rs @@ -527,12 +527,17 @@ impl Type { pub fn can_be_none(&self) -> bool { matches!(self, Self::Optional(_) | Self::Any) } + pub fn unoptionalize(&self) -> &Self { match self { Self::Optional(arg) => arg, type_ => type_, } } + + pub fn is_optional(&self) -> bool { + matches!(self, Self::Optional(_)) + } } impl Display for Type { diff --git a/tests/integration/test_arrow.rs b/tests/integration/test_arrow.rs index 74193d5d..d037af47 100644 --- a/tests/integration/test_arrow.rs +++ b/tests/integration/test_arrow.rs @@ -1,3 +1,4 @@ +use assert_matches::assert_matches; use std::collections::HashMap; use std::sync::mpsc; @@ -6,7 +7,7 @@ use serde_json::json; use pathway_engine::connectors::data_format::FormatterContext; use pathway_engine::connectors::data_lake::columns_into_pathway_values; -use pathway_engine::connectors::data_lake::LakeBatchWriter; +use pathway_engine::connectors::data_lake::{LakeBatchWriter, LakeWriterSettings}; use pathway_engine::connectors::data_storage::{LakeWriter, WriteError, Writer}; use pathway_engine::engine::{ DateTimeNaive, DateTimeUtc, Duration as EngineDuration, Key, Timestamp, Type, Value, @@ -28,6 +29,13 @@ impl LakeBatchWriter for ArrowBatchWriter { self.sender.send(batch).unwrap(); Ok(()) } + + fn settings(&self) -> LakeWriterSettings { + LakeWriterSettings { + use_64bit_size_type: false, + utc_timezone_name: "UTC".into(), + } + } } fn run_arrow_roadtrip(type_: Type, values: Vec) -> eyre::Result<()> { @@ -66,7 +74,10 @@ fn run_arrow_roadtrip(type_: Type, values: Vec) -> eyre::Result<()> { assert_eq!(values_roundtrip, values); if !matches!(type_, Type::Optional(_)) { - // Check that the same process works when the type is optional. + // If the type isn't optional, we run a test for its optional version. + // To do that, we create an optional version of the type, append a null-value + // to the end of the tested values vector, and run test on the parameters + // modified this way. let mut values_with_nulls = values.clone(); values_with_nulls.push(Value::None); run_arrow_roadtrip(Type::Optional(type_.clone().into()), values_with_nulls)?; @@ -167,3 +178,67 @@ fn test_save_datetimeutc() -> eyre::Result<()> { fn test_save_json() -> eyre::Result<()> { run_arrow_roadtrip(Type::Json, vec![Value::from(json!({"A": 100}))]) } + +#[test] +fn test_save_pointer() -> eyre::Result<()> { + run_arrow_roadtrip(Type::Pointer, vec![Value::Pointer(Key::random())]) +} + +#[test] +fn test_save_array() -> eyre::Result<()> { + run_arrow_roadtrip(Type::Pointer, vec![Value::Pointer(Key::random())]) +} + +#[test] +fn test_save_list() -> eyre::Result<()> { + let value_list_1 = vec![ + Value::Duration(EngineDuration::new_with_unit(-1, "s")?), + Value::Duration(EngineDuration::new_with_unit(2, "ms")?), + Value::Duration(EngineDuration::new_with_unit(0, "ns")?), + ]; + let value_list_2 = vec![ + Value::Duration(EngineDuration::new_with_unit(-10, "s")?), + Value::Duration(EngineDuration::new_with_unit(20, "ms")?), + Value::Duration(EngineDuration::new_with_unit(0, "ns")?), + ]; + run_arrow_roadtrip( + Type::List(Type::Duration.into()), + vec![ + Value::Tuple(value_list_1.into()), + Value::Tuple(value_list_2.into()), + ], + ) +} + +#[test] +fn test_save_optionals_list() -> eyre::Result<()> { + let value_list_1 = vec![ + Value::Duration(EngineDuration::new_with_unit(-1, "s")?), + Value::Duration(EngineDuration::new_with_unit(2, "ms")?), + Value::Duration(EngineDuration::new_with_unit(0, "ns")?), + Value::None, + ]; + let value_list_2 = vec![ + Value::Duration(EngineDuration::new_with_unit(-10, "s")?), + Value::None, + Value::Duration(EngineDuration::new_with_unit(20, "ms")?), + Value::Duration(EngineDuration::new_with_unit(0, "ns")?), + ]; + run_arrow_roadtrip( + Type::List(Type::Optional(Type::Duration.into()).into()), + vec![ + Value::Tuple(value_list_1.into()), + Value::Tuple(value_list_2.into()), + ], + ) +} + +#[test] +fn test_save_any_is_unsupported() -> eyre::Result<()> { + let save_result = run_arrow_roadtrip(Type::Any, vec![Value::from(json!({"A": 100}))]); + assert_matches!( + save_result.err().unwrap().downcast::(), + Ok(WriteError::UnsupportedType(_)) + ); + Ok(()) +} diff --git a/tests/integration/test_deltalake.rs b/tests/integration/test_deltalake.rs index d802f43f..b7319c77 100644 --- a/tests/integration/test_deltalake.rs +++ b/tests/integration/test_deltalake.rs @@ -5,20 +5,22 @@ use std::path::Path; use assert_matches::assert_matches; use deltalake::datafusion::parquet::file::reader::SerializedFileReader; -use deltalake::datafusion::parquet::record::Field as ParquetField; +use ndarray::ArrayD; use serde_json::json; use tempfile::tempdir; use pathway_engine::connectors::data_format::{ Formatter, IdentityFormatter, InnerSchemaField, ParsedEvent, TransparentParser, }; -use pathway_engine::connectors::data_lake::DeltaBatchWriter; +use pathway_engine::connectors::data_lake::{parquet_row_into_values_map, DeltaBatchWriter}; use pathway_engine::connectors::data_storage::{ - ConnectorMode, DeltaTableReader, LakeWriter, ObjectDownloader, WriteError, Writer, + ConnectorMode, ConversionError, DeltaTableReader, LakeWriter, ObjectDownloader, WriteError, + Writer, }; use pathway_engine::connectors::SessionType; use pathway_engine::engine::{ - DateTimeNaive, DateTimeUtc, Duration, Key, Result, Timestamp, Type, Value, + DateTimeNaive, DateTimeUtc, Duration, Duration as EngineDuration, Key, Result, Timestamp, Type, + Value, }; use pathway_engine::python_api::ValueField; @@ -49,23 +51,53 @@ fn run_single_column_save(type_: Type, values: &[Value]) -> eyre::Result<()> { writer.write(context)?; } writer.flush(true)?; - let rows_present = read_from_deltalake(test_storage_path.to_str().unwrap(), &type_); + let rows_present: Vec<_> = read_from_deltalake(test_storage_path.to_str().unwrap(), &type_) + .into_iter() + .map(|item| item.unwrap()) + .collect(); assert_eq!(rows_present, values); - let rows_roundtrip = read_with_connector(test_storage_path.to_str().unwrap(), type_)?; + let rows_roundtrip = read_with_connector(test_storage_path.to_str().unwrap(), &type_)?; assert_eq!(rows_roundtrip, values); + let is_optional = matches!(type_, Type::Optional(_)); + if !is_optional { + // If the type isn't optional, we run a test for its optional version. + // To do that, we create an optional version of the type, append a null-value + // to the end of the tested values vector, and run test on the parameters + // modified this way. + let mut values_with_nulls = values.to_vec(); + values_with_nulls.push(Value::None); + run_single_column_save(Type::Optional(type_.into()), &values_with_nulls)?; + } else { + // If the type is optional, we've previously added None to the end of the vector. + // Then we need to check that the without optionality would fail. + let mut rows_roundtrip = + read_from_deltalake(test_storage_path.to_str().unwrap(), type_.unoptionalize()); + assert!(rows_roundtrip.pop().unwrap().is_err()); + let rows_roundtrip: Vec<_> = rows_roundtrip + .into_iter() + .map(|item| item.unwrap()) + .collect(); + assert_eq!(rows_roundtrip, values[..rows_roundtrip.len()]); + + let mut rows_roundtrip = + read_with_connector(test_storage_path.to_str().unwrap(), type_.unoptionalize())?; + assert_eq!(rows_roundtrip.pop().unwrap(), Value::Error); + assert_eq!(rows_roundtrip, values[..rows_roundtrip.len()]); + } + Ok(()) } -fn read_with_connector(path: &str, type_: Type) -> Result> { +fn read_with_connector(path: &str, type_: &Type) -> Result> { let mut schema = HashMap::new(); schema.insert( "field".to_string(), - InnerSchemaField::new(Type::Optional(type_.clone().into()), None), + InnerSchemaField::new(type_.clone(), None), ); let mut type_map = HashMap::new(); - type_map.insert("field".to_string(), type_); + type_map.insert("field".to_string(), type_.clone()); let reader = DeltaTableReader::new( path, ObjectDownloader::Local, @@ -90,8 +122,12 @@ fn read_with_connector(path: &str, type_: Type) -> Result> { Ok(result) } -fn read_from_deltalake(path: &str, type_: &Type) -> Vec { +fn read_from_deltalake(path: &str, type_: &Type) -> Vec>> { let mut reread_values = Vec::new(); + let mut column_types = HashMap::new(); + column_types.insert("field".to_string(), type_.clone()); + column_types.insert("time".to_string(), Type::Int); + column_types.insert("diff".to_string(), Type::Int); tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -108,37 +144,11 @@ fn read_from_deltalake(path: &str, type_: &Type) -> Vec { .map(|p| SerializedFileReader::try_from(Path::new(p)).unwrap()) .flat_map(|r| r.into_iter()); for row in rows { - let mut has_time_column = false; - let mut has_diff_column = false; - for (name, field) in row.expect("row reading failed").get_column_iter() { - if name == "time" { - has_time_column = true; - } - if name == "diff" { - has_diff_column = true; - } - if name != "field" { - continue; - } - let parsed_value = match (field, type_) { - (ParquetField::Null, _) => Value::None, - (ParquetField::Bool(b), Type::Bool) => Value::from(*b), - (ParquetField::Long(i), Type::Int) => Value::from(*i), - (ParquetField::Long(i), Type::Duration) => Value::from(Duration::new_with_unit(*i, "us").unwrap()), - (ParquetField::Double(f), Type::Float) => Value::Float((*f).into()), - (ParquetField::Str(s), Type::String) => Value::String(s.into()), - (ParquetField::Str(s), Type::Json) => { - let json: serde_json::Value = serde_json::from_str(s).unwrap(); - Value::from(json) - }, - (ParquetField::TimestampMicros(us), Type::DateTimeNaive) => Value::from(DateTimeNaive::from_timestamp(*us, "us").unwrap()), - (ParquetField::TimestampMicros(us), Type::DateTimeUtc) => Value::from(DateTimeUtc::from_timestamp(*us, "us").unwrap()), - (ParquetField::Bytes(b), Type::Bytes) => Value::Bytes(b.data().into()), - (field, type_) => panic!("Pathway shouldn't have serialized field of type {type_:?} as {field:?}"), - }; - reread_values.push(parsed_value); - } - assert!(has_time_column && has_diff_column); + let row = row.expect("row reading failed"); + let values_map = parquet_row_into_values_map(&row, &column_types); + reread_values.push(values_map.get("field").unwrap().clone()); + assert!(values_map.get("time").is_some()); + assert!(values_map.get("diff").is_some()); } }); @@ -219,21 +229,266 @@ fn test_save_json() -> eyre::Result<()> { } #[test] -fn test_unsupported_types_fail_as_expected() -> eyre::Result<()> { - let unsupported_types = &[ - Type::Any, - Type::Array(Some(2), Type::Int.into()), - Type::PyObjectWrapper, - Type::Tuple([].into()), - Type::Pointer, +fn test_save_list() -> eyre::Result<()> { + let value_list_1 = vec![ + Value::Duration(EngineDuration::new_with_unit(-1, "s")?), + Value::Duration(EngineDuration::new_with_unit(2, "ms")?), + Value::Duration(EngineDuration::new_with_unit(0, "ns")?), ]; - for t in unsupported_types { - let save_result = run_single_column_save(t.clone(), &[]); - assert!(save_result.is_err()); - assert_matches!( - save_result.err().unwrap().downcast::(), - Ok(WriteError::UnsupportedType(_)) - ); - } + let value_list_2 = vec![ + Value::Duration(EngineDuration::new_with_unit(-10, "s")?), + Value::Duration(EngineDuration::new_with_unit(20, "ms")?), + Value::Duration(EngineDuration::new_with_unit(0, "ns")?), + ]; + run_single_column_save( + Type::List(Type::Duration.into()), + &[ + Value::Tuple(value_list_1.into()), + Value::Tuple(value_list_2.into()), + ], + ) +} + +#[test] +fn test_save_optionals_list() -> eyre::Result<()> { + let value_list_1 = vec![ + Value::Duration(EngineDuration::new_with_unit(-1, "s")?), + Value::Duration(EngineDuration::new_with_unit(2, "ms")?), + Value::Duration(EngineDuration::new_with_unit(0, "ns")?), + Value::None, + ]; + let value_list_2 = vec![ + Value::Duration(EngineDuration::new_with_unit(-10, "s")?), + Value::None, + Value::Duration(EngineDuration::new_with_unit(20, "ms")?), + Value::Duration(EngineDuration::new_with_unit(0, "ns")?), + ]; + run_single_column_save( + Type::List(Type::Optional(Type::Duration.into()).into()), + &[ + Value::Tuple(value_list_1.into()), + Value::Tuple(value_list_2.into()), + ], + ) +} + +#[test] +fn test_save_pointer() -> eyre::Result<()> { + run_single_column_save(Type::Pointer, &[Value::Pointer(Key::random())]) +} + +#[test] +fn test_save_int_array() -> eyre::Result<()> { + let array1 = ArrayD::::from_shape_vec(vec![2, 3], vec![0, 1, 2, 3, 4, 5]).unwrap(); + let array2 = + ArrayD::::from_shape_vec(vec![2, 2, 2], vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap(); + run_single_column_save( + Type::Array(None, Type::Int.into()), + &[Value::from(array1), Value::from(array2)], + ) +} + +#[test] +fn test_save_float_array() -> eyre::Result<()> { + let array1 = + ArrayD::::from_shape_vec(vec![2, 3], vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).unwrap(); + let array2 = + ArrayD::::from_shape_vec(vec![2, 2, 2], vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) + .unwrap(); + run_single_column_save( + Type::Array(None, Type::Float.into()), + &[Value::from(array1), Value::from(array2)], + ) +} + +#[test] +fn test_save_tuple() -> eyre::Result<()> { + let tuple_contents = vec![Type::String, Type::Int]; + let tuple_type = Type::Tuple(tuple_contents.into()); + + run_single_column_save( + tuple_type, + &[ + Value::Tuple(vec![Value::String("hello".into()), Value::Int(10)].into()), + Value::Tuple(vec![Value::String("world".into()), Value::Int(20)].into()), + ], + ) +} + +#[test] +fn test_save_tuple_with_optionals() -> eyre::Result<()> { + let tuple_contents = vec![ + Type::String, + Type::Optional(Type::Int.into()), + Type::Optional(Type::Bool.into()), + ]; + let tuple_type = Type::Tuple(tuple_contents.into()); + + run_single_column_save( + tuple_type, + &[ + Value::Tuple(vec![Value::String("lorem".into()), Value::Int(10), Value::None].into()), + Value::Tuple( + vec![ + Value::String("ipsum".into()), + Value::Int(20), + Value::Bool(true), + ] + .into(), + ), + Value::Tuple( + vec![ + Value::String("dolor".into()), + Value::None, + Value::Bool(false), + ] + .into(), + ), + Value::Tuple(vec![Value::String("sit".into()), Value::None, Value::None].into()), + ], + ) +} + +#[test] +fn test_save_tuple_nested_tuples() -> eyre::Result<()> { + // (String, (Int, (Bool, Bytes))) + let tuple_contents = vec![ + Type::String, + Type::Tuple(vec![Type::Int, Type::Tuple(vec![Type::Bool, Type::Bytes].into())].into()), + ]; + let tuple_type = Type::Tuple(tuple_contents.into()); + + run_single_column_save( + tuple_type, + &[ + Value::Tuple( + vec![ + Value::String("lorem".into()), + Value::Tuple( + vec![ + Value::Int(10), + Value::Tuple( + vec![Value::Bool(true), Value::Bytes(b"lorem".to_vec().into())] + .into(), + ), + ] + .into(), + ), + ] + .into(), + ), + Value::Tuple( + vec![ + Value::String("ipsum".into()), + Value::Tuple( + vec![ + Value::Int(20), + Value::Tuple( + vec![Value::Bool(false), Value::Bytes(b"ipsum".to_vec().into())] + .into(), + ), + ] + .into(), + ), + ] + .into(), + ), + ], + ) +} + +#[test] +fn test_save_tuple_nested_tuples_with_arrays() -> eyre::Result<()> { + // (String, (Int, (Optional, FloatArray))) + let tuple_contents = vec![ + Type::String, + Type::Tuple( + vec![ + Type::Int, + Type::Tuple( + vec![ + Type::Optional(Type::Array(None, Type::Int.into()).into()), + Type::Array(None, Type::Float.into()), + ] + .into(), + ), + ] + .into(), + ), + ]; + let tuple_type = Type::Tuple(tuple_contents.into()); + + let int_array_1 = ArrayD::::from_shape_vec(vec![2, 3], vec![0, 1, 2, 3, 4, 5]).unwrap(); + let int_array_2 = + ArrayD::::from_shape_vec(vec![2, 2, 2], vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap(); + let float_array_1 = + ArrayD::::from_shape_vec(vec![2, 3], vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).unwrap(); + let float_array_2 = + ArrayD::::from_shape_vec(vec![2, 2, 2], vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) + .unwrap(); + let float_array_3 = ArrayD::::from_shape_vec(vec![2, 1], vec![-1.1, 1.2]).unwrap(); + + run_single_column_save( + tuple_type, + &[ + Value::Tuple( + vec![ + Value::String("lorem".into()), + Value::Tuple( + vec![ + Value::Int(10), + Value::Tuple( + vec![Value::from(int_array_1), Value::from(float_array_1)].into(), + ), + ] + .into(), + ), + ] + .into(), + ), + Value::Tuple( + vec![ + Value::String("ipsum".into()), + Value::Tuple( + vec![ + Value::Int(20), + Value::Tuple( + vec![Value::from(int_array_2), Value::from(float_array_2)].into(), + ), + ] + .into(), + ), + ] + .into(), + ), + Value::Tuple( + vec![ + Value::String("dolor".into()), + Value::Tuple( + vec![ + Value::Int(30), + Value::Tuple(vec![Value::None, Value::from(float_array_3)].into()), + ] + .into(), + ), + ] + .into(), + ), + ], + ) +} + +#[test] +fn test_py_object_wrapper_makes_no_error() -> eyre::Result<()> { + run_single_column_save(Type::PyObjectWrapper, &[]) +} + +#[test] +fn test_save_any_is_unsupported() -> eyre::Result<()> { + let save_result = run_single_column_save(Type::Any, &[Value::from(json!({"A": 100}))]); + assert_matches!( + save_result.err().unwrap().downcast::(), + Ok(WriteError::UnsupportedType(_)) + ); Ok(()) }