This repository implements VQVAE for mnist and colored version of mnist and follows up with a simple LSTM for generating numbers.
![VQVAE Video](https://private-user-images.githubusercontent.com/144267687/302552988-a411d732-8c99-41fb-b39c-dd2c3fbfa448.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2MTQyMDIsIm5iZiI6MTczOTYxMzkwMiwicGF0aCI6Ii8xNDQyNjc2ODcvMzAyNTUyOTg4LWE0MTFkNzMyLThjOTktNDFmYi1iMzljLWRkMmMzZmJmYTQ0OC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE1JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNVQxMDA1MDJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02NmU2ZjI4ZGM2YWU4YTNlYWE4Yzk2ZDYxMDNlOGFmNWY1YmUzNDk3ZDY2NmQwNjk3YmU4NGRmZDc4NjEzYjMzJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.ObsBhxt-1XhKoUFaxYbPi4CXVfJbWZr8c1B54hCjGQM)
- Create a new conda environment with python 3.8 then run below commands
git clone https://github.com/explainingai-code/VQVAE-Pytorch.git
cd VQVAE-Pytorch
pip install -r requirements.txt
- For running a simple VQVAE with minimal code to understand the basics
python run_simple_vqvae.py
- For playing around with VQVAE and training/inferencing the LSTM use the below commands passing the desired configuration file as the config argument
python -m tools.train_vqvae
for training vqvaepython -m tools.infer_vqvae
for generating reconstructions and encoder outputs for LSTM trainingpython -m tools.train_lstm
for training minimal LSTMpython -m tools.generate_images
for using the trained LSTM to generate some numbers
config/vqvae_mnist.yaml
- VQVAE for training on black and white mnist imagesconfig/vqvae_colored_mnist.yaml
- VQVAE with more embedding vectors for training colored mnist images
For setting up the dataset: Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
Verify the data directory has the following structure:
VQVAE-Pytorch/data/train/images/{0/1/.../9}
*.png
VQVAE-Pytorch/data/test/images/{0/1/.../9}
*.png
Outputs will be saved according to the configuration present in yaml files.
For every run a folder of task_name
key in config will be created and output_train_dir
will be created inside it.
During training of VQVAE the following output will be saved
- Best Model checkpoints(VQVAE and LSTM) in
task_name
directory
During inference the following output will be saved
- Reconstructions for sample of test set in
task_name/output_train_dir/reconstruction.png
- Encoder outputs on train set for LSTM training in
task_name/output_train_dir/mnist_encodings.pkl
- LSTM generation output in
task_name/output_train_dir/generation_results.png
Running run_simple_vqvae
should be very quick (as its very simple model) and give you below reconstructions (input in black black background and reconstruction in white background)
![](https://private-user-images.githubusercontent.com/144267687/273195458-607fb5a8-b880-4af5-8ce0-5d7127aa66a7.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2MTQyMDIsIm5iZiI6MTczOTYxMzkwMiwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMTk1NDU4LTYwN2ZiNWE4LWI4ODAtNGFmNS04Y2UwLTVkNzEyN2FhNjZhNy5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE1JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNVQxMDA1MDJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT00NTcwMTg4NTEwOGQ2NzdjYmFjYWRjNmNiZDdmZWIwZTEzZTQ2ZWQ1YzQ2NzlhY2FlZjE4NjA3N2RlNTVmZDFiJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.dsZGrazVI0mpiknaggrqy5-JYVquUa2IbDiQo0PDFnU)
Running default config VQVAE for mnist should give you below reconstructions for both versions
![](https://private-user-images.githubusercontent.com/144267687/273195513-939f8f22-0145-467f-8cd6-4b6c6e6f315f.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2MTQyMDIsIm5iZiI6MTczOTYxMzkwMiwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMTk1NTEzLTkzOWY4ZjIyLTAxNDUtNDY3Zi04Y2Q2LTRiNmM2ZTZmMzE1Zi5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE1JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNVQxMDA1MDJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT04NDFmMmY4MWM0ZDdjZGE5MjdkYTEyYjllOTA5ZjRhNjNjNGU1MTY0Y2RhMTc2YTFmZDU4Nzk5Y2JmM2MzMTYwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.f5ylfl88BmnXTK6eF2YUJkSGKIET8wysiWzb19ESUlQ)
![](https://private-user-images.githubusercontent.com/144267687/273195627-0e28286a-bc4c-44e3-a385-84d1ae99492c.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2MTQyMDIsIm5iZiI6MTczOTYxMzkwMiwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMTk1NjI3LTBlMjgyODZhLWJjNGMtNDRlMy1hMzg1LTg0ZDFhZTk5NDkyYy5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE1JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNVQxMDA1MDJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1iMjI1NmU0NzU1YTczYjM0MmQ2M2FmMDVmZGEzMDFjZTFjODUzNGUzYzk1NzNhMTg1ZDIyMGVmYTRmNmM3NDM0JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.xSe_bgrTSdQ5IEbf818HFycR-30ewahbcHO9-qnZmbA)
Sample Generation Output after just 10 epochs Training the vqvae and lstm longer and more parameters(codebook size, codebook dimension, channels , lstm hidden dimension e.t.c) will give better results
![](https://private-user-images.githubusercontent.com/144267687/273196494-688a6631-df34-4fde-9508-a05ae3c2ae91.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2MTQyMDIsIm5iZiI6MTczOTYxMzkwMiwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMTk2NDk0LTY4OGE2NjMxLWRmMzQtNGZkZS05NTA4LWEwNWFlM2MyYWU5MS5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE1JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNVQxMDA1MDJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1jMTMzNDAxYzlhNDdhNDZjMTg3ODQ5NjRiZTM4NDQyMTdkNzgxZTBiZTE0OGM0YjY0MjQ0MDJjZTY5YmE2YmYxJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.030HJbmIPtQoppJcouTJxPPZc4t8IPc_Sp-ak6_5NrU)
![](https://private-user-images.githubusercontent.com/144267687/273203035-187fa630-a7ef-4f0b-aef7-5c6b53019b38.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2MTQyMDIsIm5iZiI6MTczOTYxMzkwMiwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMjAzMDM1LTE4N2ZhNjMwLWE3ZWYtNGYwYi1hZWY3LTVjNmI1MzAxOWIzOC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE1JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNVQxMDA1MDJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT04MmJhN2QwNjcxMDZkOTViZGYyOTY0YWU4MzQwMGQwNzkyZTdlY2QzNmQ2NTVjYjhlMjY0MDYxYThhODk4Mjc0JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.eeuVNGVXEWBh0obfiVqLGc5lUafZV7auG_s0H9pBYqk)
@misc{oord2018neural,
title={Neural Discrete Representation Learning},
author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
year={2018},
eprint={1711.00937},
archivePrefix={arXiv},
primaryClass={cs.LG}
}