Skip to content

Commit

Permalink
Merge branch 'main' into jamba-test
Browse files Browse the repository at this point in the history
  • Loading branch information
yubofredwang committed Sep 6, 2024
2 parents 7602cf7 + 7382a87 commit 3357a56
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/liger_kernel/ops/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _layer_norm_backward_kernel(
stride_dy, # stride of each row in output grad
n_rows,
n_cols,
rows_per_program,
rows_per_program: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
num_warps: tl.constexpr,
dtype: tl.constexpr,
Expand Down
60 changes: 30 additions & 30 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def _rms_norm_forward_kernel(
X_row_stride,
W_ptr,
W_row_stride,
r_ptr,
r_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
Expand All @@ -57,13 +57,13 @@ def _rms_norm_forward_kernel(

Y_ptr += row_idx * Y_row_stride
X_ptr += row_idx * X_row_stride
r_ptr += row_idx * r_row_stride
RSTD_ptr += row_idx * RSTD_row_stride

X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)

# On Llama, only inv_rms is computed on fp32
# On Llama, only rstd is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)

Expand All @@ -73,14 +73,14 @@ def _rms_norm_forward_kernel(
X_row = X_row.to(tl.float32)

mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
inv_rms = rsqrt(mean_square + eps)
rstd = rsqrt(mean_square + eps)

# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(r_ptr, inv_rms)
tl.store(RSTD_ptr, rstd)

X_row = X_row * inv_rms
X_row = X_row * rstd

# On Llama, the multiplication with the weight is done on the original dtype
if casting_mode == _CASTING_MODE_LLAMA:
Expand All @@ -99,8 +99,8 @@ def _rms_norm_backward_kernel(
X_row_stride,
W_ptr,
W_row_stride,
r_ptr,
r_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_cols,
Expand All @@ -119,7 +119,7 @@ def _rms_norm_backward_kernel(

dY_ptr += row_idx * dY_row_stride
X_ptr += row_idx * X_row_stride
r_ptr += row_idx * r_row_stride
RSTD_ptr += row_idx * RSTD_row_stride
dW_ptr += row_idx * dW_row_stride

dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
Expand All @@ -128,7 +128,7 @@ def _rms_norm_backward_kernel(
original_x_dtype = X_row.dtype

# Get cached rms
inv_rms_row = tl.load(r_ptr)
rstd_row = tl.load(RSTD_ptr)

W_row = W_row + offset

Expand All @@ -146,18 +146,18 @@ def _rms_norm_backward_kernel(

m = dY_row * W_row

dX_row = inv_rms_row * m
dX_row = rstd_row * m

dX_row += (inv_rms_row) * (
-(1 / n_cols) * inv_rms_row * inv_rms_row * tl.sum(m * X_row, axis=0) * X_row
dX_row += (rstd_row) * (
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
)

# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row = dY_row * (X_row * inv_rms_row).to(original_x_dtype)
dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row = dY_row * (X_row * inv_rms_row)
dW_row = dY_row * (X_row * rstd_row)

tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
Expand Down Expand Up @@ -188,14 +188,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# r is to cache (1/rms) for each row
# r is always computed/stored in fp32 if we are using Llama or Gemma casting mode
r_dtype = (
# RSTD is to cache rstd for each row
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
rstd_dtype = (
torch.float32
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
else X.dtype
)
r = torch.empty(n_rows, dtype=r_dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)

# Check constraints.
assert (
Expand All @@ -209,19 +209,19 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
X.stride(0),
W,
W.stride(0),
r,
r.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), X, r, BLOCK_SIZE, num_warps, casting_mode
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode


def rms_norm_backward(dY, X, W, r, offset, casting_mode, BLOCK_SIZE, num_warps):
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
Expand All @@ -239,8 +239,8 @@ def rms_norm_backward(dY, X, W, r, offset, casting_mode, BLOCK_SIZE, num_warps):
X.stride(0),
W,
W.stride(0),
r,
r.stride(0),
RSTD,
RSTD.stride(0),
dW,
dW.stride(0),
n_cols,
Expand Down Expand Up @@ -279,14 +279,14 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
X: (B, T, H) or (BxT, H)
W: (H,)
"""
Y, X, r, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
X, W, eps, offset, casting_mode
)
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, W, r)
ctx.save_for_backward(X, W, RSTD)
return Y

@staticmethod
Expand All @@ -295,12 +295,12 @@ def backward(ctx, dY):
"""
Y: (B, T, H) or (BxT, H)
"""
X, W, r = ctx.saved_tensors
X, W, RSTD = ctx.saved_tensors
dX, dW = rms_norm_backward(
dY,
X,
W,
r,
RSTD,
ctx.offset,
ctx.casting_mode,
ctx.BLOCK_SIZE,
Expand Down

0 comments on commit 3357a56

Please sign in to comment.