Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data][Datasink] support passing write results to on_write_completes #49251

Merged
merged 23 commits into from
Dec 28, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11


class BigQueryDatasink(Datasink):
class BigQueryDatasink(Datasink[None]):
def __init__(
self,
project_id: str,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/mongo_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = logging.getLogger(__name__)


class MongoDatasink(Datasink):
class MongoDatasink(Datasink[None]):
def __init__(self, uri: str, database: str, collection: str) -> None:
_check_import(self, module="pymongo", package="pymongo")
_check_import(self, module="pymongoarrow", package="pymongoarrow")
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/sql_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ray.data.datasource.datasink import Datasink


class SQLDatasink(Datasink):
class SQLDatasink(Datasink[None]):

_MAX_ROWS_PER_WRITE = 128

Expand Down
37 changes: 32 additions & 5 deletions python/ray/data/_internal/planner/plan_write_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import itertools
from typing import Callable, Iterator, List, Union

from pandas import DataFrame

from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.interfaces.task_context import TaskContext
Expand All @@ -16,19 +18,39 @@
from ray.data.datasource.datasource import Datasource


def gen_data_sink_write_result(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Datasink is one word, so this seems more accurate?

Suggested change
def gen_data_sink_write_result(
def gen_datasink_write_result(

data_sink: Datasink,
write_result_blocks: List[Block],
) -> WriteResult:
assert all(
isinstance(block, DataFrame) and len(block) == 1
for block in write_result_blocks
)
total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks)
total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks)

write_task_results = [
result["write_task_result"][0] for result in write_result_blocks
]
return WriteResult(total_num_rows, total_num_rows, write_task_results)


def generate_write_fn(
datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args
) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]:
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
"""Writes the blocks to the given datasink or legacy datasource.

Outputs the original blocks to be written."""
# Create a copy of the iterator, so we can return the original blocks.
it1, it2 = itertools.tee(blocks, 2)
if isinstance(datasink_or_legacy_datasource, Datasink):
datasink_or_legacy_datasource.write(it1, ctx)
ctx.kwargs[
"_data_sink_custom_result"
] = datasink_or_legacy_datasource.write(it1, ctx)
else:
datasink_or_legacy_datasource.write(it1, ctx, **write_args)

return it2

