forked from neu-vi/SMooDi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_cmld.py
195 lines (156 loc) · 7.42 KB
/
demo_cmld.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import logging
import os
import time
from builtins import ValueError
from multiprocessing.sharedctypes import Value
from pathlib import Path
import torch.nn.functional as F
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import ConcatDataset, DataLoader
# from torchsummary import summary
from tqdm import tqdm
from mld.config import parse_args
# from mld.datasets.get_dataset import get_datasets
from mld.data.get_data import get_datasets
from mld.data.sampling import subsample, upsample
from mld.models.get_model import get_model
from mld.utils.logger import create_logger
from mld.models.architectures.mld_style_encoder import StyleClassification
from mld.utils.demo_utils import load_example_input
from mld.data.humanml.utils.plot_script import plot_3d_motion
from moviepy.editor import VideoFileClip
t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
def build_dict_from_txt(filename):
result_dict = {}
with open(filename, 'r') as f:
for line in f:
parts = line.strip().split(" ")
if len(parts) >= 3:
key = parts[2]
value = parts[1].split("_")[0]
result_dict[key] = value
return result_dict
def convert_mp4_to_gif(input_file, output_file, resize=None):
clip = VideoFileClip(input_file)
clip.write_gif(output_file, fps=20)
def main():
cfg = parse_args(phase="demo")
cfg.FOLDER = cfg.TEST.FOLDER
cfg.Name = "demo--" + cfg.NAME
logger = create_logger(cfg, phase="demo")
text, length = load_example_input(cfg.DEMO.EXAMPLE)
task = "Stylized Text2Motion"
# loading checkpoints
logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
state_dict = torch.load(cfg.TEST.CHECKPOINTS,
map_location="cpu")["state_dict"]
from collections import OrderedDict
mld_state_dict = torch.load(cfg.TRAIN.PRETRAINED_MLD,
map_location="cpu")["state_dict"]
vae_state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE,
map_location="cpu")["state_dict"]
# load dataset to extract nfeats dim of model
# cuda options
if cfg.ACCELERATOR == "gpu":
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(x) for x in cfg.DEVICE)
device = torch.device("cuda")
dataset = get_datasets(cfg, logger=logger, phase="test")[0]
model = get_model(cfg, dataset)
model.load_state_dict(state_dict, strict=False)
style_dict = torch.load(cfg.TRAIN.PRETRAINED_STYLE)
model.style_function.load_state_dict(style_dict, strict=True)
style_dict = torch.load("./save/style_encoder.pt")
style_class = StyleClassification(nclasses=100)#.cuda()
style_class.load_state_dict(style_dict, strict=True)
dict_path = "./datasets/100STYLE_name_dict.txt"
label_to_motion = build_dict_from_txt(dict_path)
logger.info("model {} loaded".format(cfg.model.model_type))
model.sample_mean = cfg.TEST.MEAN
model.fact = cfg.TEST.FACT
model.to(device)
model.eval()
mld_time = time.time()
motion_path = "./test_motion"
motion_list = os.listdir(motion_path)
mean = torch.tensor(dataset.hparams.mean).cuda()
std = torch.tensor(dataset.hparams.std).cuda()
for motion_file in motion_list:
full_name = motion_path + "/" + motion_file
base_name = os.path.basename(full_name).split("_")[0]
reference_motions = np.load(full_name)
m_length,_ = reference_motions.shape
if m_length < 196:
reference_motions = np.concatenate([reference_motions,
np.zeros((196 - m_length, reference_motions.shape[1]))
], axis=0)
reference_motions = torch.from_numpy(reference_motions).cuda().double()
reference_motions = reference_motions.unsqueeze(0)
output_dir = Path(os.path.join(cfg.FOLDER, str(cfg.model.model_type), str(cfg.NAME),"samples_" + cfg.TIME))
output_dir.mkdir(parents=True, exist_ok=True)
reference_motions = (reference_motions - mean) / std
# create mld model
total_time = time.time()
# ToDo
# 1 choose task, input motion reference, text, lengths
# 2 print task, input, output path
#
if not text:
logger.info(f"Begin specific task{task}")
# sample
with torch.no_grad():
rep_lst = []
rep_ref_lst = []
texts_lst = []
# task: input or Example
if text:
# prepare batch data
batch = {"length": length, "text": text,"motion":reference_motions}
for rep in range(cfg.DEMO.REPLICATION):
# text motion transfer
joints, feats = model(batch,feature="True")
predict_label = []
for data in feats:
logits = style_class(data.unsqueeze(0))
probabilities = F.softmax(logits, dim=1)
predicted = torch.argmax(probabilities).item()
motion_name = label_to_motion[str(predicted)]
predict_label.append(motion_name)
# cal inference time
infer_time = time.time() - mld_time
num_batch = 1
num_all_frame = sum(batch["length"])
num_ave_frame = sum(batch["length"]) / len(batch["length"])
nsample = len(joints)
id = 0
for i in range(nsample):
npypath = str(output_dir /
f"{base_name}_{length[i]}_batch{id}_{rep}.npy")
np.save(npypath, joints[i].detach().cpu().numpy())
logger.info(f"Motions are generated here:\n{npypath}")
fig_path = Path(str(npypath).replace(".npy",".mp4"))
gif_path = Path(str(npypath).replace(".npy",".gif"))
plot_3d_motion(fig_path,t2m_kinematic_chain, joints[i].detach().cpu().numpy(), title=batch["text"][i] + " " + predict_label[i],dataset='humanml',fps=20)
convert_mp4_to_gif(str(fig_path),str(gif_path))
# ToDo fix time counting
total_time = time.time() - total_time
print(f'SMooDi Infer time - This/Ave batch: {infer_time/num_batch:.2f}')
print(f'SMooDi Infer FPS - Total batch: {num_all_frame/infer_time:.2f}')
print(f'SMooDi Infer time - This/Ave batch: {infer_time/num_batch:.2f}')
print(f'SMooDi Infer FPS - Total batch: {num_all_frame/infer_time:.2f}')
print(
f'SMooDi Infer FPS - Running Poses Per Second: {num_ave_frame*infer_time/num_batch:.2f}')
print(
f'SMooDi Infer FPS - {num_all_frame/infer_time:.2f}s')
print(
f'SMooDi Infer FPS - Running Poses Per Second: {num_ave_frame*infer_time/num_batch:.2f}')
print(
f'SMooDi Infer FPS - time for 100 Poses: {infer_time/(num_batch*num_ave_frame)*100:.2f}'
)
print(
f'Total time spent: {total_time:.2f} seconds (including model loading time and exporting time).'
)
if __name__ == "__main__":
main()