diff --git a/mlfoundry_utils.py b/mlfoundry_utils.py index 2f197de..6fb0851 100644 --- a/mlfoundry_utils.py +++ b/mlfoundry_utils.py @@ -178,9 +178,14 @@ def on_save(self, args, state, control, **kwargs): if TFY_INTERNAL_JOB_NAME: description = f"Checkpoint from finetuning job={TFY_INTERNAL_JOB_NAME} run={TFY_INTERNAL_JOB_RUN_NAME}" logger.info(f"Uploading checkpoint {ckpt_dir} ...") + metadata = {} + for log in state.log_history: + if isinstance(log, dict) and log.get("step") == state.global_step: + metadata = log.copy() self._run.log_artifact( name=self._checkpoint_artifact_name, artifact_paths=[(artifact_path,)], + metadata=metadata, step=state.global_step, description=description, )