Skip to content

Commit

Permalink
Fix BoN (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun authored Jan 3, 2025
1 parent bc42252 commit 179c75b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

data/
19 changes: 12 additions & 7 deletions recipes/README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
# Recipes

Here we include yaml configs to run the three test time compute variants detailed in the blog post:
Here we include YAML configs to run the three test time compute variants detailed in the blog post:

- Best of N: [`recipes/Llama-3.2-1B-Instruct/best_of_n.yaml`](Llama-3.2-1B-Instruct/best_of_n.yaml)
- Beam Search: [`recipes/Llama-3.2-1B-Instruct/beam_search.yaml`](Llama-3.2-1B-Instruct/beam_search.yaml)
- Diverse Verifier Beam Search (DVTS): [`recipes/Llama-3.2-1B-Instruct/dvts.yaml`](Llama-3.2-1B-Instruct/dvts.yaml)

Each approach can be launched by specifying the associated yaml file:
Each approach can be launched by specifying the associated YAML file:

```
python scripts/test_time_compute.py <YAML_CONFIG>
# for example:
python scripts/test_time_compute.py recipes/Llama-3.2-1B-Instruct/best_of_n.yaml
```


The configs shown here are for the `Llama-3.2-1B-Instruct` model, you can override the size of the llama model evaluated by including it in the command line arguments:
The configs shown here are for the `Llama-3.2-1B-Instruct` model, you can override the choice of model by including it in the command line arguments:

```shell
python scripts/test_time_compute.py recipes/Llama-3.2-1B-Instruct/best_of_n.yaml --model_path=Llama-3.2-3B-Instruct --hub_dataset_id=<YOUR_ORG>/Llama-3.2-3B-Instruct-bon-completions
```

> [!WARNING]
> __best of n__ and __DVTS__ can be run at `n=256` and then subsampled for get complarable solutions for running at `n=4,16,64` etc. The beam search variant **must** be run at the correct `n` in order to make a valid comparison.
> __best of n__ and __DVTS__ can be run at `n=256` and then subsampled for get comparable solutions for running at `n=4,16,64` etc. The beam search variant **must** be run at the correct `n` in order to make a valid comparison.

## Reproducing results on the MATH-500 dataset

## Reproducing results on the MATH-500 dataset:
We provide slurm scripts to configure array jobs to parallelize the evaluation of the three methods:
We provide Slurm scripts to configure array jobs to parallelize the evaluation of the three methods:


```shell
Expand All @@ -41,11 +44,13 @@ sbatch recipes/launch_array.slurm recipes/Llama-3.2-1B-Instruct/dvts.yaml --n=16
By default this will shard the dataset into 20 chunks in order to run the algorithm in parallel, the dataset will be pushed to the Hugging Face hub.

The full dataset can then be recontructed with:

```shell
python scripts/merge_chunks.py --dataset_name=<YOUR_ORG>/Llama-3.2-1B-Instruct-bon-completions
```

## Exacting the MATH-500 accuracy numbers:
## Exacting the MATH-500 accuracy numbers

To get the final numbers for the evalations, we use the [Qwen2.5-Math evaluation repo](https://github.com/QwenLM/Qwen2.5-Math), their codebase is well documented, so please refer to their instuctions.


2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"pebble", # for parallel processing
"latex2sympy2==1.9.1", # for MATH answer parsing
"word2number", # for MATH answer parsing
"transformers>=4.47.0",
"transformers>=4.47.0",
"fastapi",
]

Expand Down
9 changes: 6 additions & 3 deletions src/sal/search/best_of_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@ def best_of_n(x, config: Config, llm: LLM, prm: PRM):
{"role": "system", "content": config.system_prompt},
{"role": "user", "content": prompt},
]
for prompt in x["problem"] * config.n
for prompt in x["problem"]
]
tokenizer = llm.get_tokenizer()
# TODO: set the augmented template from a file
if config.custom_chat_template is not None:
tokenizer.chat_template = config.custom_chat_template
templated_convs = tokenizer.apply_chat_template(
convs,
tokenize=False,
convs, tokenize=False, add_generation_prompt=True
)

# Duplicate convs to generate config.n completions per prompt so we can do continous batching
# This makes [p1, p2, p3, p4] become [p1, p1, p2, p2, p3, p3, p4, p4] for e.g. config.n=2
templated_convs = [c for conv in templated_convs for c in [conv] * config.n]

# Initialize empty lists for completions and completion tokens
completions = [[] for _ in range(len(x["problem"]))]
completion_tokens = [[] for _ in range(len(x["problem"]))]
Expand Down
8 changes: 6 additions & 2 deletions src/sal/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,9 @@ def save_dataset(dataset, config):
if config.output_dir is None:
config.output_dir = f"data/{config.model_path}"
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
dataset.to_json(f"{config.output_dir}/{config.approach}_completions.jsonl", lines=True)
logger.info(f"Saved completions to {config.output_dir}/{config.approach}_completions.jsonl")
dataset.to_json(
f"{config.output_dir}/{config.approach}_completions.jsonl", lines=True
)
logger.info(
f"Saved completions to {config.output_dir}/{config.approach}_completions.jsonl"
)

0 comments on commit 179c75b

Please sign in to comment.