-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
60 lines (51 loc) · 1.57 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
import torch.optim as optim
from torch import autograd
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import models
import torch.utils.model_zoo as model_zoo
import torchvision.transforms as transforms
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import transformers
from sentence_transformers import SentenceTransformer
import pdb
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import os
import json
import time
import pprint
import importlib
import textwrap
import PIL
import io
import os, sys
import requests
import argparse
import easydict
from IPython.display import display, display_markdown
dataset = 'CelebA'
DATASET_PATH = "" #Path image and caption(caps.txt)
if dataset == 'CelebA':
"""
images.shape = [192010 , 64, 64, 3]
captions_ids = [192010 , any]
"""
data = pd.read_csv(DATASET_PATH + '/caps.txt', sep="\t", names=['img_path', 'desc'])
data['desc'] = data['desc'].apply(lambda t: t if isinstance(t, str) else None)
data = data.dropna()
data['desc'] = data['desc'].apply(lambda t: t if len(t) > 16 else None)
data = data.dropna()
data = data.drop_duplicates()
data['desc'] = data['desc'].apply(lambda t: t.replace('|', ' '))
data['img_path'] = data['img_path'].apply(lambda t: os.path.join(DATASET_PATH, 'img_align_celeba', t))
data = data[['desc', 'img_path']]
data = data.astype('str')
data = data.reset_index(drop=True)
data.head()