Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model diverges or struggles to converge with complex-valued tensors in DDP #20480

Open
ouioui199 opened this issue Dec 9, 2024 · 3 comments
Open
Labels
3rd party Related to a 3rd-party bug Something isn't working ver: 2.4.x

Comments

@ouioui199
Copy link

Bug description

Hello,

I am using lightning to train a complex-valued neural networks with complex valued tensor. When I use single gpu training, there is no issue. When I train with multi-gpus with DDP, my training diverges. I try to train on only one gpu, and still declaring " strategy='ddp' " in the trainer, the training also diverge.

I've tried to reproduce the issue with the code sample below. MNIST dataset and the model defined in this sample are simpler than in my current work, so the model won't diverge but really struggle to converge. To check if the issue happens, just comment the line " strategy='ddp' " in the trainer.

This seems to be related to #55375 and #60931

What version are you seeing the problem on?

v2.4

How to reproduce the bug

from typing import List

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as v2_transforms
import lightning as L
import torchcvnn.nn as c_nn
from torchmetrics.classification import Accuracy
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.progress import TQDMProgressBar
from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm
from lightning.pytorch.utilities import rank_zero_only


def conv_block(in_c: int, out_c: int, cdtype: torch.dtype) -> List[nn.Module]:
    return [
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
        c_nn.BatchNorm2d(out_c),
        c_nn.Cardioid(),
        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
        c_nn.BatchNorm2d(out_c),
        c_nn.Cardioid(),
        c_nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
    ]


class TBLogger(TensorBoardLogger):
    @rank_zero_only
    def log_metrics(self, metrics, step):
        metrics.pop('epoch', None)
        metrics = {k: v for k, v in metrics.items() if ('step' not in k) and ('val' not in k)}
        return super().log_metrics(metrics, step)
    
    
class CustomProgressBar(TQDMProgressBar):
    
    def get_metrics(self, trainer, model):
        items = super().get_metrics(trainer, model)
        items.pop("v_num", None)
        return items
    
    def init_train_tqdm(self) -> Tqdm:
        """Override this to customize the tqdm bar for training."""
        bar = super().init_train_tqdm()
        bar.ascii = ' >'
        return bar
    
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.ascii = ' >'
        return bar


class cMNISTModel(L.LightningModule):

    def __init__(self):
        super().__init__()

        self.ce_loss = nn.CrossEntropyLoss()
        self.model = self.configure_model()
        self.accuracy = Accuracy(task='multiclass', num_classes=10)
        
        self.train_step_outputs = {}
        self.valid_step_outputs = {}

    def configure_model(self):
        conv_model = nn.Sequential(
            *conv_block(1, 16, torch.complex64),
            *conv_block(16, 16, torch.complex64),
            *conv_block(16, 32, torch.complex64),
            *conv_block(32, 32, torch.complex64),
            nn.Flatten(),
        )

        with torch.no_grad():
            conv_model.eval()
            dummy_input = torch.zeros((64, 1, 28, 28), dtype=torch.complex64, requires_grad=False)
            out_conv = conv_model(dummy_input).view(64, -1)
        lin_model = nn.Sequential(
            nn.Linear(out_conv.shape[-1], 124, dtype=torch.complex64),
            c_nn.Cardioid(),
            nn.Linear(124, 10, dtype=torch.complex64),
            c_nn.Mod(),
        )

        return nn.Sequential(conv_model, lin_model)
    
    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(), lr=3e-4)
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        logits = self(data)

        loss = self.ce_loss(logits, label)
        acc = self.accuracy(logits, label)

        self.log('step_loss', loss, prog_bar=True, sync_dist=True)
        self.log('step_metrics', acc, prog_bar=True, sync_dist=True)
        
        if not self.train_step_outputs:
            self.train_step_outputs = {
                'step_loss': [loss],
                'step_metrics': [acc]
            }
        else:
            self.train_step_outputs['step_loss'].append(loss)
            self.train_step_outputs['step_metrics'].append(acc)

        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        images, labels = batch
        logits = self(images)

        loss = self.ce_loss(logits, labels)
        acc = self.accuracy(logits, labels)
        self.log('step_loss', loss, prog_bar=True, sync_dist=True)
        self.log('step_metrics', acc, prog_bar=True, sync_dist=True)
        
        if not self.valid_step_outputs:
            self.valid_step_outputs = {
                'step_loss': [loss],
                'step_metrics': [acc]
            }
        else:
            self.valid_step_outputs['step_loss'].append(loss)
            self.valid_step_outputs['step_metrics'].append(acc)

    def on_train_epoch_end(self) -> None:
        _log_dict = {
            'Loss/loss': torch.tensor(self.train_step_outputs['step_loss']).mean(),
            'Metrics/accuracy': torch.tensor(self.train_step_outputs['step_metrics']).mean()
        }
        
        self.loggers[0].log_metrics(_log_dict, self.current_epoch)
        self.train_step_outputs.clear()

    def on_validation_epoch_end(self) -> None:
        mean_loss_value = torch.tensor(self.valid_step_outputs['step_loss']).mean()
        mean_metrics_value = torch.tensor(self.valid_step_outputs['step_metrics']).mean()
        
        _log_dict = {
            'Loss/loss': mean_loss_value,
            'Metrics/accuracy': mean_metrics_value
        }
        
        self.loggers[1].log_metrics(_log_dict, self.current_epoch)
        
        self.log('val_loss', mean_loss_value, sync_dist=True)
        self.log('val_Accuracy', mean_metrics_value, sync_dist=True)
        self.valid_step_outputs.clear()


