Skip to content

Commit

Permalink
add ability to save individual frames of interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 28, 2020
1 parent ad5d07a commit e40eb6f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ To generate a video of a interpolation through two random points in latent space
$ stylegan2_pytorch --generate-interpolation
```

To save each individual frame of the interpolation

```bash
$ stylegan2_pytorch --generate-interpolation --save-frames
```

If a previous checkpoint contained a better generator, (which often happens as generators start degrading towards the end of training), you can load from a previous checkpoint with another flag

```bash
Expand Down
3 changes: 2 additions & 1 deletion bin/stylegan2_pytorch
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def train_from_folder(
save_every = 1000,
generate = False,
generate_interpolation = False,
save_frames = False,
num_image_tiles = 8,
trunc_psi = 0.75,
fp16 = False,
Expand Down Expand Up @@ -72,7 +73,7 @@ def train_from_folder(
now = datetime.now()
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
samples_name = f'generated-{timestamp}'
model.generate_interpolation(samples_name, num_image_tiles)
model.generate_interpolation(samples_name, num_image_tiles, save_frames = save_frames)
print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
return

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'stylegan2_pytorch',
packages = find_packages(),
scripts=['bin/stylegan2_pytorch'],
version = '0.17.1',
version = '0.17.2',
license='GPLv3+',
description = 'StyleGan2 in Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 8 additions & 2 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles
return generated_images.clamp_(0., 1.)

@torch.no_grad()
def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0):
def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, save_frames = False):
self.GAN.eval()
ext = 'jpg' if not self.transparent else 'png'
num_rows = num_image_tiles
Expand All @@ -907,11 +907,17 @@ def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0):
latents = [(interp_latents, num_layers)]
generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
pil_image = transforms.ToPILImage()(images_grid.cpu()).convert('RGB')
pil_image = transforms.ToPILImage()(images_grid.cpu())
frames.append(pil_image)

frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)

if save_frames:
folder_path = (self.results_dir / self.name / f'{str(num)}')
folder_path.mkdir(parents=True, exist_ok=True)
for ind, frame in enumerate(frames):
frame.save(str(folder_path / f'{str(ind)}.{ext}'))

def print_log(self):
print(f'G: {self.g_loss:.2f} | D: {self.d_loss:.2f} | GP: {self.last_gp_loss:.2f} | PL: {self.pl_mean:.2f} | CR: {self.last_cr_loss:.2f} | Q: {self.q_loss:.2f}')

Expand Down

0 comments on commit e40eb6f

Please sign in to comment.