diff --git a/generic_trainer/__init__.py b/generic_trainer/__init__.py index a0013e9..bc5a08b 100644 --- a/generic_trainer/__init__.py +++ b/generic_trainer/__init__.py @@ -1,2 +1,3 @@ from generic_trainer.trainer import * +from generic_trainer.tester import * import generic_trainer.metrics diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index 71501ca..d69e79a 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -125,6 +125,9 @@ class Config(OptionContainer): dataset: Optional[Dataset] = None """The dataset object.""" + pred_names: Any = ('cs', 'eg', 'sg') + """Names of the quantities predicted by the model.""" + debug: bool = False cpu_only: bool = False @@ -144,15 +147,22 @@ class InferenceConfig(Config): # ===== PtychoNN configs ===== batch_size: int = 1 - model_path: str = None - """Path to a trained PtychoNN model.""" + pretrained_model_path: str = None + """Path to a trained model.""" prediction_output_path: str = None - """Path to save PtychoNN prediction results.""" + """Path to save prediction results.""" + + load_pretrained_encoder_only: bool = False + """Keep this False for testing.""" + + batch_size_per_process: int = 64 + """The batch size per process.""" @dataclasses.dataclass class TrainingConfig(Config): + training_dataset: Optional[Dataset] = None """ The training dataset. If this is None, then the whole dataset (including training and validation) must be @@ -218,9 +228,6 @@ class TrainingConfig(Config): load_pretrained_encoder_only: bool = False """If True, only the pretrained encoder (backbone) will be loaded if `pretrained_model_path` is not None.""" - pred_names: Any = ('cs', 'eg', 'sg') - """Names of the quantities predicted by the model.""" - validation_ratio: float = 0.1 """Ratio of data to be used as validation set.""" diff --git a/generic_trainer/tester.py b/generic_trainer/tester.py new file mode 100644 index 0000000..3fd36d0 --- /dev/null +++ b/generic_trainer/tester.py @@ -0,0 +1,68 @@ +import os + +import torch +from torch.utils.data import Dataset, DataLoader + +import generic_trainer.trainer as trainer +from generic_trainer.configs import * + + +class Tester(trainer.Trainer): + + def __init__(self, configs: InferenceConfig): + super().__init__(configs, skip_init=True) + self.rank = 0 + self.num_processes = 1 + self.gatekeeper = None + self.dataset = None + self.sampler = None + self.dataloader = None + + def build(self): + self.build_ranks() + self.build_scalable_parameters() + self.build_device() + self.build_dataset() + self.build_dataloaders() + self.build_model() + self.build_dir() + + def build_scalable_parameters(self): + self.all_proc_batch_size = self.configs.batch_size_per_process * self.num_processes + + def build_dataset(self): + self.dataset = self.configs.dataset + + def build_dataloaders(self): + if self.parallelization_type == 'multi_node': + self.sampler = torch.utils.data.distributed.DistributedSampler(self.training_dataset, + num_replicas=self.num_processes, + rank=self.rank, + drop_last=False, + shuffle=False) + self.dataloader = DataLoader(self.dataset, + batch_size=self.configs.batch_size_per_process, + sampler=self.sampler, + collate_fn=lambda x: x) + else: + self.dataloader = DataLoader(self.training_dataset, shuffle=False, + batch_size=self.all_proc_batch_size, + collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(), + generator=self.get_dataloader_generator(), num_workers=0, + drop_last=False) + + def build_dir(self): + if self.gatekeeper.should_proceed(gate_kept=True): + if not os.path.exists(self.configs.prediction_output_path): + os.makedirs(self.configs.prediction_output_path) + self.barrier() + + def run(self): + self.model.eval() + for j, data_and_labels in enumerate(self.dataloader): + data, _ = self.process_data_loader_yield(data_and_labels) + preds = self.model(*data) + self.save_predictions(preds) + + def save_predictions(self, preds): + pass diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index a32dd7d..35339ae 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -1,6 +1,6 @@ import logging import os -from typing import Optional, Any +from typing import Optional, Any, Union import itertools import copy import re @@ -231,7 +231,8 @@ def dump(self, path): class Trainer: - def __init__(self, configs: TrainingConfig, rank=None, num_processes=None, *args, **kwargs): + def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_processes=None, skip_init=False, + *args, **kwargs): """ Trainer constructor. @@ -244,6 +245,8 @@ def __init__(self, configs: TrainingConfig, rank=None, num_processes=None, *args None unless multi_node is intended and training is run using torch.multiprocessing.spawn. """ self.configs = configs + if skip_init: + return self.parallelization_type = self.configs.parallelization_params.parallelization_type self.dataset = self.configs.dataset self.training_dataset = None