Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save function fix #329

Merged
merged 8 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .utils import (
DynamicPatch,
ModelPatcher,
Patcher,
is_torch_xla_available,
patch_within_function,
)
Expand All @@ -62,6 +63,7 @@
prepare_environment_for_neuron,
set_neuron_cc_optlevel_for_model,
skip_first_batches,
torch_xla_safe_save_file,
)


Expand Down Expand Up @@ -402,20 +404,27 @@ def _save_xla(self, output_dir: Optional[str] = None):
xm.mark_step()
parallelizer.save_model_checkpoint(self.model, output_dir, as_sharded=True, optimizer=self.optimizer)
else:
safe_save_function_patcher = Patcher(
[("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)]
)
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
with safe_save_function_patcher:
unwrap_model(self.model).save_pretrained(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Could you explain what you want to avoid/achieve here ?
  2. Could you explain what will the exact chain of calls be here ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more specific, why could not you just continue with the typical transformers paradigm: pass the correct flag to tag the main process and give the specific save function (looking at the patch it is not clear why it is not redundant).

Copy link
Member Author

@michaelbenayoun michaelbenayoun Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically the chain of calls here is:

  1. We enter a context manager that patches the transformers.modeling_utils.safe_save_file function which is safetensors.torch.save_file to be my torch_xla compatible version of this function.
  2. The model is unwraped (it's not really important here, it's related to the Trainer and how sometimes the model is wrapped for features I am not really sure we support).
  3. The model is saved, and since last Transformers release, it will save the checkpoint using safetensors. The issue is that when safetensors is used, the save_function parameter is ignored. That is the reason why we do it like that instead of simply passing the torch_xla_safe_save_file as the value for save_function.
  4. We exit the context manager and everything that was patched is restored to its original value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explaination !

output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
with safe_save_function_patcher:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is self.model under this case? If unwrap_model(self.model) is not an instance of PreTrainedModel, can it still apply save_pretrained()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.model is the model we are currently training and is an instance of PreTrainedModel in this case.

No, if unwrap_model(self.model) is not an instance of PreTrainedModel we cannot call save_pretrained().

self.model.save_pretrained(
output_dir, is_main_process=self.args.should_save, save_function=xm.save
)

if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
Expand Down
42 changes: 31 additions & 11 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,22 @@ def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = Tru


def is_private_repo(repo_id: str) -> bool:
"""Tells whether `repo_id` is private."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function check whether a repo is private to all public? If so why do we need to check if it's private to a particular user in advance?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function checks whether is private to the general public but public to the current user.

if _DISABLE_IS_PRIVATE_REPO_CHECK:
return False
HfApi().list_repo_files(repo_id=repo_id, token=HfFolder.get_token())
private = False
try:
HfApi().list_repo_files(repo_id=repo_id, token=False)
HfApi().model_info(repo_id=repo_id, token=HfFolder.get_token())
private_to_user = False
except RepositoryNotFoundError:
private_to_user = True
if private_to_user:
private = True
else:
try:
HfApi().model_info(repo_id=repo_id, token=False)
private = False
except RepositoryNotFoundError:
private = True
return private


Expand Down Expand Up @@ -921,10 +929,20 @@ def push_to_cache_on_hub(
overwrite_existing: bool = False,
local_path_to_path_in_repo: Optional[Union[Literal["default"], Callable[[Path], Path]]] = None,
fail_when_could_not_push: bool = False,
) -> CachedModelOnTheHub:
) -> Optional[CachedModelOnTheHub]:
if cache_repo_id is None:
cache_repo_id = get_hf_hub_cache_repos()[0]

if not has_write_access_to_repo(cache_repo_id):
error_message = (
f"Could not push the cached model to {cache_repo_id} because you do not have write access to this repo."
)
if fail_when_could_not_push:
raise ValueError(error_message)
if is_main_worker():
logger.warning(error_message)
return

try:
create_registry_file_if_does_not_exist(cache_repo_id)
_REGISTRY_FILE_EXISTS[cache_repo_id] = True
Expand All @@ -933,10 +951,15 @@ def push_to_cache_on_hub(

is_cache_repo_private = is_private_repo(cache_repo_id)
if neuron_hash.is_private and not is_cache_repo_private:
raise ValueError(
f"Cannot push the cached model to {cache_repo_id} because this repo is not private but the original model is "
"coming from private repo."
error_message = (
f"Could not push the cached model to {cache_repo_id} because this repo is not private but the original "
"model is coming from private repo."
)
if fail_when_could_not_push:
raise ValueError(error_message)
if is_main_worker():
logger.warning(error_message)
return

if local_path_to_path_in_repo == "default":
local_path_to_path_in_repo = functools.partial(default_local_path_to_path_in_repo, neuron_hash=neuron_hash)
Expand Down Expand Up @@ -970,10 +993,7 @@ def push_to_cache_on_hub(
f"{local_cache_dir_or_file}"
)

could_not_push_message = (
"Could not push the cached model to the repo {cache_repo_id}, most likely due to not having the write permission "
"for this repo. Exact error:\n{error}."
)
could_not_push_message = "Could not push the cached model to the repo {cache_repo_id}. Exact error:\n{error}."
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
success = True
if local_cache_dir_or_file.is_dir():
try:
Expand Down
24 changes: 23 additions & 1 deletion optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ...utils.logging import set_verbosity as set_verbosity_optimum
from ..generation import NeuronGenerationMixin
from . import is_torch_xla_available
from .require_utils import requires_torch_xla
from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla


if TYPE_CHECKING:
Expand Down Expand Up @@ -315,6 +315,28 @@ def skip_first_batches(dataloader, num_batches=0):
return dataloader


@requires_neuronx_distributed
@requires_safetensors
def torch_xla_safe_save_file(
tensors: Dict[str, torch.Tensor],
filename: Union[str, os.PathLike],
metadata: Optional[Dict[str, str]] = None,
master_only: bool = True,
global_master: bool = False,
):
"""
Torch XLA compatible implementation of `safetensors.torch.save_file`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate and explain those arguments a bit? I am not familiar with what master_only and global_master indicate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So those two parameters are related to distributed training.
Basically, when master_only is True, only the master rank will be saving the file instead of all ranks, and global_master controls whether only the global master (accross multiple nodes) should be saving or the master of each node.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More information here

"""
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu
from safetensors.torch import save_file
from torch_xla.core.xla_model import is_master_ordinal

should_write_data = not master_only or is_master_ordinal(local=not global_master)
cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data)
if should_write_data:
save_file(cpu_data, filename, metadata=metadata)


def get_model_param_count(model, trainable_only=False):
"""Wrapper around `transformers.trainer_pt_utils.get_model_param_count` to handle tensor parallelism."""
# TODO: make it work for TP
Expand Down
12 changes: 5 additions & 7 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ def remove_repo():
if orig_token:
HfFolder.save_token(orig_token)

@is_staging_test
def test_has_write_access_to_repo(self):
orig_token = HfFolder.get_token()
wrong_token = "random_string"
Expand All @@ -292,7 +291,6 @@ def test_has_write_access_to_repo(self):
self.assertTrue(has_write_access_to_repo(self.CUSTOM_CACHE_REPO))
self.assertTrue(has_write_access_to_repo(self.CUSTOM_PRIVATE_CACHE_REPO))

@is_staging_test
def test_list_in_registry(self):
def _test_list_in_registry(use_private_cache_repo: bool):
if use_private_cache_repo:
Expand Down Expand Up @@ -466,7 +464,6 @@ def test_neuron_hash_is_private(self):

bert_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
neuron_hash = NeuronHash(bert_model, input_shapes, data_type, neuron_compiler_version=DUMMY_COMPILER_VERSION)

self.assertFalse(neuron_hash.is_private)

with TemporaryDirectory() as tmpdirname:
Expand Down Expand Up @@ -494,8 +491,10 @@ def test_push_to_hub_fails_with_private_model_and_public_repo(self):

# The model being loaded locally is assumed to be private, push to hub should prevent from pushing to a
# public repo.
with self.assertRaisesRegex(ValueError, "Cannot push the cached model"):
push_to_cache_on_hub(neuron_hash, cached_files[0], self.CUSTOM_CACHE_REPO)
with self.assertRaisesRegex(ValueError, "Could not push the cached model"):
push_to_cache_on_hub(
neuron_hash, cached_files[0], self.CUSTOM_CACHE_REPO, fail_when_could_not_push=True
)

# It should work when using a private repo.
cached_model_on_the_hub = push_to_cache_on_hub(
Expand Down Expand Up @@ -547,7 +546,6 @@ def test_push_to_hub_overwrite_existing(self):
# With a directory
with self.assertLogs("optimum", level="INFO") as cm:
push_to_cache_on_hub(neuron_hash, cache_dir, self.CUSTOM_PRIVATE_CACHE_REPO)
print(cm.output)
self.assertIn("Did not push the cached model located at", cm.output[0])

with self.assertLogs("optimum", level="WARNING") as cm:
Expand Down Expand Up @@ -636,7 +634,7 @@ def test_push_to_hub_without_writing_rights(self):
set_custom_cache_repo_name_in_hf_home(f"{TRANSFORMERS_USER}/{repo_name}")
with self.assertLogs("optimum", "WARNING") as cm:
push_to_cache_on_hub(neuron_hash, get_neuron_cache_path())
self.assertTrue(any("Could not push the cached model to the repo" in output for output in cm.output))
self.assertTrue(any("Could not push the cached model to" in output for output in cm.output))

self.set_hf_hub_token(TRANSFORMERS_TOKEN)
delete_repo(repo_name, repo_type="model")
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def push_tiny_pretrained_model_cache_to_hub(
push_to_cache_on_hub(
neuron_hash,
tmp_cache_dir,
fail_when_could_not_push=True,
)
if cache_dir is not None:
for file_or_dir in tmp_cache_dir.iterdir():
Expand Down
Loading