Skip to content

Commit

Permalink
fix extra_match low if batch_size > 1 (#2595)
Browse files Browse the repository at this point in the history
* fix extra_match low if batch_size > 1

Signed-off-by: Wang, Yi A <[email protected]>

* add sorting to logprobs

* nit

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Baber <[email protected]>
  • Loading branch information
sywangyi and baberabb authored Dec 25, 2024
1 parent 932e8f9 commit 59f9ad4
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from functools import cached_property
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union

from lm_eval.api.registry import register_model
Expand Down Expand Up @@ -68,7 +69,9 @@ def parse_logprobs(
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choice, ctxlen in zip(out["choices"], ctxlens):
for choice, ctxlen in zip(
sorted(out["choices"], key=itemgetter("index")), ctxlens
):
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
Expand All @@ -87,8 +90,10 @@ def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]:
res.append(choices["text"])
tmp[choices["index"]] = choices["text"]
res = res + tmp
return res

@property
Expand Down Expand Up @@ -157,8 +162,10 @@ def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]:
res.append(choices["message"]["content"])
tmp[choices["index"]] = choices["message"]["content"]
res = res + tmp
return res

def tok_encode(
Expand Down

0 comments on commit 59f9ad4

Please sign in to comment.