def train():
    batch_size = 64
    epochs = 10
    torch.set_float32_matmul_precision('high')

    # Dataloading
    train_dataset = torchvision.datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
    )
    valid_dataset = torchvision.datasets.MNIST(
        root="./data",
        train=False,
        download=True,
        transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
    )

    # Train dataloader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=4,
        persistent_workers=True,
        pin_memory=True
    )

    # Valid dataloader
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=4,
        persistent_workers=True,
        pin_memory=True
    )

    model = cMNISTModel()
    trainer = L.Trainer(
        max_epochs=epochs,
        strategy='ddp_find_unused_parameters_true',
        num_sanity_val_steps=0,
        benchmark=True,
        enable_checkpointing=True,
        callbacks=[
            CustomProgressBar(),
            EarlyStopping(
                monitor='val_loss', 
                verbose=True,
                patience=5,
                min_delta=0.005
            ),
            LearningRateMonitor(logging_interval='epoch'),
            ModelCheckpoint(
                dirpath='weights_storage_/',
                monitor='val_Accuracy', 
                verbose=True, 
                mode='max'
            )
        ],
        logger=[
            TBLogger('training_logs_', name=None, sub_dir='train'),
            TBLogger('training_logs_', name=None, sub_dir='valid')
        ]
    )

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
    

if __name__ == "__main__":
    train()

Error messages and logs

No response

Environment

Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.5.1
#- Python version: 3.12.7
#- OS: Linux Ubuntu 24.04.1 or Slurm
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration: RTX 4090 (Ubuntu pc), NVIDIA A100 40G (Slurm)
#- How you installed Lightning: pip

More info

@jeremyfix @QuentinGABOT might also be interested in this issue

@ouioui199 ouioui199 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Dec 9, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 11, 2024

Thank you for the report. I don't have first-hand experience with complex + distributed, but it indeed looks like there is an issue upstream with DDP not behaving correctly with complex valued tensors, from looking at pytorch/pytorch#55375

@lantiga lantiga added 3rd party Related to a 3rd-party and removed needs triage Waiting to be triaged by maintainers labels Dec 11, 2024
@ouioui199
Copy link
Author

I debugged the code. In fact, it seems that DDP doesn't cause make the view_as_real tensor as in pytorch/pytorch#55375 , all tensors and models have dtype torch.complex64. But I don't know what happen in the distributing process that causes this.

@ouioui199
Copy link
Author

Hello, is there any update on this ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

2 participants