Skip to content

Commit

Permalink
Merge pull request #26528 from jakevdp:lax-docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726971041
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
2 parents ca87f5f + 531443c commit 4b94665
Showing 1 changed file with 102 additions and 5 deletions.
107 changes: 102 additions & 5 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def exp(x: ArrayLike) -> Array:
"""
return exp_p.bind(x)

@export
def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`.
Expand Down Expand Up @@ -589,6 +590,7 @@ def log1p(x: ArrayLike) -> Array:
"""
return log1p_p.bind(x)

@export
def tanh(x: ArrayLike) -> Array:
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.
Expand Down Expand Up @@ -801,24 +803,114 @@ def abs(x: ArrayLike) -> Array:
"""
return abs_p.bind(x)

@export
def pow(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise power: :math:`x^y`."""
r"""Elementwise power: :math:`x^y`.
This function lowers directly to the `stablehlo.pow`_ operation, along with
a `stablehlo.convert`_ when the argument dtypes do not match.
Args:
x: Input array giving the base value. Must have floating or complex type.
y: Input array giving the exponent value. Must have integer, floating, or
complex type. Its dtype will be cast to that of ``x.dtype`` if necessary.
If neither ``x`` nor ``y`` is a scalar, then ``x`` and ``y`` must have
the same number of dimensions and be broadcast-compatible.
Returns:
An array of the same dtype as ``x`` containing the elementwise power.
See also:
:func:`jax.lax.integer_pow`: Elementwise power where ``y`` is a static integer.
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
.. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow
"""
return pow_p.bind(x, y)

@export
def integer_pow(x: ArrayLike, y: int) -> Array:
r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer."""
r"""Elementwise power: :math:`x^y`, where :math:`y` is a static integer.
This will lower to a sequence of :math:`O[\log_2(y)]` repetitions of
`stablehlo.multiply`_.
Args:
x: Input array giving the base value. Must have numerical dtype.
y: Static scalar integer giving the exponent.
Returns:
An array of the same shape and dtype as ``x`` containing the elementwise power.
See also:
:func:`jax.lax.pow`: Elementwise pwoer where ``y`` is an array.
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
return integer_pow_p.bind(x, y=y)

@export
def sqrt(x: ArrayLike) -> Array:
r"""Elementwise square root: :math:`\sqrt{x}`."""
r"""Elementwise square root: :math:`\sqrt{x}`.
This function lowers directly to the `stablehlo.sqrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the square root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.cbrt`: Elementwise cube root.
:func:`jax.lax.rsqrt`: Elementwise reciporical square root.
.. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt
"""
return sqrt_p.bind(x)

@export
def rsqrt(x: ArrayLike) -> Array:
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`.
This function lowers directly to the `stablehlo.rsqrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the
reciporical square root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.sqrt`: Elementwise square root.
:func:`jax.lax.cbrt`: Elementwise cube root.
.. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt
"""
return rsqrt_p.bind(x)

@export
def cbrt(x: ArrayLike) -> Array:
r"""Elementwise cube root: :math:`\sqrt[3]{x}`."""
r"""Elementwise cube root: :math:`\sqrt[3]{x}`.
This function lowers directly to the `stablehlo.cbrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the cube root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.sqrt`: Elementwise square root.
:func:`jax.lax.rsqrt`: Elementwise reciporical square root.
.. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt
"""
return cbrt_p.bind(x)

def bitwise_not(x: ArrayLike) -> Array:
Expand Down Expand Up @@ -2880,6 +2972,7 @@ def atan(x: ArrayLike) -> Array:
"""
return atan_p.bind(x)

@export
def sinh(x: ArrayLike) -> Array:
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`.
Expand All @@ -2899,6 +2992,7 @@ def sinh(x: ArrayLike) -> Array:
"""
return sinh_p.bind(x)

@export
def cosh(x: ArrayLike) -> Array:
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`.
Expand All @@ -2918,6 +3012,7 @@ def cosh(x: ArrayLike) -> Array:
"""
return cosh_p.bind(x)

@export
def asinh(x: ArrayLike) -> Array:
r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`.
Expand All @@ -2937,6 +3032,7 @@ def asinh(x: ArrayLike) -> Array:
"""
return asinh_p.bind(x)

@export
def acosh(x: ArrayLike) -> Array:
r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`.
Expand All @@ -2956,6 +3052,7 @@ def acosh(x: ArrayLike) -> Array:
"""
return acosh_p.bind(x)

@export
def atanh(x: ArrayLike) -> Array:
r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`.
Expand Down

0 comments on commit 4b94665

Please sign in to comment.