return fn
Expand All @@ -41,7 +63,7 @@ def generate_collect_write_stats_fn() -> (
# one Block which contain stats/metrics about the write.
# Otherwise, an error will be raised. The Datasource can handle
# execution outcomes with `on_write_complete()`` and `on_write_failed()``.
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
"""Handles stats collection for block writes."""
block_accessors = [BlockAccessor.for_block(block) for block in blocks]
total_num_rows = sum(ba.num_rows() for ba in block_accessors)
Expand All @@ -51,8 +73,13 @@ def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
# type.
import pandas as pd

write_result = WriteResult(num_rows=total_num_rows, size_bytes=total_size_bytes)
block = pd.DataFrame({"write_result": [write_result]})
block = pd.DataFrame(
{
"num_rows": [total_num_rows],
"size_bytes": [total_size_bytes],
"write_task_result": [ctx.kwargs.get("_data_sink_custom_result", None)],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than passing information through the TaskContext, can we directly yield the write returns and stats in generate_write_fn.fn?

# Pseudocode for `generate_write_fn.fn`
for block in blocks:
    write_return = datasink.write(block)
    yield Block({"write_return": write_return, "num_rows": block.num_rows()})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, on second thought, do we even still need generate_collect_write_stats_fn?

I think we can return the write return and statistics from generate_write_fn.fn, and then aggregate the statistics and create WriteResult on the driver?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having 2 separate TransformFns allows optimization rules to insert certain operations in between them. And to pass data between them, TaskContext is probably the best place.

}
)
return iter([block])

return fn
Expand Down
14 changes: 9 additions & 5 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from ray.data._internal.execution.interfaces.ref_bundle import (
_ref_bundles_iterator_to_block_refs_list,
)
from ray.data._internal.execution.util import memory_string
from ray.data._internal.iterator.iterator_impl import DataIteratorImpl
from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator
from ray.data._internal.logical.operators.all_to_all_operator import (
Expand Down Expand Up @@ -78,6 +79,7 @@
from ray.data._internal.pandas_block import PandasBlockBuilder, PandasBlockSchema
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
from ray.data._internal.planner.plan_write_op import gen_data_sink_write_result
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.split import _get_num_rows, _split_at_indices
from ray.data._internal.stats import DatasetStats, DatasetStatsSummary, StatsManager
Expand Down Expand Up @@ -3875,19 +3877,21 @@ def write_datasink(
logical_plan = LogicalPlan(write_op, self.context)

try:
import pandas as pd

datasink.on_write_start()

self._write_ds = Dataset(plan, logical_plan).materialize()
# TODO: Get and handle the blocks with an iterator instead of getting
# everything in a blocking way, so some blocks can be freed earlier.
raw_write_results = ray.get(self._write_ds._plan.execute().block_refs)
assert all(
isinstance(block, pd.DataFrame) and len(block) == 1
for block in raw_write_results
write_result = gen_data_sink_write_result(datasink, raw_write_results)
logger.info(
"Data sink %s finished. %d rows and %s data written.",
datasink.get_name(),
write_result.num_rows,
memory_string(write_result.size_bytes),
)
datasink.on_write_complete(raw_write_results)
datasink.on_write_complete(write_result)

except Exception as e:
datasink.on_write_failed(e)
Expand Down
85 changes: 25 additions & 60 deletions python/ray/data/datasource/datasink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass, fields
from typing import Iterable, List, Optional
from dataclasses import dataclass
from typing import Generic, Iterable, List, Optional, TypeVar

import ray
from ray.data._internal.execution.interfaces import TaskContext
Expand All @@ -10,45 +10,24 @@
logger = logging.getLogger(__name__)


@dataclass
@DeveloperAPI
class WriteResult:
"""Result of a write operation, containing stats/metrics
on the written data.

Attributes:
total_num_rows: The total number of rows written.
total_size_bytes: The total size of the written data in bytes.
"""

num_rows: int = 0
size_bytes: int = 0

@staticmethod
def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult":
"""Aggregate a list of write results.

Args:
write_results: A list of write results.
# Generic type for the result of a write task.
WriteResultType = TypeVar("WriteResultType")

Returns:
A single write result that aggregates the input results.
"""
total_num_rows = 0
total_size_bytes = 0

for write_result in write_results:
total_num_rows += write_result.num_rows
total_size_bytes += write_result.size_bytes
@dataclass
class WriteResult(Generic[WriteResultType]):
"""Aggregated result of the Datasink write operations."""

return WriteResult(
num_rows=total_num_rows,
size_bytes=total_size_bytes,
)
# Total number of written rows.
num_rows: int
# Total size in bytes of written data.
size_bytes: int
# Results of all `Datasink.write`.
write_task_results: List[WriteResultType]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: To avoid confusion between the WriteResult dataclass and the object returned from write tasks, it might clearer if we rename write_task_results to write_return_types (and WriteResultType to WriteReturnType)

Suggested change
write_task_results: List[WriteResultType]
write_task_returns: List[WriteReturnType]



@DeveloperAPI
class Datasink:
class Datasink(Generic[WriteResultType]):
"""Interface for defining write-related logic.

If you want to write data to something that isn't built-in, subclass this class
Expand All @@ -67,44 +46,32 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> WriteResultType:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bveeramani Unrelated to this PR. But one pitfall about the current Datasink interface is that, the Datasink object will be used both on the driver (the on_xxx callbacks) and on the workers (this write function).
Users may mistakenly think that if they update an attribute in the write method, the update will be available on on_write_complete.

We should consider addressing this issue before making the Datasink API public. One solution is to introduce a separate DatasinkWriter class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree, that's janky.

One solution is to introduce a separate DatasinkWriter class.

Sounds reasonable.

We should consider addressing this issue before making the Datasink API public.

Makes sense. There's no urgency to make Datasink public.

"""Write blocks. This is used by a single write task.

Args:
blocks: Generator of data blocks.
ctx: ``TaskContext`` for the write task.

Returns:
Result of this write task. When the entire write operator finishes,
All write tasks results will be passed as `WriteResult.write_task_results`
to `Datasink.on_write_complete`.
"""
raise NotImplementedError

def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
def on_write_complete(self, write_result: WriteResult[WriteResultType]):
"""Callback for when a write job completes.

This can be used to "commit" a write output. This method must
succeed prior to ``write_datasink()`` returning to the user. If this
method fails, then ``on_write_failed()`` is called.

Args:
write_result_blocks: The blocks resulting from executing
write_result: Aggregated result of the
the Write operator, containing write results and stats.
Returns:
A ``WriteResult`` object containing the aggregated stats of all
the input write results.
"""
write_results = [
result["write_result"].iloc[0] for result in write_result_blocks
]
aggregated_write_results = WriteResult.aggregate_write_results(write_results)

aggregated_results_str = ""
for k in fields(aggregated_write_results.__class__):
v = getattr(aggregated_write_results, k.name)
aggregated_results_str += f"\t- {k.name}: {v}\n"

logger.info(
f"Write operation succeeded. Aggregated write results:\n"
f"{aggregated_results_str}"
)
return aggregated_write_results
pass

def on_write_failed(self, error: Exception) -> None:
"""Callback for when a write job fails.
Expand Down Expand Up @@ -144,7 +111,7 @@ def num_rows_per_write(self) -> Optional[int]:


@DeveloperAPI
class DummyOutputDatasink(Datasink):
class DummyOutputDatasink(Datasink[None]):
"""An example implementation of a writable datasource for testing.
Examples:
>>> import ray
Expand Down Expand Up @@ -189,10 +156,8 @@ def write(
tasks.append(self.data_sink.write.remote(b))
ray.get(tasks)

def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
def on_write_complete(self, write_result: WriteResult[None]):
self.num_ok += 1
aggregated_results = super().on_write_complete(write_result_blocks)
return aggregated_results

def on_write_failed(self, error: Exception) -> None:
self.num_failed += 1
11 changes: 4 additions & 7 deletions python/ray/data/datasource/file_datasink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import posixpath
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
from urllib.parse import urlparse

from ray._private.utils import _add_creatable_buckets_param_if_s3_uri
Expand All @@ -27,7 +27,7 @@
WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32


class _FileDatasink(Datasink):
class _FileDatasink(Datasink[None]):
def __init__(
self,
path: str,
Expand Down Expand Up @@ -130,13 +130,10 @@ def write(
def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
raise NotImplementedError

def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
aggregated_results = super().on_write_complete(write_result_blocks)

def on_write_complete(self, write_result: WriteResult[None]):
# If no rows were written, we can delete the directory.
if self.has_created_dir and aggregated_results.num_rows == 0:
if self.has_created_dir and write_result.num_rows == 0:
self.filesystem.delete_dir(self.path)
return aggregated_results

@property
def supports_distributed_writes(self) -> bool:
Expand Down
Loading