Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cuda fallback bf16 for compute_cap < 8.0 #2704

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion candle-kernels/src/affine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
AFFINE_OP(__nv_bfloat16, affine_bf16)
#endif

Expand Down
2 changes: 1 addition & 1 deletion candle-kernels/src/binary.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "binary_op_macros.cuh"
#include<stdint.h>

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y)
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
Expand Down
28 changes: 11 additions & 17 deletions candle-kernels/src/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,28 @@ extern "C" __global__ void FN_NAME( \
} \

#if __CUDA_ARCH__ >= 800
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16 )
#elif __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16 )

CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
#else
#include <cuda.h>
#if CUDA_VERSION >= 11000
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800 // needed CUDA_VERSION >= 11000
CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16)
CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16)
#endif
#endif

#if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16)
CAST_OP(__half, __half, cast_f16_f16)

CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8)
CAST_OP(__half, uint32_t, cast_f16_u32)
Expand Down
30 changes: 15 additions & 15 deletions candle-kernels/src/compatibility.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ __device__ double atomicAdd(double* address, double val) {
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
__device__ __half atomicAdd(__half *address, __half val) {
// unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
// unsigned int old = *address_as_ui;
// unsigned int assumed;
// bool unaligned = (size_t) address & 2;
// do {
// assumed = old;
// unsigned int hsum;
// hsum = unaligned ? (old >> 16) : (old & 0xffff);
// hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
// old = atomicCAS(address_as_ui, assumed,
// unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
// );

// } while (assumed != old);
// return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
bool unaligned = (size_t) address & 2;
do {
assumed = old;
unsigned int hsum;
hsum = unaligned ? (old >> 16) : (old & 0xffff);
hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
old = atomicCAS(address_as_ui, assumed,
unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
);

} while (assumed != old);
return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
}
#endif

Expand Down
43 changes: 22 additions & 21 deletions candle-kernels/src/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,28 +158,8 @@ __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a,
__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); }
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
#if __CUDA_ARCH__ >= 530
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); }
__device__ __forceinline__ __half cosg(__half a) { return hcos(a); }
__device__ __forceinline__ __half sing(__half a) { return hsin(a); }
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
__device__ __forceinline__ __half ceilg(__half a) { return __float2half(ceilf(__half2float(a))); }
__device__ __forceinline__ __half floorg(__half a) { return __float2half(floorf(__half2float(a))); }
__device__ __forceinline__ __half roundg(__half a) { return __float2half(roundf(__half2float(a))); }
__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
__device__ __forceinline__ __half absg(__half a) { return __habs(a); }
__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); }
#endif

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
__device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); }
__device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); }
__device__ __forceinline__ __nv_bfloat16 sqrtg(__nv_bfloat16 a) { return hsqrt(a); }
Expand All @@ -199,3 +179,24 @@ __device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a);
__device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); }
__device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); }
#endif

#if __CUDA_ARCH__ >= 530
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); }
__device__ __forceinline__ __half cosg(__half a) { return hcos(a); }
__device__ __forceinline__ __half sing(__half a) { return hsin(a); }
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
__device__ __forceinline__ __half ceilg(__half a) { return __float2half(ceilf(__half2float(a))); }
__device__ __forceinline__ __half floorg(__half a) { return __float2half(floorf(__half2float(a))); }
__device__ __forceinline__ __half roundg(__half a) { return __float2half(roundf(__half2float(a))); }
__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
__device__ __forceinline__ __half absg(__half a) { return __habs(a); }
__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); }
#endif
5 changes: 3 additions & 2 deletions candle-kernels/src/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ COPY2D_OP(uint32_t, copy2d_u32)
COPY2D_OP(int64_t, copy2d_i64)

#if __CUDA_ARCH__ >= 530
#include <cuda_bf16.h>
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__half, copy2d_f16)
#endif

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
#include <cuda_bf16.h>
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
#endif
#endif
2 changes: 1 addition & 1 deletion candle-kernels/src/indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ extern "C" __global__ void FN_NAME( \
) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \


#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
IS_OP(__nv_bfloat16, int64_t, is_i64_bf16)
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
Expand Down
13 changes: 11 additions & 2 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -571,15 +571,24 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
} \

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
#endif

#if __CUDA_ARCH__ >= 750
SUM_OP(__nv_bfloat16, sum_bf16)
#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 750
// The automatic fallback mechanism for these architectures:
// 1. Converts bfloat16 to float using __bfloat162float
// 2. Performs atomicAdd with floats
// 3. Converts back to bfloat16
// SUM_OP(__nv_bfloat16, sum_bf16)
#endif

#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
Expand Down
2 changes: 1 addition & 1 deletion candle-kernels/src/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ extern "C" __global__ void asort_desc_##RUST_NAME( \
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
} \

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
ASORT_OP(__nv_bfloat16, bf16)
#endif

Expand Down
4 changes: 1 addition & 3 deletions candle-kernels/src/ternary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 530
WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16)
WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16)
WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16)
#endif

#if __CUDA_ARCH__ >= 530
WHERE_OP(__half, int64_t, where_i64_f16)
WHERE_OP(__half, uint32_t, where_u32_f16)
WHERE_OP(__half, uint8_t, where_u8_f16)
Expand Down
3 changes: 1 addition & 2 deletions candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ __device__ T sign_(T t) {
return static_cast<T>(t > static_cast<T>(0)) - static_cast<T>(t < static_cast<T>(0));
}


#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x))
Expand Down
25 changes: 24 additions & 1 deletion candle-nn/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle::{test_device, test_utils::to_vec3_round, Device, DType, Result, Tensor};

fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
Expand Down Expand Up @@ -249,6 +249,27 @@ fn sigmoid(device: &Device) -> Result<()> {
Ok(())
}

fn sigmoid_f16(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?.to_dtype(DType::F16)?;
let s1 = candle_nn::ops::sigmoid(&tensor)?;
let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?;
let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<half::f16>()?;
assert_eq!(diff, half::f16::from_f32(0.));
Ok(())
}

fn sigmoid_bf16(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?.to_dtype(DType::BF16)?;
let s1 = candle_nn::ops::sigmoid(&tensor)?;
let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?;
let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<half::bf16>()?;
assert_eq!(diff, half::bf16::from_f32(0.));
Ok(())
}


test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
Expand All @@ -258,3 +279,5 @@ test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
test_device!(sigmoid_f16, sigmoid_b16_cpu, sigmoid_b16_gpu, sigmoid_b16_metal);
test_device!(sigmoid_bf16, sigmoid_bf16_cpu, sigmoid_bf16_gpu, sigmoid_bf16_metal);