Skip to content

Commit

Permalink
Standardize the inputs for ONNX STFT op for Whisper model (#681)
Browse files Browse the repository at this point in the history
* Standardize the inputs for ONNX STFT op for Whisper model

* undo the format change

* Update _torch_cvt.py
  • Loading branch information
wenbingl authored Mar 29, 2024
1 parent 5aefc7e commit 00a594f
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions onnxruntime_extensions/_torch_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def _to_onnx_stft(onnx_model, n_fft):

make_node = onnx.helper.make_node
replaced_nodes = [
make_node('Constant', inputs=[], outputs=['const_minus_1_output_0'], name='const_minus_1',
value=numpy_helper.from_array(np.array([-1], dtype='int64'))),
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
value=numpy_helper.from_array(np.array([0,
n_fft // 2, 0,
Expand All @@ -144,20 +146,23 @@ def _to_onnx_stft(onnx_model, n_fft):
make_node('Pad',
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
outputs=['pad_1_output_0'], mode='reflect'),
make_node('Unsqueeze',
inputs=['pad_1_output_0', 'const_minus_1_output_0'],
outputs=['unsqueeze_1_output_0'],
name='unsqueeze_1'),
make_node('STFT',
inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
outputs=['stft_output_0'], name='stft', domain='', onesided=1),
inputs=['unsqueeze_1_output_0', stft_norm_node.input[2],
stft_norm_node.input[3], stft_norm_node.input[4]],
outputs=['stft_output_0'], name='stft', onesided=1),
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
perm=[0, 2, 1, 3]),
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0',
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
name='slice_1'),
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
Expand Down

0 comments on commit 00a594f

Please sign in to comment.