-
Notifications
You must be signed in to change notification settings - Fork 260
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: fengding <[email protected]>
- Loading branch information
1 parent
9bddd52
commit 0bc5d8c
Showing
4 changed files
with
205 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,8 @@ FP8 Quantization | |
1. [Introduction](#introduction) | ||
2. [Supported Parameters](#supported-parameters) | ||
3. [Get Start with FP8 Quantization](#get-start-with-fp8-quantization) | ||
4. [Examples](#examples) | ||
4. [Optimum-habana LLM example](#optimum-habana-LLM-example) | ||
5. [VLLM example](#VLLM-example) | ||
|
||
## Introduction | ||
|
||
|
@@ -75,30 +76,192 @@ Intel Neural Compressor provides general quantization APIs to leverage HPU FP8 c | |
</tbody></table> | ||
|
||
## Get Start with FP8 Quantization | ||
[Demo Usage](https://github.com/intel/neural-compressor?tab=readme-ov-file#getting-started) | ||
[Computer vision example](../../../examples/3.x_api/pytorch/cv/fp8_quant) | ||
|
||
### Demo Usage | ||
|
||
```python | ||
from neural_compressor.torch.quantization import ( | ||
FP8Config, | ||
prepare, | ||
convert, | ||
) | ||
import torchvision.models as models | ||
|
||
model = models.resnet18() | ||
qconfig = FP8Config(fp8_config="E4M3") | ||
model = prepare(model, qconfig) | ||
# customer defined calibration | ||
calib_func(model) | ||
model = convert(model) | ||
## Optimum-habana LLM example | ||
### Overview | ||
[Optimum](https://huggingface.co/docs/optimum) is an extension of Transformers that provides a set of performance optimization tools to train and run models on targeted hardware with maximum efficiency. | ||
[Optimum-habana](https://github.com/huggingface/optimum-habana) is the interface between the Transformers, Diffusers libraries and Intel Gaudi AI Accelerators (HPU). It provides higher performance based on modified modeling files, and utilizes Intel Neural Compressor for FP8 quantization internally, [running-with-fp8](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8) | ||
![](./imgs/optimum-habana.png) | ||
### Installation | ||
Refer to [optimum-habana, install-the-library-and-get-example-scripts](https://github.com/huggingface/optimum-habana?tab=readme-ov-file#install-the-library-and-get-example-scripts) | ||
Option to install from source, | ||
``` | ||
$ git clone https://github.com/huggingface/optimum-habana | ||
$ cd optimum-habana && git checkout v1.14.0 (change the version) | ||
$ pip install -e . | ||
$ pip install git+https://github.com/HabanaAI/[email protected] | ||
$ cd examples/text-generation | ||
$ pip install -r requirements.txt | ||
$ pip install -r requirements_lm_eval.txt (Option) | ||
``` | ||
### Check neural_compressor code | ||
> optimum-habana/examples/text-generation/utils.py | ||
>> initialize_model() -> setup_model() -> setup_quantization() -> FP8Config/prepare()/convert() | ||
### FP8 KV cache | ||
Introduction: [kv-cache-quantization in huggingface transformers](https://huggingface.co/blog/kv-cache-quantization) | ||
|
||
BF16 KVCache Code -> [Modeling_all_models.py -> KVCache()](https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/models/modeling_all_models.py) | ||
|
||
FP8 KVCache code trace with neural compressor support, for example Llama models, | ||
> optimum-habana/optimum/habana/transformers/models/llama/modeling_llama.py | ||
>> GaudiLlamaForCausalLM() -> self.model() | ||
>>> GaudiLlamaModel() -> forward() -> decoder_layer() -> GaudiLlamaDecoderLayer() forward() -> pre_attn() -> pre_attn_forward() -> self.k_cache.update | ||
> neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py | ||
>> PatchedKVCache() -> update() | ||
>> PatchedModuleFusedSDPA() | ||
Models list which support FP8 KV Cache, | ||
``` | ||
microsoft/Phi-3-mini-4k-instruct | ||
bigcode/starcoder2-3b | ||
Qwen/Qwen2.5-7B-Instruct| | ||
meta-llama/Llama-3.2-3B-Instruct | ||
tiiuae/falcon-7b-instruct | ||
mistralai/Mixtral-8x7B-Instruct-v0.1 | ||
EleutherAI/gpt-j-6b | ||
mistralai/Mistral-Nemo-Instruct-2407 | ||
... | ||
``` | ||
|
||
### Running with FP8 | ||
Refer to [here](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8). | ||
Change "--model_name_or_path" to be your model like | ||
"meta-llama/Llama-3.1-8B-Instruct", | ||
"Qwen/Qwen2.5-7B-Instruct", or | ||
"mistralai/Mixtral-8x7B-Instruct-v0.1" and so on. | ||
"--use_kv_cache" is to enable FP8 KV cache. | ||
|
||
### Profiling | ||
Add "--profiling_warmup_steps 5 --profiling_steps 2 --profiling_record_shapes" as args in the end of commandline of run_generation.py. | ||
Refer to [torch.profiler.ProfilerActivity.HPU](https://github.com/huggingface/optimum-habana/blob/c9e1c23620618e2f260c92c46dfeb163545ec5ba/optimum/habana/utils.py#L305). | ||
|
||
### FP8 Accuracy | ||
"lm_eval.tasks", "lm_eval.evaluator", "lm_eval" are installed from the above requirements_lm_eval.txt. The tasks can be set and the default is ["hellaswag", "lambada_openai", "piqa", "winogrande"], [more info](https://github.com/EleutherAI/lm-evaluation-harness/) | ||
|
||
| `Llama-2-7b-hf`| fp8 & fp8 KVCache| bf16 w/ bf16 KVCache| | ||
|---------------|---------|--------| | ||
| hellaswag | 0.5691097390957977 | 0.5704043019318861 | | ||
| lambada_openai| 0.7360760721909567 | 0.7372404424607025 | | ||
| piqa | 0.7850924918389554 | 0.7818280739934712 | | ||
| winogrande | 0.6929755327545383 | 0.6929755327545383 | | ||
|
||
| `Qwen2.5-7B-Instruct`| fp8 & fp8 KVCache| bf16 w/ bf16 KVCache| | ||
|---------------|---------|--------| | ||
| hellaswag | 0.2539334793865764 | 0.2539334793865764 | | ||
| lambada_openai| 0.0 | 0.0 | | ||
| piqa | 0.5391730141458106 | 0.5391730141458106 | | ||
| winogrande | 0.4956590370955012 | 0.4956590370955012 | | ||
|
||
| `Llama-3.1-8B-Instruct`| fp8 & fp8 KVCache| bf16 w/ bf16 KVCache| | ||
|---------------|---------|--------| | ||
| hellaswag | 0.5934076877116112 | 0.5975901214897431 | | ||
| lambada_openai| 0.7230739375121289 | 0.7255967397632447 | | ||
| piqa | 0.7932535364526659 | 0.8030467899891186 | | ||
| winogrande | 0.7434885556432518 | 0.7371744277821626 | | ||
|
||
|
||
| `Mixtral-8x7B-Instruct-v0.1`| fp8 & fp8 KVCache| bf16 w/ bf16 KVCache| | ||
|---------------|---------|--------| | ||
| hellaswag | 0.25323640709022105 | 0.25323640709022105 | | ||
| lambada_openai| 0.0 | 0.0 | | ||
| piqa | 0.528835690968444 | 0.528835690968444 | | ||
| winogrande | 0.4956590370955012 | 0.4956590370955012 | | ||
|
||
## VLLM example | ||
### Overview | ||
![](./imgs/vllm_gaudi.png) | ||
|
||
### Installation | ||
Refer to [Habana vllm-fork](https://github.com/HabanaAI/vllm-fork) to install. | ||
Option to install `vllm-hpu-extension`, `neural_compressor` and `vllm` from the source, | ||
``` | ||
$ git clone https://github.com/HabanaAI/vllm-fork.git | ||
$ cd vllm-fork | ||
$ pip install -r requirements-hpu.txt | ||
$ python setup.py develop --user | ||
## Check | ||
$ pip list |grep vllm | ||
vllm 0.6.3.dev1122+g2f43ebf5.d20241121.gaudi118 /home/fengding/vllm-fork | ||
vllm-hpu-extension 0.1 | ||
## Validation | ||
$ VLLM_SKIP_WARMUP=true python3 examples/offline_inference.py | ||
...... | ||
Prompt: 'Hello, my name is', Generated text: ' Kelly and I have a job to do.\nI need someone to come over' | ||
Prompt: 'The president of the United States is', Generated text: ' facing a sharp criticism of his handling of the coronavirus pandemic, including' | ||
Prompt: 'The capital of France is', Generated text: ' the capital of the Socialist Party of France (SPF), with its state-' | ||
Prompt: 'The future of AI is', Generated text: " in what's coming, not what's coming.\nI don't know what" | ||
``` | ||
|
||
### Run FP8 calibration | ||
Refer to [vllm-hpu-extension->calibration](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration) | ||
``` | ||
$ git clone https://github.com/HabanaAI/vllm-hpu-extension | ||
$ cd vllm-hpu-extension/calibration | ||
# For Llama-3.1.8B-Instruct | ||
$ ./calibrate_model.sh -m meta-llama/Llama-3.1-8B-Instruct -d /home/fengding/processed-data.pkl -o ./output_llama3.1.8b.Instruct -b 128 -t 1 -l 128 | ||
## Generate scale factors in ./output_llama3.1.8b.Instruct | ||
``` | ||
|
||
### Start vllm server | ||
``` | ||
$ cd vllm-fork/ | ||
$ PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ | ||
PT_HPU_WEIGHT_SHARING=0 \ | ||
VLLM_CONTIGUOUS_PA=true \ | ||
VLLM_SKIP_WARMUP=true \ | ||
QUANT_CONFIG=output_llama3.1.8b.Instruct/maxabs_quant_g2.json \ | ||
python3 -m vllm.entrypoints.openai.api_server \ | ||
--model meta-llama/Llama-3.1-8B-Instruct \ | ||
--port 8080 \ | ||
--gpu-memory-utilization 0.9 \ | ||
--tensor-parallel-size 1 \ | ||
--disable-log-requests \ | ||
--block-size 128 \ | ||
--quantization inc \ | ||
--kv-cache-dtype fp8_inc \ | ||
--device hpu \ | ||
--weights-load-device cpu \ | ||
--dtype bfloat16 \ | ||
--num_scheduler_steps 16 2>&1 > vllm_serving.log & | ||
``` | ||
Refer to [vllm-fork->README_GAUDI.md](https://github.com/HabanaAI/vllm-fork/blob/habana_main/README_GAUDI.md) for more details. | ||
|
||
### Start client to test | ||
``` | ||
$ curl --noproxy "*" http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{"model": "meta-llama/Llama-3.1-8B-Instruct", "prompt": "San Francisco is a", "max_tokens": 100}' | ||
``` | ||
|
||
### Run benchmark | ||
``` | ||
python benchmarks/benchmark_serving.py \ | ||
--backend vllm \ | ||
--model meta-llama/Llama-3.1-8B-Instruct \ | ||
--dataset-name sonnet \ | ||
--dataset-path benchmarks/sonnet.txt \ | ||
--request-rate 128 \ | ||
--num-prompts 128 \ | ||
--port 8080 \ | ||
--sonnet-input-len 128 \ | ||
--sonnet-output-len 128 \ | ||
--sonnet-prefix-len 100 | ||
``` | ||
|
||
## Examples | ||
### FP8 KV cache | ||
Code trace | ||
> vllm-fork/vllm/attention/backends/hpu_attn.py | ||
>> from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache | ||
>> HPUAttentionImpl() -> self.k_cache() / self.v_cache() | ||
| Task | Example | | ||
|----------------------|---------| | ||
| Computer Vision (CV) | [Link](../../../examples/3.x_api/pytorch/cv/fp8_quant/) | | ||
| Large Language Model (LLM) | [Link](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8) | | ||
> neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py | ||
>> PatchedVLLMKVCache() | ||
> Note: For LLM, Optimum-habana provides higher performance based on modified modeling files, so here the Link of LLM goes to Optimum-habana, which utilize Intel Neural Compressor for FP8 quantization internally. | ||
> neural_compressor/torch/algorithms/fp8_quant/common.py | ||
>> "VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.