Skip to content

Commit

Permalink
Move info!=0 logic into lax.linalg.tridiagonal lowering rule.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726617102
  • Loading branch information
dfm authored and Google-ML-Automation committed Feb 13, 2025
1 parent 91c6e44 commit 14afb73
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,17 +619,7 @@ def tridiagonal(
superdiagonal. ``taus`` contains the scalar factors of the elementary
Householder reflectors.
"""
arr, d, e, taus, info = tridiagonal_p.bind(lax_internal.asarray(a), lower=lower)
def nans_like(arr):
if dtypes.issubdtype(arr.dtype, np.complexfloating):
return lax.full_like(arr, np.nan + 1j * np.nan)
return lax.full_like(arr, np.nan)
mask = lambda x: lax.broadcast_in_dim(info == 0, x.shape, range(info.ndim))
arr = lax.select(mask(arr), arr, nans_like(arr))
d = lax.select(mask(d), d, nans_like(d))
e = lax.select(mask(e), e, nans_like(e))
taus = lax.select(mask(taus), taus, nans_like(taus))
return arr, d, e, taus
return tridiagonal_p.bind(lax_internal.asarray(a), lower=lower)


def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
Expand Down Expand Up @@ -2949,7 +2939,6 @@ def _tridiagonal_abstract_eval(a, *, lower):
ShapedArray(a.shape[:-2] + (a.shape[-1],), real_dtype),
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), real_dtype),
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
ShapedArray(a.shape[:-2], np.int32)
]

tridiagonal_p = Primitive("tridiagonal")
Expand All @@ -2961,34 +2950,48 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return tridiagonal_p.bind(x, lower=lower), (0, 0, 0, 0, 0)
return tridiagonal_p.bind(x, lower=lower), (0, 0, 0, 0)

batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule

def _tridiagonal_cpu_hlo(ctx, a, *, lower):
def _tridiagonal_cpu_gpu_lowering(ctx, a, *, lower, target_name_prefix):
a_aval, = ctx.avals_in
real = a_aval.dtype == np.float32 or a_aval.dtype == np.float64
prefix = "sy" if real else "he"
target_name = lapack.prepare_lapack_call(f"{prefix}trd_ffi", a_aval.dtype)
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0})
return rule(ctx, a, uplo=_matrix_uplo_attr(lower))

def _tridiagonal_gpu_hlo(ctx, a, *, lower, target_name_prefix):
rule = _linalg_ffi_lowering(f"{target_name_prefix}solver_sytrd_ffi",
operand_output_aliases={0: 0})
return rule(ctx, a, lower=lower)

arr_aval, d_aval, e_aval, taus_aval = ctx.avals_out
batch_dims = a_aval.shape[:-2]
if target_name_prefix == "cpu":
real = a_aval.dtype == np.float32 or a_aval.dtype == np.float64
prefix = "sy" if real else "he"
target_name = lapack.prepare_lapack_call(f"{prefix}trd_ffi", a_aval.dtype)
params = {"uplo": _matrix_uplo_attr(lower)}
else:
target_name = f"{target_name_prefix}solver_sytrd_ffi"
params = {"lower": lower}
info_aval = ShapedArray(batch_dims, np.int32)
rule = _linalg_ffi_lowering(
target_name, avals_out=(*ctx.avals_out, info_aval),
operand_output_aliases={0: 0})
arr, d, e, taus, info = rule(ctx, a, **params)
zeros = mlir.full_like_aval(ctx, 0, info_aval)
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
arr = _replace_not_ok_with_nan(ctx, batch_dims, ok, arr, arr_aval)
d = _replace_not_ok_with_nan(ctx, batch_dims, ok, d, d_aval)
e = _replace_not_ok_with_nan(ctx, batch_dims, ok, e, e_aval)
taus = _replace_not_ok_with_nan(ctx, batch_dims, ok, taus, taus_aval)
return arr, d, e, taus

mlir.register_lowering(
tridiagonal_p, _tridiagonal_cpu_hlo, platform="cpu")
tridiagonal_p,
partial(_tridiagonal_cpu_gpu_lowering, target_name_prefix="cpu"),
platform="cpu",
)
mlir.register_lowering(
tridiagonal_p,
partial(_tridiagonal_gpu_hlo, target_name_prefix="cu"),
partial(_tridiagonal_cpu_gpu_lowering, target_name_prefix="cu"),
platform="cuda",
)
mlir.register_lowering(
tridiagonal_p,
partial(_tridiagonal_gpu_hlo, target_name_prefix="hip"),
partial(_tridiagonal_cpu_gpu_lowering, target_name_prefix="hip"),
platform="rocm",
)

Expand Down

0 comments on commit 14afb73

Please sign in to comment.