-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[data][Datasink] support passing write results to on_write_completes #49251
Changes from 6 commits
87b04de
04814ce
10ffd4d
af508d2
b34feee
073972c
d479c4e
71d8ab4
71591ab
e5e6bc6
bede1b8
a94c4d6
d738409
052f884
974f9e5
7246c6b
20bf620
3c654bc
01f4b24
da5de5d
aa1970c
af8afe2
8f6baf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import itertools | ||
from typing import Callable, Iterator, List, Union | ||
|
||
from pandas import DataFrame | ||
|
||
from ray.data._internal.compute import TaskPoolStrategy | ||
from ray.data._internal.execution.interfaces import PhysicalOperator | ||
from ray.data._internal.execution.interfaces.task_context import TaskContext | ||
|
@@ -16,19 +18,39 @@ | |
from ray.data.datasource.datasource import Datasource | ||
|
||
|
||
def gen_data_sink_write_result( | ||
data_sink: Datasink, | ||
write_result_blocks: List[Block], | ||
) -> WriteResult: | ||
assert all( | ||
isinstance(block, DataFrame) and len(block) == 1 | ||
for block in write_result_blocks | ||
) | ||
total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks) | ||
total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks) | ||
|
||
write_task_results = [ | ||
result["write_task_result"][0] for result in write_result_blocks | ||
] | ||
return WriteResult(total_num_rows, total_num_rows, write_task_results) | ||
|
||
|
||
def generate_write_fn( | ||
datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args | ||
) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]: | ||
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: | ||
def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: | ||
"""Writes the blocks to the given datasink or legacy datasource. | ||
|
||
Outputs the original blocks to be written.""" | ||
# Create a copy of the iterator, so we can return the original blocks. | ||
it1, it2 = itertools.tee(blocks, 2) | ||
if isinstance(datasink_or_legacy_datasource, Datasink): | ||
datasink_or_legacy_datasource.write(it1, ctx) | ||
ctx.kwargs[ | ||
"_data_sink_custom_result" | ||
] = datasink_or_legacy_datasource.write(it1, ctx) | ||
else: | ||
datasink_or_legacy_datasource.write(it1, ctx, **write_args) | ||
|
||
return it2 | ||
|
||
return fn | ||
|
@@ -41,7 +63,7 @@ def generate_collect_write_stats_fn() -> ( | |
# one Block which contain stats/metrics about the write. | ||
# Otherwise, an error will be raised. The Datasource can handle | ||
# execution outcomes with `on_write_complete()`` and `on_write_failed()``. | ||
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: | ||
def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: | ||
"""Handles stats collection for block writes.""" | ||
block_accessors = [BlockAccessor.for_block(block) for block in blocks] | ||
total_num_rows = sum(ba.num_rows() for ba in block_accessors) | ||
|
@@ -51,8 +73,13 @@ def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: | |
# type. | ||
import pandas as pd | ||
|
||
write_result = WriteResult(num_rows=total_num_rows, size_bytes=total_size_bytes) | ||
block = pd.DataFrame({"write_result": [write_result]}) | ||
block = pd.DataFrame( | ||
{ | ||
"num_rows": [total_num_rows], | ||
"size_bytes": [total_size_bytes], | ||
"write_task_result": [ctx.kwargs.get("_data_sink_custom_result", None)], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than passing information through the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, on second thought, do we even still need I think we can return the write return and statistics from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. having 2 separate TransformFns allows optimization rules to insert certain operations in between them. And to pass data between them, TaskContext is probably the best place. |
||
} | ||
) | ||
return iter([block]) | ||
|
||
return fn | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,6 @@ | ||||||
import logging | ||||||
from dataclasses import dataclass, fields | ||||||
from typing import Iterable, List, Optional | ||||||
from dataclasses import dataclass | ||||||
from typing import Generic, Iterable, List, Optional, TypeVar | ||||||
|
||||||
import ray | ||||||
from ray.data._internal.execution.interfaces import TaskContext | ||||||
|
@@ -10,45 +10,24 @@ | |||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
|
||||||
@dataclass | ||||||
@DeveloperAPI | ||||||
class WriteResult: | ||||||
"""Result of a write operation, containing stats/metrics | ||||||
on the written data. | ||||||
|
||||||
Attributes: | ||||||
total_num_rows: The total number of rows written. | ||||||
total_size_bytes: The total size of the written data in bytes. | ||||||
""" | ||||||
|
||||||
num_rows: int = 0 | ||||||
size_bytes: int = 0 | ||||||
|
||||||
@staticmethod | ||||||
def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult": | ||||||
"""Aggregate a list of write results. | ||||||
|
||||||
Args: | ||||||
write_results: A list of write results. | ||||||
# Generic type for the result of a write task. | ||||||
WriteResultType = TypeVar("WriteResultType") | ||||||
|
||||||
Returns: | ||||||
A single write result that aggregates the input results. | ||||||
""" | ||||||
total_num_rows = 0 | ||||||
total_size_bytes = 0 | ||||||
|
||||||
for write_result in write_results: | ||||||
total_num_rows += write_result.num_rows | ||||||
total_size_bytes += write_result.size_bytes | ||||||
@dataclass | ||||||
class WriteResult(Generic[WriteResultType]): | ||||||
"""Aggregated result of the Datasink write operations.""" | ||||||
|
||||||
return WriteResult( | ||||||
num_rows=total_num_rows, | ||||||
size_bytes=total_size_bytes, | ||||||
) | ||||||
# Total number of written rows. | ||||||
num_rows: int | ||||||
# Total size in bytes of written data. | ||||||
size_bytes: int | ||||||
# Results of all `Datasink.write`. | ||||||
write_task_results: List[WriteResultType] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: To avoid confusion between the
Suggested change
|
||||||
|
||||||
|
||||||
@DeveloperAPI | ||||||
class Datasink: | ||||||
class Datasink(Generic[WriteResultType]): | ||||||
"""Interface for defining write-related logic. | ||||||
|
||||||
If you want to write data to something that isn't built-in, subclass this class | ||||||
|
@@ -67,44 +46,32 @@ def write( | |||||
self, | ||||||
blocks: Iterable[Block], | ||||||
ctx: TaskContext, | ||||||
) -> None: | ||||||
) -> WriteResultType: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bveeramani Unrelated to this PR. But one pitfall about the current Datasink interface is that, the Datasink object will be used both on the driver (the We should consider addressing this issue before making the Datasink API public. One solution is to introduce a separate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I agree, that's janky.
Sounds reasonable.
Makes sense. There's no urgency to make |
||||||
"""Write blocks. This is used by a single write task. | ||||||
|
||||||
Args: | ||||||
blocks: Generator of data blocks. | ||||||
ctx: ``TaskContext`` for the write task. | ||||||
|
||||||
Returns: | ||||||
Result of this write task. When the entire write operator finishes, | ||||||
All write tasks results will be passed as `WriteResult.write_task_results` | ||||||
to `Datasink.on_write_complete`. | ||||||
""" | ||||||
raise NotImplementedError | ||||||
|
||||||
def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: | ||||||
def on_write_complete(self, write_result: WriteResult[WriteResultType]): | ||||||
"""Callback for when a write job completes. | ||||||
|
||||||
This can be used to "commit" a write output. This method must | ||||||
succeed prior to ``write_datasink()`` returning to the user. If this | ||||||
method fails, then ``on_write_failed()`` is called. | ||||||
|
||||||
Args: | ||||||
write_result_blocks: The blocks resulting from executing | ||||||
write_result: Aggregated result of the | ||||||
the Write operator, containing write results and stats. | ||||||
Returns: | ||||||
A ``WriteResult`` object containing the aggregated stats of all | ||||||
the input write results. | ||||||
""" | ||||||
write_results = [ | ||||||
result["write_result"].iloc[0] for result in write_result_blocks | ||||||
] | ||||||
aggregated_write_results = WriteResult.aggregate_write_results(write_results) | ||||||
|
||||||
aggregated_results_str = "" | ||||||
for k in fields(aggregated_write_results.__class__): | ||||||
v = getattr(aggregated_write_results, k.name) | ||||||
aggregated_results_str += f"\t- {k.name}: {v}\n" | ||||||
|
||||||
logger.info( | ||||||
f"Write operation succeeded. Aggregated write results:\n" | ||||||
f"{aggregated_results_str}" | ||||||
) | ||||||
return aggregated_write_results | ||||||
pass | ||||||
|
||||||
def on_write_failed(self, error: Exception) -> None: | ||||||
"""Callback for when a write job fails. | ||||||
|
@@ -144,7 +111,7 @@ def num_rows_per_write(self) -> Optional[int]: | |||||
|
||||||
|
||||||
@DeveloperAPI | ||||||
class DummyOutputDatasink(Datasink): | ||||||
class DummyOutputDatasink(Datasink[None]): | ||||||
"""An example implementation of a writable datasource for testing. | ||||||
Examples: | ||||||
>>> import ray | ||||||
|
@@ -189,10 +156,8 @@ def write( | |||||
tasks.append(self.data_sink.write.remote(b)) | ||||||
ray.get(tasks) | ||||||
|
||||||
def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: | ||||||
def on_write_complete(self, write_result: WriteResult[None]): | ||||||
self.num_ok += 1 | ||||||
aggregated_results = super().on_write_complete(write_result_blocks) | ||||||
return aggregated_results | ||||||
|
||||||
def on_write_failed(self, error: Exception) -> None: | ||||||
self.num_failed += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
Datasink
is one word, so this seems more accurate?