Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save function fix #329

Merged
merged 8 commits into from
Dec 18, 2023
Merged

Save function fix #329

merged 8 commits into from
Dec 18, 2023

Conversation

michaelbenayoun
Copy link
Member

This PR aims at fixing the save function.

Since transformers==4.35.0 the default saving function being used is the safetensors.torch.save_file one. This PR patches the saving mechanism to account for that.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@michaelbenayoun michaelbenayoun changed the title Save function taking a lot of time Save function fix Nov 27, 2023
@michaelbenayoun michaelbenayoun marked this pull request as ready for review November 27, 2023 16:31
Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

I just left some very minor nits, mostly because that I want to understand better.

else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
with safe_save_function_patcher:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is self.model under this case? If unwrap_model(self.model) is not an instance of PreTrainedModel, can it still apply save_pretrained()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.model is the model we are currently training and is an instance of PreTrainedModel in this case.

No, if unwrap_model(self.model) is not an instance of PreTrainedModel we cannot call save_pretrained().

@@ -133,14 +133,22 @@ def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = Tru


def is_private_repo(repo_id: str) -> bool:
"""Tells whether `repo_id` is private."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function check whether a repo is private to all public? If so why do we need to check if it's private to a particular user in advance?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function checks whether is private to the general public but public to the current user.

optimum/neuron/utils/cache_utils.py Outdated Show resolved Hide resolved
global_master: bool = False,
):
"""
Torch XLA compatible implementation of `safetensors.torch.save_file`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate and explain those arguments a bit? I am not familiar with what master_only and global_master indicate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So those two parameters are related to distributed training.
Basically, when master_only is True, only the master rank will be saving the file instead of all ranks, and global_master controls whether only the global master (accross multiple nodes) should be saving or the master of each node.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More information here

save_function=xm.save,
)
with safe_save_function_patcher:
unwrap_model(self.model).save_pretrained(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Could you explain what you want to avoid/achieve here ?
  2. Could you explain what will the exact chain of calls be here ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more specific, why could not you just continue with the typical transformers paradigm: pass the correct flag to tag the main process and give the specific save function (looking at the patch it is not clear why it is not redundant).

Copy link
Member Author

@michaelbenayoun michaelbenayoun Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically the chain of calls here is:

  1. We enter a context manager that patches the transformers.modeling_utils.safe_save_file function which is safetensors.torch.save_file to be my torch_xla compatible version of this function.
  2. The model is unwraped (it's not really important here, it's related to the Trainer and how sometimes the model is wrapped for features I am not really sure we support).
  3. The model is saved, and since last Transformers release, it will save the checkpoint using safetensors. The issue is that when safetensors is used, the save_function parameter is ignored. That is the reason why we do it like that instead of simply passing the torch_xla_safe_save_file as the value for save_function.
  4. We exit the context manager and everything that was patched is restored to its original value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explaination !

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks !

save_function=xm.save,
)
with safe_save_function_patcher:
unwrap_model(self.model).save_pretrained(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explaination !

@michaelbenayoun michaelbenayoun merged commit c2367ae into main Dec 18, 2023
7 checks passed
@michaelbenayoun michaelbenayoun deleted the fix_long_save branch December 18, 2023 16:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants