-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_trainer.py
128 lines (109 loc) · 6.36 KB
/
run_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import signal
import subprocess
import psutil
from config import Config
class RunTrainer:
def __init__(self):
self.running = False
self.process = None
pass
def run(self, config: Config, finetrainers_path: str, log_file: str):
assert finetrainers_path, "Path to finetrainers is required"
assert config.get('data_root'), "Data root required"
assert config.get('pretrained_model_name_or_path'), "pretrained_model_name_or_path required"
model_cmd = ["--model_name", config.get('model_name'),
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path'),
"--text_encoder_dtype", config.get('text_encoder_dtype'),
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
"--transformer_dtype", config.get('transformer_dtype'),
"--vae_dtype", config.get('vae_dtype')]
if config.get('layerwise_upcasting_modules') != 'none':
model_cmd +=["--layerwise_upcasting_modules", config.get('layerwise_upcasting_modules'),
"--layerwise_upcasting_storage_dtype", config.get('layerwise_upcasting_storage_dtype'),
"--layerwise_upcasting_skip_modules_pattern", config.get('layerwise_upcasting_skip_modules_pattern')]
dataset_cmd = ["--data_root", config.get('data_root'),
"--video_column", config.get('video_column'),
"--caption_column", config.get('caption_column'),
"--id_token", config.get('id_token'),
"--video_resolution_buckets"]
dataset_cmd += config.get('video_resolution_buckets').split(' ')
dataset_cmd += ["--image_resolution_buckets"]
dataset_cmd += config.get('image_resolution_buckets').split(' ')
dataset_cmd += ["--caption_dropout_p", config.get('caption_dropout_p'),
"--caption_dropout_technique", config.get('caption_dropout_technique'),
'--precompute_conditions' if config.get('precompute_conditions') else '']
if config.get('dataset_file'):
dataset_cmd += ["--dataset_file", f"{config.get('data_root')}/{config.get('dataset_file')}"]
dataloader_cmd = ["--dataloader_num_workers", config.get('dataloader_num_workers')]
diffusion_cmd = [config.get('diffusion_options')]
training_cmd = ["--training_type", config.get('training_type'),
"--seed", config.get('seed'),
"--batch_size", config.get('batch_size'),
"--train_steps", config.get('train_steps'),
"--rank", config.get('rank'),
"--lora_alpha", config.get('lora_alpha'),
"--target_modules"]
training_cmd += config.get('target_modules').split(' ')
training_cmd += ["--gradient_accumulation_steps", config.get('gradient_accumulation_steps'),
'--gradient_checkpointing' if config.get('gradient_checkpointing') else '',
"--checkpointing_steps", config.get('checkpointing_steps'),
"--checkpointing_limit", config.get('checkpointing_limit'),
'--enable_slicing' if config.get('enable_slicing') else '',
'--enable_tiling' if config.get('enable_tiling') else '']
if config.get('enable_model_cpu_offload'):
training_cmd += ["--enable_model_cpu_offload"]
if config.get('resume_from_checkpoint'):
training_cmd += ["--resume_from_checkpoint", config.get('resume_from_checkpoint')]
optimizer_cmd = ["--optimizer", config.get('optimizer'),
"--lr", config.get('lr'),
"--lr_scheduler", config.get('lr_scheduler'),
"--lr_warmup_steps", config.get('lr_warmup_steps'),
"--lr_num_cycles", config.get('lr_num_cycles'),
"--beta1", config.get('beta1'),
"--beta2", config.get('beta2'),
"--weight_decay", config.get('weight_decay'),
"--epsilon", config.get('epsilon'),
"--max_grad_norm", config.get('max_grad_norm'),
'--use_8bit_bnb' if config.get('use_8bit_bnb') else '']
validation_cmd = ["--validation_prompts" if config.get('validation_prompts') else '', config.get('validation_prompts') or '',
"--num_validation_videos", config.get('num_validation_videos'),
"--validation_steps", config.get('validation_steps')]
miscellaneous_cmd = ["--tracker_name", config.get('tracker_name'),
"--output_dir", config.get('output_dir'),
"--nccl_timeout", config.get('nccl_timeout'),
"--report_to", config.get('report_to')]
accelerate_cmd = ["accelerate", "launch", "--config_file", f"{finetrainers_path}/accelerate_configs/{config.get('accelerate_config')}", "--gpu_ids", config.get('gpu_ids')]
cmd = accelerate_cmd + [f"{finetrainers_path}/train.py"] + model_cmd + dataset_cmd + dataloader_cmd + diffusion_cmd + training_cmd + optimizer_cmd + validation_cmd + miscellaneous_cmd
fixed_cmd = []
for i in range(len(cmd)):
if cmd[i] != '':
fixed_cmd.append(f"{cmd[i]}")
print(' '.join(fixed_cmd))
self.running = True
with open(log_file, "w") as output_file:
self.process = subprocess.Popen(fixed_cmd, shell=False, stdout=output_file, stderr=output_file, text=True, preexec_fn=os.setsid)
self.process.communicate()
return self.process
return "Unknown result"
def stop(self):
try:
self.running = False
if self.process:
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
self.terminate_process_tree(self.process.pid)
except Exception as e:
return f"Error stopping training: {e}"
finally:
self.process.wait()
return "Training forcibly stopped"
def terminate_process_tree(pid):
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True) # Get child processes
for child in children:
child.terminate()
parent.terminate()
except psutil.NoSuchProcess:
pass