From a32e71f605e4bb3b81aef7b14e53d2c98b8ddcb6 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Thu, 5 Dec 2024 12:56:20 -0800 Subject: [PATCH] Add logs in `consistent_restore_mesh` to examine the invariant. PiperOrigin-RevId: 703218403 --- .../experimental/emergency/multihost.py | 53 +++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multihost.py b/checkpoint/orbax/checkpoint/experimental/emergency/multihost.py index 1995a364..9884cd07 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multihost.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multihost.py @@ -132,11 +132,34 @@ def consistent_restore_mesh( """ # Map how device ids changed across restarts. device_id_across_restarts = {} - for i in range(len(previous_distributed_to_device_ids)): - for j in range(len(previous_distributed_to_device_ids[i])): - previous_id = previous_distributed_to_device_ids[i][j] - current_id = current_distributed_to_device_ids[i][j] + assert len(previous_distributed_to_device_ids) == len( + current_distributed_to_device_ids + ) + + logging.debug( + 'previous_distributed_to_device_ids: %s', + previous_distributed_to_device_ids, + ) + logging.debug( + 'current_distributed_to_device_ids: %s', + current_distributed_to_device_ids, + ) + # TODO(b/376748289): remove the following variables after bug is fixed. + previous_device_to_distributed_id = {} + current_device_to_distributed_id = {} + for distributed_id in range(len(previous_distributed_to_device_ids)): + logging.debug( + 'distributed_id: %s, previous_device_ids: %s, current_device_ids: %s', + distributed_id, + previous_distributed_to_device_ids[distributed_id], + current_distributed_to_device_ids[distributed_id], + ) + for j in range(len(previous_distributed_to_device_ids[distributed_id])): + previous_id = previous_distributed_to_device_ids[distributed_id][j] + current_id = current_distributed_to_device_ids[distributed_id][j] device_id_across_restarts[previous_id] = current_id + previous_device_to_distributed_id[previous_id] = distributed_id + current_device_to_distributed_id[current_id] = distributed_id logging.debug( 'device_id_across_restarts (key: previous_id, value: current_id): %s', device_id_across_restarts, @@ -150,6 +173,28 @@ def consistent_restore_mesh( jax_devices_by_id[device_id_across_restarts[id]] for id in previous_flattened_mesh_device_ids ] + logging.debug( + 'previous_flattened_mesh_device_ids: %s', + previous_flattened_mesh_device_ids, + ) + new_flattened_mesh_device_ids = [d.id for d in new_flattened_mesh_devices] + logging.debug( + 'new_flattened_mesh_device_ids: %s', + new_flattened_mesh_device_ids, + ) + + previous_flattened_distributed_ids = [ + previous_device_to_distributed_id[id] + for id in previous_flattened_mesh_device_ids + ] + current_flattened_distributed_ids = [ + current_device_to_distributed_id[id] + for id in new_flattened_mesh_device_ids + ] + # The following is the invariant considering the distributed ids are + # the same across restarts. + assert previous_flattened_distributed_ids == current_flattened_distributed_ids + new_mesh_devices = np.array(new_flattened_mesh_devices).reshape( user_mesh.devices.shape )