Skip to content

Commit

Permalink
Add string strip text operator (#460)
Browse files Browse the repository at this point in the history
* add string strip text operator

---------

Co-authored-by: Wenbing Li <[email protected]>
  • Loading branch information
aidanryan-msft and wenbingl authored May 30, 2023
1 parent 93f239c commit 30eb7af
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 8 deletions.
54 changes: 54 additions & 0 deletions operators/text/string_strip.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "string_strip.hpp"
#include "string_tensor.h"
#include <vector>
#include <cmath>
#include <algorithm>

const char* WHITE_SPACE_CHARS = " \t\n\r\f\v";

KernelStringStrip::KernelStringStrip(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}

void KernelStringStrip::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> X;
GetTensorMutableDataString(api_, ort_, context, input_X, X);

// For each string in input, replace with whitespace-trimmed version.
for (size_t i = 0; i < X.size(); ++i) {
size_t nonWhitespaceBegin = X[i].find_first_not_of(WHITE_SPACE_CHARS);
if (nonWhitespaceBegin != std::string::npos) {
size_t nonWhitespaceEnd = X[i].find_last_not_of(WHITE_SPACE_CHARS);
size_t nonWhitespaceRange = nonWhitespaceEnd - nonWhitespaceBegin + 1;

X[i] = X[i].substr(nonWhitespaceBegin, nonWhitespaceRange);
}
}

// Fills the output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
FillTensorDataString(api_, ort_, context, X, output);
}

const char* CustomOpStringStrip::GetName() const { return "StringStrip"; };

size_t CustomOpStringStrip::GetInputTypeCount() const {
return 1;
};

ONNXTensorElementDataType CustomOpStringStrip::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

size_t CustomOpStringStrip::GetOutputTypeCount() const {
return 1;
};

ONNXTensorElementDataType CustomOpStringStrip::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
20 changes: 20 additions & 0 deletions operators/text/string_strip.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "ocos.h"
#include "string_utils.h"

struct KernelStringStrip : BaseKernel {
KernelStringStrip(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

struct CustomOpStringStrip : OrtW::CustomOpBase<CustomOpStringStrip, KernelStringStrip> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
14 changes: 7 additions & 7 deletions operators/text/text.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "text/string_join.hpp"
#include "text/string_lower.hpp"
#include "text/string_split.hpp"
#include "text/string_strip.hpp"
#include "text/string_to_vector.hpp"
#include "text/string_upper.hpp"
#include "text/vector_to_string.hpp"
Expand All @@ -17,15 +18,14 @@
#if defined(ENABLE_RE2_REGEX)
#include "text/re2_strings/string_regex_replace.hpp"
#include "text/re2_strings/string_regex_split.hpp"
#endif // ENABLE_RE2_REGEX
#endif // ENABLE_RE2_REGEX


FxLoadCustomOpFactory LoadCustomOpClasses_Text =
LoadCustomOpClasses<CustomOpClassBegin,
FxLoadCustomOpFactory LoadCustomOpClasses_Text =
LoadCustomOpClasses<CustomOpClassBegin,
#if defined(ENABLE_RE2_REGEX)
CustomOpStringRegexReplace,
CustomOpStringRegexSplitWithOffsets,
#endif // ENABLE_RE2_REGEX
#endif // ENABLE_RE2_REGEX
CustomOpRaggedTensorToDense,
CustomOpRaggedTensorToSparse,
CustomOpStringRaggedTensorToDense,
Expand All @@ -38,10 +38,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Text =
CustomOpStringMapping,
CustomOpMaskedFill,
CustomOpStringSplit,
CustomOpStringStrip,
CustomOpStringToVector,
CustomOpVectorToString,
CustomOpStringLength,
CustomOpStringConcat,
CustomOpStringECMARegexReplace,
CustomOpStringECMARegexSplitWithOffsets
>;
CustomOpStringECMARegexSplitWithOffsets>;
37 changes: 36 additions & 1 deletion test/test_string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,22 @@ def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
return model


def _create_test_model_string_strip(prefix, domain='ai.onnx.contrib'):
nodes = []
nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
nodes[1:] = [helper.make_node('%sStringStrip' % prefix,
['identity1'], ['customout'],
domain=domain)]

input0 = helper.make_tensor_value_info(
'input_1', onnx_proto.TensorProto.STRING, [None, None])
output0 = helper.make_tensor_value_info(
'customout', onnx_proto.TensorProto.STRING, [None, None])

graph = helper.make_graph(nodes, 'test0', [input0], [output0])
model = make_onnx_model(graph)
return model

def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
nodes = []
nodes.append(helper.make_node('Identity', ['input'], ['id1']))
Expand Down Expand Up @@ -436,6 +452,26 @@ def test_check_types(self):
for t in type_list:
self.assertIn(t, def_list)

def test_string_strip_cc(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_strip('')
self.assertIn('op_type: "StringStrip"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
input_1 = np.array([[" a b c "]])
txout = sess.run(None, {'input_1': input_1})
self.assertEqual(txout[0].tolist(), np.array([["a b c"]]).tolist())

def test_string_strip_cc_empty(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_strip('')
self.assertIn('op_type: "StringStrip"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
input_1 = np.array([[""]])
txout = sess.run(None, {'input_1': input_1})
self.assertEqual(txout[0].tolist(), np.array([[""]]).tolist())

def test_string_upper_cc(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
Expand Down Expand Up @@ -1151,7 +1187,6 @@ def _CreateTable(vocab, num_oov=1):
res.__len__ = lambda self: len(vocab)

vocab_table = _CreateTable(["want", "##want", "##ed", "wa", "un", "runn", "##ing"])

text = tf.convert_to_tensor(["unwanted running", "unwantedX running"], dtype=tf.string)
try:
tf_tokens, tf_rows, tf_begins, tf_ends = (
Expand Down

0 comments on commit 30eb7af

Please sign in to comment.