Skip to content

Commit

Permalink
Tester
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Apr 5, 2024
1 parent 9cfaef2 commit c05a8fd
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 8 deletions.
1 change: 1 addition & 0 deletions generic_trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from generic_trainer.trainer import *
from generic_trainer.tester import *
import generic_trainer.metrics
19 changes: 13 additions & 6 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class Config(OptionContainer):
dataset: Optional[Dataset] = None
"""The dataset object."""

pred_names: Any = ('cs', 'eg', 'sg')
"""Names of the quantities predicted by the model."""

debug: bool = False

cpu_only: bool = False
Expand All @@ -144,15 +147,22 @@ class InferenceConfig(Config):
# ===== PtychoNN configs =====
batch_size: int = 1

model_path: str = None
"""Path to a trained PtychoNN model."""
pretrained_model_path: str = None
"""Path to a trained model."""

prediction_output_path: str = None
"""Path to save PtychoNN prediction results."""
"""Path to save prediction results."""

load_pretrained_encoder_only: bool = False
"""Keep this False for testing."""

batch_size_per_process: int = 64
"""The batch size per process."""


@dataclasses.dataclass
class TrainingConfig(Config):

training_dataset: Optional[Dataset] = None
"""
The training dataset. If this is None, then the whole dataset (including training and validation) must be
Expand Down Expand Up @@ -218,9 +228,6 @@ class TrainingConfig(Config):
load_pretrained_encoder_only: bool = False
"""If True, only the pretrained encoder (backbone) will be loaded if `pretrained_model_path` is not None."""

pred_names: Any = ('cs', 'eg', 'sg')
"""Names of the quantities predicted by the model."""

validation_ratio: float = 0.1
"""Ratio of data to be used as validation set."""

Expand Down
68 changes: 68 additions & 0 deletions generic_trainer/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os

import torch
from torch.utils.data import Dataset, DataLoader

import generic_trainer.trainer as trainer
from generic_trainer.configs import *


class Tester(trainer.Trainer):

def __init__(self, configs: InferenceConfig):
super().__init__(configs, skip_init=True)
self.rank = 0
self.num_processes = 1
self.gatekeeper = None
self.dataset = None
self.sampler = None
self.dataloader = None

def build(self):
self.build_ranks()
self.build_scalable_parameters()
self.build_device()
self.build_dataset()
self.build_dataloaders()
self.build_model()
self.build_dir()

def build_scalable_parameters(self):
self.all_proc_batch_size = self.configs.batch_size_per_process * self.num_processes

def build_dataset(self):
self.dataset = self.configs.dataset

def build_dataloaders(self):
if self.parallelization_type == 'multi_node':
self.sampler = torch.utils.data.distributed.DistributedSampler(self.training_dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=False,
shuffle=False)
self.dataloader = DataLoader(self.dataset,
batch_size=self.configs.batch_size_per_process,
sampler=self.sampler,
collate_fn=lambda x: x)
else:
self.dataloader = DataLoader(self.training_dataset, shuffle=False,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=False)

def build_dir(self):
if self.gatekeeper.should_proceed(gate_kept=True):
if not os.path.exists(self.configs.prediction_output_path):
os.makedirs(self.configs.prediction_output_path)
self.barrier()

def run(self):
self.model.eval()
for j, data_and_labels in enumerate(self.dataloader):
data, _ = self.process_data_loader_yield(data_and_labels)
preds = self.model(*data)
self.save_predictions(preds)

def save_predictions(self, preds):
pass
7 changes: 5 additions & 2 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Optional, Any
from typing import Optional, Any, Union
import itertools
import copy
import re
Expand Down Expand Up @@ -231,7 +231,8 @@ def dump(self, path):

class Trainer:

def __init__(self, configs: TrainingConfig, rank=None, num_processes=None, *args, **kwargs):
def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_processes=None, skip_init=False,
*args, **kwargs):
"""
Trainer constructor.
Expand All @@ -244,6 +245,8 @@ def __init__(self, configs: TrainingConfig, rank=None, num_processes=None, *args
None unless multi_node is intended and training is run using torch.multiprocessing.spawn.
"""
self.configs = configs
if skip_init:
return
self.parallelization_type = self.configs.parallelization_params.parallelization_type
self.dataset = self.configs.dataset
self.training_dataset = None
Expand Down

0 comments on commit c05a8fd

Please sign in to comment.