diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f803de40a..12c619375 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -63,6 +63,10 @@ jobs: python -m pip install --upgrade pip python -m pip install '.[train, onnx, openvino, dev]' + - name: Install model2vec + run: python -m pip install model2vec + if: ${{ contains(fromJSON('["3.10", "3.11", "3.12"]'), matrix.python-version) }} + - name: Run unit tests run: | python -m pytest --durations 20 -sv tests/ diff --git a/examples/applications/embedding-quantization/semantic_search_usearch.py b/examples/applications/embedding-quantization/semantic_search_usearch.py index 03883a330..9af0e49f3 100644 --- a/examples/applications/embedding-quantization/semantic_search_usearch.py +++ b/examples/applications/embedding-quantization/semantic_search_usearch.py @@ -6,7 +6,7 @@ from sentence_transformers.quantization import quantize_embeddings, semantic_search_usearch # 1. Load the quora corpus with questions -dataset = load_dataset("quora", split="train").map( +dataset = load_dataset("quora", split="train", trust_remote_code=True).map( lambda batch: {"text": [text for sample in batch["questions"] for text in sample["text"]]}, batched=True, remove_columns=["questions", "is_duplicate"], @@ -26,7 +26,7 @@ # 4. Choose a target precision for the corpus embeddings corpus_precision = "binary" # Valid options are: "float32", "uint8", "int8", "ubinary", and "binary" -# But usearch only supports "float32", "int8", and "binary" +# But usearch only supports "float32", "int8", "binary" and "ubinary" # 5. Encode the corpus full_corpus_embeddings = model.encode(corpus, normalize_embeddings=True, show_progress_bar=True) diff --git a/pyproject.toml b/pyproject.toml index 139510dcd..a8c3c56fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sentence-transformers" -version = "3.2.0" +version = "3.2.1" description = "State-of-the-Art Text Embeddings" license = { text = "Apache 2.0" } readme = "README.md" @@ -49,8 +49,8 @@ Repository = "https://github.com/UKPLab/sentence-transformers/" [project.optional-dependencies] train = ["datasets", "accelerate>=0.20.3"] -onnx = ["optimum[onnxruntime]>=1.23.0"] -onnx-gpu = ["optimum[onnxruntime-gpu]>=1.23.0"] +onnx = ["optimum[onnxruntime]>=1.23.1"] +onnx-gpu = ["optimum[onnxruntime-gpu]>=1.23.1"] openvino = ["optimum-intel[openvino]>=1.20.0"] dev = ["datasets", "accelerate>=0.20.3", "pre-commit", "pytest", "pytest-cov"] @@ -100,4 +100,4 @@ testpaths = [ addopts = "--strict-markers -m 'not slow'" markers = [ "slow: marks tests as slow" -] \ No newline at end of file +] diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 1a8cb2efb..4fd069f7d 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1718,10 +1718,10 @@ def _load_sbert_model( # Try to initialize the module with a lot of kwargs, but only if the module supports them # Otherwise we fall back to the load method - # try: - module = module_class(model_name_or_path, cache_dir=cache_folder, backend=self.backend, **kwargs) - # except TypeError: - # module = module_class.load(model_name_or_path) + try: + module = module_class(model_name_or_path, cache_dir=cache_folder, backend=self.backend, **kwargs) + except TypeError: + module = module_class.load(model_name_or_path) else: # Normalize does not require any files to be loaded if module_class == Normalize: diff --git a/sentence_transformers/__init__.py b/sentence_transformers/__init__.py index d3d8b0741..084b36bb0 100644 --- a/sentence_transformers/__init__.py +++ b/sentence_transformers/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__version__ = "3.2.0" +__version__ = "3.2.1" __MODEL_HUB_ORGANIZATION__ = "sentence-transformers" import importlib diff --git a/sentence_transformers/backend.py b/sentence_transformers/backend.py index eef76352e..355f40d83 100644 --- a/sentence_transformers/backend.py +++ b/sentence_transformers/backend.py @@ -78,7 +78,9 @@ def export_optimized_onnx_model( or not isinstance(model[0], Transformer) or not isinstance(model[0].auto_model, ORTModelForFeatureExtraction) ): - raise ValueError('The model must be a SentenceTransformer model loaded with `backend="onnx"`.') + raise ValueError( + 'The model must be a Transformer-based SentenceTransformer model loaded with `backend="onnx"`.' + ) ort_model: ORTModelForFeatureExtraction = model[0].auto_model optimizer = ORTOptimizer.from_pretrained(ort_model) @@ -158,7 +160,9 @@ def export_dynamic_quantized_onnx_model( or not isinstance(model[0], Transformer) or not isinstance(model[0].auto_model, ORTModelForFeatureExtraction) ): - raise ValueError('The model must be a SentenceTransformer model loaded with `backend="onnx"`.') + raise ValueError( + 'The model must be a Transformer-based SentenceTransformer model loaded with `backend="onnx"`.' + ) ort_model: ORTModelForFeatureExtraction = model[0].auto_model quantizer = ORTQuantizer.from_pretrained(ort_model) diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 5a99fa419..aa83c59e8 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -10,7 +10,7 @@ from torch.utils.checkpoint import get_device_states, set_device_states from sentence_transformers import SentenceTransformer -from sentence_transformers.models import Transformer +from sentence_transformers.models import StaticEmbedding, Transformer class RandContext: @@ -139,6 +139,11 @@ def __init__( trainer.train() """ super().__init__() + if isinstance(model[0], StaticEmbedding): + raise ValueError( + "CachedGISTEmbedLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. " + "Consider using GISTEmbedLoss instead." + ) self.model = model self.guide = guide self.temperature = temperature diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index c1e7d67c1..9c787fe8b 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -10,6 +10,7 @@ from torch.utils.checkpoint import get_device_states, set_device_states from sentence_transformers import SentenceTransformer, util +from sentence_transformers.models import StaticEmbedding class RandContext: @@ -145,6 +146,12 @@ def __init__( trainer.train() """ super().__init__() + if isinstance(model[0], StaticEmbedding): + raise ValueError( + "CachedMultipleNegativesRankingLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. " + "Consider using MultipleNegativesRankingLoss instead." + ) + self.model = model self.scale = scale self.similarity_fct = similarity_fct diff --git a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py index 83fe1e06f..ac82d133f 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py @@ -10,6 +10,7 @@ from sentence_transformers import SentenceTransformer, util from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext +from sentence_transformers.models import StaticEmbedding def _backward_hook( @@ -114,6 +115,12 @@ def __init__( - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://arxiv.org/pdf/2101.06983.pdf """ super().__init__() + if isinstance(model[0], StaticEmbedding): + raise ValueError( + "CachedMultipleNegativesSymmetricRankingLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. " + "Consider using MultipleNegativesSymmetricRankingLoss instead." + ) + self.model = model self.scale = scale self.similarity_fct = similarity_fct diff --git a/sentence_transformers/losses/DenoisingAutoEncoderLoss.py b/sentence_transformers/losses/DenoisingAutoEncoderLoss.py index bb1cf8bef..8f38342d7 100644 --- a/sentence_transformers/losses/DenoisingAutoEncoderLoss.py +++ b/sentence_transformers/losses/DenoisingAutoEncoderLoss.py @@ -7,6 +7,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from sentence_transformers import SentenceTransformer +from sentence_transformers.models import StaticEmbedding logger = logging.getLogger(__name__) @@ -73,6 +74,12 @@ def __init__( ) """ super().__init__() + + if isinstance(model[0], StaticEmbedding): + raise ValueError( + "DenoisingAutoEncoderLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding." + ) + self.encoder = model # This will be the final model used during the inference time. self.tokenizer_encoder = model.tokenizer diff --git a/sentence_transformers/losses/GISTEmbedLoss.py b/sentence_transformers/losses/GISTEmbedLoss.py index f1bb833bd..51958da5e 100644 --- a/sentence_transformers/losses/GISTEmbedLoss.py +++ b/sentence_transformers/losses/GISTEmbedLoss.py @@ -5,7 +5,7 @@ import torch from torch import Tensor, nn -from sentence_transformers.models import Transformer +from sentence_transformers.models import StaticEmbedding, Transformer from sentence_transformers.SentenceTransformer import SentenceTransformer @@ -91,6 +91,12 @@ def __init__( if self.must_retokenize: self.tokenizer = self.model.tokenizer + if isinstance(self.model[0], StaticEmbedding): + raise ValueError( + "If we must retokenize because the guide model has a different tokenizer, " + "then the Sentence Transformer model must not be based on a StaticEmbedding." + ) + def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor: return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0)) diff --git a/sentence_transformers/losses/Matryoshka2dLoss.py b/sentence_transformers/losses/Matryoshka2dLoss.py index 4b77b9c74..7c85884d5 100644 --- a/sentence_transformers/losses/Matryoshka2dLoss.py +++ b/sentence_transformers/losses/Matryoshka2dLoss.py @@ -95,21 +95,23 @@ def __init__( Example: :: - from sentence_transformers import SentenceTransformer, losses, InputExample - from torch.utils.data import DataLoader + from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses + from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") - train_examples = [ - InputExample(texts=['Anchor 1', 'Positive 1']), - InputExample(texts=['Anchor 2', 'Positive 2']), - ] - train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32) - train_loss = losses.MultipleNegativesRankingLoss(model=model) - train_loss = losses.Matryoshka2dLoss(model, train_loss, [768, 512, 256, 128, 64]) - model.fit( - [(train_dataloader, train_loss)], - epochs=10, + train_dataset = Dataset.from_dict({ + "anchor": ["It's nice weather outside today.", "He drove to work."], + "positive": ["It's so sunny.", "He took the car to the office."], + }) + loss = losses.MultipleNegativesRankingLoss(model) + loss = losses.Matryoshka2dLoss(model, loss, [768, 512, 256, 128, 64]) + + trainer = SentenceTransformerTrainer( + model=model, + train_dataset=train_dataset, + loss=loss, ) + trainer.train() """ matryoshka_loss = MatryoshkaLoss( model, diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index e6a18aac0..997e7be0b 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -101,21 +101,23 @@ def __init__( Example: :: - from sentence_transformers import SentenceTransformer, losses, InputExample - from torch.utils.data import DataLoader + from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses + from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") - train_examples = [ - InputExample(texts=['Anchor 1', 'Positive 1']), - InputExample(texts=['Anchor 2', 'Positive 2']), - ] - train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32) - train_loss = losses.MultipleNegativesRankingLoss(model=model) - train_loss = losses.MatryoshkaLoss(model, train_loss, [768, 512, 256, 128, 64]) - model.fit( - [(train_dataloader, train_loss)], - epochs=10, + train_dataset = Dataset.from_dict({ + "anchor": ["It's nice weather outside today.", "He drove to work."], + "positive": ["It's so sunny.", "He took the car to the office."], + }) + loss = losses.MultipleNegativesRankingLoss(model) + loss = losses.MatryoshkaLoss(model, loss, [768, 512, 256, 128, 64]) + + trainer = SentenceTransformerTrainer( + model=model, + train_dataset=train_dataset, + loss=loss, ) + trainer.train() """ super().__init__() self.model = model diff --git a/sentence_transformers/losses/MegaBatchMarginLoss.py b/sentence_transformers/losses/MegaBatchMarginLoss.py index a964eb726..22dbbe5ea 100644 --- a/sentence_transformers/losses/MegaBatchMarginLoss.py +++ b/sentence_transformers/losses/MegaBatchMarginLoss.py @@ -59,25 +59,30 @@ def __init__( Example: :: - from sentence_transformers import SentenceTransformer, InputExample, losses - from torch.utils.data import DataLoader + from sentence_transformers import SentenceTransformer, SentenceTransformerTrainingArguments, SentenceTransformerTrainer, losses + from datasets import Dataset - model = SentenceTransformer('all-MiniLM-L6-v2') - - total_examples = 500 train_batch_size = 250 train_mini_batch_size = 32 - train_examples = [ - InputExample(texts=[f"This is sentence number {i}", f"This is sentence number {i+1}"]) for i in range(total_examples) - ] - train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size) - train_loss = losses.MegaBatchMarginLoss(model=model, mini_batch_size=train_mini_batch_size) - - model.fit( - [(train_dataloader, train_loss)], - epochs=10, + model = SentenceTransformer('all-MiniLM-L6-v2') + train_dataset = Dataset.from_dict({ + "anchor": [f"This is sentence number {i}" for i in range(500)], + "positive": [f"This is sentence number {i}" for i in range(1, 501)], + }) + loss = losses.MegaBatchMarginLoss(model=model, mini_batch_size=train_mini_batch_size) + + args = SentenceTransformerTrainingArguments( + output_dir="output", + per_device_train_batch_size=train_batch_size, + ) + trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + loss=loss, ) + trainer.train() """ super().__init__() self.model = model diff --git a/sentence_transformers/model_card.py b/sentence_transformers/model_card.py index 99da35d96..ef279e11f 100644 --- a/sentence_transformers/model_card.py +++ b/sentence_transformers/model_card.py @@ -423,10 +423,15 @@ def set_widget_examples(self, dataset: Dataset | DatasetDict) -> None: columns = [ column for column, feature in dataset[dataset_name].features.items() - if isinstance(feature, Value) and feature.dtype == "string" and column != "dataset_name" + if isinstance(feature, Value) + and (feature.dtype == "string" or feature.dtype == "large_string") + and column != "dataset_name" ] str_dataset = dataset[dataset_name].select_columns(columns) dataset_size = len(str_dataset) + if dataset_size == 0: + continue + lengths = {} for idx, sample in enumerate( str_dataset.select(random.sample(range(dataset_size), k=min(num_samples_to_check, dataset_size))) diff --git a/sentence_transformers/models/StaticEmbedding.py b/sentence_transformers/models/StaticEmbedding.py index de69285b2..fae3756e2 100644 --- a/sentence_transformers/models/StaticEmbedding.py +++ b/sentence_transformers/models/StaticEmbedding.py @@ -159,9 +159,11 @@ def from_distillation( """ try: - from model2vec import distill + from model2vec.distill import distill except ImportError: - raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`") + raise ImportError( + "To use this method, please install the `model2vec` package: `pip install model2vec[distill]`" + ) device = get_device_name() static_model = distill( @@ -172,7 +174,10 @@ def from_distillation( apply_zipf=apply_zipf, use_subword=use_subword, ) - embedding_weights = static_model.embedding.weight + if isinstance(static_model.embedding, np.ndarray): + embedding_weights = torch.from_numpy(static_model.embedding) + else: + embedding_weights = static_model.embedding.weight tokenizer: Tokenizer = static_model.tokenizer return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_name) @@ -200,7 +205,10 @@ def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding: raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`") static_model = StaticModel.from_pretrained(model_id_or_path) - embedding_weights = static_model.embedding.weight + if isinstance(static_model.embedding, np.ndarray): + embedding_weights = torch.from_numpy(static_model.embedding) + else: + embedding_weights = static_model.embedding.weight tokenizer: Tokenizer = static_model.tokenizer return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_id_or_path) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 7592278bf..fca50225a 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -155,7 +155,7 @@ def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_ar else: model_args["ov_config"] = {} - # Either load an exported model, or export the model to ONNX + # Either load an exported model, or export the model to OpenVINO self.auto_model: OVModelForFeatureExtraction = OVModelForFeatureExtraction.from_pretrained( model_name_or_path, config=config, @@ -352,8 +352,8 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]}) - if self.auto_model.config.output_hidden_states: - all_layer_idx = 2 + if self.auto_model.config.output_hidden_states and len(output_states) > 2: + all_layer_idx = 2 # I.e. after last_hidden_states and pooler_output if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states all_layer_idx = 1 diff --git a/sentence_transformers/quantization.py b/sentence_transformers/quantization.py index 37402cae7..aa5be00f0 100644 --- a/sentence_transformers/quantization.py +++ b/sentence_transformers/quantization.py @@ -216,8 +216,8 @@ def semantic_search_usearch( `corpus_embeddings` or `corpus_index` should be used, not both. corpus_precision: Precision of the corpus embeddings. The - options are "float32", "int8", or "binary". Default is - "float32". + options are "float32", "int8", "ubinary" or "binary". Default + is "float32". top_k: Number of top results to retrieve. Default is 10. ranges: Ranges for quantization of embeddings. This is only used for int8 quantization, where the ranges refers to the @@ -263,8 +263,8 @@ def semantic_search_usearch( raise ValueError("Only corpus_embeddings or corpus_index should be used, not both.") if corpus_embeddings is None and corpus_index is None: raise ValueError("Either corpus_embeddings or corpus_index should be used.") - if corpus_precision not in ["float32", "int8", "binary"]: - raise ValueError('corpus_precision must be "float32", "int8", or "binary" for usearch') + if corpus_precision not in ["float32", "int8", "ubinary", "binary"]: + raise ValueError('corpus_precision must be "float32", "int8", "ubinary", "binary" for usearch') # If corpus_index is not provided, create a new index if corpus_index is None: @@ -284,6 +284,12 @@ def semantic_search_usearch( corpus_index = Index( ndim=corpus_embeddings.shape[1], metric="hamming", + dtype="i8", + ) + elif corpus_precision == "ubinary": + corpus_index = Index( + ndim=corpus_embeddings.shape[1] * 8, + metric="hamming", dtype="b1", ) corpus_index.add(np.arange(len(corpus_embeddings)), corpus_embeddings) @@ -331,7 +337,7 @@ def semantic_search_usearch( if rescore_embeddings is not None: top_k_embeddings = np.array([corpus_index.get(query_indices) for query_indices in indices]) # If the corpus precision is binary, we need to unpack the bits - if corpus_precision == "binary": + if corpus_precision in ("ubinary", "binary"): top_k_embeddings = np.unpackbits(top_k_embeddings.astype(np.uint8), axis=-1) top_k_embeddings = top_k_embeddings.astype(int) diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 97b0f13df..50115f6b5 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -2,7 +2,6 @@ import logging import os -import warnings from collections import OrderedDict from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Callable @@ -156,14 +155,19 @@ def __init__( raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") else: if model_init is not None: - warnings.warn( + logger.warning( "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" - " overwrite your model when calling the `train` method. This will become a fatal error in the next" - " release.", - FutureWarning, + " overwrite your model when calling the `train` method." ) self.model_init = model_init + if compute_metrics is not None: + logger.warning( + "`compute_metrics` is currently not compatible with the SentenceTransformerTrainer. Please use the " + "`evaluator` argument instead for detailed evaluation metrics, or the `eval_dataset` argument for " + "the evaluation loss." + ) + # Get a dictionary of the default training arguments, so we can determine which arguments have been changed # for the model card default_args_dict = SentenceTransformerTrainingArguments(output_dir="unused").to_dict() diff --git a/tests/models/test_static_embedding.py b/tests/models/test_static_embedding.py new file mode 100644 index 000000000..75041d852 --- /dev/null +++ b/tests/models/test_static_embedding.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +from tokenizers import Tokenizer + +from sentence_transformers.models.StaticEmbedding import StaticEmbedding + +try: + import model2vec +except ImportError: + model2vec = None + +skip_if_no_model2vec = pytest.mark.skipif(model2vec is None, reason="The model2vec library is not installed.") + + +@pytest.fixture +def tokenizer() -> Tokenizer: + return Tokenizer.from_pretrained("bert-base-uncased") + + +@pytest.fixture +def embedding_weights(): + return np.random.rand(30522, 768) + + +@pytest.fixture +def static_embedding(tokenizer: Tokenizer, embedding_weights) -> StaticEmbedding: + return StaticEmbedding(tokenizer, embedding_weights=embedding_weights) + + +def test_initialization_with_embedding_weights(tokenizer: Tokenizer, embedding_weights) -> None: + model = StaticEmbedding(tokenizer, embedding_weights=embedding_weights) + assert model.embedding.weight.shape == (30522, 768) + + +def test_initialization_with_embedding_dim(tokenizer: Tokenizer) -> None: + model = StaticEmbedding(tokenizer, embedding_dim=768) + assert model.embedding.weight.shape == (30522, 768) + + +def test_tokenize(static_embedding: StaticEmbedding) -> None: + texts = ["Hello world!", "How are you?"] + tokens = static_embedding.tokenize(texts) + assert "input_ids" in tokens + assert "offsets" in tokens + + +def test_forward(static_embedding: StaticEmbedding) -> None: + texts = ["Hello world!", "How are you?"] + tokens = static_embedding.tokenize(texts) + output = static_embedding(tokens) + assert "sentence_embedding" in output + + +def test_save_and_load(tmp_path: Path, static_embedding: StaticEmbedding) -> None: + save_dir = tmp_path / "model" + save_dir.mkdir() + static_embedding.save(str(save_dir)) + + loaded_model = StaticEmbedding.load(str(save_dir)) + assert loaded_model.embedding.weight.shape == static_embedding.embedding.weight.shape + + +@skip_if_no_model2vec() +def test_from_distillation() -> None: + model = StaticEmbedding.from_distillation("sentence-transformers-testing/stsb-bert-tiny-safetensors", pca_dims=32) + assert model.embedding.weight.shape == (29528, 32) + + +@skip_if_no_model2vec() +def test_from_model2vec() -> None: + model = StaticEmbedding.from_model2vec("minishlab/M2V_base_output") + assert model.embedding.weight.shape == (29528, 256)