Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 31, 2024
1 parent 9d11fdc commit 2e70fdf
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 55 deletions.
32 changes: 5 additions & 27 deletions swift/llm/infer/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def _get_stop_words(self, stop_words: List[Union[str, List[int], None]]) -> List
return stop

@staticmethod
def __infer_stream(tasks,
stream: bool = True,
use_tqdm: bool = True) -> Iterator[List[Optional[ChatCompletionStreamResponse]]]:
def _batch_infer_stream(tasks,
stream: bool = True,
use_tqdm: bool = True) -> Iterator[List[Optional[ChatCompletionStreamResponse]]]:

async def _run_infer(i, task, queue, stream: bool = False):
# task with queue
Expand Down Expand Up @@ -142,13 +142,13 @@ def infer(self,
if request_config.stream:

def _gen_wrapper():
for res in self.__infer_stream(tasks, True, use_tqdm):
for res in self._batch_infer_stream(tasks, True, use_tqdm):
yield res
self._update_metrics(res, metrics)

return _gen_wrapper()
else:
for outputs in self.__infer_stream(tasks, False, use_tqdm):
for outputs in self._batch_infer_stream(tasks, False, use_tqdm):
pass
return self._update_metrics(outputs, metrics)

Expand All @@ -164,28 +164,6 @@ def _get_toolcall(self, response: Union[str, List[int]],

return [ChatCompletionMessageToolCall(function=Function(name=action, arguments=action_input))]

@torch.inference_mode()
async def infer_async(self,
infer_request: InferRequest,
request_config: Optional[RequestConfig] = None,
*,
template: Optional[Template] = None,
**kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
request_config = deepcopy(request_config or RequestConfig())
if template is None:
template = self.default_template

inputs = template.encode(infer_request)
assert len(inputs) >= 0
self.set_default_max_tokens(request_config, inputs)
generation_config = self._prepare_generation_config(request_config)
self._add_stop_words(generation_config, request_config, template.template_meta)
infer_args = (template, inputs, generation_config)
if request_config.stream:
return self._infer_stream_async(*infer_args, **kwargs)
else:
return await self._infer_full_async(*infer_args, **kwargs)

@staticmethod
def _get_num_tokens(inputs: Dict[str, Any]) -> int:
if 'input_ids' in inputs: # 1d or 2d
Expand Down
22 changes: 22 additions & 0 deletions swift/llm/infer/infer_engine/lmdeploy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,25 @@ async def _infer_full_async(self, template: Template, inputs: Dict[str, Any],
logprobs=logprobs)
]
return ChatCompletionResponse(model=self.model_dir, choices=choices, usage=usage_info)

@torch.inference_mode()
async def infer_async(self,
infer_request: InferRequest,
request_config: Optional[RequestConfig] = None,
*,
template: Optional[Template] = None,
**kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
request_config = deepcopy(request_config or RequestConfig())
if template is None:
template = self.default_template

inputs = template.encode(infer_request)
assert len(inputs) >= 0
self.set_default_max_tokens(request_config, inputs)
generation_config = self._prepare_generation_config(request_config)
self._add_stop_words(generation_config, request_config, template.template_meta)
infer_args = (template, inputs, generation_config)
if request_config.stream:
return self._infer_stream_async(*infer_args, **kwargs)
else:
return await self._infer_full_async(*infer_args, **kwargs)
15 changes: 14 additions & 1 deletion swift/llm/infer/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,17 @@ async def infer_async(
template: Optional[Template] = None,
lora_request: Optional['LoRARequest'] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
return await super().infer_async(infer_request, request_config, template=template, lora_request=lora_request)
request_config = deepcopy(request_config or RequestConfig())
if template is None:
template = self.default_template

inputs = template.encode(infer_request)
assert len(inputs) >= 0
self.set_default_max_tokens(request_config, inputs)
generation_config = self._prepare_generation_config(request_config)
self._add_stop_words(generation_config, request_config, template.template_meta)
infer_args = (template, inputs, generation_config)
if request_config.stream:
return self._infer_stream_async(*infer_args, **kwargs)
else:
return await self._infer_full_async(*infer_args, **kwargs)
39 changes: 12 additions & 27 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import itertools
import os
import re
from dataclasses import dataclass, field
from functools import partial, update_wrapper
from types import MethodType
Expand Down Expand Up @@ -84,7 +85,7 @@ def get_model_names(self) -> List[str]:

if isinstance(value, str):
model_name = value.rsplit('/', 1)[-1]
res.add(model_name)
res.add(model_name.lower())
return list(res)

def check_requires(self):
Expand Down Expand Up @@ -289,49 +290,33 @@ def get_default_torch_dtype(torch_dtype: torch.dtype):

def _get_model_name(model_id_or_path: str) -> str:
# compat hf hub
match_ = re.search('/models--.+?--(.+?)/snapshots/', model_dir)
match_ = re.search('/models--.+?--(.+?)/snapshots/', model_id_or_path)
if match_ is not None:
model_name = match_.group(1)
else:
model_name = model_dir.rsplit('/', 1)[-1]
return model_name
model_name = model_id_or_path.rsplit('/', 1)[-1]
return model_name.lower()


_model_name_mapping = {}


def get_arch_mapping() -> Dict[str, Dict[str, List[str]]]:
global ARCH_MAPPING
if ARCH_MAPPING is None:
# arch(str) -> Dict[model_type(str), List[model_name(str)]]
ARCH_MAPPING = {}
for model_type, model_info in MODEL_MAPPING.items():
model_meta = model_info['model_meta']
archs = model_meta.architectures
model_names = model_meta.get_model_names()
for arch in archs:
if arch not in ARCH_MAPPING:
ARCH_MAPPING[arch] = {}
ARCH_MAPPING[arch][model_type] = model_names
return ARCH_MAPPING


def get_model_name_mapping():
# model_name -> model_type
global _model_name_mapping
if _model_name_mapping is not None:
return _model_name_mapping
for model_type, model_meta in MODEL_MAPPING.items():
model_meta.get_model_names()
model_names = model_meta.get_model_names()
for model_name in model_names:
_model_name_mapping[model_name] = model_type
return _model_name_mapping


def get_matched_model_meta(model_id_or_path: str):
def get_matched_model_meta(model_id_or_path: str) -> Optional[str]:
model_name = _get_model_name(model_id_or_path).lower()
model_type_dict_reversed = {}
for model_type, model_names in model_type_dict.items():
model_type_dict_reversed.update({model_name.lower(): model_type for model_name in model_names})
model_type = model_type_dict_reversed.get(model_name)

model_name_mapping = get_model_name_mapping()
return model_name_mapping.get(model_name)


def get_model_tokenizer(
Expand Down

0 comments on commit 2e70fdf

Please sign in to comment.