From 16ac05b3698c2ca3e6af6426a900cf0a1178263b Mon Sep 17 00:00:00 2001 From: Ming Du Date: Wed, 13 Mar 2024 10:05:30 -0500 Subject: [PATCH] move_to_device wrapper; multirank gatekeeper --- generic_trainer/trainer.py | 96 ++++++++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index faabf93..cf80211 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -26,6 +26,24 @@ from .message_logger import logger +class MultirankGateKeeper: + """ + A gatekeeper class that determines if a routine should be executed with the current rank. + """ + def __init__(self, rank, num_ranks): + self.rank = rank + self.num_ranks = num_ranks + + def should_proceed(self, gate_kept=True): + if not gate_kept: + return True + else: + if self.rank == 0: + return True + else: + return False + + class LossTracker(dict): def __init__(self, pred_names=('cs', 'eg', 'sg'), *args, **kwargs): @@ -245,6 +263,7 @@ def __init__(self, configs: TrainingConfig, rank=None, num_processes=None, *args self.loss_criterion = self.configs.loss_function self.iterations_per_epoch = 0 self.current_epoch = 0 + self.gatekeeper = MultirankGateKeeper(0, 1) self.debug = self.configs.debug self.verbose = True @@ -371,7 +390,7 @@ def build_loss_tracker(self): **self.configs.loss_tracker_params.__dict__) def build_dir(self): - if self.rank == 0: + if self.gatekeeper.should_proceed(gate_kept=True): if not os.path.exists(self.configs.model_save_dir): os.makedirs(self.configs.model_save_dir) self.barrier() @@ -383,6 +402,7 @@ def build_device(self): def build_ranks(self): self.rank = self.get_rank() self.num_processes = self.get_num_processes() + self.gatekeeper = MultirankGateKeeper(self.rank, self.num_processes) def build_scalable_parameters(self): self.all_proc_batch_size = self.configs.batch_size_per_process * self.num_processes @@ -422,8 +442,7 @@ def run_training(self): if self.verbose and self.rank == 0: self.loss_tracker.print_losses() - if self.rank == 0: - self.update_saved_model(filename='final_model.pth') + self.update_saved_model(filename='final_model.pth') def compute_losses(self, loss_records, preds, labels): """ @@ -492,22 +511,23 @@ def process_data_loader_yield(self, data_and_labels): data = [] for i in range(len(data_and_labels)): data.append(data_and_labels[i][0]) - data = torch.concat(data, dim=0).to(self.device) + data = self.move_to_device(torch.concat(data, dim=0)) n_labels = len(data_and_labels[0]) - 1 label_list = [[] for i in range(n_labels)] for i_item in range(n_labels): for i_sample in range(len(data_and_labels)): label_list[i_item].append(data_and_labels[i_sample][i_item + 1]) - labels = [torch.concat(label_list[i]).to(self.device) for i in range(len(label_list))] + labels = [self.move_to_device(torch.concat(label_list[i])) for i in range(len(label_list))] else: # In this case, data_and_labels is organized in a item-then-sample order: # it is a tuple of items. Each element of the tuple is a tensor of # [n_total_batch_size, ...]. - data = data_and_labels[0][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank].to(self.device) + data = data_and_labels[0][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank] + data = self.move_to_device(data) labels = [] for i in range(1, len(data_and_labels)): labels.append( - data_and_labels[i][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank].to(self.device) + self.move_to_device(data_and_labels[i][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank]) ) else: if isinstance(data_and_labels[0], tuple): @@ -515,17 +535,17 @@ def process_data_loader_yield(self, data_and_labels): data = [] for i in range(len(data_and_labels)): data.append(data_and_labels[i][0]) - data = torch.concat(data, dim=0).to(self.device) + data = self.move_to_device(torch.concat(data, dim=0)) n_labels = len(data_and_labels[0]) - 1 label_list = [[] for i in range(n_labels)] for i_item in range(n_labels): for i_sample in range(len(data_and_labels)): label_list[i_item].append(data_and_labels[i_sample][i_item + 1]) - labels = [torch.concat(label_list[i]).to(self.device) for i in range(len(label_list))] + labels = [self.move_to_device(torch.concat(label_list[i])) for i in range(len(label_list))] else: # In this case, data_and_labels is in item-then-sample order. - data = data_and_labels[0].to(self.device) - labels = [data_and_labels[i].to(self.device) for i in range(1, len(data_and_labels))] + data = self.move_to_device(data_and_labels[0]) + labels = [self.move_to_device(data_and_labels[i]) for i in range(1, len(data_and_labels))] return data, labels def get_epoch_loss_buffer(self): @@ -566,8 +586,7 @@ def run_training_epoch(self): acc_dict = self.loss_tracker.calculate_classification_accuracy() self.loss_tracker.update_accuracy_history(acc_dict, 'train') - if self.rank == 0: - self.save_model_and_states_checkpoint() + self.save_model_and_states_checkpoint() if self.configs.post_training_epoch_hook is not None: self.configs.post_training_epoch_hook() @@ -591,15 +610,13 @@ def run_validation(self): losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses] is_best = self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='val_loss') - if self.rank == 0: - self.write_training_info() + self.write_training_info() # Update saved model if val loss is lower if is_best: logger.info("Saving improved model after Val Loss improved from %.5f to %.5f" % ( last_best_val_loss, self.loss_tracker['best_val_loss'])) - if self.rank == 0: - self.update_saved_model(filename='best_model.pth') + self.update_saved_model(filename='best_model.pth') if self.configs.task_type == 'classification': self.loss_tracker.sync_classification_preds_and_labels_across_ranks() @@ -715,10 +732,12 @@ def save_model(self, path, subcomponent=None): m = getattr(m, subcomponent) torch.save(m.state_dict(), path) - def update_saved_model(self, filename='best_model.pth', save_configs=True, subcomponent=None): + def update_saved_model(self, filename='best_model.pth', save_configs=True, subcomponent=None, run_with_only_rank_0=True): """ Updates saved model if validation loss is minimum. """ + if not self.gatekeeper.should_proceed(gate_kept=True): + return path = self.configs.model_save_dir dest_path = os.path.join(path, filename) if not os.path.isdir(path): @@ -742,10 +761,11 @@ def generate_state_dict(self): return state def save_model_and_states_checkpoint(self): + if not self.gatekeeper.should_proceed(gate_kept=True): + return state_dict = self.generate_state_dict() torch.save(state_dict, os.path.join(self.configs.model_save_dir, 'checkpoint.state')) - if self.rank == 0: - self.update_saved_model('checkpoint_model.pth') + self.update_saved_model('checkpoint_model.pth') def load_state_checkpoint(self): if self.configs.checkpoint_dir is None: @@ -761,6 +781,8 @@ def load_state_checkpoint(self): self.loss_tracker = state_dict['loss_tracker'] def write_training_info(self): + if not self.gatekeeper.should_proceed(gate_kept=True): + return f = open(os.path.join(self.configs.model_save_dir, 'training_info.txt'), 'w') for key in self.loss_tracker: f.write('{} = {}\n'.format(key, self.loss_tracker[key])) @@ -853,6 +875,9 @@ def communicate_value_across_ranks(self, var, mode='average'): var = comm.allgather(var) return var + def move_to_device(self, var): + return var.to(self.device) + def cleanup_memory(self): self.model = None if torch.cuda.is_available(): @@ -887,7 +912,7 @@ def process_data_loader_yield(self, data): each element being a tensor of [batch_size_per_process, ...]. With the collate_fn defined, the yields of dataloader are different between PyTorch 1.x and 2.x. This - function automatically detects the format and treat the data accordingly. + function automatically detects the format and processes the data accordingly. """ if self.parallelization_type == 'multi_node': bsize_per_rank = self.configs.batch_size_per_process @@ -899,9 +924,9 @@ def process_data_loader_yield(self, data): n_data = len(data[0]) data_proc = [[] for i in range(n_data)] for i_item in range(n_data): - for i in range(len(data_chunk)): - data_proc[i_item].append(data_chunk[i][i_item]) - data = [torch.concat(data_proc[i]).to(self.device) for i in range(len(data_proc))] + for i_sample in range(len(data_chunk)): + data_proc[i_item].append(data_chunk[i_sample][i_item]) + data = [self.move_to_device(torch.concat(data_proc[i])) for i in range(len(data_proc))] else: # In this case, data_and_labels is organized in a item-then-sample order: # it is a tuple of items. Each element of the tuple is a tensor of @@ -909,7 +934,7 @@ def process_data_loader_yield(self, data): data_list = [] for i in range(len(data)): data_list.append( - data[i][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank].to(self.device) + self.move_to_device(data[i][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank]) ) data = data_list else: @@ -920,10 +945,10 @@ def process_data_loader_yield(self, data): for i_item in range(n_data): for i_sample in range(len(data)): data_list[i_item].append(data[i_sample][i_item]) - data = [torch.concat(data_list[i]).to(self.device) for i in range(len(data_list))] + data = [self.move_to_device(torch.concat(data_list[i])) for i in range(len(data_list))] else: # In this case, data_and_labels is in item-then-sample order. - data = [data[i].to(self.device) for i in range(len(data))] + data = [self.move_to_device(data[i]) for i in range(len(data))] return data def compute_losses(self, loss_records, preds, *args, **kwargs): @@ -963,7 +988,7 @@ def run_training_epoch(self): # Zero current grads and do backprop self.optimizer.zero_grad() - total_loss_tensor.backward() + self.run_backprop(total_loss_tensor) self.optimizer.step() # Update the LR according to the schedule -- CyclicLR updates each batch @@ -974,8 +999,7 @@ def run_training_epoch(self): losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses] self.loss_tracker.update_losses(losses, type='loss', epoch=self.current_epoch) - if self.rank == 0: - self.save_model_and_states_checkpoint() + self.save_model_and_states_checkpoint() if self.configs.post_training_epoch_hook is not None: self.configs.post_training_epoch_hook() @@ -995,15 +1019,13 @@ def run_validation(self): losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses] is_best = self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='val_loss') - if self.rank == 0: - self.write_training_info() + self.write_training_info() # Update saved model if val loss is lower if is_best: logger.info("Saving improved model after Val Loss improved from %.5f to %.5f" % ( last_best_val_loss, self.loss_tracker['best_val_loss'])) - if self.rank == 0: - self.update_saved_model(filename='best_model.pth') + self.update_saved_model(filename='best_model.pth') if self.configs.post_validation_epoch_hook is not None: self.configs.post_validation_epoch_hook() @@ -1012,6 +1034,8 @@ def update_saved_model(self, filename='best_model.pth'): """ Updates saved model if validation loss is minimum. """ + if not self.gatekeeper.should_proceed(gate_kept=True): + return super().update_saved_model(filename) encoder_filename = os.path.splitext(filename)[0] + '_encoder.pth' super().update_saved_model(encoder_filename, save_configs=False, subcomponent='encoder') @@ -1063,3 +1087,7 @@ def build_accelerate(self): def run_backprop(self, loss_node): self.accelerator.backward(loss_node) + + def move_to_device(self, var): + # HuggingFace Accelerate should not need manual data offloading. + return var