Skip to content

Commit

Permalink
Complete dpnp.cov implementation (IntelPython#2303)
Browse files Browse the repository at this point in the history
The PR proposes to rework implementation of `dpnp.cov` and to add
support of all missing keywords, replace TODO with already supported
functionality and to remove any fallback on numpy call.
Also this PR update documentation and adds extra tests to improve the
coverage.
  • Loading branch information
antonwolfy authored Feb 13, 2025
1 parent 0bda96a commit 0403a7c
Show file tree
Hide file tree
Showing 6 changed files with 477 additions and 102 deletions.
192 changes: 153 additions & 39 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
to_supported_dtypes,
)

from .dpnp_utils import call_origin, get_usm_allocations
from .dpnp_utils import get_usm_allocations
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov, dpnp_median

Expand Down Expand Up @@ -748,61 +748,175 @@ def cov(
For full documentation refer to :obj:`numpy.cov`.
Parameters
----------
m : {dpnp.ndarray, usm_ndarray}
A 1-D or 2-D array containing multiple variables and observations.
Each row of `m` represents a variable, and each column a single
observation of all those variables. Also see `rowvar` below.
y : {None, dpnp.ndarray, usm_ndarray}, optional
An additional set of variables and observations. `y` has the same form
as that of `m`.
Default: ``None``.
rowvar : bool, optional
If `rowvar` is ``True``, then each row represents a variable, with
observations in the columns. Otherwise, the relationship is transposed:
each column represents a variable, while the rows contain observations.
Default: ``True``.
bias : bool, optional
Default normalization is by ``(N - 1)``, where ``N`` is the number of
observations given (unbiased estimate). If `bias` is ``True``, then
normalization is by ``N``. These values can be overridden by using the
keyword `ddof`.
Default: ``False``.
ddof : {None, int}, optional
If not ``None`` the default value implied by `bias` is overridden. Note
that ``ddof=1`` will return the unbiased estimate, even if both
`fweights` and `aweights` are specified, and ``ddof=0`` will return the
simple average. See the notes for the details.
Default: ``None``.
fweights : {None, dpnp.ndarray, usm_ndarray}, optional
1-D array of integer frequency weights; the number of times each
observation vector should be repeated.
It is required that ``fweights >= 0``. However, the function will not
raise an error when ``fweights < 0`` for performance reasons.
Default: ``None``.
aweights : {None, dpnp.ndarray, usm_ndarray}, optional
1-D array of observation vector weights. These relative weights are
typically large for observations considered "important" and smaller for
observations considered less "important". If ``ddof=0`` the array of
weights can be used to assign probabilities to observation vectors.
It is required that ``aweights >= 0``. However, the function will not
error when ``aweights < 0`` for performance reasons.
Default: ``None``.
dtype : {None, str, dtype object}, optional
Data-type of the result. By default, the return data-type will have
at least floating point type based on the capabilities of the device on
which the input arrays reside.
Default: ``None``.
Returns
-------
out : dpnp.ndarray
The covariance matrix of the variables.
Limitations
-----------
Input array ``m`` is supported as :obj:`dpnp.ndarray`.
Dimension of input array ``m`` is limited by ``m.ndim <= 2``.
Size and shape of input arrays are supported to be equal.
Parameter `y` is supported only with default value ``None``.
Parameter `bias` is supported only with default value ``False``.
Parameter `ddof` is supported only with default value ``None``.
Parameter `fweights` is supported only with default value ``None``.
Parameter `aweights` is supported only with default value ``None``.
Otherwise the function will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
See Also
--------
:obj:`dpnp.corrcoef` : Normalized covariance matrix
:obj:`dpnp.corrcoef` : Normalized covariance matrix.
Notes
-----
Assume that the observations are in the columns of the observation array `m`
and let ``f = fweights`` and ``a = aweights`` for brevity. The steps to
compute the weighted covariance are as follows::
>>> import dpnp as np
>>> m = np.arange(10, dtype=np.float32)
>>> f = np.arange(10) * 2
>>> a = np.arange(10) ** 2.0
>>> ddof = 1
>>> w = f * a
>>> v1 = np.sum(w)
>>> v2 = np.sum(w * a)
>>> m -= np.sum(m * w, axis=None, keepdims=True) / v1
>>> cov = np.dot(m * w, m.T) * v1 / (v1**2 - ddof * v2)
Note that when ``a == 1``, the normalization factor
``v1 / (v1**2 - ddof * v2)`` goes over to ``1 / (np.sum(f) - ddof)``
as it should.
Examples
--------
>>> import dpnp as np
>>> x = np.array([[0, 2], [1, 1], [2, 0]]).T
>>> x.shape
(2, 3)
>>> [i for i in x]
[0, 1, 2, 2, 1, 0]
>>> out = np.cov(x)
>>> out.shape
(2, 2)
>>> [i for i in out]
[1.0, -1.0, -1.0, 1.0]
Consider two variables, :math:`x_0` and :math:`x_1`, which correlate
perfectly, but in opposite directions:
>>> x
array([[0, 1, 2],
[2, 1, 0]])
Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance
matrix shows this clearly:
>>> np.cov(x)
array([[ 1., -1.],
[-1., 1.]])
Note that element :math:`C_{0,1}`, which shows the correlation between
:math:`x_0` and :math:`x_1`, is negative.
Further, note how `x` and `y` are combined:
>>> x = np.array([-2.1, -1, 4.3])
>>> y = np.array([3, 1.1, 0.12])
>>> X = np.stack((x, y), axis=0)
>>> np.cov(X)
array([[11.71 , -4.286 ], # may vary
[-4.286 , 2.14413333]])
>>> np.cov(x, y)
array([[11.71 , -4.286 ], # may vary
[-4.286 , 2.14413333]])
>>> np.cov(x)
array(11.71)
"""

if not dpnp.is_supported_array_type(m):
pass
elif m.ndim > 2:
pass
elif bias:
pass
elif ddof is not None:
pass
elif fweights is not None:
pass
elif aweights is not None:
pass
arrays = [m]
if y is not None:
arrays.append(y)
dpnp.check_supported_arrays_type(*arrays)

if m.ndim > 2:
raise ValueError("m has more than 2 dimensions")

if y is not None:
if y.ndim > 2:
raise ValueError("y has more than 2 dimensions")

if ddof is not None:
if not isinstance(ddof, int):
raise ValueError("ddof must be integer")
else:
return dpnp_cov(m, y=y, rowvar=rowvar, dtype=dtype)
ddof = 0 if bias else 1

def_float = dpnp.default_float_type(m.sycl_queue)
if dtype is None:
dtype = dpnp.result_type(*arrays, def_float)

if fweights is not None:
dpnp.check_supported_arrays_type(fweights)
if not dpnp.issubdtype(fweights.dtype, numpy.integer):
raise TypeError("fweights must be integer")

if fweights.ndim > 1:
raise ValueError("cannot handle multidimensional fweights")

fweights = dpnp.astype(fweights, dtype=def_float)

if aweights is not None:
dpnp.check_supported_arrays_type(aweights)
if aweights.ndim > 1:
raise ValueError("cannot handle multidimensional aweights")

aweights = dpnp.astype(aweights, dtype=def_float)

return call_origin(
numpy.cov, m, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype
return dpnp_cov(
m,
y=y,
rowvar=rowvar,
ddof=ddof,
dtype=dtype,
fweights=fweights,
aweights=aweights,
)


Expand Down
108 changes: 59 additions & 49 deletions dpnp/dpnp_utils/dpnp_utils_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import dpnp
from dpnp.dpnp_array import dpnp_array
from dpnp.dpnp_utils import get_usm_allocations, map_dtype_to_device

__all__ = ["dpnp_cov", "dpnp_median"]

Expand Down Expand Up @@ -119,66 +118,77 @@ def _flatten_array_along_axes(a, axes_to_flatten, overwrite_input):
return a_flatten, overwrite_input


def dpnp_cov(m, y=None, rowvar=True, dtype=None):
def dpnp_cov(
m, y=None, rowvar=True, ddof=1, dtype=None, fweights=None, aweights=None
):
"""
dpnp_cov(m, y=None, rowvar=True, dtype=None)
Estimate a covariance matrix based on passed data.
No support for given weights is provided now.
The implementation is done through existing dpnp and dpctl methods
instead of separate function call of dpnp backend.
The implementation is done through existing dpnp functions.
"""

def _get_2dmin_array(x, dtype):
"""
Transform an input array to a form required for building a covariance matrix.
If applicable, it reshapes the input array to have 2 dimensions or greater.
If applicable, it transposes the input array when 'rowvar' is False.
It casts to another dtype, if the input array differs from requested one.
"""
if x.ndim == 0:
x = x.reshape((1, 1))
elif x.ndim == 1:
x = x[dpnp.newaxis, :]

if not rowvar and x.ndim != 1:
x = x.T

if x.dtype != dtype:
x = dpnp.astype(x, dtype)
return x
# need to create a copy of input, since it will be modified in-place
x = dpnp.array(m, ndmin=2, dtype=dtype)
if not rowvar and m.ndim != 1:
x = x.T

# input arrays must follow CFD paradigm
_, queue = get_usm_allocations((m,) if y is None else (m, y))
if x.shape[0] == 0:
return dpnp.empty_like(
x, shape=(0, 0), dtype=dpnp.default_float_type(m.sycl_queue)
)

# calculate a type of result array if not passed explicitly
if dtype is None:
dtypes = [m.dtype, dpnp.default_float_type(sycl_queue=queue)]
if y is not None:
dtypes.append(y.dtype)
dtype = dpnp.result_type(*dtypes)
# TODO: remove when dpctl.result_type() is returned dtype based on fp64
dtype = map_dtype_to_device(dtype, queue.sycl_device)

X = _get_2dmin_array(m, dtype)
if y is not None:
y = _get_2dmin_array(y, dtype)

X = dpnp.concatenate((X, y), axis=0)

avg = X.mean(axis=1)
y_ndim = y.ndim
y = dpnp.array(y, copy=None, ndmin=2, dtype=dtype)
if not rowvar and y_ndim != 1:
y = y.T
x = dpnp.concatenate((x, y), axis=0)

# get the product of frequencies and weights
w = None
if fweights is not None:
if fweights.shape[0] != x.shape[1]:
raise ValueError("incompatible numbers of samples and fweights")

w = fweights

if aweights is not None:
if aweights.shape[0] != x.shape[1]:
raise ValueError("incompatible numbers of samples and aweights")

if w is None:
w = aweights
else:
w *= aweights

avg, w_sum = dpnp.average(x, axis=1, weights=w, returned=True)
w_sum = w_sum[0]

# determine the normalization
if w is None:
fact = x.shape[1] - ddof
elif ddof == 0:
fact = w_sum
elif aweights is None:
fact = w_sum - ddof
else:
fact = w_sum - ddof * dpnp.sum(w * aweights) / w_sum

fact = X.shape[1] - 1
X -= avg[:, None]
if fact <= 0:
warnings.warn(
"Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2
)
fact = 0.0

c = dpnp.dot(X, X.T.conj())
c *= 1 / fact if fact != 0 else dpnp.nan
x -= avg[:, None]
if w is None:
x_t = x.T
else:
x_t = (x * w).T

return dpnp.squeeze(c)
c = dpnp.dot(x, x_t.conj()) / fact
return c.squeeze()


def dpnp_median(
Expand Down
Loading

0 comments on commit 0403a7c

Please sign in to comment.