Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data][Datasink] support passing write results to on_write_completes #49251

Merged
merged 23 commits into from
Dec 28, 2024
Merged
6 changes: 5 additions & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@
# arising from type annotations. See https://github.com/ray-project/ray/pull/46103
# for additional context.
nitpicky = True
nitpick_ignore_regex = [("py:class", ".*")]
nitpick_ignore_regex = [
("py:class", ".*"),
# Workaround for https://github.com/sphinx-doc/sphinx/issues/10974
("py:obj", "ray\.data\.datasource\.datasink\.WriteReturnType"),
]

# Cache notebook outputs in _build/.jupyter_cache
# To prevent notebook execution, set this to "off". To force re-execution, set this to
Expand Down
2 changes: 2 additions & 0 deletions doc/source/data/api/input_output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ Datasink API
datasource.RowBasedFileDatasink
datasource.BlockBasedFileDatasink
datasource.FileBasedDatasource
datasource.WriteResult
datasource.WriteReturnType

Partitioning API
----------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11


class BigQueryDatasink(Datasink):
class BigQueryDatasink(Datasink[None]):
def __init__(
self,
project_id: str,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/mongo_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = logging.getLogger(__name__)


class MongoDatasink(Datasink):
class MongoDatasink(Datasink[None]):
def __init__(self, uri: str, database: str, collection: str) -> None:
_check_import(self, module="pymongo", package="pymongo")
_check_import(self, module="pymongoarrow", package="pymongoarrow")
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/sql_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ray.data.datasource.datasink import Datasink


class SQLDatasink(Datasink):
class SQLDatasink(Datasink[None]):

_MAX_ROWS_PER_WRITE = 128

Expand Down
34 changes: 29 additions & 5 deletions python/ray/data/_internal/planner/plan_write_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import itertools
from typing import Callable, Iterator, List, Union

from pandas import DataFrame

from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.interfaces.task_context import TaskContext
Expand All @@ -16,19 +18,36 @@
from ray.data.datasource.datasource import Datasource


def gen_datasink_write_result(
write_result_blocks: List[Block],
) -> WriteResult:
assert all(
isinstance(block, DataFrame) and len(block) == 1
for block in write_result_blocks
)
total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks)
total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks)

write_returns = [result["write_return"][0] for result in write_result_blocks]
return WriteResult(total_num_rows, total_size_bytes, write_returns)


def generate_write_fn(
datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args
) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]:
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
"""Writes the blocks to the given datasink or legacy datasource.

Outputs the original blocks to be written."""
# Create a copy of the iterator, so we can return the original blocks.
it1, it2 = itertools.tee(blocks, 2)
if isinstance(datasink_or_legacy_datasource, Datasink):
datasink_or_legacy_datasource.write(it1, ctx)
ctx.kwargs["_datasink_write_return"] = datasink_or_legacy_datasource.write(
it1, ctx
)
else:
datasink_or_legacy_datasource.write(it1, ctx, **write_args)

return it2

