Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Nov 27, 2023
1 parent e86a7c9 commit e83eca5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
8 changes: 6 additions & 2 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,13 +721,17 @@ def state_dict_to_bytes(self, state_dict: Dict[str, torch.Tensor]) -> bytes:
torch.bfloat16: torch.float16,
}

# Materialize all tensors before copying them to the CPU.
# This will prevent from triggering a graph compilation for each tensor.
xm.mark_step()

# It is actually important to first move the tensor to CPU then cast, because all XLA tensor operations,
# and in particular `to()` behave differently when doing `neuron_parallel_compile`.
cpu_state_dict = move_all_tensor_to_cpu(state_dict)

bytes_to_join = []
for name, tensor in cpu_state_dict.items():
memfile = io.BytesIO()
# It is actually important to first move the tensor to CPU then cast, because all XLA tensor operations,
# and in particular `to()` behave differently when doing `neuron_parallel_compile`.
np.save(memfile, tensor.to(cast_to_mapping.get(tensor.dtype, tensor.dtype)).numpy())
bytes_to_join.append(name.encode("utf-8"))
bytes_to_join.append(memfile.getvalue())
Expand Down
1 change: 1 addition & 0 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def _test_list_in_registry(use_private_cache_repo: bool):
_test_list_in_registry(True)


@is_trainium_test
class NeuronHashTestCase(TestCase):
def test_neuron_hash_is_not_mutable(self):
bert_model = BertModel(BertConfig())
Expand Down

0 comments on commit e83eca5

Please sign in to comment.