-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Add some fixtures for a pluggable set of data + train release tests #50019
base: master
Are you sure you want to change the base?
[train] Add some fixtures for a pluggable set of data + train release tests #50019
Conversation
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Ray Data stats:
|
Signed-off-by: Justin Yu <[email protected]>
ds_output_summary = stats_summary.parents[0] | ||
ds_throughput = ( | ||
ds_output_summary.operators_stats[-1].output_num_rows["sum"] | ||
/ ds_output_summary.get_total_wall_time() | ||
) | ||
|
||
iter_stats = stats_summary.iter_stats | ||
|
||
# TODO: Make this raw data dict easier to access from the iterator. | ||
# The only way to access this through public API is a string representation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dataset stats are a bit hard to export as a dict right now. Can we add a xxxSummary.to_dict()
instead of just to_string
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to track the global dataset pipeline throughput, but that's not tracked on the dataset shard summary. So, I needed to pull it from the parent (aka the last stage before the streaming split).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG. @srinathk10 can you improve this after this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay sure tracking this one.
# Training blocked time | ||
"time_spent_blocked-avg": iter_stats.block_time.avg(), | ||
"time_spent_blocked-min": iter_stats.block_time.min(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should distinguish "first batch blocked time" vs. all the rest of the batches, since the first batch blocked time includes the dataset pipeline "warmup" time. This warmup time drags the average up a lot.
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
import pydantic | ||
|
||
|
||
class DataloaderType(enum.Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like most components under this file and factory.py are general. Shouldn't be put under the torch
directory.
|
||
# Model | ||
task: str = "image_classification" | ||
model_name: str = "resnet50" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better remove these default values here, as it will be used by various different benchmarks
@@ -0,0 +1,189 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, maybe put this file and imagenet.py under the subdir, so it's more clear that they belong to the same benchmark.
benchmark_config = parse_cli_args() | ||
print(benchmark_config.model_dump_json(indent=2)) | ||
|
||
if benchmark_config.task == "image_classification": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, each benchmark may have different config items. it may be better to move task
out of BenchmarkConfig
.
We can first detect task
to define which BenchmarkConfig
subclass to use and then parse other configs.
local_throughput = batch_size / step_elapsed | ||
# TODO: This calculation assumes that all workers have the same throughput. | ||
global_throughput = local_throughput * ray.train.get_context().get_world_size() | ||
self._train_throughput.add(global_throughput) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we adding the "throughput" for each step?
I think we need to to add the total # of rows and at the end compute the throughput by "total_rows/total_time", right?
Also, this calculation is only needed for rank 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I found that you are collecting the average at the end.
But this is not accurate.
For example, for the first step, consuming 10 rows takes 10s; and the second step takes 1s to consume 10 rows.
With the current approach, the avg throughput is (10/10 + 10/1) / 2 = 5.5
But the actual avg throughput is (10 + 10) / (10 + 1) = 1.818
with self._timers["validation/epoch"].timer(): | ||
validation_metrics = self.validate() | ||
|
||
with tempfile.TemporaryDirectory() as temp_checkpoint_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why save the checkpoints to a local temp dir? This would break on spot instances.
I think we can either use a network-mounted dir or S3.
Summary