Skip to content

Commit

Permalink
Postgres writer init mode (#8049)
Browse files Browse the repository at this point in the history
Co-authored-by: Kamil Piechowiak <[email protected]>
GitOrigin-RevId: db0864baf4e65eb15814ffd2faa13532323d47ea
  • Loading branch information
2 people authored and Manul from Pathway committed Jan 21, 2025
1 parent 8147bc7 commit f72fbd4
Show file tree
Hide file tree
Showing 8 changed files with 417 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

### Added
- `pw.io.iceberg.read` method for reading Apache Iceberg tables into Pathway.
- methods `pw.io.postgres.write` and `pw.io.postgres.write_snapshot` now accept an additional argument `init_mode`, which allows initializing the table before writing.

### Changed
- **BREAKING**: `pw.io.deltalake.read` now requires explicit specification of primary key fields.
Expand Down
202 changes: 190 additions & 12 deletions integration_tests/db_connectors/test_postgres.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import datetime
import json

import pytest
from utils import POSTGRES_SETTINGS

import pathway as pw
from pathway.internals import api
from pathway.internals.parse_graph import G
from pathway.tests.utils import run


def test_psql_output_stream(tmp_path, postgres):
Expand All @@ -16,30 +20,31 @@ class InputSchema(pw.Schema):
input_path = tmp_path / "input.txt"
output_table = postgres.create_table(InputSchema, used_for_output=True)

def run(test_items: list[dict]) -> None:
def _run(test_items: list[dict]) -> None:
G.clear()
with open(input_path, "w") as f:
for test_item in test_items:
f.write(json.dumps(test_item) + "\n")
table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static")
pw.io.postgres.write(table, POSTGRES_SETTINGS, output_table)
pw.run()
run()

test_items = [
{"name": "Milk", "count": 500, "price": 1.5, "available": False},
{"name": "Water", "count": 600, "price": 0.5, "available": True},
]
run(test_items)
_run(test_items)

rows = postgres.get_table_contents(output_table, InputSchema.column_names())
rows.sort(key=lambda item: (item["name"], item["available"]))
assert rows == test_items

new_test_items = [{"name": "Milk", "count": 500, "price": 1.5, "available": True}]
run(new_test_items)
_run(new_test_items)

rows = postgres.get_table_contents(output_table, InputSchema.column_names())
rows.sort(key=lambda item: (item["name"], item["available"]))
rows = postgres.get_table_contents(
output_table, InputSchema.column_names(), ("name", "available")
)
expected_rows = [
{"name": "Milk", "count": 500, "price": 1.5, "available": False},
{"name": "Milk", "count": 500, "price": 1.5, "available": True},
Expand All @@ -58,32 +63,205 @@ class InputSchema(pw.Schema):
input_path = tmp_path / "input.txt"
output_table = postgres.create_table(InputSchema, used_for_output=True)

def run(test_items: list[dict]) -> None:
def _run(test_items: list[dict]) -> None:
G.clear()
with open(input_path, "w") as f:
for test_item in test_items:
f.write(json.dumps(test_item) + "\n")
table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static")
pw.io.postgres.write_snapshot(table, POSTGRES_SETTINGS, output_table, ["name"])
pw.run()
run()

test_items = [
{"name": "Milk", "count": 500, "price": 1.5, "available": False},
{"name": "Water", "count": 600, "price": 0.5, "available": True},
]
run(test_items)
_run(test_items)

rows = postgres.get_table_contents(output_table, InputSchema.column_names())
rows.sort(key=lambda item: item["name"])
assert rows == test_items

new_test_items = [{"name": "Milk", "count": 500, "price": 1.5, "available": True}]
run(new_test_items)
_run(new_test_items)

rows = postgres.get_table_contents(output_table, InputSchema.column_names(), "name")

rows = postgres.get_table_contents(output_table, InputSchema.column_names())
rows.sort(key=lambda item: item["name"])
expected_rows = [
{"name": "Milk", "count": 500, "price": 1.5, "available": True},
{"name": "Water", "count": 600, "price": 0.5, "available": True},
]
assert rows == expected_rows


def write_snapshot(primary_key: list[str]):
def _write_snapshot(table: pw.Table, /, **kwargs):
pw.io.postgres.write_snapshot(table, **kwargs, primary_key=primary_key)

return _write_snapshot


@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["a"])])
def test_init_wrong_mode(write_method):
class InputSchema(pw.Schema):
a: str
b: int

rows = [
("foo", 1),
("bar", 2),
]

table = pw.debug.table_from_rows(
InputSchema,
rows,
)

with pytest.raises(ValueError, match="Invalid init_mode: wrong_mode"):
write_method(
table,
postgres_settings=POSTGRES_SETTINGS,
init_mode="wrong_mode",
table_name="non_existent_table",
)


@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["a"])])
def test_init_default_table_not_exists(write_method):
class InputSchema(pw.Schema):
a: str
b: int

rows = [
("foo", 1),
("bar", 2),
]

table = pw.debug.table_from_rows(
InputSchema,
rows,
)

with pytest.raises(api.EngineError):
write_method(
table,
postgres_settings=POSTGRES_SETTINGS,
table_name="non_existent_table",
)
run()


@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["a"])])
def test_init_create_if_not_exists(write_method, postgres):
table_name = postgres.random_table_name()

class InputSchema(pw.Schema):
a: str
b: float
c: bool
d: list[int]
e: tuple[int, int, int]
f: pw.Json
g: str
h: str
# i: np.ndarray[typing.Any, np.dtype[int]]

