Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loading a StatefulDataLoader state_dict creates a slightly different one in the dataloader, potentially dropping an epoch #1437

Open
gailweiss opened this issue Feb 3, 2025 · 3 comments · May be fixed by #1439
Assignees

Comments

@gailweiss
Copy link

gailweiss commented Feb 3, 2025

🐛 Describe the bug

Loading a state_dict taken from a StatefulDataLoader that has just completed an epoch yields what I would describe as a "just finishing" rather than "just finished" state: the next iteration over that dataloader does nothing (as opposed to a full new epoch), before continuing as expected on the following one.

Reproduction: I run a dataloader for two epochs, save the state dict, remake it, load the state dict, and try to run for two more epochs. I find only 3 epochs have run. They do align with the first 3 epochs of a separate 4-epoch run, but, they are 3 instead of 4.

(Side comment - the fact that they align here despite me not setting the random seed reveals that the shuffling of the StatefulDataLoader ignores the current state of the random number generators - this behaviour could maybe be clearer in the documentation, as it is not equivalent to torch.utils.data.DataLoader in this regard)

from torchdata.stateful_dataloader import StatefulDataLoader

def get_dl():
    d = list(range(100))
    return StatefulDataLoader(d, batch_size=1, shuffle=True)

def run_through(dl):
    for i, b in enumerate(dl):
        if i == 0:
            print(b)

def run_for_goes(goes):
    sd = None
    c = 0
    for n in goes:
        dl = get_dl()

        if None is not sd:
            print("loading state dict:", sd)
            dl.load_state_dict(sd)
            print("recall: loaded:", sd)
            print("state dict is now:", dl.state_dict())

        for j in range(n):
            print(c, j)
            run_through(dl)
            c += 1
    
        sd = dl.state_dict()

print("===")
run_for_goes([2,2])
print("===")
run_for_goes([4])

expected output:

===
0 0
tensor([45])
1 1
tensor([33])
loading state dict: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
recall: loaded: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
state dict is now: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
2 0
tensor([62])
3 1
tensor([19])
===
0 0
tensor([45])
1 1
tensor([33])
2 2
tensor([62])
3 3
tensor([19])

obtained output:

===
0 0
tensor([45])
1 1
tensor([33])
loading state dict: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
recall: loaded: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
state dict is now: {'_index_sampler_state': {'samples_yielded': 0, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 0}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
2 0
3 1
tensor([62])
===
0 0
tensor([45])
1 1
tensor([33])
2 2
tensor([62])
3 3
tensor([19])

Versions

PyTorch version: 2.2.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.0.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.8
[pip3] numpy==1.26.3
[pip3] pytorch-lightning==2.0.3
[pip3] torch==2.2.0
[pip3] torchaudio==2.2.0
[pip3] torchdata==0.10.1
[pip3] torchmetrics==1.1.2
[pip3] torchvision==0.15.2a0
[conda] msgpack-numpy 0.4.8 pypi_0 pypi
[conda] numpy 1.26.3 py311he598dae_0
[conda] numpy-base 1.26.3 py311hfbfe69c_0
[conda] pytorch 2.2.0 py3.11_0 pytorch
[conda] pytorch-lightning 2.0.3 py311hca03da5_0
[conda] torchaudio 2.2.0 py311_cpu pytorch
[conda] torchdata 0.10.1 pypi_0 pypi
[conda] torchmetrics 1.1.2 py311hca03da5_0
[conda] torchvision 0.15.2 cpu_py311he74fb5d_0

@andrewkho
Copy link
Contributor

andrewkho commented Feb 3, 2025

Thanks for reporting the issue @gailweiss ! I've repro'd this in colab with the latest stable release. The multi-procesing use case (ie num_workers>0) seems to work correctly. We should have parity between in-memory and multiprocess so this is a bug, and a gap in our testing.

For context: requesting a state_dict at the end of an epoch is slightly ambiguous, as you mentioned: is the state at the end of epoch, or at the beginning of the next one? IMO for this common use case, loading it should start from the next epoch.

@andrewkho
Copy link
Contributor

Here is the check I used:

from torchdata.stateful_dataloader import StatefulDataLoader

def get_dl(num_workers):
    d = list(range(100))
    return StatefulDataLoader(d, batch_size=1, shuffle=True, num_workers=num_workers)

def run_through(dl):
    for i, b in enumerate(dl):
        if i == 0:
            print(b)

def run_for_goes(goes, num_workers: int):
    sd = None
    c = 0
    for n in goes:
        dl = get_dl(num_workers)
        if None is not sd:
            print("loading state dict:", sd)
            dl.load_state_dict(sd)
            print("recall: loaded:", sd)
            print("state dict is now:", dl.state_dict())

        for j in range(n):
            print(c, j)
            run_through(dl)
            c += 1
    
        sd = dl.state_dict()

print("===")
run_for_goes([2,2], 0)
print("===")
run_for_goes([2,2], 1)

Output:

===
0 0
tensor([45])
1 1
tensor([33])
loading state dict: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
recall: loaded: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
state dict is now: {'_index_sampler_state': {'samples_yielded': 0, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 0}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
2 0
3 1
tensor([62])
===
0 0
tensor([45])
1 1
tensor([33])
loading state dict: {'_snapshot': {'_snapshot_step': 100, '_last_yielded_worker_id': 0, '_main_snapshot': {'_num_workers': 1, '_sampler_iter_state': None, '_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, '_base_seed': 3018434745330552193}, '_worker_snapshots': {'worker_0': {'worker_id': 0, 'dataset_state': None, 'fetcher_state': None}}}, '_steps_since_snapshot': 0, '_iterator_finished': True}
recall: loaded: {'_snapshot': {'_snapshot_step': 100, '_last_yielded_worker_id': 0, '_main_snapshot': {'_num_workers': 1, '_sampler_iter_state': None, '_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, '_base_seed': 3018434745330552193}, '_worker_snapshots': {'worker_0': {'worker_id': 0, 'dataset_state': None, 'fetcher_state': None}}}, '_steps_since_snapshot': 0, '_iterator_finished': True}
state dict is now: {'_snapshot': {'_snapshot_step': 100, '_last_yielded_worker_id': 0, '_main_snapshot': {'_num_workers': 1, '_sampler_iter_state': None, '_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, '_base_seed': 3018434745330552193}, '_worker_snapshots': {'worker_0': {'worker_id': 0, 'dataset_state': None, 'fetcher_state': None}}}, '_steps_since_snapshot': 0, '_iterator_finished': True}
2 0
tensor([62])
3 1
tensor([19])

@gailweiss
Copy link
Author

Thank you both for working on this!

Regarding the ambiguity at the end of an epoch, I think it should be resolved by the question of whether the state dict is taken from inside or outside the finished loop on the dataloader. Unfortunately I am on my phone now so I will only present code I am not sure will actually run:

dl = get_dl()
for i, b in enumerate(dl):
   sd_in = dl.state_dict()
# sd_in is now dl’s last state dict before the iteration ended

sd_out = dl.state_dict()

# IMO sd_out and sd_in should be different now

dl.load_state_dict(sd_in)
for i, b in enumerate(dl):
    print(“not supposed to be here - dl was finishing”)

dl.load_state_dict(sd_out)
print(“successful continuation after finished epoch:”, len([b.item() for b in dl]) == 100)

Is it possible that extra code is being written now to force the loading of sd_in and sd_out to result in the same state? To me as a user it would not be expected behaviour

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants