Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 18, 2025
1 parent 6df15cc commit 7170dbd
Showing 1 changed file with 53 additions and 2 deletions.
55 changes: 53 additions & 2 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs.

We provide three per-tensor scaling strategies: dynamic, delayed and static. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`).

## float8 linear with dynamic scaling for `input`, `weight` and `grad_output`
## float8 linear with dynamic tensorwise scaling

This is the most accurate recipe as every tensor is scaled dynamically.
This is the default recipe, with a good balance of performance and accuracy.

```python
import torch
Expand Down Expand Up @@ -63,6 +63,57 @@ for _ in range(10):
optimizer.step()
```

## float8 linear with rowwise scaling

This is a more accurate recipe compared to tensorwise, with more granular scaling.

:warning: <em>The composability of float8 with rowwise scaling with Tensor Parallelism is WIP, please see https://github.com/pytorch/ao/issues/1732 for more details.</em>

```python
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training, Float8LinearConfig
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

# configure rowwise scaling
config = Float8LinearConfig.from_recipe_name("rowwise")

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, config=config, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
m = torch.compile(m)

# toy training loop
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
```

## float8 linear with delayed scaling

:warning: <em>We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details.</em>
Expand Down

0 comments on commit 7170dbd

Please sign in to comment.