Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move info!=0 logic into lax.linalg.tridiagonal lowering rule. #26521

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,17 +619,7 @@ def tridiagonal(
superdiagonal. ``taus`` contains the scalar factors of the elementary
Householder reflectors.
"""
arr, d, e, taus, info = tridiagonal_p.bind(lax_internal.asarray(a), lower=lower)
def nans_like(arr):
if dtypes.issubdtype(arr.dtype, np.complexfloating):
return lax.full_like(arr, np.nan + 1j * np.nan)
return lax.full_like(arr, np.nan)
mask = lambda x: lax.broadcast_in_dim(info == 0, x.shape, range(info.ndim))
arr = lax.select(mask(arr), arr, nans_like(arr))
d = lax.select(mask(d), d, nans_like(d))
e = lax.select(mask(e), e, nans_like(e))
taus = lax.select(mask(taus), taus, nans_like(taus))
return arr, d, e, taus
return tridiagonal_p.bind(lax_internal.asarray(a), lower=lower)


def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
Expand Down Expand Up @@ -2949,7 +2939,6 @@ def _tridiagonal_abstract_eval(a, *, lower):
ShapedArray(a.shape[:-2] + (a.shape[-1],), real_dtype),
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), real_dtype),
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
ShapedArray(a.shape[:-2], np.int32)
]

tridiagonal_p = Primitive("tridiagonal")
Expand All @@ -2961,34 +2950,48 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return tridiagonal_p.bind(x, lower=lower), (0, 0, 0, 0, 0)
return tridiagonal_p.bind(x, lower=lower), (0, 0, 0, 0)

batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule

def _tridiagonal_cpu_hlo(ctx, a, *, lower):
def _tridiagonal_cpu_gpu_lowering(ctx, a, *, lower, target_name_prefix):
a_aval, = ctx.avals_in
real = a_aval.dtype == np.float32 or a_aval.dtype == np.float64
prefix = "sy" if real else "he"
target_name = lapack.prepare_lapack_call(f"{prefix}trd_ffi", a_aval.dtype)
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0})
return rule(ctx, a, uplo=_matrix_uplo_attr(lower))

def _tridiagonal_gpu_hlo(ctx, a, *, lower, target_name_prefix):
rule = _linalg_ffi_lowering(f"{target_name_prefix}solver_sytrd_ffi",
operand_output_aliases={0: 0})
return rule(ctx, a, lower=lower)

arr_aval, d_aval, e_aval, taus_aval = ctx.avals_out
batch_dims = a_aval.shape[:-2]
if target_name_prefix == "cpu":
real = a_aval.dtype == np.float32 or a_aval.dtype == np.float64
prefix = "sy" if real else "he"
target_name = lapack.prepare_lapack_call(f"{prefix}trd_ffi", a_aval.dtype)
params = {"uplo": _matrix_uplo_attr(lower)}
else:
target_name = f"{target_name_prefix}solver_sytrd_ffi"
params = {"lower": lower}
info_aval = ShapedArray(batch_dims, np.int32)
rule = _linalg_ffi_lowering(
target_name, avals_out=(*ctx.avals_out, info_aval),
operand_output_aliases={0: 0})
arr, d, e, taus, info = rule(ctx, a, **params)
zeros = mlir.full_like_aval(ctx, 0, info_aval)
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
arr = _replace_not_ok_with_nan(ctx, batch_dims, ok, arr, arr_aval)
d = _replace_not_ok_with_nan(ctx, batch_dims, ok, d, d_aval)
e = _replace_not_ok_with_nan(ctx, batch_dims, ok, e, e_aval)
taus = _replace_not_ok_with_nan(ctx, batch_dims, ok, taus, taus_aval)
return arr, d, e, taus

mlir.register_lowering(
tridiagonal_p, _tridiagonal_cpu_hlo, platform="cpu")
tridiagonal_p,
partial(_tridiagonal_cpu_gpu_lowering, target_name_prefix="cpu"),
platform="cpu",
)
mlir.register_lowering(
tridiagonal_p,
partial(_tridiagonal_gpu_hlo, target_name_prefix="cu"),
partial(_tridiagonal_cpu_gpu_lowering, target_name_prefix="cu"),
platform="cuda",
)
mlir.register_lowering(
tridiagonal_p,
partial(_tridiagonal_gpu_hlo, target_name_prefix="hip"),
partial(_tridiagonal_cpu_gpu_lowering, target_name_prefix="hip"),
platform="rocm",
)

Expand Down
Loading