Skip to content

Commit

Permalink
add test_pix2pix_components
Browse files Browse the repository at this point in the history
  • Loading branch information
Jungwon-Lee committed Sep 18, 2022
1 parent 2059bb0 commit c2d9426
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
10 changes: 10 additions & 0 deletions pl_bolts/models/gans/pix2pix/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,23 @@ def forward(self, x: Tensor) -> Tensor:

class DownSampleConv(nn.Module):
def __init__(
<<<<<<< HEAD
self,
in_channels: int,
out_channels: int,
kernel: int = 4,
strides: int = 2,
padding: int = 1,
batchnorm: bool = True,
=======
self,
in_channels: int,
out_channels: int,
kernel: int = 4,
strides: int = 2,
padding: int = 1,
batchnorm: bool = True
>>>>>>> 9a8ce2f (add test_pix2pix_components)
) -> None:
super().__init__()
layers = [nn.Conv2d(in_channels, out_channels, kernel, strides, padding)]
Expand Down
41 changes: 41 additions & 0 deletions tests/models/gans/unit/test_pix2pix_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch
from pytorch_lightning import seed_everything

from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN


@pytest.mark.parametrize(
"in_shape, out_shape",
[
pytest.param((3, 128, 128), (3, 128, 128), id="multichannel"),
pytest.param((1, 128, 128), (3, 128, 128), id="singlechannel"),
],
)
def test_generator(catch_warnings, in_shape, out_shape):
batch_dim = 10
in_channels = in_shape.size(0)
out_channels = out_shape.size(0)
seed_everything(1234)
generator = Generator(in_channels=in_channels, out_channels=out_channels)
conditional_image = torch.randn(batch_dim, in_shape)
samples = generator(conditional_image)
assert samples.shape == (batch_dim, *out_shape)


@pytest.mark.parametrize(
"img_shape",
[
pytest.param((3, 128, 128), id="discriminator-multichannel"),
pytest.param((1, 128, 128), id="discriminator-singlechannel"),
],
)
def test_discriminator(catch_warnings, img_shape):
batch_dim = 10
in_channels = img_shape.size(0)
seed_everything(1234)
discriminator = PatchGAN(input_channels=in_channels)
samples = torch.randn(batch_dim, *img_shape)
real_or_fake = discriminator(samples)
assert real_or_fake.shape == (batch_dim, 1)
assert (torch.clamp(real_or_fake.clone(), 0, 1) == real_or_fake).all()

0 comments on commit c2d9426

Please sign in to comment.