Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

LLM models with TorchScript backends #74

Open
Devjiu opened this issue Jan 31, 2024 · 0 comments
Open

LLM models with TorchScript backends #74

Devjiu opened this issue Jan 31, 2024 · 0 comments

Comments

@Devjiu
Copy link
Contributor

Devjiu commented Jan 31, 2024

For many hugging face models available option torchscript=True:

model = AutoModelForCausalLM.from_pretrained(model_name, torchscript=True, torch_dtype=self.dtype)

According to guide: https://huggingface.co/docs/transformers/torchscript
for trace required following steps:

# The model needs to be in evaluation mode
model.eval()

# If you are instantiating the model with *from_pretrained* you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")

We are currently using model generate method, not forward, so to pass it into trace can be used Wrapper like this:

class Wrapper(torch.nn.Module):
    def __init__(self, llm_net):
        super().__init__()
        self.net = llm_net
        self.args = []
        self.kwargs = {}
        pass

    def set_args(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs


    def forward(self, input_ids):
        print("passing args: ", self.args , " kwargs: ", self.kwargs)
        return self.net.generate(input_ids, *self.args, **self.kwargs)

but with this approach we have an error:

ValueError: `prompt_length_to_skip` has to be a positive integer, but is 11

Also such scripting can be configured without any args, but it produces incorrect output
produced

localdisk/dmitriim/miniconda3/envs/test_static/lib/python3.11/site-packages/transformers/generation/stopping_criteria.py:132: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return any(criteria(input_ids, scores) for criteria in self)
/localdisk/dmitriim/miniconda3/envs/test_static/lib/python3.11/site-packages/transformers/models/gptj/modeling_gptj.py:785: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if input_ids.shape[1] > past_length:
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Warmup started
input size:  torch.Size([1, 11])
output size:  torch.Size([1, 20])  zero:  tensor([ 1858,   389,  1811,  2842,   284,  3067,    11,   475,   616, 12507,
          318,   416,  4512,    13,   314,  1842,   262,  7002,    11,   262])
Warmup done
input size:  torch.Size([1, 23])
output size:  torch.Size([1, 32])  zero:  tensor([ 4342,   318,   257,  1621,   546,   257,  1048,   326,  1064,   503,
          339,   373,  8197,    25,   530,  1110,  1310,  5045,  1820,   373,
         2045,   832,  1468,  4590, 16788,   290,   339,  1043,   257,  4286,
          286,   465])

expected:

Warmup started
input size:  torch.Size([1, 11])
output size:  torch.Size([1, 139])  zero:  tensor([ 1858,   389,  1811,  2842,   284,  3067,    11,   475,   616, 12507,
          318,   416,  4512,    13,   632,   447,   247,    82,  7026,    11,
          340,   447,   247,    82,  3049,    11,   340,   447,   247,    82,
         6792,    11,   340,   447,   247,    82,  3338,   290,   340,   447,
          247,    82, 34286,    12, 13120,    13,   198,   198,    40,   447,
          247,   303,   587, 16574,   416,  4512,  1201,   314,   373,   257,
         5141,    11,   290,   314,   447,   247,   303,  1464,  6151,   340,
           13,   632,   447,   247,    82,   257,  1049,   835,   284,   651,
          284,   760,   257,  1499,    11,   290,   340,   447,   247,    82,
          257,  1049,   835,   284,   651,   284,   760,  3511,    13,   198,
          198,    40,   447,   247,   303,   587, 16574,   416,  4512,  1201,
          314,   373,   257,  5141,    11,   290,   314,   447,   247,   303,
         1464,  6151,   340,    13,   632,   447,   247,    82,   257,  1049,
          835,   284,   651,   284,   760,   257,  1499,    11,   290])
Warmup done
input size:  torch.Size([1, 23])
output size:  torch.Size([1, 151])  zero:  tensor([ 4342,   318,   257,  1621,   546,   257,  1048,   326,  1064,   503,
          339,   373,  8197,    25,   530,  1110,  1310,  5045,  1820,   373,
         2045,   832,  1468,  4590, 16788,   290,   339,  1625,  1973,   257,
         4286,   286,   257,  1310,  2933,   326,  3114,   655,   588,   683,
           13,   679,  1965,   465,  2802,   546,   340,   290,   673,  1297,
          683,   326,   339,   373,  8197,   290,   326,   339,   373,  4642,
          319,   257,  5318,   287,   262,  3504,   286, 12062,    13,   198,
          198, 22253,  5045,  1820,  1422,   470,  1975,   607,    11,   523,
          339,  1965,   465,  2988,   546,   340,    13,  2399,  2988,  1297,
          683,   326,   339,   373,  8197,   290,   326,   339,   373,  4642,
          319,   257,  5318,   287,   262,  3504,   286, 12062,    13,   198,
          198, 22253,  5045,  1820,  1422,   470,  1975,   683,    11,   523,
          339,  1965,   465,  2802,   546,   340,    13,  1375,  1297,   683,
          326,   339,   373,  8197,   290,   326,   339,   373,  4642,   319,
          257,  5318,   287,   262,  3504,   286, 12062,    13,   198,   198,
        22253])

So TorchScripted model differ from original one.

Passing more arguments issues internal errors during torchscript conversion. Also according to this discussion: https://discuss.pytorch.org/t/converting-gpt-2-to-torchscript/175593 - it's recommended by pytorch community admins to avoid using torchscript for models, that have data-dependent control flow, which is almost any llm model.

So the question is - do we need llm models on torchscript?
Do we have torchscript OneDNN reference performance for LLM models?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant