diff --git a/operators/tokenizer/sentencepiece_tokenizer.cc b/operators/tokenizer/sentencepiece_tokenizer.cc index 42faa2493..e0cbbe679 100644 --- a/operators/tokenizer/sentencepiece_tokenizer.cc +++ b/operators/tokenizer/sentencepiece_tokenizer.cc @@ -83,32 +83,33 @@ OrtStatusPtr KernelSentencepieceTokenizer::Compute(const ortc::Tensor(str_input[i].length())); } - - if (fairseq.has_value() && (*fairseq)) { - // HF Fairseq Example (XLMRobertaTokenizer) : https://huggingface.co/transformers/v4.6.0/_modules/transformers/models/xlm_roberta/tokenization_xlm_roberta.html#XLMRobertaTokenizer - // - // Original fairseq vocab and spm vocab must be "aligned": - // Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 - // -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- - // fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' - // spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' - // - // As per HF, the first "real" token "," has position 4 in the XLMRobertaTokenizer vocab and position - // 3 in the SPM vocab, so we add a padding value of 1 to IDs, and fix exceptions for '' and ''. - std::for_each(content.begin(), content.end(), [](int& n) { - if (n == 0) { // '': 0 -> 3 - n = 3; - } else if (n == 1) { // '': 1 -> 0 - n = 0; - } else if (n != 2) { // '': 2 -> 2, '<*>': x -> x + 1 - n++; - } - }); - } } } instance_indices.push_back(content.size()); + // Patch fairseq indices + if (fairseq.has_value() && (*fairseq) && !add_rev) { + // HF Fairseq Example (XLMRobertaTokenizer) : https://huggingface.co/transformers/v4.6.0/_modules/transformers/models/xlm_roberta/tokenization_xlm_roberta.html#XLMRobertaTokenizer + // + // Original fairseq vocab and spm vocab must be "aligned": + // Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + // -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + // fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + // spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + // + // As per HF, the first "real" token "," has position 4 in the XLMRobertaTokenizer vocab and position + // 3 in the SPM vocab, so we add a padding value of 1 to IDs, and fix exceptions for '' and ''. + std::for_each(content.begin(), content.end(), [](int& n) { + if (n == 0) { // '': 0 -> 3 + n = 3; + } else if (n == 1) { // '': 1 -> 0 + n = 0; + } else if (n != 2) { // '': 2 -> 2, '<*>': x -> x + 1 + n++; + } + }); + } + // Setup output std::vector size_content(1); size_content[0] = content.size();