From 6347b75e0c94ab534ad5335f3638ca700fd2d29c Mon Sep 17 00:00:00 2001 From: Tim Rosenflanz Date: Mon, 11 Nov 2024 03:30:06 -0800 Subject: [PATCH] Moved Model Card Callback init in Trainer to a separate function (#3047) * Moved Model Card Callback init in Trainer to a separate function for subclassing * revert formatting * Rename slightly; add docstring which'll go in docs * Initialize default_args_dict before the super().__init__() --------- Co-authored-by: Tom Aarsen --- sentence_transformers/trainer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 68a60edbf..ef36d922f 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -277,8 +277,24 @@ def __init__( self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column( eval_dataset, args.prompts, dataset_name="eval" ) + self.add_model_card_callback(default_args_dict) + + def add_model_card_callback(self, default_args_dict: dict[str, Any]) -> None: + """ + Add a callback responsible for automatically tracking data required for the automatic model card generation + + This method is called in the ``__init__`` method of the + :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` class. + + Args: + default_args_dict (Dict[str, Any]): A dictionary of the default training arguments, so we can determine + which arguments have been changed for the model card. + + .. note:: + + This method can be overriden by subclassing the trainer to remove/customize this callback in custom uses cases + """ - # Add a callback responsible for automatically tracking data required for the automatic model card generation model_card_callback = ModelCardCallback(self, default_args_dict) self.add_callback(model_card_callback) model_card_callback.on_init_end(self.args, self.state, self.control, self.model)