From 198d2c3b2c941fd347aaef6a0c4b890af84e6d6d Mon Sep 17 00:00:00 2001 From: Nina Andrejevic Date: Wed, 29 Jan 2025 13:50:02 -0600 Subject: [PATCH] Fix prefetch factor. --- generic_trainer/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index b2b05dd..0ab6dd1 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -96,6 +96,12 @@ def __init__(self, pred_names_and_types=(('cs', 'cls'), ('eg', 'cls'), ('sg', 'c self['test_acc_{}'.format(pred_name)] = [] self['classification_preds_{}'.format(pred_name)] = [] self['classification_labels_{}'.format(pred_name)] = [] + for pred_name in self.regr_pred_names: + self['train_acc_{}'.format(pred_name)] = [] + self['val_acc_{}'.format(pred_name)] = [] + self['test_acc_{}'.format(pred_name)] = [] + self['regression_preds_{}'.format(pred_name)] = [] + self['regression_labels_{}'.format(pred_name)] = [] def categorize_predictions(self): assert len(self.pred_names_and_types[0]) > 1, 'Prediction names and types should be both given.' @@ -359,7 +365,7 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces self.rank = rank self.device = self.get_device() self.num_workers = self.configs.num_workers if self.configs.num_workers is not None else 0 - self.prefetch_factor = 10 if self.num_workers > 0 else None + self.prefetch_factor = None #10 if self.num_workers > 0 else None self.all_proc_batch_size = self.configs.batch_size_per_process self.learning_rate = self.configs.learning_rate_per_process self.num_epochs = self.configs.num_epochs