Skip to content

Commit

Permalink
move_to_device wrapper; multirank gatekeeper
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 13, 2024
1 parent 18a9f6c commit 16ac05b
Showing 1 changed file with 62 additions and 34 deletions.
96 changes: 62 additions & 34 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@
from .message_logger import logger


class MultirankGateKeeper:
"""
A gatekeeper class that determines if a routine should be executed with the current rank.
"""
def __init__(self, rank, num_ranks):
self.rank = rank
self.num_ranks = num_ranks

def should_proceed(self, gate_kept=True):
if not gate_kept:
return True
else:
if self.rank == 0:
return True
else:
return False


class LossTracker(dict):

def __init__(self, pred_names=('cs', 'eg', 'sg'), *args, **kwargs):
Expand Down Expand Up @@ -245,6 +263,7 @@ def __init__(self, configs: TrainingConfig, rank=None, num_processes=None, *args
self.loss_criterion = self.configs.loss_function
self.iterations_per_epoch = 0
self.current_epoch = 0
self.gatekeeper = MultirankGateKeeper(0, 1)

self.debug = self.configs.debug
self.verbose = True
Expand Down Expand Up @@ -371,7 +390,7 @@ def build_loss_tracker(self):
**self.configs.loss_tracker_params.__dict__)

def build_dir(self):
if self.rank == 0:
if self.gatekeeper.should_proceed(gate_kept=True):
if not os.path.exists(self.configs.model_save_dir):
os.makedirs(self.configs.model_save_dir)
self.barrier()
Expand All @@ -383,6 +402,7 @@ def build_device(self):
def build_ranks(self):
self.rank = self.get_rank()
self.num_processes = self.get_num_processes()
self.gatekeeper = MultirankGateKeeper(self.rank, self.num_processes)

def build_scalable_parameters(self):
self.all_proc_batch_size = self.configs.batch_size_per_process * self.num_processes
Expand Down Expand Up @@ -422,8 +442,7 @@ def run_training(self):

if self.verbose and self.rank == 0:
self.loss_tracker.print_losses()
if self.rank == 0:
self.update_saved_model(filename='final_model.pth')
self.update_saved_model(filename='final_model.pth')

def compute_losses(self, loss_records, preds, labels):
"""
Expand Down Expand Up @@ -492,40 +511,41 @@ def process_data_loader_yield(self, data_and_labels):
data = []
for i in range(len(data_and_labels)):
data.append(data_and_labels[i][0])
data = torch.concat(data, dim=0).to(self.device)
data = self.move_to_device(torch.concat(data, dim=0))
n_labels = len(data_and_labels[0]) - 1
label_list = [[] for i in range(n_labels)]
for i_item in range(n_labels):
for i_sample in range(len(data_and_labels)):
label_list[i_item].append(data_and_labels[i_sample][i_item + 1])
labels = [torch.concat(label_list[i]).to(self.device) for i in range(len(label_list))]
labels = [self.move_to_device(torch.concat(label_list[i])) for i in range(len(label_list))]
else:
# In this case, data_and_labels is organized in a item-then-sample order:
# it is a tuple of items. Each element of the tuple is a tensor of
# [n_total_batch_size, ...].
data = data_and_labels[0][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank].to(self.device)
data = data_and_labels[0][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank]
data = self.move_to_device(data)
labels = []
for i in range(1, len(data_and_labels)):
labels.append(
data_and_labels[i][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank].to(self.device)
self.move_to_device(data_and_labels[i][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank])
)
else:
if isinstance(data_and_labels[0], tuple):
# In this case, data_and_labels is in sample-then-item order.
data = []
for i in range(len(data_and_labels)):
data.append(data_and_labels[i][0])
data = torch.concat(data, dim=0).to(self.device)
data = self.move_to_device(torch.concat(data, dim=0))
n_labels = len(data_and_labels[0]) - 1
label_list = [[] for i in range(n_labels)]
for i_item in range(n_labels):
for i_sample in range(len(data_and_labels)):
label_list[i_item].append(data_and_labels[i_sample][i_item + 1])
labels = [torch.concat(label_list[i]).to(self.device) for i in range(len(label_list))]
labels = [self.move_to_device(torch.concat(label_list[i])) for i in range(len(label_list))]
else:
# In this case, data_and_labels is in item-then-sample order.
data = data_and_labels[0].to(self.device)
labels = [data_and_labels[i].to(self.device) for i in range(1, len(data_and_labels))]
data = self.move_to_device(data_and_labels[0])
labels = [self.move_to_device(data_and_labels[i]) for i in range(1, len(data_and_labels))]
return data, labels

def get_epoch_loss_buffer(self):
Expand Down Expand Up @@ -566,8 +586,7 @@ def run_training_epoch(self):
acc_dict = self.loss_tracker.calculate_classification_accuracy()
self.loss_tracker.update_accuracy_history(acc_dict, 'train')

if self.rank == 0:
self.save_model_and_states_checkpoint()
self.save_model_and_states_checkpoint()

if self.configs.post_training_epoch_hook is not None:
self.configs.post_training_epoch_hook()
Expand All @@ -591,15 +610,13 @@ def run_validation(self):

losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses]
is_best = self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='val_loss')
if self.rank == 0:
self.write_training_info()
self.write_training_info()

# Update saved model if val loss is lower
if is_best:
logger.info("Saving improved model after Val Loss improved from %.5f to %.5f" % (
last_best_val_loss, self.loss_tracker['best_val_loss']))
if self.rank == 0:
self.update_saved_model(filename='best_model.pth')
self.update_saved_model(filename='best_model.pth')

if self.configs.task_type == 'classification':
self.loss_tracker.sync_classification_preds_and_labels_across_ranks()
Expand Down Expand Up @@ -715,10 +732,12 @@ def save_model(self, path, subcomponent=None):
m = getattr(m, subcomponent)
torch.save(m.state_dict(), path)

def update_saved_model(self, filename='best_model.pth', save_configs=True, subcomponent=None):
def update_saved_model(self, filename='best_model.pth', save_configs=True, subcomponent=None, run_with_only_rank_0=True):
"""
Updates saved model if validation loss is minimum.
"""
if not self.gatekeeper.should_proceed(gate_kept=True):
return
path = self.configs.model_save_dir
dest_path = os.path.join(path, filename)
if not os.path.isdir(path):
Expand All @@ -742,10 +761,11 @@ def generate_state_dict(self):
return state

def save_model_and_states_checkpoint(self):
if not self.gatekeeper.should_proceed(gate_kept=True):
return
state_dict = self.generate_state_dict()
torch.save(state_dict, os.path.join(self.configs.model_save_dir, 'checkpoint.state'))
if self.rank == 0:
self.update_saved_model('checkpoint_model.pth')
self.update_saved_model('checkpoint_model.pth')

def load_state_checkpoint(self):
if self.configs.checkpoint_dir is None:
Expand All @@ -761,6 +781,8 @@ def load_state_checkpoint(self):
self.loss_tracker = state_dict['loss_tracker']

def write_training_info(self):
if not self.gatekeeper.should_proceed(gate_kept=True):
return
f = open(os.path.join(self.configs.model_save_dir, 'training_info.txt'), 'w')
for key in self.loss_tracker:
f.write('{} = {}\n'.format(key, self.loss_tracker[key]))
Expand Down Expand Up @@ -853,6 +875,9 @@ def communicate_value_across_ranks(self, var, mode='average'):
var = comm.allgather(var)
return var

def move_to_device(self, var):
return var.to(self.device)

def cleanup_memory(self):
self.model = None
if torch.cuda.is_available():
Expand Down Expand Up @@ -887,7 +912,7 @@ def process_data_loader_yield(self, data):
each element being a tensor of [batch_size_per_process, ...].
With the collate_fn defined, the yields of dataloader are different between PyTorch 1.x and 2.x. This
function automatically detects the format and treat the data accordingly.
function automatically detects the format and processes the data accordingly.
"""
if self.parallelization_type == 'multi_node':
bsize_per_rank = self.configs.batch_size_per_process
Expand All @@ -899,17 +924,17 @@ def process_data_loader_yield(self, data):
n_data = len(data[0])
data_proc = [[] for i in range(n_data)]
for i_item in range(n_data):
for i in range(len(data_chunk)):
data_proc[i_item].append(data_chunk[i][i_item])
data = [torch.concat(data_proc[i]).to(self.device) for i in range(len(data_proc))]
for i_sample in range(len(data_chunk)):
data_proc[i_item].append(data_chunk[i_sample][i_item])
data = [self.move_to_device(torch.concat(data_proc[i])) for i in range(len(data_proc))]
else:
# In this case, data_and_labels is organized in a item-then-sample order:
# it is a tuple of items. Each element of the tuple is a tensor of
# [n_total_batch_size, ...].
data_list = []
for i in range(len(data)):
data_list.append(
data[i][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank].to(self.device)
self.move_to_device(data[i][self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank])
)
data = data_list
else:
Expand All @@ -920,10 +945,10 @@ def process_data_loader_yield(self, data):
for i_item in range(n_data):
for i_sample in range(len(data)):
data_list[i_item].append(data[i_sample][i_item])
data = [torch.concat(data_list[i]).to(self.device) for i in range(len(data_list))]
data = [self.move_to_device(torch.concat(data_list[i])) for i in range(len(data_list))]
else:
# In this case, data_and_labels is in item-then-sample order.
data = [data[i].to(self.device) for i in range(len(data))]
data = [self.move_to_device(data[i]) for i in range(len(data))]
return data

