Skip to content

Commit

Permalink
Support audio attention mask for multiple audio file preprocessing fo…
Browse files Browse the repository at this point in the history
…r Phi4 model;
  • Loading branch information
wenbingl committed Mar 4, 2025
1 parent d374c03 commit efc86e7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
3 changes: 2 additions & 1 deletion shared/api/runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,8 @@ class OrtxRunner {
if (shape != ts[axis]->Shape()) {
is_same_shape = false;
auto dtype = ts[axis]->Type();
if (dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 && dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
if (dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 &&
dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT && dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) {
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
}
if (IsGreaterShape(ts[axis]->Shape(), shape)) {
Expand Down
25 changes: 11 additions & 14 deletions shared/api/speech_features.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,12 @@ class SpeechFeatures {
return stft_norm_.Compute(pcm, n_fft_, hop_length_, {fft_win_.data(), fft_win_.size()}, n_fft_, stft_norm);
}

OrtxStatus SpeechLibSTFTNorm(const ortc::Tensor<float>& pcm,
ortc::Tensor<float>& stft_norm,
ortc::Tensor<int64_t>& audio_frames) {
constexpr int64_t feat_stride = 1;
OrtxStatus SpeechLibSTFTNorm(const ortc::Tensor<float>& pcm, ortc::Tensor<float>& stft_norm) {
const float preemphasis = 0.97f;
// # Spec 1: SpeechLib cut remaining sample insufficient for a hop
// n_batch = (wav.shape[0] - win_length) // hop_length + 1
auto pcm_length = pcm.Shape()[1];
auto n_batch = (pcm_length - frame_length_) / hop_length_ + 1;
audio_frames.Allocate({1})[0] = n_batch * feat_stride;
auto pcm_data = pcm.Data();
dlib::matrix<float> dm_x = dlib::mat(pcm_data, 1, pcm_length);

Expand Down Expand Up @@ -609,26 +605,24 @@ class Phi4AudioEmbed {
OrtxStatus Compute(const ortc::Tensor<float>& pcm,
const ortc::Tensor<int64_t>& sr,
ortc::Tensor<float>& ts_logmel,
ortc::Tensor<int64_t>& audio_frames,
ortc::Tensor<bool>& audio_attention_mask,
ortc::Tensor<int64_t>& embeded_size) {
int64_t sr_val = sr.Data()[0];
ortc::Tensor<float> stft_norm(&CppAllocator::Instance());
ortc::Tensor<int64_t> num_audio_frames(&CppAllocator::Instance());
SpeechFeatures stft_normal;
stft_normal.Init(sr_val == 8000? stft_normal_8k_attrs_: stft_normal_attrs_);
auto status = stft_normal.SpeechLibSTFTNorm(pcm, stft_norm, num_audio_frames);
auto status = stft_normal.SpeechLibSTFTNorm(pcm, stft_norm);
if (!status.IsOk()) {
return status;
}

SpeechLibLogMel logmel;
// already checked in Init

// Currently we only support 8k and 16k Hz sampling rate.
if (sr_val != 8000 && sr_val != 16000){
return OrtxStatus(kOrtxErrorNotImplemented, "Currently only 8k and 16k Hz sampling rate is supported. Please resample your audio file with unsupported audio sampling rate: " + sr_val);
return {kOrtxErrorInvalidArgument, "Only 8k and 16k Hz target sampling rate is supported."};
}

SpeechLibLogMel logmel;
// attributes already are verified in Init method
logmel.Init(sr_val == 8000 ? logmel_8k_attrs_: logmel_attrs_);
status = logmel.Compute(stft_norm, ts_logmel);
if (!status.IsOk()) {
Expand All @@ -648,10 +642,13 @@ class Phi4AudioEmbed {
return result
*/
auto audio_frames = ts_logmel.Shape()[0];
auto embedded_size_data = embeded_size.Allocate({1});
embedded_size_data[0] = std::ceil(static_cast<float>(ts_logmel.Shape()[0]) / audio_compression_rate_);
embedded_size_data[0] = std::ceil(static_cast<float>(audio_frames) / audio_compression_rate_);

audio_frames.Allocate({1})[0] = num_audio_frames.Data()[0];
constexpr int64_t feat_stride = 1;
auto attention = audio_attention_mask.Allocate({audio_frames * feat_stride});
std::memset(attention, 1, audio_frames * feat_stride * sizeof(bool));
return status;
}

Expand Down
29 changes: 15 additions & 14 deletions test/pp_api_test/test_feature_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,17 @@ TEST(ExtractorTest, TestPhi4AudioFeatureExtraction) {
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3, 1344, 80}));

tensor.reset();
const int64_t* audio_frames{};
const int64_t* audio_frames_shape{};
size_t audio_frames_num_dims;
const bool* audio_attention_mask{};
const int64_t* audio_mask_shape{};
size_t audio_mask_dims;
err = OrtxTensorResultGetAt(result.get(), 1, tensor.ToBeAssigned());
ASSERT_EQ(err, kOrtxOK);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&audio_frames), &audio_frames_shape, &audio_frames_num_dims);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&audio_attention_mask), &audio_mask_shape, &audio_mask_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(std::vector<int64_t>(audio_frames_shape, audio_frames_shape + audio_frames_num_dims), std::vector<int64_t>({3, 1}));
const size_t num_elements = std::accumulate(audio_frames_shape, audio_frames_shape + audio_frames_num_dims, 1, std::multiplies<size_t>());
ASSERT_EQ(std::vector<int64_t>(audio_frames, audio_frames + num_elements), std::vector<int64_t>({1098, 1332, 1344}));
ASSERT_EQ(std::vector<int64_t>(audio_mask_shape, audio_mask_shape + audio_mask_dims), std::vector<int64_t>({3, 1344}));
ASSERT_EQ(std::count(audio_attention_mask + 0 * 1344, audio_attention_mask + 1 * 1344, true), 1098);
ASSERT_EQ(std::count(audio_attention_mask + 1 * 1344, audio_attention_mask + 2 * 1344, true), 1332);
ASSERT_EQ(std::count(audio_attention_mask + 2 * 1344, audio_attention_mask + 3 * 1344, true), 1344);

tensor.reset();
err = OrtxTensorResultGetAt(result.get(), 2, tensor.ToBeAssigned());
Expand Down Expand Up @@ -109,16 +110,16 @@ TEST(ExtractorTest, TestPhi4AudioFeatureExtraction8k) {
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({1, 2938, 80}));

tensor.reset();
const int64_t* audio_frames{};
const int64_t* audio_frames_shape{};
size_t audio_frames_num_dims;
const bool* audio_attention_mask{};
const int64_t* audio_mask_shape{};
size_t audio_mask_dims{};
err = OrtxTensorResultGetAt(result.get(), 1, tensor.ToBeAssigned());
ASSERT_EQ(err, kOrtxOK);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&audio_frames), &audio_frames_shape, &audio_frames_num_dims);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&audio_attention_mask), &audio_mask_shape, &audio_mask_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(std::vector<int64_t>(audio_frames_shape, audio_frames_shape + audio_frames_num_dims), std::vector<int64_t>({1, 1}));
const size_t num_elements = std::accumulate(audio_frames_shape, audio_frames_shape + audio_frames_num_dims, 1, std::multiplies<size_t>());
ASSERT_EQ(std::vector<int64_t>(audio_frames, audio_frames + num_elements), std::vector<int64_t>({2938}));
ASSERT_EQ(std::vector<int64_t>(audio_mask_shape, audio_mask_shape + audio_mask_dims), std::vector<int64_t>({1, 2938}));
const size_t num_elements = std::count(audio_attention_mask, audio_attention_mask + 2938, true);
ASSERT_EQ(num_elements, 2938);

tensor.reset();
err = OrtxTensorResultGetAt(result.get(), 2, tensor.ToBeAssigned());
Expand Down

0 comments on commit efc86e7

Please sign in to comment.