Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Add some fixtures for a pluggable set of data + train release tests #50019

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Jan 23, 2025

Summary

Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
@justinvyu
Copy link
Contributor Author

Ray Data stats:

{'train': {'iter_stats': {'block_to_batch-avg': 0.42843372929091855,
                          'block_to_batch-max': 3.7064356099999713,
                          'block_to_batch-min': 4.692699985753279e-05,
                          'block_to_batch-total': 23.56385511100052,
                          'collate-avg': 0.09726231129088773,
                          'collate-max': 0.8624489379999432,
                          'collate-min': 0.007031553000160784,
                          'collate-total': 5.349427120998826,
                          'fetch_block-avg': 0.006737741999990653,
                          'fetch_block-max': 0.01227974199991877,
                          'fetch_block-min': 0.002948295999885886,
                          'fetch_block-total': 0.060639677999915875,
                          'finalize-avg': inf,
                          'finalize-max': 0,
                          'finalize-min': inf,
                          'finalize-total': 0,
                          'format_batch-avg': 0.015120474363605106,
                          'format_batch-max': 0.06806337499983783,
                          'format_batch-min': 0.006484282999736024,
                          'format_batch-total': 0.8316260899982808,
                          'prefetch_block-avg': inf,
                          'prefetch_block-max': 0,
                          'prefetch_block-min': inf,
                          'prefetch_block-total': 0,
                          'time_spent_blocked-avg': 0.7261243590391863,
                          'time_spent_blocked-max': 23.501931334000346,
                          'time_spent_blocked-min': 2.3814000087440945e-05,
                          'time_spent_blocked-total': 37.0323423109985,
                          'time_spent_training-avg': 0.35339732419997744,
                          'time_spent_training-max': 1.7951285889998871,
                          'time_spent_training-min': 0.14404996000030224,
                          'time_spent_training-total': 17.669866209998872},
           'throughput': 1558.5289483425217}}

Comment on lines +116 to +125
ds_output_summary = stats_summary.parents[0]
ds_throughput = (
ds_output_summary.operators_stats[-1].output_num_rows["sum"]
/ ds_output_summary.get_total_wall_time()
)

iter_stats = stats_summary.iter_stats

# TODO: Make this raw data dict easier to access from the iterator.
# The only way to access this through public API is a string representation.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataset stats are a bit hard to export as a dict right now. Can we add a xxxSummary.to_dict() instead of just to_string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to track the global dataset pipeline throughput, but that's not tracked on the dataset shard summary. So, I needed to pull it from the parent (aka the last stage before the streaming split).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. @srinathk10 can you improve this after this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay sure tracking this one.

Comment on lines +167 to +169
# Training blocked time
"time_spent_blocked-avg": iter_stats.block_time.avg(),
"time_spent_blocked-min": iter_stats.block_time.min(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should distinguish "first batch blocked time" vs. all the rest of the batches, since the first batch blocked time includes the dataset pipeline "warmup" time. This warmup time drags the average up a lot.

Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
import pydantic


class DataloaderType(enum.Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like most components under this file and factory.py are general. Shouldn't be put under the torch directory.


# Model
task: str = "image_classification"
model_name: str = "resnet50"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better remove these default values here, as it will be used by various different benchmarks

@@ -0,0 +1,189 @@
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, maybe put this file and imagenet.py under the subdir, so it's more clear that they belong to the same benchmark.

benchmark_config = parse_cli_args()
print(benchmark_config.model_dump_json(indent=2))

if benchmark_config.task == "image_classification":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, each benchmark may have different config items. it may be better to move task out of BenchmarkConfig.
We can first detect task to define which BenchmarkConfig subclass to use and then parse other configs.

local_throughput = batch_size / step_elapsed
# TODO: This calculation assumes that all workers have the same throughput.
global_throughput = local_throughput * ray.train.get_context().get_world_size()
self._train_throughput.add(global_throughput)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we adding the "throughput" for each step?
I think we need to to add the total # of rows and at the end compute the throughput by "total_rows/total_time", right?
Also, this calculation is only needed for rank 0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I found that you are collecting the average at the end.
But this is not accurate.
For example, for the first step, consuming 10 rows takes 10s; and the second step takes 1s to consume 10 rows.
With the current approach, the avg throughput is (10/10 + 10/1) / 2 = 5.5
But the actual avg throughput is (10 + 10) / (10 + 1) = 1.818  

with self._timers["validation/epoch"].timer():
validation_metrics = self.validate()

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why save the checkpoints to a local temp dir? This would break on spot instances.
I think we can either use a network-mounted dir or S3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants