-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
loading a StatefulDataLoader state_dict creates a slightly different one in the dataloader, potentially dropping an epoch #1437
Comments
Thanks for reporting the issue @gailweiss ! I've repro'd this in colab with the latest stable release. The multi-procesing use case (ie For context: requesting a state_dict at the end of an epoch is slightly ambiguous, as you mentioned: is the state at the end of epoch, or at the beginning of the next one? IMO for this common use case, loading it should start from the next epoch. |
Here is the check I used:
Output:
|
Thank you both for working on this! Regarding the ambiguity at the end of an epoch, I think it should be resolved by the question of whether the state dict is taken from inside or outside the finished loop on the dataloader. Unfortunately I am on my phone now so I will only present code I am not sure will actually run:
Is it possible that extra code is being written now to force the loading of sd_in and sd_out to result in the same state? To me as a user it would not be expected behaviour |
🐛 Describe the bug
Loading a state_dict taken from a StatefulDataLoader that has just completed an epoch yields what I would describe as a "just finishing" rather than "just finished" state: the next iteration over that dataloader does nothing (as opposed to a full new epoch), before continuing as expected on the following one.
Reproduction: I run a dataloader for two epochs, save the state dict, remake it, load the state dict, and try to run for two more epochs. I find only 3 epochs have run. They do align with the first 3 epochs of a separate 4-epoch run, but, they are 3 instead of 4.
(Side comment - the fact that they align here despite me not setting the random seed reveals that the shuffling of the StatefulDataLoader ignores the current state of the random number generators - this behaviour could maybe be clearer in the documentation, as it is not equivalent to torch.utils.data.DataLoader in this regard)
expected output:
obtained output:
Versions
PyTorch version: 2.2.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.0.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: Could not collect
Libc version: N/A
Python version: 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.8
[pip3] numpy==1.26.3
[pip3] pytorch-lightning==2.0.3
[pip3] torch==2.2.0
[pip3] torchaudio==2.2.0
[pip3] torchdata==0.10.1
[pip3] torchmetrics==1.1.2
[pip3] torchvision==0.15.2a0
[conda] msgpack-numpy 0.4.8 pypi_0 pypi
[conda] numpy 1.26.3 py311he598dae_0
[conda] numpy-base 1.26.3 py311hfbfe69c_0
[conda] pytorch 2.2.0 py3.11_0 pytorch
[conda] pytorch-lightning 2.0.3 py311hca03da5_0
[conda] torchaudio 2.2.0 py311_cpu pytorch
[conda] torchdata 0.10.1 pypi_0 pypi
[conda] torchmetrics 1.1.2 py311hca03da5_0
[conda] torchvision 0.15.2 cpu_py311he74fb5d_0
The text was updated successfully, but these errors were encountered: