diff --git a/sentence_transformers/losses/AdaptiveLayerLoss.py b/sentence_transformers/losses/AdaptiveLayerLoss.py index 8eb9cf4ed..82931aede 100644 --- a/sentence_transformers/losses/AdaptiveLayerLoss.py +++ b/sentence_transformers/losses/AdaptiveLayerLoss.py @@ -12,6 +12,9 @@ from sentence_transformers import SentenceTransformer from sentence_transformers.losses.CachedGISTEmbedLoss import CachedGISTEmbedLoss from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import CachedMultipleNegativesRankingLoss +from sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss import ( + CachedMultipleNegativesSymmetricRankingLoss, +) from sentence_transformers.models import Transformer @@ -149,7 +152,8 @@ def __init__( - `Adaptive Layers <../../examples/training/adaptive_layer/README.html>`_ Requirements: - 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss` or :class:`CachedGISTEmbedLoss`. + 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`, + :class:`CachedMultipleNegativesSymmetricRankingLoss`, or :class:`CachedGISTEmbedLoss`. Inputs: +---------------------------------------+--------+ @@ -192,10 +196,11 @@ def __init__( self.kl_div_weight = kl_div_weight self.kl_temperature = kl_temperature assert isinstance(self.model[0], Transformer) - if isinstance(loss, CachedMultipleNegativesRankingLoss): - warnings.warn("MatryoshkaLoss is not compatible with CachedMultipleNegativesRankingLoss.", stacklevel=2) - if isinstance(loss, CachedGISTEmbedLoss): - warnings.warn("MatryoshkaLoss is not compatible with CachedGISTEmbedLoss.", stacklevel=2) + if isinstance( + loss, + (CachedMultipleNegativesRankingLoss, CachedMultipleNegativesSymmetricRankingLoss, CachedGISTEmbedLoss), + ): + warnings.warn(f"MatryoshkaLoss is not compatible with {loss.__class__.__name__}.", stacklevel=2) def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: # Decorate the forward function of the transformer to cache the embeddings of all layers diff --git a/sentence_transformers/losses/Matryoshka2dLoss.py b/sentence_transformers/losses/Matryoshka2dLoss.py index 7c85884d5..dc77fe052 100644 --- a/sentence_transformers/losses/Matryoshka2dLoss.py +++ b/sentence_transformers/losses/Matryoshka2dLoss.py @@ -79,7 +79,8 @@ def __init__( - `Adaptive Layers <../../examples/training/adaptive_layer/README.html>`_ Requirements: - 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`. + 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`, + :class:`CachedMultipleNegativesSymmetricRankingLoss`, or :class:`CachedGISTEmbedLoss`. Inputs: +---------------------------------------+--------+ diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index 4dfe0acb4..7f11400e4 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -124,6 +124,9 @@ def __init__( different embedding dimensions. This is useful for when you want to train a model where users have the option to lower the embedding dimension to improve their embedding comparison speed and costs. + This loss is also compatible with the Cached... losses, which are in-batch negative losses that allow for + higher batch sizes. The higher batch sizes allow for more negatives, and often result in a stronger model. + Args: model: SentenceTransformer model loss: The loss function to be used, e.g. @@ -143,9 +146,6 @@ def __init__( - The concept was introduced in this paper: https://arxiv.org/abs/2205.13147 - `Matryoshka Embeddings <../../examples/training/matryoshka/README.html>`_ - Requirements: - 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss` or :class:`CachedGISTEmbedLoss`. - Inputs: +---------------------------------------+--------+ | Texts | Labels |