Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax: improve docs for real, imag, complex, conj, and abs. #26470

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 90 additions & 5 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,34 +685,119 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array:
"""
return atan2_p.bind(x, y)

@export
def real(x: ArrayLike) -> Array:
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.

Returns the real part of a complex number.
This function lowers directly to the `stablehlo.real`_ operation.

Args:
x: input array. Must have complex dtype.

Returns:
Array of the same shape as ``x`` containing its real part. Will have dtype
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.

See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.conj`: elementwise complex conjugate.

.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
"""
return real_p.bind(x)

@export
def imag(x: ArrayLike) -> Array:
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.

Returns the imaginary part of a complex number.
This function lowers directly to the `stablehlo.imag`_ operation.

Args:
x: input array. Must have complex dtype.

Returns:
Array of the same shape as ``x`` containing its imaginary part. Will have dtype
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.

See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.conj`: elementwise complex conjugate.

.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
"""
return imag_p.bind(x)

@export
def complex(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise make complex number: :math:`x + jy`.

Builds a complex number from real and imaginary parts.
This function lowers directly to the `stablehlo.complex`_ operation.

Args:
x, y: input arrays. Must have matching floating-point dtypes. If
neither is a scalar, the two arrays must have the same number
of dimensions and be broadcast-compatible.

Returns:
The complex array with the real part given by ``x``, and the imaginary
part given by ``y``. For inputs of dtype float32 or float64, the result
will have dtype complex64 or complex128 respectively.

See also:
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.conj`: elementwise complex conjugate.

.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
"""
return complex_p.bind(x, y)

@export
def conj(x: ArrayLike) -> Array:
r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
r"""Elementwise complex conjugate function: :math:`\overline{x}`.

This function lowers to a combination of `stablehlo.real`_, `stablehlo.imag`_,
and `stablehlo.complex`_.

Args:
x: input array. Must have complex dtype.

Returns:
Array of the same shape and dtype as ``x`` containing its complex conjugate.

See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.abs`: elementwise absolute value / complex magnitude.

.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
"""
# TODO(mattjj): remove input_dtype, not needed anymore
return conj_p.bind(x, input_dtype=_dtype(x))

@export
def abs(x: ArrayLike) -> Array:
r"""Elementwise absolute value: :math:`|x|`."""
r"""Elementwise absolute value: :math:`|x|`.

This function lowers directly to the `stablehlo.abs`_ operation.

Args:
x: Input array. Must have signed integer, floating, or complex dtype.

Returns:
An array of the same shape as ``x`` containing the elementwise absolute value.
For complex valued input, :math:`a + ib`, ``abs(x)`` returns :math:`\sqrt{a^2+b^2}`.

See also:
- :func:`jax.numpy.abs`: a more flexible NumPy-style ``abs`` implementation.

.. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs
"""
return abs_p.bind(x)

def pow(x: ArrayLike, y: ArrayLike) -> Array:
Expand Down
Loading