Skip to content

Commit

Permalink
minor changes based on uxlfoundation#2206, suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Dec 4, 2024
1 parent 8fca003 commit 164435d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sklearnex/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
# Copyright 2024 UXL Foundation Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
35 changes: 32 additions & 3 deletions sklearnex/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def _onedal_supported_format(X, xp=None):
# _onedal_supported_format is therefore conservative in verifying attributes and
# does not support array_api. This will block onedal_assert_all_finite from being
# used for array_api inputs but will allow dpnp ndarrays and dpctl tensors.
return X.dtype in [xp.float32, xp.float64] and hasattr(X, "flags")
# only check contiguous arrays to prevent unnecessary copying of data, even if
# non-contiguous arrays can now be converted to oneDAL tables.
return (
X.dtype in [xp.float32, xp.float64]
and hasattr(X, "flags")
and (X.flags["C_CONTIGUOUS"] or X.flags["F_CONTIGUOUS"])
)

else:
from daal4py.utils.validation import _assert_all_finite as _onedal_assert_all_finite
Expand Down Expand Up @@ -108,14 +114,37 @@ def validate_data(
y=y,
**kwargs,
)

check_x = not isinstance(X, str) or X != "no_validation"
check_y = not (y is None or isinstance(y, str) and y == "no_validation")

if ensure_all_finite:
# run local finite check
allow_nan = ensure_all_finite == "allow-nan"
arg = iter(out if isinstance(out, tuple) else (out,))
if not isinstance(X, str) or X != "no_validation":
if check_x:
assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X")
if not (y is None or isinstance(y, str) and y == "no_validation"):
if check_y:
assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y")

if check_y and "dtype" in kwargs:
# validate_data does not do full dtype conversions, as it uses check_X_y
# oneDAL can make tables from [int32, float32, float64], requiring
# a dtype check and conversion. This will query the array_namespace and
# convert y as necessary. This is done after assert_all_finite, because
# int y arrays do not need to finite check, and this will lead to a speedup
# in comparison to sklearn
dtype = kwargs["dtype"]
if not isinstance(dtype, (tuple, list)):
dtype = tuple(dtype)

outx, outy = out if check_x else (None, out)
if outy.dtype not in dtype:
yp, _ = get_namespace(outy)
# use asarray rather than astype because of numpy support
outy = yp.asarray(outy, dtype=dtype[0])
out = (outx, outy) if check_x else outy

return out


Expand Down

0 comments on commit 164435d

Please sign in to comment.