Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vision function calling models based on: InternVL 2.0 models #228

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d39eb21
add template for internlm2
khai-meetkai Jul 30, 2024
e9e6731
update and fix code
khai-meetkai Jul 30, 2024
c51a337
minor fix
khai-meetkai Jul 30, 2024
70cd9f6
fix preprocess logits
khai-meetkai Jul 31, 2024
fd02d43
fix server_vision
khai-meetkai Aug 1, 2024
184e7d3
fix, add usage
khai-meetkai Aug 1, 2024
4778f65
fix bugs for batch_size>1
khai-meetkai Aug 1, 2024
2eb5e72
handle truncation in training
khai-meetkai Aug 1, 2024
dab3c2f
fix bug if batch_size>1
khai-meetkai Aug 2, 2024
7e5ea9a
add training params
khai-meetkai Aug 2, 2024
e809633
add source of code files
khai-meetkai Aug 2, 2024
4b54757
edit readme
khai-meetkai Aug 2, 2024
d8f8111
fix readme
khai-meetkai Aug 5, 2024
2ffd01a
fix original requirements
khai-meetkai Aug 5, 2024
9a9d816
remove test files
khai-meetkai Aug 5, 2024
b7c7b84
merge from add_llama3.1
khai-meetkai Aug 5, 2024
56f6696
add streaming for vision function calling
khai-meetkai Aug 5, 2024
3919768
fix README
khai-meetkai Aug 7, 2024
94b75c4
merge with main
khai-meetkai Aug 7, 2024
06e4ac5
format after merge
khai-meetkai Aug 7, 2024
5fc2f75
merg from main
khai-meetkai Aug 12, 2024
6aff1dd
change max_token to request.max_tokens
khai-meetkai Aug 13, 2024
ecbb0bb
merge from main
khai-meetkai Aug 21, 2024
535fdf5
fix streaming for internlm2
khai-meetkai Aug 22, 2024
9478eaf
fix conflicts from merge
khai-meetkai Aug 22, 2024
02f1de8
update vllm_inference to serve vision models
khai-meetkai Aug 22, 2024
af597e3
update vllm_inference to support vision LM
khai-meetkai Aug 22, 2024
5644454
fix readme
khai-meetkai Aug 22, 2024
9755af1
fix readme
khai-meetkai Aug 23, 2024
e0902e5
pull from main
jeffreymeetkai Sep 17, 2024
3017340
use jinja chat template
jeffreymeetkai Sep 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 100 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m
<summary>Changelog: (click to expand)</summary>

