Skip to content

Commit

Permalink
Update training pytests to reduce total time
Browse files Browse the repository at this point in the history
  • Loading branch information
jiminha committed Jan 22, 2025
1 parent c5d679d commit bcbcb2a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions tests/baselines/Llama_3_2_11B_Vision_Instruct.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{
"gaudi2": {
"image2text_lora_finetune": {
"num_train_epochs": 2,
"num_train_epochs": 1,
"eval_batch_size": 4,
"distribution": {
"multi_card": {
"learning_rate": 5e-5,
"train_batch_size": 2,
"train_runtime": 470,
"train_runtime": 350,
"train_samples_per_second": 20.48,
"eval_accuracy": 0.6,
"extra_arguments": [
Expand Down
8 changes: 4 additions & 4 deletions tests/baselines/falcon_40b.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
{
"gaudi2": {
"timdettmers/openassistant-guanaco": {
"num_train_epochs": 3,
"num_train_epochs": 1,
"eval_batch_size": 1,
"distribution": {
"multi_card": {
"learning_rate": 4e-4,
"train_batch_size": 1,
"perplexity": 4.0893,
"train_runtime": 931.1213,
"train_runtime": 360,
"train_samples_per_second": 28.162,
"extra_arguments": [
"--bf16",
Expand Down Expand Up @@ -36,14 +36,14 @@
}
},
"mamamiya405/finred": {
"num_train_epochs": 3,
"num_train_epochs": 1,
"eval_batch_size": 1,
"distribution": {
"multi_card": {
"learning_rate": 4e-4,
"train_batch_size": 1,
"perplexity": 4.0893,
"train_runtime": 1170,
"train_runtime": 470,
"train_samples_per_second": 28.162,
"extra_arguments": [
"--bf16",
Expand Down
4 changes: 2 additions & 2 deletions tests/baselines/gpt_neox_20b.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
{
"gaudi2": {
"wikitext": {
"num_train_epochs": 2,
"num_train_epochs": 1,
"eval_batch_size": 2,
"distribution": {
"deepspeed": {
"learning_rate": 5e-5,
"train_batch_size": 2,
"perplexity": 8.169664686471043,
"train_runtime": 781.7156,
"train_runtime": 445,
"train_samples_per_second": 7.328,
"extra_arguments": [
"--dataset_config_name wikitext-2-raw-v1",
Expand Down
18 changes: 9 additions & 9 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,30 @@
("bert-large-uncased-whole-word-masking", "Habana/bert-large-uncased-whole-word-masking"),
],
"roberta": [
("roberta-base", "Habana/roberta-base"),
#("roberta-base", "Habana/roberta-base"),
("roberta-large", "Habana/roberta-large"),
],
"albert": [
("albert-large-v2", "Habana/albert-large-v2"),
("albert-xxlarge-v1", "Habana/albert-xxlarge-v1"),
#("albert-xxlarge-v1", "Habana/albert-xxlarge-v1"),
],
"distilbert": [
("distilbert-base-uncased", "Habana/distilbert-base-uncased"),
#("distilbert-base-uncased", "Habana/distilbert-base-uncased"),
],
"gpt2": [
("gpt2", "Habana/gpt2"),
#("gpt2", "Habana/gpt2"),
("gpt2-xl", "Habana/gpt2"),
],
"t5": [
("t5-small", "Habana/t5"),
#("t5-small", "Habana/t5"),
("google/flan-t5-xxl", "Habana/t5"),
],
"vit": [
("google/vit-base-patch16-224-in21k", "Habana/vit"),
],
"wav2vec2": [
("facebook/wav2vec2-base", "Habana/wav2vec2"),
("facebook/wav2vec2-large-lv60", "Habana/wav2vec2"),
#("facebook/wav2vec2-large-lv60", "Habana/wav2vec2"),
],
"swin": [("microsoft/swin-base-patch4-window7-224-in22k", "Habana/swin")],
"clip": [("./clip-roberta", "Habana/clip")],
Expand All @@ -68,10 +68,10 @@
}

MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [
"bert",
#"bert",
"roberta",
"albert",
"distilbert",
#"albert",
#"distilbert",
]

# Only BERT has been officially validated for sequence classification
Expand Down

0 comments on commit bcbcb2a

Please sign in to comment.