Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax: improve docs for pow & related functions #26528

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 102 additions & 5 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def exp(x: ArrayLike) -> Array:
"""
return exp_p.bind(x)

@export
def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`.
Expand Down Expand Up @@ -589,6 +590,7 @@ def log1p(x: ArrayLike) -> Array:
"""
return log1p_p.bind(x)

@export
def tanh(x: ArrayLike) -> Array:
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.
Expand Down Expand Up @@ -801,24 +803,114 @@ def abs(x: ArrayLike) -> Array:
"""
return abs_p.bind(x)

@export
def pow(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise power: :math:`x^y`."""
r"""Elementwise power: :math:`x^y`.
This function lowers directly to the `stablehlo.pow`_ operation, along with
a `stablehlo.convert`_ when the argument dtypes do not match.
Args:
x: Input array giving the base value. Must have floating or complex type.
y: Input array giving the exponent value. Must have integer, floating, or
complex type. Its dtype will be cast to that of ``x.dtype`` if necessary.
If neither ``x`` nor ``y`` is a scalar, then ``x`` and ``y`` must have
the same number of dimensions and be broadcast-compatible.
Returns:
An array of the same dtype as ``x`` containing the elementwise power.
See also:
:func:`jax.lax.integer_pow`: Elementwise power where ``y`` is a static integer.
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
.. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow
"""
return pow_p.bind(x, y)

@export
def integer_pow(x: ArrayLike, y: int) -> Array:
r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer."""
r"""Elementwise power: :math:`x^y`, where :math:`y` is a static integer.
This will lower to a sequence of :math:`O[\log_2(y)]` repetitions of
`stablehlo.multiply`_.
Args:
x: Input array giving the base value. Must have numerical dtype.
y: Static scalar integer giving the exponent.
Returns:
An array of the same shape and dtype as ``x`` containing the elementwise power.
See also:
:func:`jax.lax.pow`: Elementwise pwoer where ``y`` is an array.
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
return integer_pow_p.bind(x, y=y)

@export
def sqrt(x: ArrayLike) -> Array:
r"""Elementwise square root: :math:`\sqrt{x}`."""
r"""Elementwise square root: :math:`\sqrt{x}`.
This function lowers directly to the `stablehlo.sqrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the square root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.cbrt`: Elementwise cube root.
:func:`jax.lax.rsqrt`: Elementwise reciporical square root.
.. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt
"""
return sqrt_p.bind(x)

@export
def rsqrt(x: ArrayLike) -> Array:
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`.
This function lowers directly to the `stablehlo.rsqrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the
reciporical square root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.sqrt`: Elementwise square root.
:func:`jax.lax.cbrt`: Elementwise cube root.
.. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt
"""
return rsqrt_p.bind(x)

@export
def cbrt(x: ArrayLike) -> Array:
r"""Elementwise cube root: :math:`\sqrt[3]{x}`."""
r"""Elementwise cube root: :math:`\sqrt[3]{x}`.
This function lowers directly to the `stablehlo.cbrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the cube root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.sqrt`: Elementwise square root.
:func:`jax.lax.rsqrt`: Elementwise reciporical square root.
.. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt
"""
return cbrt_p.bind(x)

def bitwise_not(x: ArrayLike) -> Array:
Expand Down Expand Up @@ -2880,6 +2972,7 @@ def atan(x: ArrayLike) -> Array:
"""
return atan_p.bind(x)

@export
def sinh(x: ArrayLike) -> Array:
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`.
Expand All @@ -2899,6 +2992,7 @@ def sinh(x: ArrayLike) -> Array:
"""
return sinh_p.bind(x)

@export
def cosh(x: ArrayLike) -> Array:
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`.
Expand All @@ -2918,6 +3012,7 @@ def cosh(x: ArrayLike) -> Array:
"""
return cosh_p.bind(x)

@export
def asinh(x: ArrayLike) -> Array:
r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`.
Expand All @@ -2937,6 +3032,7 @@ def asinh(x: ArrayLike) -> Array:
"""
return asinh_p.bind(x)

@export
def acosh(x: ArrayLike) -> Array:
r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`.
Expand All @@ -2956,6 +3052,7 @@ def acosh(x: ArrayLike) -> Array:
"""
return acosh_p.bind(x)

@export
def atanh(x: ArrayLike) -> Array:
r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`.
Expand Down
Loading