-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save function fix #329
Save function fix #329
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix!
I just left some very minor nits, mostly because that I want to understand better.
else: | ||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | ||
state_dict = self.model.state_dict() | ||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | ||
else: | ||
self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) | ||
with safe_save_function_patcher: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is self.model
under this case? If unwrap_model(self.model)
is not an instance of PreTrainedModel
, can it still apply save_pretrained()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.model
is the model we are currently training and is an instance of PreTrainedModel
in this case.
No, if unwrap_model(self.model)
is not an instance of PreTrainedModel
we cannot call save_pretrained()
.
@@ -133,14 +133,22 @@ def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = Tru | |||
|
|||
|
|||
def is_private_repo(repo_id: str) -> bool: | |||
"""Tells whether `repo_id` is private.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this function check whether a repo is private to all public? If so why do we need to check if it's private to a particular user in advance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function checks whether is private to the general public but public to the current user.
global_master: bool = False, | ||
): | ||
""" | ||
Torch XLA compatible implementation of `safetensors.torch.save_file`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate and explain those arguments a bit? I am not familiar with what master_only
and global_master
indicate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So those two parameters are related to distributed training.
Basically, when master_only
is True
, only the master rank will be saving the file instead of all ranks, and global_master
controls whether only the global master (accross multiple nodes) should be saving or the master of each node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More information here
save_function=xm.save, | ||
) | ||
with safe_save_function_patcher: | ||
unwrap_model(self.model).save_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Could you explain what you want to avoid/achieve here ?
- Could you explain what will the exact chain of calls be here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be more specific, why could not you just continue with the typical transformers paradigm: pass the correct flag to tag the main process and give the specific save function (looking at the patch it is not clear why it is not redundant).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So basically the chain of calls here is:
- We enter a context manager that patches the
transformers.modeling_utils.safe_save_file
function which issafetensors.torch.save_file
to be mytorch_xla
compatible version of this function. - The model is unwraped (it's not really important here, it's related to the Trainer and how sometimes the model is wrapped for features I am not really sure we support).
- The model is saved, and since last Transformers release, it will save the checkpoint using
safetensors
. The issue is that whensafetensors
is used, thesave_function
parameter is ignored. That is the reason why we do it like that instead of simply passing thetorch_xla_safe_save_file
as the value forsave_function
. - We exit the context manager and everything that was patched is restored to its original value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explaination !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks !
save_function=xm.save, | ||
) | ||
with safe_save_function_patcher: | ||
unwrap_model(self.model).save_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explaination !
This PR aims at fixing the save function.
Since
transformers==4.35.0
the default saving function being used is thesafetensors.torch.save_file
one. This PR patches the saving mechanism to account for that.