diff --git a/generic_trainer/tester.py b/generic_trainer/tester.py index 0de86b0..35f87c5 100644 --- a/generic_trainer/tester.py +++ b/generic_trainer/tester.py @@ -7,6 +7,7 @@ import tqdm import generic_trainer.trainer as trainer +from generic_trainer.trainer import LossTracker from generic_trainer.configs import * from generic_trainer.inference_util import * @@ -25,6 +26,7 @@ class Tester(trainer.Trainer): def __init__(self, configs: InferenceConfig): super().__init__(configs, skip_init=True) + self.loss_tracker = None self.rank = 0 self.num_processes = 1 self.gatekeeper = None @@ -51,6 +53,7 @@ def build(self): self.build_ranks() self.build_scalable_parameters() self.build_device() + self.build_loss_tracker() self.build_dataset() self.build_dataloaders() self.build_model()