Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View
Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View
Xuanchi Ren*, Tao Yang*, Yuwang Wang and Wenjun Zeng
ICLR 2022
* indicates equal contribution
✅ Update StyleGAN2
✅ Update SNGAN
🔲 Update VAE
🔲 Update Glow
✅ Evaluation
In this repo, we propose an unsupervised and model-agnostic method: Disentanglement via Contrast (DisCo) in the Variation Space. This code discovers disentangled directions in the latent space and extract disentangled representations from images with Contrastive Learning. DisCo achieves the state-of-the-art disentanglement given pretrained non-disentangled generative models, including GAN, VAE, and Flow.
NOTE: The following results are obtained in a completely unsupervised manner. More results (including VAE and Flow) are presented in Appendix.
| FFHQ StyleGAN2 | |
|---|---|
| Pose | Smile |
| Race | Oldness |
| Overexpose | Hair |
| Shapes3D StyleGAN2 | |
|---|---|
| Wall Color | Floor Color |
| Object Color | Pose |
| Car3D StyleGAN2 | |
|---|---|
| Azimuth | Yaw |
| Anime SNGAN | |
|---|---|
| Pose | Natureness |
| Glass | Tone |
NOTE: DisCo achieves the state-of-the-art disentanglement
| Shapes3D | |
|---|---|
| MIG | DCI |
| Car3D | |
|---|---|
| MIG | DCI |
| MPI3D | |
|---|---|
| MIG | DCI |
- NVIDIA GPU + CUDA CuDNN
- Python 3
- Clone the repository:
git clone https://github.com/xrenaa/DisCo.git
cd DisCo
- Dependencies (To Do):
We recommend running this repository using Anaconda.
- Docker:
Alternatively, you can useDockerto run the code. We providethomasyt/gan-discfor easy use.
Please download the pre-trained models from the following links and put them to the corresponding paths.
| Path | Description |
|---|---|
| shapes3d_StyleGAN | StyleGAN2 model pretrained on shapes3d: range from 0-4.pt. Corresponding path: ./pretrained_weights/shapes3d/. |
| cars3d_StyleGAN | StyleGAN2 model pretrained on cars3d: range from 0-4.pt. Corresponding path: ./pretrained_weights/cars3d/. |
| mpi3d_StyleGAN | StyleGAN2 model pretrained on mpi3d: range from 0-4.pt. Corresponding path: ./pretrained_weights/mpi3d/. |
| shapes3d_VAE | VAE model pretrained on shapes3d: range from VAE_0-4. Corresponding path: ./pretrained_weights/shapes3d/. |
| cars3d_VAE | VAE model pretrained on cars3d: range from VAE_0-4. Corresponding path: ./pretrained_weights/cars3d/. |
| mpi3d_VAE | VAE model pretrained on mpi3d: range from VAE_0-4. Corresponding path: ./pretrained_weights/mpi3d/. |
For SNGAN, you can run the following code to download the weights for MNIST and Anime:
python ./pretrained_weights/download.py
To train the models, make sure you download the required models and put them to the correct path.
python train.py \
--G stylegan \
--dataset 0 \
--exp_name your_name \
--B 32 \
--N 32 \
--K 64
For --dataset, you can choose 0 for shapes3D, 1 for mpi3d, 2 for cars3d.
python train.py \
--G sngan \
--dataset 5 \
--exp_name your_name \
--B 32 \
--N 32 \
--K 64
For --dataset, you can choose 5 for MNIST, 6 for Anime.
-
Dependencies: For evaluation, you will need
tensorflow,gin-config. -
Download the dataset (except for Shapes3D):
cd data
./dlib_download_data.sh
For Shapes3D, you will first need to download the data from Google Cloud Storage. Click on this link and left-click the file 3dshapes.h5 to download it. Then you should put it under directory data.
- Run the evaluation:
python evaluate.py --dataset 0 --exp_name your_name
For --dataset, you can choose 0 for shapes3D, 1 for mpi3d, 2 for cars3d (you can only evaluate the performance on these datasets). The results will be put under the same directory with the checkpoint.
ProgGAN and BigGAN are based on: https://github.com/anvoynov/GANLatentDiscovery.
StyleGAN are based on: https://github.com/rosinality/stylegan2-pytorch.
Disentanglement metrics are based on: https://github.com/google-research/disentanglement_lib.
@inproceedings{ren2022DisCo,
title = {Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View},
author = {Xuanchi Ren, Tao Yang, Yuwang Wang, Wenjun Zeng},
booktitle = {ICLR},
year = {2022}
}