rows = [
{
"a": "foo",
"b": 1.5,
"c": False,
"d": [1, 2, 3],
"e": (1, 2, 3),
"f": {"foo": "bar", "baz": 123},
"g": "2025-03-14T10:13:00",
"h": "2025-04-23T10:13:00+00:00",
# "i": np.array([1, 2, 3]),
}
]

table = pw.debug.table_from_rows(
InputSchema,
[tuple(row.values()) for row in rows],
).with_columns(
g=pw.this.g.dt.strptime("%Y-%m-%dT%H:%M:%S", contains_timezone=False),
h=pw.this.h.dt.strptime("%Y-%m-%dT%H:%M:%S%z", contains_timezone=True),
)

write_method(
table,
postgres_settings=POSTGRES_SETTINGS,
table_name=table_name,
init_mode="create_if_not_exists",
)
run()

result = postgres.get_table_contents(table_name, InputSchema.column_names())

assert result == [
{
"a": "foo",
"b": 1.5,
"c": False,
"d": [1, 2, 3],
"e": [1, 2, 3],
"f": {"foo": "bar", "baz": 123},
"g": datetime.datetime(2025, 3, 14, 10, 13),
"h": datetime.datetime(2025, 4, 23, 10, 13, tzinfo=datetime.timezone.utc),
# "i": np.array([4, 5, 6], dtype=int),
},
]


@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["i"])])
def test_init_create_if_not_exists_append(write_method, postgres):
table_name = postgres.random_table_name()

class InputSchema(pw.Schema):
i: int
data: int

for i in range(3):
G.clear()
table = pw.debug.table_from_rows(
InputSchema,
[(i, i)],
)
write_method(
table,
postgres_settings=POSTGRES_SETTINGS,
table_name=table_name,
init_mode="create_if_not_exists",
)
run()

result = postgres.get_table_contents(table_name, InputSchema.column_names(), "i")

assert result == [{"i": 0, "data": 0}, {"i": 1, "data": 1}, {"i": 2, "data": 2}]


@pytest.mark.parametrize("write_method", [pw.io.postgres.write, write_snapshot(["i"])])
def test_init_replace(write_method, postgres):
table_name = postgres.random_table_name()

class InputSchema(pw.Schema):
i: int
data: int

for i in range(3):
G.clear()
table = pw.debug.table_from_rows(
InputSchema,
[(i, i)],
)
write_method(
table,
postgres_settings=POSTGRES_SETTINGS,
table_name=table_name,
init_mode="replace",
)
run()

result = postgres.get_table_contents(table_name, InputSchema.column_names(), "i")

assert result == [{"i": 2, "data": 2}]
15 changes: 13 additions & 2 deletions integration_tests/db_connectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def insert_row(
self.cursor.execute(condition)

def create_table(self, schema: type[pw.Schema], *, used_for_output: bool) -> str:
table_name = f'postgres_{str(uuid.uuid4()).replace("-", "")}'
table_name = self.random_table_name()

primary_key_found = False
fields = []
Expand Down Expand Up @@ -106,7 +106,10 @@ def create_table(self, schema: type[pw.Schema], *, used_for_output: bool) -> str
return table_name

def get_table_contents(
self, table_name: str, column_names: list[str]
self,
table_name: str,
column_names: list[str],
sort_by: str | tuple | None = None,
) -> list[dict[str, str | int | bool | float]]:
select_query = f'SELECT {",".join(column_names)} FROM {table_name};'
self.cursor.execute(select_query)
Expand All @@ -117,8 +120,16 @@ def get_table_contents(
for name, value in zip(column_names, row):
row_map[name] = value
result.append(row_map)
if sort_by is not None:
if isinstance(sort_by, tuple):
result.sort(key=lambda item: tuple(item[key] for key in sort_by))
else:
result.sort(key=lambda item: item[sort_by])
return result

def random_table_name(self) -> str:
return f'postgres_{str(uuid.uuid4()).replace("-", "")}'


class MongoDBContext:
client: MongoClient
Expand Down
7 changes: 7 additions & 0 deletions python/pathway/engine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,8 @@ class DataStorage:
object_pattern: str
mock_events: dict[tuple[str, int], list[SnapshotEvent]] | None
table_name: str | None
sql_writer_init_mode: SqlWriterInitMode

def __init__(self, *args, **kwargs): ...

class CsvParserSettings:
Expand Down Expand Up @@ -807,6 +809,11 @@ class SessionType(Enum):
NATIVE: SessionType
UPSERT: SessionType

class SqlWriterInitMode(Enum):
DEFAULT: SqlWriterInitMode
CREATE_IF_NOT_EXISTS: SqlWriterInitMode
REPLACE: SqlWriterInitMode

class SnapshotEvent:
@staticmethod
def insert(key: Pointer, values: list[Value]) -> SnapshotEvent: ...
Expand Down
8 changes: 7 additions & 1 deletion python/pathway/internals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from __future__ import annotations

from pathway.internals import reducers, udfs, universes
from pathway.internals.api import Pointer, PyObjectWrapper, wrap_py_object
from pathway.internals.api import (
Pointer,
PyObjectWrapper,
SqlWriterInitMode,
wrap_py_object,
)
from pathway.internals.common import (
apply,
apply_async,
Expand Down Expand Up @@ -152,4 +157,5 @@
"local_error_log",
"ColumnDefinition",
"load_yaml",
"SqlWriterInitMode",
]
Loading

0 comments on commit f72fbd4

Please sign in to comment.