Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

doc: improve docs for jax.lax trig functions #26342

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Operators
erfc
erf_inv
exp
exp2
expand_dims
expm1
fft
Expand Down
160 changes: 146 additions & 14 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def round(x: ArrayLike,
rounding_method = RoundingMethod(rounding_method)
return round_p.bind(x, rounding_method=rounding_method)

@export
def is_finite(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{isfinite}`.

Expand All @@ -478,6 +479,7 @@ def is_finite(x: ArrayLike) -> Array:
"""
return is_finite_p.bind(x)

@export
def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`.

Expand All @@ -488,7 +490,7 @@ def exp(x: ArrayLike) -> Array:

Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential.
exponential.

See also:
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
Expand All @@ -509,7 +511,7 @@ def exp2(x: ArrayLike) -> Array:

Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
base-2 exponential.
base-2 exponential.

See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
Expand All @@ -520,6 +522,7 @@ def exp2(x: ArrayLike) -> Array:
"""
return exp2_p.bind(x)

@export
def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`.

Expand All @@ -532,7 +535,7 @@ def expm1(x: ArrayLike) -> Array:

Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential minus 1.
exponential minus 1.

See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
Expand All @@ -542,6 +545,7 @@ def expm1(x: ArrayLike) -> Array:
"""
return expm1_p.bind(x)

@export
def log(x: ArrayLike) -> Array:
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.

Expand All @@ -552,7 +556,7 @@ def log(x: ArrayLike) -> Array:

Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm.
natural logarithm.

See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
Expand All @@ -561,8 +565,9 @@ def log(x: ArrayLike) -> Array:
"""
return log_p.bind(x)

@export
def log1p(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{log}(1 + x)`..
r"""Elementwise :math:`\mathrm{log}(1 + x)`.

This function lowers directly to the `stablehlo.log_plus_one`_ operation.
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
Expand All @@ -573,7 +578,7 @@ def log1p(x: ArrayLike) -> Array:

Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm of ``x + 1``.
natural logarithm of ``x + 1``.

See also:
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
Expand All @@ -591,17 +596,76 @@ def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
return logistic_p.bind(x)

@export
def sin(x: ArrayLike) -> Array:
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
r"""Elementwise sine: :math:`\mathrm{sin}(x)`.

For floating-point inputs, this function lowers directly to the
`stablehlo.sine`_ operation. For complex inputs, it lowers to a
sequence of HLO operations implementing the complex sine.

Args:
x: input array. Must have floating-point or complex type.

Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
sine.

See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.asin`: elementwise arc sine.

.. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine
"""
return sin_p.bind(x)

@export
def cos(x: ArrayLike) -> Array:
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.

For floating-point inputs, this function lowers directly to the
`stablehlo.cosine`_ operation. For complex inputs, it lowers to a
sequence of HLO operations implementing the complex cosine.

Args:
x: input array. Must have floating-point or complex type.

Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
cosine.

See also:
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.acos`: elementwise arc cosine.

.. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine
"""
return cos_p.bind(x)

@export
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise arc tangent of two variables:
:math:`\mathrm{atan}({x \over y})`."""
r"""Elementwise two-term arc tangent: :math:`\mathrm{atan}({x \over y})`.

This function lowers directly to the `stablehlo.atan2`_ operation.

Args:
x, y: input arrays. Must have a matching floating-point or complex dtypes. If
neither is a scalar, the two arrays must have the same number of dimensions
and be broadcast-compatible.

Returns:
Array of the same shape and dtype as ``x`` and ``y`` containing the element-wise
arc tangent of :math:`x \over y`, respecting the quadrant indicated by the sign
of each input.

See also:
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.atan`: elementwise one-term arc tangent.

.. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2
"""
return atan2_p.bind(x, y)

def real(x: ArrayLike) -> Array:
Expand Down Expand Up @@ -2473,20 +2537,88 @@ def reciprocal(x: ArrayLike) -> Array:
r"""Elementwise reciprocal: :math:`1 \over x`."""
return integer_pow(x, -1)

@export
def tan(x: ArrayLike) -> Array:
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.

This function lowers directly to the `stablehlo.tangent`_ operation.

Args:
x: input array. Must have floating-point or complex type.

Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
tangent.

See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.atan`: elementwise arc tangent.
- :func:`jax.lax.atan2`: elementwise 2-term arc tangent.

.. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent
"""
return tan_p.bind(x)

@export
def asin(x: ArrayLike) -> Array:
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.

This function lowers directly to the ``chlo.asin`` operation.

Args:
x: input array. Must have floating-point or complex type.

Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arc sine.

See also:
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.acos`: elementwise arc cosine.
- :func:`jax.lax.atan`: elementwise arc tangent.
"""
return asin_p.bind(x)

@export
def acos(x: ArrayLike) -> Array:
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`.

This function lowers directly to the ``chlo.acos`` operation.

Args:
x: input array. Must have floating-point or complex type.

Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arc cosine.

See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.asin`: elementwise arc sine.
- :func:`jax.lax.atan`: elementwise arc tangent.
"""
return acos_p.bind(x)

@export
def atan(x: ArrayLike) -> Array:
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.

This function lowers directly to the ``chlo.atan`` operation.

Args:
x: input array. Must have floating-point or complex type.

Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arc tangent.

See also:
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.acos`: elementwise arc cosine.
- :func:`jax.lax.asin`: elementwise arc sine.
- :func:`jax.lax.atan2`: elementwise 2-term arc tangent.
"""
return atan_p.bind(x)

def sinh(x: ArrayLike) -> Array:
Expand Down
Loading