Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration for Weights and Biases #34

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ dmypy.json
.pyre/

wandb/
artifacts/
input/
output/
*.lmdb/
*.pkl
*.pt
Expand Down
39 changes: 9 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@


## Description
Official Implementation of Barbershop. **KEEP UPDATING! Please Git Pull the latest version.**

## Updates
`2021/12/27` Add dilation and erosion parameters to smooth the boundary.

#### `2021/12/24` Important Update: Add semantic mask inpainting module to solve the occlusion problem. Please git pull the latest version.

`2021/12/18` Add a rough version of the project.

`2021/06/02` Add project page.
This repository is a fork of the [official implmentation of Barbershop](https://github.com/ZPdesu/Barbershop). This repository build on the official reporsitory to add the following features:

- Combine [`main.py`](https://github.com/ZPdesu/Barbershop/blob/main/main.py) and [`align_face.py`](https://github.com/ZPdesu/Barbershop/blob/main/align_face.py) into a single command line interface as part of the updated [`main.py`](https://github.com/soumik12345/Barbershop/blob/main/main.py).
- Provide a notebook [`inference.ipynb`](https://github.com/soumik12345/Barbershop/blob/main/inference.ipynb) for performing step-by-step inference and visualization of the result.
- Add an integration with Weights & Biases, which enables the predictions to be visualized as a W&B Table. The integration works with both the script and the notebook.

## Installation
- Clone the repository:
Expand All @@ -37,36 +32,20 @@ cd Barbershop
We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/).
All dependencies for defining the environment are provided in `environment/environment.yaml`.


## Download II2S images
Please download the [II2S](https://drive.google.com/drive/folders/15jsR9yy_pfDHiS9aE3HcYDgwtBbAneId?usp=sharing)
and put them in the `input/face` folder.


## Getting Started
Preprocess your own images. Please put the raw images in the `unprocessed` folder.
```
python align_face.py
```

## Getting Started
Produce realistic results:
```
python main.py --im_path1 90.png --im_path2 15.png --im_path3 117.png --sign realistic --smooth 5
python main.py --identity_image 90.png --structure_image 15.png --appearance_image 117.png --sign realistic --smooth 5
```

Produce results faithful to the masks:
```
python main.py --im_path1 90.png --im_path2 15.png --im_path3 117.png --sign fidelity --smooth 5
python main.py --identity_image 90.png --structure_image 15.png --appearance_image 117.png --sign fidelity --smooth 5
```

You can also use the [Jupyter Notebook](./inference.ipynb) to producde the results. The results are now logged automatically as a Weights and Biases Table.


## Todo List
* add a detailed readme
* update mask inpainting code
* integrate image encoder
* add preprocessing step
* ...
![](https://i.imgur.com/subthu8.png)

## Acknowledgments
This code borrows heavily from [II2S](https://github.com/ZPdesu/II2S).
Expand Down
56 changes: 38 additions & 18 deletions align_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,65 @@
from utils.shape_predictor import align_face
import PIL

parser = argparse.ArgumentParser(description='Align_face')
parser = argparse.ArgumentParser(description="Align_face")

parser.add_argument('-unprocessed_dir', type=str, default='unprocessed', help='directory with unprocessed images')
parser.add_argument('-output_dir', type=str, default='input/face', help='output directory')
parser.add_argument(
"-unprocessed_dir",
type=str,
default="unprocessed",
help="directory with unprocessed images",
)
parser.add_argument(
"-output_dir", type=str, default="input/face", help="output directory"
)

parser.add_argument('-output_size', type=int, default=1024, help='size to downscale the input images to, must be power of 2')
parser.add_argument('-seed', type=int, help='manual seed to use')
parser.add_argument('-cache_dir', type=str, default='cache', help='cache directory for model weights')
parser.add_argument(
"-output_size",
type=int,
default=1024,
help="size to downscale the input images to, must be power of 2",
)
parser.add_argument("-seed", type=int, help="manual seed to use")
parser.add_argument(
"-cache_dir", type=str, default="cache", help="cache directory for model weights"
)

###############
parser.add_argument('-inter_method', type=str, default='bicubic')

parser.add_argument("-inter_method", type=str, default="bicubic")


args = parser.parse_args()
print(vars(args))

cache_dir = Path(args.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)

output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True,exist_ok=True)
output_dir.mkdir(parents=True, exist_ok=True)

print("Downloading Shape Predictor")
f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True)
f = open_url(
"https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx",
cache_dir=cache_dir,
return_path=True,
)
predictor = dlib.shape_predictor(f)

for im in Path(args.unprocessed_dir).glob("*.*"):
faces = align_face(str(im),predictor)
faces = align_face(str(im), predictor)

for i,face in enumerate(faces):
if(args.output_size):
factor = 1024//args.output_size
assert args.output_size*factor == 1024
for i, face in enumerate(faces):
if args.output_size:
factor = 1024 // args.output_size
assert args.output_size * factor == 1024
face_tensor = torchvision.transforms.ToTensor()(face).unsqueeze(0).cuda()
face_tensor_lr = face_tensor[0].cpu().detach().clamp(0, 1)
face = torchvision.transforms.ToPILImage()(face_tensor_lr)
if factor != 1:
face = face.resize((args.output_size, args.output_size), PIL.Image.LANCZOS)
face = face.resize(
(args.output_size, args.output_size), PIL.Image.LANCZOS
)
if len(faces) > 1:
face.save(Path(args.output_dir) / (im.stem+f"_{i}.png"))
face.save(Path(args.output_dir) / (im.stem + f"_{i}.png"))
else:
face.save(Path(args.output_dir) / (im.stem + f".png"))
face.save(Path(args.output_dir) / (im.stem + f".png"))
16 changes: 8 additions & 8 deletions datasets/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torchvision.transforms as transforms
import os

class ImagesDataset(Dataset):

class ImagesDataset(Dataset):
def __init__(self, opts, image_path=None):
if not image_path:
image_root = opts.input_dir
Expand All @@ -16,24 +16,24 @@ def __init__(self, opts, image_path=None):
elif type(image_path) == list:
self.image_paths = image_path

self.image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
self.image_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
self.opts = opts

def __len__(self):
return len(self.image_paths)

def __getitem__(self, index):
im_path = self.image_paths[index]
im_H = Image.open(im_path).convert('RGB')
im_H = Image.open(im_path).convert("RGB")
im_L = im_H.resize((256, 256), PIL.Image.LANCZOS)
im_name = os.path.splitext(os.path.basename(im_path))[0]
if self.image_transform:
im_H = self.image_transform(im_H)
im_L = self.image_transform(im_L)

return im_H, im_L, im_name



3 changes: 2 additions & 1 deletion environment/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ dependencies:
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- clip==1.0
- clip==0.2.0
- deprecated==1.2.13
- dlib==19.22.1
- et-xmlfile==1.1.0
Expand Down Expand Up @@ -206,4 +206,5 @@ dependencies:
- uritemplate==3.0.1
- urllib3==1.26.7
- wrapt==1.13.3
- wandb==0.12.11
prefix: ~/.conda/envs/Barbershop
Loading