+ [2024-08-11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html)
+ [2024/08/09] We release **new vision function calling models**: [meetkai/functionary-vision-medium-v0.1](https://huggingface.co/meetkai/functionary-vision-medium-v0.1); [meetkai/functionary-vision-small-v0.1](https://huggingface.co/meetkai/functionary-vision-small-v0.1)
+ [2024/08/08] We release 128k-context length 70B-model: [meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) that are based on [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
+ [2024/08/07] We release 2 128k-context length models that are based on [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct):
+ [meetkai/functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1): **using Meta's original prompt template** as described in: [User-defined Custom tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1#user-defined-custom-tool-calling)
Expand All @@ -29,7 +30,7 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m

</details>

### Setup
### Setup

To install the required dependencies, run:

Expand Down Expand Up @@ -112,8 +113,30 @@ you can start your environment like this:
sudo docker run --gpus all -it --ipc=host --name functionary -v ${PWD}/functionary_workspace:/workspace -p 8000:8000 nvcr.io/nvidia/pytorch:23.10-py3
```


### Vision Function Calling Models
We also use ``server_vllm.py`` to deploy vision function calling models. Note that currently, vllm only supports single image in inputs, mutiple images will be supported in the future.

**Small Model:**
```shell
python3 server_vllm.py --model "meetkai/functionary-vision-small-v0.1" --max-model-len 8192

```

**Medium Model:**
```shell
# vllm requires to run this first: https://github.com/vllm-project/vllm/issues/6152
export VLLM_WORKER_MULTIPROC_METHOD=spawn
python server_vllm.py --model "meetkai/functionary-vision-medium-v0.1" --max-model-len 8192 --tensor-parallel-size 2
```

You need to have 4xA6000 (or A40) or 2xA100 to run medium model.

</details>

### OpenAI Compatible Usage

**For text only:**
```python
from openai import OpenAI

Expand Down Expand Up @@ -145,7 +168,67 @@ client.chat.completions.create(
)
```

**For Including Image (Vision Models)**
```python
from openai import OpenAI
import base64
import os

client = OpenAI(base_url="http://localhost:8000/v1", api_key="functionary")

def encode_image(image_path: str):
# check if the image exists
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

encoded_img = encode_image("assets/example.png")
client.chat.completions.create(
model="meetkai/functionary-vision-small-v0.1",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpg;base64,{encoded_img}"},
},
{
"type": "text",
"text": "can you translate the text in the image to Japanese",
},
],
}
],
tools=[{
"type": "function",
"function": {
"name": "translate",
"description": "translate text from language to language",
"parameters": {
"type": "object",
"properties": {
"src_language": {
"type": "string",
"description": "The source language, for example: en, vi, ja, ...",
},
"target_language": {
"type": "string",
"description": "The target language, for example: en, vi, ja, ...",
},
"text": {
"type": "string",
"description": "The text of source language to translate"
}
},
"required": ["src_language", "text", "target_language"],
},
},
}],
tool_choice="auto"
)
```

### Raw Usage:

Expand Down Expand Up @@ -197,22 +280,22 @@ print(response.text)


## Models Available
| Model | Description | VRAM FP16 |
|:-------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------|:------|
| [functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.2-GGUF) | 128k context, code interpreter, using **our own prompt template** | 24GB |
| [functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 160GB |
| [functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 24GB |
| [functionary-medium-v3.0](https://huggingface.co/meetkai/functionary-medium-v3.0) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v3.0-GGUF) | 8k context, based on [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 160GB |
| [functionary-small-v2.5](https://huggingface.co/meetkai/functionary-small-v2.5) / [GGUF](https://huggingface.co/meetkai/functionary-small-v2.5-GGUF) | 8k context, code interpreter | 24GB |
| [functionary-small-v2.4](https://huggingface.co/meetkai/functionary-small-v2.4) / [GGUF](https://huggingface.co/meetkai/functionary-small-v2.4-GGUF) | 8k context, code interpreter | 24GB |
| [functionary-medium-v2.4](https://huggingface.co/meetkai/functionary-medium-v2.4) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v2.4-GGUF) | 8k context, code interpreter, better accuracy | 90GB |
| [functionary-small-v2.2](https://huggingface.co/meetkai/functionary-small-v2.2) / [GGUF](https://huggingface.co/meetkai/functionary-small-v2.2-GGUF) | 8k context | 24GB |
| [functionary-medium-v2.2](https://huggingface.co/meetkai/functionary-medium-v2.2) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v2.2-GGUF) | 8k context| 90GB |
| [functionary-7b-v2.1](https://huggingface.co/meetkai/functionary-7b-v2.1) / [GGUF](https://huggingface.co/meetkai/functionary-7b-v2.1-GGUF) | 8k context | 24GB |
| [functionary-7b-v2](https://huggingface.co/meetkai/functionary-7b-v2) / [GGUF](https://huggingface.co/meetkai/functionary-7b-v2-GGUF) | Parallel function call support. | 24GB |
| [functionary-7b-v1.4](https://huggingface.co/meetkai/functionary-7b-v1.4) / [GGUF](https://huggingface.co/meetkai/functionary-7b-v1.4-GGUF) | 4k context, better accuracy (deprecated) | 24GB |
| [functionary-7b-v1.1](https://huggingface.co/meetkai/functionary-7b-v1.1) | 4k context (deprecated) | 24GB |
| functionary-7b-v0.1 | 2k context (deprecated) Not recommended, use 2.1 onwards | 24GB |
| Model | Model Type| Description | VRAM FP16 |
|:-------------------------------------------------------------------------------------|---|:--------------------------------------------------------------------------------------------------------------------------------------|:------|
| [meetkai/functionary-vision-medium-v0.1](https://huggingface.co/meetkai/functionary-vision-medium-v0.1) | Vision | 8k context| 160GB |
| [meetkai/functionary-vision-small-v0.1](https://huggingface.co/meetkai/functionary-vision-small-v0.1) | Vision | 8k context| 24GB |
| [functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.2-GGUF) | Text-Only | 128k context, code interpreter, using **our own prompt template** | 24GB |
| [functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v3.1-GGUF) | Text-Only | 128k context, code interpreter, using **original Meta's prompt template** | 160GB |
| [functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.1-GGUF) | Text-Only | 128k context, code interpreter | 24GB |
| [functionary-medium-v3.0](https://huggingface.co/meetkai/functionary-medium-v3.0) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v3.0-GGUF)| Text-only | 8k context, based on [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 160GB |
| [functionary-small-v2.5](https://huggingface.co/meetkai/functionary-small-v2.5) / [GGUF](https://huggingface.co/meetkai/functionary-small-v2.5-GGUF)| Text-only | 8k context, code interpreter | 24GB |
| [functionary-small-v2.4](https://huggingface.co/meetkai/functionary-small-v2.4) / [GGUF](https://huggingface.co/meetkai/functionary-small-v2.4-GGUF)| Text-only | 8k context, code interpreter | 24GB |
| [functionary-medium-v2.4](https://huggingface.co/meetkai/functionary-medium-v2.4) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v2.4-GGUF)| Text-only | 8k context, code interpreter, better accuracy | 90GB |
| [functionary-small-v2.2](https://huggingface.co/meetkai/functionary-small-v2.2) / [GGUF](https://huggingface.co/meetkai/functionary-small-v2.2-GGUF)| Text-only | 8k context | 24GB |
| [functionary-medium-v2.2](https://huggingface.co/meetkai/functionary-medium-v2.2) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v2.2-GGUF)| Text-only | 8k context| 90GB |
| [functionary-7b-v2.1](https://huggingface.co/meetkai/functionary-7b-v2.1) / [GGUF](https://huggingface.co/meetkai/functionary-7b-v2.1-GGUF)| Text-only | 8k context | 24GB |
| [functionary-7b-v2](https://huggingface.co/meetkai/functionary-7b-v2) / [GGUF](https://huggingface.co/meetkai/functionary-7b-v2-GGUF)| Text-only | Parallel function call support. | 24GB |
| [functionary-7b-v1.4](https://huggingface.co/meetkai/functionary-7b-v1.4) / [GGUF](https://huggingface.co/meetkai/functionary-7b-v1.4-GGUF)| Text-only | 4k context, better accuracy (deprecated) | 24GB |

### Compatibility information

Expand Down
Binary file added assets/example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
80 changes: 74 additions & 6 deletions functionary/inference_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,89 @@
FunctionCall,
Tool,
ChatCompletionRequest,
UsageInfo,
)
from functionary.prompt_template import get_prompt_template_from_tokenizer
from functionary.prompt_template.prompt_utils import (
prepare_messages_for_inference,
extract_images_from_messages,
)
from typing import Any, List, Optional
from llava.mm_utils import process_images
from typing import Any, List, Optional, Tuple
import torch
from functionary.inference_utils import analyze_tools_and_tool_choice
from enum import Enum


class ModelType(str, Enum):
llama_llava = "llama_llava"
internvl_chat = "internvl_chat"


def generate(
*, model_type: ModelType, model: Any, tokenizer: Any, request: ChatCompletionRequest
) -> Tuple[ChatMessage, UsageInfo]:
generate_func = generate_internvl_chat
if model_type == ModelType.llama_llava:
generate_func = generate_llava
return generate_func(model=model, tokenizer=tokenizer, request=request)


def generate_internvl_chat(
*, model: Any, tokenizer: Any, request: ChatCompletionRequest
) -> Tuple[ChatMessage, UsageInfo]:
tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request)

prompt_token_ids = prepare_messages_for_inference(
tokenizer=tokenizer,
messages=request.messages,
tools_or_functions=tools_or_functions,
tool_choice=tool_func_choice,
device=model.device,
)
input_ids = prompt_token_ids.unsqueeze(0)
attention_mask = torch.ones_like(input_ids).to(model.device)
images = extract_images_from_messages(
[message.dict() for message in request.messages]
)
input_ids, attention_mask, _, pixel_values, _ = model.expand_input_ids(
input_ids, None, attention_mask, images, training=False
)

prompt_template = get_prompt_template_from_tokenizer(tokenizer)
eos_token_ids = [
tokenizer.convert_tokens_to_ids(tok)
for tok in prompt_template.get_stop_tokens_for_generation()
]

generation_config = dict(
max_new_tokens=request.max_tokens,
do_sample=False,
eos_token_id=eos_token_ids,
temperature=request.temperature,
)

generation_output = model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
**generation_config
)

response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
assistant_response = prompt_template.parse_assistant_response(response)
usage = UsageInfo(
prompt_tokens=input_ids.shape[-1],
completion_tokens=generation_output.shape[-1],
total_tokens=input_ids.shape[-1] + generation_output.shape[-1],
)
return ChatMessage(**assistant_response), usage


def generate_llava(
*, model: Any, tokenizer: Any, request: ChatCompletionRequest
) -> ChatMessage:
) -> Tuple[ChatMessage, UsageInfo]:
from llava.mm_utils import process_images

tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request)

prompt_token_ids = prepare_messages_for_inference(
Expand Down Expand Up @@ -55,11 +123,11 @@ def generate(
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=1024,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
)

text_output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0]

result = prompt_template.parse_assistant_response(text_output, tool_choice="auto")
return ChatMessage(**result)
return ChatMessage(**result), UsageInfo()
8 changes: 7 additions & 1 deletion functionary/prompt_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, List

from functionary.prompt_template.base_template import PromptTemplate
from functionary.prompt_template.internlm2_prompt_template import InternLMChat
from functionary.prompt_template.llama3_prompt_template import Llama3Template
from functionary.prompt_template.llama3_prompt_template_v3 import Llama3TemplateV3
from functionary.prompt_template.llama31_prompt_template import Llama31Template
Expand All @@ -28,7 +29,7 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]:
# directly add LLavaLlama as it is not a direct subclass of PromptTemplate but the subclass of: Llama3TemplateV3
# we don't use get_prompt_template or this will return the parent class
all_templates_obj.append(LlavaLlama.get_prompt_template())

all_templates_obj.append(InternLMChat.get_prompt_template())
return all_templates_obj


Expand Down Expand Up @@ -80,6 +81,11 @@ def get_prompt_template_from_tokenizer(tokenizer: Any) -> PromptTemplate:
p4 = _TEMPLATE_DIC[Llama3TemplateV3.version]
p5 = _TEMPLATE_DIC[LlavaLlama.version]
p6 = _TEMPLATE_DIC[Llama31Template.version]
p7 = _TEMPLATE_DIC[InternLMChat.version]

token_ids = tokenizer.encode(p7.img_context, add_special_tokens=False)
if len(token_ids) == 1:
return p7

token_ids = tokenizer.encode("<|eom_id|>", add_special_tokens=False)
if len(token_ids) == 1 and token_ids[0] == 128008: # tokenizer from llama-3.1
Expand Down
8 changes: 5 additions & 3 deletions functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ class PromptTemplate:
_jinja_env.policies["json.dumps_kwargs"] = {"sort_keys": False}
# Mapping from class --> instance to create singleton instance
_instances = {}

def __init__(self):
self._jinja_template = self._jinja_env.from_string(self.get_chat_template_jinja())
self._jinja_template = self._jinja_env.from_string(
self.get_chat_template_jinja()
)

@abstractmethod
def get_start_of_function_call_token(self) -> str:
Expand Down Expand Up @@ -341,7 +343,7 @@ def get_chat_template_jinja(self) -> str:
json_to_ts_schema = f.read()
with open(f"{path_prefix}{self.version}.txt", "r") as f:
template = f.read()

return (
template[: template.index("{%")]
+ json_to_ts_schema
Expand Down
21 changes: 21 additions & 0 deletions functionary/prompt_template/internlm2_prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils
from functionary.prompt_template.llama3_prompt_template_v3 import Llama3TemplateV3
from functionary.schema import generate_schema_from_functions


class InternLMChat(Llama3TemplateV3):
version = "internlm2-chat"
img_token = "<img>"
start_of_turn = "<|im_start|>"
eos_token = "<|im_end|>"
function_separator = ">>>"

def get_assistant_prefixes(self) -> List[str]:
return [f"{self.start_of_turn}assistant\n{self.function_separator}"]

def get_stop_tokens_for_generation(self) -> List[str]:
return [self.eos_token]
Loading
Loading