From 7d6ae24516485138e30b1eb256f0b1c67000c2cf Mon Sep 17 00:00:00 2001 From: Ming Du Date: Thu, 29 Feb 2024 12:11:02 -0600 Subject: [PATCH] Removed hyperparam scanner --- generic_trainer/trainer.py | 149 ------------------------------------- 1 file changed, 149 deletions(-) diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index c18a64d..3506c08 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -1013,152 +1013,3 @@ def update_saved_model(self, filename='best_model.pth'): super().update_saved_model(filename) encoder_filename = os.path.splitext(filename)[0] + '_encoder.pth' super().update_saved_model(encoder_filename, save_configs=False, subcomponent='encoder') - - -class AlphaDiffractHyperparameterScanner: - def __init__(self, scan_params_dict: dict, base_config_dict: TrainingConfig, keep_models_in_memory=False): - """ - Hyperparameter scanner. - - :param scan_params_dict: dict. A dictionary of the parameters to be scanned. The keys of the dictionary - should be from `TrainingConfig`, and the value should be a list of values to test. - :param base_config_dict: TrainingConficDict. A baseline config dictionary. - """ - self.scan_params_dict = scan_params_dict - self.result_table = None - self.n_params = len(self.scan_params_dict) - self.param_comb_list = None - self.base_config_dict = base_config_dict - # Might need to be a copy. Deepcopy is currently not done because of the H5py object in `dataset`. - self.config_dict = self.base_config_dict - self.trainer_list = [] - self.model_save_dir_prefix = self.base_config_dict['model_save_dir'] - self.keep_models_in_memory = keep_models_in_memory - self.dummy_loss_tracker = LossTracker(self.config_dict['pred_names']) - self.metric_names = self.dummy_loss_tracker.get_metric_names_for_hyperparam_tuning() - self.verbose = True - - def build_result_table(self): - - self.param_comb_list = list(itertools.product(*self.scan_params_dict.values())) - dframe_dict = {} - for i_param, param in enumerate(self.scan_params_dict.keys()): - dframe_dict[param] = [] - for i_comb in range(len(self.param_comb_list)): - v = self.param_comb_list[i_comb][i_param] - v = self.convert_item_to_be_dataframe_compatible(v) - dframe_dict[param].append(v) - for metric in self.metric_names: - dframe_dict[metric] = [0.0] * len(self.param_comb_list) - self.result_table = pd.DataFrame(dframe_dict) - - def build(self, seed=123): - if seed is not None: - set_all_random_seeds(seed) - self.build_result_table() - - def convert_item_to_be_dataframe_compatible(self, v): - if isinstance(v, nn.Module): - nv = v._get_name() - elif isinstance(v, (tuple, list)) and issubclass(v[0], nn.Module): - nv = v[0].__name__ - if len(v[1]) > 0: - nv += '_' + self.convert_dict_to_string(v[1]) - else: - nv = v - return nv - - def modify_condig_dict(self, param_dict): - for i, param in enumerate(param_dict.keys()): - # if param == 'model': - # # For testing different models, the input in `scan_params_dict['model']` is supposed to be a list of - # # 2-tuples, where the first element is the class handle of the model, and the second element is a - # # dictionary of keyword arguments in the constructor of that class. - # self.config_dict[param] = param_dict[param][0](**param_dict[param][1]) - # else: - self.config_dict[param] = param_dict[param] - # Update save path. - appendix = self.convert_dict_to_string(param_dict) - self.config_dict['model_save_dir'] = self.model_save_dir_prefix + '_' + appendix - - @staticmethod - def convert_string_to_camel_case(s): - s = re.sub(r"(_|-)+", " ", s).title().replace(" ", "") - return ''.join([s[0].lower(), s[1:]]) - - def convert_dict_to_string(self, d): - s = '' - for i, (k, v) in enumerate(d.items()): - s += self.convert_string_to_camel_case(k) - s += '_' - s += str(self.convert_item_to_be_dataframe_compatible(v)) - if i < len(d) - 1: - s += '_' - return s - - def create_param_dict(self, config_val_list: list): - d = {} - for i in range(len(config_val_list)): - param_name = list(self.scan_params_dict.keys())[i] - d[param_name] = config_val_list[i] - return d - - def run(self): - for i_comb in tqdm(range(len(self.param_comb_list))): - param_dict = self.create_param_dict(self.param_comb_list[i_comb]) - self.modify_condig_dict(param_dict) - trainer = AlphaDiffractTrainer(self.config_dict) - self.run_trainer(trainer) - self.trainer_list.append(trainer) - self.update_result_table(i_comb, trainer) - self.cleanup() - - def run_trainer(self, trainer): - trainer.verbose = False - trainer.build() - trainer.run_training() - - def plot_all_training_history(self): - for i_comb in range(len(self.param_comb_list)): - print('Training history for the following config - ') - print(self.result_table.iloc[i_comb]) - trainer = self.trainer_list[i_comb] - trainer.plot_training_history() - - def load_model_for_trainer(self, trainer): - # Reinitialize with a brand new model object - assert isinstance(trainer.configs['model'], (tuple, list)), \ - '`config_dict["model"]` should be a tuple of class handle and kwargs.' - trainer.build_model() - trainer.configs['model_path'] = os.path.join(trainer.configs['model_save_dir'], 'best_model.pth') - trainer.load_state_checkpoint() - - def run_testing_for_all(self, indices, dataset='train'): - """ - Run test for all trained models with selected samples and plot results. - - :param indices: tuple. - :param dataset: str. Can be 'train' or 'validation'. - """ - for i_comb in range(len(self.param_comb_list)): - print('Testing results for the following config - ') - print(self.result_table.iloc[i_comb]) - trainer = self.trainer_list[i_comb] - if not self.keep_models_in_memory: - # If the models of trainers were not kept in memory, load them back from hard drive. - self.load_model_for_trainer(trainer) - trainer.run_testing(indices, dataset=dataset) - - def update_result_table(self, i_comb, trainer): - for metric_name in self.metric_names: - self.result_table.at[i_comb, metric_name] = trainer.loss_tracker[metric_name] - - def cleanup(self): - if self.verbose: - get_gpu_memory(show=True) - if not self.keep_models_in_memory: - # Destroy model to save memory. - del self.trainer_list[-1].model - self.trainer_list[-1].model = None - if torch.cuda.is_available(): - torch.cuda.empty_cache()