diff --git a/CMakeLists.txt b/CMakeLists.txt index e16efc9e8..885625197 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,6 +3,7 @@ # Minimum CMake required cmake_minimum_required(VERSION 3.25) +cmake_policy(SET CMP0104 OLD) project(onnxruntime_extensions LANGUAGES C CXX) # set(CMAKE_VERBOSE_MAKEFILE ON) @@ -294,6 +295,8 @@ endmacro() if(OCOS_USE_CUDA) include(ext_cuda) + include(cutlass) +# include(flash_attention) endif() ####################################################################################################################### @@ -358,12 +361,10 @@ if(OCOS_ENABLE_MATH) list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE}) endif() -file(GLOB TARGET_SRC_CONTRIB "operators/contrib/*.cc" "operators/contrib/*.h*") if (OCOS_USE_CUDA) - file(GLOB TARGET_SRC_CONTRIB_CUDA "operators/contrib/cuda/*.*") - list(APPEND TARGET_SRC_CONTRIB ${TARGET_SRC_CONTRIB_CUDA}) + file(GLOB_RECURSE TARGET_SRC_CUDA "operators/cuda/*.*") + list(APPEND TARGET_SRC ${TARGET_SRC_CUDA}) endif() -list(APPEND TARGET_SRC ${TARGET_SRC_CONTRIB}) # enable the opencv dependency if we have ops that require it if(OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION) @@ -578,6 +579,10 @@ target_include_directories(ocos_operators PUBLIC ${PROJECT_SOURCE_DIR}/base ${PROJECT_SOURCE_DIR}/operators) +if (OCOS_USE_CUDA) + target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) +endif() + set(ocos_libraries) set(OCOS_COMPILE_DEFINITIONS) diff --git a/cmake/ext_cuda.cmake b/cmake/ext_cuda.cmake index 15e66ff99..6a2886289 100644 --- a/cmake/ext_cuda.cmake +++ b/cmake/ext_cuda.cmake @@ -6,6 +6,14 @@ enable_language(CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) +include(CMakeDependentOption) +cmake_dependent_option(USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32" OFF) +option(USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) +if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) + message( STATUS "Turn off flash attention and memory efficient attention since CUDA compiler version < 11.6") + set(USE_FLASH_ATTENTION OFF) + set(USE_MEMORY_EFFICIENT_ATTENTION OFF) +endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11) @@ -22,3 +30,11 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_co set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"") add_compile_definitions(USE_CUDA) +if (USE_FLASH_ATTENTION) + message( STATUS "Enable flash attention") + add_compile_definitions(USE_FLASH_ATTENTION) +endif() +if (USE_MEMORY_EFFICIENT_ATTENTION) + message( STATUS "Enable memory efficient attention") + add_compile_definitions(USE_MEMORY_EFFICIENT_ATTENTION) +endif() diff --git a/cmake/externals/cutlass.cmake b/cmake/externals/cutlass.cmake new file mode 100644 index 000000000..24b9bf72e --- /dev/null +++ b/cmake/externals/cutlass.cmake @@ -0,0 +1,11 @@ +include(FetchContent) +FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git + GIT_TAG v3.1.0 +) + +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) +endif() diff --git a/cmake/externals/flash_attention.cmake b/cmake/externals/flash_attention.cmake new file mode 100644 index 000000000..f1633e6d1 --- /dev/null +++ b/cmake/externals/flash_attention.cmake @@ -0,0 +1,22 @@ +include(FetchContent) +FetchContent_Declare( + flash_attention + GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git + GIT_TAG v2.3.0 +) + +#FetchContent_GetProperties(flash_attention) +#if(NOT flash_attention_POPULATED) +# FetchContent_Populate(flash_attention) +# file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_kernel.h DESTINATION ${PROJECT_SOURCE_DIR}/includes) +#endif() +FetchContent_MakeAvailable(flash_attention) +file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/utils.h DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) +file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/block_info.h DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) +file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/kernel_traits.h DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) +file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/softmax.h DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) +file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_kernel.h DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) +file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/philox.cuh DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) +#file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/dropout.h DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) +#file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/mask.h DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) +#file(COPY ${flash_attention_SOURCE_DIR}/csrc/flash_attn/src/rotary.h DESTINATION ${PROJECT_SOURCE_DIR}/operators/contrib/cuda/flash_attention) diff --git a/operators/cuda/attention_lib/attention_common.h b/operators/cuda/attention_lib/attention_common.h new file mode 100644 index 000000000..6f32bb94b --- /dev/null +++ b/operators/cuda/attention_lib/attention_common.h @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace contrib { + +enum AttentionMaskType { + MASK_NONE, // No mask + MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length + MASK_1D_END_START, // [2 * batch_size] with end positions and start positions + MASK_1D_KEY_SEQ_LEN_START, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length] + MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length] + MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + MASK_UNKNOWN +}; + +enum AttentionQkvFormat { + UNKNOWN, // enum value not set, or depends on qkv projection implementation details + Q_K_V_BNSH, // for non-packed qkv, permuted + Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + QKV_BSN3H, // for TRT fused attention, qkv are packed + Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed + Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed +}; + +enum AttentionKernelType { + AttentionKernel_Unfused, + AttentionKernel_TrtFusedAttention, + AttentionKernel_TrtFlashAttention, + AttentionKernel_TrtFusedCrossAttention, + AttentionKernel_CutlassMemoryEfficientAttention, + AttentionKernel_FlashAttention, + AttentionKernel_Default +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct AttentionParameters { + int batch_size; + int sequence_length; + int kv_sequence_length; // input sequence length of K or V + int past_sequence_length; // sequence length in past state of K or V + int total_sequence_length; // total sequence length of K or V + int max_sequence_length; // max sequence length from 4D mask + int input_hidden_size; // first dimension of weights for input projection + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V + int num_heads; + int num_splits; + int rotary_embedding; + bool is_unidirectional; + bool past_present_share_buffer; + bool do_rotary; + bool broadcast_res_pos_bias; + bool pass_past_in_kv; + float mask_filter_value; + float scale; + bool use_tf32; + AttentionMaskType mask_type; + AttentionQkvFormat qkv_format; +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct PackedAttentionParameters { + int batch_size; + int sequence_length; + int input_hidden_size; // hidden size of input + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V + int num_heads; + float scale; + int token_count; + bool has_relative_position_bias; + bool broadcast_res_pos_bias; + bool use_tf32; +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct GroupQueryAttentionParameters { + int batch_size; + int sequence_length; // sequence length of input query, key, value + int seqlen_past_kv_cache; // sequence length of past kv tensor + int seqlen_present_kv_cache; // sequence length of present kv tensor + int hidden_size; + int num_heads; + int head_size; + int kv_hidden_size; + int kv_num_heads; + int num_splits; // number of splits for splitkv + bool is_unidirectional; // causal + int local_window_size; + bool kv_share_buffer; + bool is_packed_qkv; + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool do_rotary; + bool rotary_interleaved; + float scale; + AttentionQkvFormat qkv_format; + AttentionQkvFormat past_kv_format; + int zeros_count; + int* zero_ptr; +}; + +namespace attention { +// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). +constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; + +// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled). +constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION"; + +// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled). +// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels. +constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION"; + +// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled). +constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION"; + +// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled). +constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION"; + +// Environment variable to enable or disable flash attention. Default is 0 (enabled). +constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; + +// Minimum sequence length to enable memory efficient attention in FP32. +constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; + +// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention +constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV"; +// Default value for the above setting. +constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513; + +// Environment variable to enable loading more KV data in flight in +// DecoderMaskedMultiHeadAttention/DecoderMaskedSelfAttention kernels +constexpr const char* kDecoderMaskedAttentionLoadKVDataInFlight = "ORT_DECODER_MASKED_ATTENTION_LOAD_KV_DATA_IN_FLIGHT"; + +} // namespace attention + +} // namespace contrib diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h b/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h new file mode 100644 index 000000000..bc4bd278a --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "memory_efficient_attention.h" +#include "41_fused_multi_head_attention/kernel_forward.h" + +namespace contrib { +namespace cuda { + +template +struct RightPaddingBatchHook { + using scalar_t = typename AttentionKernel::scalar_t; + using accum_t = typename AttentionKernel::accum_t; + using lse_scalar_t = typename AttentionKernel::lse_scalar_t; + using output_t = typename AttentionKernel::output_t; + using output_accum_t = typename AttentionKernel::output_accum_t; + + static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout; + static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias; + static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock; + static constexpr bool kIsAligned = AttentionKernel::kIsAligned; + static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration; + static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward + static constexpr bool kPreloadV = AttentionKernel::kPreloadV; + static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF; + static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer; + + template + static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; + + // Advance to current batch - in case of different sequence lengths + if (p.seqlen_k_ptr) { + p.num_keys = p.seqlen_k_ptr[batch_id]; + } + + if (query_start >= p.num_queries) { + return false; + } + + // Advance to the current batch / head / query_start + p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH; + p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH; + p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH; + p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value; + + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH); + } + if (p.output_accum_ptr != nullptr) { + p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) + + int64_t(query_start) * (p.head_dim_value * p.num_heads) + + head_id * p.head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + p.output_accum_ptr = (accum_t*)(p.output_ptr); + } + + if (p.logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + p.logsumexp_ptr += + batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (p.causal_diagonal_ptr) { + p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id]; + } + if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + p.causal_diagonal_offset += p.num_keys - p.num_queries; + } + if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft || + p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + p.num_keys = cutlass::fast_min( + int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock), + p.num_keys); + } + + p.num_queries -= query_start; + p.num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + p.q_strideM = p.q_strideH; + p.num_queries = p.num_heads; + p.num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + p.custom_mask_type = AttentionKernel::NoCustomMask; + p.o_strideM = p.head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + p.query_ptr = warp_uniform(p.query_ptr); + p.key_ptr = warp_uniform(p.key_ptr); + p.value_ptr = warp_uniform(p.value_ptr); + if (kSupportsBias) { + p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr); + } + p.output_ptr = warp_uniform(p.output_ptr); + p.output_accum_ptr = warp_uniform(p.output_accum_ptr); + p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr); + p.num_queries = warp_uniform(p.num_queries); + p.num_keys = warp_uniform(p.num_keys); + p.num_heads = warp_uniform(p.num_heads); + p.head_dim = warp_uniform(p.head_dim); + p.head_dim_value = warp_uniform(p.head_dim_value); + p.o_strideM = warp_uniform(p.o_strideM); + p.custom_mask_type = warp_uniform(p.custom_mask_type); + return true; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl_right_padding(typename AK::Params p) { + if (!RightPaddingBatchHook::AdvanceToBlockForGQA(p)) { + return; + } + AK::attention_kernel(p); +} + +template +void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { + using Attention = AttentionKernel; + typename Attention::Params p; + { // set parameters + p.query_ptr = const_cast(reinterpret_cast(params.query)); + p.key_ptr = const_cast(reinterpret_cast(params.key)); + p.value_ptr = const_cast(reinterpret_cast(params.value)); + p.attn_bias_ptr = const_cast(reinterpret_cast(params.attn_bias)); + p.seqstart_q_ptr = params.seqstart_q_ptr; + p.seqstart_k_ptr = params.seqstart_k_ptr; + p.seqlen_k_ptr = params.seqlen_k_ptr; + + p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward + p.output_ptr = reinterpret_cast(params.output); + if (Attention::kNeedsOutputAccumulatorBuffer) { + using Acc = typename Attention::accum_t; + // workspace size: batch_size * sequence_length * num_heads * v_head_size * sizeof(float) + // TODO: ORT_ENFORCE(params.workspace != nullptr, "Need output accumulator buffer but no workspace provided"); + p.output_accum_ptr = reinterpret_cast(params.workspace); + } else { + p.output_accum_ptr = nullptr; + } + p.num_heads = params.num_heads; + p.num_batches = params.batch_size; + p.head_dim = params.qk_head_size; + p.head_dim_value = params.v_head_size; + + p.scale = params.scale; + + // When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel + p.num_queries = params.sequence_length; + p.num_keys = params.kv_sequence_length; + + if (params.causal) { + p.custom_mask_type = Attention::CausalFromBottomRight; + } + + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } + } + + auto kernel_fn = attention_kernel_batched_impl; + if (params.has_custom_right_padding) { + kernel_fn = attention_kernel_batched_impl_right_padding; + } + + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + // TODO: ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!"); + static bool once = [&]() { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + return true; + }(); + } + + // TODO: ORT_ENFORCE(Attention::check_supported(p)); + kernel_fn<<>>(p); +} + +template +void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { + using AlignedAK = AttentionKernel; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 6287) +#endif + // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. + bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && + params.qk_head_size % AlignedAK::kAlignmentK == 0 && + params.v_head_size % AlignedAK::kAlignmentV == 0; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { + LaunchCutlassFmha(params); + })); +} + +template +void DispatchBlockSize(const MemoryEfficientAttentionParams& params) { + if (params.v_head_size <= 64) { + DispatchIsAligned(params); + } else if (params.v_head_size <= 128) { + DispatchIsAligned(params); + } else { + DispatchIsAligned(params); + } +} + +} // namespace cuda +} // namespace contrib + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu new file mode 100644 index 000000000..1900ee46a --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu new file mode 100644 index 000000000..8cd8d6f89 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu new file mode 100644 index 000000000..9454953d9 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu new file mode 100644 index 000000000..f5d956fb7 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu new file mode 100644 index 000000000..608b79798 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "memory_efficient_attention.h" +#include + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params) { + const int32_t& sm = params.sm; + if (sm >= 80) { + run_memory_efficient_attention_sm80(params); + } else if (sm >= 75) { + run_memory_efficient_attention_sm75(params); + } else if (sm >= 70) { + run_memory_efficient_attention_sm70(params); + } else if (sm >= 50) { + run_memory_efficient_attention_sm50(params); + } else { + assert(false); // shall not reach here. + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h new file mode 100644 index 000000000..99188ba01 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#if USE_MEMORY_EFFICIENT_ATTENTION +#include + +namespace contrib { +namespace cuda { + +struct MemoryEfficientAttentionParams { + int32_t sm; + bool is_half; + bool is_kv_bsnh = true; + int32_t batch_size; + int32_t num_heads; + int32_t sequence_length; + int32_t kv_sequence_length; + int32_t max_sequence_length; + int32_t qk_head_size; + int32_t v_head_size; + bool causal; + // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models. + bool is_attn_bias_batched; + + float scale; + + int32_t* seqstart_q_ptr; + int32_t* seqstart_k_ptr; + int32_t* seqlen_k_ptr; + + const void* query; // [B, S, N, H] + const void* key; // [B, L, N, H], where L is kv_sequence_length + const void* value; // [B, L, N, H_v] + const void* attn_bias; // [N, S, S*] or null + void* output; // [B, S, N, H_v] + void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise + cudaStream_t stream; + + static bool need_workspace(size_t v_head_size, bool is_float) { + return (v_head_size > 128 && !is_float); + } + + bool has_custom_right_padding = false; +}; + +void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); + +inline bool has_memory_efficient_attention(int32_t sm, bool is_half) { + return sm >= (is_half ? 53 : 50); +} + +void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params); + +} +} +#endif diff --git a/operators/cuda/attention_lib/flash_attention/block_info.h b/operators/cuda/attention_lib/flash_attention/block_info.h new file mode 100644 index 000000000..1ec632658 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/block_info.h @@ -0,0 +1,44 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +namespace flash { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + template + __device__ BlockInfo(const Params& params, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), + sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), + actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { + } + + template + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/flash.h b/operators/cuda/attention_lib/flash_attention/flash.h new file mode 100644 index 000000000..603a6e068 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash.h @@ -0,0 +1,114 @@ +#pragma once +#include + +namespace flash { +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void* __restrict__ q_ptr = nullptr; + void* __restrict__ k_ptr = nullptr; + void* __restrict__ v_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride = 0; + index_t k_batch_stride = 0; + index_t v_batch_stride = 0; + index_t q_row_stride = 0; + index_t k_row_stride = 0; + index_t v_row_stride = 0; + index_t q_head_stride = 0; + index_t k_head_stride = 0; + index_t v_head_stride = 0; + + // The number of heads. + int h = 0; + int h_k = 0; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio = 0; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr = nullptr; + void* __restrict__ oaccum_ptr = nullptr; + + // The stride between rows of O. + index_t o_batch_stride = 0; + index_t o_row_stride = 0; + index_t o_head_stride = 0; + + // The pointer to the P matrix. + void* __restrict__ p_ptr = nullptr; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr = nullptr; + void* __restrict__ softmax_lseaccum_ptr = nullptr; + + // The dimensions. + int b = 0; + int seqlen_q = 0; + int seqlen_k = 0; + int seqlen_knew = 0; + int d = 0; + int seqlen_q_rounded = 0; + int seqlen_k_rounded = 0; + int d_rounded = 0; + int rotary_dim = 0; + + // The scaling factors for the kernel. + float scale_softmax = 0.0; + float scale_softmax_log2 = 0.0; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q = nullptr; + int* __restrict__ cu_seqlens_k = nullptr; + + int* __restrict__ blockmask = nullptr; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr = nullptr; + void* __restrict__ vnew_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride = 0; + index_t vnew_batch_stride = 0; + index_t knew_row_stride = 0; + index_t vnew_row_stride = 0; + index_t knew_head_stride = 0; + index_t vnew_head_stride = 0; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; + + bool is_bf16 = false; + bool is_causal = false; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; + + int num_splits = 0; // For split-KV version + + const cudaDeviceProp* dprops = nullptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +} \ No newline at end of file diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.cc b/operators/cuda/attention_lib/flash_attention/flash_api.cc new file mode 100644 index 000000000..73dd51fec --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_api.cc @@ -0,0 +1,465 @@ +#if USE_FLASH_ATTENTION + +#include "flash_api.h" +#include "flash.h" +#include "static_switch.h" +#include + +namespace flash { + +void set_params_fprop(Flash_fwd_params& params, + // sizes + size_t batch_size, + size_t seqlen_q, + size_t seqlen_k, + size_t seqlen_q_rounded, + size_t seqlen_k_rounded, + size_t num_heads, + size_t num_heads_k, + size_t head_size, + size_t head_size_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* p_d, + void* softmax_lse_d, + float softmax_scale, + bool is_causal, + bool is_bf16, + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.o_ptr = out; + + params.is_bf16 = is_bf16; + + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + } else { + params.q_batch_stride = 0; + params.k_batch_stride = 0; + params.v_batch_stride = 0; + params.o_batch_stride = 0; + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size + params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; +} + +size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { + size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; + return bytes; +} + +size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, + int max_splits) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks; + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + continue; + } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + +OrtStatusPtr mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16, + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh, + int local_window_size) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + /*cu_seqlens_q*/ nullptr, + /*cu_seqlens_k*/ nullptr, + nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + run_mha_fwd(params, stream); + return nullptr; +} + +OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, head_size) + void* out, // half (total_q, num_heads, head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + cu_seqlens_q, + cu_seqlens_k, + nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + true, + -1, + is_causal ? 0 : -1); + params.dprops = &dprops; + params.num_splits = 0; + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + run_mha_fwd(params, stream); + return nullptr; +} + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); +} + +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // In kv-cache case, seqlen_k_max as kv sequence length + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + past_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + + if (k_new != nullptr && v_new != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; + // All stride are in elements, not bytes. + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = (head_size / 16) * 16; + } + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); + + return nullptr; +} + +} // namespace flash + +#endif // USE_FLASH_ATTENTION diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.h b/operators/cuda/attention_lib/flash_attention/flash_api.h new file mode 100644 index 000000000..512b7a6d9 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_api.h @@ -0,0 +1,92 @@ +#pragma once + +#if USE_FLASH_ATTENTION + +#include +#include +#include +#include "onnxruntime_c_api.h" + +namespace flash { + +OrtStatusPtr mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16, + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh = true, + int local_window_size = -1); + +OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, v_head_size) + void* out, // half (total_q, num_heads, v_head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16); + +OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); + +size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); + +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); + +} // namespace flash + +#endif // USE_FLASH_ATTENTION diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..778941e8e --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..5cfb3019f --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..dda68cafc --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..3eb91029e --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..1d6cec57b --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..166b9a124 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..fd6e6693c --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..520c5482f --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..6b93f9627 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..12def28cb --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..6400a4829 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..81d19b481 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..d84464cc8 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..98fbc9a2e --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..c788cc92f --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..377d6118a --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h new file mode 100644 index 000000000..c44a470f6 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h @@ -0,0 +1,1259 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#endif + +#include +#include +#include + +#include +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" + +namespace flash { +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, + Tensor2& acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + cute::Tensor scores_max_prev = make_fragment_like(scores_max); + cute::copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; + } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + cute::Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_sum); ++mi) { + scores_sum(mi) += scores_sum_cur(mi); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + cute::Layout l = tOrP.layout(); + cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); +#pragma unroll + for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // We exit early and write 0 to gO and gLSE. + // Otherwise we might read OOB elements from gK and gV. + if (n_block_max <= n_block_min) { + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSE(row) = INFINITY; + } + } + return; + } + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + cute::Shape, cute::Int>{}, + make_stride(params.q_row_stride, _1{})); + cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + cute::Shape, cute::Int>{}, + make_stride(params.k_row_stride, _1{})); + cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + cute::Shape, cute::Int>{}, + make_stride(params.v_row_stride, _1{})); + cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + cute::Shape, cute::Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); + + cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{}); + cute::Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ))); + cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < cute::size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + + cute::Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { + cute::cp_async_fence(); + } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + // I can't get the stride from idx_row + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + cute::Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + + // Convert acc_o from fp32 to fp16/bf16 + cute::Tensor rO = flash::convert_type(acc_o); + cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + cute::Shape, cute::Int>{}, + make_stride(params.o_row_stride, _1{})); + cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + cute::Shape>{}, cute::Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + cute::Tensor tOrO = make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(cute::size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < cute::size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSE(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyOaccum, + typename Kernel_traits::GmemTiledCopyO>; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSEaccum(row) = Split ? -INFINITY : INFINITY; + } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + Tensor tQrQ = make_fragment_like(tQgQ); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + // if (cute::thread0()) { print(lse); } + // if (cute::thread0()) { print(acc_o_rowcol); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum>; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + // __syncthreads(); + // if (cute::thread0()) { print(tOgOaccum); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_splitkv(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel(const Params& params) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { + sLSE[row][col] = lse; + } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // 16 rows, so each time we load we can load 8 rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + sLSE[row][col] = expf(lse_accum(l) - lse_logsum); + } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print(tOrO); } + + Tensor rO = flash::convert_type(tOrO); +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + +} // namespace flash + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h new file mode 100644 index 000000000..e2f2505a7 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h @@ -0,0 +1,294 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +namespace flash { + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn_splitkv(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif +} + +template +void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int kBlockM = 64; // Fixed for all head dimensions + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 32; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 64; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 96; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 192; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 256; + size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +} // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..8553913b2 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..8ed5afc6d --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..9d74a13ce --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..235eeaf69 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..a95bda783 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..23546d313 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..18b0d8010 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..5df080b63 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..f6eb273df --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..d2a3dfdc8 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..bbfe75396 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..75123f1d2 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..366ecefef --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..90845fb31 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..b71a69fce --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..81d87e161 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/kernel_traits.h b/operators/cuda/attention_lib/flash_attention/kernel_traits.h new file mode 100644 index 000000000..48e899c2a --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/kernel_traits.h @@ -0,0 +1,367 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +using namespace cute; + +namespace flash { + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; + using ValLayoutMNK = cute::Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = cute::Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomVtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQCount = cute::size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + cute::Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, + cute::Layout>, + cute::Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKV = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomKtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomPdStransposed = decltype(composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomQdOtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = cute::size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{}); + // static constexpr int kSmemdPsumCount = kBlockM; + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + // static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + cute::Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_32, _1>>{}, + cute::Layout>{})); // Val layout, 1 val per store +}; + +} // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/softmax.h b/operators/cuda/attention_lib/flash_attention/softmax.h new file mode 100644 index 000000000..9c31336c9 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/softmax.h @@ -0,0 +1,215 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor& sum) { + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + // const int row_idx_offset = row_idx_offset_ + lane_id / 4; + const int row_idx_offset = row_idx_offset_; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +} // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/static_switch.h b/operators/cuda/attention_lib/flash_attention/static_switch.h new file mode 100644 index 000000000..5b7098894 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/static_switch.h @@ -0,0 +1,64 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/operators/cuda/attention_lib/flash_attention/utils.h b/operators/cuda/attention_lib/flash_attention/utils.h new file mode 100644 index 000000000..cd10bd534 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/utils.h @@ -0,0 +1,499 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t relu2(const uint32_t x); + +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +inline __device__ uint32_t convert_relu2(const float2 x); + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void relu_(Tensor& tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); +#pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +inline __device__ auto convert_type_relu(Tensor const& tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); +#pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, const int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/operators/cuda/attention_lib/group_query_attention.h b/operators/cuda/attention_lib/group_query_attention.h new file mode 100644 index 000000000..312b0dcbf --- /dev/null +++ b/operators/cuda/attention_lib/group_query_attention.h @@ -0,0 +1,455 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "../cuda_type.h" +#include "attention_common.h" +#include "group_query_attention_impl.cuh" +#include "../device_prop.cuh" +#if USE_FLASH_ATTENTION +#include "flash_attention/flash_api.h" +#endif +#if USE_MEMORY_EFFICIENT_ATTENTION +#include "cutlass_fmha/memory_efficient_attention.h" +#endif + +#include "ortx_common.h" + +namespace contrib { + +template +using IAllocatorUniquePtr = std::unique_ptr>; + +template +inline IAllocatorUniquePtr GetScrachBuffer(void* p, std::function deleter) { + return IAllocatorUniquePtr{static_cast(p), deleter}; +} + +template +inline IAllocatorUniquePtr GetCudaScrachBuffer(void* p, Ort::Custom::CUDAKernelContext* ctx) { + return GetScrachBuffer(p, [=](T* p) { + if (p) + ctx->FreeCudaScratchBuffer(p); + }); +} + +template +OrtStatusPtr CheckInputs(const Ort::Custom::Tensor& query, + std::optional*> key, + std::optional*> value, + std::optional*> past_key, + std::optional*> past_value, + std::optional*> cos_cache, + std::optional*> sin_cache, + void* parameters, + int num_heads, + int kv_num_heads, + const Ort::Custom::Tensor& seqlens_k, + const Ort::Custom::Tensor& total_seqlen, + bool is_past_bsnh, + float scale, + int max_threads_per_block) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return OrtW::CreateStatus(MakeString("num_heads should be no larger than ", max_threads_per_block), ORT_INVALID_ARGUMENT); + } + + // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // no packing for q/k/v: + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr + + AttentionQkvFormat qkv_format = Q_K_V_BSNH; + AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; + const bool is_packed_qkv = !key.has_value(); + const auto& query_dims = query.Shape(); + + if (query_dims.size() != 3) { + return OrtW::CreateStatus(MakeString("Input 'query' is expected to have 3 dimensions, got ", query_dims.size()), ORT_INVALID_ARGUMENT); + } + + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int q_hidden_size = static_cast(query_dims[2]); + int head_size = 0; + + if (num_heads % kv_num_heads != 0) { + return OrtW::CreateStatus(MakeString("num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", num_heads % kv_num_heads), ORT_INVALID_ARGUMENT); + } + + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return OrtW::CreateStatus(MakeString("head_size must be a multiple of 8. Got head_size % 8 == ", head_size % 8), ORT_INVALID_ARGUMENT); + } + if (!value.has_value()) { + return OrtW::CreateStatus("Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.", ORT_INVALID_ARGUMENT); + } + const auto& key_dims = (*key)->Shape(); + if (key_dims.size() != 3) { + return OrtW::CreateStatus(MakeString("Input 'key' is expected to have 3 dimensions, got ", key_dims.size()), ORT_INVALID_ARGUMENT); + } else if (query_dims[0] != key_dims[0]) { + return OrtW::CreateStatus("Input 'query' and 'key' shall have same dim 0 (batch size)", ORT_INVALID_ARGUMENT); + } else if (query_dims[1] != key_dims[1]) { + return OrtW::CreateStatus("Input 'query' and 'key' shall have same dim 1 (sequence length)", ORT_INVALID_ARGUMENT); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = (*value)->Shape(); + if (value_dims.size() != 3) { + return OrtW::CreateStatus(MakeString("Input 'value' is expected to have 3 dimensions, got ", value_dims.size()), ORT_INVALID_ARGUMENT); + } else if (query_dims[0] != value_dims[0]) { + return OrtW::CreateStatus("Input 'query' and 'value' shall have same dim 0 (batch size)", ORT_INVALID_ARGUMENT); + } else if (query_dims[1] != value_dims[1]) { + return OrtW::CreateStatus("Input 'query' and 'value' shall have same dim 1 (sequence length)", ORT_INVALID_ARGUMENT); + } else if (value_dims[2] != kv_hidden_size) { + return OrtW::CreateStatus("Input 'value' is expected to have same hidden size as key.", ORT_INVALID_ARGUMENT); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return OrtW::CreateStatus(MakeString("head_size must be a multiple of 8. Got head_size % 8 == ", head_size % 8), ORT_INVALID_ARGUMENT); + } + if (value.has_value()) { + return OrtW::CreateStatus("Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.", ORT_INVALID_ARGUMENT); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + + // Check past-present KV + int32_t past_sequence_length = 0; + if (past_key.has_value() && past_value.has_value()) { + const auto& past_key_dims = (*past_key)->Shape(); + const auto& past_value_dims = (*past_value)->Shape(); + + if (past_key_dims.size() != 4) { + return OrtW::CreateStatus(MakeString("Input 'past_key' is expected to have 4 dimensions, got ", past_key_dims.size()), ORT_INVALID_ARGUMENT); + } + if (past_value_dims.size() != 4) { + return OrtW::CreateStatus(MakeString("Input 'past_value' is expected to have 4 dimensions, got ", past_value_dims.size()), ORT_INVALID_ARGUMENT); + } + + if (past_key_dims[0] != batch_size) { + return OrtW::CreateStatus(MakeString("Input 'past_key' dimension 0 should be batch_size, got ", past_key_dims[0]), ORT_INVALID_ARGUMENT); + } + if (past_value_dims[0] != batch_size) { + return OrtW::CreateStatus(MakeString("Input 'past_value' dimension 0 should be batch_size, got ", past_value_dims[0]), ORT_INVALID_ARGUMENT); + } + + // BNSH + if (!is_past_bsnh) { + if (past_key_dims[2] != past_value_dims[2]) { + return OrtW::CreateStatus(MakeString("BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence length or past sequence length), got ", past_key_dims[1]), ORT_INVALID_ARGUMENT); + } + if (past_key_dims[1] != kv_num_heads) { + return OrtW::CreateStatus("Input 'past_key' shall have kv_num_heads", ORT_INVALID_ARGUMENT); + } + if (past_value_dims[1] != kv_num_heads) { + return OrtW::CreateStatus("Input 'past_value' shall have kv_num_heads", ORT_INVALID_ARGUMENT); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[2]); + // BSNH + } else { + if (past_key_dims[1] != past_value_dims[1]) { + return OrtW::CreateStatus(MakeString("BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence length or past sequence length), got ", past_key_dims[1]), ORT_INVALID_ARGUMENT); + } + if (past_key_dims[2] != kv_num_heads) { + return OrtW::CreateStatus("Input 'past_key' shall have kv_num_heads", ORT_INVALID_ARGUMENT); + } + if (past_value_dims[2] != kv_num_heads) { + return OrtW::CreateStatus("Input 'past_value' shall have kv_num_heads", ORT_INVALID_ARGUMENT); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[1]); + } + + if (past_key_dims[3] != head_size) { + return OrtW::CreateStatus(MakeString("Input 'past_key' dimension 3 should be same as head_size, got ", past_key_dims[3]), ORT_INVALID_ARGUMENT); + } + if (past_value_dims[3] != head_size) { + return OrtW::CreateStatus(MakeString("Input 'past_value' dimension 3 should be same as head_size, got ", past_value_dims[3]), ORT_INVALID_ARGUMENT); + } + } else if (past_key.has_value() || past_value.has_value()) { + return OrtW::CreateStatus("Input 'past_key' and 'past_value' shall be both present or both absent.", ORT_INVALID_ARGUMENT); + } + + // Check seqlens_k tensor (holding past seqlen for token gen) + const auto& seqlens_dim = seqlens_k.Shape(); + if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + return OrtW::CreateStatus("seqlens_k must be shape (batch_size).", ORT_INVALID_ARGUMENT); + } + + // Set present sequence length and kv_share_buffer from input total_seqlen tensor + size_t num_dimensions = total_seqlen.Shape().size(); + int64_t shape_size = total_seqlen.NumberOfElement(); + if (!IsScalarOr1ElementVector(num_dimensions, shape_size)) { + return OrtW::CreateStatus("total_sequence_length tensor must be of one element.", ORT_INVALID_ARGUMENT); + } + int total_sequence_length = *(total_seqlen.Data()); + int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + + if (cos_cache.has_value() && sin_cache.has_value()) { + const auto& cos_dims = (*cos_cache)->Shape(); + const auto& sin_dims = (*sin_cache)->Shape(); + + if (head_size % 16 != 0) { + return OrtW::CreateStatus(MakeString("head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16), ORT_INVALID_ARGUMENT); + } + if (cos_dims[0] != present_sequence_length) { + return OrtW::CreateStatus("cos_cache dimension 0 must be of present_sequence_length.", ORT_INVALID_ARGUMENT); + } + if (sin_dims[0] != present_sequence_length) { + return OrtW::CreateStatus("sin_cache dimension 0 must be of present_sequence_length.", ORT_INVALID_ARGUMENT); + } + if (cos_dims[1] != (head_size / 16) * 8) { + return OrtW::CreateStatus("cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.", ORT_INVALID_ARGUMENT); + } + if (sin_dims[1] != (head_size / 16) * 8) { + return OrtW::CreateStatus("sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.", ORT_INVALID_ARGUMENT); + } + } else if (cos_cache.has_value() || sin_cache.has_value()) { + return OrtW::CreateStatus("Input 'cos_cache' and 'sin_cache' shall be both present or both absent.", ORT_INVALID_ARGUMENT); + } + + bool is_prompt = sequence_length != 1; + + if (parameters != nullptr) { + GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; // sequence length of Q + output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors + output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->hidden_size = q_hidden_size; + output_parameters->num_heads = num_heads; + output_parameters->head_size = head_size; + output_parameters->kv_hidden_size = kv_hidden_size; + output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_packed_qkv = is_packed_qkv; + output_parameters->is_unidirectional = true; + output_parameters->is_prompt = is_prompt; + output_parameters->scale = scale; + output_parameters->qkv_format = qkv_format; + output_parameters->past_kv_format = past_kv_format; + } + + return nullptr; +} + +template +struct GroupQueryAttention { + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 6) return OrtMemType::OrtMemTypeCPUInput; + return OrtMemType::OrtMemTypeDefault; + } + + template + OrtStatusPtr OnModelAttach(const TDict& dict) { + int64_t num_heads = dict.TryToGetAttributeWithDefault("num_heads", (int64_t)0); + int64_t kv_num_heads = dict.TryToGetAttributeWithDefault("kv_num_heads", (int64_t)0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; + int64_t local_window_size = dict.TryToGetAttributeWithDefault("local_window_size", (int64_t)-1); + local_window_size_ = static_cast(local_window_size); + int64_t do_rotary = dict.TryToGetAttributeWithDefault("do_rotary", 0); + do_rotary_ = do_rotary == 1; + int64_t rotary_interleaved = dict.TryToGetAttributeWithDefault("rotary_interleaved", (int64_t)0); + rotary_interleaved_ = rotary_interleaved == 1; + scale_ = dict.TryToGetAttributeWithDefault("scale", 0.0f); + +#if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); +#else + disable_flash_attention_ = true; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif + return nullptr; + } + + OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, std::optional*> key, + std::optional*> value, std::optional*> past_key, std::optional*> past_value, + const ortc::Tensor& seqlens_k, const ortc::Tensor& total_seqlen, std::optional*> cos_cache, + std::optional*> sin_cache, ortc::Tensor& attn_out, std::optional*> present_key, std::optional*> present_value) const { + // TODO: will initialize disable_flash_attention_ be put here or OnModelAttach()? if latter, need to expose a function to get allocator from kernelInfo + IAllocatorUniquePtr zeros_ = disable_flash_attention_ ? nullptr : GetCudaScrachBuffer(ctx->AllocCudaScratchBuffer(kZerosCount), ctx); + + GroupQueryAttentionParameters parameters; + ORTX_RETURN_IF_ERROR(CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, ¶meters, num_heads_, kv_num_heads_, + seqlens_k, total_seqlen, is_past_bsnh_, scale_, DeviceProp::GetCudaDeviceProp().maxThreadsPerBlock)); + parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + parameters.zeros_count = kZerosCount; + parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; + int sequence_length = parameters.sequence_length; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + std::vector output_shape(3, 0); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && flash::is_supported(DeviceProp::GetCudaDeviceProp(), parameters.head_size, parameters.num_heads, parameters.kv_num_heads); + // Allocate buffers + size_t softmax_lse_bytes = 0; + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + // softmax buffer + softmax_lse_bytes = flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads, + parameters.head_size, DeviceProp::GetCudaDeviceProp().multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_buffer = GetCudaScrachBuffer(ctx->AllocCudaScratchBuffer(softmax_lse_bytes), ctx); + auto softmax_lse_accum_buffer = GetCudaScrachBuffer(ctx->AllocCudaScratchBuffer(softmax_lse_accum_bytes), ctx); + auto out_accum_buffer = GetCudaScrachBuffer(ctx->AllocCudaScratchBuffer(out_accum_bytes), ctx); +#else + constexpr bool use_flash_attention = false; + IAllocatorUniquePtr softmax_lse_buffer = nullptr; + IAllocatorUniquePtr softmax_lse_accum_buffer = nullptr; + IAllocatorUniquePtr out_accum_buffer = nullptr; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (DeviceProp::GetCudaDeviceProp().major * 10) + DeviceProp::GetCudaDeviceProp().minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + local_window_size_ == -1 && + do_rotary_ == false && + key != nullptr && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + cuda::has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && cuda::MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } + auto k_buffer = GetCudaScrachBuffer(ctx->AllocCudaScratchBuffer(kv_buffer_bytes), ctx); + auto v_buffer = GetCudaScrachBuffer(ctx->AllocCudaScratchBuffer(kv_buffer_bytes), ctx); + auto fmha_buffer = GetCudaScrachBuffer(ctx->AllocCudaScratchBuffer(fmha_buffer_bytes), ctx); +#else + constexpr bool use_memory_efficient_attention = false; + IAllocatorUniquePtr k_buffer = nullptr; + IAllocatorUniquePtr v_buffer = nullptr; + IAllocatorUniquePtr fmha_buffer = nullptr; +#endif + + // seqlens_k buffer + size_t seqlens_k_bytes = 0; + seqlens_k_bytes = sizeof(int) * parameters.batch_size; + auto seqlens_k_buffer = GetCudaScrachBuffer(ctx->AllocCudaScratchBuffer(seqlens_k_bytes), ctx); + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.seqlen_present_kv_cache, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size}; + } + + using TT = typename CudaT::MappedType; + cuda::GroupQueryAttentionData data; + data.query = reinterpret_cast(query.Data()); + data.key = key.has_value() ? reinterpret_cast((*key)->Data()) : nullptr; + data.value = value.has_value() ? reinterpret_cast((*value)->Data()) : nullptr; + data.past_key = past_key.has_value() ? reinterpret_cast((*past_key)->Data()) : nullptr; + data.past_value = past_value.has_value() ? reinterpret_cast((*past_value)->Data()) : nullptr; + data.output = reinterpret_cast(attn_out.Allocate(output_shape)); + data.present_key = present_key.has_value() ? reinterpret_cast((*present_key)->Allocate(present_dims)) : nullptr; + data.present_value = present_value.has_value() ? reinterpret_cast((*present_value)->Allocate(present_dims)) : nullptr; + data.seqlens_k = const_cast(seqlens_k.Data()); + data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; + if (data.past_key == data.present_key) { + parameters.kv_share_buffer = true; + } else { + parameters.kv_share_buffer = false; + } + // Flash Buffers + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } + if (seqlens_k_buffer != nullptr) { + data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); + } + // Memory Efficient Buffers + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } + // Rotary + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast((*cos_cache)->Data()); + data.sin_cache = reinterpret_cast((*sin_cache)->Data()); + } + + return cuda::QkvToContext( + /*device_prop, ctx.cublas,*/ reinterpret_cast(ctx->GetCudaStream()), parameters, data); + } + + private: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_unidirectional_; + bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + bool disable_flash_attention_; + bool disable_memory_efficient_attention_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/attention_lib/group_query_attention_impl.cu b/operators/cuda/attention_lib/group_query_attention_impl.cu new file mode 100644 index 000000000..0c6cdabae --- /dev/null +++ b/operators/cuda/attention_lib/group_query_attention_impl.cu @@ -0,0 +1,663 @@ +#include +#include +#include "group_query_attention_impl.cuh" +#include "../utils.cuh" +#include "ortx_common.h" +#include "onnxruntime_f16.h" +#ifdef USE_FLASH_ATTENTION +#include "flash_attention/flash_api.h" +#endif +#ifdef USE_MEMORY_EFFICIENT_ATTENTION +#include "cutlass_fmha/memory_efficient_attention.h" +#endif +#include "../device_prop.cuh" + +namespace contrib { +namespace cuda { + +////////// Auxiliary Kernels for KV prep + +// Kernel for seqlens_k +__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { + int id = blockDim.x * blockIdx.x + threadIdx.x; + if (id < batch_size) seqlens_k[id] = seqlen; +} + +// Kernel to append new and past kv in either BSNH or BNSH format +// Adapted from ConcatTensorToTensor kernel in attention_kv_cache.cu file +template +__global__ void ConcatNewToPastKV(const int new_seqlen, + const int past_buffer_seqlen, + const T* past_kv, + const T* new_kv, + T* present_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to past; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int present_buffer_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_buffer_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < past_seqlen + new_seqlen) { + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } +} + +// Use when (H*)*num_heads > 1024 +template +__global__ void ConcatNewToPastKVLarge(const int new_seqlen, + const int past_buffer_seqlen, + const int H, + const int num_heads, + const T* past_kv, + const T* new_kv, + T* present_kv, + const int* seqlens_k, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_buffer_seqlen = gridDim.y; + + const int present_batch_stride = present_buffer_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < past_seqlen + new_seqlen) { + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } + } +} + +// Concat new to past in present. Supports past BSNH or past BNSH +template +OrtStatusPtr LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block, + const bool past_only = false) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; + const int past_sequence_length = parameters.seqlen_past_kv_cache; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CudaCall(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int max_seqlen, + T* kv_buff, + const T* new_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = max_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int max_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = max_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +OrtStatusPtr LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + // Indicates past sequence_length of each sequence + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CudaCall(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +OrtStatusPtr LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CudaCall(cudaGetLastError()); +} + +__global__ void PastToTotalSeqlen(int32_t* seqlens_k, + int32_t* seqlens_k_buff, + const int add_seqlen) { + seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; +} + +// Convert Past to Total sequence length tensor +OrtStatusPtr LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, + const int threads_per_block) { + if (parameters.is_prompt) { + return nullptr; + } + const int batch_size = parameters.batch_size; + const int add_seqlen = is_total ? parameters.sequence_length : 0; + + const dim3 grid(1, 1, 1); + // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads + const dim3 block(batch_size, 1, 1); + + // TODO(aciddelgado): small version + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); + + return CudaCall(cudaGetLastError()); +} + +////////// Launch Kernels + +using BFloat16 = Ort::Custom::BFloat16; + +#if USE_FLASH_ATTENTION +template +OrtStatusPtr FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + bool is_causal = true; + bool is_bf16 = std::is_same::value; + + void* query = reinterpret_cast(const_cast(data.query)); + void* key; + void* value; + + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); + } else { + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; + } + + void* seqlens_k = reinterpret_cast(data.seqlens_k); + if (parameters.is_prompt) { + // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value + // user should use seqlens_k to index into output to get new tokens + if (batch_size <= parameters.zeros_count) { + seqlens_k = parameters.zero_ptr; + } else { + // Launch kernel to create larger seqlen tensor when batch_size > 256 + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); + seqlens_k = data.seqlens_k_total; + } + } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + ORTX_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); + } + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORTX_RETURN_IF_ERROR(flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + batch_size, num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, + scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, + parameters.is_packed_qkv)); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORTX_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + + // TODO: DUMP_TENSOR_INIT(); + // TODO: DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return nullptr; +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +OrtStatusPtr EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + + if (parameters.is_prompt) { + // Launch kernel to copy seqlen + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + } else { + ORTX_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + } + + if (parameters.kv_share_buffer) { + // Share buffer case + if (data.past_key == nullptr || data.past_key != data.present_key) { + return OrtW::API::CreateStatus(ORT_INVALID_ARGUMENT, "Past and present kv shall share the same tensor when kv_share_buffer is on."); + } + // Concatenate new kv in place + ORTX_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + // Not share buffer case + if (data.past_key != nullptr && data.past_key == data.present_key) { + return OrtW::API::CreateStatus(ORT_INVALID_ARGUMENT, "Past and present kv share the same tensor but kv_share_buffer is not on."); + } + // Copy past and concat new KV to present buffer + ORTX_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + } + + // Ungroup if grouped, otherwise use present kv directly + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORTX_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + // TODO: DUMP_TENSOR_INIT(); + // TODO: DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.max_sequence_length = present_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = true; + p.scale = scale; + p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + p.has_custom_right_padding = true; + run_memory_efficient_attention(p); + + // TODO: DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return nullptr; +} +#endif + +////////// API Functions + +template +OrtStatusPtr QkvToContext( +// const cudaDeviceProp& device_prop, +// cublasHandle_t& cublas, + cudaStream_t cuda_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + const cudaDeviceProp& device_prop = DeviceProp::GetCudaDeviceProp(); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; + +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, cuda_stream, parameters, data, scale); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, cuda_stream, parameters, data, scale); + } +#endif + + return OrtW::API::CreateStatus(ORT_INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); +} + +template struct GroupQueryAttentionData; + +template OrtStatusPtr QkvToContext( +// const cudaDeviceProp& device_prop, +// cublasHandle_t& cublas, + cudaStream_t cuda_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +template struct GroupQueryAttentionData; + +template OrtStatusPtr QkvToContext( +// const cudaDeviceProp& device_prop, +// cublasHandle_t& cublas, + cudaStream_t cuda_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); +} +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/attention_lib/group_query_attention_impl.cuh b/operators/cuda/attention_lib/group_query_attention_impl.cuh new file mode 100644 index 000000000..00f11a500 --- /dev/null +++ b/operators/cuda/attention_lib/group_query_attention_impl.cuh @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "onnxruntime_c_api.h" +#include "attention_common.h" + +namespace contrib { +namespace cuda { + +template +struct GroupQueryAttentionData { + // Input Tensors + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + int* seqlens_k = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; + int* seqlens_k_total = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors + T* output = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; + // Kernel Flags + bool use_flash_attention = false; + bool use_memory_efficient_attention = false; +}; + +template +OrtStatusPtr QkvToContext( +// const cudaDeviceProp& device_prop, +// cublasHandle_t& cublas, // TODO: cublas is not used at all + cudaStream_t cuda_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +} // namespace cuda +} // namespace contrib diff --git a/operators/contrib/contrib.cc b/operators/cuda/cuda_ops.cc similarity index 69% rename from operators/contrib/contrib.cc rename to operators/cuda/cuda_ops.cc index 39cc02f85..6701a6948 100644 --- a/operators/contrib/contrib.cc +++ b/operators/cuda/cuda_ops.cc @@ -4,7 +4,8 @@ #include "ocos.h" #ifdef USE_CUDA -#include "cuda/fast_gelu.h" +#include "fast_gelu.h" +#include "attention_lib/group_query_attention.h" #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { @@ -15,6 +16,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("FastGelu", contrib::FastGelu), #if ORT_API_VERSION >= 16 + CustomCudaStructV2("GroupQueryAttention", contrib::GroupQueryAttention), + CustomCudaStructV2("GroupQueryAttention", contrib::GroupQueryAttention), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu) #endif diff --git a/operators/contrib/cuda/cuda_type.h b/operators/cuda/cuda_type.h similarity index 100% rename from operators/contrib/cuda/cuda_type.h rename to operators/cuda/cuda_type.h diff --git a/operators/contrib/cuda/device_prop.cuh b/operators/cuda/device_prop.cuh similarity index 100% rename from operators/contrib/cuda/device_prop.cuh rename to operators/cuda/device_prop.cuh diff --git a/operators/contrib/cuda/fast_gelu.h b/operators/cuda/fast_gelu.h similarity index 100% rename from operators/contrib/cuda/fast_gelu.h rename to operators/cuda/fast_gelu.h diff --git a/operators/contrib/cuda/fast_gelu_impl.cu b/operators/cuda/fast_gelu_impl.cu similarity index 100% rename from operators/contrib/cuda/fast_gelu_impl.cu rename to operators/cuda/fast_gelu_impl.cu diff --git a/operators/contrib/cuda/fast_gelu_impl.cuh b/operators/cuda/fast_gelu_impl.cuh similarity index 100% rename from operators/contrib/cuda/fast_gelu_impl.cuh rename to operators/cuda/fast_gelu_impl.cuh diff --git a/operators/contrib/cuda/utils.cuh b/operators/cuda/utils.cuh similarity index 97% rename from operators/contrib/cuda/utils.cuh rename to operators/cuda/utils.cuh index fe3d27daa..552f9b146 100644 --- a/operators/contrib/cuda/utils.cuh +++ b/operators/cuda/utils.cuh @@ -189,5 +189,11 @@ __device__ __inline__ half2 _Tanh(half2 a) { return __float22half2_rn(tmp); } -template <> + +// TODO: +inline OrtStatusPtr CudaCall(cudaError_t cuda_error) { + if (cuda_error == cudaSuccess) return nullptr; + return OrtW::API::CreateStatus(ORT_FAIL, MakeString("cuda error:", (int)cuda_error).c_str()); +} + __device__ __inline__ ortc::BFloat16 _Tanh(ortc::BFloat16 a) { return tanhf(static_cast(a)); } diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index d868fe675..be6b3753a 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -6,7 +6,601 @@ from onnxruntime_extensions import get_library_path as _get_library_path import onnxruntime as _ort +import pdb +import math +import os +import platform +import random + +import torch +from einops import rearrange, repeat +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, OrtValue, SessionOptions + +torch.manual_seed(0) +class Formats: + BSNH = 0 + BNSH = 1 + + +class Config: + batch_size = 0 + sequence_length = 0 + kv_sequence_length = 0 + past_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, s, s2, sp, n, n2, h): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.past_sequence_length = sp + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, + reorder_ops=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + +def create_group_query_attention_graph_past( + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, +): + past_kv_seqlen = config.kv_sequence_length + present_kv_seqlen = ( + config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length + ) + #pdb.set_trace() + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key" if not packed else "", + "value" if not packed else "", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", +# "cos_cache" if rotary else "", +# "sin_cache" if rotary else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, + # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # kv_share_buffer=1 if share_buffer else 0, + domain="ai.onnx.contrib", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), + ], + ), + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "seqlens_k", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "total_sequence_length", + TensorProto.INT32, + [1], + ), + ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = make_onnx_model(graph) + return model + +def gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + past_kv_format=Formats.BSNH, + share_buffer=True, + window_size=-1, + rotary_interleaved=False, +): + onnx_model = create_group_query_attention_graph_past( + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, + ) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + past_k = k.clone() + past_v = v.clone() + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if share_buffer: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(np.int32), + "total_sequence_length": torch.tensor([config.kv_sequence_length], dtype=torch.int32) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + sess_options.register_custom_ops_library(_get_library_path()) + ort_session = InferenceSession(onnx_model.SerializeToString(), sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_input( + "past_key", "cuda", 0, np.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cuda", + 0, + np.float16, + ort_inputs["past_value"].shape(), + ort_inputs["past_value"].data_ptr(), + ) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) + io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = np.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + else: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "past_key": past_k.detach().cpu().numpy(), + "past_value": past_v.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(np.int32), + "total_sequence_length": torch.tensor( + [config.kv_sequence_length + config.sequence_length], dtype=torch.int32 + ) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + sess_options.register_custom_ops_library(_get_library_path()) + ort_session = InferenceSession(onnx_model.SerializeToString(), sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) + io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = np.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + +def parity_check_gqa_past_no_buff( + config, + causal=False, + local=False, + past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, + rtol=1e-3, + atol=1e-3, +): + torch.manual_seed(69) + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + k_cache_ref = torch.cat((k_cache_ref, new_k), 1) + v_cache_ref = torch.cat((v_cache_ref, new_v), 1) + # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) + cache_seqlens = torch.randint( + 0, + config.kv_sequence_length, + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + + cos, sin = None, None + q_ro, k_ro = q, new_k + + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length + out_ref, _ = attention_ref( + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + # Compare results + print( + "NO buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + np.mean(np.abs(out - out_ref)), + np.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) class TestCudaOps(unittest.TestCase): @staticmethod @@ -115,7 +709,83 @@ def test_cuda_fastgelu_f16(self): assert_almost_equal(y, expected_y) else: print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.') + + @staticmethod + def _create_GroupQueryAttention_test_model(domain='ai.onnx.contrib'): + nodes = [ + helper.make_node( + 'GroupQueryAttention', + #['query', 'key', 'value', 'past_key', 'past_value', 'seqlens_k', 'total_seqlen', 'cos_cache', 'sin_cache'], + ['query', 'key', 'value', 'past_key', 'past_value', 'seqlens_k', 'total_seqlen'], + ['attn_out', 'present_key', 'present_value'], + #domain=domain, num_heads=32, kv_num_heads=32, scale=0.0, local_window_size=-1, do_rotary=0, rotary_interleaved=0) + domain=domain, num_heads=32, kv_num_heads=32) + ] + + query = helper.make_tensor_value_info( + 'query', onnx_proto.TensorProto.FLOAT16, [5,1,512]) + key = helper.make_tensor_value_info( + 'key', onnx_proto.TensorProto.FLOAT16, [5,1,512]) + value = helper.make_tensor_value_info( + 'value', onnx_proto.TensorProto.FLOAT16, [5,1,512]) + past_key = helper.make_tensor_value_info( + 'past_key', onnx_proto.TensorProto.FLOAT16, [5,32,128,16]) + past_value = helper.make_tensor_value_info( + 'past_value', onnx_proto.TensorProto.FLOAT16, [5,32,128,16]) + seqlens_k = helper.make_tensor_value_info( + 'seqlens_k', onnx_proto.TensorProto.INT32, [5]) + total_seqlen = helper.make_tensor_value_info( + 'total_seqlen', onnx_proto.TensorProto.INT32, [1]) +# cos_cache = helper.make_tensor_value_info( +# 'cos_cache', onnx_proto.TensorProto.FLOAT, []) +# sin_cache = helper.make_tensor_value_info( +# 'sin_cache', onnx_proto.TensorProto.FLOAT, []) + attn_out = helper.make_tensor_value_info( + 'attn_out', onnx_proto.TensorProto.FLOAT16, [5,1,512]) + present_key = helper.make_tensor_value_info( + 'present_key', onnx_proto.TensorProto.FLOAT16, [5,32,129,16]) + present_value = helper.make_tensor_value_info( + 'present_value', onnx_proto.TensorProto.FLOAT16, [5,32,129,16]) + + graph = helper.make_graph(nodes, 'testgqa', + #[query, key, value, past_key, past_value, seqlens_k, total_seqlen, cos_cache, sin_cache], + [query, key, value, past_key, past_value, seqlens_k, total_seqlen], + [attn_out, present_key, present_value]) + model = make_onnx_model(graph) + return model + + def test_cuda_GroupQueryAttention(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_GroupQueryAttention_test_model() + #self.assertIn('op_type: "NegPos"', str(onnx_model)) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + query = np.random.randn(5,1,512).astype(np.float16) + key = np.random.randn(5,1,512).astype(np.float16) + value = np.random.randn(5,1,512).astype(np.float16) + past_key = np.random.randn(5,32,128,16).astype(np.float16) + past_value = np.random.randn(5,32,128,16).astype(np.float16) + seqlens_k = np.array([128, 87, 0, 22, 125]).astype(np.int32) + total_seqlen = np.array([129]).astype(np.int32) + y = sess.run(None, {'query':query, 'key':key, 'value':value, 'past_key':past_key, 'past_value':past_value, 'seqlens_k':seqlens_k, 'total_seqlen':total_seqlen}) + def test_cuda_GroupQueryAttention2(self): + random.seed(69) + for b in [5]: + for s, s2 in [(1,128)]: + for n, n2 in [(32, 32)]: + for h in [16]: + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past_no_buff( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) if __name__ == "__main__": unittest.main() diff --git a/test/static_test/test_cuda_eager.cc b/test/static_test/test_cuda_eager.cc index 65de140de..ed594fdd6 100644 --- a/test/static_test/test_cuda_eager.cc +++ b/test/static_test/test_cuda_eager.cc @@ -9,10 +9,13 @@ #ifdef USE_CUDA #include "math/cuda/negpos_def.h" -#include "contrib/cuda/fast_gelu.h" +#include "cuda/attention_lib/group_query_attention.h" +#include "cuda/fast_gelu.h" #include #include +#include + class CudaAllocator : public Ort::Custom::IAllocator { public: @@ -100,4 +103,90 @@ TEST(CudaOp, test_fastgelu_eager) { ASSERT_NEAR(host_output[5], 5.5, 0.01f); } +TEST(CudaOp, test_gqa_eager) { + + MockCudaKernelContext mock_cuda_kc; + ortc::NamedArgumentDict dict({"num_heads", "kv_num_heads", "local_window_size", "rotary", "rotary_interleaved"}, + std::make_tuple((int64_t)4, (int64_t)4, (int64_t)-1, (int64_t)false, (int64_t)false)); + contrib::GroupQueryAttention GQA; + GQA.OnModelAttach(dict); + + std::vector query_fp32_data{-1.6592, 1.9277, 0.8760, 0.3105, 1.1377, -0.7349, -0.8086, + 0.5542, 0.4773, -0.7651, -0.3364, 0.8901, -1.6172, -1.3828, + 2.2129, -0.6030, -0.8359, 0.8130, -0.2239, -0.3994, 0.2673, + -0.1252, 0.3840, -0.5801, 0.1830, -1.0537, -1.7383, -0.9712, + 0.2480, -1.3701, 0.7559, -0.5557}; + std::vector query_data(query_fp32_data.begin(), query_fp32_data.end()); + + std::vector past_key_fp32_data{ 0.5010, -0.0542, 0.5386, 0.2764, -1.4385, -1.5312, 0.1119, 1.7080, + -0.1099, -0.3079, 0.6372, -0.7539, -0.0911, -0.9551, 0.5029, 0.2251, + 1.3135, 2.0723, 1.2764, -0.2993, 1.6289, -0.5664, -1.5410, 0.8188, + 0.3479, -0.6240, -0.1943, 0.0476, 0.5396, -0.3943, -1.1904, 1.7070, + -0.7700, -1.3760, 0.5176, -0.7925, -0.0111, 0.4668, 0.7832, -2.2246, + 1.0742, -0.0551, -0.3535, -2.1895, 0.6045, -0.1617, 1.8232, 0.5317, + -0.2417, 0.6602, 0.1171, 2.5059, -0.8545, 1.5771, 0.7280, -0.6860, + 0.2258, 0.4800, -0.3633, 1.7559, 1.8066, 0.0654, 0.0540, -1.3291}; + std::vector past_key_data(past_key_fp32_data.begin(), past_key_fp32_data.end()); + + std::vector past_value_fp32_data {-0.5835, -0.8921, -0.5298, 1.0850, 1.2051, -0.5659, -0.1124, -0.6567, + -1.1182, -0.5957, 0.2952, 1.0215, -0.7632, 0.7295, -0.4319, -0.4116, + -0.5938, -1.2607, -0.3037, 0.5249, 0.1610, -0.0620, -0.1490, -0.1721, + -1.3164, 0.5884, 1.0400, 1.2471, -0.9409, -2.7012, -0.1023, -0.5967, + -0.7583, 0.8965, -1.5752, -0.8535, -0.2247, -0.7705, 0.8159, 0.2113, + -1.5742, -0.3538, -0.6343, -0.3789, 0.2079, 1.6826, 1.7314, -1.3691, + 0.4917, 0.7573, 0.5498, -0.3804, -0.0951, -0.8687, -2.8359, -0.5874, + -0.9648, 0.2649, -0.0262, 0.5845, 0.3723, 1.0117, 0.3867, -2.3340}; + std::vector past_value_data(past_value_fp32_data.begin(), past_value_fp32_data.end()); + + std::vector key_fp32_data {-0.9658, -0.2551, -0.3589, 0.7075, 0.5664, -0.8550, -1.8037, -0.0263, + -2.0117, 1.2432, -0.1371, -0.6460, 1.6084, -0.7856, 0.3774, 0.0493, + -1.9062, 1.6357, 1.6689, 0.6250, -0.9961, -1.1406, -0.5303, -0.5591, + -0.2861, -1.4609, -0.3911, 0.9136, 0.4893, 0.1588, 0.5972, -0.9507}; + std::vector key_data(key_fp32_data.begin(), key_fp32_data.end()); + + std::vector value_fp32_data {1.7578, 0.7573, -0.3792, -0.2634, 0.0267, 0.1066, -0.4268, 1.8516, + -1.1758, 0.5981, -0.3325, 1.5234, 0.7876, -0.1825, 0.6123, 0.9810, + 0.2473, 1.1494, 1.4395, -0.8579, 1.0684, -0.4692, -0.1188, -1.5713, + -1.5430, -2.5391, 0.8301, -0.3464, -0.3789, -2.0332, -2.0508, -0.3186}; + std::vector value_data(value_fp32_data.begin(), value_fp32_data.end()); + + std::vector seqlens_k = {2}; + std::vector total_sequence_length = {3}; + + auto cuda_alloc = mock_cuda_kc.GetCudaAllocator(); + void* query_data_gpu = cuda_alloc->Alloc(sizeof(ortc::MFloat16) * query_data.size()); + cudaMemcpyAsync(query_data_gpu, query_data.data(), sizeof(ortc::MFloat16)*query_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + void* past_key_data_gpu = cuda_alloc->Alloc(sizeof(ortc::MFloat16) * past_key_data.size()); + cudaMemcpyAsync(past_key_data_gpu, past_key_data.data(), sizeof(ortc::MFloat16)*past_key_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + void* past_value_data_gpu = cuda_alloc->Alloc(sizeof(ortc::MFloat16) * past_value_data.size()); + cudaMemcpyAsync(past_value_data_gpu, past_value_data.data(), sizeof(ortc::MFloat16)*past_value_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + void* key_data_gpu = cuda_alloc->Alloc(sizeof(ortc::MFloat16) * key_data.size()); + cudaMemcpyAsync(key_data_gpu, key_data.data(), sizeof(ortc::MFloat16)*key_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + void* value_data_gpu = cuda_alloc->Alloc(sizeof(ortc::MFloat16) * value_data.size()); + cudaMemcpyAsync(value_data_gpu, value_data.data(), sizeof(ortc::MFloat16)*value_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + void* seqlens_k_data_gpu = cuda_alloc->Alloc(sizeof(int32_t)); + cudaMemcpyAsync(seqlens_k_data_gpu, seqlens_k.data(), sizeof(int32_t)*seqlens_k.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + // input tensors + ortc::Tensor query(std::vector{1, 1, 32}, query_data_gpu); + ortc::Tensor key(std::vector{1, 1, 32}, key_data_gpu); + ortc::Tensor value(std::vector{1, 1, 32}, value_data_gpu); + + ortc::Tensor past_key(std::vector{1, 4, 2, 8}, past_key_data_gpu); + ortc::Tensor past_value(std::vector{1, 4, 2, 8}, past_value_data_gpu); + ortc::Tensor seqlens_k_gpu(std::vector{1,}, seqlens_k_data_gpu); + ortc::Tensor total_sequence_cpu(std::vector{1,}, total_sequence_length.data()); + ortc::Tensor output(cuda_alloc); + + auto status = GQA.Compute(&mock_cuda_kc, query, &key, &value, &past_key, &past_value, seqlens_k_gpu, total_sequence_cpu, std::nullopt, std::nullopt, output, std::nullopt, std::nullopt); + + cudaDeviceSynchronize(); + + assert(status == nullptr); +} + #endif \ No newline at end of file diff --git a/transformers/__init__.py b/transformers/__init__.py new file mode 100644 index 000000000..90c03fa9b --- /dev/null +++ b/transformers/__init__.py @@ -0,0 +1,8 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +import sys + +sys.path.append(os.path.dirname(__file__)) diff --git a/transformers/dynamo_onnx_helper.py b/transformers/dynamo_onnx_helper.py new file mode 100644 index 000000000..9a66afe3a --- /dev/null +++ b/transformers/dynamo_onnx_helper.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +import onnx + + +class DynamoOnnxHelper: + """ + Helper class for processing ONNX models exported by torch Dynamo. + """ + + def __init__(self, model: onnx.ModelProto): + self.model = model + + def update_edges(self, edge_mapping: dict) -> None: + """ + Updates the edges in the model according to the given mapping. + """ + for node in self.model.graph.node: + for i in range(len(node.input)): + if node.input[i] in edge_mapping: + node.input[i] = edge_mapping[node.input[i]] + for i in range(len(node.output)): + if node.output[i] in edge_mapping: + node.output[i] = edge_mapping[node.output[i]] + + for graph_input in self.model.graph.input: + if graph_input.name in edge_mapping: + graph_input.name = edge_mapping[graph_input.name] + for graph_output in self.model.graph.output: + if graph_output.name in edge_mapping: + graph_output.name = edge_mapping[graph_output.name] + + def unroll_function(self, func_name: str) -> None: + """ + Unrolls the function with the given name in the model. + """ + logging.info(f"Unrolling function {func_name}...") + nodes_to_remove = [] + nodes_to_add = [] + edges_to_remove = [] + edges_to_add = [] + for node in self.model.graph.node: + if node.op_type == func_name: + nodes_to_remove.append(node) + edges_to_remove.extend(list(node.input) + list(node.output)) + + func_to_remove = None + for f in self.model.functions: + if f.name == func_name: + nodes_to_add.extend(list(f.node)) + edges_to_add.extend(list(f.input) + list(f.output)) + func_to_remove = f + + assert len(edges_to_remove) == len(edges_to_add) + + for node in nodes_to_remove: + self.model.graph.node.remove(node) + for node in nodes_to_add: + self.model.graph.node.append(node) + if func_to_remove is not None: + self.model.functions.remove(func_to_remove) + + edge_mapping = {} + for i in range(len(edges_to_remove)): + k = edges_to_remove[i] + v = edges_to_add[i] + if k != v: + edge_mapping[k] = v + + return self.update_edges(edge_mapping) + + def remove_function(self, func_name: str, input_id: int, output_id: int) -> None: + """ + Removes the function in the model. + """ + edge_mapping = {} + nodes_to_remove = [] + for node in self.model.graph.node: + if node.op_type.find(func_name) != -1: + edge_mapping[node.input[input_id]] = node.output[output_id] + nodes_to_remove.append(node) + for node in nodes_to_remove: + self.model.graph.node.remove(node) + + self.update_edges(edge_mapping) + + def remove_dropout_layer(self) -> None: + """ + Removes the dropout layer in the model. + """ + logging.info("Removing dropout layer...") + self.remove_function("Dropout", 0, 0) + + def remove_lm_head_layer(self) -> None: + """ + Removes the LM head layer in the model. + """ + logging.info("Removing LM head layer...") + # bugbug: need to copy the right vi over + self.remove_function("Linear_lm_head", 2, 0) diff --git a/transformers/float16.py b/transformers/float16.py new file mode 100644 index 000000000..48c79b1d5 --- /dev/null +++ b/transformers/float16.py @@ -0,0 +1,503 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +# This file is modified from https://github.com/microsoft/onnxconverter-common/blob/master/onnxconverter_common/float16.py +# Modifications: +# (1) Update default value of min_positive_val and max_finite_val +# (2) keep_io_types can be list of names +# (3) convert initializers if needed to preserve precision +# (4) add force_fp16_initializers option +# (5) handle Resize and GroupNorm with mixed float inputs +# (6) allow convert_float_to_float16 to accept model path + +import itertools +import logging +import os +import tempfile +from typing import Dict + +import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper +from onnx.shape_inference import infer_shapes, infer_shapes_path +from packaging import version + +logger = logging.getLogger(__name__) + + +def _npfloat16_to_int(np_list): + """ + Convert numpy float16 to python int. + + :param np_list: numpy float16 list + :return int_list: python int list + """ + return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list] + + +def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0): + """ + Convert float32 numpy array to float16 without changing sign or finiteness. + Positive values less than min_positive_val are mapped to min_positive_val. + Positive finite values greater than max_finite_val are mapped to max_finite_val. + Similar for negative values. NaN, 0, inf, and -inf are unchanged. + """ + + def between(a, b, c): + return np.logical_and(a < b, b < c) + + if np_array[np.where(np_array > 0)].shape[0] > 0: + positive_max = np_array[np.where(np_array > 0)].max() + positive_min = np_array[np.where(np_array > 0)].min() + if positive_max >= max_finite_val: + logger.debug(f"the float32 number {positive_max} will be truncated to {max_finite_val}") + if positive_min <= min_positive_val: + logger.debug(f"the float32 number {positive_min} will be truncated to {min_positive_val}") + + if np_array[np.where(np_array < 0)].shape[0] > 0: + negative_max = np_array[np.where(np_array < 0)].max() + negative_min = np_array[np.where(np_array < 0)].min() + if negative_min <= -max_finite_val: + logger.debug(f"the float32 number {negative_min} will be truncated to {-max_finite_val}") + if negative_max >= -min_positive_val: + logger.debug(f"the float32 number {negative_max} will be truncated to {-min_positive_val}") + + np_array = np.where(between(0, np_array, min_positive_val), min_positive_val, np_array) + np_array = np.where(between(-min_positive_val, np_array, 0), -min_positive_val, np_array) + np_array = np.where(between(max_finite_val, np_array, float("inf")), max_finite_val, np_array) + np_array = np.where(between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array) + return np.float16(np_array) + + +def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0): + """Convert tensor float to float16. + + Args: + tensor (TensorProto): the tensor to convert. + min_positive_val (float, optional): minimal positive value. Defaults to 1e-7. + max_finite_val (float, optional): maximal finite value. Defaults to 1e4. + + Raises: + ValueError: input type is not TensorProto. + + Returns: + TensorProto: the converted tensor. + """ + + if not isinstance(tensor, TensorProto): + raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}") + + if tensor.data_type == TensorProto.FLOAT: + tensor.data_type = TensorProto.FLOAT16 + # convert float_data (float type) to float16 and write to int32_data + if tensor.float_data: + float16_data = convert_np_to_float16(np.array(tensor.float_data), min_positive_val, max_finite_val) + int_list = _npfloat16_to_int(float16_data) + tensor.int32_data[:] = int_list + tensor.float_data[:] = [] + # convert raw_data (bytes type) + if tensor.raw_data: + # convert n.raw_data to float + float32_list = np.frombuffer(tensor.raw_data, dtype="float32") + # convert float to float16 + float16_list = convert_np_to_float16(float32_list, min_positive_val, max_finite_val) + # convert float16 to bytes and write back to raw_data + tensor.raw_data = float16_list.tobytes() + return tensor + + +def make_value_info_from_tensor(tensor): + shape = numpy_helper.to_array(tensor).shape + return helper.make_tensor_value_info(tensor.name, tensor.data_type, shape) + + +DEFAULT_OP_BLOCK_LIST = [ + "ArrayFeatureExtractor", + "Binarizer", + "CastMap", + "CategoryMapper", + "DictVectorizer", + "FeatureVectorizer", + "Imputer", + "LabelEncoder", + "LinearClassifier", + "LinearRegressor", + "Normalizer", + "OneHotEncoder", + "RandomUniformLike", + "SVMClassifier", + "SVMRegressor", + "Scaler", + "TreeEnsembleClassifier", + "TreeEnsembleRegressor", + "ZipMap", + "NonMaxSuppression", + "TopK", + "RoiAlign", + "Range", + "CumSum", + "Min", + "Max", + "Upsample", +] + + +# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices +# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this. +ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]} + + +class InitializerTracker: + """Class for keeping track of initializer.""" + + def __init__(self, initializer: TensorProto): + self.initializer = initializer + self.fp32_nodes = [] + self.fp16_nodes = [] + + def add_node(self, node: NodeProto, is_node_blocked): + if is_node_blocked: + self.fp32_nodes.append(node) + else: + self.fp16_nodes.append(node) + + +def convert_float_to_float16( + model, + min_positive_val=5.96e-08, + max_finite_val=65504.0, + keep_io_types=False, + disable_shape_infer=False, + op_block_list=None, + node_block_list=None, + force_fp16_initializers=False, + force_fp16_inputs=None, + use_bfloat16_as_blocked_nodes_dtype=False, +): + """Convert tensor float type in the input ONNX model to tensor float16. + + Args: + model (ModelProto or str): The ONNX model or path of the model to convert. + min_positive_val (float, optional): minimal positive value. Defaults to 5.96e-08. + max_finite_val (float, optional): maximal finite value of float16. Defaults to 65504. + keep_io_types (Union[bool, List[str]], optional): It could be boolean or a list of float32 input/output names. + If True, model inputs/outputs should be left as float32. + Defaults to False. + disable_shape_infer (bool, optional): Skips running onnx shape/type inference. + Useful if shape inference has been done. Defaults to False. + op_block_list (List[str], optional): List of op types to leave as float32. + Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`. + node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None. + force_fp16_initializers(bool): force converting all float initializers to float16. + Default to false, which will convert only the one needed to avoid precision loss. + force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if + this script's preference it to keep them in float32. + Raises: + ValueError: input type is not ModelProto. + + Returns: + ModelProto: converted model. + """ + assert ( + min_positive_val >= 5.96e-08 + ), "invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05" + assert max_finite_val <= float(np.finfo(np.float16).max), "invalid max_finite_val. largest float16 value: 65504" + + force_fp16_inputs_dict = {} if force_fp16_inputs is None else force_fp16_inputs + + if isinstance(model, str): + model_path = model + if version.parse(onnx.__version__) >= version.parse("1.8.0") and not disable_shape_infer: + # shape_infer_model_path should be in the same folder of model_path + with tempfile.NamedTemporaryFile(dir=os.path.dirname(model_path)) as tmpfile: + shape_infer_model_path = tmpfile.name + # infer_shapes_path can be used for model >2GB, and infer_shapes cannot. + infer_shapes_path(model_path, shape_infer_model_path) + model = onnx.load(shape_infer_model_path) + disable_shape_infer = True + else: + model = onnx.load(model_path) + + if not isinstance(model, ModelProto): + raise ValueError(f"Expected an ONNX ModelProto but got {type(model)}") + + func_infer_shape = None + if not disable_shape_infer and version.parse(onnx.__version__) >= version.parse("1.2.0"): + try: + func_infer_shape = infer_shapes + finally: + pass + + # create blocklists + if op_block_list is None: + op_block_list = DEFAULT_OP_BLOCK_LIST + if node_block_list is None: + node_block_list = [] + op_block_list = set(op_block_list) + node_block_list = set(node_block_list) + + logger.debug( + f"fp16 parameters: min_positive_val={min_positive_val} max_finite_val={max_finite_val} keep_io_types={keep_io_types} disable_shape_infer={disable_shape_infer} op_block_list={op_block_list} node_block_list={node_block_list} force_fp16_initializers={force_fp16_initializers}" + ) + + # create a queue for BFS + queue = [] + value_info_list = [] + node_list = [] + + # Some operators (Like Resize or GroupNorm) have data type fixed as float for some input. + # When it is converted to float16, there are mixed types: some inputs are float32 and some are float16. + # This list keeps track of such nodes that are not in block list. + mixed_float_type_node_list = [] + + # type inference on input model + if func_infer_shape is not None: + model = func_infer_shape(model) + queue.append(model) + name_mapping = {} + graph_io_to_skip = set() + io_casts = set() + + fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == TensorProto.FLOAT] + fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == TensorProto.FLOAT] + if isinstance(keep_io_types, list): + fp32_inputs = [n for n in fp32_inputs if n in keep_io_types] + fp32_outputs = [n for n in fp32_outputs if n in keep_io_types] + elif not keep_io_types: + fp32_inputs = [] + fp32_outputs = [] + + for i, n in enumerate(model.graph.input): + if n.name in fp32_inputs: + output_name = "graph_input_cast_" + str(i) + name_mapping[n.name] = output_name + graph_io_to_skip.add(n.name) + + node_name = "graph_input_cast" + str(i) + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(n) + new_value_info.name = output_name + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16 + # add Cast node (from tensor(float) to tensor(float16) after graph input + new_node = [helper.make_node("Cast", [n.name], [output_name], to=TensorProto.FLOAT16, name=node_name)] + model.graph.node.extend(new_node) + value_info_list.append(new_value_info) + io_casts.add(node_name) + + for i, n in enumerate(model.graph.output): + if n.name in fp32_outputs: + input_name = "graph_output_cast_" + str(i) + name_mapping[n.name] = input_name + graph_io_to_skip.add(n.name) + + node_name = "graph_output_cast" + str(i) + # add Cast node (from tensor(float16) to tensor(float) before graph output + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(n) + new_value_info.name = input_name + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16 + new_node = [helper.make_node("Cast", [input_name], [n.name], to=1, name=node_name)] + model.graph.node.extend(new_node) + value_info_list.append(new_value_info) + io_casts.add(node_name) + + fp32_initializers: Dict[str, InitializerTracker] = {} + while queue: + next_level = [] + for q in queue: + # if q is model, push q.graph (GraphProto) + if isinstance(q, ModelProto): + next_level.append(q.graph) + # if q is model.graph, push q.node.attribute (AttributeProto) + if isinstance(q, GraphProto): + for n in q.initializer: # TensorProto type + if n.data_type == TensorProto.FLOAT: + assert n.name not in fp32_initializers + fp32_initializers[n.name] = InitializerTracker(n) + + for n in q.node: + # if n is in the block list (doesn't support float16), no conversion for the node, + # and save the node for further processing + if n.name in io_casts: + continue + for i in range(len(n.input)): + if n.input[i] in name_mapping: + n.input[i] = name_mapping[n.input[i]] + for i in range(len(n.output)): + if n.output[i] in name_mapping: + n.output[i] = name_mapping[n.output[i]] + + is_node_blocked = n.op_type in op_block_list or n.name in node_block_list + for i, input_name in enumerate(n.input): + if input_name in fp32_initializers: + # For Resize/GroupNorm, only the first input can be float16 + use_fp32_weight = is_node_blocked or ( + i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) + and i not in force_fp16_inputs_dict.get(n.op_type, []) + ) + fp32_initializers[input_name].add_node(n, use_fp32_weight) + + if is_node_blocked: + node_list.append(n) + else: + if n.op_type == "Cast": + for attr in n.attribute: + if attr.name == "to" and attr.i == TensorProto.FLOAT: + attr.i = TensorProto.FLOAT16 + break + + if n.op_type in [ + "EyeLike", + "Multinomial", + "RandomNormal", + "RandomNormalLike", + "RandomUniform", + "RandomUniformLike", + "SequenceEmpty", + "Bernoulli", + ]: + has_dtype = False + for attr in n.attribute: + if attr.name == "dtype": + has_dtype = True + if attr.i == TensorProto.FLOAT: + attr.i = TensorProto.FLOAT16 + + # The dtype attribute is optional and default is FLOAT in the following operators + # so we need add dtype attribute to specify the data type float16 + if (n.op_type in ["RandomNormal", "RandomUniform", "SequenceEmpty"]) and not has_dtype: + n.attribute.extend([helper.make_attribute("dtype", TensorProto.FLOAT16)]) + + # For Resize/GroupNorm, attribute data type cannot be changed + if n.op_type not in ALWAYS_FLOAT_INPUTS or n.op_type in force_fp16_inputs_dict: + for attr in n.attribute: + next_level.append(attr) # noqa: PERF402 + else: + mixed_float_type_node_list.append(n) + + # if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto) + # and process node.attribute.t and node.attribute.tensors (TensorProto) + if isinstance(q, AttributeProto): + next_level.append(q.g) + for n in q.graphs: + next_level.append(n) # noqa: PERF402 + q.t.CopyFrom(convert_tensor_float_to_float16(q.t, min_positive_val, max_finite_val)) + for n in q.tensors: + n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) # noqa: PLW2901 + # if q is graph, process input, output and value_info (ValueInfoProto) + if isinstance(q, GraphProto): + # Note that float initializers tracked by fp32_initializers will be processed later. + # for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to + # tensor(float16) except map and seq(map). And save them in value_info_list for further processing + for n in itertools.chain(q.input, q.output, q.value_info): + if n.type.tensor_type.elem_type == TensorProto.FLOAT: + if n.name not in graph_io_to_skip: + n.type.tensor_type.elem_type = TensorProto.FLOAT16 + value_info_list.append(n) + if n.type.HasField("sequence_type"): + if n.type.sequence_type.elem_type.tensor_type.elem_type == TensorProto.FLOAT: + if n.name not in graph_io_to_skip: + n.type.sequence_type.elem_type.tensor_type.elem_type = TensorProto.FLOAT16 + value_info_list.append(n) + + queue = next_level + + for value in fp32_initializers.values(): + # By default, to avoid precision loss, do not convert an initializer to fp16 when it is used only by fp32 nodes. + if force_fp16_initializers or value.fp16_nodes: + value.initializer = convert_tensor_float_to_float16(value.initializer, min_positive_val, max_finite_val) + value_info_list.append(make_value_info_from_tensor(value.initializer)) + if value.fp32_nodes and not force_fp16_initializers: + logger.info( + "initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{}".format( + value.fp16_nodes + ) + ) + + # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. + for node in mixed_float_type_node_list: + for i, input_name in enumerate(node.input): + if i not in ALWAYS_FLOAT_INPUTS[node.op_type] or i in force_fp16_inputs_dict.get(node.op_type, []): + continue + for value_info in value_info_list: + if input_name == value_info.name: + # create new value_info for current node's new input name + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(value_info) + output_name = node.name + "_input_cast_" + str(i) + new_value_info.name = output_name + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + # add Cast node (from tensor(float16) to tensor(float) before current node + node_name = node.name + "_input_cast" + str(i) + new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] + model.graph.node.extend(new_node) + # change current node's input name + node.input[i] = output_name + break + + accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT + # process the nodes in block list that doesn't support tensor(float16) + for node in node_list: + # if input's name is in the value_info_list meaning input is tensor(float16) type, + # insert a float16 to float Cast node before the node, + # change current node's input name and create new value_info for the new name + for i in range(len(node.input)): + input_name = node.input[i] + for value_info in value_info_list: + if input_name == value_info.name: + # create new value_info for current node's new input name + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(value_info) + output_name = node.name + "_input_cast_" + str(i) + new_value_info.name = output_name + new_value_info.type.tensor_type.elem_type = accuracy_type + # add Cast node (from tensor(float16) to tensor(float) before current node + node_name = node.name + "_input_cast" + str(i) + new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)] + model.graph.node.extend(new_node) + # change current node's input name + node.input[i] = output_name + break + # if output's name is in the value_info_list meaning output is tensor(float16) type, insert a float to + # float16 Cast node after the node, change current node's output name and create new value_info for the new name + for i in range(len(node.output)): + output = node.output[i] + for value_info in value_info_list: + if output == value_info.name: + # create new value_info for current node's new output + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(value_info) + input_name = node.name + "_output_cast_" + str(i) + new_value_info.name = input_name + new_value_info.type.tensor_type.elem_type = accuracy_type + # add Cast node (from tensor(float) to tensor(float16) after current node + node_name = node.name + "_output_cast" + str(i) + new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)] + model.graph.node.extend(new_node) + # change current node's input name + node.output[i] = input_name + break + return model + + +def float_to_float16_max_diff(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0): + """Measure the maximum absolute difference after converting a float tensor to float16.""" + if not isinstance(tensor, TensorProto): + raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}") + if tensor.data_type != TensorProto.FLOAT: + raise ValueError("Expected tensor data type is float.") + + float32_data = None + if tensor.float_data: + float32_data = np.array(tensor.float_data) + + if tensor.raw_data: + float32_data = np.frombuffer(tensor.raw_data, dtype="float32") + + if float32_data is None: + raise RuntimeError("external data not loaded!") + + float16_data = convert_np_to_float16(float32_data, min_positive_val, max_finite_val) + return np.amax(np.abs(float32_data - np.float32(float16_data))) diff --git a/transformers/fusion_base.py b/transformers/fusion_base.py new file mode 100644 index 000000000..67f4f0b55 --- /dev/null +++ b/transformers/fusion_base.py @@ -0,0 +1,137 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import defaultdict +from logging import getLogger +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +from onnx import NodeProto, helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class Fusion: + """ + Base class for Graph Fusion + """ + + def __init__( + self, + model: OnnxModel, + fused_op_type: str, + search_op_types: Union[str, List[str]], + description: str = "", + ): + self.search_op_types: List[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types + self.fused_op_type: str = fused_op_type + self.description: str = f"{fused_op_type}({description})" if description else fused_op_type + self.model: OnnxModel = model + self.nodes_to_remove: List = [] + self.nodes_to_add: List = [] + self.prune_graph: bool = False + self.node_name_to_graph_name: dict = {} + self.this_graph_name: Optional[str] = None + # It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter. + self.fused_count: defaultdict = defaultdict(int) + + def increase_counter(self, fused_op_name: str): + """ + Increase counter of a fused operator. + """ + self.fused_count[fused_op_name] += 1 + + def fuse( + self, + node: NodeProto, + input_name_to_nodes: Dict[str, List[NodeProto]], + output_name_to_node: Dict[str, NodeProto], + ): + """Interface for fusion that starts from a node""" + raise NotImplementedError + + def apply(self): + """ + Apply graph fusion on the whole model graph. + It searched nodes of given operators, and start fusion on each of those nodes. + """ + logger.debug(f"start {self.description} fusion...") + input_name_to_nodes = self.model.input_name_to_nodes() + output_name_to_node = self.model.output_name_to_node() + + # This assumes that two search ops will not be fused at same time! + for search_op_type in self.search_op_types: + for node in self.model.get_nodes_by_op_type(search_op_type): + graph = self.model.get_graph_by_node(node) + if graph is None: + raise Exception("Can not find node in any graph") + self.this_graph_name = graph.name + self.fuse(node, input_name_to_nodes, output_name_to_node) + + op_list = [node.op_type for node in self.nodes_to_add] + if self.fused_count: + for key, value in self.fused_count.items(): + if value: + logger.info(f"Fused {key}: {value}") + else: + count = op_list.count(self.fused_op_type) + if count > 0: + logger.info(f"Fused {self.description}: {count}") + + self.model.remove_nodes(self.nodes_to_remove) + self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name) + + if self.prune_graph: + self.model.prune_graph() + elif self.nodes_to_remove or self.nodes_to_add: + self.model.update_graph() + + def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True): + if raw: + np_type = helper.tensor_dtype_to_np_dtype(data_type) + if not isinstance(vals, np.ndarray): + bytes = np.array(vals, dtype=np_type).tobytes() + else: + bytes = vals.astype(np_type).tobytes() + tensor = helper.make_tensor( + name=name, + data_type=data_type, + dims=dims, + vals=bytes, + raw=True, + ) + else: + tensor = helper.make_tensor( + name=name, + data_type=data_type, + dims=dims, + vals=vals, + raw=False, + ) + + self.model.add_initializer(tensor, self.this_graph_name) + return tensor + + def add_nodes_to_remove(self, nodes: List[NodeProto]): + # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths). + # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B + # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are + # iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first. + # Since path A's shared nodes are removed, path B's shared nodes are not removed because they + # were previously removed for path A. This causes an error to print in remove_node that a node + # has failed to be removed. + # + # To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`. + # We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could + # be scenarios where the nodes need to be removed in a specific order and converting to a set would + # lose this order. + for node in nodes: + if node not in self.nodes_to_remove: + self.nodes_to_remove.append(node) + + def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]): + for node in nodes: + if node not in self.nodes_to_remove and node not in nodes_to_keep: + self.nodes_to_remove.append(node) diff --git a/transformers/fusion_options.py b/transformers/fusion_options.py new file mode 100644 index 000000000..edac1989e --- /dev/null +++ b/transformers/fusion_options.py @@ -0,0 +1,340 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from argparse import ArgumentParser +from enum import Enum + + +class AttentionMaskFormat: + # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance. + MaskIndexEnd = 0 + + # For experiment only. Do not use it in production. + MaskIndexEndAndStart = 1 + + # Raw attention mask with 0 means padding (or no attention) and 1 otherwise. + AttentionMask = 2 + + # No attention mask + NoMask = 3 + + +class AttentionOpType(Enum): + Attention = "Attention" + MultiHeadAttention = "MultiHeadAttention" + GroupQueryAttention = "GroupQueryAttention" + PagedAttention = "PagedAttention" + + def __str__(self): + return self.value + + # Override __eq__ to return string comparison + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value + + +class FusionOptions: + """Options of fusion in graph optimization""" + + def __init__(self, model_type): + self.enable_gelu = True + self.enable_layer_norm = True + self.enable_attention = True + self.enable_rotary_embeddings = True + + # Use MultiHeadAttention instead of Attention operator. The difference: + # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is + # merged into one. + # (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention. + self.use_multi_head_attention = False + self.disable_multi_head_attention_bias = False + + self.enable_skip_layer_norm = True + self.enable_embed_layer_norm = True + self.enable_bias_skip_layer_norm = True + self.enable_bias_gelu = True + self.enable_gelu_approximation = False + self.enable_qordered_matmul = True + + self.enable_shape_inference = True + self.enable_gemm_fast_gelu = False + self.group_norm_channels_last = True + + if model_type == "clip": + self.enable_embed_layer_norm = False + + # Set default to sequence length for BERT model to use fused attention to speed up. + # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. + self.attention_mask_format = AttentionMaskFormat.AttentionMask + if model_type == "bert": + self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd + elif model_type == "vit": + self.attention_mask_format = AttentionMaskFormat.NoMask + + self.attention_op_type = None + + # options for stable diffusion + if model_type in ["unet", "vae", "clip"]: + self.enable_nhwc_conv = True + self.enable_group_norm = True + self.enable_skip_group_norm = True + self.enable_bias_splitgelu = True + self.enable_packed_qkv = True + self.enable_packed_kv = True + self.enable_bias_add = True + + def use_raw_attention_mask(self, use_raw_mask=True): + if use_raw_mask: + self.attention_mask_format = AttentionMaskFormat.AttentionMask + else: + self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd + + def disable_attention_mask(self): + self.attention_mask_format = AttentionMaskFormat.NoMask + + def set_attention_op_type(self, attn_op_type: AttentionOpType): + self.attention_op_type = attn_op_type + + @staticmethod + def parse(args): + options = FusionOptions(args.model_type) + if args.disable_gelu: + options.enable_gelu = False + if args.disable_layer_norm: + options.enable_layer_norm = False + if args.disable_rotary_embeddings: + options.enable_rotary_embeddings = False + if args.disable_attention: + options.enable_attention = False + if args.use_multi_head_attention: + options.use_multi_head_attention = True + if args.disable_skip_layer_norm: + options.enable_skip_layer_norm = False + if args.disable_embed_layer_norm: + options.enable_embed_layer_norm = False + if args.disable_bias_skip_layer_norm: + options.enable_bias_skip_layer_norm = False + if args.disable_bias_gelu: + options.enable_bias_gelu = False + if args.enable_gelu_approximation: + options.enable_gelu_approximation = True + if args.disable_shape_inference: + options.enable_shape_inference = False + if args.enable_gemm_fast_gelu: + options.enable_gemm_fast_gelu = True + if args.use_mask_index: + options.use_raw_attention_mask(False) + if args.use_raw_attention_mask: + options.use_raw_attention_mask(True) + if args.no_attention_mask: + options.disable_attention_mask() + + if args.model_type in ["unet", "vae", "clip"]: + if args.use_group_norm_channels_first: + options.group_norm_channels_last = False + if args.disable_nhwc_conv: + options.enable_nhwc_conv = False + if args.disable_group_norm: + options.enable_group_norm = False + if args.disable_skip_group_norm: + options.enable_skip_group_norm = False + if args.disable_bias_splitgelu: + options.enable_bias_splitgelu = False + if args.disable_packed_qkv: + options.enable_packed_qkv = False + if args.disable_packed_kv: + options.enable_packed_kv = False + if args.disable_bias_add: + options.enable_bias_add = False + + return options + + @staticmethod + def add_arguments(parser: ArgumentParser): + parser.add_argument( + "--disable_attention", + required=False, + action="store_true", + help="disable Attention fusion", + ) + parser.set_defaults(disable_attention=False) + + parser.add_argument( + "--disable_skip_layer_norm", + required=False, + action="store_true", + help="disable SkipLayerNormalization fusion", + ) + parser.set_defaults(disable_skip_layer_norm=False) + + parser.add_argument( + "--disable_embed_layer_norm", + required=False, + action="store_true", + help="disable EmbedLayerNormalization fusion", + ) + parser.set_defaults(disable_embed_layer_norm=False) + + parser.add_argument( + "--disable_bias_skip_layer_norm", + required=False, + action="store_true", + help="disable Add Bias and SkipLayerNormalization fusion", + ) + parser.set_defaults(disable_bias_skip_layer_norm=False) + + parser.add_argument( + "--disable_bias_gelu", + required=False, + action="store_true", + help="disable Add Bias and Gelu/FastGelu fusion", + ) + parser.set_defaults(disable_bias_gelu=False) + + parser.add_argument( + "--disable_layer_norm", + required=False, + action="store_true", + help="disable LayerNormalization fusion", + ) + parser.set_defaults(disable_layer_norm=False) + + parser.add_argument( + "--disable_gelu", + required=False, + action="store_true", + help="disable Gelu fusion", + ) + parser.set_defaults(disable_gelu=False) + + parser.add_argument( + "--enable_gelu_approximation", + required=False, + action="store_true", + help="enable Gelu/BiasGelu to FastGelu conversion", + ) + parser.set_defaults(enable_gelu_approximation=False) + + parser.add_argument( + "--disable_shape_inference", + required=False, + action="store_true", + help="disable symbolic shape inference", + ) + parser.set_defaults(disable_shape_inference=False) + + parser.add_argument( + "--enable_gemm_fast_gelu", + required=False, + action="store_true", + help="enable GemmfastGelu fusion", + ) + parser.set_defaults(enable_gemm_fast_gelu=False) + + parser.add_argument( + "--use_mask_index", + required=False, + action="store_true", + help="use mask index to activate fused attention to speed up. It requires right-side padding!", + ) + parser.set_defaults(use_mask_index=False) + + parser.add_argument( + "--use_raw_attention_mask", + required=False, + action="store_true", + help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.", + ) + parser.set_defaults(use_raw_attention_mask=False) + + parser.add_argument( + "--no_attention_mask", + required=False, + action="store_true", + help="no attention mask. Only works for model_type=bert", + ) + parser.set_defaults(no_attention_mask=False) + + parser.add_argument( + "--use_multi_head_attention", + required=False, + action="store_true", + help="Use MultiHeadAttention instead of Attention operator for testing purpose. " + "Note that MultiHeadAttention might be slower than Attention when qkv are not packed. ", + ) + parser.set_defaults(use_multi_head_attention=False) + + parser.add_argument( + "--disable_group_norm", + required=False, + action="store_true", + help="not fuse GroupNorm. Only works for model_type=unet or vae", + ) + parser.set_defaults(disable_group_norm=False) + + parser.add_argument( + "--disable_skip_group_norm", + required=False, + action="store_true", + help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae", + ) + parser.set_defaults(disable_skip_group_norm=False) + + parser.add_argument( + "--disable_packed_kv", + required=False, + action="store_true", + help="not use packed kv for cross attention in MultiHeadAttention. Only works for model_type=unet", + ) + parser.set_defaults(disable_packed_kv=False) + + parser.add_argument( + "--disable_packed_qkv", + required=False, + action="store_true", + help="not use packed qkv for self attention in MultiHeadAttention. Only works for model_type=unet", + ) + parser.set_defaults(disable_packed_qkv=False) + + parser.add_argument( + "--disable_bias_add", + required=False, + action="store_true", + help="not fuse BiasAdd. Only works for model_type=unet", + ) + parser.set_defaults(disable_bias_add=False) + + parser.add_argument( + "--disable_bias_splitgelu", + required=False, + action="store_true", + help="not fuse BiasSplitGelu. Only works for model_type=unet", + ) + parser.set_defaults(disable_bias_splitgelu=False) + + parser.add_argument( + "--disable_nhwc_conv", + required=False, + action="store_true", + help="Do not use NhwcConv. Only works for model_type=unet or vae", + ) + parser.set_defaults(disable_nhwc_conv=False) + + parser.add_argument( + "--use_group_norm_channels_first", + required=False, + action="store_true", + help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae", + ) + parser.set_defaults(use_group_norm_channels_first=False) + + parser.add_argument( + "--disable_rotary_embeddings", + required=False, + action="store_true", + help="Do not fuse rotary embeddings into RotaryEmbedding op", + ) diff --git a/transformers/fusion_skiplayernorm.py b/transformers/fusion_skiplayernorm.py new file mode 100644 index 000000000..1ec5edf68 --- /dev/null +++ b/transformers/fusion_skiplayernorm.py @@ -0,0 +1,200 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger + +from fusion_base import Fusion +from fusion_utils import NumpyHelper +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionSkipLayerNormalization(Fusion): + """ + Fuse Add + LayerNormalization into one node: SkipLayerNormalization + Note: This fusion does not check the input shape of Add and LayerNormalization. + """ + + def __init__( + self, + model: OnnxModel, + fused_op_type: str = "SkipLayerNormalization", + search_op_types: str = "LayerNormalization", + ): + super().__init__(model, fused_op_type, search_op_types) + # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. + self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True) + + if self.shape_infer_helper is None: + # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op. + logger.warning("symbolic shape inference disabled or failed.") + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + add = self.model.get_parent(node, 0, output_name_to_node) + + # In some models there is input_ids->gather->add->LayerNorm and one of input of the + # add node is initializer with fixed shape which should not be fused into SkipLayerNorm + if add is None or add.op_type != "Add": + return + + # The number of inputs of add should be 2 + if len(add.input) != 2: + return + + for add_input in add.input: + if self.model.get_initializer(add_input) is not None: + return + + # To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization + if add in self.nodes_to_remove: + return + + # Root Mean Square Layer Normalization + simplified = node.op_type == "SimplifiedLayerNormalization" + + if self.shape_infer_helper is not None: + # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) + if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): + logger.debug( + "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", + add.input[0], + add.input[1], + ) + return + else: + logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed") + return + + gather_path = self.model.match_parent_path(add, ["Gather"], [None]) + if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None: + if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None: + return + + # This means that the residual Add before the LayerNormalization produces an output + # that is consumed by some other nodes or graph output other than the LayerNormalization itself + # We can still go ahead with the SkipLayerNormalization fusion but we need to + # preserve the output of Add and that needs to be produced by SkipLayerNormalization. + add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None + residual_add_has_multiple_consumers = ( + add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1 + ) + + outputs_to_keep = node.output + + if residual_add_has_multiple_consumers: + outputs_to_keep.extend([add.output[0]]) + + outputs = [node.output[0]] + + # Skip the other optional outputs of SkipLayerNormalization before adding the Add's output + if residual_add_has_multiple_consumers: + outputs.extend(["", "", add.output[0]]) + + if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node): + self.nodes_to_remove.extend([add, node]) + + inputs = ( + [add.input[0], add.input[1], node.input[1], node.input[2]] + if not simplified + else [add.input[0], add.input[1], node.input[1]] + ) + normalize_node = helper.make_node( + self.fused_op_type, + inputs=inputs, + outputs=outputs, + name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"), + ) + normalize_node.domain = "com.microsoft" + + # Pass attribute "epsilon" from layernorm node to SkipLayerNormalization + for att in node.attribute: + if att.name == "epsilon": + normalize_node.attribute.extend([att]) + + # Set default epsilon if no epsilon exists from layernorm + if len(normalize_node.attribute) == 0: + normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)]) + + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + + +class FusionBiasSkipLayerNormalization(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "SkipLayerNormalization", "SkipLayerNormalization", "add bias") + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + if len(node.input) != 4: + return + + return_indice = [] + nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice) + if nodes is not None: + (add, _matmul) = nodes + else: + # In case of fp16, we could have a Cast between the MatMul and the bias Add + return_indice = [] + nodes = self.model.match_parent_path( + node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice + ) + if nodes is not None: + (add, _cast, _matmul) = nodes + else: + return + + assert len(return_indice) == 2 or len(return_indice) == 3 + add_input_index = return_indice[0] + if add_input_index >= 2: + return + sln_input = add.input[return_indice[1]] + bias_input = add.input[1 - return_indice[1]] + skip_input = node.input[1 - add_input_index] + + # bias should be one dimension + initializer = self.model.get_initializer(bias_input) + if initializer is None: + return + bias_weight = NumpyHelper.to_array(initializer) + if bias_weight is None: + logger.debug("Bias weight not found") + return + if len(bias_weight.shape) != 1: + logger.debug("Bias weight is not 1D") + return + + subgraph_nodes = [node, add] + if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, input_name_to_nodes, output_name_to_node): + logger.debug("Skip fusing SkipLayerNormalization with Bias since it is not safe") + return + + self.nodes_to_remove.extend(subgraph_nodes) + inputs = [ + sln_input, + skip_input, + node.input[2], + node.input[3], + bias_input, + ] + new_node = helper.make_node( + "SkipLayerNormalization", + inputs=inputs, + outputs=node.output, + name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"), + ) + new_node.domain = "com.microsoft" + + # Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias) + for att in node.attribute: + if att.name == "epsilon": + new_node.attribute.extend([att]) + + # Set default epsilon if no epsilon exists from skiplayernorm + if len(new_node.attribute) == 0: + new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)]) + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name diff --git a/transformers/fusion_utils.py b/transformers/fusion_utils.py new file mode 100644 index 000000000..726c587ff --- /dev/null +++ b/transformers/fusion_utils.py @@ -0,0 +1,307 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Optional, Tuple + +import numpy +from numpy import array_equal, ndarray +from onnx import NodeProto, TensorProto, helper, numpy_helper +from onnx import onnx_pb as onnx_proto +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionUtils: + def __init__(self, model: OnnxModel): + self.model: OnnxModel = model + + def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: + graph_input = self.model.find_graph_input(input_name) + if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32: + cast_output, cast_node = self.cast_input_to_int32(input_name) + logger.debug(f"Casted graph input {input_name} to int32") + return True, cast_output + + logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}") + return False, input_name + + def cast_input(self, input_name: str, target_type="int32"): + output_name = input_name + "_" + target_type + + if target_type == "int32": + to_type = int(TensorProto.INT32) + elif target_type == "float32": + to_type = int(TensorProto.FLOAT) + elif target_type == "float16": + to_type = int(TensorProto.FLOAT16) + else: + raise ValueError("Invalid target_type: {target_type}") + + cast_node = self.add_cast_node(input_name, to_type, output_name) + + return output_name, cast_node + + def add_cast_node( + self, + input_name: str, + to_type: int, + output_name: Optional[str] = None, + output_name_to_node=None, + graph_name: Optional[str] = None, + ): + if output_name is None: + output_name = input_name + f"_cast_to_{to_type}" + + # Avoid consequent Cast nodes. + inputs = [input_name] + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + if input_name in output_name_to_node: + parent_node = output_name_to_node[input_name] + if parent_node and parent_node.op_type == "Cast": + inputs = [parent_node.input[0]] + + cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name]) + + cast_node.attribute.extend([helper.make_attribute("to", to_type)]) + self.model.add_node(cast_node, graph_name=graph_name) + + return cast_node + + def cast_input_to_int32(self, input_name: str): + return self.cast_input(input_name, "int32") + + def remove_cast_int32(self, input_name: str): + input_name_to_nodes = self.model.input_name_to_nodes() + nodes = input_name_to_nodes[input_name] + for node in nodes: + if node.op_type == "Cast": + is_int32 = False + for att in node.attribute: + if att.name == "to" and att.i == int(TensorProto.INT32): + is_int32 = True + break + if is_int32: + output_name = node.output[0] + self.model.remove_node(node) + self.model.replace_input_of_all_nodes(output_name, input_name) + + @staticmethod + def update_node_input(node, i, new_input_name, input_name_to_nodes): + old_input_reference = 0 + if (node.input[i] in input_name_to_nodes) and node in input_name_to_nodes[node.input[i]]: + input_name_to_nodes[node.input[i]].remove(node) + old_input_reference = len(input_name_to_nodes[node.input[i]]) + + node.input[i] = new_input_name + + if new_input_name in input_name_to_nodes: + input_name_to_nodes[new_input_name].append(node) + else: + input_name_to_nodes[new_input_name] = [node] + + return old_input_reference + + @staticmethod + def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_input_index=0, parent_input_index=0): + """ + Before: + (input)-->parent-->node-->(output) + After: + (input)-->parent--> + | + +----->node-->(output) + + This function returns a flag whether the parent node can be removed. + """ + + old_input_name = node.input[node_input_index] + new_input_name = parent_node.input[parent_input_index] + old_input_reference = FusionUtils.update_node_input(node, node_input_index, new_input_name, input_name_to_nodes) + + # We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore. + parent_can_be_removed = (old_input_reference == 0) and not model.find_graph_output(old_input_name) + + return parent_can_be_removed + + @staticmethod + def check_node_attribute(node, attribute_name: str, expected_value, default_value=None): + """Verify that a node has expected value for an attribute. + + Args: + node (NodeProto): a node to check + attribute_name (str): name of attribute + expected_value (Any): expected value of the attribute + default_value (Any, optional): default value if the attribute does not exist. Defaults to None. + + Returns: + bool: whether the check is passed or not + """ + value = default_value + for attr in node.attribute: + if attr.name == attribute_name: + value = helper.get_attribute_value(attr) + + if isinstance(expected_value, list): + return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False) + else: + return value == expected_value + + @staticmethod + def transpose_2d_int8_tensor(tensor: onnx_proto.TensorProto): + """Transpose a 2-D INT8 TensorProto + Args: + tensor (TensorProto): tensor to be transposed + Returns: + tensor (TensorProto): transposed tensor + """ + if not isinstance(tensor, onnx_proto.TensorProto): + raise ValueError("Expected input type is an ONNX TensorProto but got %s" % type(tensor)) + + if len(tensor.dims) != 2 or tensor.data_type != onnx_proto.TensorProto.INT8: + raise ValueError("Only INT8 2-D tensors can be transposed") + + if tensor.raw_data: + int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims) + int32_transposed_data = numpy.transpose(int32_data, [1, 0]) + tensor.raw_data = int32_transposed_data.tobytes() + + else: + raise ValueError("only raw buffer supported") + + return tensor + + @staticmethod + def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True): + """Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion. + It is a good candidate for fusion if: + (1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True` + (2) The Q/DQ node should have constant scale + (3) The Q/DQ node should have a zero point of 0 + Args: + node (NodeProto): a Q/DQ node to check + Returns: + bool: whether the check is passed or not + """ + if node.op_type not in {"QuantizeLinear", "DequantizeLinear"}: + logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}") + + scale = model.get_constant_value(node.input[1]) + + # Scale is not constant + if scale is None: + return False + + # Not per-tensor quantization + scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1) + if allow_per_tensor_quantization_only and not scale_has_single_element: + return False + + # If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec) + if len(node.input) == 2: + return True + + # Zero point should be constant and should have a value of 0 + zero_point = model.get_constant_value(node.input[2]) + + # Zero point and scale should have same number of dims + if scale.ndim != zero_point.ndim: + return False + + # Zero point is not constant or zero point is not zero + if zero_point is None: + return False + + return numpy.all(zero_point == 0) + + def check_node_input_value(self, node, input_index: int, expected_value): + """Verify that a node has expected input value + + Args: + node (NodeProto): a node to check + input_index (int): index of its input to be verified + expected_value (Any): expected value of the input + + Returns: + bool: whether the check is passed or not + """ + assert len(node.input) > input_index + + value = self.model.get_constant_value(node.input[input_index]) + + if isinstance(expected_value, list): + return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False) + else: + return value == expected_value + + def remove_identity_nodes(self): + """Remove Identity nodes, except those right before graph output.""" + nodes_to_remove = [] + graph_output_names = self.model.get_graphs_output_names() + for node in self.model.nodes(): + if node.op_type == "Identity": + if node.output[0] not in graph_output_names: + self.model.replace_input_of_all_nodes(node.output[0], node.input[0]) + nodes_to_remove.append(node) + + if nodes_to_remove: + self.model.remove_nodes(nodes_to_remove) + logger.info(f"Removed {len(nodes_to_remove)} Identity nodes") + + def remove_cascaded_cast_nodes(self): + self.model.remove_cascaded_cast_nodes() + + def remove_useless_cast_nodes(self): + self.model.remove_useless_cast_nodes() + + def remove_useless_reshape_nodes(self): + """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape""" + shape_infer = self.model.infer_runtime_shape(update=True) + if shape_infer is None: + return + + nodes_to_remove = [] + for node in self.model.nodes(): + if node.op_type == "Reshape": + input_shape = shape_infer.get_edge_shape(node.input[0]) + output_shape = shape_infer.get_edge_shape(node.output[0]) + if input_shape and output_shape and input_shape == output_shape: + logger.info( + f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}" + ) + nodes_to_remove.append(node) + + if nodes_to_remove: + graph_input_names = set(self.model.get_graphs_input_names()) + graph_output_names = set(self.model.get_graphs_output_names()) + for node in nodes_to_remove: + if bool(set(node.output) & graph_output_names): + if ( + not bool(set(node.input) & graph_input_names) + and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child + ): + self.model.replace_output_of_all_nodes(node.input[0], node.output[0]) + else: + continue + else: + self.model.replace_input_of_all_nodes(node.output[0], node.input[0]) + self.model.remove_node(node) + + +class NumpyHelper: + @staticmethod + def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray: + # When weights are in external data format but not presented, we can still test the optimizer with two changes: + # (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py + if fill_zeros: + from onnx import mapping + + return ndarray( + shape=tensor.dims, + dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type], + ) + + return numpy_helper.to_array(tensor) diff --git a/transformers/inference_example.py b/transformers/inference_example.py new file mode 100644 index 000000000..de29bddea --- /dev/null +++ b/transformers/inference_example.py @@ -0,0 +1,217 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import numpy as np +import torch +from transformers import AutoTokenizer + +import onnxruntime as ort +from onnxruntime_extensions import get_library_path as _get_library_path + +pt_to_np = { + "torch.int32": np.int32, + "torch.int64": np.int64, + "torch.float32": np.float32, + "torch.float16": np.float16, +} + + +class ORTGenerator: + def __init__(self, decoder_path): + self.onnx_decoder_path = decoder_path + self.num_heads = 32 + self.head_size = 80 + self.num_layers = 32 + self.max_sequence_length = 2048 + + def get_initial_inputs_and_outputs(self, encodings_dict): + self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32 + + input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32) + attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32) + step = torch.tensor([0], device=self.device, dtype=torch.int64) + + inputs = { + "input_ids": input_ids.contiguous(), + "attention_mask": attention_mask.contiguous(), + } + + if self.use_step: + inputs["step"] = step.contiguous() + + batch_size, sequence_length = input_ids.shape + + past_seq_length = self.max_sequence_length if self.use_buffer_share else 0 + past_shape = ( + (2, batch_size, self.num_heads, past_seq_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, past_seq_length, self.head_size) + ) + for i in range(self.num_layers): + past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) + inputs.update( + {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()} + ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) + + logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype) + outputs = {"logits": logits.contiguous()} + + if not self.use_buffer_share: + present_shape = ( + (2, batch_size, self.num_heads, sequence_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, sequence_length, self.head_size) + ) + for i in range(self.num_layers): + present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) + outputs.update( + {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()} + ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + + return inputs, outputs + + def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict): + io_binding = model.io_binding() + device = None + + for k, v in inputs.items(): + io_binding.bind_input( + name=k, + device_type=v.device.type, + device_id=0 if v.device.type == "cpu" else v.device.index, + element_type=pt_to_np[repr(v.dtype)], + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + device = v.device + + for output in model.get_outputs(): + name = output.name + if self.use_buffer_share and "present" in name: + v = inputs[name.replace("present", "past")] + io_binding.bind_output( + name=name, + device_type=v.device.type, + device_id=v.device.index, + element_type=(np.float16 if self.use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + else: + v = outputs[name] + io_binding.bind_output( + name=name, + device_type=device.type, + device_id=0 if device.type == "cpu" else device.index, + element_type=(np.float16 if self.use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + + return io_binding + + def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False): + sess_options = ort.SessionOptions() + sess_options.register_custom_ops_library(_get_library_path()) + ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider" + self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep]) + + self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu") + self.use_fp16 = use_fp16 + self.use_buffer_share = use_buffer_share + self.packed_kv = packed_kv + self.use_step = use_step + + self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + self.tokenizer.pad_token = "[PAD]" + + def generate(self, prompt, max_length): + encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) + + inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict) + + all_token_ids = inputs["input_ids"].clone() + batch_size, sequence_length = all_token_ids.shape + + current_length = sequence_length + has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool) + + while current_length < max_length: + io_binding = self.apply_io_binding(self.sess, inputs, outputs) + + io_binding.synchronize_inputs() + self.sess.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # Sample with argmax (greedy search) + next_token_logits = outputs["logits"][:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + + # Check if we previously reached EOS token id or if generated token id is EOS token id + has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id + + # Determine which new tokens to add to list of all token ids + # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't) + tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1]) + all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) + + # Return early if all batch entries have reached EOS token id + if torch.all(has_eos): + break + + # Update inputs for next inference run + current_length += 1 + inputs["input_ids"] = tokens_to_add.to(torch.int32) + if self.use_step: + inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64) + inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to( + torch.int32 + ) + + # Set logits to zeros for next inference run and re-use memory buffer + if outputs["logits"].shape[1] != 1: + outputs["logits"] = outputs["logits"][:, :1, :].contiguous() + outputs["logits"].zero_() + + if not self.use_buffer_share: + for i in range(self.num_layers): + if not self.packed_kv: + inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"] + inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"] + else: + inputs[f"past_{i}"] = outputs[f"present_{i}"] + + new_sequence_length = inputs["attention_mask"].shape[1] + present_shape = ( + (2, batch_size, self.num_heads, new_sequence_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, new_sequence_length, self.head_size) + ) + for i in range(self.num_layers): + present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) + outputs.update( + {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()} + ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + + texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) + return texts + + +def run_phi2(onnx_model_path, use_buffer_share=True, device_id=0, packed_kv=False, use_fp16=True, use_step=True): + prompt = [ + '''```python + def print_prime(n): + """ + Print all primes between 1 and n + """''' + ] + + generator = ORTGenerator(onnx_model_path) + generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step) + texts = generator.generate(prompt, max_length=200) + + for i in range(len(texts)): + print("Prompt: ", prompt[i]) + print("Texts: ", texts[i]) diff --git a/transformers/onnx_model.py b/transformers/onnx_model.py new file mode 100644 index 000000000..a8fc6e661 --- /dev/null +++ b/transformers/onnx_model.py @@ -0,0 +1,1524 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import itertools +import logging +import os +import sys +from collections import deque +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from float16 import convert_float_to_float16 +from onnx import ( + AttributeProto, + GraphProto, + ModelProto, + NodeProto, + TensorProto, + ValueInfoProto, + helper, + numpy_helper, + save_model, +) +from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data +from shape_infer_helper import SymbolicShapeInferenceHelper + +logger = logging.getLogger(__name__) + + +class OnnxModel: + def __init__(self, model): + self.initialize(model) + + def initialize(self, model): + self.model: ModelProto = model + self._node_name_suffix: Dict[str, int] = {} # key is node name prefix, value is the last suffix generated + self.shape_infer_helper: SymbolicShapeInferenceHelper = None + self.enable_shape_infer: bool = True + self.all_graphs: Optional[List[GraphProto]] = None + + # Cache of shape and data type from onnx graph to speed up optimization. + # Be careful that fusion shall not reuse node output name for different shape/type (in adding/removing nodes) + # Note that these do not cache the symbolic shape inference result. + self._dtype_dict: Optional[Dict[str, int]] = None + self._shape_dict: Optional[Dict[str, List]] = None + + def disable_shape_inference(self): + self.enable_shape_infer = False + + def infer_runtime_shape(self, dynamic_axis_mapping={}, update=False): # noqa: B006 + if self.enable_shape_infer: + if self.shape_infer_helper is None or update: + self.shape_infer_helper = SymbolicShapeInferenceHelper(self.model) + + try: + if self.shape_infer_helper.infer(dynamic_axis_mapping): + return self.shape_infer_helper + except Exception: + self.enable_shape_infer = False # disable shape inference to suppress same error message. + print("failed in shape inference", sys.exc_info()[0]) + + return None + + def input_name_to_nodes(self): + input_name_to_nodes = {} + for node in self.nodes(): + for input_name in node.input: + if input_name: # could be empty when it is optional + if input_name not in input_name_to_nodes: + input_name_to_nodes[input_name] = [node] + else: + input_name_to_nodes[input_name].append(node) + return input_name_to_nodes + + def output_name_to_node(self): + output_name_to_node = {} + for node in self.nodes(): + for output_name in node.output: + if output_name: # could be empty when it is optional + output_name_to_node[output_name] = node + return output_name_to_node + + def functions(self): + all_functions = [list(self.model.functions)] + return all_functions + + def nodes(self): + all_nodes = [] + for graph in self.graphs(): + for node in graph.node: + all_nodes.append(node) # noqa: PERF402 + return all_nodes + + def graph(self): + return self.model.graph + + def graphs(self): + if self.all_graphs is not None: + return self.all_graphs + self.all_graphs = [] + graph_queue = [self.model.graph] + while graph_queue: + graph = graph_queue.pop(0) + self.all_graphs.append(graph) + for node in graph.node: + for attr in node.attribute: + if attr.type == AttributeProto.AttributeType.GRAPH: + assert isinstance(attr.g, GraphProto) + graph_queue.append(attr.g) + if attr.type == AttributeProto.AttributeType.GRAPHS: + for g in attr.graphs: + assert isinstance(g, GraphProto) + graph_queue.append(g) + return self.all_graphs + + def get_graphs_input_names(self): + input_names = [] + for graph in self.graphs(): + for input in graph.input: + input_names.append(input.name) + return input_names + + def get_graphs_output_names(self): + output_names = [] + for graph in self.graphs(): + for output in graph.output: + output_names.append(output.name) + return output_names + + def get_graph_by_node(self, node): + for graph in self.graphs(): + if node in graph.node: + return graph + return None + + def get_graph_by_name(self, graph_name): + for graph in self.graphs(): + if graph_name == graph.name: + return graph + return None + + def get_topological_insert_id(self, graph, outputs): + for idx, node in enumerate(graph.node): + for input in node.input: + if input in outputs: + return idx + return len(graph.node) + + def remove_node(self, node): + for graph in self.graphs(): + if node in graph.node: + graph.node.remove(node) + return + logger.warning("Failed to remove node %s", node) # It might be a bug to hit this line. + + def remove_nodes(self, nodes_to_remove): + for node in nodes_to_remove: + self.remove_node(node) + + def add_node(self, node, graph_name=None): + if graph_name is None or graph_name == self.model.graph.name: + self.model.graph.node.extend([node]) + else: + graph = self.get_graph_by_name(graph_name) + insert_idx = self.get_topological_insert_id(graph, node.output) + graph.node.insert(insert_idx, node) + + def add_nodes(self, nodes_to_add, node_name_to_graph_name=None): + if node_name_to_graph_name is None: + self.model.graph.node.extend(nodes_to_add) + else: + for node in nodes_to_add: + graph_name = node_name_to_graph_name[node.name] + self.add_node(node, graph_name) + + def add_initializer(self, tensor, graph_name=None): + if graph_name is None or graph_name == self.model.graph.name: + self.model.graph.initializer.extend([tensor]) + else: + graph = self.get_graph_by_name(graph_name) + graph.initializer.extend([tensor]) + + def add_input(self, input, graph_name=None): + if graph_name is None or graph_name == self.model.graph.name: + self.model.graph.input.extend([input]) + else: + graph = self.get_graph_by_name(graph_name) + graph.input.extend([input]) + + @staticmethod + def replace_node_input(node, old_input_name, new_input_name): + assert isinstance(old_input_name, str) and isinstance(new_input_name, str) + for j in range(len(node.input)): + if node.input[j] == old_input_name: + node.input[j] = new_input_name + + def replace_input_of_all_nodes(self, old_input_name, new_input_name): + for node in self.model.graph.node: + OnnxModel.replace_node_input(node, old_input_name, new_input_name) + + @staticmethod + def replace_node_output(node, old_output_name, new_output_name): + assert isinstance(old_output_name, str) and isinstance(new_output_name, str) + for j in range(len(node.output)): + if node.output[j] == old_output_name: + node.output[j] = new_output_name + + def replace_output_of_all_nodes(self, old_output_name, new_output_name): + # This function shall be used carefully. For example: + # Add --[old_name]--> Cast ---> [new_name] + # | + # +----[old_name]--> Transpose --> + # If we want to remove the Cast node: replace output of Add to new_name is not enough; + # The input of Transpose shall also be updated to new_name. + for node in self.model.graph.node: + OnnxModel.replace_node_output(node, old_output_name, new_output_name) + + def get_initializer(self, name): + for graph in self.graphs(): + for tensor in graph.initializer: + if tensor.name == name: + return tensor + return None + + def get_nodes_by_op_type(self, op_type): + nodes = [] + for node in self.nodes(): + if node.op_type == op_type: + nodes.append(node) + return nodes + + def get_children(self, node, input_name_to_nodes=None): + if input_name_to_nodes is None: + input_name_to_nodes = self.input_name_to_nodes() + + children = [] + for output in node.output: + if output in input_name_to_nodes: + for node in input_name_to_nodes[output]: + children.append(node) # noqa: PERF402 + return children + + def get_parents(self, node, output_name_to_node=None): + if output_name_to_node is None: + output_name_to_node = self.output_name_to_node() + + parents = [] + for input in node.input: + if input in output_name_to_node: + parents.append(output_name_to_node[input]) + return parents + + def get_parent(self, node, i, output_name_to_node=None): + if output_name_to_node is None: + output_name_to_node = self.output_name_to_node() + + if len(node.input) <= i: + return None + + input = node.input[i] + if input not in output_name_to_node: + return None + + return output_name_to_node[input] + + def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=[]): # noqa: B006 + """ + Find parent node based on constraints on op_type. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + + Returns: + parent: The matched parent node. None if not found. + index: The input index of matched parent node. None if not found. + """ + for i, input in enumerate(node.input): + if input in output_name_to_node: + parent = output_name_to_node[input] + if parent.op_type == parent_op_type and parent not in exclude: + return parent, i + else: + logger.debug(f"To find first {parent_op_type}, current {parent.op_type}") + return None, None + + def match_parent( + self, + node, + parent_op_type, + input_index=None, + output_name_to_node=None, + exclude=[], # noqa: B006 + return_indice=None, + ): + """ + Find parent node based on constraints on op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + input_index (int or None): only check the parent given input index of current node. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + return_indice (list): a list to append the input index when input_index is None. + + Returns: + parent: The matched parent node. + """ + assert node is not None + assert input_index is None or input_index >= 0 + + if output_name_to_node is None: + output_name_to_node = self.output_name_to_node() + + if input_index is None: + parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude) + if return_indice is not None: + return_indice.append(index) + return parent + + if input_index >= len(node.input): + logger.debug(f"input_index {input_index} >= node inputs {len(node.input)}") + return None + + parent = self.get_parent(node, input_index, output_name_to_node) + if parent is not None and parent.op_type == parent_op_type and parent not in exclude: + return parent + + if parent is not None: + logger.debug(f"Expect {parent_op_type}, Got {parent.op_type}") + + return None + + def match_parent_paths(self, node, paths, output_name_to_node): + for i, path in enumerate(paths): + assert isinstance(path, (List, Tuple)) + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + return i, matched, return_indice + return -1, None, None + + def match_parent_paths_all(self, node, paths, output_name_to_node): + match_i, matches, return_indices = [], [], [] + for i, path in enumerate(paths): + assert isinstance(path, (List, Tuple)) + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + match_i.append(i) + matches.append(matched) + return_indices.append(return_indice) + return match_i, matches, return_indices + + def match_parent_path( + self, + node, + parent_op_types, + parent_input_index=None, + output_name_to_node=None, + return_indice=None, + ): + """ + Find a sequence of input edges based on constraints on parent op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + parent_op_types (str): constraint of parent node op_type of each input edge. + parent_input_index (list): constraint of input index of each input edge. None means no constraint. + output_name_to_node (dict): dictionary with output name as key, and node as value. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. + + Returns: + parents: a list of matched parent node. + """ + if parent_input_index is not None: + assert len(parent_input_index) == len(parent_op_types) + + if output_name_to_node is None: + output_name_to_node = self.output_name_to_node() + + current_node = node + matched_parents = [] + for i, op_type in enumerate(parent_op_types): + matched_parent = self.match_parent( + current_node, + op_type, + parent_input_index[i] if parent_input_index is not None else None, + output_name_to_node, + exclude=[], + return_indice=return_indice, + ) + if matched_parent is None: + if parent_input_index is not None: + logger.debug( + f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", + stack_info=True, + ) + else: + logger.debug(f"Failed to match index={i} op_type={op_type}", stack_info=True) + return None + + matched_parents.append(matched_parent) + current_node = matched_parent + + return matched_parents + + def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, recursive=True): + children = self.get_children(node, input_name_to_nodes) + dq = deque(children) + while len(dq) > 0: + current_node = dq.pop() + if current_node.op_type == child_type: + return current_node + + if recursive: + children = self.get_children(current_node, input_name_to_nodes) + for child in children: + dq.appendleft(child) + + return None + + def match_child_path( + self, + node, + child_op_types, + child_output_index=None, + return_indice=None, + exclude=[], # noqa: B006 + ): + """ + Find a sequence of input edges based on constraints on parent op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + child_op_types (str): constraint of child node op_type of each input edge. + child_output_index (list): constraint of input index of each input edge. None means no constraint. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. + + Returns: + children: a list of matched children node. + """ + if child_output_index is not None: + assert len(child_output_index) == len(child_op_types) + + current_node = node + matched_children = [] + for i, op_type in enumerate(child_op_types): + matched_child = None + node_children = self.get_children(current_node) + for child_i, child in enumerate(node_children): + if child.op_type == op_type and child not in exclude: + if child_output_index is not None and child_output_index[i] != child_i: + logger.debug( + f"Failed to match index={i} child_output_index={child_output_index[i]} op_type={op_type}", + stack_info=True, + ) + return None + matched_child = child + if matched_child is None: + logger.debug(f"Failed to match child op_type={op_type}", stack_info=True) + return None + + matched_children.append(matched_child) + current_node = matched_child + return matched_children + + def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True): + if output_name_to_node is None: + output_name_to_node = self.output_name_to_node() + + parents = self.get_parents(node, output_name_to_node) + dq = deque(parents) + while len(dq) > 0: + current_node = dq.pop() + if current_node.op_type == parent_type: + return current_node + + if recursive: + parents = self.get_parents(current_node, output_name_to_node) + for parent in parents: + dq.appendleft(parent) + + return None + + def get_constant_value(self, output_name): + for node in self.get_nodes_by_op_type("Constant"): + if node.output[0] == output_name: + for att in node.attribute: + if att.name == "value": + return numpy_helper.to_array(att.t) + + # Fall back to intializer since constant folding might have been applied. + initializer = self.get_initializer(output_name) + if initializer is not None: + return numpy_helper.to_array(initializer) + + return None + + def get_constant_input(self, node): + for i, input in enumerate(node.input): + value = self.get_constant_value(input) + if value is not None: + return i, value + + return None, None + + def find_constant_input(self, node, expected_value, delta=0.000001): + i, value = self.get_constant_input(node) + if value is not None and value.size == 1 and abs(value - expected_value) < delta: + return i + + return -1 + + def is_constant_with_specified_dimension(self, output_name, dimensions, description): + value = self.get_constant_value(output_name) + if value is None: + logger.debug(f"{description} {output_name} is not initializer.") + return False + + if len(value.shape) != dimensions: + logger.debug(f"{description} {output_name} shall have {dimensions} dimensions. Got shape {value.shape}") + return False + + return True + + def has_constant_input(self, node, expected_value, delta=0.000001): + return self.find_constant_input(node, expected_value, delta) >= 0 + + def get_children_subgraph_nodes(self, root_node, stop_nodes, input_name_to_nodes=None): + if input_name_to_nodes is None: + input_name_to_nodes = self.input_name_to_nodes() + + children = input_name_to_nodes[root_node.output[0]] + + unique_nodes = [] + + dq = deque(children) + while len(dq) > 0: + current_node = dq.pop() + if current_node in stop_nodes: + continue + + if current_node not in unique_nodes: + unique_nodes.append(current_node) + + for output in current_node.output: + if output in input_name_to_nodes: + children = input_name_to_nodes[output] + for child in children: + dq.appendleft(child) + + return unique_nodes + + def tensor_shape_to_list(self, tensor_type): + """Convert tensor shape to list""" + shape_list = [] + for d in tensor_type.shape.dim: + if d.HasField("dim_value"): + shape_list.append(d.dim_value) # known dimension + elif d.HasField("dim_param"): + shape_list.append(d.dim_param) # unknown dimension with symbolic name + else: + shape_list.append("?") # shall not happen + return shape_list + + def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get data type given a name (could be initializer, input or output of graph or node).""" + + if self._dtype_dict is None: + self._dtype_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + self._dtype_dict[value_info.name] = value_info.type.tensor_type.elem_type + + for initializer in self.model.graph.initializer: + if initializer.name not in self._dtype_dict: + self._dtype_dict[initializer.name] = initializer.data_type + + if name in self._dtype_dict: + return self._dtype_dict[name] + + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type + + return None + + def get_shape(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get shape given a name (could be initializer, input or output of graph or node).""" + + if self._shape_dict is None: + self._shape_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + if value_info.type.tensor_type.HasField("shape"): + shape = [] + for dim in value_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + self._shape_dict[value_info.name] = shape + + for initializer in self.model.graph.initializer: + if initializer.name not in self._shape_dict: + self._shape_dict[initializer.name] = initializer.dims + + if name in self._shape_dict: + return self._shape_dict[name] + + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type + + return None + + @staticmethod + def get_node_attribute(node: NodeProto, attribute_name: str): + for attr in node.attribute: + if attr.name == attribute_name: + value = helper.get_attribute_value(attr) + return value + return None + + def remove_cascaded_cast_nodes(self): + """Remove Cast node that are followed by another Cast node like --> Cast --> Cast --> + Note that this shall be used carefully since it might introduce semantic change. + For example, float -> int -> float could get different value than the original float value. + So, it is recommended to used only in post-processing of mixed precision conversion. + """ + output_name_to_node = self.output_name_to_node() + removed_count = 0 + for node in self.nodes(): + if node.op_type == "Cast": + parent = self.get_parent(node, 0, output_name_to_node=output_name_to_node) + if parent and parent.op_type == "Cast": + node.input[0] = parent.input[0] + removed_count += 1 + + if removed_count > 0: + logger.info("Removed %d cascaded Cast nodes", removed_count) + self.prune_graph() + + def remove_useless_cast_nodes(self): + """Remove cast nodes that are not needed: input and output has same data type.""" + shape_infer = self.infer_runtime_shape(update=True) + if self.enable_shape_infer and shape_infer is None: + logger.warning("shape inference failed which might impact useless cast node detection.") + + nodes_to_remove = [] + for node in self.nodes(): + if node.op_type == "Cast": + input_dtype = self.get_dtype(node.input[0], shape_infer) + output_dtype = self.get_dtype(node.output[0], shape_infer) + if input_dtype and input_dtype == output_dtype: + nodes_to_remove.append(node) + + if nodes_to_remove: + graph_input_names = set(self.get_graphs_input_names()) + graph_output_names = set(self.get_graphs_output_names()) + for node in nodes_to_remove: + if bool(set(node.output) & graph_output_names): + if (not bool(set(node.input) & graph_input_names)) and len( + self.input_name_to_nodes()[node.input[0]] + ) == 1: + self.replace_output_of_all_nodes(node.input[0], node.output[0]) + else: + continue + else: + self.replace_input_of_all_nodes(node.output[0], node.input[0]) + self.remove_node(node) + + logger.info( + "Removed %d Cast nodes with output type same as input", + len(nodes_to_remove), + ) + + def convert_model_float32_to_float16(self, cast_input_output=True): + logger.warning( + "The function convert_model_float32_to_float16 is deprecated. Use convert_float_to_float16 instead!" + ) + self.convert_float_to_float16(use_symbolic_shape_infer=True, keep_io_types=cast_input_output) + + def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): + """Convert a model to half (default) or mixed precision. + To use mixed precision, user need specify which graph inputs, outputs, operator type + or list of nodes shall keep in float32. + + Note that the conversion might not proceed without type information for the whole graph. + + By default, we use symbolic shape inference to get type information. The benefit of symbolic shape inference + is that it could handle fused operators in com.microsoft domain. Those operators cannot be handled in onnx shape + inference so symbolic shape inference is recommended for optimized model. + + When symbolic shape inference is used (even if it failed), ONNX shape inference will be disabled. + + Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to enable + symbolic shape inference. If your model is not optimized, you can also use model path to call + convert_float_to_float16 in float16.py (see https://github.com/microsoft/onnxruntime/pull/15067) to + avoid the 2GB limit. + + Args: + use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference. + Defaults to True. + keep_io_types (Union[bool, List[str]], optional): boolean or a list of float32 input/output names. + If True, model inputs/outputs should be left as float32. + Defaults to True. + op_block_list (List[str], optional): List of operator types to leave as float32. + Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`. + node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None. + force_fp16_initializers(bool): force converting all float initializers to float16. + Default to false. + min_positive_val (float, optional): minimal positive value. Defaults to 1e-7. + max_finite_val (float, optional): maximal finite value. Defaults to 1e4. + force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if + this script's preference it to keep them in float32. + """ + if "keep_io_types" not in kwargs: + kwargs["keep_io_types"] = True + + model = self.model + if use_symbolic_shape_infer: + # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) + # are not recognized by onnx shape inference. + shape_infer_helper = SymbolicShapeInferenceHelper(model) + try: + model_with_shape = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False) + + # auto_merge might cause issue (see https://github.com/microsoft/onnxruntime/issues/15521) + # we only merge tensor data type but not shape information back to the original onnx model. + # Note that float16 conversion need data type but not shape information. + if model_with_shape is not None: + name_vi = {} + for vi in model_with_shape.graph.value_info: + if ( + hasattr(vi.type, "tensor_type") + and hasattr(vi.type.tensor_type, "elem_type") + and vi.type.tensor_type.elem_type != TensorProto.UNDEFINED + and vi.name + ): + vi_copy = ValueInfoProto() + vi_copy.CopyFrom(vi) + if hasattr(vi_copy.type.tensor_type, "shape"): + vi_copy.type.tensor_type.ClearField("shape") + name_vi[vi.name] = vi_copy + for vi in model.graph.value_info: + if vi.name in name_vi: + del name_vi[vi.name] + for vi in name_vi.values(): + model.graph.value_info.append(vi) + except Exception: + logger.warning( + "Failed to run symbolic shape inference. Please file an issue in https://github.com/microsoft/onnxruntime." + ) + + parameters = {"disable_shape_infer": use_symbolic_shape_infer} + parameters.update( + { + key: kwargs[key] + for key in [ + "keep_io_types", + "min_positive_val", + "max_finite_val", + "op_block_list", + "node_block_list", + "force_fp16_initializers", + "force_fp16_inputs", + "use_bfloat16_as_blocked_nodes_dtype", + ] + if key in kwargs + } + ) + + fp16_model = convert_float_to_float16(model, **parameters) + self.initialize(fp16_model) + + self.remove_cascaded_cast_nodes() + + self.remove_useless_cast_nodes() + + def create_node_name(self, op_type, name_prefix=None): + """Create a unique node name that starts with a prefix (default is operator type). + The name will not be duplicated with any name that generated or existed in current graphs. + Args: + op_type (str): operator type + name_prefix (str, optional): prefix of node name. Defaults to None. + + Returns: + str: node name + """ + + if name_prefix: + prefix = name_prefix if name_prefix.endswith("_") else (name_prefix + "_") + else: + prefix = op_type + "_" + + suffix: int = 0 + if prefix in self._node_name_suffix: + suffix = self._node_name_suffix[prefix] + 1 + else: + # Check existed node name only once for a prefix + # as we assume create_node_name is called for every new node in fusion. + for node in self.nodes(): + if node.name and node.name.startswith(prefix): + try: + index = int(node.name[len(prefix) :]) + suffix = max(index + 1, suffix) + except ValueError: + continue + + # Record the generated suffix so that we can avoid generating duplicated name. + self._node_name_suffix[prefix] = suffix + + return prefix + str(suffix) + + def find_graph_input(self, input_name): + for input in self.model.graph.input: + if input.name == input_name: + return input + return None + + def find_graph_output(self, output_name): + for output in self.model.graph.output: + if output.name == output_name: + return output + return None + + def get_parent_subgraph_nodes(self, node, stop_nodes, output_name_to_node=None): + if output_name_to_node is None: + output_name_to_node = self.output_name_to_node() + + unique_nodes = [] + + parents = self.get_parents(node, output_name_to_node) + dq = deque(parents) + while len(dq) > 0: + current_node = dq.pop() + if current_node in stop_nodes: + continue + + if current_node not in unique_nodes: + unique_nodes.append(current_node) + + for input in current_node.input: + if input in output_name_to_node: + dq.appendleft(output_name_to_node[input]) + + return unique_nodes + + def get_graph_inputs(self, current_node, recursive=False): + """ + Find graph inputs that linked to current node. + """ + graph_inputs = [] + for input in current_node.input: + if self.find_graph_input(input) and input not in graph_inputs: + graph_inputs.append(input) + + if recursive: + parent_nodes = self.get_parent_subgraph_nodes(current_node, []) + for node in parent_nodes: + for input in node.input: + if self.find_graph_input(input) and input not in graph_inputs: + graph_inputs.append(input) + return graph_inputs + + @staticmethod + def input_index(node_output, child_node): + for index, input in enumerate(child_node.input): + if input == node_output: + return index + return -1 + + def remove_unused_constant(self): + input_name_to_nodes = self.input_name_to_nodes() + + # remove unused constant + unused_nodes = [] + nodes = self.nodes() + for node in nodes: + if node.op_type == "Constant" and node.output[0] not in input_name_to_nodes: + unused_nodes.append(node) + + self.remove_nodes(unused_nodes) + + if len(unused_nodes) > 0: + logger.debug(f"Removed unused constant nodes: {len(unused_nodes)}") + + def prune_graph(self, outputs=None, allow_remove_graph_inputs=True): + """ + Prune graph to keep only required outputs. It removes unnecessary nodes that are not linked + (directly or indirectly) to any required output. + + There is also an option to remove graph inputs that are not used to generate any required output. + + Args: + outputs (list): a list of graph outputs to retain. If it is None, all graph outputs will be kept. + allow_remove_graph_inputs (bool): allow remove graph inputs. + """ + + if len(self.graphs()) > 1: + # TODO(tianleiwu): handle subgraph + logger.debug("Skip prune_graph since graph has subgraph") + return + + keep_outputs = [output.name for output in self.model.graph.output] if outputs is None else outputs + + output_name_to_node = self.output_name_to_node() + + def get_first_output(node): + if node.output[0]: + return node.output[0] + return next(iter([o for o in node.output if o]), None) + + # Keep track of nodes to keep. The key is first output of node, and the value is the node. + output_to_node = {} + + # Start from graph outputs, and find parent nodes recursively, and add nodes to the output_to_node dictionary. + dq = deque() + for output in keep_outputs: + if output in output_name_to_node: + dq.append(output_name_to_node[output]) + while len(dq) > 0: + node = dq.pop() + first_output = get_first_output(node) + if first_output and (first_output not in output_to_node): + output_to_node[first_output] = node + for name in node.input: + if len(name) > 0 and (name in output_name_to_node) and (name not in output_to_node): + dq.appendleft(output_name_to_node[name]) + + # Keep only those nodes in the output_to_node dictionary. + nodes_to_keep = [] + num_nodes_removed = 0 + for node in self.model.graph.node: + first_output = get_first_output(node) + kept_node = output_to_node.get(first_output) + + # Need double check the node since fused node might reuse output name of some nodes to be removed. + # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases. + if kept_node and kept_node.op_type == node.op_type and kept_node == node: + nodes_to_keep.append(node) + else: + num_nodes_removed += 1 + self.model.graph.ClearField("node") + self.model.graph.node.extend(nodes_to_keep) + + # Remove graph outputs not in list + output_to_remove = [] + if outputs is not None: + for output in self.model.graph.output: + if output.name not in outputs: + output_to_remove.append(output) + for output in output_to_remove: + self.model.graph.output.remove(output) + + # Remove graph inputs not used by any node. + input_to_remove = [] + if allow_remove_graph_inputs: + input_name_to_nodes = self.input_name_to_nodes() + input_to_remove = [input for input in self.model.graph.input if input.name not in input_name_to_nodes] + for name in input_to_remove: + self.model.graph.input.remove(name) + + if input_to_remove or output_to_remove or num_nodes_removed > 0: + removed = [] + if input_to_remove: + removed.append(f"{len(input_to_remove)} inputs") + if output_to_remove: + removed.append(f"{len(output_to_remove)} outputs") + if num_nodes_removed > 0: + removed.append(f"{num_nodes_removed} nodes") + logger.info("Removed %s", ", ".join(removed)) + + self.update_graph() + + def update_graph(self, verbose=False, allow_remove_graph_inputs=False): + graph = self.model.graph + + remaining_input_names = [] + for node in graph.node: + if node.op_type in ["Loop", "Scan", "If"]: + # TODO: handle inner graph + logger.debug(f"Skip update_graph since graph has operator: {node.op_type}") + return + if node.op_type != "Constant": + for input_name in node.input: + if input_name not in remaining_input_names: + remaining_input_names.append(input_name) + if verbose: + logger.debug(f"remaining input names: {remaining_input_names}") + + # remove graph input that is not used + inputs_to_remove = [] + if allow_remove_graph_inputs: + for input in graph.input: + if input.name not in remaining_input_names: + inputs_to_remove.append(input) + for input in inputs_to_remove: + graph.input.remove(input) + + names_to_remove = [input.name for input in inputs_to_remove] + logger.debug(f"remove {len(inputs_to_remove)} unused inputs: {names_to_remove}") + + # remove weights that are not used + weights_to_remove = [] + weights_to_keep = [] + for initializer in graph.initializer: + if initializer.name not in remaining_input_names and not self.find_graph_output(initializer.name): + weights_to_remove.append(initializer) + else: + weights_to_keep.append(initializer.name) + for initializer in weights_to_remove: + graph.initializer.remove(initializer) + + names_to_remove = [initializer.name for initializer in weights_to_remove] + logger.debug(f"remove {len(weights_to_remove)} unused initializers: {names_to_remove}") + if verbose: + logger.debug(f"remaining initializers:{weights_to_keep}") + + self.remove_unused_constant() + + def is_safe_to_fuse_nodes(self, nodes_to_remove, keep_outputs, input_name_to_nodes, output_name_to_node): + for node_to_remove in nodes_to_remove: + for output_to_remove in node_to_remove.output: + if output_to_remove in keep_outputs: + continue + + if output_to_remove in input_name_to_nodes: + for impacted_node in input_name_to_nodes[output_to_remove]: + if impacted_node not in nodes_to_remove: + logger.debug( + "it is not safe to remove nodes since output %s is used by %s", + output_to_remove, + impacted_node, + ) + return False + return True + + @staticmethod + def graph_topological_sort(graph, is_deterministic=False): + deps_set = set() # dependency set of all node + sorted_node_set = set() # sorted node set + sorted_nodes = [] # initialize sorted_nodes + + initializer_names = [init.name for init in graph.initializer] + graph_input_names = [input.name for input in graph.input] + input_names = initializer_names + graph_input_names + + if is_deterministic: + input_names.sort() + + for input_name in input_names: + deps_set.add(input_name) + + sorted_node_set_len = -1 + graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name) + + last_node_name = None + while len(sorted_node_set) != len(graph_nodes): + if len(sorted_node_set) == sorted_node_set_len: + break + sorted_node_set_len = len(sorted_node_set) + for node_idx, node in enumerate(graph_nodes): + if node_idx in sorted_node_set: + continue + input_count = sum(1 for _ in node.input if _) + if input_count == 0: + sorted_nodes.append(node) + sorted_node_set.add(node_idx) + for output in node.output: + if output: + deps_set.add(output) + continue + failed = False + for input_name in node.input: + if input_name and input_name not in deps_set: + failed = True + last_node_name = node.name + if not failed: + sorted_nodes.append(node) + sorted_node_set.add(node_idx) + for output in node.output: + if output: + deps_set.add(output) + else: + continue + + if len(sorted_node_set) != len(graph.node): + raise RuntimeError( + f"Graph is not a DAG: len(sorted_node_set)={len(sorted_node_set)}, len(graph.node)={len(graph.node)}, failed at node {last_node_name}" + ) + + graph.ClearField("node") + graph.node.extend(sorted_nodes) + + def topological_sort(self, is_deterministic=False): + # TODO: support graph_topological_sort() in subgraphs + # for graph in self.graphs(): + # self.graph_topological_sort(graph) + OnnxModel.graph_topological_sort(self.model.graph, is_deterministic) + + @staticmethod + def save( + model, + output_path, + save_as_external_data=False, + all_tensors_to_one_file=True, + size_threshold=1024, + convert_attribute=False, + ): + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + # Add ms domain if needed + ms_opset = [opset for opset in model.opset_import if opset.domain == "com.microsoft"] + # Check whether there is custom op in top level graph (our fusion is on top level right now). + # May need to extend to subgraph if our fusion are extended to subgraphs. + ms_node = [node for node in model.graph.node if node.domain == "com.microsoft"] + if ms_node and not ms_opset: + opset = model.opset_import.add() + opset.version = 1 + opset.domain = "com.microsoft" + + if save_as_external_data: + # Save model to external data, which is needed for model size > 2GB + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + external_data_path = output_path + ".data" + location = Path(external_data_path).name if all_tensors_to_one_file else None + + if os.path.exists(output_path): + logger.info(f"Delete the existing onnx file: {output_path}") + os.remove(output_path) + + if all_tensors_to_one_file: + if os.path.exists(external_data_path): + # Delete the external data file. Otherwise, data will be appended to existing file. + logger.info(f"Delete the existing external data file: {external_data_path}") + os.remove(external_data_path) + else: + if os.listdir(output_dir): + raise RuntimeError(f"Output directory ({output_dir}) for external data is not empty.") + + save_model( + model, + output_path, + save_as_external_data=True, + all_tensors_to_one_file=all_tensors_to_one_file, + location=location, + size_threshold=size_threshold, + convert_attribute=convert_attribute, + ) + else: + save_model(model, output_path) + + def save_model_to_file(self, output_path, use_external_data_format=False, all_tensors_to_one_file=True): + logger.info("Sort graphs in topological order") + self.topological_sort() + + # Note: After the model is saved to another directory with external data, + # You need reload the onnx model if you want to read tensor from self.model object. + # It is because the base directory is not updated for self.model object so attempt to read tensor data + # might encounter error since external data cannot be located. + OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file) + logger.info(f"Model saved to {output_path}") + + def get_graph_inputs_excluding_initializers(self): + """ + Returns real graph inputs (excluding initializers from older onnx model). + """ + graph_inputs = [] + for input in self.model.graph.input: + if self.get_initializer(input.name) is None: + graph_inputs.append(input) + return graph_inputs + + def get_opset_version(self): + """Get opset version of onnx domain + + Raises: + RuntimeError: ONNX model has no opset for default domain. + + Returns: + int: opset version of onnx domain. + """ + for opset in self.model.opset_import: + if opset.domain in ["", "ai.onnx"]: + return opset.version + raise RuntimeError("ONNX model has no opset for default domain") + + def get_operator_statistics(self, include_domain=False): + """ + Returns node count of operators. + """ + op_count = {} + for node in self.nodes(): + op = (node.domain + ":" if include_domain and node.domain else "") + node.op_type + op_count[op] = 1 if op not in op_count else (op_count[op] + 1) + + # Sorted by count in the descending order, then by key in alphabetical order. + logger.info(f"Operators:{sorted(op_count.items(), key=lambda kv:(-kv[1], kv[0]))}") + + return op_count + + @staticmethod + def to_data_hash(tensor: TensorProto, base_dir: str = "") -> int: + """Converts a tensor def object to a hash for data comparison purposes. + Args: + tensor: a TensorProto object. + base_dir: if external tensor exists, base_dir can help to find the path to it + Returns: + hash: a hash of the data. + """ + if tensor.HasField("segment"): + raise ValueError("Currently not supporting loading segments.") + if tensor.data_type == TensorProto.UNDEFINED: + raise TypeError("The element type in the input tensor is not defined.") + tensor_dtype = tensor.data_type + storage_field = helper.tensor_dtype_to_field(tensor_dtype) + + if tensor.data_type == TensorProto.STRING: + utf8_strings = getattr(tensor, storage_field) + return hash(tuple(s.decode("utf-8") for s in utf8_strings)) + # Load raw data from external tensor if it exists + if uses_external_data(tensor): + load_external_data_for_tensor(tensor, base_dir) + if tensor.HasField("raw_data"): + return hash(tensor.raw_data) + else: + np_data = numpy_helper.to_array(tensor) + return hash(np_data.tobytes()) + + @staticmethod + def has_same_value( + tensor1: TensorProto, + tensor2: TensorProto, + signature_cache1: Optional[dict] = None, + signature_cache2: Optional[dict] = None, + ) -> bool: + """Returns True when two tensors have same value. + Note that name can be different. + + Args: + tensor1 (TensorProto): initializer 1 + tensor2 (TensorProto): initializer 2 + signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison. + signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison. + Returns: + bool: True when two initializers has same value. + """ + sig1 = ( + signature_cache1[tensor1.name] + if signature_cache1 and tensor1.name in signature_cache1 + else OnnxModel.to_data_hash(tensor1) + ) + sig2 = ( + signature_cache2[tensor2.name] + if signature_cache2 and tensor2.name in signature_cache2 + else OnnxModel.to_data_hash(tensor2) + ) + if signature_cache1 is not None: + signature_cache1[tensor1.name] = sig1 + if signature_cache2 is not None: + signature_cache2[tensor2.name] = sig2 + if sig1 == sig2 and tensor1.data_type == tensor2.data_type and tensor1.dims == tensor2.dims: + # Same signature, now do the expensive check to confirm the data is the same + return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all() + + return False + + def remove_duplicated_initializer(self, cache: Optional[dict] = None): + """Remove initializers with duplicated values, and only keep the first one. + It could help reduce size of models (like ALBert) with shared weights. + If require_raw_data passed, method will only compare raw_data initializers to speed runtime + Note: this function does not process subgraph. + """ + if len(self.graphs()) > 1: + logger.warning("remove_duplicated_initializer does not process subgraphs.") + + initializer_count = len(self.model.graph.initializer) + + same = [-1] * initializer_count + for i in range(initializer_count - 1): + if same[i] >= 0: + continue + for j in range(i + 1, initializer_count): + if OnnxModel.has_same_value( + self.model.graph.initializer[i], + self.model.graph.initializer[j], + cache, + cache, + ): + same[j] = i + + count = 0 + for i in range(initializer_count): + if same[i] >= 0: + count += 1 + self.replace_input_of_all_nodes( + self.model.graph.initializer[i].name, + self.model.graph.initializer[same[i]].name, + ) + + if count > 0: + self.update_graph() + print(f"Removed {count} initializers with duplicated value") + + def add_prefix_to_names(self, prefix: str): + """Add prefix to initializer or intermediate outputs in graph. Main graph inputs and outputs are excluded. + It could help avoid conflicting in name of node_args when merging two graphs. + Note: this function does not process subgraph. + """ + if len(self.graphs()) > 1: + logger.warning("add_prefix_to_names does not process subgraphs.") + + # Exclude the names of inputs and outputs of main graph (but not subgraphs) + # and empty names ("") as they have special meaning to denote missing optional inputs + excluded = [i.name for i in self.model.graph.input] + [o.name for o in self.model.graph.output] + [""] + + for initializer in self.model.graph.initializer: + if initializer.name not in excluded: + if prefix + initializer.name not in excluded: + initializer.name = prefix + initializer.name + + for node in self.model.graph.node: + # update name of node inputs + for j in range(len(node.input)): + if node.input[j] not in excluded: + if prefix + node.input[j] not in excluded: + node.input[j] = prefix + node.input[j] + + # update name of node outputs + for j in range(len(node.output)): + if node.output[j] not in excluded: + if prefix + node.output[j] not in excluded: + node.output[j] = prefix + node.output[j] + + for value_info in self.model.graph.value_info: + if value_info.name not in excluded: + value_info.name = prefix + value_info.name + + def clean_shape_infer(self): + self.model.graph.ClearField("value_info") + + def use_float16(self): + """Check whether the model uses float16""" + queue = [] # queue for BFS + queue.append(self.model.graph) + while queue: + sub_graphs = [] + for graph in queue: + if not isinstance(graph, GraphProto): + continue + + for v in itertools.chain(graph.input, graph.output, graph.value_info): + if v.type.tensor_type.elem_type == TensorProto.FLOAT16: + return True + if v.type.HasField("sequence_type"): + if v.type.sequence_type.elem_type.tensor_type.elem_type == TensorProto.FLOAT16: + return True + + for t in graph.initializer: + if t.data_type == TensorProto.FLOAT16: + return True + + for node in graph.node: + if node.op_type == "Cast": + for attr in node.attribute: + if attr.name == "to" and attr.i == TensorProto.FLOAT16: + return True + + for attr in node.attribute: + if attr.type == AttributeProto.GRAPH: + sub_graphs.append(attr.g) + + for g in attr.graphs: + sub_graphs.append(g) # noqa: PERF402 + + if isinstance(attr.t, TensorProto) and attr.t.data_type == TensorProto.FLOAT16: + return True + + for t in attr.tensors: + if isinstance(t, TensorProto) and t.data_type == TensorProto.FLOAT16: + return True + + queue = sub_graphs + + return False + + def change_graph_input_type( + self, + graph_input: ValueInfoProto, + new_type: int, + ): + """Change graph input type, and add Cast node if needed. + + Args: + graph_input (ValueInfoProto): input of the graph + new_type (int): new data type like TensorProto.INT32. + + Returns: + NodeProto: a new Cast node that added. None if Cast node is not added. + List[NodeProto]: Cast nodes that have been removed. + """ + assert isinstance(graph_input, ValueInfoProto) + assert self.find_graph_input(graph_input.name) + + if graph_input.type.tensor_type.elem_type == int(new_type): + return None, [] + + graph = self.graph() + new_cast_node = None + nodes_to_remove = [] + + input_name_to_nodes = self.input_name_to_nodes() + if graph_input.name in input_name_to_nodes: + nodes = input_name_to_nodes[graph_input.name] + + # For children that is not Cast node, insert a Cast node to convert int32 to original data type. + nodes_not_cast = [node for node in nodes if node.op_type != "Cast"] + if nodes_not_cast: + node_name = self.create_node_name("Cast") + output_name = node_name + "_" + graph_input.name + new_value_info = graph.value_info.add() + new_value_info.CopyFrom(graph_input) + new_value_info.name = output_name + new_cast_node = helper.make_node( + "Cast", + [graph_input.name], + [output_name], + to=int(graph_input.type.tensor_type.elem_type), + name=node_name, + ) + graph.node.extend([new_cast_node]) + + for node in nodes_not_cast: + OnnxModel.replace_node_input(node, graph_input.name, output_name) + + # For children that is Cast node, no need to insert Cast. + # When the children is Cast to int32, we can remove that Cast node since input type is int32 now. + nodes_cast = [node for node in nodes if node.op_type == "Cast"] + for node in nodes_cast: + if OnnxModel.get_node_attribute(node, "to") == int(new_type): + self.replace_input_of_all_nodes(node.output[0], graph_input.name) + if not self.find_graph_output(node.output[0]): + nodes_to_remove.append(node) + if nodes_to_remove: + self.remove_nodes(nodes_to_remove) + + graph_input.type.tensor_type.elem_type = int(new_type) + return new_cast_node, nodes_to_remove + + def change_graph_output_type( + self, + graph_output: ValueInfoProto, + new_type: int, + ): + """Change graph input type, and add Cast node if needed. + + Args: + graph_input (str | ValueInfoProto): output of the graph + new_type (int): new data type. + + Returns: + NodeProto: a new Cast node that added. None if Cast node is not added. + """ + assert isinstance(graph_output, ValueInfoProto) + assert self.find_graph_output(graph_output.name) + + if graph_output.type.tensor_type.elem_type == int(new_type): + return None + + cast_node = None + graph = self.graph() + + # Add a cast node + node_name = self.create_node_name("Cast") + input_name = node_name + "_" + graph_output.name + self.replace_input_of_all_nodes(graph_output.name, input_name) + new_value_info = graph.value_info.add() + new_value_info.CopyFrom(graph_output) + new_value_info.name = input_name + cast_node = helper.make_node( + "Cast", + [input_name], + [graph_output.name], + to=int(new_type), + name=node_name, + ) + graph.node.extend([cast_node]) + graph_output.type.tensor_type.elem_type = int(new_type) + return cast_node + + def rename_graph_output(self, old_name: str, new_name: str): + if new_name in self.output_name_to_node(): + raise RuntimeError("{new_name} exists in graph") + + graph = self.graph() + for output in graph.output: + if output.name == old_name: + logger.debug("replace output name from %s to %s", old_name, new_name) + self.replace_input_of_all_nodes(old_name, new_name) + self.replace_output_of_all_nodes(old_name, new_name) + output.name = new_name diff --git a/transformers/onnx_model_phi.py b/transformers/onnx_model_phi.py new file mode 100644 index 000000000..f7fc8ba99 --- /dev/null +++ b/transformers/onnx_model_phi.py @@ -0,0 +1,928 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import List, Optional + +import numpy as np +from dynamo_onnx_helper import DynamoOnnxHelper +from fusion_base import Fusion +from fusion_options import AttentionOpType, FusionOptions +from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization +from fusion_utils import NumpyHelper +from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class ProcessGemmWFunc: + def __call__(self, x): + return np.transpose(x, (1, 0)) + + +class ProcessMatMulQFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[0], (1, 0)) + + +class ProcessMatMulKFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[1], (1, 0)) + + +class ProcessMatMulVFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[2], (1, 0)) + + +class ProcessBiasQFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[0] + return x + + +class ProcessBiasKFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[1] + return x + + +class ProcessBiasVFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[2] + return x + + +class ProcessRotCacheFunc: + def __call__(self, x): + # half rotary embedding + assert len(x.shape) == 2 + if x.shape[1] == 32: + return x[:, 0:16] + return x + + +# TODO: move to a seperate file +class Fission(Fusion): + def __init__( + self, + model: OnnxModel, + nodes_to_find: List[str], + ): + super().__init__(model, "DONOTUSE", nodes_to_find) + + def set_attention_op_type(self, attn_op_type: AttentionOpType): + self.attn_op_type = attn_op_type + + def get_uname(self, layer_id, name): + return name + "_" + str(layer_id) + + def get_edge_by_name(self, edges, name): + for edge in edges: + if edge == name or edge.endswith(name) or edge.startswith(name): + return edge + raise ValueError(f"Edge {name} not found") + + def get_input_by_name(self, node, name): + return self.get_edge_by_name(node.input, name) + + def get_output_by_name(self, node, name): + return self.get_edge_by_name(node.output, name) + + def process_initializer(self, initializer_name, functor, custom_name=None): + i = self.model.get_initializer(initializer_name) + i_np_array = NumpyHelper.to_array(i) + processed_i_np_array = functor(i_np_array) + new_tensor = helper.make_tensor( + initializer_name + "_processed" if custom_name is None else custom_name, + data_type=TensorProto.FLOAT, + dims=processed_i_np_array.shape, + vals=processed_i_np_array.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(new_tensor, self.this_graph_name) + return new_tensor.name + + def add_fp32_value_info(self, name): + new_value_info = self.model.graph().value_info.add() + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + + def add_int64_value_info(self, name): + new_value_info = self.model.graph().value_info.add() + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = TensorProto.INT64 + + def replace_fp32_value_info(self, name, shape): + for value_info in self.model.graph().value_info: + if value_info.name == name: + self.model.graph().value_info.remove(value_info) + break + new_value_info = helper.make_tensor_value_info( + name, + elem_type=TensorProto.FLOAT, + shape=shape, + ) + self.model.graph().value_info.extend([new_value_info]) + + def set_unique_name_and_add_nodes( + self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str] + ): + for new_node in subgraph_nodes: + for i, name in enumerate(new_node.input): + if name == "": + continue + elif name not in layer_known_edges_names: + new_node.input[i] = self.get_uname(layer_id, name) + self.add_fp32_value_info(new_node.input[i]) + for i, name in enumerate(new_node.output): + if name == "": + continue + elif name not in layer_known_edges_names: + new_node.output[i] = self.get_uname(layer_id, name) + self.add_fp32_value_info(new_node.output[i]) + new_node.name = self.get_uname(layer_id, new_node.name) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + node = helper.make_node( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + name=prefix + "_LayerNormalization", + epsilon=9.999999747378752e-06, + ) + return [node] + + def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + matmul = helper.make_node( + "MatMul", + inputs=[inputs[0], inputs[1]], + outputs=[prefix + "matmul_out"], + name=prefix + "MatMul", + ) + add = helper.make_node( + "Add", + inputs=[prefix + "matmul_out", inputs[2]], + outputs=outputs, + name=prefix + "Bias", + ) + return [matmul, add] + + def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32): + assert len(inputs) == 4 + assert len(outputs) == 1 + node = helper.make_node( + "RotaryEmbedding", + inputs=inputs, + outputs=outputs, + name=prefix + "RotaryEmbedding", + domain="com.microsoft", + rotary_embedding_dim=rot_dim, + num_heads=num_heads, + ) + return [node] + + def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 1 + assert len(outputs) == 1 + node = helper.make_node( + "FastGelu", + inputs=inputs, + outputs=outputs, + name=prefix + "FastGelu", + domain="com.microsoft", + ) + return [node] + + def add(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 2 + assert len(outputs) == 1 + node = helper.make_node( + "Add", + inputs=inputs, + outputs=outputs, + name=prefix + "Add", + ) + return [node] + + def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 8 + assert len(outputs) == 3 + node = helper.make_node( + "MultiHeadAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "MultiHeadAttention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + ) + return [node] + + def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 7 + assert len(outputs) == 3 + node = helper.make_node( + "GroupQueryAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "GroupQueryAttention", + domain="ai.onnx.contrib", + num_heads=num_heads, + kv_num_heads=num_heads, + ) + return [node] + + def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 5 + assert len(outputs) == 2 + node = helper.make_node( + "Attention", + inputs=inputs, + outputs=outputs, + name=prefix + "Attention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + do_rotary=1, + rotary_embedding_dim=32, + ) + return [node] + + def paged_attn( + self, + inputs: List[str], + outputs: List[str], + prefix: str = "", + num_heads=32, + head_size=80, + scale=0.11180339753627777, + ): + assert len(inputs) == 6 + assert len(outputs) == 1 + node = helper.make_node( + "PagedAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "PagedAttention", + domain="vllm.ort.ext", + num_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_size, + scale=scale, + ) + return [node] + + +class Phi2PreProcessor(DynamoOnnxHelper): + def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): + super().__init__(model) + self.num_hidden_layers = 32 + self.num_attention_heads = num_heads + self.hidden_size = hidden_size + + self.func_name = "modeling_phi_PhiModel_model_1" + + def get_phi2_edge_dict(self) -> dict: + edge_dict = {} + edge_dict["lm_head_1"] = "logits" + edge_dict["l_input_ids_"] = "input_ids" + edge_dict["key_states"] = "past_key_0" + edge_dict["value_states"] = "past_value_0" + for i in range(1, self.num_hidden_layers, 1): + edge_dict[f"key_states_{i}"] = f"past_key_{i}" + edge_dict[f"value_states_{i}"] = f"past_value_{i}" + edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}" + edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}" + + outputs = [o.name for o in self.model.graph.output] + if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs: + edge_dict["model_layers_0_1_1"] = "present_key_0" + edge_dict["model_layers_0_1_2"] = "present_value_0" + else: + assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs + edge_dict["model_layers_0_1"] = "present_key_0" + edge_dict["model_layers_0_1_1"] = "present_value_0" + return edge_dict + + def simplify_phi2_op_type(self): + phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers" + for node in self.model.graph.node: + index = node.op_type.find(phi2_transformer_layer_name) + if index != -1: + node.op_type = node.op_type[index:] + + def process_graph_io(self, attn_op_type: AttentionOpType): + self.use_attn = attn_op_type == AttentionOpType.Attention + self.use_vllm = attn_op_type == AttentionOpType.PagedAttention + graph = self.model.graph + new_inputs = [] + for vi in graph.input: + if "input_ids" in vi.name: + vi_iid = helper.make_tensor_value_info( + vi.name, + elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64, + shape=["batch_size", "seq_len"], + ) + vi_step = helper.make_tensor_value_info( + "step", + elem_type=TensorProto.INT64, + shape=[1], + ) + vi_pid = helper.make_tensor_value_info( + "position_ids", + elem_type=TensorProto.INT64, + shape=["batch_size", "seq_len"], + ) + vi_mask = helper.make_tensor_value_info( + "attention_mask", + elem_type=TensorProto.INT32, + shape=["batch_size", "seq_len"], + ) + vi_meta = helper.make_tensor_value_info( + "input_metadata", + elem_type=TensorProto.INT64, + shape=[1], + ) + new_inputs.extend([vi_iid, vi_step, vi_mask]) if not self.use_vllm else new_inputs.extend( + [vi_iid, vi_pid, vi_meta] + ) + if self.use_attn: + if "past_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("past_key", "past"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + self.num_attention_heads, + "past_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + elif self.use_vllm: + if "past_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"], + ) + new_inputs.extend([vi_cache]) + if "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "num_blocks", + "num_heads", + "head_size", + "block_size", + ], + ) + new_inputs.extend([vi_cache]) + else: + if "past_key" in vi.name or "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "past_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + + graph.ClearField("input") + graph.input.extend(new_inputs) + + new_outputs = [] + for i, vi in enumerate(graph.output): + if i == 0: + new_outputs.extend([vi]) + else: + if self.use_attn: + if "present_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("present_key", "present"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) + elif self.use_vllm: + pass + else: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) + + graph.ClearField("output") + graph.output.extend(new_outputs) + + def preprocess_onnx(self, attn_op_type: AttentionOpType): + function_name = None + for func in self.model.functions: + if func.name.endswith(self.func_name): + function_name = func.name + break + assert function_name is not None + self.unroll_function(function_name) + self.update_edges(self.get_phi2_edge_dict()) + self.simplify_phi2_op_type() + self.remove_dropout_layer() + if attn_op_type == AttentionOpType.PagedAttention: + self.remove_lm_head_layer() + self.process_graph_io(attn_op_type) + + +class FissionTransformerEmbeddingPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 2 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + embedding = self.get_input_by_name(node, "embed_tokens.weight") + + layer_known_edges_names = [input, output, embedding] + + subgraph_nodes = [ + helper.make_node( + "Gather", + inputs=[embedding, input], + outputs=[output], + name="Embedding_Gather", + ), + ] + + self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names) + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerLayerNormPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 3 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + ln_weight = self.get_input_by_name(node, "final_layernorm.weight") + ln_bias = self.get_input_by_name(node, "final_layernorm.bias") + + layer_known_edges_names = [input, output, ln_weight, ln_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerCausalLMHeadPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 5 + assert len(node.output) == 1 + + input = node.input[2] + output = node.output[0] + + fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) + fc_bias = self.get_input_by_name(node, "lm_head.bias") + + layer_known_edges_names = [input, output, fc_weight, fc_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerBlockPhi(Fission): + def __init__( + self, + model: OnnxModel, + num_heads: int, + ): + self.num_heads = num_heads + max_num_layers = 32 + self.func_to_layer_id = {} + nodes_to_find = [] + for layer in range(max_num_layers): + func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1" + nodes_to_find.append(func_name) + self.func_to_layer_id[func_name] = layer + + super().__init__(model, nodes_to_find) + + def get_layer_id(self, node): + return self.func_to_layer_id[node.op_type] + + def get_gqa_aux_nodes(self): + gqa_aux_nodes = [ + helper.make_node( + "Cast", + inputs=["attention_mask"], + outputs=["mask_int64"], + name="Cast_gqa_aux_0", + to=TensorProto.INT64, + ), + helper.make_node( + "ReduceSum", + inputs=["mask_int64", "one"], + outputs=["mask_row_sums"], + name="ReduceSum_gqa_aux", + ), + helper.make_node( + "Sub", + inputs=["mask_row_sums", "one"], + outputs=["seqlens_k_int64"], + name="Sub_gqa_aux", + ), + helper.make_node( + "Cast", + inputs=["seqlens_k_int64"], + outputs=["seqlens_k"], + name="Cast_gqa_aux_1", + to=TensorProto.INT32, + ), + helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"), + helper.make_node( + "Gather", + inputs=["mask_shape", "one"], + outputs=["total_seq_len_int64"], + name="Gather_gqa_aux_0", + axis=0, + ), + helper.make_node( + "Cast", + inputs=["total_seq_len_int64"], + outputs=["total_sequence_length"], + name="Cast_gqa_aux_2", + to=TensorProto.INT32, + ), + ] + return gqa_aux_nodes + + def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name): + q_weight = self.model.get_initializer(q_w) + k_weight = self.model.get_initializer(k_w) + v_weight = self.model.get_initializer(v_w) + qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0)) + kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0)) + vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0)) + qkv_weight = np.stack((qw, kw, vw), axis=1) + + q_bias = self.model.get_initializer(q_b) + k_bias = self.model.get_initializer(k_b) + v_bias = self.model.get_initializer(v_b) + qb = NumpyHelper.to_array(q_bias) + kb = NumpyHelper.to_array(k_bias) + vb = NumpyHelper.to_array(v_bias) + qkv_bias = np.stack((qb, kb, vb), axis=0) + + hidden_size = qkv_weight.shape[0] + + weight = helper.make_tensor( + weight_name, + data_type=TensorProto.FLOAT, + dims=[hidden_size, hidden_size * 3], + vals=qkv_weight.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(weight, self.this_graph_name) + + bias = helper.make_tensor( + bias_name, + data_type=TensorProto.FLOAT, + dims=[hidden_size * 3], + vals=qkv_bias.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(bias, self.this_graph_name) + + self.add_fp32_value_info(weight.name) + self.add_fp32_value_info(bias.name) + + return weight_name, bias_name + + def fuse( + self, + node, + input_name_to_nodes, + output_name_to_node, + ): + logger.info("Optimizing %s...", node.name) + + logger.info(f"AttentionOpType: {self.attn_op_type}") + + layer_id = self.get_layer_id(node) + + i_hidden_states = node.input[0] + i_key_cache = self.get_input_by_name(node, "past_key") + i_value_cache = self.get_input_by_name(node, "past_value") + + o_hidden_states = node.output[-1] + o_key_cache = self.get_output_by_name(node, "present_key") + o_value_cache = self.get_output_by_name(node, "present_value") + + ln_weight = self.get_input_by_name(node, "input_layernorm.weight") + ln_bias = self.get_input_by_name(node, "input_layernorm.bias") + + attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = ( + None, + None, + None, + None, + None, + None, + ) + attn_qkv_weight, attn_qkv_bias = None, None + cos_cache, sin_cache = None, None + + if self.attn_op_type != AttentionOpType.Attention: + attn_q_weight = self.process_initializer( + self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() + ) + attn_k_weight = self.process_initializer( + self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() + ) + attn_v_weight = self.process_initializer( + self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() + ) + attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias") + attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias") + attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias") + + cos_cache = self.process_initializer( + self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() + ) + sin_cache = self.process_initializer( + self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() + ) + else: + attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm( + self.get_input_by_name(node, "self_attn.q_proj.weight"), + self.get_input_by_name(node, "self_attn.k_proj.weight"), + self.get_input_by_name(node, "self_attn.v_proj.weight"), + self.get_input_by_name(node, "self_attn.q_proj.bias"), + self.get_input_by_name(node, "self_attn.k_proj.bias"), + self.get_input_by_name(node, "self_attn.v_proj.bias"), + self.get_uname(layer_id, "attn_qkv_weight"), + self.get_uname(layer_id, "attn_qkv_bias"), + ) + + attn_out_weight = self.process_initializer( + self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() + ) + attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias") + + mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) + mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) + mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias") + mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias") + + layer_known_edges_names = [] + layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache]) + layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache]) + layer_known_edges_names.extend([ln_weight, ln_bias]) + if self.attn_op_type != AttentionOpType.Attention: + layer_known_edges_names.extend( + [ + attn_q_weight, + attn_q_bias, + attn_k_weight, + attn_k_bias, + attn_v_weight, + attn_v_bias, + cos_cache, + sin_cache, + ] + ) + else: + layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias]) + layer_known_edges_names.extend( + [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias] + ) + layer_known_edges_names.extend( + ["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"] + ) + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"])) + subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_")) + subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_")) + subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"])) + subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_")) + subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1")) + subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2")) + if self.attn_op_type != AttentionOpType.Attention: + subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + # vllm engine requires full position ids as the input + pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step" + subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_")) + subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_")) + if self.attn_op_type == AttentionOpType.MultiHeadAttention: + subgraph_nodes.extend( + self.mha( + ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache], + ["attn_out", o_key_cache, o_value_cache], + ) + ) + elif self.attn_op_type == AttentionOpType.GroupQueryAttention: + subgraph_nodes.extend( + self.gqa( + [ + "query_rot", + "key_rot", + "value", + i_key_cache, + i_value_cache, + "seqlens_k", + "total_sequence_length", + ], + ["attn_out", o_key_cache, o_value_cache], + ) + ) + if layer_id == 0: + gqa_aux_nodes = self.get_gqa_aux_nodes() + for new_node in gqa_aux_nodes: + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.model.add_initializer( + numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name + ) + elif self.attn_op_type == AttentionOpType.PagedAttention: + subgraph_nodes.extend( + self.paged_attn( + ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"], + ["attn_out"], + ) + ) + else: + past_name = f"past_{layer_id}" + present_name = f"present_{layer_id}" + layer_known_edges_names.extend([past_name, present_name]) + subgraph_nodes.extend( + self.attention( + ["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name] + ) + ) + + self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names) + + self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class PhiOnnxModel(OnnxModel): + def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): + super().__init__(model) + self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size) + self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads) + self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self) + self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self) + self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self) + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + assert options is not None + attn_op_type = options.attention_op_type + + self.fission_transformer_block.set_attention_op_type(attn_op_type) + + self.phi2_preprocessor.preprocess_onnx(attn_op_type) + + self.fission_transformer_block.apply() + self.fission_transformer_layernorm.apply() + self.fission_causal_lm_head.apply() + self.fission_transformer_embedding.apply() + + super().prune_graph() + + # SLN ctor is placed here intentionally to delay the symbolic shape inference + self.fuse_sln = FusionSkipLayerNormalization(self) + self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self) + self.fuse_sln.apply() + self.fuse_bias_sln.apply() + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "Attention", + "MultiHeadAttention", + "GroupQueryAttention", + "PagedAttention", + "Gelu", + "BiasGelu", + "FastGelu", + "LayerNormalization", + "SkipLayerNormalization", + ] + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators: {op_count}") + return op_count + + def is_fully_optimized(self, fused_op_count=None): + """ + Returns True when the model is fully optimized. + """ + if fused_op_count is None: + fused_op_count = self.get_fused_operator_statistics() + + def op_count(op_name: str): + return fused_op_count.get(op_name) or 0 + + attention = ( + op_count("Attention") + + op_count("MultiHeadAttention") + + op_count("GroupQueryAttention") + + op_count("PagedAttention") + ) + gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") + layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization") + + is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention) + + if layer_norm == 0: + logger.debug("Layer Normalization not fused") + + if gelu == 0: + logger.debug("Gelu (or FastGelu) not fused") + + if attention == 0: + logger.warning("Attention (or MultiHeadAttention) not fused") + + return is_perfect diff --git a/transformers/optimize_GroupQueryAttention.py b/transformers/optimize_GroupQueryAttention.py new file mode 100644 index 000000000..2be92ea7e --- /dev/null +++ b/transformers/optimize_GroupQueryAttention.py @@ -0,0 +1,120 @@ +import argparse +import logging +import torch + +from onnx import ModelProto, load_model +from transformers import AutoConfig +from typing import Dict, List, Optional + +from fusion_options import FusionOptions, AttentionOpType +from onnx_model_phi import PhiOnnxModel + +# Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level +MODEL_TYPES = { + "phi": (PhiOnnxModel, "pytorch", 0), +} + +def optimize_by_fusion( + model: ModelProto, + model_type: str = "bert", + num_heads: int = 0, + hidden_size: int = 0, + optimization_options: Optional[FusionOptions] = None, +): + """Optimize Model by graph fusion logic. + + Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable + constant folding during exporting ONNX model, or run optimize_by_onnxruntime on the model first like optimize_model. + + For BERT model, num_heads and hidden_size are optional. For other model types, you need to specify these parameters. + + Args: + model (ModelProto): model object + model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. + num_heads (int, optional): number of attention heads. Defaults to 0. + 0 allows detect the parameter from graph automatically. + hidden_size (int, optional): hidden size. Defaults to 0. + 0 allows detect the parameter from graph automatically. + optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. + Defaults to None. + + Returns: + object of an optimizer class. + """ + if model_type not in ["bert", "swin", "unet", "vae", "clip"] and (num_heads == 0 or hidden_size == 0): + logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") + + if model_type not in MODEL_TYPES: + logger.warning(f"Unsupported model type: {model_type} for graph fusion, directly return model.") + return OnnxModel(model) + + (optimizer_class, producer, _) = MODEL_TYPES[model_type] + + if model.producer_name and producer != model.producer_name: + logger.warning( + f'Model producer not matched: Expected "{producer}", Got "{model.producer_name}".' + "Please specify correct --model_type parameter." + ) + + if optimization_options is None: + optimization_options = FusionOptions(model_type) + + optimizer = optimizer_class(model, num_heads, hidden_size) + + optimizer.optimize(optimization_options) + + optimizer.topological_sort() + + optimizer.model.producer_name = "onnxruntime.transformers" + from onnxruntime import __version__ as onnxruntime_version + + optimizer.model.producer_version = onnxruntime_version + + return optimizer + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_onnx_path", type=str, default="./phi2_original.onnx", help="Input ONNX model") + parser.add_argument("--optimized_GQA_path", type=str, default="./phi2_decoder_fp16_gpu_sm8x.onnx", help="Optimized GQA model") + parser.add_argument("--skip_export", required=False, action="store_true", help="skip export phi2 model and optiomization") + parser.add_argument("--run_example", required=False, action="store_true", help="Run phi2 model example with ORT-extension") + args = parser.parse_args() + return args + +def optimize_to_GroupQueryAttention(input_onnx_path, optimized_GQA_path): + model = load_model(input_onnx_path) + phi_config = AutoConfig.from_pretrained("microsoft/phi-2", trust_remote_code=True, cache_dir="./cache") + optimization_options = FusionOptions("phi") + optimization_options.set_attention_op_type(AttentionOpType.GroupQueryAttention) + optimizer = optimize_by_fusion( + model, + "phi", + num_heads=phi_config.num_attention_heads, + hidden_size=phi_config.hidden_size, + optimization_options=optimization_options) + + node_block_list = ( + [ + "Attention_29", + "Attention_30", + "Attention_31", + ] + ) + logging.info("Converting onnx model to float16/bfloat16...") + optimizer.convert_float_to_float16( + keep_io_types=False, + node_block_list=node_block_list, + use_symbolic_shape_infer=True, + use_bfloat16_as_blocked_nodes_dtype=True, + ) + logging.info("Converting onnx model to float16/bfloat16 done.") + optimizer.save_model_to_file(optimized_GQA_path, use_external_data_format=True) + +if __name__ == "__main__": + args = parse_arguments() + if not args.skip_export: + optimize_to_GroupQueryAttention(args.input_onnx_path, args.optimized_GQA_path) + if args.run_example: + from inference_example import run_phi2 + run_phi2(onnx_model_path=args.optimized_GQA_path) + \ No newline at end of file diff --git a/transformers/shape_infer_helper.py b/transformers/shape_infer_helper.py new file mode 100644 index 000000000..14a24ecf4 --- /dev/null +++ b/transformers/shape_infer_helper.py @@ -0,0 +1,122 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import logging +import os +import sys +from typing import Dict + +# In ORT Package the symbolic_shape_infer.py is in ../tools +#file_path = os.path.dirname(__file__) +#if os.path.exists(os.path.join(file_path, "../tools/symbolic_shape_infer.py")): +# sys.path.append(os.path.join(file_path, "../tools")) +#else: +# sys.path.append(os.path.join(file_path, "..")) + +from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy # noqa: E402 + +logger = logging.getLogger(__name__) + + +class SymbolicShapeInferenceHelper(SymbolicShapeInference): + def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False): + super().__init__(int_max, auto_merge, guess_output_rank, verbose) + self.model_ = model + self.all_shapes_inferred_: bool = False + self.is_inferred_: bool = False + self.dynamic_axis_mapping_: Dict[str, int] = {} + + def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 200): + """Run shape inference, and try replace dynamic axis from string to integer when mapping is provided. + + Args: + dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4} + max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200. + + Returns: + bool: whether all shapes has been inferred or not. + """ + assert dynamic_axis_mapping is not None + + if self.is_inferred_ and self.dynamic_axis_mapping_ == dynamic_axis_mapping: + return self.all_shapes_inferred_ + + self.dynamic_axis_mapping_ = dynamic_axis_mapping + + self._preprocess(self.model_) + + count = 0 + while self.run_: + logger.debug(f"shape infer run {count}") + self.all_shapes_inferred_ = self._infer_impl() + count += 1 + if max_runs > 0 and count >= max_runs: + break + + self.is_inferred_ = True + return self.all_shapes_inferred_ + + def _get_sympy_shape(self, node, idx): + """Override it to ensure shape inference by giving the actual value of dynamic axis.""" + sympy_shape = [] + + shape = self._get_shape(node, idx) + if shape: + for dim in shape: + if isinstance(dim, str): + if dim in self.dynamic_axis_mapping_: + sympy_shape.append(self.dynamic_axis_mapping_[dim]) + elif dim in self.symbolic_dims_: + sympy_shape.append(self.symbolic_dims_[dim]) + else: + sympy_shape.append(sympy.Symbol(dim, integer=True)) + else: + assert dim is not None + sympy_shape.append(dim) + return sympy_shape + + def get_edge_shape(self, edge): + """Get shape of an edge. + + Args: + edge (str): name of edge + + Returns: + Optional[List[int]]: the shape, or None if shape is unknown + """ + assert self.all_shapes_inferred_ + if edge not in self.known_vi_: + print("Cannot retrieve the shape of " + str(edge)) + return None + + type_proto = self.known_vi_[edge].type + shape = get_shape_from_type_proto(type_proto) + + if shape is not None: + for i, dim in enumerate(shape): + if isinstance(dim, str) and dim in self.dynamic_axis_mapping_: + shape[i] = self.dynamic_axis_mapping_[dim] + + return shape + + def compare_shape(self, edge, edge_other): + """Compare shape of two edges. + + Args: + edge (str): name of edge + edge_other (str): name of another edge + + Raises: + Exception: At least one shape is missed for edges to compare + + Returns: + bool: whether the shape is same or not + """ + assert self.all_shapes_inferred_ + shape = self.get_edge_shape(edge) + shape_other = self.get_edge_shape(edge_other) + if shape is None or shape_other is None: + raise Exception("At least one shape is missed for edges to compare") + return shape == shape_other diff --git a/transformers/symbolic_shape_infer.py b/transformers/symbolic_shape_infer.py new file mode 100755 index 000000000..4b56bc1e8 --- /dev/null +++ b/transformers/symbolic_shape_infer.py @@ -0,0 +1,2982 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# -*- coding: UTF-8 -*- +import argparse +import logging + +import numpy as np +import onnx +import sympy +from onnx import helper, numpy_helper, shape_inference +from packaging import version + +assert version.parse(onnx.__version__) >= version.parse("1.8.0") + +logger = logging.getLogger(__name__) + + +def get_attribute(node, attr_name, default_value=None): + found = [attr for attr in node.attribute if attr.name == attr_name] + if found: + return helper.get_attribute_value(found[0]) + return default_value + + +def get_dim_from_proto(dim): + return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None # noqa: E721 + + +def is_sequence(type_proto): + cls_type = type_proto.WhichOneof("value") + assert cls_type in ["tensor_type", "sequence_type"] + return cls_type == "sequence_type" + + +def get_shape_from_type_proto(type_proto): + assert not is_sequence(type_proto) + if type_proto.tensor_type.HasField("shape"): + return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] + else: + return None # note no shape is different from shape without dim (scalar) + + +def get_elem_type_from_type_proto(type_proto): + if is_sequence(type_proto): + return type_proto.sequence_type.elem_type.tensor_type.elem_type + else: + return type_proto.tensor_type.elem_type + + +def get_shape_from_value_info(vi): + cls_type = vi.type.WhichOneof("value") + if cls_type is None: + return None + if is_sequence(vi.type): + if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type": + return get_shape_from_type_proto(vi.type.sequence_type.elem_type) + else: + return None + else: + return get_shape_from_type_proto(vi.type) + + +def make_named_value_info(name): + vi = onnx.ValueInfoProto() + vi.name = name + return vi + + +def get_shape_from_sympy_shape(sympy_shape): + return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape] + + +def is_literal(dim): + return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number) + + +def handle_negative_axis(axis, rank): + assert axis < rank and axis >= -rank + return axis if axis >= 0 else rank + axis + + +def get_opset(mp, domain=None): + domain = domain or ["", "onnx", "ai.onnx"] + if type(domain) != list: # noqa: E721 + domain = [domain] + for opset in mp.opset_import: + if opset.domain in domain: + return opset.version + + return None + + +def as_scalar(x): + if type(x) == list: # noqa: E721 + assert len(x) == 1 + return x[0] + elif type(x) == np.ndarray: + return x.item() + else: + return x + + +def as_list(x, keep_none): + if type(x) == list: # noqa: E721 + return x + elif type(x) == np.ndarray: + return list(x) + elif keep_none and x is None: + return None + else: + return [x] + + +def sympy_reduce_product(x): + if type(x) == list: # noqa: E721 + value = sympy.Integer(1) + for v in x: + value = value * v + else: + value = x + return value + + +class SymbolicShapeInference: + def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): + self.dispatcher_ = { + "Add": self._infer_symbolic_compute_ops, + "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor, + "AveragePool": self._infer_Pool, + "BatchNormalization": self._infer_BatchNormalization, + "Cast": self._infer_Cast, + "CategoryMapper": self._infer_CategoryMapper, + "Compress": self._infer_Compress, + "Concat": self._infer_Concat, + "ConcatFromSequence": self._infer_ConcatFromSequence, + "Constant": self._infer_Constant, + "ConstantOfShape": self._infer_ConstantOfShape, + "Conv": self._infer_Conv, + "CumSum": self._pass_on_shape_and_type, + "Div": self._infer_symbolic_compute_ops, + "Einsum": self._infer_Einsum, + "Expand": self._infer_Expand, + "Equal": self._infer_symbolic_compute_ops, + "Floor": self._infer_symbolic_compute_ops, + "Gather": self._infer_Gather, + "GatherElements": self._infer_GatherElements, + "GatherND": self._infer_GatherND, + "Identity": self._pass_on_shape_and_type, + "AllReduce": self._pass_on_shape_and_type, + "If": self._infer_If, + "Loop": self._infer_Loop, + "MatMul": self._infer_MatMul, + "MatMulInteger16": self._infer_MatMulInteger, + "MaxPool": self._infer_Pool, + "Max": self._infer_symbolic_compute_ops, + "MemcpyFromHost": self._pass_on_shape_and_type, + "MemcpyToHost": self._pass_on_shape_and_type, + "Min": self._infer_symbolic_compute_ops, + "MoE": self._pass_on_shape_and_type, + "Mul": self._infer_symbolic_compute_ops, + "NonMaxSuppression": self._infer_NonMaxSuppression, + "NonZero": self._infer_NonZero, + "OneHot": self._infer_OneHot, + "Pad": self._infer_Pad, + "Range": self._infer_Range, + "Reciprocal": self._pass_on_shape_and_type, + "ReduceSum": self._infer_ReduceSum, + "ReduceProd": self._infer_ReduceProd, + "Reshape": self._infer_Reshape, + "Resize": self._infer_Resize, + "Round": self._pass_on_shape_and_type, + "Scan": self._infer_Scan, + "ScatterElements": self._infer_ScatterElements, + "SequenceAt": self._infer_SequenceAt, + "SequenceInsert": self._infer_SequenceInsert, + "Shape": self._infer_Shape, + "Size": self._infer_Size, + "Slice": self._infer_Slice, + "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss, + "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss, + "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss, + "Split": self._infer_Split, + "SplitToSequence": self._infer_SplitToSequence, + "Squeeze": self._infer_Squeeze, + "Sub": self._infer_symbolic_compute_ops, + "Tile": self._infer_Tile, + "TopK": self._infer_TopK, + "Transpose": self._infer_Transpose, + "Unsqueeze": self._infer_Unsqueeze, + "Where": self._infer_symbolic_compute_ops, + "ZipMap": self._infer_ZipMap, + "Neg": self._infer_symbolic_compute_ops, + # contrib ops: + "Attention": self._infer_Attention, + "BiasAdd": self._infer_BiasAdd, + "BiasGelu": self._infer_BiasGelu, + "BiasSplitGelu": self._infer_BiasSplitGelu, + "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, + "DequantizeLinear": self._infer_DequantizeLinear, + "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, + "FastGelu": self._infer_FastGelu, + "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, + "Gelu": self._infer_Gelu, + "GemmFastGelu": self._infer_GemmFastGelu, + "GemmFloat8": self._infer_GemmFloat8, + "GroupNorm": self._infer_GroupNorm, + "GroupQueryAttention": self._infer_GroupQueryAttention, + "SkipGroupNorm": self._infer_SkipGroupNorm, + "LayerNormalization": self._infer_LayerNormalization, + "LongformerAttention": self._infer_LongformerAttention, + "MultiHeadAttention": self._infer_MultiHeadAttention, + "NhwcConv": self._infer_NhwcConv, + "PackedAttention": self._infer_PackedAttention, + "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, + "PagedAttention": self._infer_PagedAttention, + "PythonOp": self._infer_PythonOp, + "QuantizeLinear": self._infer_QuantizeLinear, + "QuickGelu": self._infer_FastGelu, + "RelativePositionBias": self._infer_RelativePositionBias, + "RemovePadding": self._infer_RemovePadding, + "RestorePadding": self._infer_RestorePadding, + "RotaryEmbedding": self._infer_RotaryEmbedding, + "SimplifiedLayerNormalization": self._infer_LayerNormalization, + "SkipLayerNormalization": self._infer_SkipLayerNormalization, + "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + } + self.aten_op_dispatcher_ = { + "embedding": self._infer_Gather, + "bitwise_or": self._infer_aten_bitwise_or, + "diagonal": self._infer_aten_diagonal, + "max_pool2d_with_indices": self._infer_aten_pool2d, + "max": self._infer_aten_minmax, + "min": self._infer_aten_minmax, + "multinomial": self._infer_aten_multinomial, + "unfold": self._infer_aten_unfold, + "argmax": self._infer_aten_argmax, + "avg_pool2d": self._infer_aten_pool2d, + "_adaptive_avg_pool2d": self._infer_aten_pool2d, + "numpy_T": self._infer_Transpose, + "native_group_norm": self._infer_aten_group_norm, + "upsample_nearest1d": self._infer_aten_upsample, + "upsample_nearest2d": self._infer_aten_upsample, + "upsample_nearest3d": self._infer_aten_upsample, + "upsample_bicubic2d": self._infer_aten_upsample, + } + self.run_ = True + self.suggested_merge_ = {} + self.symbolic_dims_ = {} + self.input_symbols_ = {} + self.auto_merge_ = auto_merge + self.guess_output_rank_ = guess_output_rank + self.verbose_ = verbose + self.int_max_ = int_max + self.subgraph_id_ = 0 + self.prefix_ = prefix + + def _add_suggested_merge(self, symbols, apply=False): + assert all([(type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols]) # noqa: E721 + symbols = set(symbols) + for k, v in self.suggested_merge_.items(): + if k in symbols: + symbols.remove(k) + symbols.add(v) + map_to = None + # if there is literal, map to it first + for s in symbols: + if is_literal(s): + map_to = s + break + # when no literals, map to input symbolic dims, then existing symbolic dims + if map_to is None: + for s in symbols: + if s in self.input_symbols_: + map_to = s + break + if map_to is None: + for s in symbols: + if type(self.symbolic_dims_[s]) == sympy.Symbol: + map_to = s + break + # when nothing to map to, use the shorter one + if map_to is None: + if self.verbose_ > 0: + logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols))) + symbols_list = list(symbols) + lens = [len(s) for s in symbols_list] + map_to = symbols_list[lens.index(min(lens))] + symbols.remove(map_to) + + for s in symbols: + if s == map_to: + continue + if is_literal(map_to) and is_literal(s): + assert int(map_to) == int(s) + self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to + for k, v in self.suggested_merge_.items(): + if v == s: + self.suggested_merge_[k] = map_to + if apply and self.auto_merge_: + self._apply_suggested_merge() + + def _apply_suggested_merge(self, graph_input_only=False): + if not self.suggested_merge_: + return + for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)): + for d in i.type.tensor_type.shape.dim: + if d.dim_param in self.suggested_merge_: + v = self.suggested_merge_[d.dim_param] + if is_literal(v): + d.dim_value = int(v) + else: + d.dim_param = v + + def _preprocess(self, in_mp): + self.out_mp_ = onnx.ModelProto() + self.out_mp_.CopyFrom(in_mp) + self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)} + self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer} + self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)} + self.known_vi_.update( + { + i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)) + for i in self.out_mp_.graph.initializer + } + ) + + def _merge_symbols(self, dims): + if not all([type(d) == str for d in dims]): # noqa: E721 + if self.auto_merge_: + unique_dims = list(set(dims)) + is_int = [is_literal(d) for d in unique_dims] + assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong + if sum(is_int) == 1: + int_dim = is_int.index(1) + if self.verbose_ > 0: + logger.debug( + "dim {} has been merged with value {}".format( + unique_dims[:int_dim] + unique_dims[int_dim + 1 :], + unique_dims[int_dim], + ) + ) + self._check_merged_dims(unique_dims, allow_broadcast=False) + return unique_dims[int_dim] + else: + if self.verbose_ > 0: + logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}") + return dims[0] + else: + return None + if all([d == dims[0] for d in dims]): + return dims[0] + merged = [self.suggested_merge_.get(d, d) for d in dims] + if all([d == merged[0] for d in merged]): + assert merged[0] in self.symbolic_dims_ + return merged[0] + else: + return None + + # broadcast from right to left, and merge symbolic dims if needed + def _broadcast_shapes(self, shape1, shape2): + new_shape = [] + rank1 = len(shape1) + rank2 = len(shape2) + new_rank = max(rank1, rank2) + for i in range(new_rank): + dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 + dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 + if dim1 == 1 or dim1 == dim2: + new_dim = dim2 + elif dim2 == 1: + new_dim = dim1 + else: + new_dim = self._merge_symbols([dim1, dim2]) + if not new_dim: + # warning about unsupported broadcast when not auto merge + # note that auto merge has the risk of incorrectly merge symbols while one of them being 1 + # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b' + if self.auto_merge_: + self._add_suggested_merge([dim1, dim2], apply=True) + else: + logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) + new_shape = [new_dim, *new_shape] + return new_shape + + def _get_shape(self, node, idx): + name = node.input[idx] + if name in self.known_vi_: + vi = self.known_vi_[name] + return get_shape_from_value_info(vi) + else: + assert name in self.initializers_ + return list(self.initializers_[name].dims) + + def _try_get_shape(self, node, idx): + if idx > len(node.input) - 1: + return None + name = node.input[idx] + if name in self.known_vi_: + vi = self.known_vi_[name] + return get_shape_from_value_info(vi) + if name in self.initializers_: + return list(self.initializers_[name].dims) + return None + + def _get_shape_rank(self, node, idx): + return len(self._get_shape(node, idx)) + + def _get_sympy_shape(self, node, idx): + sympy_shape = [] + for d in self._get_shape(node, idx): + if type(d) == str: # noqa: E721 + sympy_shape.append( + self.symbolic_dims_[d] + if d in self.symbolic_dims_ + else sympy.Symbol(d, integer=True, nonnegative=True) + ) + else: + assert None is not d + sympy_shape.append(d) + return sympy_shape + + def _get_value(self, node, idx): + name = node.input[idx] + assert name in self.sympy_data_ or name in self.initializers_ + return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name]) + + def _try_get_value(self, node, idx): + if idx >= len(node.input): + return None + name = node.input[idx] + if name in self.sympy_data_ or name in self.initializers_: + return self._get_value(node, idx) + return None + + def _update_computed_dims(self, new_sympy_shape): + for i, new_dim in enumerate(new_sympy_shape): + if not is_literal(new_dim) and type(new_dim) != str: # noqa: E721 + str_dim = str(new_dim) + if str_dim in self.suggested_merge_: + if is_literal(self.suggested_merge_[str_dim]): + continue # no need to create dim for literals + new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]] + else: + # add new_dim if it's a computational expression + if str(new_dim) not in self.symbolic_dims_: + self.symbolic_dims_[str(new_dim)] = new_dim + + def _onnx_infer_single_node(self, node): + # skip onnx shape inference for some ops, as they are handled in _infer_* + skip_infer = node.op_type in [ + "If", + "Loop", + "Scan", + "SplitToSequence", + "ZipMap", # contrib ops + "Attention", + "BiasGelu", + "EmbedLayerNormalization", + "FastGelu", + "Gelu", + "GemmFastGelu", + "LayerNormalization", + "LongformerAttention", + "DequantizeLinear", + "QuantizeLinear", + "RelativePositionBias", + "RemovePadding", + "RestorePadding", + "SimplifiedLayerNormalization", + "SkipLayerNormalization", + "SkipSimplifiedLayerNormalization", + "PackedAttention", + "PagedAttention", + "PythonOp", + "MultiHeadAttention", + "GroupNorm", + "GroupQueryAttention", + "SkipGroupNorm", + "BiasSplitGelu", + "BiasAdd", + "NhwcConv", + "QuickGelu", + "RotaryEmbedding", + ] + + if not skip_infer: + # Only pass initializers that satisfy the following condition: + # (1) Operator need value of some input for shape inference. + # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output. + # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. + # (3) The initializer is not in graph input. The means the node input is "constant" in inference. + initializers = [] + if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]: + initializers = [ + self.initializers_[name] + for name in node.input + if (name in self.initializers_ and name not in self.graph_inputs_) + ] + + # run single node inference with self.known_vi_ shapes + tmp_graph = helper.make_graph( + [node], + "tmp", + [self.known_vi_[i] for i in node.input if i], + [make_named_value_info(i) for i in node.output], + initializers, + ) + + self.tmp_mp_.graph.CopyFrom(tmp_graph) + + self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) + + for i_o in range(len(node.output)): + o = node.output[i_o] + if o: # skip optional output + vi = self.out_mp_.graph.value_info.add() + if not skip_infer: + vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) + else: + vi.name = o + self.known_vi_[o] = vi + + def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): + if self.verbose_ > 2: + logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}") + # node inputs are not passed directly to the subgraph + # it's up to the node dispatcher to prepare subgraph input + # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape + # besides, inputs in subgraph could shadow implicit inputs + subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)} + subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs} + tmp_graph = helper.make_graph( + list(subgraph.node), + "tmp", + list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], + [make_named_value_info(i.name) for i in subgraph.output], + ) + tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input]) + tmp_graph.initializer.extend(subgraph.initializer) + self.tmp_mp_.graph.CopyFrom(tmp_graph) + + symbolic_shape_inference = SymbolicShapeInference( + self.int_max_, + self.auto_merge_, + self.guess_output_rank_, + self.verbose_, + prefix=self.prefix_ + "_" + str(self.subgraph_id_), + ) + if inc_subgraph_id: + self.subgraph_id_ += 1 + + symbolic_shape_inference._preprocess(self.tmp_mp_) + symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() + while symbolic_shape_inference.run_: + symbolic_shape_inference._infer_impl(self.sympy_data_.copy()) + symbolic_shape_inference._update_output_from_vi() + if use_node_input: + # if subgraph uses node input, it needs to update to merged dims + subgraph.ClearField("input") + subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)]) + subgraph.ClearField("output") + subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) + subgraph.ClearField("value_info") + subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info) + subgraph.ClearField("node") + subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) + # for new symbolic dims from subgraph output, add to main graph symbolic dims + subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output] + subgraph_new_symbolic_dims = { + d for s in subgraph_shapes if s for d in s if type(d) == str and d not in self.symbolic_dims_ # noqa: E721 + } + new_dims = {} + for d in subgraph_new_symbolic_dims: + assert d in symbolic_shape_inference.symbolic_dims_ + new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] + self.symbolic_dims_.update(new_dims) + return symbolic_shape_inference + + def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False): + def int_or_float(value, allow_float_values): + # If casting into int has precision loss: keep float output + if allow_float_values and value % 1 != 0: + return value + return int(value) + + values = [self._try_get_value(node, i) for i in range(len(node.input))] + if all([v is not None for v in values]): + # some shape compute is in floating point, cast to int for sympy + for i, v in enumerate(values): + if type(v) != np.ndarray: + continue + if len(v.shape) > 1: + new_v = None # ignore value for rank > 1 + elif len(v.shape) == 0: + new_v = int_or_float(v.item(), allow_float_values) + else: + assert len(v.shape) == 1 + new_v = [int_or_float(vv, allow_float_values) for vv in v] + values[i] = new_v + values_len = [len(v) if isinstance(v, list) else 0 for v in values] + max_len = max(values_len) + if max_len >= 1 and broadcast: + # broadcast + for i, v in enumerate(values): + if v is None: + continue # don't broadcast if value is unknown + if isinstance(v, list): + if len(v) < max_len: + values[i] = v * max_len + else: + assert len(v) == max_len + else: + values[i] = [v] * max_len + return values + + def _compute_on_sympy_data(self, node, op_func): + assert len(node.output) == 1 + + # Before mul & div operations + # cast inputs into interger might lose decimal part and reduce precision + # keep them as float, finish the operation, then cast the result into integer + if node.op_type in ["Mul", "Div"]: + values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True) + else: + values = self._get_int_or_float_values(node, broadcast=True) + + if all([v is not None for v in values]): + is_list = [isinstance(v, list) for v in values] + as_list = any(is_list) + if as_list: + self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)] + else: + self.sympy_data_[node.output[0]] = op_func(values) + + def _pass_on_sympy_data(self, node): + assert len(node.input) == 1 or node.op_type in [ + "Reshape", + "Unsqueeze", + "Squeeze", + ] + self._compute_on_sympy_data(node, lambda x: x[0]) + + def _pass_on_shape_and_type(self, node): + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type), + self._get_shape(node, 0), + ) + ) + + def _new_symbolic_dim(self, prefix, dim): + new_dim = f"{prefix}_d{dim}" + if new_dim in self.suggested_merge_: + v = self.suggested_merge_[new_dim] + new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v + else: + new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True) + self.symbolic_dims_[new_dim] = new_symbolic_dim + return new_symbolic_dim + + def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): + return self._new_symbolic_dim( + "{}{}_{}_o{}_".format( + node.op_type, + self.prefix_, + list(self.out_mp_.graph.node).index(node), + out_idx, + ), + dim, + ) + + def _new_symbolic_shape(self, rank, node, out_idx=0): + return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] + + def _compute_conv_pool_shape(self, node, channels_last=False): + sympy_shape = self._get_sympy_shape(node, 0) + if len(node.input) > 1: + W_shape = self._get_sympy_shape(node, 1) # noqa: N806 + rank = len(W_shape) - 2 # number of spatial axes + kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:] + sympy_shape[3 if channels_last else 1] = W_shape[0] + else: + W_shape = None # noqa: N806 + kernel_shape = get_attribute(node, "kernel_shape") + rank = len(kernel_shape) + + assert len(sympy_shape) == rank + 2 + + # only need to symbolic shape inference if input has symbolic dims in spatial axes + spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] + is_symbolic_dims = [not is_literal(i) for i in spatial_shape] + + if not any(is_symbolic_dims): + shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) + if len(shape) > 0: + assert len(sympy_shape) == len(shape) + if channels_last: + sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]] + else: + sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] + return sympy_shape + + dilations = get_attribute(node, "dilations", [1] * rank) + strides = get_attribute(node, "strides", [1] * rank) + effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] + pads = get_attribute(node, "pads") + if pads is None: + pads = [0] * (2 * rank) + auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") + if auto_pad != "VALID" and auto_pad != "NOTSET": + try: + residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] + total_pads = [ + max(0, (k - s) if r == 0 else (k - r)) + for k, s, r in zip(effective_kernel_shape, strides, residual) + ] + except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational + total_pads = [ + max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) + ] # assuming no residual if sympy throws error + elif auto_pad == "VALID": + total_pads = [] + else: + total_pads = [0] * rank + else: + assert len(pads) == 2 * rank + total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] + + ceil_mode = get_attribute(node, "ceil_mode", 0) + for i in range(rank): + effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)] + if len(total_pads) > 0: + effective_input_size = effective_input_size + total_pads[i] + if ceil_mode: + strided_kernel_positions = sympy.ceiling( + (effective_input_size - effective_kernel_shape[i]) / strides[i] + ) + else: + strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i] + sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1 + return sympy_shape + + def _check_merged_dims(self, dims, allow_broadcast=True): + if allow_broadcast: + dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] + if not all([d == dims[0] for d in dims]): + self._add_suggested_merge(dims, apply=True) + + def _compute_matmul_shape(self, node, output_dtype=None): + lhs_shape = self._get_shape(node, 0) + rhs_shape = self._get_shape(node, 1) + lhs_rank = len(lhs_shape) + rhs_rank = len(rhs_shape) + lhs_reduce_dim = 0 + rhs_reduce_dim = 0 + assert lhs_rank > 0 and rhs_rank > 0 + if lhs_rank == 1 and rhs_rank == 1: + new_shape = [] + elif lhs_rank == 1: + rhs_reduce_dim = -2 + new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] + elif rhs_rank == 1: + lhs_reduce_dim = -1 + new_shape = lhs_shape[:lhs_reduce_dim] + else: + lhs_reduce_dim = -1 + rhs_reduce_dim = -2 + new_shape = [*self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]), lhs_shape[-2], rhs_shape[-1]] + # merge reduce dim + self._check_merged_dims( + [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], + allow_broadcast=False, + ) + if output_dtype is None: + # infer output_dtype from input type when not specified + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): + """ + update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches + """ + dst_tensor_type = ( + dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type + ) + src_tensor_type = ( + src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type + ) + if dst_tensor_type.elem_type != src_tensor_type.elem_type: + node_id = node.name if node.name else node.op_type + raise ValueError( + f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " + f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " + f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" + ) + if dst_tensor_type.HasField("shape"): + for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): + if ds[0] != ds[1]: + # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type + # for sequence_type, clear the dimension + new_dim = onnx.TensorShapeProto.Dimension() + if not is_sequence(dst_type): + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di)) + dst_tensor_type.shape.dim[di].CopyFrom(new_dim) + else: + dst_tensor_type.CopyFrom(src_tensor_type) + + def _infer_ArrayFeatureExtractor(self, node): # noqa: N802 + data_shape = self._get_shape(node, 0) + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape[:-1] + indices_shape, + ) + ) + + def _infer_symbolic_compute_ops(self, node): + funcs = { + "Add": lambda l: l[0] + l[1], # noqa: E741 + "Div": lambda l: ( # noqa: E741 + int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1] + ), # integer div in sympy + "Equal": lambda l: l[0] == l[1], # noqa: E741 + "Floor": lambda l: sympy.floor(l[0]), # noqa: E741 + "Max": lambda l: ( # noqa: E741 + l[1] + if is_literal(l[0]) and int(l[0]) < -self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])) + ), + "Min": lambda l: ( # noqa: E741 + l[1] + if is_literal(l[0]) and int(l[0]) > self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])) + ), + "Mul": lambda l: int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1], # noqa: E741 + "Sub": lambda l: l[0] - l[1], # noqa: E741 + "Where": lambda l: l[1] if l[0] else l[2], # noqa: E741 + "Neg": lambda l: -l[0], # noqa: E741 + } + assert node.op_type in funcs + self._compute_on_sympy_data(node, funcs[node.op_type]) + + def _infer_Cast(self, node): # noqa: N802 + self._pass_on_sympy_data(node) + + def _infer_CategoryMapper(self, node): # noqa: N802 + input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + if input_type == onnx.TensorProto.STRING: + output_type = onnx.TensorProto.INT64 + else: + output_type = onnx.TensorProto.STRING + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0))) + + def _infer_Compress(self, node): # noqa: N802 + input_shape = self._get_shape(node, 0) + # create a new symbolic dimension for Compress output + compress_len = str(self._new_symbolic_dim_from_output(node)) + axis = get_attribute(node, "axis") + if axis is None: + # when axis is not specified, input is flattened before compress so output is 1D + output_shape = [compress_len] + else: + output_shape = input_shape + output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + def _infer_Concat(self, node): # noqa: N802 + if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]): + values = self._get_int_or_float_values(node) + if all([v is not None for v in values]): + assert get_attribute(node, "axis") == 0 + self.sympy_data_[node.output[0]] = [] + for i in range(len(node.input)): + value = values[i] + if isinstance(value, list): + self.sympy_data_[node.output[0]].extend(value) + else: + self.sympy_data_[node.output[0]].append(value) + + sympy_shape = self._get_sympy_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape)) + for i_idx in range(1, len(node.input)): + input_shape = self._get_sympy_shape(node, i_idx) + if input_shape: + sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] + self._update_computed_dims(sympy_shape) + # merge symbolic dims for non-concat axes + for d in range(len(sympy_shape)): + if d == axis: + continue + dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)] + if all([d == dims[0] for d in dims]): + continue + merged = self._merge_symbols(dims) + if type(merged) == str: # noqa: E721 + sympy_shape[d] = self.symbolic_dims_[merged] if merged else None + else: + sympy_shape[d] = merged + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_ConcatFromSequence(self, node): # noqa: N802 + seq_shape = self._get_shape(node, 0) + new_axis = 1 if get_attribute(node, "new_axis") else 0 + axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) + concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) + new_shape = seq_shape + if new_axis: + new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] + else: + new_shape[axis] = concat_dim + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Constant(self, node): # noqa: N802 + t = get_attribute(node, "value") + self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) + + def _infer_ConstantOfShape(self, node): # noqa: N802 + sympy_shape = self._get_int_or_float_values(node)[0] + vi = self.known_vi_[node.output[0]] + if sympy_shape is not None: + if type(sympy_shape) != list: # noqa: E721 + sympy_shape = [sympy_shape] + self._update_computed_dims(sympy_shape) + # update sympy data if output type is int, and shape is known + if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]): + self.sympy_data_[node.output[0]] = np.ones( + [int(x) for x in sympy_shape], dtype=np.int64 + ) * numpy_helper.to_array(get_attribute(node, "value", 0)) + else: + # create new dynamic shape + # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length + sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node) + + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_Conv(self, node): # noqa: N802 + sympy_shape = self._compute_conv_pool_shape(node) + self._update_computed_dims(sympy_shape) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_NhwcConv(self, node): # noqa: N802 + sympy_shape = self._compute_conv_pool_shape(node, channels_last=True) + self._update_computed_dims(sympy_shape) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_DequantizeLinear(self, node): # noqa: N802 + # Get the output data type from the scale input (index 1, required). + output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_QuantizeLinear(self, node): # noqa: N802 + # Get the output data type from the zero-point input (index 2, optional). + # Otherwise, default to uint8 + output_dtype = onnx.TensorProto.UINT8 + if len(node.input) > 2 and node.input[2]: + output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_Einsum(self, node): # noqa: N802 + # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 + equation = get_attribute(node, "equation") + equation = equation.replace(b" ", b"") + mid_index = equation.find(b"->") + left_equation = equation[:mid_index] if mid_index != -1 else equation + + num_operands = 0 + num_ellipsis = 0 + num_ellipsis_indices = 0 + + letter_to_dim = {} + + terms = left_equation.split(b",") + for term in terms: + ellipsis_index = term.find(b"...") + shape = self._get_shape(node, num_operands) + rank = len(shape) + if ellipsis_index != -1: + if num_ellipsis == 0: + num_ellipsis_indices = rank - len(term) + 3 + num_ellipsis = num_ellipsis + 1 + for i in range(1, rank + 1): + letter = term[-i] + if letter != 46: # letter != b'.' + dim = shape[-i] + if letter not in letter_to_dim: + letter_to_dim[letter] = dim + elif type(dim) != sympy.Symbol: + letter_to_dim[letter] = dim + num_operands = num_operands + 1 + + new_sympy_shape = [] + from collections import OrderedDict + + num_letter_occurrences = OrderedDict() + if mid_index != -1: + right_equation = equation[mid_index + 2 :] + right_ellipsis_index = right_equation.find(b"...") + if right_ellipsis_index != -1: + for i in range(num_ellipsis_indices): + new_sympy_shape.append(shape[i]) + for c in right_equation: + if c != 46: # c != b'.' + new_sympy_shape.append(letter_to_dim[c]) + else: + for i in range(num_ellipsis_indices): + new_sympy_shape.append(shape[i]) + for c in left_equation: + if c != 44 and c != 46: # c != b',' and c != b'.': + if c in num_letter_occurrences: + num_letter_occurrences[c] = num_letter_occurrences[c] + 1 + else: + num_letter_occurrences[c] = 1 + for key, value in num_letter_occurrences.items(): + if value == 1: + new_sympy_shape.append(letter_to_dim[key]) + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape)) + + def _infer_Expand(self, node): # noqa: N802 + expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) + if expand_to_shape is not None: + # new_shape's dim can come from shape value + self._update_computed_dims(expand_to_shape) + shape = self._get_shape(node, 0) + new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Gather(self, node): # noqa: N802 + data_shape = self._get_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape)) + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape[:axis] + indices_shape + data_shape[axis + 1 :], + ) + ) + # for 1D input, do some sympy compute + if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0: + idx = self._try_get_value(node, 1) + if idx is not None: + data = self.sympy_data_[node.input[0]] + if type(data) == list: # noqa: E721 + if type(idx) == np.ndarray and len(idx.shape) == 1: + self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx] + else: + self.sympy_data_[node.output[0]] = data[int(idx)] + else: + assert idx == 0 or idx == -1 + self.sympy_data_[node.output[0]] = data + + def _infer_GatherElements(self, node): # noqa: N802 + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + indices_shape, + ) + ) + + def _infer_GatherND(self, node): # noqa: N802 + data_shape = self._get_shape(node, 0) + data_rank = len(data_shape) + indices_shape = self._get_shape(node, 1) + len(indices_shape) + last_index_dimension = indices_shape[-1] + assert is_literal(last_index_dimension) and last_index_dimension <= data_rank + new_shape = indices_shape[:-1] + data_shape[last_index_dimension:] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_If(self, node): # noqa: N802 + # special case for constant condition, in case there are mismatching shape from the non-executed branch + subgraphs = [ + get_attribute(node, "then_branch"), + get_attribute(node, "else_branch"), + ] + cond = self._try_get_value(node, 0) + if cond is not None: + if as_scalar(cond) > 0: + subgraphs[1].CopyFrom(subgraphs[0]) + else: + subgraphs[0].CopyFrom(subgraphs[1]) + + for i_sub, subgraph in enumerate(subgraphs): + subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False) + for i_out in range(len(node.output)): + vi = self.known_vi_[node.output[i_out]] + if i_sub == 0: + vi.CopyFrom(subgraph.output[i_out]) + vi.name = node.output[i_out] + else: + self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type) + + # pass on sympy data from subgraph, if cond is constant + if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1): + if subgraph.output[i_out].name in subgraph_infer.sympy_data_: + self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name] + + def _infer_Loop(self, node): # noqa: N802 + subgraph = get_attribute(node, "body") + assert len(subgraph.input) == len(node.input) + num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition + # when sequence_type is used as loop carried input + # needs to run subgraph infer twice if the tensor shape in sequence contains None + for i, si in enumerate(subgraph.input): + si_name = si.name + si.CopyFrom(self.known_vi_[node.input[i]]) + si.name = si_name + + self._onnx_infer_subgraph(node, subgraph) + + # check subgraph input/output for shape changes in loop carried variables + # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a) + # for sequence_type, propagate from output to input + need_second_infer = False + for i_out in range(1, num_loop_carried + 1): + so = subgraph.output[i_out] + so_shape = get_shape_from_value_info(so) + if is_sequence(so.type): + if so_shape and None in so_shape: + # copy shape from output to input + # note that loop input is [loop_len, cond, input_0, input_1, ...] + # while loop output is [cond, output_0, output_1, ...] + subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type) + need_second_infer = True + else: + si = subgraph.input[i_out + 1] + si_shape = get_shape_from_value_info(si) + for di, dims in enumerate(zip(si_shape, so_shape)): + if dims[0] != dims[1]: + new_dim = onnx.TensorShapeProto.Dimension() + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di)) + si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + need_second_infer = True + + if need_second_infer: + if self.verbose_ > 2: + logger.debug( + "Rerun Loop: {}({}...), because of sequence in loop carried variables".format( + node.name, node.output[0] + ) + ) + self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) + + # create a new symbolic dimension for iteration dependent dimension + loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) + for i in range(len(node.output)): + vi = self.known_vi_[node.output[i]] + vi.CopyFrom(subgraph.output[i + 1]) # first subgraph output is condition, not in node output + if i >= num_loop_carried: + assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type + subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim + vi.type.tensor_type.shape.ClearField("dim") + vi_dim = vi.type.tensor_type.shape.dim + vi_dim.add().dim_param = loop_iter_dim + vi_dim.extend(list(subgraph_vi_dim)) + vi.name = node.output[i] + + def _infer_MatMul(self, node): # noqa: N802 + self._compute_matmul_shape(node) + + def _infer_MatMulInteger(self, node): # noqa: N802 + self._compute_matmul_shape(node, onnx.TensorProto.INT32) + + def _infer_NonMaxSuppression(self, node): # noqa: N802 + selected = str(self._new_symbolic_dim_from_output(node)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3])) + + def _infer_NonZero(self, node): # noqa: N802 + input_rank = self._get_shape_rank(node, 0) + # create a new symbolic dimension for NonZero output + nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) + + def _infer_OneHot(self, node): # noqa: N802 + sympy_shape = self._get_sympy_shape(node, 0) + depth = self._try_get_value(node, 1) + axis = get_attribute(node, "axis", -1) + axis = handle_negative_axis(axis, len(sympy_shape) + 1) + new_shape = get_shape_from_sympy_shape( + sympy_shape[:axis] + + [self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth] + + sympy_shape[axis:] + ) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[2]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Pad(self, node): # noqa: N802 + if get_opset(self.out_mp_) <= 10: + pads = get_attribute(node, "pads") + else: + pads = self._try_get_value(node, 1) + + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + + if pads is not None: + assert len(pads) == 2 * rank + new_sympy_shape = [ + d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) + ] + self._update_computed_dims(new_sympy_shape) + else: + # dynamic pads, create new symbolic dimensions + new_sympy_shape = self._new_symbolic_shape(rank, node) + output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)) + ) + + def _infer_Pool(self, node): # noqa: N802 + sympy_shape = self._compute_conv_pool_shape(node) + self._update_computed_dims(sympy_shape) + for o in node.output: + if not o: + continue + vi = self.known_vi_[o] + vi.CopyFrom( + helper.make_tensor_value_info( + o, + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_aten_bitwise_or(self, node): + shape0 = self._get_shape(node, 0) + shape1 = self._get_shape(node, 1) + new_shape = self._broadcast_shapes(shape0, shape1) + t0 = self.known_vi_[node.input[0]] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape)) + + def _infer_aten_diagonal(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + offset = self._try_get_value(node, 1) + dim1 = self._try_get_value(node, 2) + dim2 = self._try_get_value(node, 3) + + assert offset is not None and dim1 is not None and dim2 is not None + dim1 = handle_negative_axis(dim1, rank) + dim2 = handle_negative_axis(dim2, rank) + + new_shape = [] + for dim, val in enumerate(sympy_shape): + if dim not in [dim1, dim2]: + new_shape.append(val) + + shape1 = sympy_shape[dim1] + shape2 = sympy_shape[dim2] + if offset >= 0: + diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) + else: + diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) + new_shape.append(diag_shape) + + if node.output[0]: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_shape), + ) + ) + + def _infer_aten_multinomial(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + assert rank in [1, 2] + num_samples = self._try_get_value(node, 1) + di = rank - 1 + last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di)) + output_shape = sympy_shape[:-1] + [last_dim] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + onnx.TensorProto.INT64, + get_shape_from_sympy_shape(output_shape), + ) + ) + + def _infer_aten_pool2d(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + assert len(sympy_shape) == 4 + sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]] + self._update_computed_dims(sympy_shape) + for i, o in enumerate(node.output): + if not o: + continue + vi = self.known_vi_[o] + elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape))) + + def _infer_aten_minmax(self, node): + vi = self.known_vi_[node.output[0]] + if len(node.input) == 1: + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, [] + ) + ) + else: + assert len(node.input) == 3 + keepdim = self._try_get_value(node, 2) + assert keepdim is not None # can only handle known keepdim case. + dim = self._try_get_value(node, 1) + if dim is None: + rank = self._get_shape_rank(node, 0) + output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) + else: + shape = self._get_sympy_shape(node, 0) + dim = handle_negative_axis(dim, len(shape)) + output_shape = shape[:dim] + if keepdim: + output_shape += [1] + output_shape += shape[dim + 1 :] + + output_shape = get_shape_from_sympy_shape(output_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, output_shape + ) + ) + vi1 = self.known_vi_[node.output[1]] + vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape)) + + def _infer_aten_unfold(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + dimension = self._try_get_value(node, 1) + size = self._try_get_value(node, 2) + step = self._try_get_value(node, 3) + if dimension is not None and size is not None and step is not None: + assert dimension < len(sympy_shape) + sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 + sympy_shape.append(size) + else: + rank = len(sympy_shape) + sympy_shape = self._new_symbolic_shape(rank + 1, node) + self._update_computed_dims(sympy_shape) + if node.output[0]: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_aten_argmax(self, node): + new_shape = None + if not node.input[1]: + # The argmax of the flattened input is returned. + new_shape = [] + else: + dim = self._try_get_value(node, 1) + keepdim = self._try_get_value(node, 2) + if keepdim is not None: + sympy_shape = self._get_sympy_shape(node, 0) + if dim is not None: + dim = handle_negative_axis(dim, len(sympy_shape)) + if keepdim: + sympy_shape[dim] = 1 + else: + del sympy_shape[dim] + else: + rank = len(sympy_shape) + sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) + self._update_computed_dims(sympy_shape) + new_shape = get_shape_from_sympy_shape(sympy_shape) + if node.output[0] and new_shape is not None: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape)) + + def _infer_aten_group_norm(self, node): + self._propagate_shape_and_type(node) + input_shape = self._get_shape(node, 0) + N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None # noqa: N806 + group = self._try_get_value(node, 6) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + for i in [1, 2]: + if node.output[i]: + vi = self.known_vi_[node.output[i]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[i], + output_dtype, + [ + N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)), + ( + as_scalar(group) + if group is not None + else str(self._new_symbolic_dim_from_output(node, i, 1)) + ), + ], + ) + ) + + def _infer_aten_upsample(self, node): + new_shape = None + input_shape = self._get_shape(node, 0) + if input_shape is not None: + new_shape = input_shape[:2] + output_size = self._try_get_value(node, 1) + if output_size is not None: + new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size] + else: + rank = len(input_shape) + new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)] + if node.output[0] and new_shape is not None: + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _infer_BatchNormalization(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop + for i in [1, 2, 3, 4]: + if i < len(node.output) and node.output[i]: + # all of these parameters have the same shape as the 1st input + self._propagate_shape_and_type(node, input_index=1, output_index=i) + + def _infer_Range(self, node): # noqa: N802 + vi = self.known_vi_[node.output[0]] + input_data = self._get_int_or_float_values(node) + if all([i is not None for i in input_data]): + start = as_scalar(input_data[0]) + limit = as_scalar(input_data[1]) + delta = as_scalar(input_data[2]) + new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)] + else: + new_sympy_shape = [self._new_symbolic_dim_from_output(node)] + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_ReduceSum(self, node): # noqa: N802 + keep_dims = get_attribute(node, "keepdims", 1) + if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: + # ReduceSum changes axes to input[1] in opset 13 + axes = self._try_get_value(node, 1) + vi = self.known_vi_[node.output[0]] + if axes is None: + assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)), + ) + ) + else: + shape = self._get_shape(node, 0) + output_shape = [] + axes = [handle_negative_axis(a, len(shape)) for a in axes] + for i, d in enumerate(shape): + if i in axes: + if keep_dims: + output_shape.append(1) + else: + output_shape.append(d) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + def _infer_ReduceProd(self, node): # noqa: N802 + axes = get_attribute(node, "axes") + keep_dims = get_attribute(node, "keepdims", 1) + if keep_dims == 0 and axes == [0]: + data = self._get_int_or_float_values(node)[0] + if data is not None: + self.sympy_data_[node.output[0]] = sympy_reduce_product(data) + + def _infer_RelativePositionBias(self, node): # noqa: N802 + seq_len = self._try_get_value(node, 1) + real_seq_len = self._try_get_value(node, 2) + if seq_len is None or real_seq_len is None: + return + num_heads = self._get_sympy_shape(node, 0)[1] + + new_shape = [1, num_heads, str(seq_len), str(real_seq_len)] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _infer_Reshape(self, node): # noqa: N802 + shape_value = self._try_get_value(node, 1) + vi = self.known_vi_[node.output[0]] + if shape_value is None: + shape_shape = self._get_shape(node, 1) + assert len(shape_shape) == 1 + shape_rank = shape_shape[0] + assert is_literal(shape_rank) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), + ) + ) + else: + input_sympy_shape = self._get_sympy_shape(node, 0) + total = 1 + for d in input_sympy_shape: + total = total * d + new_sympy_shape = [] + deferred_dim_idx = -1 + non_deferred_size = 1 + for i, d in enumerate(shape_value): + if type(d) == sympy.Symbol: + new_sympy_shape.append(d) + elif d == 0: + new_sympy_shape.append(input_sympy_shape[i]) + non_deferred_size = non_deferred_size * input_sympy_shape[i] + else: + new_sympy_shape.append(d) + if d == -1: + deferred_dim_idx = i + elif d != 0: + non_deferred_size = non_deferred_size * d + + assert new_sympy_shape.count(-1) < 2 + if -1 in new_sympy_shape: + new_dim = total // non_deferred_size + new_sympy_shape[deferred_dim_idx] = new_dim + + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + self._pass_on_sympy_data(node) + + def _infer_Resize(self, node): # noqa: N802 + vi = self.known_vi_[node.output[0]] + input_sympy_shape = self._get_sympy_shape(node, 0) + if get_opset(self.out_mp_) <= 10: + scales = self._try_get_value(node, 1) + if scales is not None: + new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)] + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + else: + roi = self._try_get_value(node, 1) + scales = self._try_get_value(node, 2) + sizes = self._try_get_value(node, 3) + if sizes is not None: + new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes] + self._update_computed_dims(new_sympy_shape) + elif scales is not None: + rank = len(scales) + if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize": + assert len(roi) == 2 * rank + roi_start = list(roi)[:rank] + roi_end = list(roi)[rank:] + else: + roi_start = [0] * rank + roi_end = [1] * rank + scales = list(scales) + new_sympy_shape = [ + sympy.simplify(sympy.floor(d * (end - start) * scale)) + for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales) + ] + self._update_computed_dims(new_sympy_shape) + else: + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) + + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_Scan(self, node): # noqa: N802 + subgraph = get_attribute(node, "body") + num_scan_inputs = get_attribute(node, "num_scan_inputs") + scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs) + num_scan_states = len(node.input) - num_scan_inputs + scan_input_axes = [ + handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states)) + for i, ax in enumerate(scan_input_axes) + ] + # We may have cases where the subgraph has optional inputs that appear in both subgraph's input and initializer, + # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs. + assert len(subgraph.input) >= len(node.input) + subgraph_inputs = subgraph.input[: len(node.input)] + for i, si in enumerate(subgraph_inputs): + subgraph_name = si.name + si.CopyFrom(self.known_vi_[node.input[i]]) + if i >= num_scan_states: + scan_input_dim = si.type.tensor_type.shape.dim + scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]]) + si.name = subgraph_name + self._onnx_infer_subgraph(node, subgraph) + num_scan_outputs = len(node.output) - num_scan_states + scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) + scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] + for i, o in enumerate(node.output): + vi = self.known_vi_[o] + if i >= num_scan_states: + shape = get_shape_from_type_proto(subgraph.output[i].type) + new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1) + shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] + vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape)) + else: + vi.CopyFrom(subgraph.output[i]) + vi.name = o + + def _infer_ScatterElements(self, node): # noqa: N802 + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape, + ) + ) + + def _infer_SequenceAt(self, node): # noqa: N802 + # need to create new symbolic dimension if sequence shape has None: + seq_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[0]] + if seq_shape is not None: + for di, d in enumerate(seq_shape): + if d is not None: + continue + new_dim = onnx.TensorShapeProto.Dimension() + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di)) + vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + + def _infer_SequenceInsert(self, node): # noqa: N802 + # workaround bug in onnx's shape inference + vi_seq = self.known_vi_[node.input[0]] + vi_tensor = self.known_vi_[node.input[1]] + vi_out_seq = self.known_vi_[node.output[0]] + vi_out_seq.CopyFrom(vi_seq) + vi_out_seq.name = node.output[0] + self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) + + def _infer_Shape(self, node): # noqa: N802 + self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) + + def _infer_Size(self, node): # noqa: N802 + sympy_shape = self._get_sympy_shape(node, 0) + self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) + self.known_vi_[node.output[0]].CopyFrom( + helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) + ) + + def _infer_Slice(self, node): # noqa: N802 + # SymPy fails to prove that `x_0 + ... + x_n >= 0` if one of `x_i` is a `sympy.Min(a, b)`, + # even when the relation holds for both `a` and `b`. + # + # When given `expr` of form `min(a, b) + ...`, this function returns `[a + ..., b + ...]`, + # so that we can prove inequalities for both expressions separately. + # + # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`. + def flatten_min(expr): + assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" + min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)] + if len(min_positions) == 1: + min_pos = min_positions[0] + + def replace_min_with_arg(arg_idx): + replaced = list(expr.args) + assert isinstance( + replaced[min_pos], sympy.Min + ), f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}" + assert ( + len(replaced[min_pos].args) == 2 + ), f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}" + replaced[min_pos] = replaced[min_pos].args[arg_idx] + return sympy.Add(*replaced) + + return [ + replace_min_with_arg(0), + replace_min_with_arg(1), + ] + return [expr] + + def less_equal(x, y): + try: + return bool(x <= y) + except TypeError: + pass + try: + return bool(y >= x) + except TypeError: + pass + try: + return bool(-x >= -y) + except TypeError: + pass + try: + return bool(-y <= -x) + except TypeError: + pass + try: + return bool(y - x >= 0) + except TypeError: + # the last attempt; this may raise TypeError + return all(bool(d >= 0) for d in flatten_min(y - x)) + + def handle_negative_index(index, bound): + """normalizes a negative index to be in [0, bound)""" + try: + if not less_equal(0, index): + if is_literal(index) and index <= -self.int_max_: + # this case is handled separately + return index + return bound + index + except TypeError: + logger.warning(f"Cannot determine if {index} < 0") + return index + + if get_opset(self.out_mp_) <= 9: + axes = get_attribute(node, "axes") + starts = get_attribute(node, "starts") + ends = get_attribute(node, "ends") + if not axes: + axes = list(range(len(starts))) + steps = [1] * len(axes) + else: + starts = as_list(self._try_get_value(node, 1), keep_none=True) + ends = as_list(self._try_get_value(node, 2), keep_none=True) + axes = self._try_get_value(node, 3) + steps = self._try_get_value(node, 4) + if axes is None and not (starts is None and ends is None): + axes = list(range(0, len(starts if starts is not None else ends))) + if steps is None and not (starts is None and ends is None): + steps = [1] * len(starts if starts is not None else ends) + axes = as_list(axes, keep_none=True) + steps = as_list(steps, keep_none=True) + + new_sympy_shape = self._get_sympy_shape(node, 0) + if starts is None or ends is None: + if axes is None: + for i in range(len(new_sympy_shape)): + new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) + else: + new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) + for i in axes: + new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) + else: + for i, s, e, t in zip(axes, starts, ends, steps): + e = handle_negative_index(e, new_sympy_shape[i]) # noqa: PLW2901 + if is_literal(e): + if e >= self.int_max_: + e = new_sympy_shape[i] # noqa: PLW2901 + elif e <= -self.int_max_: + e = 0 if s > 0 else -1 # noqa: PLW2901 + elif is_literal(new_sympy_shape[i]): + if e < 0: + e = max(0, e + new_sympy_shape[i]) # noqa: PLW2901 + e = min(e, new_sympy_shape[i]) # noqa: PLW2901 + else: + if e > 0: + e = ( # noqa: PLW2901 + sympy.Min(e, new_sympy_shape[i]) if e > 1 else e + ) # special case for slicing first to make computation easier + else: + if is_literal(new_sympy_shape[i]): + e = sympy.Min(e, new_sympy_shape[i]) # noqa: PLW2901 + else: + try: + if not less_equal(e, new_sympy_shape[i]): + e = new_sympy_shape[i] # noqa: PLW2901 + except Exception: + logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal") + e = new_sympy_shape[i] # noqa: PLW2901 + + s = handle_negative_index(s, new_sympy_shape[i]) # noqa: PLW2901 + if is_literal(new_sympy_shape[i]) and is_literal(s): + s = max(0, min(s, new_sympy_shape[i])) # noqa: PLW2901 + + new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) + + self._update_computed_dims(new_sympy_shape) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + # handle sympy_data if needed, for slice in shape computation + if ( + node.input[0] in self.sympy_data_ + and [0] == axes + and starts is not None + and len(starts) == 1 + and ends is not None + and len(ends) == 1 + and steps is not None + and len(steps) == 1 + ): + input_sympy_data = self.sympy_data_[node.input[0]] + if type(input_sympy_data) == list or ( # noqa: E721 + type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 + ): + self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]] + + def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 + vi = self.known_vi_[node.output[0]] + elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + # If output type is explicit specified in attribute, we use it as output tensor type. + specified_output_type = get_attribute(node, "output_type", None) + if specified_output_type is not None: + elem_type = specified_output_type + + vi.type.tensor_type.elem_type = elem_type + vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) + + if len(node.output) > 1: + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape)) + + def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 + input_sympy_shape = self._get_sympy_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) + split = get_attribute(node, "split") + if not split: + num_outputs = len(node.output) + split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs + self._update_computed_dims(split) + else: + split = [sympy.Integer(s) for s in split] + + for i_o in range(len(split)): + vi = self.known_vi_[node.output[i_o]] + vi.CopyFrom( + make_value_info_func( + node.output[i_o], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :]), + ) + ) + self.known_vi_[vi.name] = vi + + def _infer_Split(self, node): # noqa: N802 + self._infer_Split_Common(node, helper.make_tensor_value_info) + + def _infer_SplitToSequence(self, node): # noqa: N802 + self._infer_Split_Common(node, helper.make_sequence_value_info) + + def _infer_Squeeze(self, node): # noqa: N802 + input_shape = self._get_shape(node, 0) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'axes' are provided as attribute or via 2nd input + if op_set < 13: + axes = get_attribute(node, "axes") + assert self._try_get_value(node, 1) is None + else: + axes = self._try_get_value(node, 1) + assert get_attribute(node, "axes") is None + + if axes is None: + # No axes have been provided (neither via attribute nor via input). + # In this case the 'Shape' op should remove all axis with dimension 1. + # For symbolic dimensions we guess they are !=1. + output_shape = [s for s in input_shape if s != 1] + if self.verbose_ > 0: + symbolic_dimensions = [s for s in input_shape if type(s) != int] # noqa: E721 + if len(symbolic_dimensions) > 0: + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" + ) + else: + axes = [handle_negative_axis(a, len(input_shape)) for a in axes] + output_shape = [] + for i in range(len(input_shape)): + if i not in axes: + output_shape.append(input_shape[i]) + else: + assert input_shape[i] == 1 or type(input_shape[i]) != int # noqa: E721 + if self.verbose_ > 0 and type(input_shape[i]) != int: # noqa: E721 + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." + ) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + self._pass_on_sympy_data(node) + + def _infer_Tile(self, node): # noqa: N802 + repeats_value = self._try_get_value(node, 1) + new_sympy_shape = [] + if repeats_value is not None: + input_sympy_shape = self._get_sympy_shape(node, 0) + for i, d in enumerate(input_sympy_shape): + new_dim = d * repeats_value[i] + new_sympy_shape.append(new_dim) + self._update_computed_dims(new_sympy_shape) + else: + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_TopK(self, node): # noqa: N802 + rank = self._get_shape_rank(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", -1), rank) + new_shape = self._get_shape(node, 0) + + if get_opset(self.out_mp_) <= 9: + k = get_attribute(node, "k") + else: + k = self._get_int_or_float_values(node)[1] + + if k is None: + k = self._new_symbolic_dim_from_output(node) + else: + k = as_scalar(k) + + if type(k) in [int, str]: + new_shape[axis] = k + else: + new_sympy_shape = self._get_sympy_shape(node, 0) + new_sympy_shape[axis] = k + self._update_computed_dims( + new_sympy_shape + ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape + new_shape = get_shape_from_sympy_shape(new_sympy_shape) + + for i_o in range(len(node.output)): + vi = self.known_vi_[node.output[i_o]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape)) + + def _infer_Transpose(self, node): # noqa: N802 + if node.input[0] in self.sympy_data_: + data_shape = self._get_shape(node, 0) + perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) + input_data = self.sympy_data_[node.input[0]] + self.sympy_data_[node.output[0]] = ( + np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist() + ) + + def _infer_Unsqueeze(self, node): # noqa: N802 + input_shape = self._get_shape(node, 0) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'axes' are provided as attribute or via 2nd input + if op_set < 13: + axes = get_attribute(node, "axes") + assert self._try_get_value(node, 1) is None + else: + axes = self._try_get_value(node, 1) + assert get_attribute(node, "axes") is None + + output_rank = len(input_shape) + len(axes) + axes = [handle_negative_axis(a, output_rank) for a in axes] + + input_axis = 0 + output_shape = [] + for i in range(output_rank): + if i in axes: + output_shape.append(1) + else: + output_shape.append(input_shape[input_axis]) + input_axis += 1 + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + self._pass_on_sympy_data(node) + + def _infer_ZipMap(self, node): # noqa: N802 + map_key_type = None + if get_attribute(node, "classlabels_int64s") is not None: + map_key_type = onnx.TensorProto.INT64 + elif get_attribute(node, "classlabels_strings") is not None: + map_key_type = onnx.TensorProto.STRING + + assert map_key_type is not None + new_vi = onnx.ValueInfoProto() + new_vi.name = node.output[0] + new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT + new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(new_vi) + + def _infer_Attention(self, node): # noqa: N802 + shape = self._get_shape(node, 0) + shape_weights = self._get_shape(node, 1) + shape_bias = self._try_get_shape(node, 2) + if shape_bias is not None: + assert len(shape_bias) == 1 + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] + if shape and len(shape) == 3: + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[2] = int(qkv_hidden_sizes_attr[2]) + elif isinstance(tripled_hidden_size, int): + shape[2] = int(tripled_hidden_size / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + if len(node.output) > 1: + # input shape: (batch_size, sequence_length, hidden_size) + # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) + # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) + # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length + input_shape = self._get_shape(node, 0) + past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] + mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] + + if past_shape and len(past_shape) == 5: + if mask_shape and len(mask_shape) in [2, 3]: + past_shape[3] = mask_shape[-1] + elif input_shape and len(input_shape) == 3: + if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): + past_shape[3] = input_shape[1] + past_shape[3] + else: + past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + # No past input but present output still exists + else: + num_heads = get_attribute(node, "num_heads") + head_size = input_shape[2] // num_heads + present_shape = [2, input_shape[0], num_heads, input_shape[1], head_size] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + + def _infer_GatedRelativePositionBias(self, node): # noqa: N802 + # When padding is removed: + # query_layer: (token_count, num_heads x head_size) + # token_offset: (batch_size, seq_len) + # Otherwise: + # query_layer: (batch_size, seq_len, num_heads x head_size) + # token_offset: None + # Output shape: (batch_size, num_heads, seq_len, seq_len) + num_heads = get_attribute(node, "num_heads") + + token_offset_shape = self._try_get_shape(node, 6) + if token_offset_shape is not None: + output_shape = [token_offset_shape[0], num_heads, token_offset_shape[1], token_offset_shape[1]] + else: + query_layer_shape = self._get_shape(node, 0) + assert query_layer_shape is not None and len(query_layer_shape) == 3 + output_shape = [query_layer_shape[0], num_heads, query_layer_shape[1], query_layer_shape[1]] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_PackedAttention(self, node): # noqa: N802 + shape = self._get_shape(node, 0) + shape_weights = self._get_shape(node, 1) + shape_bias = self._try_get_shape(node, 2) + if shape_bias is not None: + assert len(shape_bias) == 1 + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] + if shape and len(shape) == 2: + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[1] = int(qkv_hidden_sizes_attr[2]) + elif isinstance(tripled_hidden_size, int): + shape[1] = int(tripled_hidden_size / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + def _infer_PackedMultiHeadAttention(self, node): # noqa: N802 + shape_value = self._try_get_shape(node, 2) + if shape_value is not None and len(shape_value) == 2: + output_shape = shape_value + else: + shape_query = self._get_shape(node, 0) + assert shape_query is not None and len(shape_query) == 4 + output_shape = [shape_query[0], shape_query[1] * shape_query[3]] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_RemovePadding(self, node): # noqa: N802 + shape = self._get_shape(node, 0) + if shape and len(shape) == 3: + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]])) + + vi_token_offset = self.known_vi_[node.output[1]] + vi_token_offset.CopyFrom( + helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]]) + ) + + vi_cumulated_seq_len = self.known_vi_[node.output[2]] + vi_cumulated_seq_len.CopyFrom( + helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"]) + ) + + vi_max_seq_len = self.known_vi_[node.output[3]] + vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1])) + + def _infer_RestorePadding(self, node): # noqa: N802 + shape_input = self._get_shape(node, 0) + shape_token_offset = self._get_shape(node, 1) + if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2: + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + + output_shape = [shape_token_offset[0], shape_token_offset[1], shape_input[1]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_BiasGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_MultiHeadAttention(self, node): # noqa: N802 + # Output 0 has shape (batch_size, sequence_length, v_hidden_size) + # Q, K and V without packing: + # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) + # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) + # Packed KV: + # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size) + # Input 2 nullptr + # Packed QKV: + # Input 0 (batch_size, sequence_length, num_heads, 3, head_size) + # Input 1 nullptr + # Input 2 nullptr + + query_shape = self._get_shape(node, 0) + total_sequence_length = None + output_dtype = None + if query_shape is not None: + if len(query_shape) == 3: + key_shape = self._try_get_shape(node, 1) + # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. + output_shape = query_shape + if key_shape is not None and len(key_shape) == 3: + value_shape = self._try_get_shape(node, 2) + if value_shape is not None and len(value_shape) == 3: + output_shape[2] = value_shape[2] + total_sequence_length = key_shape[1] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + elif len(query_shape) == 5: + if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): + output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]] + else: + output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"] + + total_sequence_length = query_shape[1] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + if len(node.output) > 1: + batch_size = query_shape[0] + num_heads = get_attribute(node, "num_heads") + + head_size = None + if len(query_shape) == 3: + head_size = ( + int(query_shape[2] / num_heads) + if isinstance(query_shape[2], int) + else f"{query_shape[2]}/{num_heads}" + ) + else: + head_size = query_shape[4] + + past_shape = self._try_get_shape(node, 6) + + if past_shape is not None: + if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int): + total_sequence_length = past_shape[2] + total_sequence_length + else: + total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" + + present_shape = [batch_size, num_heads, total_sequence_length, head_size] + + assert output_dtype is not None + if len(node.output) > 2 and node.output[1] and node.output[2]: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + + def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802 + # Output 0 has shape (batch_size, 1, v_hidden_size) + # Q, K and V without packing: + # Input 0 (query) has shape (batch_size, 1, hidden_size) + # Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size) + + query_shape = self._get_shape(node, 0) + if query_shape is not None: + output_shape = query_shape + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + assert output_dtype is not None + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + if len(node.output) > 2 and node.output[1] and node.output[2]: + past_shape = self._try_get_shape(node, 5) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + + def _infer_FastGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_Gelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_QuickGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_GemmFastGelu(self, node): # noqa: N802 + self._compute_matmul_shape(node) + + def _infer_GemmFloat8(self, node): # noqa: N802 + self._compute_matmul_shape(node) + + def _infer_LayerNormalization(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + if len(node.output) > 1: + axis = get_attribute(node, "axis") + if axis is None: + axis = -1 + x_shape = self._get_shape(node, 0) + if x_shape is not None: + rank = len(x_shape) + axis = handle_negative_axis(axis, rank) + mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] + mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + if mean_dtype == onnx.TensorProto.FLOAT16 or mean_dtype == onnx.TensorProto.BFLOAT16: + mean_dtype = onnx.TensorProto.FLOAT + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape)) + if len(node.output) > 2: + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape)) + + def _infer_LongformerAttention(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_EmbedLayerNormalization(self, node): # noqa: N802 + input_ids_shape = self._get_shape(node, 0) + word_embedding_shape = self._get_shape(node, 2) + assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 + output_shape = [*input_ids_shape, word_embedding_shape[1]] + + word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape)) + + if len(node.output) > 1 and node.output[1]: + mask_index_shape = [input_ids_shape[0]] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape)) + + if len(node.output) > 2: + # Optional output of add before layer normalization is done + # shape is same as the output + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape)) + + def _infer_SkipLayerNormalization(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + # If the SkipLayerNormalization node contains the optional + # output for inference, infer the shape and type for it too + if len(node.output) > 3: + self._propagate_shape_and_type(node, 0, 3) + + def _infer_GroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_PagedAttention(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_GroupQueryAttention(self, node): # noqa: N802 + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + past_shape = self._try_get_shape(node, 3) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + + if node.input[1] != "" and node.input[2] != "": + self._propagate_shape_and_type(node, 0, 0) + else: + # combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size) + assert node.input[1] == "" and node.input[2] == "" + num_heads = get_attribute(node, "num_heads") + kv_num_heads = get_attribute(node, "kv_num_heads") + query_shape = self._get_shape(node, 0) + if query_shape is not None: + hidden_size = query_shape[2] + if isinstance(hidden_size, int): + head_size = int(hidden_size / (num_heads + 2 * kv_num_heads)) + query_shape[2] = num_heads * head_size + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape)) + + def _infer_SkipGroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node, 0, 0) + if len(node.output) > 1: + self._propagate_shape_and_type(node, 0, 1) + + def _infer_BiasSplitGelu(self, node): # noqa: N802 + input_shape = self._get_shape(node, 0) + bias_shape = self._get_shape(node, 1) + if input_shape and bias_shape and isinstance(bias_shape[0], int): + output_shape = input_shape + output_shape[2] = int(bias_shape[0] / 2) + vi = self.known_vi_[node.output[0]] + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) + + def _infer_BiasAdd(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_RotaryEmbedding(self, node): # noqa: N802 + if len(node.output) == 1: + self._propagate_shape_and_type(node) + elif len(node.output) == 2: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output + elif len(node.output) == 3: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=1, output_index=1) + self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output + + def _infer_PythonOp(self, node): # noqa: N802 + output_tensor_types = get_attribute(node, "output_tensor_types") + assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute." + output_tensor_ranks = get_attribute(node, "output_tensor_ranks") + assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute." + + from onnxruntime.capi._pybind_state import get_shape_inference_function + + func_name = get_attribute(node, "func_name").decode() + shape_inferer = get_shape_inference_function(func_name) + + # Set the context output separately. + # The first output is torch.autograd.Function''s context. + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) + + if shape_inferer is not None: + input_shapes = [] + input_dtypes = [] + for input_index in range(len(node.input)): + shape = self._get_shape(node, input_index) + input_shapes.append(shape) + input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type + input_dtypes.append(input_dtype) + output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) + assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( + f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " + f"but expected {len(node.output) - 1} outputs." + ) + for i in range(len(node.output) - 1): + output_index = i + 1 + vi = self.known_vi_[node.output[output_index]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]) + ) + else: + # General shape inference for PythonOp. + # Outputs after torch.autograd.Function's context are tensors. + # We assume their ranks are fixed for different model inputs. + for i in range(len(node.output) - 1): + # Process the i-th tensor outputs. + vi = self.known_vi_[node.output[i + 1]] + sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) + shape = get_shape_from_sympy_shape(sympy_shape) + value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape) + vi.CopyFrom(value_info) + + def _propagate_shape_and_type(self, node, input_index=0, output_index=0): + shape = self._get_shape(node, input_index) + output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[output_index]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape)) + + def _is_none_dim(self, dim_value): + if type(dim_value) != str: # noqa: E721 + return False + if "unk__" not in dim_value: + return False + if dim_value in self.symbolic_dims_: + return False + return True + + def _is_shape_contains_none_dim(self, out_shape): + for out in out_shape: + if self._is_none_dim(out): + return out + return None + + def _infer_impl(self, start_sympy_data=None): + self.sympy_data_ = start_sympy_data or {} + self.out_mp_.graph.ClearField("value_info") + self._apply_suggested_merge(graph_input_only=True) + self.input_symbols_ = set() + for i in self.out_mp_.graph.input: + input_shape = get_shape_from_value_info(i) + if input_shape is None: + continue + + if is_sequence(i.type): + input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim + else: + input_dims = i.type.tensor_type.shape.dim + + for i_dim, dim in enumerate(input_shape): + if dim is None: + # some models use None for symbolic dim in input, replace it with a string + input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim)) + + self.input_symbols_.update([d for d in input_shape if type(d) == str]) # noqa: E721 + + for s in self.input_symbols_: + if s in self.suggested_merge_: + s_merge = self.suggested_merge_[s] + assert s_merge in self.symbolic_dims_ + self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] + else: + # Since inputs are not produced by other ops, we can assume positivity + self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True) + # create a temporary ModelProto for single node inference + # note that we remove initializer to have faster inference + # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways + self.tmp_mp_ = onnx.ModelProto() + self.tmp_mp_.CopyFrom(self.out_mp_) + self.tmp_mp_.graph.ClearField("initializer") + + # compute prerequesite for node for topological sort + # node with subgraphs may have dependency on implicit inputs, which will affect topological sort + prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph + + def get_prereq(node): + names = {i for i in node.input if i} + subgraphs = [] + if node.op_type == "If": + subgraphs = [ + get_attribute(node, "then_branch"), + get_attribute(node, "else_branch"), + ] + elif node.op_type in ["Loop", "Scan"]: + subgraphs = [get_attribute(node, "body")] + for g in subgraphs: + g_outputs_and_initializers = {i.name for i in g.initializer} + g_prereq = set() + for n in g.node: + g_outputs_and_initializers.update(n.output) + for n in g.node: + g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers]) + names.update(g_prereq) + # remove subgraph inputs from g_prereq since those are local-only + for i in g.input: + if i.name in names: + names.remove(i.name) + return names + + for n in self.tmp_mp_.graph.node: + prereq_for_node[n.output[0]] = get_prereq(n) + + # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate + sorted_nodes = [] + sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)} + if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): + # Loop/Scan will have some graph output in graph inputs, so don't do topological sort + sorted_nodes = self.out_mp_.graph.node + else: + while not all([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): + old_sorted_nodes_len = len(sorted_nodes) + for node in self.out_mp_.graph.node: + if (node.output[0] not in sorted_known_vi) and all( + [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i] + ): + sorted_known_vi.update(node.output) + sorted_nodes.append(node) + if old_sorted_nodes_len == len(sorted_nodes) and not all( + [o.name in sorted_known_vi for o in self.out_mp_.graph.output] + ): + raise Exception("Invalid model with cyclic graph") + + for node in sorted_nodes: + assert all([i in self.known_vi_ for i in node.input if i]) + self._onnx_infer_single_node(node) + known_aten_op = False + if node.op_type in self.dispatcher_: + self.dispatcher_[node.op_type](node) + elif node.op_type in ["ConvTranspose"]: + # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input + # before adding symbolic compute for them + # mark the output type as UNDEFINED to allow guessing of rank + vi = self.known_vi_[node.output[0]] + if len(vi.type.tensor_type.shape.dim) == 0: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + elif node.op_type == "ATen" and node.domain == "org.pytorch.aten": + for attr in node.attribute: + # TODO: Is overload_name needed? + if attr.name == "operator": + aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + if aten_op_name in self.aten_op_dispatcher_: + known_aten_op = True + self.aten_op_dispatcher_[aten_op_name](node) + break + + if self.verbose_ > 2: + logger.debug(node.op_type + ": " + node.name) + for i, name in enumerate(node.input): + logger.debug( + " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "") + ) + + # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] + # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case + if node.op_type in [ + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Where", + "Sum", + ]: + vi = self.known_vi_[node.output[0]] + out_rank = len(get_shape_from_type_proto(vi.type)) + in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] + for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): + in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] + if len(in_dims) > 1: + self._check_merged_dims(in_dims, allow_broadcast=True) + + for i_o in range(len(node.output)): + # Special cases: + # 1) We do not care about the training related outputs of SkipLayerNormalization + # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because + # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding + # contrib op + if ( + node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" + ) and i_o in [1, 2]: + continue + if node.op_type == "RotaryEmbedding" and len(node.output) > 1: + # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs + # generated by `export_modules_as_functions` + continue + + vi = self.known_vi_[node.output[i_o]] + out_type = vi.type + out_type_kind = out_type.WhichOneof("value") + + # do not process shape for non-tensors + if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]: + if self.verbose_ > 2: + if out_type_kind == "sequence_type": + seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") + if seq_cls_type == "tensor_type": + logger.debug( + " {}: sequence of {} {}".format( + node.output[i_o], + str(get_shape_from_value_info(vi)), + onnx.TensorProto.DataType.Name( + vi.type.sequence_type.elem_type.tensor_type.elem_type + ), + ) + ) + else: + logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}") + else: + logger.debug(f" {node.output[i_o]}: {out_type_kind}") + continue + + out_shape = get_shape_from_value_info(vi) + out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED + if self.verbose_ > 2: + logger.debug( + " {}: {} {}".format( + node.output[i_o], + str(out_shape), + onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), + ) + ) + if node.output[i_o] in self.sympy_data_: + logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) + + # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain + if ( + out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) + ) or out_type_undefined: + if self.auto_merge_: + if node.op_type in [ + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Concat", + "Where", + "Sum", + "Equal", + "Less", + "Greater", + "LessOrEqual", + "GreaterOrEqual", + "Min", + "Max", + ]: + shapes = [self._get_shape(node, i) for i in range(len(node.input))] + if node.op_type in [ + "MatMul", + "MatMulInteger", + "MatMulInteger16", + ]: + if None in out_shape or self._is_shape_contains_none_dim(out_shape): + if None in out_shape: + idx = out_shape.index(None) + else: + idx = out_shape.index(self._is_shape_contains_none_dim(out_shape)) + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] + # only support auto merge for MatMul for dim < rank-2 when rank > 2 + assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 + assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 + elif node.op_type == "Expand": + # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) + shapes = [ + self._get_shape(node, 0), + self._get_value(node, 1), + ] + else: + shapes = [] + + if shapes: + for idx in range(len(out_shape)): + if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]): + continue + # note that the broadcasting rule aligns from right to left + # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] + if len(dim_idx) > 0: + self._add_suggested_merge( + [ + s[i] if is_literal(s[i]) else str(s[i]) + for s, i in zip(shapes, dim_idx) + if i >= 0 + ] + ) + self.run_ = True + else: + self.run_ = False + else: + self.run_ = False + + # create new dynamic dims for ops not handled by symbolic shape inference + if self.run_ is False and node.op_type not in self.dispatcher_ and not known_aten_op: + is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0) + if is_unknown_op: + # unknown op to ONNX, maybe from higher opset or other domain + # only guess the output rank from input 0 when using guess_output_rank option + out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1 + else: + # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape + out_rank = len(out_shape) + + if out_rank >= 0: + new_shape = self._new_symbolic_shape(out_rank, node, i_o) + if out_type_undefined: + # guess output data type from input vi if not defined + out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + else: + # otherwise, use original data type + out_dtype = vi.type.tensor_type.elem_type + vi.CopyFrom( + helper.make_tensor_value_info( + vi.name, + out_dtype, + get_shape_from_sympy_shape(new_shape), + ) + ) + + if self.verbose_ > 0: + if is_unknown_op: + logger.debug( + "Possible unknown op: {} node: {}, guessing {} shape".format( + node.op_type, node.name, vi.name + ) + ) + if self.verbose_ > 2: + logger.debug( + " {}: {} {}".format( + node.output[i_o], + str(new_shape), + vi.type.tensor_type.elem_type, + ) + ) + + self.run_ = True + continue # continue the inference after guess, no need to stop as no merge is needed + + if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: + logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name) + logger.debug("node inputs:") + for i in node.input: + if i in self.known_vi_: + logger.debug(self.known_vi_[i]) + else: + logger.debug(f"not in known_vi_ for {i}") + logger.debug("node outputs:") + for o in node.output: + if o in self.known_vi_: + logger.debug(self.known_vi_[o]) + else: + logger.debug(f"not in known_vi_ for {o}") + if self.auto_merge_ and not out_type_undefined: + logger.debug("Merging: " + str(self.suggested_merge_)) + return False + + self.run_ = False + return True + + def _update_output_from_vi(self): + for output in self.out_mp_.graph.output: + if output.name in self.known_vi_: + output.CopyFrom(self.known_vi_[output.name]) + + @staticmethod + def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0): + onnx_opset = get_opset(in_mp) + if (not onnx_opset) or onnx_opset < 7: + logger.warning("Only support models of onnx opset 7 and above.") + return None + symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose) + all_shapes_inferred = False + symbolic_shape_inference._preprocess(in_mp) + while symbolic_shape_inference.run_: + all_shapes_inferred = symbolic_shape_inference._infer_impl() + symbolic_shape_inference._update_output_from_vi() + if not all_shapes_inferred: + onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) + raise Exception("Incomplete symbolic shape inference") + return symbolic_shape_inference.out_mp_ + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, help="The input model file") + parser.add_argument("--output", help="The output model file") + parser.add_argument( + "--auto_merge", + help="Automatically merge symbolic dims when confliction happens", + action="store_true", + default=False, + ) + parser.add_argument( + "--int_max", + help="maximum value for integer to be treated as boundless for ops like slice", + type=int, + default=2**31 - 1, + ) + parser.add_argument( + "--guess_output_rank", + help="guess output rank to be the same as input 0 for unknown ops", + action="store_true", + default=False, + ) + parser.add_argument( + "--verbose", + help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed", + type=int, + default=0, + ) + parser.add_argument( + "--save_as_external_data", + help="Saving an ONNX model to external data", + action="store_true", + default=False, + ) + parser.add_argument( + "--all_tensors_to_one_file", + help="Saving all the external data to one file", + action="store_true", + default=False, + ) + parser.add_argument( + "--external_data_location", + help="The file location to save the external file", + default="./", + ) + parser.add_argument( + "--external_data_size_threshold", + help="The size threshold for external data", + type=int, + default=1024, + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + logger.info("input model: " + args.input) + if args.output: + logger.info("output model " + args.output) + logger.info("Doing symbolic shape inference...") + out_mp = SymbolicShapeInference.infer_shapes( + onnx.load(args.input), + args.int_max, + args.auto_merge, + args.guess_output_rank, + args.verbose, + ) + if args.output and out_mp: + if args.save_as_external_data: + onnx.save_model( + out_mp, + args.output, + save_as_external_data=True, + all_tensors_to_one_file=args.all_tensors_to_one_file, + location=args.external_data_location, + size_threshold=args.external_data_size_threshold, + convert_attribute=False, + ) + else: + onnx.save(out_mp, args.output) + logger.info("Done!")