Skip to content

Commit

Permalink
address inconsistent ordering of fake/real logits going into losses f…
Browse files Browse the repository at this point in the history
…or dual contrastive loss, #289
  • Loading branch information
lucidrains committed Aug 23, 2024
1 parent e47fafa commit 5ff7b57
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ def slerp(val, low, high):
def gen_hinge_loss(fake, real):
return fake.mean()

def hinge_loss(real, fake):
def hinge_loss(fake, real):
return (F.relu(1 + real) + F.relu(1 - fake)).mean()

def dual_contrastive_loss(real_logits, fake_logits):
def dual_contrastive_loss(fake_logits, real_logits):
device = real_logits.device
real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))

def loss_half(t1, t2):
t1 = rearrange(t1, 'i -> i ()')
t1 = rearrange(t1, 'i -> i 1')
t2 = repeat(t2, 'j -> i j', i = t1.shape[0])
t = torch.cat((t1, t2), dim = -1)
return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))
Expand Down Expand Up @@ -1043,7 +1043,7 @@ def train(self):
real_output_loss = real_output_loss - fake_output.mean()
fake_output_loss = fake_output_loss - real_output.mean()

divergence = D_loss_fn(real_output_loss, fake_output_loss)
divergence = D_loss_fn(fake_output_loss, real_output_loss)
disc_loss = divergence

if self.has_fq:
Expand Down
2 changes: 1 addition & 1 deletion stylegan2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.8.10'
__version__ = '1.8.11'

0 comments on commit 5ff7b57

Please sign in to comment.