From f72fbd462d87d9b24c9358970c2998ab943bf91e Mon Sep 17 00:00:00 2001 From: Jakub Kowalski Date: Tue, 21 Jan 2025 23:29:33 +0100 Subject: [PATCH] Postgres writer init mode (#8049) Co-authored-by: Kamil Piechowiak <32928185+KamilPiechowiak@users.noreply.github.com> GitOrigin-RevId: db0864baf4e65eb15814ffd2faa13532323d47ea --- CHANGELOG.md | 1 + .../db_connectors/test_postgres.py | 202 ++++++++++++++++-- integration_tests/db_connectors/utils.py | 15 +- python/pathway/engine.pyi | 7 + python/pathway/internals/__init__.py | 8 +- python/pathway/io/postgres/__init__.py | 36 +++- src/connectors/data_storage.rs | 124 ++++++++++- src/python_api.rs | 49 ++++- 8 files changed, 417 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 165efab2..14053c2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Added - `pw.io.iceberg.read` method for reading Apache Iceberg tables into Pathway. +- methods `pw.io.postgres.write` and `pw.io.postgres.write_snapshot` now accept an additional argument `init_mode`, which allows initializing the table before writing. ### Changed - **BREAKING**: `pw.io.deltalake.read` now requires explicit specification of primary key fields. diff --git a/integration_tests/db_connectors/test_postgres.py b/integration_tests/db_connectors/test_postgres.py index 82aa2cbf..563366f6 100644 --- a/integration_tests/db_connectors/test_postgres.py +++ b/integration_tests/db_connectors/test_postgres.py @@ -1,9 +1,13 @@ +import datetime import json +import pytest from utils import POSTGRES_SETTINGS import pathway as pw +from pathway.internals import api from pathway.internals.parse_graph import G +from pathway.tests.utils import run def test_psql_output_stream(tmp_path, postgres): @@ -16,30 +20,31 @@ class InputSchema(pw.Schema): input_path = tmp_path / "input.txt" output_table = postgres.create_table(InputSchema, used_for_output=True) - def run(test_items: list[dict]) -> None: + def _run(test_items: list[dict]) -> None: G.clear() with open(input_path, "w") as f: for test_item in test_items: f.write(json.dumps(test_item) + "\n") table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static") pw.io.postgres.write(table, POSTGRES_SETTINGS, output_table) - pw.run() + run() test_items = [ {"name": "Milk", "count": 500, "price": 1.5, "available": False}, {"name": "Water", "count": 600, "price": 0.5, "available": True}, ] - run(test_items) + _run(test_items) rows = postgres.get_table_contents(output_table, InputSchema.column_names()) rows.sort(key=lambda item: (item["name"], item["available"])) assert rows == test_items new_test_items = [{"name": "Milk", "count": 500, "price": 1.5, "available": True}] - run(new_test_items) + _run(new_test_items) - rows = postgres.get_table_contents(output_table, InputSchema.column_names()) - rows.sort(key=lambda item: (item["name"], item["available"])) + rows = postgres.get_table_contents( + output_table, InputSchema.column_names(), ("name", "available") + ) expected_rows = [ {"name": "Milk", "count": 500, "price": 1.5, "available": False}, {"name": "Milk", "count": 500, "price": 1.5, "available": True}, @@ -58,32 +63,205 @@ class InputSchema(pw.Schema): input_path = tmp_path / "input.txt" output_table = postgres.create_table(InputSchema, used_for_output=True) - def run(test_items: list[dict]) -> None: + def _run(test_items: list[dict]) -> None: G.clear() with open(input_path, "w") as f: for test_item in test_items: f.write(json.dumps(test_item) + "\n") table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static") pw.io.postgres.write_snapshot(table, POSTGRES_SETTINGS, output_table, ["name"]) - pw.run() + run() test_items = [ {"name": "Milk", "count": 500, "price": 1.5, "available": False}, {"name": "Water", "count": 600, "price": 0.5, "available": True}, ] - run(test_items) + _run(test_items) rows = postgres.get_table_contents(output_table, InputSchema.column_names()) rows.sort(key=lambda item: item["name"]) assert rows == test_items new_test_items = [{"name": "Milk", "count": 500, "price": 1.5, "available": True}] - run(new_test_items) + _run(new_test_items) + + rows = postgres.get_table_contents(output_table, InputSchema.column_names(), "name") - rows = postgres.get_table_contents(output_table, InputSchema.column_names()) - rows.sort(key=lambda item: item["name"]) expected_rows = [ {"name": "Milk", "count": 500, "price": 1.5, "available": True}, {"name": "Water", "count": 600, "price": 0.5, "available": True}, ] assert rows == expected_rows + + +def write_snapshot(primary_key: list[str]): + def _write_snapshot(table: pw.Table, /, **kwargs): + pw.io.postgres.write_snapshot(table, **kwargs, primary_key=primary_key) + + return _write_snapshot + + +@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["a"])]) +def test_init_wrong_mode(write_method): + class InputSchema(pw.Schema): + a: str + b: int + + rows = [ + ("foo", 1), + ("bar", 2), + ] + + table = pw.debug.table_from_rows( + InputSchema, + rows, + ) + + with pytest.raises(ValueError, match="Invalid init_mode: wrong_mode"): + write_method( + table, + postgres_settings=POSTGRES_SETTINGS, + init_mode="wrong_mode", + table_name="non_existent_table", + ) + + +@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["a"])]) +def test_init_default_table_not_exists(write_method): + class InputSchema(pw.Schema): + a: str + b: int + + rows = [ + ("foo", 1), + ("bar", 2), + ] + + table = pw.debug.table_from_rows( + InputSchema, + rows, + ) + + with pytest.raises(api.EngineError): + write_method( + table, + postgres_settings=POSTGRES_SETTINGS, + table_name="non_existent_table", + ) + run() + + +@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["a"])]) +def test_init_create_if_not_exists(write_method, postgres): + table_name = postgres.random_table_name() + + class InputSchema(pw.Schema): + a: str + b: float + c: bool + d: list[int] + e: tuple[int, int, int] + f: pw.Json + g: str + h: str + # i: np.ndarray[typing.Any, np.dtype[int]] + + rows = [ + { + "a": "foo", + "b": 1.5, + "c": False, + "d": [1, 2, 3], + "e": (1, 2, 3), + "f": {"foo": "bar", "baz": 123}, + "g": "2025-03-14T10:13:00", + "h": "2025-04-23T10:13:00+00:00", + # "i": np.array([1, 2, 3]), + } + ] + + table = pw.debug.table_from_rows( + InputSchema, + [tuple(row.values()) for row in rows], + ).with_columns( + g=pw.this.g.dt.strptime("%Y-%m-%dT%H:%M:%S", contains_timezone=False), + h=pw.this.h.dt.strptime("%Y-%m-%dT%H:%M:%S%z", contains_timezone=True), + ) + + write_method( + table, + postgres_settings=POSTGRES_SETTINGS, + table_name=table_name, + init_mode="create_if_not_exists", + ) + run() + + result = postgres.get_table_contents(table_name, InputSchema.column_names()) + + assert result == [ + { + "a": "foo", + "b": 1.5, + "c": False, + "d": [1, 2, 3], + "e": [1, 2, 3], + "f": {"foo": "bar", "baz": 123}, + "g": datetime.datetime(2025, 3, 14, 10, 13), + "h": datetime.datetime(2025, 4, 23, 10, 13, tzinfo=datetime.timezone.utc), + # "i": np.array([4, 5, 6], dtype=int), + }, + ] + + +@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["i"])]) +def test_init_create_if_not_exists_append(write_method, postgres): + table_name = postgres.random_table_name() + + class InputSchema(pw.Schema): + i: int + data: int + + for i in range(3): + G.clear() + table = pw.debug.table_from_rows( + InputSchema, + [(i, i)], + ) + write_method( + table, + postgres_settings=POSTGRES_SETTINGS, + table_name=table_name, + init_mode="create_if_not_exists", + ) + run() + + result = postgres.get_table_contents(table_name, InputSchema.column_names(), "i") + + assert result == [{"i": 0, "data": 0}, {"i": 1, "data": 1}, {"i": 2, "data": 2}] + + +@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["i"])]) +def test_init_replace(write_method, postgres): + table_name = postgres.random_table_name() + + class InputSchema(pw.Schema): + i: int + data: int + + for i in range(3): + G.clear() + table = pw.debug.table_from_rows( + InputSchema, + [(i, i)], + ) + write_method( + table, + postgres_settings=POSTGRES_SETTINGS, + table_name=table_name, + init_mode="replace", + ) + run() + + result = postgres.get_table_contents(table_name, InputSchema.column_names(), "i") + + assert result == [{"i": 2, "data": 2}] diff --git a/integration_tests/db_connectors/utils.py b/integration_tests/db_connectors/utils.py index 87a8ae27..1a11761e 100644 --- a/integration_tests/db_connectors/utils.py +++ b/integration_tests/db_connectors/utils.py @@ -70,7 +70,7 @@ def insert_row( self.cursor.execute(condition) def create_table(self, schema: type[pw.Schema], *, used_for_output: bool) -> str: - table_name = f'postgres_{str(uuid.uuid4()).replace("-", "")}' + table_name = self.random_table_name() primary_key_found = False fields = [] @@ -106,7 +106,10 @@ def create_table(self, schema: type[pw.Schema], *, used_for_output: bool) -> str return table_name def get_table_contents( - self, table_name: str, column_names: list[str] + self, + table_name: str, + column_names: list[str], + sort_by: str | tuple | None = None, ) -> list[dict[str, str | int | bool | float]]: select_query = f'SELECT {",".join(column_names)} FROM {table_name};' self.cursor.execute(select_query) @@ -117,8 +120,16 @@ def get_table_contents( for name, value in zip(column_names, row): row_map[name] = value result.append(row_map) + if sort_by is not None: + if isinstance(sort_by, tuple): + result.sort(key=lambda item: tuple(item[key] for key in sort_by)) + else: + result.sort(key=lambda item: item[sort_by]) return result + def random_table_name(self) -> str: + return f'postgres_{str(uuid.uuid4()).replace("-", "")}' + class MongoDBContext: client: MongoClient diff --git a/python/pathway/engine.pyi b/python/pathway/engine.pyi index 899e4a77..a93616af 100644 --- a/python/pathway/engine.pyi +++ b/python/pathway/engine.pyi @@ -757,6 +757,8 @@ class DataStorage: object_pattern: str mock_events: dict[tuple[str, int], list[SnapshotEvent]] | None table_name: str | None + sql_writer_init_mode: SqlWriterInitMode + def __init__(self, *args, **kwargs): ... class CsvParserSettings: @@ -807,6 +809,11 @@ class SessionType(Enum): NATIVE: SessionType UPSERT: SessionType +class SqlWriterInitMode(Enum): + DEFAULT: SqlWriterInitMode + CREATE_IF_NOT_EXISTS: SqlWriterInitMode + REPLACE: SqlWriterInitMode + class SnapshotEvent: @staticmethod def insert(key: Pointer, values: list[Value]) -> SnapshotEvent: ... diff --git a/python/pathway/internals/__init__.py b/python/pathway/internals/__init__.py index 5a78c9a9..1cbbd673 100644 --- a/python/pathway/internals/__init__.py +++ b/python/pathway/internals/__init__.py @@ -3,7 +3,12 @@ from __future__ import annotations from pathway.internals import reducers, udfs, universes -from pathway.internals.api import Pointer, PyObjectWrapper, wrap_py_object +from pathway.internals.api import ( + Pointer, + PyObjectWrapper, + SqlWriterInitMode, + wrap_py_object, +) from pathway.internals.common import ( apply, apply_async, @@ -152,4 +157,5 @@ "local_error_log", "ColumnDefinition", "load_yaml", + "SqlWriterInitMode", ] diff --git a/python/pathway/io/postgres/__init__.py b/python/pathway/io/postgres/__init__.py index 2d5ba5be..201fb3fe 100644 --- a/python/pathway/io/postgres/__init__.py +++ b/python/pathway/io/postgres/__init__.py @@ -13,6 +13,18 @@ def _connection_string_from_settings(settings: dict): return " ".join(k + "=" + v for (k, v) in settings.items()) +def _init_mode_from_str(init_mode: str) -> api.SqlWriterInitMode: + match init_mode: + case "default": + return api.SqlWriterInitMode.DEFAULT + case "create_if_not_exists": + return api.SqlWriterInitMode.CREATE_IF_NOT_EXISTS + case "replace": + return api.SqlWriterInitMode.REPLACE + case _: + raise ValueError(f"Invalid init_mode: {init_mode}") + + @check_arg_types @trace_user_frame def write( @@ -20,6 +32,7 @@ def write( postgres_settings: dict, table_name: str, max_batch_size: int | None = None, + init_mode: str = "default", ) -> None: """Writes ``table``'s stream of updates to a postgres table. @@ -27,10 +40,15 @@ def write( and ``diff`` columns of the integer type. Args: + table: Table to be written. postgres_settings: Components for the connection string for Postgres. table_name: Name of the target table. - max_batch_size: Maximum number of entries allowed to be committed within a \ -single transaction. + max_batch_size: Maximum number of entries allowed to be committed within a + single transaction. + init_mode: "default": The default initialization mode; + "create_if_not_exists": initializes the SQL writer by creating the necessary table + if they do not already exist; + "replace": Initializes the SQL writer by replacing any existing table. Returns: None @@ -94,6 +112,8 @@ def write( storage_type="postgres", connection_string=_connection_string_from_settings(postgres_settings), max_batch_size=max_batch_size, + table_name=table_name, + sql_writer_init_mode=_init_mode_from_str(init_mode), ) data_format = api.DataFormat( format_type="sql", @@ -116,6 +136,7 @@ def write_snapshot( table_name: str, primary_key: list[str], max_batch_size: int | None = None, + init_mode: str = "default", ) -> None: """Maintains a snapshot of a table within a Postgres table. @@ -126,8 +147,13 @@ def write_snapshot( postgres_settings: Components of the connection string for Postgres. table_name: Name of the target table. primary_key: Names of the fields which serve as a primary key in the Postgres table. - max_batch_size: Maximum number of entries allowed to be committed within a \ -single transaction. + max_batch_size: Maximum number of entries allowed to be committed within a + single transaction. + init_mode: "default": The default initialization mode; + "create_if_not_exists": initializes the SQL writer by creating the necessary table + if they do not already exist; + "replace": Initializes the SQL writer by replacing any existing table. + Returns: None @@ -177,6 +203,8 @@ def write_snapshot( connection_string=_connection_string_from_settings(postgres_settings), max_batch_size=max_batch_size, snapshot_maintenance_on_output=True, + table_name=table_name, + sql_writer_init_mode=_init_mode_from_str(init_mode), ) data_format = api.DataFormat( format_type="sql_snapshot", diff --git a/src/connectors/data_storage.rs b/src/connectors/data_storage.rs index d92e7bc7..a8b4727c 100644 --- a/src/connectors/data_storage.rs +++ b/src/connectors/data_storage.rs @@ -1,5 +1,6 @@ // Copyright © 2024 Pathway +use postgres::Transaction as PsqlTransaction; use pyo3::exceptions::PyValueError; use pyo3::types::PyBytes; use rdkafka::util::Timeout; @@ -1070,14 +1071,133 @@ impl PsqlWriter { client: PsqlClient, max_batch_size: Option, snapshot_mode: bool, - ) -> PsqlWriter { - PsqlWriter { + table_name: &str, + schema: &HashMap, + key_field_names: &Option>, + mode: SqlWriterInitMode, + ) -> Result { + let mut writer = PsqlWriter { client, max_batch_size, buffer: Vec::new(), snapshot_mode, + }; + writer.initialize(mode, table_name, schema, key_field_names)?; + Ok(writer) + } + + pub fn initialize( + &mut self, + mode: SqlWriterInitMode, + table_name: &str, + schema: &HashMap, + key_field_names: &Option>, + ) -> Result<(), WriteError> { + match mode { + SqlWriterInitMode::Default => return Ok(()), + SqlWriterInitMode::Replace | SqlWriterInitMode::CreateIfNotExists => { + let mut transaction = self.client.transaction()?; + + if mode == SqlWriterInitMode::Replace { + Self::drop_table_if_exists(&mut transaction, table_name)?; + } + Self::create_table_if_not_exists( + &mut transaction, + table_name, + schema, + key_field_names, + )?; + + transaction.commit()?; + } } + + Ok(()) + } + + fn create_table_if_not_exists( + transaction: &mut PsqlTransaction, + table_name: &str, + schema: &HashMap, + key_field_names: &Option>, + ) -> Result<(), WriteError> { + let columns: Vec = schema + .iter() + .map(|(name, dtype)| { + Self::postgres_data_type(dtype).map(|dtype_str| format!("{name} {dtype_str}")) + }) + .collect::, _>>()?; + + let primary_key = key_field_names + .as_ref() + .filter(|keys| !keys.is_empty()) + .map_or(String::new(), |keys| { + format!(", PRIMARY KEY ({})", keys.join(", ")) + }); + + transaction.execute( + &format!( + "CREATE TABLE IF NOT EXISTS {} ({}, time BIGINT, diff BIGINT{})", + table_name, + columns.join(", "), + primary_key + ), + &[], + )?; + + Ok(()) } + + fn drop_table_if_exists( + transaction: &mut PsqlTransaction, + table_name: &str, + ) -> Result<(), WriteError> { + let query = format!("DROP TABLE IF EXISTS {table_name}"); + transaction.execute(&query, &[])?; + Ok(()) + } + + fn postgres_data_type(type_: &Type) -> Result { + Ok(match type_ { + Type::Bool => "BOOLEAN".to_string(), + Type::Int | Type::Duration => "BIGINT".to_string(), + Type::Float => "DOUBLE PRECISION".to_string(), + Type::Pointer | Type::String => "TEXT".to_string(), + Type::Bytes => "BYTEA".to_string(), + Type::Json => "JSONB".to_string(), + Type::DateTimeNaive => "TIMESTAMP".to_string(), + Type::DateTimeUtc => "TIMESTAMPTZ".to_string(), + Type::Optional(wrapped) | Type::List(wrapped) => { + if let Type::Any = **wrapped { + return Err(WriteError::UnsupportedType(type_.clone())); + } + + let wrapped = Self::postgres_data_type(wrapped)?; + if let Type::Optional(_) = type_ { + return Ok(wrapped); + } + format!("{wrapped}[]") + } + Type::Tuple(fields) => { + let mut iter = fields.iter(); + if !fields.is_empty() && iter.all(|field| field == &fields[0]) { + let first = Self::postgres_data_type(&fields[0])?; + return Ok(format!("{first}[]")); + } + return Err(WriteError::UnsupportedType(type_.clone())); + } + Type::Any | Type::Array(_, _) | Type::PyObjectWrapper => { + return Err(WriteError::UnsupportedType(type_.clone())) + } + }) + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum SqlWriterInitMode { + Default, + CreateIfNotExists, + Replace, } mod to_sql { diff --git a/src/python_api.rs b/src/python_api.rs index 10e3eed7..85ea2690 100644 --- a/src/python_api.rs +++ b/src/python_api.rs @@ -84,7 +84,7 @@ use crate::connectors::data_storage::{ ConnectorMode, DeltaTableReader, ElasticSearchWriter, FileWriter, IcebergReader, KafkaReader, KafkaWriter, LakeWriter, MongoWriter, NatsReader, NatsWriter, NullWriter, ObjectDownloader, PsqlWriter, PythonConnectorEventType, PythonReaderBuilder, ReadError, ReadMethod, - ReaderBuilder, SqliteReader, Writer, + ReaderBuilder, SqlWriterInitMode, SqliteReader, Writer, }; use crate::connectors::scanner::S3Scanner; use crate::connectors::{PersistenceMode, SessionType, SnapshotAccess}; @@ -657,6 +657,18 @@ impl IntoPy for MonitoringLevel { } } +impl<'py> FromPyObject<'py> for SqlWriterInitMode { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + Ok(ob.extract::>()?.0) + } +} + +impl IntoPy for SqlWriterInitMode { + fn into_py(self, py: Python<'_>) -> PyObject { + PySqlWriterInitMode(self).into_py(py) + } +} + impl From for PyErr { fn from(mut error: EngineError) -> Self { match error.downcast::() { @@ -1761,6 +1773,19 @@ impl PyMonitoringLevel { pub const ALL: MonitoringLevel = MonitoringLevel::All; } +#[pyclass(module = "pathway.engine", frozen, name = "SqlWriterInitMode")] +pub struct PySqlWriterInitMode(SqlWriterInitMode); + +#[pymethods] +impl PySqlWriterInitMode { + #[classattr] + pub const DEFAULT: SqlWriterInitMode = SqlWriterInitMode::Default; + #[classattr] + pub const CREATE_IF_NOT_EXISTS: SqlWriterInitMode = SqlWriterInitMode::CreateIfNotExists; + #[classattr] + pub const REPLACE: SqlWriterInitMode = SqlWriterInitMode::Replace; +} + #[pyclass(module = "pathway.engine", frozen)] pub struct Universe { scope: Py, @@ -3709,6 +3734,7 @@ pub struct DataStorage { database: Option, start_from_timestamp_ms: Option, namespace: Option>, + sql_writer_init_mode: SqlWriterInitMode, } #[pyclass(module = "pathway.engine", frozen, name = "PersistenceMode")] @@ -4024,6 +4050,7 @@ impl DataStorage { database = None, start_from_timestamp_ms = None, namespace = None, + sql_writer_init_mode = SqlWriterInitMode::Default, ))] #[allow(clippy::too_many_arguments)] fn new( @@ -4052,6 +4079,7 @@ impl DataStorage { database: Option, start_from_timestamp_ms: Option, namespace: Option>, + sql_writer_init_mode: SqlWriterInitMode, ) -> Self { DataStorage { storage_type, @@ -4079,6 +4107,7 @@ impl DataStorage { database, start_from_timestamp_ms, namespace, + sql_writer_init_mode, } } } @@ -4804,14 +4833,25 @@ impl DataStorage { Ok(Box::new(writer)) } - fn construct_postgres_writer(&self) -> PyResult> { + fn construct_postgres_writer( + &self, + py: pyo3::Python, + data_format: &DataFormat, + ) -> PyResult> { let connection_string = self.connection_string()?; let storage = match Client::connect(connection_string, NoTls) { Ok(client) => PsqlWriter::new( client, self.max_batch_size, self.snapshot_maintenance_on_output, - ), + self.table_name()?, + &data_format.value_fields_type_map(py), + &data_format.key_field_names, + self.sql_writer_init_mode, + ) + .map_err(|e| { + PyIOError::new_err(format!("Unable to initialize PostgreSQL table: {e}")) + })?, Err(e) => { return Err(PyIOError::new_err(format!( "Failed to establish PostgreSQL connection: {e:?}" @@ -4922,7 +4962,7 @@ impl DataStorage { match self.storage_type.as_ref() { "fs" => self.construct_fs_writer(), "kafka" => self.construct_kafka_writer(), - "postgres" => self.construct_postgres_writer(), + "postgres" => self.construct_postgres_writer(py, data_format), "elasticsearch" => self.construct_elasticsearch_writer(py), "deltalake" => self.construct_deltalake_writer(py, data_format), "mongodb" => self.construct_mongodb_writer(), @@ -5498,6 +5538,7 @@ fn engine(_py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;