From 51a0f1032bcf382e0a194bdaab9477e6a00387bc Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 12 Apr 2024 12:01:31 -0500 Subject: [PATCH] Loss function docstring --- generic_trainer/configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index 1f4711a..aaf9f55 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -1,6 +1,6 @@ import collections import dataclasses -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import json import os @@ -237,7 +237,7 @@ class TrainingConfig(Config): post_validation_epoch_hook: Any = None """A Callable that can be called after each validation epoch.""" - loss_function: Any = torch.nn.CrossEntropyLoss() + loss_function: Union[Callable, list[Callable, ...]] = torch.nn.CrossEntropyLoss() """ The loss function. This could be either a Callable (like torch.nn.L1Loss) or a list of Callables. When it is a list, its length should be at least `len(pred_names)` and the Callables are respectively applied