return fn
Expand All @@ -41,7 +60,7 @@ def generate_collect_write_stats_fn() -> (
# one Block which contain stats/metrics about the write.
# Otherwise, an error will be raised. The Datasource can handle
# execution outcomes with `on_write_complete()`` and `on_write_failed()``.
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
"""Handles stats collection for block writes."""
block_accessors = [BlockAccessor.for_block(block) for block in blocks]
total_num_rows = sum(ba.num_rows() for ba in block_accessors)
Expand All @@ -51,8 +70,13 @@ def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
# type.
import pandas as pd

write_result = WriteResult(num_rows=total_num_rows, size_bytes=total_size_bytes)
block = pd.DataFrame({"write_result": [write_result]})
block = pd.DataFrame(
{
"num_rows": [total_num_rows],
"size_bytes": [total_size_bytes],
"write_return": [ctx.kwargs.get("_datasink_write_return", None)],
}
)
return iter([block])

return fn
Expand Down
14 changes: 9 additions & 5 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ray.data._internal.execution.interfaces.ref_bundle import (
_ref_bundles_iterator_to_block_refs_list,
)
from ray.data._internal.execution.util import memory_string
from ray.data._internal.iterator.iterator_impl import DataIteratorImpl
from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator
from ray.data._internal.logical.operators.all_to_all_operator import (
Expand Down Expand Up @@ -77,6 +78,7 @@
from ray.data._internal.pandas_block import PandasBlockBuilder, PandasBlockSchema
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
from ray.data._internal.planner.plan_write_op import gen_datasink_write_result
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.split import _get_num_rows, _split_at_indices
from ray.data._internal.stats import DatasetStats, DatasetStatsSummary, StatsManager
Expand Down Expand Up @@ -3917,19 +3919,21 @@ def write_datasink(
logical_plan = LogicalPlan(write_op, self.context)

try:
import pandas as pd

datasink.on_write_start()

self._write_ds = Dataset(plan, logical_plan).materialize()
# TODO: Get and handle the blocks with an iterator instead of getting
# everything in a blocking way, so some blocks can be freed earlier.
raw_write_results = ray.get(self._write_ds._plan.execute().block_refs)
assert all(
isinstance(block, pd.DataFrame) and len(block) == 1
for block in raw_write_results
write_result = gen_datasink_write_result(raw_write_results)
logger.info(
"Data sink %s finished. %d rows and %s data written.",
datasink.get_name(),
write_result.num_rows,
memory_string(write_result.size_bytes),
)
datasink.on_write_complete(raw_write_results)
datasink.on_write_complete(write_result)

except Exception as e:
datasink.on_write_failed(e)
Expand Down
9 changes: 8 additions & 1 deletion python/ray/data/datasource/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from ray.data._internal.datasource.sql_datasource import Connection
from ray.data.datasource.datasink import Datasink, DummyOutputDatasink
from ray.data.datasource.datasink import (
Datasink,
DummyOutputDatasink,
WriteResult,
WriteReturnType,
)
from ray.data.datasource.datasource import (
Datasource,
RandomIntRowDatasource,
Expand Down Expand Up @@ -57,4 +62,6 @@
"Reader",
"RowBasedFileDatasink",
"_S3FileSystemWrapper",
"WriteResult",
"WriteReturnType",
]
86 changes: 26 additions & 60 deletions python/ray/data/datasource/datasink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass, fields
from typing import Iterable, List, Optional
from dataclasses import dataclass
from typing import Generic, Iterable, List, Optional, TypeVar

import ray
from ray.data._internal.execution.interfaces import TaskContext
Expand All @@ -10,45 +10,25 @@
logger = logging.getLogger(__name__)


@dataclass
@DeveloperAPI
class WriteResult:
"""Result of a write operation, containing stats/metrics
on the written data.

Attributes:
total_num_rows: The total number of rows written.
total_size_bytes: The total size of the written data in bytes.
"""

num_rows: int = 0
size_bytes: int = 0
WriteReturnType = TypeVar("WriteReturnType")
"""Generic type for the return value of `Datasink.write`."""

@staticmethod
def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult":
"""Aggregate a list of write results.

Args:
write_results: A list of write results.

Returns:
A single write result that aggregates the input results.
"""
total_num_rows = 0
total_size_bytes = 0

for write_result in write_results:
total_num_rows += write_result.num_rows
total_size_bytes += write_result.size_bytes
@dataclass
@DeveloperAPI
class WriteResult(Generic[WriteReturnType]):
"""Aggregated result of the Datasink write operations."""

return WriteResult(
num_rows=total_num_rows,
size_bytes=total_size_bytes,
)
# Total number of written rows.
num_rows: int
# Total size in bytes of written data.
size_bytes: int
# All returned values of `Datasink.write`.
write_returns: List[WriteReturnType]


@DeveloperAPI
class Datasink:
class Datasink(Generic[WriteReturnType]):
"""Interface for defining write-related logic.

If you want to write data to something that isn't built-in, subclass this class
Expand All @@ -67,44 +47,32 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> WriteReturnType:
"""Write blocks. This is used by a single write task.

Args:
blocks: Generator of data blocks.
ctx: ``TaskContext`` for the write task.

Returns:
Result of this write task. When the entire write operator finishes,
All returned values will be passed as `WriteResult.write_returns`
to `Datasink.on_write_complete`.
"""
raise NotImplementedError

def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
def on_write_complete(self, write_result: WriteResult[WriteReturnType]):
"""Callback for when a write job completes.

This can be used to "commit" a write output. This method must
succeed prior to ``write_datasink()`` returning to the user. If this
method fails, then ``on_write_failed()`` is called.

Args:
write_result_blocks: The blocks resulting from executing
write_result: Aggregated result of the
the Write operator, containing write results and stats.
Returns:
A ``WriteResult`` object containing the aggregated stats of all
the input write results.
"""
write_results = [
result["write_result"].iloc[0] for result in write_result_blocks
]
aggregated_write_results = WriteResult.aggregate_write_results(write_results)

aggregated_results_str = ""
for k in fields(aggregated_write_results.__class__):
v = getattr(aggregated_write_results, k.name)
aggregated_results_str += f"\t- {k.name}: {v}\n"

logger.info(
f"Write operation succeeded. Aggregated write results:\n"
f"{aggregated_results_str}"
)
return aggregated_write_results
pass

def on_write_failed(self, error: Exception) -> None:
"""Callback for when a write job fails.
Expand Down Expand Up @@ -144,7 +112,7 @@ def num_rows_per_write(self) -> Optional[int]:


@DeveloperAPI
class DummyOutputDatasink(Datasink):
class DummyOutputDatasink(Datasink[None]):
"""An example implementation of a writable datasource for testing.
Examples:
>>> import ray
Expand Down Expand Up @@ -189,10 +157,8 @@ def write(
tasks.append(self.data_sink.write.remote(b))
ray.get(tasks)

def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
def on_write_complete(self, write_result: WriteResult[None]):
self.num_ok += 1
aggregated_results = super().on_write_complete(write_result_blocks)
return aggregated_results

def on_write_failed(self, error: Exception) -> None:
self.num_failed += 1
17 changes: 8 additions & 9 deletions python/ray/data/datasource/file_datasink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import posixpath
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
from urllib.parse import urlparse

from ray._private.utils import _add_creatable_buckets_param_if_s3_uri
Expand All @@ -27,7 +27,7 @@
WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32


class _FileDatasink(Datasink):
class _FileDatasink(Datasink[None]):
def __init__(
self,
path: str,
Expand Down Expand Up @@ -135,13 +135,10 @@ def write(
def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
raise NotImplementedError

def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
aggregated_results = super().on_write_complete(write_result_blocks)

def on_write_complete(self, write_result: WriteResult[None]):
# If no rows were written, we can delete the directory.
if self.has_created_dir and aggregated_results.num_rows == 0:
if self.has_created_dir and write_result.num_rows == 0:
self.filesystem.delete_dir(self.path)
return aggregated_results

@property
def supports_distributed_writes(self) -> bool:
Expand Down Expand Up @@ -194,13 +191,15 @@ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
)
write_path = posixpath.join(self.path, filename)

def write_row_to_path():
def write_row_to_path(row, write_path):
with self.open_output_stream(write_path) as file:
self.write_row_to_file(row, file)

logger.debug(f"Writing {write_path} file.")
call_with_retry(
write_row_to_path,
lambda row=row, write_path=write_path: write_row_to_path(
row, write_path
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this PR, but fixing the following lint error
python/ray/data/datasource/file_datasink.py:190:46: B023 Function definition does not bind loop variable 'write_path'.

description=f"write '{write_path}'",
match=DataContext.get_current().retried_io_errors,
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
Expand Down
Loading
Loading