diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py index d33e2f82e..8aa0664ba 100644 --- a/optimum/neuron/utils/cache_utils.py +++ b/optimum/neuron/utils/cache_utils.py @@ -721,13 +721,17 @@ def state_dict_to_bytes(self, state_dict: Dict[str, torch.Tensor]) -> bytes: torch.bfloat16: torch.float16, } + # Materialize all tensors before copying them to the CPU. + # This will prevent from triggering a graph compilation for each tensor. xm.mark_step() + + # It is actually important to first move the tensor to CPU then cast, because all XLA tensor operations, + # and in particular `to()` behave differently when doing `neuron_parallel_compile`. cpu_state_dict = move_all_tensor_to_cpu(state_dict) + bytes_to_join = [] for name, tensor in cpu_state_dict.items(): memfile = io.BytesIO() - # It is actually important to first move the tensor to CPU then cast, because all XLA tensor operations, - # and in particular `to()` behave differently when doing `neuron_parallel_compile`. np.save(memfile, tensor.to(cast_to_mapping.get(tensor.dtype, tensor.dtype)).numpy()) bytes_to_join.append(name.encode("utf-8")) bytes_to_join.append(memfile.getvalue()) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index be6ca4ba7..9a83c7fe5 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -344,6 +344,7 @@ def _test_list_in_registry(use_private_cache_repo: bool): _test_list_in_registry(True) +@is_trainium_test class NeuronHashTestCase(TestCase): def test_neuron_hash_is_not_mutable(self): bert_model = BertModel(BertConfig())