Skip to content

Commit

Permalink
Loss function docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Apr 12, 2024
1 parent 948c66b commit 51a0f10
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import dataclasses
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import json
import os

Expand Down Expand Up @@ -237,7 +237,7 @@ class TrainingConfig(Config):
post_validation_epoch_hook: Any = None
"""A Callable that can be called after each validation epoch."""

loss_function: Any = torch.nn.CrossEntropyLoss()
loss_function: Union[Callable, list[Callable, ...]] = torch.nn.CrossEntropyLoss()
"""
The loss function. This could be either a Callable (like torch.nn.L1Loss) or a list of Callables.
When it is a list, its length should be at least `len(pred_names)` and the Callables are respectively applied
Expand Down

0 comments on commit 51a0f10

Please sign in to comment.