Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For better numerical accuracy in LayerNorm #518

Open
nhamanasu opened this issue Jan 9, 2025 · 2 comments
Open

For better numerical accuracy in LayerNorm #518

nhamanasu opened this issue Jan 9, 2025 · 2 comments

Comments

@nhamanasu
Copy link

nhamanasu commented Jan 9, 2025

🚀 The feature, motivation and pitch

In the implementation of LayerNorm, I wonder the var computation may result in a wrong value when BLOCK_SIZE is not the same as the feature dimension (in other words, mask has some False elements).

This intuition is because, the elements where mask == True will be -mean instead of 0.0, and var just takes the summation including them too.

I think we should properly mask those positions with 0.0 by using tl.where or something.

var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols

Alternatives

No response

Additional context

No response

@ByronHsu
Copy link
Collaborator

in what case BLOCK_SIZE != feature dim?

@nhamanasu
Copy link
Author

Thank you for the comment. Please see the discussion in #519

To put it simply, BLOCK_SIZE must be the power of 2 in triton.
So we can make arbitrary situations where BLOCK_SIZE != feature dim e.g., feat_dim == 768, 1536, whatever.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants