Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update torchao READMEs with new configuration APIs #1711

Merged
merged 59 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
24114ce
Update
vkuzo Jan 22, 2025
5b9d876
Update
vkuzo Jan 22, 2025
1cea42f
Update
vkuzo Jan 22, 2025
138883b
Update
vkuzo Jan 22, 2025
ba045ea
Update
vkuzo Jan 22, 2025
94d9426
Update
vkuzo Jan 22, 2025
b589ce7
Update
vkuzo Jan 23, 2025
aaba2d8
Update
vkuzo Feb 5, 2025
26850da
Update
vkuzo Feb 5, 2025
7caecb1
Update
vkuzo Feb 10, 2025
d42a590
Update
vkuzo Feb 10, 2025
5702ea0
Update
vkuzo Feb 11, 2025
0542402
Update
vkuzo Feb 11, 2025
146ac3b
Update
vkuzo Feb 11, 2025
5f75897
Update
vkuzo Feb 11, 2025
1c9c39f
Update
vkuzo Feb 11, 2025
1ff1f6e
Update
vkuzo Feb 11, 2025
bb253ef
Update
vkuzo Feb 11, 2025
c2ed2da
Update
vkuzo Feb 11, 2025
698989b
Update
vkuzo Feb 11, 2025
6184530
Update
vkuzo Feb 11, 2025
397002e
Update
vkuzo Feb 11, 2025
5514a99
Update
vkuzo Feb 11, 2025
fac3263
Update
vkuzo Feb 11, 2025
1e15950
Update
vkuzo Feb 11, 2025
e9c03e0
Update
vkuzo Feb 11, 2025
f5b7d87
Update
vkuzo Feb 11, 2025
6684b39
Update
vkuzo Feb 11, 2025
4dcb349
Update
vkuzo Feb 12, 2025
d63e657
Update
vkuzo Feb 13, 2025
36c2096
Update
vkuzo Feb 13, 2025
ca7531d
Update
vkuzo Feb 13, 2025
b55b1bb
Update
vkuzo Feb 13, 2025
3aaf5a0
Update
vkuzo Feb 13, 2025
3fd4cfc
Update
vkuzo Feb 13, 2025
e0124f7
Update
vkuzo Feb 13, 2025
4de0f68
Update
vkuzo Feb 14, 2025
060cda8
Update
vkuzo Feb 14, 2025
ac7e5da
Update
vkuzo Feb 14, 2025
1e152e3
Update
vkuzo Feb 14, 2025
0be10ae
Update
vkuzo Feb 14, 2025
2f0d4e3
Update
vkuzo Feb 14, 2025
e397c47
Update
vkuzo Feb 14, 2025
326e7c4
Update
vkuzo Feb 14, 2025
9eebc4f
Update
vkuzo Feb 14, 2025
81dcff8
Update
vkuzo Feb 14, 2025
f44befc
Update
vkuzo Feb 14, 2025
e534d64
Update
vkuzo Feb 14, 2025
bd2302d
Update
vkuzo Feb 14, 2025
54d3c31
Update
vkuzo Feb 14, 2025
7688b35
Update
vkuzo Feb 14, 2025
e776f11
Update
vkuzo Feb 14, 2025
e9f056f
Update
vkuzo Feb 14, 2025
03fb862
Update
vkuzo Feb 14, 2025
0c09446
Update
vkuzo Feb 14, 2025
729a8f2
Update
vkuzo Feb 14, 2025
1979394
Update
vkuzo Feb 14, 2025
7e9d405
Update
vkuzo Feb 14, 2025
fd3b5c7
Update
vkuzo Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ For inference, we have the option of
```python
from torchao.quantization.quant_api import (
quantize_,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_weight_only
Int8DynamicActivationInt8WeightConfig,
Int4WeightOnlyConfig,
Int8WeightOnlyConfig
)
quantize_(m, int4_weight_only())
quantize_(m, Int4WeightOnlyConfig())
```

For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline.
For gpt-fast `Int4WeightOnlyConfig()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline.

If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, int8_weight_only(), device="cuda")` which will send and quantize each layer individually to your GPU.
If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, Int8WeightOnlyConfig(), device="cuda")` which will send and quantize each layer individually to your GPU.

If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file is missing a few, e.g. I still see intx_quantization_aware_training, int8_weight_only_quantized_training, and fpx_weight_only

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you! I just updated intx_quantization_aware_training and fpx_weight_only.

int8_weight_only_quantized_training is a prototype API and is not migrated yet - we can do this after the 0.9.0 release.


Expand All @@ -63,27 +63,27 @@ Post-training quantization can result in a fast and compact model, but may also
```python
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int4_weight,
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
)

# Insert fake quantization
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
my_model,
intx_quantization_aware_training(activation_config, weight_config),
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)

# Run training... (not shown)

