diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 68a60edbf..ef36d922f 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -277,8 +277,24 @@ def __init__( self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column( eval_dataset, args.prompts, dataset_name="eval" ) + self.add_model_card_callback(default_args_dict) + + def add_model_card_callback(self, default_args_dict: dict[str, Any]) -> None: + """ + Add a callback responsible for automatically tracking data required for the automatic model card generation + + This method is called in the ``__init__`` method of the + :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` class. + + Args: + default_args_dict (Dict[str, Any]): A dictionary of the default training arguments, so we can determine + which arguments have been changed for the model card. + + .. note:: + + This method can be overriden by subclassing the trainer to remove/customize this callback in custom uses cases + """ - # Add a callback responsible for automatically tracking data required for the automatic model card generation model_card_callback = ModelCardCallback(self, default_args_dict) self.add_callback(model_card_callback) model_card_callback.on_init_end(self.args, self.state, self.control, self.model)