Skip to content

Commit

Permalink
Save function fix (#329)
Browse files Browse the repository at this point in the history
* Add a safetensors  function compatible with

* Add custom _maybe_move_to_cpu

* Fix

* Add docstring

* Fix test

* Fix test

* Apply suggestions
  • Loading branch information
michaelbenayoun authored Dec 18, 2023
1 parent ecdeee8 commit c2367ae
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 26 deletions.
23 changes: 16 additions & 7 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .utils import (
DynamicPatch,
ModelPatcher,
Patcher,
is_torch_xla_available,
patch_within_function,
)
Expand All @@ -63,6 +64,7 @@
prepare_environment_for_neuron,
set_neuron_cc_optlevel_for_model,
skip_first_batches,
torch_xla_safe_save_file,
)


Expand Down Expand Up @@ -408,20 +410,27 @@ def _save_xla(self, output_dir: Optional[str] = None):
xm.mark_step()
parallelizer.save_model_checkpoint(self.model, output_dir, as_sharded=True, optimizer=self.optimizer)
else:
safe_save_function_patcher = Patcher(
[("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)]
)
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
with safe_save_function_patcher:
unwrap_model(self.model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
with safe_save_function_patcher:
self.model.save_pretrained(
output_dir, is_main_process=self.args.should_save, save_function=xm.save
)

if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
Expand Down
42 changes: 31 additions & 11 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,22 @@ def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = Tru


def is_private_repo(repo_id: str) -> bool:
"""Tells whether `repo_id` is private."""
if _DISABLE_IS_PRIVATE_REPO_CHECK:
return False
HfApi().list_repo_files(repo_id=repo_id, token=HfFolder.get_token())
private = False
try:
HfApi().list_repo_files(repo_id=repo_id, token=False)
HfApi().model_info(repo_id=repo_id, token=HfFolder.get_token())
private_to_user = False
except RepositoryNotFoundError:
private_to_user = True
if private_to_user:
private = True
else:
try:
HfApi().model_info(repo_id=repo_id, token=False)
private = False
except RepositoryNotFoundError:
private = True
return private


Expand Down Expand Up @@ -921,10 +929,20 @@ def push_to_cache_on_hub(
overwrite_existing: bool = False,
local_path_to_path_in_repo: Optional[Union[Literal["default"], Callable[[Path], Path]]] = None,
fail_when_could_not_push: bool = False,
) -> CachedModelOnTheHub:
) -> Optional[CachedModelOnTheHub]:
if cache_repo_id is None:
cache_repo_id = get_hf_hub_cache_repos()[0]

if not has_write_access_to_repo(cache_repo_id):
error_message = (
f"Could not push the cached model to {cache_repo_id} because you do not have write access to this repo."
)
if fail_when_could_not_push:
raise ValueError(error_message)
if is_main_worker():
logger.warning(error_message)
return

try:
create_registry_file_if_does_not_exist(cache_repo_id)
_REGISTRY_FILE_EXISTS[cache_repo_id] = True
Expand All @@ -933,10 +951,15 @@ def push_to_cache_on_hub(

is_cache_repo_private = is_private_repo(cache_repo_id)
if neuron_hash.is_private and not is_cache_repo_private:
raise ValueError(
f"Cannot push the cached model to {cache_repo_id} because this repo is not private but the original model is "
"coming from private repo."
error_message = (
f"Could not push the cached model to {cache_repo_id} because this repo is not private but the original "
"model is coming from private repo."
)
if fail_when_could_not_push:
raise ValueError(error_message)
if is_main_worker():
logger.warning(error_message)
return

if local_path_to_path_in_repo == "default":
local_path_to_path_in_repo = functools.partial(default_local_path_to_path_in_repo, neuron_hash=neuron_hash)
Expand Down Expand Up @@ -970,10 +993,7 @@ def push_to_cache_on_hub(
f"{local_cache_dir_or_file}"
)

could_not_push_message = (
"Could not push the cached model to the repo {cache_repo_id}, most likely due to not having the write permission "
"for this repo. Exact error:\n{error}."
)
could_not_push_message = "Could not push the cached model to the repo {cache_repo_id}. Error message:\n{error}."
success = True
if local_cache_dir_or_file.is_dir():
try:
Expand Down
24 changes: 23 additions & 1 deletion optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ...utils.logging import set_verbosity as set_verbosity_optimum
from ..generation import NeuronGenerationMixin
from . import is_torch_xla_available
from .require_utils import requires_torch_xla
from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla


if TYPE_CHECKING:
Expand Down Expand Up @@ -319,6 +319,28 @@ def skip_first_batches(dataloader, num_batches=0):
return dataloader


@requires_neuronx_distributed
@requires_safetensors
def torch_xla_safe_save_file(
tensors: Dict[str, torch.Tensor],
filename: Union[str, os.PathLike],
metadata: Optional[Dict[str, str]] = None,
master_only: bool = True,
global_master: bool = False,
):
"""
Torch XLA compatible implementation of `safetensors.torch.save_file`.
"""
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu
from safetensors.torch import save_file
from torch_xla.core.xla_model import is_master_ordinal

should_write_data = not master_only or is_master_ordinal(local=not global_master)
cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data)
if should_write_data:
save_file(cpu_data, filename, metadata=metadata)


def get_model_param_count(model, trainable_only=False):
"""Wrapper around `transformers.trainer_pt_utils.get_model_param_count` to handle tensor parallelism."""
# TODO: make it work for TP
Expand Down
12 changes: 5 additions & 7 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ def remove_repo():
if orig_token:
HfFolder.save_token(orig_token)

@is_staging_test
def test_has_write_access_to_repo(self):
orig_token = HfFolder.get_token()
wrong_token = "random_string"
Expand All @@ -292,7 +291,6 @@ def test_has_write_access_to_repo(self):
self.assertTrue(has_write_access_to_repo(self.CUSTOM_CACHE_REPO))
self.assertTrue(has_write_access_to_repo(self.CUSTOM_PRIVATE_CACHE_REPO))

@is_staging_test
def test_list_in_registry(self):
def _test_list_in_registry(use_private_cache_repo: bool):
if use_private_cache_repo:
Expand Down Expand Up @@ -466,7 +464,6 @@ def test_neuron_hash_is_private(self):

bert_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
neuron_hash = NeuronHash(bert_model, input_shapes, data_type, neuron_compiler_version=DUMMY_COMPILER_VERSION)

self.assertFalse(neuron_hash.is_private)

with TemporaryDirectory() as tmpdirname:
Expand Down Expand Up @@ -494,8 +491,10 @@ def test_push_to_hub_fails_with_private_model_and_public_repo(self):

# The model being loaded locally is assumed to be private, push to hub should prevent from pushing to a
# public repo.
with self.assertRaisesRegex(ValueError, "Cannot push the cached model"):
push_to_cache_on_hub(neuron_hash, cached_files[0], self.CUSTOM_CACHE_REPO)
with self.assertRaisesRegex(ValueError, "Could not push the cached model"):
push_to_cache_on_hub(
neuron_hash, cached_files[0], self.CUSTOM_CACHE_REPO, fail_when_could_not_push=True
)

# It should work when using a private repo.
cached_model_on_the_hub = push_to_cache_on_hub(
Expand Down Expand Up @@ -547,7 +546,6 @@ def test_push_to_hub_overwrite_existing(self):
# With a directory
with self.assertLogs("optimum", level="INFO") as cm:
push_to_cache_on_hub(neuron_hash, cache_dir, self.CUSTOM_PRIVATE_CACHE_REPO)
print(cm.output)
self.assertIn("Did not push the cached model located at", cm.output[0])

with self.assertLogs("optimum", level="WARNING") as cm:
Expand Down Expand Up @@ -636,7 +634,7 @@ def test_push_to_hub_without_writing_rights(self):
set_custom_cache_repo_name_in_hf_home(f"{TRANSFORMERS_USER}/{repo_name}")
with self.assertLogs("optimum", "WARNING") as cm:
push_to_cache_on_hub(neuron_hash, get_neuron_cache_path())
self.assertTrue(any("Could not push the cached model to the repo" in output for output in cm.output))
self.assertTrue(any("Could not push the cached model to" in output for output in cm.output))

self.set_hf_hub_token(TRANSFORMERS_TOKEN)
delete_repo(repo_name, repo_type="model")
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def push_tiny_pretrained_model_cache_to_hub(
push_to_cache_on_hub(
neuron_hash,
tmp_cache_dir,
fail_when_could_not_push=True,
)
if cache_dir is not None:
for file_or_dir in tmp_cache_dir.iterdir():
Expand Down

0 comments on commit c2367ae

Please sign in to comment.