# Convert fake quantization to actual quantized operations
quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
quantize_(my_model, FromIntXQuantizationAwareTrainingConfig())
quantize_(my_model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```

### Float8
Expand Down Expand Up @@ -139,7 +139,7 @@ The best example we have combining the composability of lower bit dtype with com

We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow

1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference

Expand Down
44 changes: 22 additions & 22 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ model(input)

When used as in the example above, when the `autoquant` api is called alongside torch.compile, autoquant sets up the model so that when its run on the next input, the autoquantization and torch.compile processes leave you with a heavily optimized model.

When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow.
When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `Int4WeightOnlyConfig()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow.

Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods.

Expand All @@ -109,13 +109,13 @@ be applied individually. While there are a large variety of quantization apis, t

```python
# for torch 2.4+
from torchao.quantization import quantize_, int4_weight_only
from torchao.quantization import quantize_, Int4WeightOnlyConfig
group_size = 32

# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through
# use_hqq flag for `int4_weight_only` quantization
# use_hqq flag for `Int4WeightOnlyConfig` quantization
use_hqq = False
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq))

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand All @@ -128,8 +128,8 @@ Note: The quantization error incurred by applying int4 quantization to your mode

```python
# for torch 2.4+
from torchao.quantization import quantize_, int8_weight_only
quantize_(model, int8_weight_only())
from torchao.quantization import quantize_, Int8WeightOnlyConfig
quantize_(model, Int8WeightOnlyConfig())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
Expand All @@ -140,8 +140,8 @@ change_linear_weights_to_int8_woqtensors(model)

```python
# for torch 2.4+
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
quantize_(model, int8_dynamic_activation_int8_weight())
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
quantize_(model, Int8DynamicActivationInt8WeightConfig())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
Expand All @@ -152,8 +152,8 @@ change_linear_weights_to_int8_dqtensors(model)

```python
# for torch 2.5+
from torchao.quantization import quantize_, float8_weight_only
quantize_(model, float8_weight_only())
from torchao.quantization import quantize_, Float8WeightOnlyConfig
quantize_(model, Float8WeightOnlyConfig())
```

Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required.
Expand All @@ -162,8 +162,8 @@ Supports all dtypes for original weight and activation. This API is only tested

```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
```

Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required.
Expand All @@ -172,8 +172,8 @@ Supports all dtypes for original weight and activation. This API is only tested

```python
# for torch 2.5+
from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow()))
from torchao.quantization import quantize_, PerRow, Float8DynamicActivationFloat8WeightConfig
quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
```

Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required.
Expand All @@ -182,14 +182,14 @@ Per-row scaling is only supported for bfloat16 weight and activation. This API i

```python
# for torch 2.4+
from torchao.quantization import quantize_, fpx_weight_only
quantize_(model, fpx_weight_only(3, 2))
from torchao.quantization import quantize_, FPXWeightOnlyConfig
quantize_(model, FPXWeightOnlyConfig(3, 2))
```

You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype.

## Affine Quantization Details
Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_preicsion_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization.
Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization.

### Quantization Primitives
We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass.
Expand All @@ -200,7 +200,7 @@ Note: these primitive ops supports two "types" of quantization, distinguished by
We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel)

#### Layouts
We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels.
We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for workflows backing `Int8WeightOnlyConfig` and `Int8DynamicActivationInt8WeightConfig` and also as a default layout. `tensor_core_tiled` layout is used for workflows backing `Int4WeightOnlyConfig` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels.

### Zero Point Domains
```ZeroPointDomain``` is used to control the data types of zero points. ```ZeroPointDomain.None``` means zero_point is None, ```ZeroPointDomain.FLOAT``` means zero_point is in the floating point domain and ```ZeroPointDomain.INT``` means integer domain. For detailed implementation of different zero point data types, refer to [the reference implementation](../../test/quantization/test_quant_primitives.py).
Expand All @@ -223,7 +223,7 @@ from torchao.dtypes import to_affine_quantized_intx
import copy
from torchao.quantization.quant_api import (
quantize_,
int4_weight_only,
Int4WeightOnlyConfig,
)

class ToyLinearModel(torch.nn.Module):
Expand All @@ -249,9 +249,9 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
group_size = 32
# only works for torch 2.4+
quantize_(m, int4_weight_only(group_size=group_size))
quantize_(m, Int4WeightOnlyConfig(group_size=group_size))
## If different zero_point_domain needed
# quantize_(m, int4_weight_only(group_size=group_size), zero_point_domain=ZeroPointDomain.FLOAT)
# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT))

# temporary workaround for tensor subclass + torch.compile
# NOTE: this is only need for torch version < 2.5+
Expand Down Expand Up @@ -360,7 +360,7 @@ We're trying to develop kernels for low bit quantization for intx quantization f
| | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 |
| | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 |

You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`.
You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `torchao/_models/llama/generate.py`.

### int8_dynamic_activation_intx_weight Quantization
We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used.
Expand Down
18 changes: 9 additions & 9 deletions torchao/quantization/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,22 @@ def train_loop(m: torch.nn.Module):

The recommended way to run QAT in torchao is through the `quantize_` API:
1. **Prepare:** specify how weights and/or activations are to be quantized through
[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242)
[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242)
2. **Convert:** quantize the model using the standard post-training quantization (PTQ)
functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606)
functions such as [`Int8DynamicActivationInt4WeightConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606)

For example:


```python
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int4_weight,
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
)
model = get_model()

Expand All @@ -96,7 +96,7 @@ activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=Fal
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
model,
intx_quantization_aware_training(activation_config, weight_config),
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)

# train
Expand All @@ -105,8 +105,8 @@ train_loop(model)
# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, from_intx_quantization_aware_training())
quantize_(model, int8_dynamic_activation_int4_weight(group_size=32))
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))

# inference or generate
```
Expand All @@ -117,7 +117,7 @@ the following with a filter function during the prepare step:
```
quantize_(
m,
intx_quantization_aware_training(weight_config=weight_config),
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
```
Expand Down
Loading