From f286d9f210824d6ea1563e789f49894b19c24f0e Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 21 Oct 2024 12:58:02 +0200 Subject: [PATCH] [`fix`] Prevent IndexError if output_hidden_states & ONNX (#3008) --- sentence_transformers/models/Transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 061098c37..fca50225a 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -352,8 +352,8 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]}) - if self.auto_model.config.output_hidden_states: - all_layer_idx = 2 + if self.auto_model.config.output_hidden_states and len(output_states) > 2: + all_layer_idx = 2 # I.e. after last_hidden_states and pooler_output if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states all_layer_idx = 1