Skip to content

Commit

Permalink
add more config options
Browse files Browse the repository at this point in the history
  • Loading branch information
neph1 committed Jan 9, 2025
1 parent e1cd838 commit 31fa03c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion config/config_categories.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Dataset: data_root, video_column, caption_column, id_token, video_resolution_buckets, caption_dropout_p
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator
Expand Down
3 changes: 3 additions & 0 deletions config/config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ checkpointing_limit: 102
checkpointing_steps: 500
data_root: ''
dataloader_num_workers: 0
dataset_file: ''
diffusion_options: ''
enable_model_cpu_offload: false
enable_slicing: true
enable_tiling: true
epsilon: 1e-8
gpu_ids: '0'
gradient_accumulation_steps: 4
gradient_checkpointing: true
id_token: afkx
image_resolution_buckets: 512x768
lora_alpha: 128
lr: 0.0001
lr_num_cycles: 1
Expand Down
6 changes: 6 additions & 0 deletions run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
"--id_token", config.get('id_token'),
"--video_resolution_buckets"]
dataset_cmd += config.get('video_resolution_buckets').split(' ')
dataset_cmd += ["--image_resolution_buckets"]
dataset_cmd += config.get('image_resolution_buckets').split(' ')
dataset_cmd += ["--caption_dropout_p", config.get('caption_dropout_p'),
"--caption_dropout_technique", config.get('caption_dropout_technique'),
"--text_encoder_dtype", config.get('text_encoder_dtype'),
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
"--vae_dtype", config.get('vae_dtype'),
'--precompute_conditions' if config.get('precompute_conditions') else '']
if config.get('dataset_file'):
dataset_cmd += ["--dataset_file", config.get('dataset_file')]

dataloader_cmd = ["--dataloader_num_workers", config.get('dataloader_num_workers')]

Expand All @@ -56,6 +60,8 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
"--checkpointing_limit", config.get('checkpointing_limit'),
'--enable_slicing' if config.get('enable_slicing') else '',
'--enable_tiling' if config.get('enable_tiling') else '']
if config.get('enable_model_cpu_offload'):
training_cmd += ["--enable_model_cpu_offload"]

if config.get('resume_from_checkpoint'):
training_cmd += ["--resume_from_checkpoint", config.get('resume_from_checkpoint')]
Expand Down

0 comments on commit 31fa03c

Please sign in to comment.