From 5418f62946ecac0715221f471e3eb9b1a664b9f8 Mon Sep 17 00:00:00 2001 From: malte Date: Mon, 6 Jan 2025 11:11:23 +0000 Subject: [PATCH 1/4] Fix vllm sampling parameters filter --- swift/llm/infer/infer_engine/vllm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index c6dcaad132..ff50bf41b9 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -167,7 +167,7 @@ def _load_generation_config(self) -> None: max_new_tokens = kwargs.get('max_new_tokens') if max_new_tokens is not None: kwargs['max_tokens'] = max_new_tokens - parameters = inspect.signature(SamplingParams.__init__).parameters + parameters = inspect.signature(SamplingParams).parameters for k, v in kwargs.copy().items(): if k not in parameters or v is None: kwargs.pop(k) From 7db491e98f860735e2c077a4920e966a39dddc64 Mon Sep 17 00:00:00 2001 From: malte Date: Mon, 6 Jan 2025 13:27:34 +0000 Subject: [PATCH 2/4] Fixes loading generation config from nested dicts using from_optional --- swift/llm/infer/infer_engine/vllm_engine.py | 39 ++++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index ff50bf41b9..4222e36ef6 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -161,19 +161,32 @@ def __len__(self) -> int: def _load_generation_config(self) -> None: generation_config_path = os.path.join(self.model_dir, 'generation_config.json') - if os.path.isfile(generation_config_path): - generation_config = GenerationConfig.from_pretrained(self.model_dir) - kwargs = generation_config.to_dict() - max_new_tokens = kwargs.get('max_new_tokens') - if max_new_tokens is not None: - kwargs['max_tokens'] = max_new_tokens - parameters = inspect.signature(SamplingParams).parameters - for k, v in kwargs.copy().items(): - if k not in parameters or v is None: - kwargs.pop(k) - self.generation_config = SamplingParams(**kwargs) - else: - self.generation_config = SamplingParams() + if os.path.isfile(generation_config_path): + generation_config = GenerationConfig.from_pretrained(self.model_dir) + kwargs = generation_config.to_dict() + max_new_tokens = kwargs.get('max_new_tokens') + if max_new_tokens is not None: + kwargs['max_tokens'] = max_new_tokens + parameters = inspect.signature(SamplingParams).parameters + for k, v in kwargs.copy().items(): + if k not in parameters or v is None: + kwargs.pop(k) + continue + # Check if parameter class has from_optional method + param_type = parameters[k].annotation + + if str(param_type).startswith('typing.Optional'): + # Extract the actual type from Optional + param_type = param_type.__args__[0] + print(param_type) + if hasattr(param_type, 'from_optional') and v is not None: + kwargs[k] = param_type.from_optional(v) + + self.generation_config = SamplingParams(**kwargs) + print("self.generation_config") + print(self.generation_config) + else: + self.generation_config = SamplingParams() def _add_stop_words(self, generation_config: SamplingParams, request_config: RequestConfig, template_meta: TemplateMeta) -> None: From f5e8984ff6dd95f02bc7cbd54cda728c7f340747 Mon Sep 17 00:00:00 2001 From: malte Date: Mon, 6 Jan 2025 13:28:58 +0000 Subject: [PATCH 3/4] Fixes formatting and removes print statements --- swift/llm/infer/infer_engine/vllm_engine.py | 47 ++++++++++----------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 4222e36ef6..058053dc12 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -161,32 +161,29 @@ def __len__(self) -> int: def _load_generation_config(self) -> None: generation_config_path = os.path.join(self.model_dir, 'generation_config.json') - if os.path.isfile(generation_config_path): - generation_config = GenerationConfig.from_pretrained(self.model_dir) - kwargs = generation_config.to_dict() - max_new_tokens = kwargs.get('max_new_tokens') - if max_new_tokens is not None: - kwargs['max_tokens'] = max_new_tokens - parameters = inspect.signature(SamplingParams).parameters - for k, v in kwargs.copy().items(): - if k not in parameters or v is None: - kwargs.pop(k) - continue - # Check if parameter class has from_optional method - param_type = parameters[k].annotation + if os.path.isfile(generation_config_path): + generation_config = GenerationConfig.from_pretrained(self.model_dir) + kwargs = generation_config.to_dict() + max_new_tokens = kwargs.get('max_new_tokens') + if max_new_tokens is not None: + kwargs['max_tokens'] = max_new_tokens + parameters = inspect.signature(SamplingParams).parameters + for k, v in kwargs.copy().items(): + if k not in parameters or v is None: + kwargs.pop(k) + continue + # Check if parameter class has from_optional method + param_type = parameters[k].annotation + + if str(param_type).startswith('typing.Optional'): + # Extract the actual type from Optional + param_type = param_type.__args__[0] + if hasattr(param_type, 'from_optional') and v is not None: + kwargs[k] = param_type.from_optional(v) - if str(param_type).startswith('typing.Optional'): - # Extract the actual type from Optional - param_type = param_type.__args__[0] - print(param_type) - if hasattr(param_type, 'from_optional') and v is not None: - kwargs[k] = param_type.from_optional(v) - - self.generation_config = SamplingParams(**kwargs) - print("self.generation_config") - print(self.generation_config) - else: - self.generation_config = SamplingParams() + self.generation_config = SamplingParams(**kwargs) + else: + self.generation_config = SamplingParams() def _add_stop_words(self, generation_config: SamplingParams, request_config: RequestConfig, template_meta: TemplateMeta) -> None: From 5bfafebd15f9e6c2c3898d565358f900dfd8f8c8 Mon Sep 17 00:00:00 2001 From: malte Date: Mon, 6 Jan 2025 13:33:18 +0000 Subject: [PATCH 4/4] Updates preparation of generation config to use loaded generation config as default values. --- swift/llm/infer/infer_engine/vllm_engine.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 058053dc12..4cda410455 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -256,18 +256,18 @@ def _get_logprobs(tokenizer: PreTrainedTokenizerBase, return {'content': res} def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingParams: - kwargs = {'max_tokens': request_config.max_tokens} - for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: + import msgspec + kwargs = msgspec.structs.asdict(self.generation_config) + + for key in ['max_tokens', 'temperature', 'top_k', 'top_p', 'repetition_penalty']: new_value = getattr(request_config, key) - if new_value is None: - kwargs[key] = getattr(self.generation_config, key) - else: + if new_value: kwargs[key] = new_value if request_config.logprobs: - kwargs['logprobs'] = 1 - if request_config.top_logprobs is not None: - kwargs['logprobs'] = max(1, request_config.top_logprobs) + top_logprobs = request_config.top_logprobs or 1 + kwargs['logprobs'] = max(1, top_logprobs) + # TODO: beam search for key in ['n', 'best_of', 'frequency_penalty', 'presence_penalty', 'seed']: