forked from IAHispano/Applio
-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathextract_feature_print.py
126 lines (111 loc) · 3.99 KB
/
extract_feature_print.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os, sys, traceback
import tqdm
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# device=sys.argv[1]
n_part = int(sys.argv[2])
i_part = int(sys.argv[3])
if len(sys.argv) == 6:
exp_dir = sys.argv[4]
version = sys.argv[5]
else:
i_gpu = sys.argv[4]
exp_dir = sys.argv[5]
os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
version = sys.argv[6]
import torch
import torch.nn.functional as F
import soundfile as sf
import numpy as np
from fairseq import checkpoint_utils
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
def printt(strr):
print(strr)
f.write("%s\n" % strr)
f.flush()
printt(sys.argv)
model_path = "hubert_base.pt"
printt(exp_dir)
wavPath = "%s/1_16k_wavs" % exp_dir
outPath = (
"%s/3_feature256" % exp_dir if version == "v1" else "%s/3_feature768" % exp_dir
)
os.makedirs(outPath, exist_ok=True)
# wave must be 16k, hop_size=320
def readwave(wav_path, normalize=False):
wav, sr = sf.read(wav_path)
assert sr == 16000
feats = torch.from_numpy(wav).float()
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
if normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
feats = feats.view(1, -1)
return feats
# HuBERT model
printt("load model(s) from {}".format(model_path))
# if hubert model is exist
if os.access(model_path, os.F_OK) == False:
printt(
"Error: Extracting is shut down because %s does not exist, you may download it from https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main"
% model_path
)
exit(0)
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[model_path],
suffix="",
)
model = models[0]
model = model.to(device)
printt("move model to %s" % device)
if device not in ["mps", "cpu"]:
model = model.half()
model.eval()
todo = sorted(list(os.listdir(wavPath)))[i_part::n_part]
n = max(1, len(todo) // 10) # 最多打印十条
if len(todo) == 0:
printt("no-feature-todo")
else:
printt("all-feature-%s" % len(todo))
with tqdm.tqdm(total=len(todo)) as pbar:
for idx, file in enumerate(todo):
try:
if file.endswith(".wav"):
wav_path = "%s/%s" % (wavPath, file)
out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))
if os.path.exists(out_path):
continue
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.half().to(device)
if device not in ["mps", "cpu"]
else feats.to(device),
"padding_mask": padding_mask.to(device),
"output_layer": 9 if version == "v1" else 12, # layer 9
}
with torch.no_grad():
logits = model.extract_features(**inputs)
feats = (
model.final_proj(logits[0]) if version == "v1" else logits[0]
)
feats = feats.squeeze(0).float().cpu().numpy()
if np.isnan(feats).sum() == 0:
np.save(out_path, feats, allow_pickle=False)
else:
printt("%s-contains nan" % file)
# if idx % n == 0:
# printt("now-%s,all-%s,%s,%s" % (idx, len(todo), file, feats.shape))
pbar.set_description("file %s, shape %s" % (file, feats.shape))
except:
printt(traceback.format_exc())
pbar.update(1)
printt("all-feature-done")