From 45ff69e648bf005a2aa4051622bbf4ff0bafdb8b Mon Sep 17 00:00:00 2001 From: Ming Du Date: Thu, 6 Jun 2024 16:00:32 -0500 Subject: [PATCH] Allow different LRs for parameter groups --- generic_trainer/configs.py | 15 ++++++++++++--- generic_trainer/trainer.py | 29 +++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index cb03216..3bc2908 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -1,7 +1,7 @@ import copy import collections import dataclasses -from typing import Any, Callable, Optional, Union, Tuple +from typing import * import json import os import re @@ -322,12 +322,21 @@ class TrainingConfig(Config): or processes. """ - optimizer: Any = torch.optim.Adam - """String of optimizer name or the handle of a subclass of torch.optim.Optimizer""" + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam + """The optimizer class. Should be given as the handle of a subclass of torch.optim.Optimizer.""" optimizer_params: dict = dataclasses.field(default_factory=dict) """Optimizer parameters.""" + multi_optimizer_param_dicts: Optional[Sequence[Dict]] = None + """ + The optimizer uses different learning rates for different parameters if this is provided. + It should be a list of dictionaries as described in + https://pytorch.org/docs/stable/optim.html#per-parameter-options. + However, the code to get trainable parameters in the "params" keys should be given + as a string, where the model object should be referenced as `self.get_model_object()`. + """ + model_save_dir: str = '.' """Directory to save trained models.""" diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index 8a4bb69..f0b25e1 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -814,19 +814,36 @@ def build_optimizer(self): "one.".format(self.model.__class__)) else: trainable_params = self.model.parameters() - if isinstance(self.configs.optimizer, str): - if self.configs.optimizer == 'adam': - self.optimizer = torch.optim.Adam(trainable_params, lr=self.learning_rate) - else: + if self.configs.multi_optimizer_param_dicts is None: self.optimizer = self.configs.optimizer(trainable_params, lr=self.learning_rate, **self.configs.optimizer_params) + else: + # Construct per-parameter dicts + perparam_dicts = [] + for i, d in enumerate(self.configs.multi_optimizer_param_dicts): + d_copy = d.copy() + d_copy['params'] = eval(d['params']) + if 'lr' in d_copy.keys(): + d_copy['lr'] = d_copy['lr'] * self.num_processes + perparam_dicts.append(d_copy) + self.optimizer = self.configs.optimizer(perparam_dicts, lr=self.learning_rate, + **self.configs.optimizer_params) def build_scheduler(self): self.iterations_per_epoch = len(self.training_dataset) / self.all_proc_batch_size self.iterations_per_epoch = np.ceil(self.iterations_per_epoch) step_size = 6 * self.iterations_per_epoch - self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.learning_rate / 10, - max_lr=self.learning_rate, step_size_up=step_size, + if self.configs.multi_optimizer_param_dicts is None: + base_lr=self.learning_rate * 0.1 + max_lr=self.learning_rate + else: + base_lr = [] + max_lr = [] + for d in self.optimizer.param_groups: + base_lr.append(d['lr'] * 0.1) + max_lr.append(self.learning_rate) + self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=base_lr, + max_lr=max_lr, step_size_up=step_size, cycle_momentum=False, mode='triangular2') def build_amp(self):