def compute_losses(self, loss_records, preds, *args, **kwargs):
Expand Down Expand Up @@ -963,7 +988,7 @@ def run_training_epoch(self):

# Zero current grads and do backprop
self.optimizer.zero_grad()
total_loss_tensor.backward()
self.run_backprop(total_loss_tensor)
self.optimizer.step()

# Update the LR according to the schedule -- CyclicLR updates each batch
Expand All @@ -974,8 +999,7 @@ def run_training_epoch(self):
losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses]
self.loss_tracker.update_losses(losses, type='loss', epoch=self.current_epoch)

if self.rank == 0:
self.save_model_and_states_checkpoint()
self.save_model_and_states_checkpoint()

if self.configs.post_training_epoch_hook is not None:
self.configs.post_training_epoch_hook()
Expand All @@ -995,15 +1019,13 @@ def run_validation(self):

losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses]
is_best = self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='val_loss')
if self.rank == 0:
self.write_training_info()
self.write_training_info()

# Update saved model if val loss is lower
if is_best:
logger.info("Saving improved model after Val Loss improved from %.5f to %.5f" % (
last_best_val_loss, self.loss_tracker['best_val_loss']))
if self.rank == 0:
self.update_saved_model(filename='best_model.pth')
self.update_saved_model(filename='best_model.pth')

if self.configs.post_validation_epoch_hook is not None:
self.configs.post_validation_epoch_hook()
Expand All @@ -1012,6 +1034,8 @@ def update_saved_model(self, filename='best_model.pth'):
"""
Updates saved model if validation loss is minimum.
"""
if not self.gatekeeper.should_proceed(gate_kept=True):
return
super().update_saved_model(filename)
encoder_filename = os.path.splitext(filename)[0] + '_encoder.pth'
super().update_saved_model(encoder_filename, save_configs=False, subcomponent='encoder')
Expand Down Expand Up @@ -1063,3 +1087,7 @@ def build_accelerate(self):

def run_backprop(self, loss_node):
self.accelerator.backward(loss_node)

def move_to_device(self, var):
# HuggingFace Accelerate should not need manual data offloading.
return var

0 comments on commit 16ac05b

Please sign in to comment.