Skip to content

Commit

Permalink
Removed hyperparam scanner
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Feb 29, 2024
1 parent 4131649 commit 7d6ae24
Showing 1 changed file with 0 additions and 149 deletions.
149 changes: 0 additions & 149 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,152 +1013,3 @@ def update_saved_model(self, filename='best_model.pth'):
super().update_saved_model(filename)
encoder_filename = os.path.splitext(filename)[0] + '_encoder.pth'
super().update_saved_model(encoder_filename, save_configs=False, subcomponent='encoder')


class AlphaDiffractHyperparameterScanner:
def __init__(self, scan_params_dict: dict, base_config_dict: TrainingConfig, keep_models_in_memory=False):
"""
Hyperparameter scanner.
:param scan_params_dict: dict. A dictionary of the parameters to be scanned. The keys of the dictionary
should be from `TrainingConfig`, and the value should be a list of values to test.
:param base_config_dict: TrainingConficDict. A baseline config dictionary.
"""
self.scan_params_dict = scan_params_dict
self.result_table = None
self.n_params = len(self.scan_params_dict)
self.param_comb_list = None
self.base_config_dict = base_config_dict
# Might need to be a copy. Deepcopy is currently not done because of the H5py object in `dataset`.
self.config_dict = self.base_config_dict
self.trainer_list = []
self.model_save_dir_prefix = self.base_config_dict['model_save_dir']
self.keep_models_in_memory = keep_models_in_memory
self.dummy_loss_tracker = LossTracker(self.config_dict['pred_names'])
self.metric_names = self.dummy_loss_tracker.get_metric_names_for_hyperparam_tuning()
self.verbose = True

def build_result_table(self):

self.param_comb_list = list(itertools.product(*self.scan_params_dict.values()))
dframe_dict = {}
for i_param, param in enumerate(self.scan_params_dict.keys()):
dframe_dict[param] = []
for i_comb in range(len(self.param_comb_list)):
v = self.param_comb_list[i_comb][i_param]
v = self.convert_item_to_be_dataframe_compatible(v)
dframe_dict[param].append(v)
for metric in self.metric_names:
dframe_dict[metric] = [0.0] * len(self.param_comb_list)
self.result_table = pd.DataFrame(dframe_dict)

def build(self, seed=123):
if seed is not None:
set_all_random_seeds(seed)
self.build_result_table()

def convert_item_to_be_dataframe_compatible(self, v):
if isinstance(v, nn.Module):
nv = v._get_name()
elif isinstance(v, (tuple, list)) and issubclass(v[0], nn.Module):
nv = v[0].__name__
if len(v[1]) > 0:
nv += '_' + self.convert_dict_to_string(v[1])
else:
nv = v
return nv

def modify_condig_dict(self, param_dict):
for i, param in enumerate(param_dict.keys()):
# if param == 'model':
# # For testing different models, the input in `scan_params_dict['model']` is supposed to be a list of
# # 2-tuples, where the first element is the class handle of the model, and the second element is a
# # dictionary of keyword arguments in the constructor of that class.
# self.config_dict[param] = param_dict[param][0](**param_dict[param][1])
# else:
self.config_dict[param] = param_dict[param]
# Update save path.
appendix = self.convert_dict_to_string(param_dict)
self.config_dict['model_save_dir'] = self.model_save_dir_prefix + '_' + appendix

@staticmethod
def convert_string_to_camel_case(s):
s = re.sub(r"(_|-)+", " ", s).title().replace(" ", "")
return ''.join([s[0].lower(), s[1:]])

def convert_dict_to_string(self, d):
s = ''
for i, (k, v) in enumerate(d.items()):
s += self.convert_string_to_camel_case(k)
s += '_'
s += str(self.convert_item_to_be_dataframe_compatible(v))
if i < len(d) - 1:
s += '_'
return s

def create_param_dict(self, config_val_list: list):
d = {}
for i in range(len(config_val_list)):
param_name = list(self.scan_params_dict.keys())[i]
d[param_name] = config_val_list[i]
return d

def run(self):
for i_comb in tqdm(range(len(self.param_comb_list))):
param_dict = self.create_param_dict(self.param_comb_list[i_comb])
self.modify_condig_dict(param_dict)
trainer = AlphaDiffractTrainer(self.config_dict)
self.run_trainer(trainer)
self.trainer_list.append(trainer)
self.update_result_table(i_comb, trainer)
self.cleanup()

def run_trainer(self, trainer):
trainer.verbose = False
trainer.build()
trainer.run_training()

def plot_all_training_history(self):
for i_comb in range(len(self.param_comb_list)):
print('Training history for the following config - ')
print(self.result_table.iloc[i_comb])
trainer = self.trainer_list[i_comb]
trainer.plot_training_history()

def load_model_for_trainer(self, trainer):
# Reinitialize with a brand new model object
assert isinstance(trainer.configs['model'], (tuple, list)), \
'`config_dict["model"]` should be a tuple of class handle and kwargs.'
trainer.build_model()
trainer.configs['model_path'] = os.path.join(trainer.configs['model_save_dir'], 'best_model.pth')
trainer.load_state_checkpoint()

def run_testing_for_all(self, indices, dataset='train'):
"""
Run test for all trained models with selected samples and plot results.
:param indices: tuple.
:param dataset: str. Can be 'train' or 'validation'.
"""
for i_comb in range(len(self.param_comb_list)):
print('Testing results for the following config - ')
print(self.result_table.iloc[i_comb])
trainer = self.trainer_list[i_comb]
if not self.keep_models_in_memory:
# If the models of trainers were not kept in memory, load them back from hard drive.
self.load_model_for_trainer(trainer)
trainer.run_testing(indices, dataset=dataset)

def update_result_table(self, i_comb, trainer):
for metric_name in self.metric_names:
self.result_table.at[i_comb, metric_name] = trainer.loss_tracker[metric_name]

def cleanup(self):
if self.verbose:
get_gpu_memory(show=True)
if not self.keep_models_in_memory:
# Destroy model to save memory.
del self.trainer_list[-1].model
self.trainer_list[-1].model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()

0 comments on commit 7d6ae24

Please sign in to comment.