diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index be872dead1a7..f766cec24465 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -501,6 +501,7 @@ def exp(x: ArrayLike) -> Array: """ return exp_p.bind(x) +@export def exp2(x: ArrayLike) -> Array: r"""Elementwise base-2 exponential: :math:`2^x`. @@ -589,6 +590,7 @@ def log1p(x: ArrayLike) -> Array: """ return log1p_p.bind(x) +@export def tanh(x: ArrayLike) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`. @@ -801,24 +803,114 @@ def abs(x: ArrayLike) -> Array: """ return abs_p.bind(x) +@export def pow(x: ArrayLike, y: ArrayLike) -> Array: - r"""Elementwise power: :math:`x^y`.""" + r"""Elementwise power: :math:`x^y`. + + This function lowers directly to the `stablehlo.pow`_ operation, along with + a `stablehlo.convert`_ when the argument dtypes do not match. + + Args: + x: Input array giving the base value. Must have floating or complex type. + y: Input array giving the exponent value. Must have integer, floating, or + complex type. Its dtype will be cast to that of ``x.dtype`` if necessary. + If neither ``x`` nor ``y`` is a scalar, then ``x`` and ``y`` must have + the same number of dimensions and be broadcast-compatible. + + Returns: + An array of the same dtype as ``x`` containing the elementwise power. + + See also: + :func:`jax.lax.integer_pow`: Elementwise power where ``y`` is a static integer. + + .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert + .. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow + """ return pow_p.bind(x, y) +@export def integer_pow(x: ArrayLike, y: int) -> Array: - r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer.""" + r"""Elementwise power: :math:`x^y`, where :math:`y` is a static integer. + + This will lower to a sequence of :math:`O[\log_2(y)]` repetitions of + `stablehlo.multiply`_. + + Args: + x: Input array giving the base value. Must have numerical dtype. + y: Static scalar integer giving the exponent. + + Returns: + An array of the same shape and dtype as ``x`` containing the elementwise power. + + See also: + :func:`jax.lax.pow`: Elementwise pwoer where ``y`` is an array. + + .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply + """ return integer_pow_p.bind(x, y=y) +@export def sqrt(x: ArrayLike) -> Array: - r"""Elementwise square root: :math:`\sqrt{x}`.""" + r"""Elementwise square root: :math:`\sqrt{x}`. + + This function lowers directly to the `stablehlo.sqrt`_ operation. + + Args: + x: Input array. Must have floating or complex dtype. + + Returns: + An array of the same shape and dtype as ``x`` containing the square root. + + See also: + :func:`jax.lax.pow`: Elementwise power. + :func:`jax.lax.cbrt`: Elementwise cube root. + :func:`jax.lax.rsqrt`: Elementwise reciporical square root. + + .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt + """ return sqrt_p.bind(x) +@export def rsqrt(x: ArrayLike) -> Array: - r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`.""" + r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`. + + This function lowers directly to the `stablehlo.rsqrt`_ operation. + + Args: + x: Input array. Must have floating or complex dtype. + + Returns: + An array of the same shape and dtype as ``x`` containing the + reciporical square root. + + See also: + :func:`jax.lax.pow`: Elementwise power. + :func:`jax.lax.sqrt`: Elementwise square root. + :func:`jax.lax.cbrt`: Elementwise cube root. + + .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt + """ return rsqrt_p.bind(x) +@export def cbrt(x: ArrayLike) -> Array: - r"""Elementwise cube root: :math:`\sqrt[3]{x}`.""" + r"""Elementwise cube root: :math:`\sqrt[3]{x}`. + + This function lowers directly to the `stablehlo.cbrt`_ operation. + + Args: + x: Input array. Must have floating or complex dtype. + + Returns: + An array of the same shape and dtype as ``x`` containing the cube root. + + See also: + :func:`jax.lax.pow`: Elementwise power. + :func:`jax.lax.sqrt`: Elementwise square root. + :func:`jax.lax.rsqrt`: Elementwise reciporical square root. + + .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt + """ return cbrt_p.bind(x) def bitwise_not(x: ArrayLike) -> Array: @@ -2880,6 +2972,7 @@ def atan(x: ArrayLike) -> Array: """ return atan_p.bind(x) +@export def sinh(x: ArrayLike) -> Array: r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`. @@ -2899,6 +2992,7 @@ def sinh(x: ArrayLike) -> Array: """ return sinh_p.bind(x) +@export def cosh(x: ArrayLike) -> Array: r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`. @@ -2918,6 +3012,7 @@ def cosh(x: ArrayLike) -> Array: """ return cosh_p.bind(x) +@export def asinh(x: ArrayLike) -> Array: r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`. @@ -2937,6 +3032,7 @@ def asinh(x: ArrayLike) -> Array: """ return asinh_p.bind(x) +@export def acosh(x: ArrayLike) -> Array: r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`. @@ -2956,6 +3052,7 @@ def acosh(x: ArrayLike) -> Array: """ return acosh_p.bind(x) +@export def atanh(x: ArrayLike) -> Array: r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`.