-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 92e467e
Showing
12 changed files
with
1,908 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a single version of Python | ||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python | ||
|
||
name: Python application | ||
|
||
on: | ||
push: | ||
branches: [ "main" ] | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.10 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.10" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install pytest | ||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi | ||
pip install -e . | ||
- name: Test with pytest | ||
run: | | ||
cd tests | ||
pytest -s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import dataclasses | ||
from typing import Any | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from generic_trainer.trainer import Trainer | ||
from generic_trainer.configs import * | ||
|
||
from dataset_handle import DummyClassificationDataset | ||
|
||
|
||
class ClassificationModel(nn.Module): | ||
def __init__(self, dim_input=128, dim_hidden=256, num_classes=(5, 7), *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.feature_extractor = nn.Sequential( | ||
nn.Linear(dim_input, dim_hidden), | ||
nn.ReLU() | ||
) | ||
self.head1 = nn.Sequential( | ||
nn.Linear(dim_hidden, num_classes[0]), | ||
nn.Softmax(dim=1) | ||
) | ||
self.head2 = nn.Sequential( | ||
nn.Linear(dim_hidden, num_classes[1]), | ||
nn.Softmax(dim=1) | ||
) | ||
|
||
def forward(self, x): | ||
x = self.feature_extractor(x) | ||
y1 = self.head1(x) | ||
y2 = self.head2(x) | ||
return y1, y2 | ||
|
||
|
||
@dataclasses.dataclass | ||
class ClassificationModelParameters(ModelParameters): | ||
dim_input: int = 128, | ||
dim_hidden: int = 512, | ||
num_classes: Any = (5, 7) | ||
|
||
|
||
if __name__ == '__main__': | ||
dataset = DummyClassificationDataset(assumed_array_shape=(40, 128), label_dims=(5, 7), add_channel_dim=False) | ||
|
||
configs = TrainingConfig( | ||
model_class=ClassificationModel, | ||
model_params=ClassificationModelParameters( | ||
dim_input=128, | ||
dim_hidden=256, | ||
num_classes=(5, 7) | ||
), | ||
parallelization_params=ParallelizationConfig( | ||
parallelization_type='single_node' | ||
), | ||
pred_names=('pred1', 'pred2'), | ||
dataset=dataset, | ||
batch_size_per_process=2, | ||
learning_rate_per_process=1e-2, | ||
loss_function=nn.CrossEntropyLoss(), | ||
optimizer=torch.optim.AdamW, | ||
optimizer_params={'weight_decay': 0.01}, | ||
num_epochs=5, | ||
model_save_dir='temp', | ||
task_type='classification' | ||
) | ||
|
||
trainer = Trainer(configs) | ||
trainer.build() | ||
trainer.run_training() | ||
|
||
trainer.cleanup() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class DummyDataset(Dataset): | ||
""" | ||
A dummy dataset that generates random data for debugging. | ||
""" | ||
|
||
def __init__(self, assumed_array_shape, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.array_shape = assumed_array_shape | ||
|
||
def __len__(self): | ||
return self.array_shape[0] | ||
|
||
def get_labels(self, size, *args, **kwargs): | ||
return [torch.randint(0, 10, size)] | ||
|
||
def __getitem__(self, idx): | ||
x = torch.rand([1, *self.array_shape[1:]]) | ||
labels = self.get_labels(1) | ||
return x, *labels | ||
|
||
def __getitems__(self, idx_list): | ||
n = len(idx_list) | ||
x = torch.rand([n, *self.array_shape[1:]]) | ||
labels = self.get_labels(n) | ||
return x, *labels | ||
|
||
|
||
|
||
class DummyClassificationDataset(DummyDataset): | ||
""" | ||
A dummy dataset that generates random data for debugging. | ||
""" | ||
|
||
def __init__(self, assumed_array_shape, label_dims=(7, 101, 230), add_channel_dim=False, *args, **kwargs): | ||
""" | ||
The constructor. | ||
:param assumed_array_shape: list or tuple. The assumed array size that the dataset contains. For 1D data, | ||
this should be a 2D vector of (n_samples, n_features). | ||
:param label_dims: list or tuple. The lengths of one-hot encoded label vectors. | ||
""" | ||
super().__init__(assumed_array_shape, *args, **kwargs) | ||
self.label_dims = label_dims | ||
self.add_channel_dim = add_channel_dim | ||
|
||
def get_labels(self, n, *args, **kwargs): | ||
labels = [] | ||
for d in self.label_dims: | ||
label = torch.zeros([n, d]) | ||
inds = torch.randint(0, d, (n,)) | ||
label[tuple(range(n)), inds] = 1.0 | ||
labels.append(label) | ||
return labels | ||
|
||
def __getitem__(self, idx): | ||
if self.add_channel_dim: | ||
x = torch.rand([1, 1, self.array_shape[-1]]) | ||
else: | ||
x = torch.rand([1, self.array_shape[-1]]) | ||
labels = self.get_labels(1) | ||
return x, *labels | ||
|
||
def __getitems__(self, idx_list): | ||
n = len(idx_list) | ||
if self.add_channel_dim: | ||
x = torch.rand([n, 1, self.array_shape[-1]]) | ||
else: | ||
x = torch.rand([n, self.array_shape[-1]]) | ||
labels = self.get_labels(n) | ||
return x, *labels | ||
|
||
|
||
class DummyImageDataset(DummyDataset): | ||
|
||
def __init__(self, assumed_array_shape, label_shapes=((1, 32, 32),), *args, **kwargs): | ||
super().__init__(assumed_array_shape, *args, **kwargs) | ||
self.label_shapes = label_shapes | ||
|
||
def get_labels(self, size, *args, **kwargs): | ||
labels = [] | ||
for label_shape in self.label_shapes: | ||
lab = torch.rand(size, *label_shape) | ||
labels.append(lab) | ||
return labels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import dataclasses | ||
from typing import Any | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from generic_trainer.trainer import Trainer | ||
from generic_trainer.configs import * | ||
from generic_trainer.metrics import * | ||
|
||
from dataset_handle import DummyImageDataset | ||
|
||
|
||
class ImageRegressionModel(nn.Module): | ||
def __init__(self, num_channels_list=(4, 8, 16), kernel_size_list=(3, 3, 3), *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.num_channels_list = num_channels_list | ||
self.kernel_size_list = kernel_size_list | ||
self.encoder = self.get_encoder() | ||
self.decoder1 = self.get_decoder() | ||
self.decoder2 = self.get_decoder() | ||
|
||
def get_encoder(self): | ||
net = [] | ||
last_nc = 1 | ||
for i, n_c in enumerate(self.num_channels_list): | ||
net.append(nn.Conv2d(last_nc, n_c, kernel_size=self.kernel_size_list[i], | ||
padding=self.kernel_size_list[i] // 2)) | ||
net.append(nn.ReLU()) | ||
net.append(nn.MaxPool2d(2, 2)) | ||
last_nc = n_c | ||
return nn.Sequential(*net) | ||
|
||
def get_decoder(self): | ||
net = [] | ||
last_nc = self.num_channels_list[-1] | ||
for i, n_c in enumerate(list(self.num_channels_list[::-1][1:]) + [1]): | ||
net.append(nn.Conv2d(last_nc, n_c, kernel_size=self.kernel_size_list[::-1][i], | ||
padding=self.kernel_size_list[::-1][i] // 2)) | ||
net.append(nn.ReLU()) | ||
net.append(nn.Upsample(scale_factor=2)) | ||
last_nc = n_c | ||
return nn.Sequential(*net) | ||
|
||
def forward(self, x): | ||
x = self.encoder(x) | ||
y1 = self.decoder1(x) | ||
y2 = self.decoder2(x) | ||
return y1, y2 | ||
|
||
|
||
@dataclasses.dataclass | ||
class ImageRegressionModelParameters(ModelParameters): | ||
num_channels_list: Any = (4, 8, 16) | ||
kernel_size_list: Any = (3, 3, 3) | ||
|
||
|
||
if __name__ == '__main__': | ||
dataset = DummyImageDataset(assumed_array_shape=(40, 1, 64, 64), label_shapes=((1, 64, 64), (1, 64, 64))) | ||
|
||
configs = TrainingConfig( | ||
model_class=ImageRegressionModel, | ||
model_params=ImageRegressionModelParameters( | ||
num_channels_list=(4, 8, 16), | ||
kernel_size_list=(5, 3, 3), | ||
), | ||
parallelization_params=ParallelizationConfig( | ||
parallelization_type='single_node' | ||
), | ||
pred_names=('mag', 'phase'), | ||
dataset=dataset, | ||
batch_size_per_process=2, | ||
learning_rate_per_process=1e-2, | ||
loss_function=(nn.L1Loss(), nn.L1Loss(), TotalVariationLoss(weight=1000)), | ||
optimizer=torch.optim.AdamW, | ||
optimizer_params={'weight_decay': 0.01}, | ||
num_epochs=5, | ||
model_save_dir='temp', | ||
task_type='regression' | ||
) | ||
|
||
trainer = Trainer(configs) | ||
trainer.build() | ||
trainer.run_training() | ||
|
||
trainer.cleanup() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import torch | ||
import numpy | ||
|
||
|
||
def set_default_device(dev): | ||
try: | ||
torch.set_default_device(dev) | ||
except: | ||
pass |
Oops, something went wrong.