-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save function fix #329
Save function fix #329
Changes from all commits
f2e7c99
f03e9b9
ca87dd4
aea8db1
36fcfc4
607ae21
158ff4e
ef9b15c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ | |
from .utils import ( | ||
DynamicPatch, | ||
ModelPatcher, | ||
Patcher, | ||
is_torch_xla_available, | ||
patch_within_function, | ||
) | ||
|
@@ -62,6 +63,7 @@ | |
prepare_environment_for_neuron, | ||
set_neuron_cc_optlevel_for_model, | ||
skip_first_batches, | ||
torch_xla_safe_save_file, | ||
) | ||
|
||
|
||
|
@@ -402,20 +404,27 @@ def _save_xla(self, output_dir: Optional[str] = None): | |
xm.mark_step() | ||
parallelizer.save_model_checkpoint(self.model, output_dir, as_sharded=True, optimizer=self.optimizer) | ||
else: | ||
safe_save_function_patcher = Patcher( | ||
[("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)] | ||
) | ||
if not isinstance(self.model, PreTrainedModel): | ||
if isinstance(unwrap_model(self.model), PreTrainedModel): | ||
unwrap_model(self.model).save_pretrained( | ||
output_dir, | ||
is_main_process=self.args.should_save, | ||
state_dict=self.model.state_dict(), | ||
save_function=xm.save, | ||
) | ||
with safe_save_function_patcher: | ||
unwrap_model(self.model).save_pretrained( | ||
output_dir, | ||
is_main_process=self.args.should_save, | ||
state_dict=self.model.state_dict(), | ||
save_function=xm.save, | ||
) | ||
else: | ||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | ||
state_dict = self.model.state_dict() | ||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | ||
else: | ||
self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) | ||
with safe_save_function_patcher: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No, if |
||
self.model.save_pretrained( | ||
output_dir, is_main_process=self.args.should_save, save_function=xm.save | ||
) | ||
|
||
if self.tokenizer is not None and self.args.should_save: | ||
self.tokenizer.save_pretrained(output_dir) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,14 +133,22 @@ def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = Tru | |
|
||
|
||
def is_private_repo(repo_id: str) -> bool: | ||
"""Tells whether `repo_id` is private.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this function check whether a repo is private to all public? If so why do we need to check if it's private to a particular user in advance? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function checks whether is private to the general public but public to the current user. |
||
if _DISABLE_IS_PRIVATE_REPO_CHECK: | ||
return False | ||
HfApi().list_repo_files(repo_id=repo_id, token=HfFolder.get_token()) | ||
private = False | ||
try: | ||
HfApi().list_repo_files(repo_id=repo_id, token=False) | ||
HfApi().model_info(repo_id=repo_id, token=HfFolder.get_token()) | ||
private_to_user = False | ||
except RepositoryNotFoundError: | ||
private_to_user = True | ||
if private_to_user: | ||
private = True | ||
else: | ||
try: | ||
HfApi().model_info(repo_id=repo_id, token=False) | ||
private = False | ||
except RepositoryNotFoundError: | ||
private = True | ||
return private | ||
|
||
|
||
|
@@ -921,10 +929,20 @@ def push_to_cache_on_hub( | |
overwrite_existing: bool = False, | ||
local_path_to_path_in_repo: Optional[Union[Literal["default"], Callable[[Path], Path]]] = None, | ||
fail_when_could_not_push: bool = False, | ||
) -> CachedModelOnTheHub: | ||
) -> Optional[CachedModelOnTheHub]: | ||
if cache_repo_id is None: | ||
cache_repo_id = get_hf_hub_cache_repos()[0] | ||
|
||
if not has_write_access_to_repo(cache_repo_id): | ||
error_message = ( | ||
f"Could not push the cached model to {cache_repo_id} because you do not have write access to this repo." | ||
) | ||
if fail_when_could_not_push: | ||
raise ValueError(error_message) | ||
if is_main_worker(): | ||
logger.warning(error_message) | ||
return | ||
|
||
try: | ||
create_registry_file_if_does_not_exist(cache_repo_id) | ||
_REGISTRY_FILE_EXISTS[cache_repo_id] = True | ||
|
@@ -933,10 +951,15 @@ def push_to_cache_on_hub( | |
|
||
is_cache_repo_private = is_private_repo(cache_repo_id) | ||
if neuron_hash.is_private and not is_cache_repo_private: | ||
raise ValueError( | ||
f"Cannot push the cached model to {cache_repo_id} because this repo is not private but the original model is " | ||
"coming from private repo." | ||
error_message = ( | ||
f"Could not push the cached model to {cache_repo_id} because this repo is not private but the original " | ||
"model is coming from private repo." | ||
) | ||
if fail_when_could_not_push: | ||
raise ValueError(error_message) | ||
if is_main_worker(): | ||
logger.warning(error_message) | ||
return | ||
|
||
if local_path_to_path_in_repo == "default": | ||
local_path_to_path_in_repo = functools.partial(default_local_path_to_path_in_repo, neuron_hash=neuron_hash) | ||
|
@@ -970,10 +993,7 @@ def push_to_cache_on_hub( | |
f"{local_cache_dir_or_file}" | ||
) | ||
|
||
could_not_push_message = ( | ||
"Could not push the cached model to the repo {cache_repo_id}, most likely due to not having the write permission " | ||
"for this repo. Exact error:\n{error}." | ||
) | ||
could_not_push_message = "Could not push the cached model to the repo {cache_repo_id}. Error message:\n{error}." | ||
success = True | ||
if local_cache_dir_or_file.is_dir(): | ||
try: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,7 +50,7 @@ | |
from ...utils.logging import set_verbosity as set_verbosity_optimum | ||
from ..generation import NeuronGenerationMixin | ||
from . import is_torch_xla_available | ||
from .require_utils import requires_torch_xla | ||
from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla | ||
|
||
|
||
if TYPE_CHECKING: | ||
|
@@ -315,6 +315,28 @@ def skip_first_batches(dataloader, num_batches=0): | |
return dataloader | ||
|
||
|
||
@requires_neuronx_distributed | ||
@requires_safetensors | ||
def torch_xla_safe_save_file( | ||
tensors: Dict[str, torch.Tensor], | ||
filename: Union[str, os.PathLike], | ||
metadata: Optional[Dict[str, str]] = None, | ||
master_only: bool = True, | ||
global_master: bool = False, | ||
): | ||
""" | ||
Torch XLA compatible implementation of `safetensors.torch.save_file`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate and explain those arguments a bit? I am not familiar with what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So those two parameters are related to distributed training. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More information here |
||
""" | ||
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu | ||
from safetensors.torch import save_file | ||
from torch_xla.core.xla_model import is_master_ordinal | ||
|
||
should_write_data = not master_only or is_master_ordinal(local=not global_master) | ||
cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data) | ||
if should_write_data: | ||
save_file(cpu_data, filename, metadata=metadata) | ||
|
||
|
||
def get_model_param_count(model, trainable_only=False): | ||
"""Wrapper around `transformers.trainer_pt_utils.get_model_param_count` to handle tensor parallelism.""" | ||
# TODO: make it work for TP | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be more specific, why could not you just continue with the typical transformers paradigm: pass the correct flag to tag the main process and give the specific save function (looking at the patch it is not clear why it is not redundant).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So basically the chain of calls here is:
transformers.modeling_utils.safe_save_file
function which issafetensors.torch.save_file
to be mytorch_xla
compatible version of this function.safetensors
. The issue is that whensafetensors
is used, thesave_function
parameter is ignored. That is the reason why we do it like that instead of simply passing thetorch_xla_safe_save_file
as the value forsave_function
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explaination !