From f2e7c9919e27f94d70a2da469fa3a0af62496bb9 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 16 Nov 2023 15:03:34 +0100 Subject: [PATCH 1/7] Add a safetensors function compatible with --- optimum/neuron/trainers.py | 23 ++++++++++++++++------- optimum/neuron/utils/cache_utils.py | 22 ++++++++++++++++++---- optimum/neuron/utils/training_utils.py | 22 ++++++++++++++++++++-- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 5258542f8..399003e9f 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -48,6 +48,7 @@ from .utils import ( DynamicPatch, ModelPatcher, + Patcher, is_torch_xla_available, patch_within_function, ) @@ -61,6 +62,7 @@ patched_finfo, prepare_environment_for_neuron, skip_first_batches, + torch_xla_safe_save_file, ) @@ -392,20 +394,27 @@ def _save_xla(self, output_dir: Optional[str] = None): xm.mark_step() parallelizer.save_model_checkpoint(self.model, output_dir, as_sharded=True, optimizer=self.optimizer) else: + safe_save_function_patcher = Patcher( + [("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)] + ) if not isinstance(self.model, PreTrainedModel): if isinstance(unwrap_model(self.model), PreTrainedModel): - unwrap_model(self.model).save_pretrained( - output_dir, - is_main_process=self.args.should_save, - state_dict=self.model.state_dict(), - save_function=xm.save, - ) + with safe_save_function_patcher: + unwrap_model(self.model).save_pretrained( + output_dir, + is_main_process=self.args.should_save, + state_dict=self.model.state_dict(), + save_function=xm.save, + ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + with safe_save_function_patcher: + self.model.save_pretrained( + output_dir, is_main_process=self.args.should_save, save_function=xm.save + ) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir) diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py index e42c897a0..34e35f3e3 100644 --- a/optimum/neuron/utils/cache_utils.py +++ b/optimum/neuron/utils/cache_utils.py @@ -138,11 +138,18 @@ def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = Tru def is_private_repo(repo_id: str) -> bool: if _DISABLE_IS_PRIVATE_REPO_CHECK: return False - HfApi().list_repo_files(repo_id=repo_id, token=HfFolder.get_token()) - private = False try: - HfApi().list_repo_files(repo_id=repo_id, token=False) + HfApi().model_info(repo_id=repo_id, token=HfFolder.get_token()) + private_to_user = False except RepositoryNotFoundError: + private_to_user = True + if not private_to_user: + try: + HfApi().list_repo_files(repo_id=repo_id, token=False) + private = False + except RepositoryNotFoundError: + private = True + else: private = True return private @@ -829,10 +836,17 @@ def push_to_cache_on_hub( cache_repo_id: Optional[str] = None, overwrite_existing: bool = False, local_path_to_path_in_repo: Optional[Union[Literal["default"], Callable[[Path], Path]]] = None, -) -> CachedModelOnTheHub: +) -> Optional[CachedModelOnTheHub]: if cache_repo_id is None: cache_repo_id = get_hf_hub_cache_repos()[0] + if not has_write_access_to_repo(cache_repo_id): + logger.warning( + f"The compilation files cannot be pushed to the cache repo {cache_repo_id} because you do not have write " + "access." + ) + return + try: create_registry_file_if_does_not_exist(cache_repo_id) _REGISTRY_FILE_EXISTS[cache_repo_id] = True diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 55031438d..1f122a46d 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -15,7 +15,7 @@ """Training utilities""" import os -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch import transformers @@ -49,7 +49,7 @@ from ...utils.logging import set_verbosity as set_verbosity_optimum from ..generation import NeuronGenerationMixin from . import is_torch_xla_available -from .require_utils import requires_torch_xla +from .require_utils import requires_safetensors, requires_torch_xla if TYPE_CHECKING: @@ -293,6 +293,24 @@ def skip_first_batches(dataloader, num_batches=0): return dataloader +@requires_torch_xla +@requires_safetensors +def torch_xla_safe_save_file( + tensors: Dict[str, torch.Tensor], + filename: Union[str, os.PathLike], + metadata: Optional[Dict[str, str]] = None, + master_only: bool = True, + global_master: bool = False, +): + from safetensors.torch import save_file + from torch_xla.core.xla_model import _maybe_convert_to_cpu, is_master_ordinal + + should_write_data = not master_only or is_master_ordinal(local=not global_master) + cpu_data = _maybe_convert_to_cpu(tensors, convert=should_write_data) + if should_write_data: + save_file(cpu_data, filename, metadata=metadata) + + def get_model_param_count(model, trainable_only=False): """Wrapper around `transformers.trainer_pt_utils.get_model_param_count` to handle tensor parallelism.""" # TODO: make it work for TP From f03e9b90d835572b206116f01cf47d8f853fecca Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 16 Nov 2023 15:25:40 +0100 Subject: [PATCH 2/7] Add custom _maybe_move_to_cpu --- optimum/neuron/utils/training_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 1f122a46d..721014b06 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -305,6 +305,23 @@ def torch_xla_safe_save_file( from safetensors.torch import save_file from torch_xla.core.xla_model import _maybe_convert_to_cpu, is_master_ordinal + def _maybe_convert_to_cpu(data, convert=True): + import torch_xla + from torch_xla.core.xla_model import ToXlaTensorArena, is_xla_tensor + + def convert_fn(tensors): + torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True) + if not convert: + return tensors + # return torch_xla._XLAC._xla_get_cpu_tensors(tensors) + # Doing the same as neuronx_distributed. + return [tensor.to("cpu") for tensor in tensors] + + def select_fn(v): + return type(v) == torch.Tensor and is_xla_tensor(v) + + return ToXlaTensorArena(convert_fn, select_fn).transform(data) + should_write_data = not master_only or is_master_ordinal(local=not global_master) cpu_data = _maybe_convert_to_cpu(tensors, convert=should_write_data) if should_write_data: From aea8db155f408adb65ce3bff0a576f165bab48a3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 27 Nov 2023 17:09:50 +0100 Subject: [PATCH 3/7] Fix --- optimum/neuron/utils/training_utils.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 48017c950..00b0dad20 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -50,7 +50,7 @@ from ...utils.logging import set_verbosity as set_verbosity_optimum from ..generation import NeuronGenerationMixin from . import is_torch_xla_available -from .require_utils import requires_safetensors, requires_torch_xla +from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla if TYPE_CHECKING: @@ -315,7 +315,7 @@ def skip_first_batches(dataloader, num_batches=0): return dataloader -@requires_torch_xla +@requires_neuronx_distributed @requires_safetensors def torch_xla_safe_save_file( tensors: Dict[str, torch.Tensor], @@ -324,28 +324,12 @@ def torch_xla_safe_save_file( master_only: bool = True, global_master: bool = False, ): + from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu from safetensors.torch import save_file - from torch_xla.core.xla_model import _maybe_convert_to_cpu, is_master_ordinal - - def _maybe_convert_to_cpu(data, convert=True): - import torch_xla - from torch_xla.core.xla_model import ToXlaTensorArena, is_xla_tensor - - def convert_fn(tensors): - torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True) - if not convert: - return tensors - # return torch_xla._XLAC._xla_get_cpu_tensors(tensors) - # Doing the same as neuronx_distributed. - return [tensor.to("cpu") for tensor in tensors] - - def select_fn(v): - return type(v) == torch.Tensor and is_xla_tensor(v) - - return ToXlaTensorArena(convert_fn, select_fn).transform(data) + from torch_xla.core.xla_model import is_master_ordinal should_write_data = not master_only or is_master_ordinal(local=not global_master) - cpu_data = _maybe_convert_to_cpu(tensors, convert=should_write_data) + cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data) if should_write_data: save_file(cpu_data, filename, metadata=metadata) From 36fcfc41b7fd6043112807376814141c018b15ab Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 27 Nov 2023 17:17:46 +0100 Subject: [PATCH 4/7] Add docstring --- optimum/neuron/utils/training_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 00b0dad20..14e01e4c6 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -324,6 +324,9 @@ def torch_xla_safe_save_file( master_only: bool = True, global_master: bool = False, ): + """ + Torch XLA compatible implementation of `safetensors.torch.save_file`. + """ from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu from safetensors.torch import save_file from torch_xla.core.xla_model import is_master_ordinal From 607ae214e18411bdd1845af173c00e562db22d91 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 28 Nov 2023 11:15:10 +0100 Subject: [PATCH 5/7] Fix test --- optimum/neuron/utils/cache_utils.py | 36 ++++++++++++++--------------- tests/test_cache_utils.py | 10 ++++---- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py index 26ca29a1c..cf1b50c0e 100644 --- a/optimum/neuron/utils/cache_utils.py +++ b/optimum/neuron/utils/cache_utils.py @@ -133,6 +133,7 @@ def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = Tru def is_private_repo(repo_id: str) -> bool: + """Tells whether `repo_id` is private to the current user logged-in.""" if _DISABLE_IS_PRIVATE_REPO_CHECK: return False try: @@ -140,15 +141,7 @@ def is_private_repo(repo_id: str) -> bool: private_to_user = False except RepositoryNotFoundError: private_to_user = True - if not private_to_user: - try: - HfApi().list_repo_files(repo_id=repo_id, token=False) - private = False - except RepositoryNotFoundError: - private = True - else: - private = True - return private + return private_to_user def has_write_access_to_repo(repo_id: str) -> bool: @@ -933,10 +926,13 @@ def push_to_cache_on_hub( cache_repo_id = get_hf_hub_cache_repos()[0] if not has_write_access_to_repo(cache_repo_id): - logger.warning( - f"The compilation files cannot be pushed to the cache repo {cache_repo_id} because you do not have write " - "access." + error_message = ( + f"Could not push the cached model to {cache_repo_id} because you do not have write access to this repo." ) + if fail_when_could_not_push: + raise ValueError(error_message) + if is_main_worker(): + logger.warning(error_message) return try: @@ -947,10 +943,15 @@ def push_to_cache_on_hub( is_cache_repo_private = is_private_repo(cache_repo_id) if neuron_hash.is_private and not is_cache_repo_private: - raise ValueError( - f"Cannot push the cached model to {cache_repo_id} because this repo is not private but the original model is " - "coming from private repo." + error_message = ( + f"Could not push the cached model to {cache_repo_id} because this repo is not private but the original " + "model is coming from private repo." ) + if fail_when_could_not_push: + raise ValueError(error_message) + if is_main_worker(): + logger.warning(error_message) + return if local_path_to_path_in_repo == "default": local_path_to_path_in_repo = functools.partial(default_local_path_to_path_in_repo, neuron_hash=neuron_hash) @@ -984,10 +985,7 @@ def push_to_cache_on_hub( f"{local_cache_dir_or_file}" ) - could_not_push_message = ( - "Could not push the cached model to the repo {cache_repo_id}, most likely due to not having the write permission " - "for this repo. Exact error:\n{error}." - ) + could_not_push_message = "Could not push the cached model to the repo {cache_repo_id}. Exact error:\n{error}." success = True if local_cache_dir_or_file.is_dir(): try: diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index be6ca4ba7..e260f3455 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -278,7 +278,6 @@ def remove_repo(): if orig_token: HfFolder.save_token(orig_token) - @is_staging_test def test_has_write_access_to_repo(self): orig_token = HfFolder.get_token() wrong_token = "random_string" @@ -292,7 +291,6 @@ def test_has_write_access_to_repo(self): self.assertTrue(has_write_access_to_repo(self.CUSTOM_CACHE_REPO)) self.assertTrue(has_write_access_to_repo(self.CUSTOM_PRIVATE_CACHE_REPO)) - @is_staging_test def test_list_in_registry(self): def _test_list_in_registry(use_private_cache_repo: bool): if use_private_cache_repo: @@ -466,7 +464,6 @@ def test_neuron_hash_is_private(self): bert_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") neuron_hash = NeuronHash(bert_model, input_shapes, data_type, neuron_compiler_version=DUMMY_COMPILER_VERSION) - self.assertFalse(neuron_hash.is_private) with TemporaryDirectory() as tmpdirname: @@ -495,7 +492,9 @@ def test_push_to_hub_fails_with_private_model_and_public_repo(self): # The model being loaded locally is assumed to be private, push to hub should prevent from pushing to a # public repo. with self.assertRaisesRegex(ValueError, "Cannot push the cached model"): - push_to_cache_on_hub(neuron_hash, cached_files[0], self.CUSTOM_CACHE_REPO) + push_to_cache_on_hub( + neuron_hash, cached_files[0], self.CUSTOM_CACHE_REPO, fail_when_could_not_push=True + ) # It should work when using a private repo. cached_model_on_the_hub = push_to_cache_on_hub( @@ -547,7 +546,6 @@ def test_push_to_hub_overwrite_existing(self): # With a directory with self.assertLogs("optimum", level="INFO") as cm: push_to_cache_on_hub(neuron_hash, cache_dir, self.CUSTOM_PRIVATE_CACHE_REPO) - print(cm.output) self.assertIn("Did not push the cached model located at", cm.output[0]) with self.assertLogs("optimum", level="WARNING") as cm: @@ -636,7 +634,7 @@ def test_push_to_hub_without_writing_rights(self): set_custom_cache_repo_name_in_hf_home(f"{TRANSFORMERS_USER}/{repo_name}") with self.assertLogs("optimum", "WARNING") as cm: push_to_cache_on_hub(neuron_hash, get_neuron_cache_path()) - self.assertTrue(any("Could not push the cached model to the repo" in output for output in cm.output)) + self.assertTrue(any("Could not push the cached model to" in output for output in cm.output)) self.set_hf_hub_token(TRANSFORMERS_TOKEN) delete_repo(repo_name, repo_type="model") From 158ff4e57ee43b2ce45f0a91ff6ec3f1d14512ec Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 28 Nov 2023 12:26:20 +0100 Subject: [PATCH 6/7] Fix test --- optimum/neuron/utils/cache_utils.py | 12 ++++++++++-- tests/test_cache_utils.py | 2 +- tests/utils.py | 1 + 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py index cf1b50c0e..be5991c96 100644 --- a/optimum/neuron/utils/cache_utils.py +++ b/optimum/neuron/utils/cache_utils.py @@ -133,7 +133,7 @@ def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = Tru def is_private_repo(repo_id: str) -> bool: - """Tells whether `repo_id` is private to the current user logged-in.""" + """Tells whether `repo_id` is private.""" if _DISABLE_IS_PRIVATE_REPO_CHECK: return False try: @@ -141,7 +141,15 @@ def is_private_repo(repo_id: str) -> bool: private_to_user = False except RepositoryNotFoundError: private_to_user = True - return private_to_user + if private_to_user: + private = True + else: + try: + HfApi().model_info(repo_id=repo_id, token=False) + private = False + except RepositoryNotFoundError: + private = True + return private def has_write_access_to_repo(repo_id: str) -> bool: diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index e260f3455..6d00cba9a 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -491,7 +491,7 @@ def test_push_to_hub_fails_with_private_model_and_public_repo(self): # The model being loaded locally is assumed to be private, push to hub should prevent from pushing to a # public repo. - with self.assertRaisesRegex(ValueError, "Cannot push the cached model"): + with self.assertRaisesRegex(ValueError, "Could not push the cached model"): push_to_cache_on_hub( neuron_hash, cached_files[0], self.CUSTOM_CACHE_REPO, fail_when_could_not_push=True ) diff --git a/tests/utils.py b/tests/utils.py index c7b5be914..be069ddf1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -256,6 +256,7 @@ def push_tiny_pretrained_model_cache_to_hub( push_to_cache_on_hub( neuron_hash, tmp_cache_dir, + fail_when_could_not_push=True, ) if cache_dir is not None: for file_or_dir in tmp_cache_dir.iterdir(): From ef9b15c8820b4804acf1aad7711565b1e205815c Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 18 Dec 2023 14:34:34 +0100 Subject: [PATCH 7/7] Apply suggestions --- optimum/neuron/utils/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py index be5991c96..698dde5e0 100644 --- a/optimum/neuron/utils/cache_utils.py +++ b/optimum/neuron/utils/cache_utils.py @@ -993,7 +993,7 @@ def push_to_cache_on_hub( f"{local_cache_dir_or_file}" ) - could_not_push_message = "Could not push the cached model to the repo {cache_repo_id}. Exact error:\n{error}." + could_not_push_message = "Could not push the cached model to the repo {cache_repo_id}. Error message:\n{error}." success = True if local_cache_dir_or_file.is_dir(): try: