diff --git a/.github/workflows/test_suite.yml b/.github/workflows/test_suite.yml index da2b748..9acf820 100644 --- a/.github/workflows/test_suite.yml +++ b/.github/workflows/test_suite.yml @@ -17,6 +17,7 @@ env: CONDA_ENV_FILE: 'env.yaml' CONDA_ENV_NAME: 'project-test' COOKIECUTTER_PROJECT_NAME: 'project-test' + HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}} jobs: build: diff --git a/README.md b/README.md index 32575bb..cb5889a 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ Avoid writing boilerplate code to integrate: - [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research. - [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications. +- [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets. - [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)* - [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes. - [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator. diff --git a/cookiecutter.json b/cookiecutter.json index eea71a2..1c2d1fe 100644 --- a/cookiecutter.json +++ b/cookiecutter.json @@ -9,5 +9,5 @@ "repository_url": "https://github.com/{{ cookiecutter.github_user }}/{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}", "conda_env_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}", "python_version": "3.11", - "__version": "0.3.1" + "__version": "0.4.0" } diff --git a/docs/index.md b/docs/index.md index a03a0db..ad8b725 100644 --- a/docs/index.md +++ b/docs/index.md @@ -36,6 +36,7 @@ and to avoid writing boilerplate code to integrate: - [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research. - [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications. +- [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets. - [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)* - [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes. - [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator. diff --git a/{{ cookiecutter.repository_name }}/.env.template b/{{ cookiecutter.repository_name }}/.env.template index 7fdc5ba..713f3ac 100644 --- a/{{ cookiecutter.repository_name }}/.env.template +++ b/{{ cookiecutter.repository_name }}/.env.template @@ -1 +1,10 @@ -# While .env is a local file full of secrets, this can be public and ease the setup of known env variables. +# .env.template is a template for .env file that can be versioned. + +# Set to 1 to show full stack trace on error, 0 to hide it +HYDRA_FULL_ERROR=1 + +# Configure where huggingface_hub will locally store data. +HF_HOME="~/.cache/huggingface" + +# Configure the User Access Token to authenticate to the Hub +# HUGGING_FACE_HUB_TOKEN= diff --git a/{{ cookiecutter.repository_name }}/.github/workflows/publish.yml b/{{ cookiecutter.repository_name }}/.github/workflows/publish.yml index 45bf8de..259ff67 100644 --- a/{{ cookiecutter.repository_name }}/.github/workflows/publish.yml +++ b/{{ cookiecutter.repository_name }}/.github/workflows/publish.yml @@ -9,8 +9,9 @@ env: CACHE_NUMBER: 0 # increase to reset cache manually CONDA_ENV_FILE: './env.yaml' CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}' - {% raw %} + HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}} + jobs: build: strategy: diff --git a/{{ cookiecutter.repository_name }}/.github/workflows/test_suite.yml b/{{ cookiecutter.repository_name }}/.github/workflows/test_suite.yml index bbd49aa..36a2803 100644 --- a/{{ cookiecutter.repository_name }}/.github/workflows/test_suite.yml +++ b/{{ cookiecutter.repository_name }}/.github/workflows/test_suite.yml @@ -16,8 +16,9 @@ env: CACHE_NUMBER: 1 # increase to reset cache manually CONDA_ENV_FILE: './env.yaml' CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}' - {% raw %} + HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}} + jobs: build: diff --git a/{{ cookiecutter.repository_name }}/conf/default.yaml b/{{ cookiecutter.repository_name }}/conf/default.yaml index 29770ac..de9bb18 100644 --- a/{{ cookiecutter.repository_name }}/conf/default.yaml +++ b/{{ cookiecutter.repository_name }}/conf/default.yaml @@ -5,6 +5,10 @@ core: version: 0.0.1 tags: null +conventions: + x_key: 'x' + y_key: 'y' + defaults: - hydra: default - nn: default diff --git a/{{ cookiecutter.repository_name }}/conf/nn/data/dataset/vision/mnist.yaml b/{{ cookiecutter.repository_name }}/conf/nn/data/dataset/vision/mnist.yaml new file mode 100644 index 0000000..e9adb05 --- /dev/null +++ b/{{ cookiecutter.repository_name }}/conf/nn/data/dataset/vision/mnist.yaml @@ -0,0 +1,22 @@ +# This class defines which dataset to use, +# and also how to split in train/[val]/test. +_target_: {{ cookiecutter.package_name }}.utils.hf_io.load_hf_dataset +name: "mnist" +ref: "mnist" +train_split: train +# val_split: val +val_percentage: 0.1 +test_split: test +label_key: label +data_key: image +num_classes: 10 +input_shape: [1, 28, 28] +standard_x_key: ${conventions.x_key} +standard_y_key: ${conventions.y_key} +transforms: + _target_: {{ cookiecutter.package_name }}.utils.hf_io.HFTransform + key: ${conventions.x_key} + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor diff --git a/{{ cookiecutter.repository_name }}/conf/nn/data/default.yaml b/{{ cookiecutter.repository_name }}/conf/nn/data/default.yaml new file mode 100644 index 0000000..7fb7d6c --- /dev/null +++ b/{{ cookiecutter.repository_name }}/conf/nn/data/default.yaml @@ -0,0 +1,28 @@ +_target_: {{ cookiecutter.package_name }}.data.datamodule.MyDataModule + +val_images_fixed_idxs: [7371, 3963, 2861, 1701, 3172, + 1749, 7023, 1606, 6481, 1377, + 6003, 3593, 3410, 3399, 7277, + 5337, 968, 8206, 288, 1968, + 5677, 9156, 8139, 7660, 7089, + 1893, 3845, 2084, 1944, 3375, + 4848, 8704, 6038, 2183, 7422, + 2682, 6878, 6127, 2941, 5823, + 9129, 1798, 6477, 9264, 476, + 3007, 4992, 1428, 9901, 5388] + +accelerator: ${train.trainer.accelerator} + +num_workers: + train: 4 + val: 2 + test: 0 + +batch_size: + train: 512 + val: 128 + test: 16 + +defaults: + - _self_ + - dataset: vision/mnist # pick one of the yamls in nn/data/ diff --git a/{{ cookiecutter.repository_name }}/conf/nn/default.yaml b/{{ cookiecutter.repository_name }}/conf/nn/default.yaml index 861964a..50867cb 100644 --- a/{{ cookiecutter.repository_name }}/conf/nn/default.yaml +++ b/{{ cookiecutter.repository_name }}/conf/nn/default.yaml @@ -1,47 +1,23 @@ -data: - _target_: {{ cookiecutter.package_name }}.data.datamodule.MyDataModule - - datasets: - train: - _target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset - -# val: -# - _target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset - - test: - - _target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset - - accelerator: ${train.trainer.accelerator} - - num_workers: - train: 8 - val: 4 - test: 4 - - batch_size: - train: 32 - val: 16 - test: 16 - - # example - val_percentage: 0.1 +data: ??? module: - _target_: {{ cookiecutter.package_name }}.pl_modules.pl_module.MyLightningModule - optimizer: - # Adam-oriented deep learning _target_: torch.optim.Adam - # These are all default parameters for the Adam optimizer - lr: 0.001 + lr: 1e-3 betas: [ 0.9, 0.999 ] eps: 1e-08 weight_decay: 0 - lr_scheduler: - _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts - T_0: 10 - T_mult: 2 - eta_min: 0 # min value for the lr - last_epoch: -1 - verbose: False +# lr_scheduler: +# _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts +# T_0: 20 +# T_mult: 1 +# eta_min: 0 +# last_epoch: -1 +# verbose: False + + +defaults: + - _self_ + - data: default + - module: default diff --git a/{{ cookiecutter.repository_name }}/conf/nn/module/default.yaml b/{{ cookiecutter.repository_name }}/conf/nn/module/default.yaml new file mode 100644 index 0000000..a26abef --- /dev/null +++ b/{{ cookiecutter.repository_name }}/conf/nn/module/default.yaml @@ -0,0 +1,7 @@ +_target_: {{ cookiecutter.package_name }}.pl_modules.pl_module.MyLightningModule +x_key: ${conventions.x_key} +y_key: ${conventions.y_key} + +defaults: + - _self_ + - model: cnn diff --git a/{{ cookiecutter.repository_name }}/conf/nn/module/model/cnn.yaml b/{{ cookiecutter.repository_name }}/conf/nn/module/model/cnn.yaml new file mode 100644 index 0000000..df22ae6 --- /dev/null +++ b/{{ cookiecutter.repository_name }}/conf/nn/module/model/cnn.yaml @@ -0,0 +1,2 @@ +_target_: {{ cookiecutter.package_name }}.modules.module.CNN +input_shape: ${nn.data.dataset.input_shape} diff --git a/{{ cookiecutter.repository_name }}/conf/train/default.yaml b/{{ cookiecutter.repository_name }}/conf/train/default.yaml index 8547d91..ffc500a 100644 --- a/{{ cookiecutter.repository_name }}/conf/train/default.yaml +++ b/{{ cookiecutter.repository_name }}/conf/train/default.yaml @@ -17,30 +17,30 @@ trainer: restore: ckpt_or_run_path: null - mode: continue # null, finetune, hotstart, continue + mode: null # null, finetune, hotstart, continue monitor: metric: 'loss/val' mode: 'min' callbacks: - - _target_: pytorch_lightning.callbacks.EarlyStopping + - _target_: lightning.pytorch.callbacks.EarlyStopping patience: 42 verbose: False monitor: ${train.monitor.metric} mode: ${train.monitor.mode} - - _target_: pytorch_lightning.callbacks.ModelCheckpoint + - _target_: lightning.pytorch.callbacks.ModelCheckpoint save_top_k: 1 verbose: False monitor: ${train.monitor.metric} mode: ${train.monitor.mode} - - _target_: pytorch_lightning.callbacks.LearningRateMonitor + - _target_: lightning.pytorch.callbacks.LearningRateMonitor logging_interval: "step" log_momentum: False - - _target_: pytorch_lightning.callbacks.progress.tqdm_progress.TQDMProgressBar + - _target_: lightning.pytorch.callbacks.progress.tqdm_progress.TQDMProgressBar refresh_rate: 20 logging: @@ -49,7 +49,7 @@ logging: source: true logger: - _target_: pytorch_lightning.loggers.WandbLogger + _target_: lightning.pytorch.loggers.WandbLogger project: ${core.project_name} entity: null diff --git a/{{ cookiecutter.repository_name }}/setup.cfg b/{{ cookiecutter.repository_name }}/setup.cfg index 7d4928b..5ff7788 100644 --- a/{{ cookiecutter.repository_name }}/setup.cfg +++ b/{{ cookiecutter.repository_name }}/setup.cfg @@ -15,7 +15,8 @@ package_dir= =src packages=find: install_requires = - nn-template-core==0.3.* + nn-template-core==0.4.* + anypy==0.0.* # Add project specific dependencies # Stuff easy to break with updates diff --git a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/__init__.py b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/__init__.py index be2c440..af13868 100644 --- a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/__init__.py +++ b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/__init__.py @@ -6,7 +6,7 @@ # thus the logging configuration defined in the __init__.py must be called before # the lightning import otherwise it has no effect. # See https://github.com/PyTorchLightning/pytorch-lightning/issues/1503 -lightning_logger = logging.getLogger("pytorch_lightning") +lightning_logger = logging.getLogger("lightning.pytorch") # Remove all handlers associated with the lightning logger. for handler in lightning_logger.handlers[:]: lightning_logger.removeHandler(handler) diff --git a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/datamodule.py b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/datamodule.py index 3ca940a..3c7d19b 100644 --- a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/datamodule.py +++ b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/datamodule.py @@ -1,15 +1,15 @@ import logging from functools import cached_property, partial from pathlib import Path -from typing import List, Mapping, Optional, Sequence +from typing import List, Mapping, Optional import hydra +import lightning.pytorch as pl import omegaconf -import pytorch_lightning as pl from omegaconf import DictConfig -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader, Dataset from torch.utils.data.dataloader import default_collate -from torchvision import transforms +from tqdm import tqdm from nn_core.common import PROJECT_ROOT from nn_core.nn_types import Split @@ -80,6 +80,10 @@ def load(src_path: Path) -> "MetaData": class_vocab=class_vocab, ) + def __repr__(self) -> str: + attributes = ",\n ".join([f"{key}={value}" for key, value in self.__dict__.items()]) + return f"{self.__class__.__name__}(\n {attributes}\n)" + def collate_fn(samples: List, split: Split, metadata: MetaData): """Custom collate function for dataloaders with access to split and metadata. @@ -98,26 +102,26 @@ def collate_fn(samples: List, split: Split, metadata: MetaData): class MyDataModule(pl.LightningDataModule): def __init__( self, - datasets: DictConfig, + dataset: DictConfig, num_workers: DictConfig, batch_size: DictConfig, accelerator: str, # example - val_percentage: float, + val_images_fixed_idxs: List[int], ): super().__init__() - self.datasets = datasets + self.dataset = dataset self.num_workers = num_workers self.batch_size = batch_size # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#gpus self.pin_memory: bool = accelerator is not None and str(accelerator) == "gpu" self.train_dataset: Optional[Dataset] = None - self.val_datasets: Optional[Sequence[Dataset]] = None - self.test_datasets: Optional[Sequence[Dataset]] = None + self.val_dataset: Optional[Dataset] = None + self.test_dataset: Optional[Dataset] = None # example - self.val_percentage: float = val_percentage + self.val_images_fixed_idxs: List[int] = val_images_fixed_idxs @cached_property def metadata(self) -> MetaData: @@ -132,40 +136,25 @@ def metadata(self) -> MetaData: if self.train_dataset is None: self.setup(stage="fit") - return MetaData(class_vocab=self.train_dataset.dataset.class_vocab) + return MetaData(class_vocab={i: name for i, name in enumerate(self.train_dataset.features["y"].names)}) def prepare_data(self) -> None: # download only pass def setup(self, stage: Optional[str] = None): - transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) - - # Here you should instantiate your datasets, you may also split the train into train and validation if needed. - if (stage is None or stage == "fit") and (self.train_dataset is None and self.val_datasets is None): - # example - mnist_train = hydra.utils.instantiate( - self.datasets.train, - split="train", - transform=transform, - path=PROJECT_ROOT / "data", - ) - train_length = int(len(mnist_train) * (1 - self.val_percentage)) - val_length = len(mnist_train) - train_length - self.train_dataset, val_dataset = random_split(mnist_train, [train_length, val_length]) - - self.val_datasets = [val_dataset] + self.transform = hydra.utils.instantiate(self.dataset.transforms) + + self.hf_datasets = hydra.utils.instantiate(self.dataset) + self.hf_datasets.set_transform(self.transform) + + # Here you should instantiate your dataset, you may also split the train into train and validation if needed. + if (stage is None or stage == "fit") and (self.train_dataset is None and self.val_dataset is None): + self.train_dataset = self.hf_datasets["train"] + self.val_dataset = self.hf_datasets["val"] if stage is None or stage == "test": - self.test_datasets = [ - hydra.utils.instantiate( - dataset_cfg, - split="test", - path=PROJECT_ROOT / "data", - transform=transform, - ) - for dataset_cfg in self.datasets.test - ] + self.test_dataset = self.hf_datasets["test"] def train_dataloader(self) -> DataLoader: return DataLoader( @@ -177,34 +166,28 @@ def train_dataloader(self) -> DataLoader: collate_fn=partial(collate_fn, split="train", metadata=self.metadata), ) - def val_dataloader(self) -> Sequence[DataLoader]: - return [ - DataLoader( - dataset, - shuffle=False, - batch_size=self.batch_size.val, - num_workers=self.num_workers.val, - pin_memory=self.pin_memory, - collate_fn=partial(collate_fn, split="val", metadata=self.metadata), - ) - for dataset in self.val_datasets - ] - - def test_dataloader(self) -> Sequence[DataLoader]: - return [ - DataLoader( - dataset, - shuffle=False, - batch_size=self.batch_size.test, - num_workers=self.num_workers.test, - pin_memory=self.pin_memory, - collate_fn=partial(collate_fn, split="test", metadata=self.metadata), - ) - for dataset in self.test_datasets - ] + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + shuffle=False, + batch_size=self.batch_size.val, + num_workers=self.num_workers.val, + pin_memory=self.pin_memory, + collate_fn=partial(collate_fn, split="val", metadata=self.metadata), + ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_dataset, + shuffle=False, + batch_size=self.batch_size.test, + num_workers=self.num_workers.test, + pin_memory=self.pin_memory, + collate_fn=partial(collate_fn, split="test", metadata=self.metadata), + ) def __repr__(self) -> str: - return f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, " f"{self.batch_size=})" + return f"{self.__class__.__name__}(" f"{self.dataset=}, " f"{self.num_workers=}, " f"{self.batch_size=})" @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default") @@ -214,7 +197,12 @@ def main(cfg: omegaconf.DictConfig) -> None: Args: cfg: the hydra configuration """ - _: pl.LightningDataModule = hydra.utils.instantiate(cfg.data.datamodule, _recursive_=False) + m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) + m.metadata + m.setup() + + for _ in tqdm(m.train_dataloader()): + pass if __name__ == "__main__": diff --git a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/dataset.py b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/dataset.py index a1ccdd7..ed50089 100644 --- a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/dataset.py +++ b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/dataset.py @@ -1,39 +1,8 @@ import hydra import omegaconf from torch.utils.data import Dataset -from torchvision.datasets import FashionMNIST from nn_core.common import PROJECT_ROOT -from nn_core.nn_types import Split - - -class MyDataset(Dataset): - def __init__(self, split: Split, **kwargs): - super().__init__() - self.split: Split = split - - # example - self.mnist = FashionMNIST( - kwargs["path"], - train=split == "train", - download=True, - transform=kwargs["transform"], - ) - - @property - def class_vocab(self): - return self.mnist.class_to_idx - - def __len__(self) -> int: - # example - return len(self.mnist) - - def __getitem__(self, index: int): - # example - return self.mnist[index] - - def __repr__(self) -> str: - return f"MyDataset({self.split=}, n_instances={len(self)})" @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default") @@ -43,7 +12,7 @@ def main(cfg: omegaconf.DictConfig) -> None: Args: cfg: the hydra configuration """ - _: Dataset = hydra.utils.instantiate(cfg.nn.data.datasets.train, split="train", _recursive_=False) + _: Dataset = hydra.utils.instantiate(cfg.nn.data.dataset, _recursive_=False) if __name__ == "__main__": diff --git a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/module.py b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/module.py index a3f4d96..49d81a1 100644 --- a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/module.py +++ b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/module.py @@ -1,13 +1,15 @@ +from typing import Tuple + from torch import nn # https://medium.com/@nutanbhogendrasharma/pytorch-convolutional-neural-network-with-mnist-dataset-4e8a4265e118 class CNN(nn.Module): - def __init__(self, num_classes: int): + def __init__(self, input_shape: Tuple[int], num_classes: int): super(CNN, self).__init__() self.model = nn.Sequential( nn.Conv2d( - in_channels=1, + in_channels=input_shape[0], out_channels=16, kernel_size=5, stride=1, diff --git a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/pl_module.py b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/pl_module.py index 3543817..6c5cf89 100644 --- a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/pl_module.py +++ b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/pl_module.py @@ -1,9 +1,9 @@ import logging -from typing import Any, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union import hydra +import lightning.pytorch as pl import omegaconf -import pytorch_lightning as pl import torch import torch.nn.functional as F import torchmetrics @@ -13,7 +13,6 @@ from nn_core.model_logging import NNLogger from {{ cookiecutter.package_name }}.data.datamodule import MetaData -from {{ cookiecutter.package_name }}.modules.module import CNN pylogger = logging.getLogger(__name__) @@ -21,7 +20,7 @@ class MyLightningModule(pl.LightningModule): logger: NNLogger - def __init__(self, metadata: Optional[MetaData] = None, *args, **kwargs) -> None: + def __init__(self, model, metadata: Optional[MetaData] = None, *args, **kwargs) -> None: super().__init__() # Populate self.hparams with args and kwargs automagically! @@ -36,11 +35,11 @@ def __init__(self, metadata: Optional[MetaData] = None, *args, **kwargs) -> None task="multiclass", num_classes=len(metadata.class_vocab) if metadata is not None else None, ) - self.train_accuracy = metric.clone() - self.val_accuracy = metric.clone() - self.test_accuracy = metric.clone() + self.train_acc = metric.clone() + self.val_acc = metric.clone() + self.test_acc = metric.clone() - self.model = CNN(num_classes=len(metadata.class_vocab)) + self.model = hydra.utils.instantiate(model, num_classes=len(metadata.class_vocab)) def forward(self, x: torch.Tensor) -> torch.Tensor: """Method for the forward pass. @@ -54,74 +53,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # example return self.model(x) - def step(self, x, y) -> Mapping[str, Any]: + def _step(self, batch: Dict[str, torch.Tensor], split: str) -> Mapping[str, Any]: + x = batch[self.hparams.x_key] + gt_y = batch[self.hparams.y_key] + # example logits = self(x) - loss = F.cross_entropy(logits, y) - return {"logits": logits.detach(), "loss": loss} + loss = F.cross_entropy(logits, gt_y) + preds = torch.softmax(logits, dim=-1) - def training_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]: - # example - x, y = batch - step_out = self.step(x, y) + metrics = getattr(self, f"{split}_acc") + metrics.update(preds, gt_y) - self.log_dict( - {"loss/train": step_out["loss"].cpu().detach()}, - on_step=True, - on_epoch=True, - prog_bar=True, - ) - - self.train_accuracy(torch.softmax(step_out["logits"], dim=-1), y) self.log_dict( { - "acc/train": self.train_accuracy, + f"acc/{split}": metrics, + f"loss/{split}": loss, }, on_epoch=True, ) - return step_out - - def validation_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]: - # example - x, y = batch - step_out = self.step(x, y) - - self.log_dict( - {"loss/val": step_out["loss"].cpu().detach()}, - on_step=False, - on_epoch=True, - prog_bar=True, - ) + return {"logits": logits.detach(), "loss": loss} - self.val_accuracy(torch.softmax(step_out["logits"], dim=-1), y) - self.log_dict( - { - "acc/val": self.val_accuracy, - }, - on_epoch=True, - ) + def training_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]: + return self._step(batch=batch, split="train") - return step_out + def validation_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]: + return self._step(batch=batch, split="val") def test_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]: - # example - x, y = batch - step_out = self.step(x, y) - - self.log_dict( - {"loss/test": step_out["loss"].cpu().detach()}, - ) - - self.test_accuracy(torch.softmax(step_out["logits"], dim=-1), y) - self.log_dict( - { - "acc/test": self.test_accuracy, - }, - on_epoch=True, - ) - - return step_out + return self._step(batch=batch, split="test") def configure_optimizers( self, @@ -154,13 +115,8 @@ def main(cfg: omegaconf.DictConfig) -> None: Args: cfg: the hydra configuration """ - module = cfg.nn.module - _: pl.LightningModule = hydra.utils.instantiate( - module, - optim=module.optimizer, - metadata=MetaData(class_vocab={str(i): i for i in range(10)}), - _recursive_=False, - ) + m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) + _: pl.LightningModule = hydra.utils.instantiate(cfg.nn.module, _recursive_=False, metadata=m.metadata) if __name__ == "__main__": diff --git a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/run.py b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/run.py index bf66984..6558e99 100644 --- a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/run.py +++ b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/run.py @@ -2,10 +2,11 @@ from typing import List, Optional import hydra +import lightning.pytorch as pl import omegaconf -import pytorch_lightning as pl +import torch +from lightning.pytorch import Callback from omegaconf import DictConfig, ListConfig -from pytorch_lightning import Callback from nn_core.callbacks import NNTemplateCore from nn_core.common import PROJECT_ROOT @@ -19,6 +20,8 @@ pylogger = logging.getLogger(__name__) +torch.set_float32_matmul_precision("high") + def build_callbacks(cfg: ListConfig, *args: Callback) -> List[Callback]: """Instantiate the callbacks given their configuration. @@ -64,6 +67,7 @@ def run(cfg: DictConfig) -> str: # Instantiate datamodule pylogger.info(f"Instantiating <{cfg.nn.data['_target_']}>") datamodule: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) + datamodule.setup(stage=None) metadata: Optional[MetaData] = getattr(datamodule, "metadata", None) if metadata is None: @@ -98,7 +102,7 @@ def run(cfg: DictConfig) -> str: if fast_dev_run: pylogger.info("Skipping testing in 'fast_dev_run' mode!") else: - if "test" in cfg.nn.data.datasets and trainer.checkpoint_callback.best_model_path is not None: + if datamodule.test_dataset is not None and trainer.checkpoint_callback.best_model_path is not None: pylogger.info("Starting testing!") trainer.test(datamodule=datamodule) diff --git a/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/utils/hf_io.py b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/utils/hf_io.py new file mode 100644 index 0000000..4918e39 --- /dev/null +++ b/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/utils/hf_io.py @@ -0,0 +1,158 @@ +from collections import namedtuple +from pathlib import Path +from typing import Any, Callable, Dict, Sequence + +import torch +from anypy.data.metadata_dataset_dict import MetadataDatasetDict +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk +from omegaconf import DictConfig + +from nn_core.common import PROJECT_ROOT + +DatasetParams = namedtuple("DatasetParams", ["name", "fine_grained", "train_split", "test_split", "hf_key"]) + + +class HFTransform: + def __init__( + self, + key: str, + transform: Callable[[torch.Tensor], torch.Tensor], + ): + """Apply a row-wise transform to a dataset column. + + Args: + key (str): The key of the column to transform. + transform (Callable[[torch.Tensor], torch.Tensor]): The transform to apply. + """ + self.transform = transform + self.key = key + + def __call__(self, samples: Dict[str, Sequence[Any]]) -> Dict[str, Sequence[Any]]: + """Apply the transform to the samples. + + Args: + samples (Dict[str, Sequence[Any]]): The samples to transform. + + Returns: + Dict[str, Sequence[Any]]: The transformed samples. + """ + samples[self.key] = [self.transform(data) for data in samples[self.key]] + return samples + + def __repr__(self) -> str: + return repr(self.transform) + + +def preprocess_dataset( + dataset: Dataset, + cfg: Dict, +) -> Dataset: + """Preprocess a dataset. + + This function applies the following preprocessing steps: + - Rename the label column to the standard key. + - Rename the data column to the standard key. + + Do not apply transforms here, as the preprocessed dataset will be saved to disk once + and then resued; thus updates on the transforms will not be reflected in the dataset. + + Args: + dataset (Dataset): The dataset to preprocess. + cfg (Dict): The configuration. + + Returns: + Dataset: The preprocessed dataset. + """ + dataset = dataset.rename_column(cfg["label_key"], cfg["standard_y_key"]) + dataset = dataset.rename_column(cfg["data_key"], cfg["standard_x_key"]) + return dataset + + +def save_dataset_to_disk(dataset: MetadataDatasetDict, output_path: Path) -> None: + """Save a dataset to disk. + + Args: + dataset (MetadataDatasetDict): The dataset to save. + output_path (Path): The path to save the dataset to. + """ + if not isinstance(output_path, Path): + output_path = Path(output_path) + + output_path.mkdir(parents=True, exist_ok=True) + + dataset.save_to_disk(output_path) + + +def load_hf_dataset(**cfg: DictConfig) -> MetadataDatasetDict: + """Load a dataset from the HuggingFace datasets library. + + The returned dataset is a MetadataDatasetDict, which is a wrapper around a DatasetDict. + It will contain the following splits: + - train + - val + - test + If `val_split` is not specified in the config, it will be created from the train split + according to the `val_percentage` specified in the config. + + The returned dataset will be preprocessed and saved to disk, + if it does not exist yet, and loaded from disk otherwise. + + Args: + cfg: The configuration. + + Returns: + Dataset: The loaded dataset. + """ + dataset_params: DatasetParams = DatasetParams( + cfg["ref"], + None, + cfg["train_split"], + cfg["test_split"], + (cfg["ref"],), + ) + DATASET_KEY = "_".join( + map( + str, + [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None], + ) + ) + DATASET_DIR: Path = PROJECT_ROOT / "data" / "datasets" / DATASET_KEY + + if not DATASET_DIR.exists(): + train_dataset = load_dataset( + dataset_params.name, + split=dataset_params.train_split, + token=True, + ) + if "val_percentage" in cfg: + train_val_dataset = train_dataset.train_test_split(test_size=cfg["val_percentage"], shuffle=True) + train_dataset = train_val_dataset["train"] + val_dataset = train_val_dataset["test"] + elif "val_split" in cfg: + val_dataset = load_dataset( + dataset_params.name, + split=cfg["val_split"], + token=True, + ) + else: + raise RuntimeError("Either val_percentage or val_split must be specified in the config.") + + test_dataset = load_dataset( + dataset_params.name, + split=dataset_params.test_split, + token=True, + ) + + dataset: DatasetDict = MetadataDatasetDict( + train=train_dataset, + val=val_dataset, + test=test_dataset, + ) + + dataset = preprocess_dataset(dataset, cfg) + + save_dataset_to_disk(dataset, DATASET_DIR) + else: + dataset: Dataset = load_from_disk(dataset_path=str(DATASET_DIR)) + + return dataset diff --git a/{{ cookiecutter.repository_name }}/tests/conftest.py b/{{ cookiecutter.repository_name }}/tests/conftest.py index fb4d158..a8d9ae1 100644 --- a/{{ cookiecutter.repository_name }}/tests/conftest.py +++ b/{{ cookiecutter.repository_name }}/tests/conftest.py @@ -7,9 +7,9 @@ import pytest from hydra import compose, initialize from hydra.core.hydra_config import HydraConfig +from lightning.pytorch import seed_everything from omegaconf import DictConfig, OmegaConf, open_dict from pytest import FixtureRequest, TempPathFactory -from pytorch_lightning import seed_everything from nn_core.serialization import NNCheckpointIO diff --git a/{{ cookiecutter.repository_name }}/tests/test_checkpoint.py b/{{ cookiecutter.repository_name }}/tests/test_checkpoint.py index 2f3fdac..74d1f4d 100644 --- a/{{ cookiecutter.repository_name }}/tests/test_checkpoint.py +++ b/{{ cookiecutter.repository_name }}/tests/test_checkpoint.py @@ -2,9 +2,9 @@ from pathlib import Path from typing import Any, Dict +from lightning.pytorch import LightningModule +from lightning.pytorch.core.saving import _load_state from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import _load_state from nn_core.serialization import NNCheckpointIO from tests.conftest import load_checkpoint