Skip to content

Commit

Permalink
skip sum_bfl16 if<750:atomicAdd with floats
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Jan 21, 2025
1 parent ca5d81f commit 8a95a11
Showing 1 changed file with 8 additions and 39 deletions.
47 changes: 8 additions & 39 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -581,45 +581,14 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm

#if __CUDA_ARCH__ >= 800
SUM_OP(__nv_bfloat16, sum_bf16)
#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800
#define SUM_BF16_OP(FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, const size_t num_dims, const size_t num_sum_dims, \
const size_t *info, const __nv_bfloat16 *inp, __nv_bfloat16 *out) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
const size_t *sum_dims_l = info + 2 * num_dims; \
const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \
i += blockDim.x * gridDim.x) { \
size_t dst_index = i; \
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
size_t stride = sum_dims_s[nd]; \
size_t pre = dst_index / stride; \
size_t post = dst_index % stride; \
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
} \
float val = __bfloat162float(inp[i]); \
atomicAdd((float *)(out + dst_index), val); \
} \
} else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \
i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
size_t dst_index = i; \
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
size_t stride = sum_dims_s[nd]; \
size_t pre = dst_index / stride; \
size_t post = dst_index % stride; \
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
} \
float val = __bfloat162float(inp[strided_i]); \
atomicAdd((float *)(out + dst_index), val); \
} \
} \
}
SUM_BF16_OP(sum_bf16)
#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ >= 750
SUM_OP(__nv_bfloat16, sum_bf16)
#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 750
//The automatic fallback mechanism for these architectures:
// 1. Converts bfloat16 to float using __bfloat162float
// 2. Performs atomicAdd with floats
// 3. Converts back to bfloat16
// SUM_OP(__nv_bfloat16, sum_bf16)
#endif

#if __CUDA_ARCH__ >= 530
Expand Down

0 comments on commit 8a95a11

Please sign in to comment.