Skip to content

Commit

Permalink
Fix prefetch factor.
Browse files Browse the repository at this point in the history
  • Loading branch information
Nina Andrejevic committed Jan 29, 2025
1 parent bf2cc1a commit 198d2c3
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def __init__(self, pred_names_and_types=(('cs', 'cls'), ('eg', 'cls'), ('sg', 'c
self['test_acc_{}'.format(pred_name)] = []
self['classification_preds_{}'.format(pred_name)] = []
self['classification_labels_{}'.format(pred_name)] = []
for pred_name in self.regr_pred_names:
self['train_acc_{}'.format(pred_name)] = []
self['val_acc_{}'.format(pred_name)] = []
self['test_acc_{}'.format(pred_name)] = []
self['regression_preds_{}'.format(pred_name)] = []
self['regression_labels_{}'.format(pred_name)] = []

def categorize_predictions(self):
assert len(self.pred_names_and_types[0]) > 1, 'Prediction names and types should be both given.'
Expand Down Expand Up @@ -359,7 +365,7 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces
self.rank = rank
self.device = self.get_device()
self.num_workers = self.configs.num_workers if self.configs.num_workers is not None else 0
self.prefetch_factor = 10 if self.num_workers > 0 else None
self.prefetch_factor = None #10 if self.num_workers > 0 else None
self.all_proc_batch_size = self.configs.batch_size_per_process
self.learning_rate = self.configs.learning_rate_per_process
self.num_epochs = self.configs.num_epochs
Expand Down

0 comments on commit 198d2c3

Please sign in to comment.