Skip to content

Commit

Permalink
[Nvidia] Support fp8 to bf16 casting on RTX 4000 series (#5544)
Browse files Browse the repository at this point in the history
I noticed that some of the tests were failing when I was testing on a
workstation with a consumer RTX card. Turns out that sm_89 supports fp8,
but doesn't support cvt.bf16.f16

From the ptx spec:

```
cvt.bf16.{u8/s8/u16/s16/u32/s32/u64/s64/f16/f64/bf16}, cvt.{u8/s8/u16/s16/u32/s32/u64/s64/f16/f64}.bf16, and cvt.tf32.f32.{relu}.{rn/rz} require sm_90 or higher.
```

This adds a path to first convert to fp32 and then bf16 if compute
compatibility is < 90,

This is already hit in the tests (specifically several test cases in
test core, many variations on dot_scaled in particular).
  • Loading branch information
mbrookhart authored Jan 7, 2025
1 parent 4a4dac9 commit 4947a95
Showing 1 changed file with 35 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,39 @@ static const Fp8ConversionDesc Fp16_to_Fp8E4M3Nv = {
"}",
32, 16, 2};

// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
static const Fp8ConversionDesc Fp8E4M3Nv_to_Bf16 = {
"{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
".reg .b16 b<2>; \n"
"cvt.rn.f16x2.e4m3x2 a, $1; \n"
"mov.b32 {a0, a1}, a; \n"
"cvt.bf16.f16 b0, a0; \n"
"cvt.bf16.f16 b1, a1; \n"
"mov.b32 $0, {b0, b1}; \n"
"}",
16, 32, 2};
static const Fp8ConversionDesc Fp8E4M3Nv_to_Bf16(bool hasNativeFP) {
Fp8ConversionDesc ret;
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
if (!hasNativeFP) {
ret = {"{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
".reg .f32 b<2>; \n"
".reg .b16 c<2>; \n"
"cvt.rn.f16x2.e4m3x2 a, $1; \n"
"mov.b32 {a0, a1}, a; \n"
"cvt.f32.f16 b0, a0; \n"
"cvt.f32.f16 b1, a1; \n"
"cvt.rn.bf16.f32 c0, b0; \n"
"cvt.rn.bf16.f32 c1, b1; \n"
"mov.b32 $0, {c0, c1}; \n"
"}",
16, 32, 2};
} else {
ret = {"{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
".reg .b16 b<2>; \n"
"cvt.rn.f16x2.e4m3x2 a, $1; \n"
"mov.b32 {a0, a1}, a; \n"
"cvt.bf16.f16 b0, a0; \n"
"cvt.bf16.f16 b1, a1; \n"
"mov.b32 $0, {b0, b1}; \n"
"}",
16, 32, 2};
}
return ret;
}

// Bf16 (x2) -> Fp8E4M3 (x2) (packed)
static const Fp8ConversionDesc Bf16_to_Fp8E4M3Nv = {
Expand Down Expand Up @@ -424,7 +444,8 @@ struct FpToFpOpConversion
// F8 -> BF16
{{F8E5M2TyID, BF16TyID, undefRounding},
Fp8E5M2_to_Bf16(computeCapability >= 90)},
{{F8E4M3TyID, BF16TyID, undefRounding}, Fp8E4M3Nv_to_Bf16},
{{F8E4M3TyID, BF16TyID, undefRounding},
Fp8E4M3Nv_to_Bf16(computeCapability >= 90)},
// BF16 -> F8
{{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
Bf16_to_Fp8E5M2(computeCapability >= 90)},
Expand Down

0 comments on commit 4947a95

Please sign in to comment.