-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[enhancement] add oneDAL finiteness_checker implementation to onedal #2126
Changes from all commits
32fe269
cdbf1b5
62674a2
c75c23b
6a20938
382d7a1
c8ffd9c
9aa13d5
84e15d5
63073c6
d915da5
3dddf2d
1e1213e
0531713
e831167
d6eb1d0
fb30d6e
63a18c2
76c0856
7deb2bb
ed46b29
67d6273
054f0a1
8abead9
47d0f8b
e48c2bd
c6751c4
f3e4a3a
39cdb5f
0f39613
b42cfe3
0ed615e
f101aff
24c0e94
3f96166
d985053
90ec48b
8c2c854
6fa38d7
9c1ca9c
4b67dbd
fa59a3c
3330b33
4895940
0c6dd5d
e2182fa
982ef2c
2fb52a8
28dc267
2f85fd4
8659248
3827d6f
55fa7d2
175cd78
7016ad0
1a01859
2fbcdd9
fb7375f
30816bf
abb3b16
97aef73
6e29651
59363a8
61da628
e3facab
5bb54a5
e8d8c71
1e09b11
afc76b8
edf0350
4efad2c
63e2fa8
48cafbc
085f8a7
cdb11f2
b539d23
5549f99
61ca3db
63d9566
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/******************************************************************************* | ||
* Copyright 2024 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
// fix error with missing headers | ||
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20250200 | ||
#include "oneapi/dal/algo/finiteness_checker.hpp" | ||
#else | ||
#include "oneapi/dal/algo/finiteness_checker/compute.hpp" | ||
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20250200 | ||
|
||
#include "onedal/common.hpp" | ||
#include "onedal/version.hpp" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace oneapi::dal::python { | ||
|
||
template <typename Task, typename Ops> | ||
struct method2t { | ||
method2t(const Task& task, const Ops& ops) : ops(ops) {} | ||
|
||
template <typename Float> | ||
auto operator()(const py::dict& params) { | ||
using namespace finiteness_checker; | ||
|
||
const auto method = params["method"].cast<std::string>(); | ||
|
||
ONEDAL_PARAM_DISPATCH_VALUE(method, "dense", ops, Float, method::dense); | ||
ONEDAL_PARAM_DISPATCH_VALUE(method, "by_default", ops, Float, method::by_default); | ||
ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(method); | ||
} | ||
|
||
Ops ops; | ||
}; | ||
|
||
struct params2desc { | ||
template <typename Float, typename Method, typename Task> | ||
auto operator()(const pybind11::dict& params) { | ||
using namespace dal::finiteness_checker; | ||
|
||
auto desc = descriptor<Float, Method, Task>(); | ||
desc.set_allow_NaN(params["allow_nan"].cast<bool>()); | ||
return desc; | ||
} | ||
}; | ||
|
||
template <typename Policy, typename Task> | ||
void init_compute_ops(py::module_& m) { | ||
m.def("compute", | ||
[](const Policy& policy, | ||
const py::dict& params, | ||
const table& data) { | ||
using namespace finiteness_checker; | ||
using input_t = compute_input<Task>; | ||
|
||
compute_ops ops(policy, input_t{ data }, params2desc{}); | ||
return fptype2t{ method2t{ Task{}, ops } }(params); | ||
}); | ||
} | ||
|
||
template <typename Task> | ||
void init_compute_result(py::module_& m) { | ||
using namespace finiteness_checker; | ||
using result_t = compute_result<Task>; | ||
|
||
py::class_<result_t>(m, "compute_result") | ||
.def(py::init()) | ||
.DEF_ONEDAL_PY_PROPERTY(finite, result_t); | ||
} | ||
|
||
ONEDAL_PY_TYPE2STR(finiteness_checker::task::compute, "compute"); | ||
|
||
ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_ops); | ||
ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_result); | ||
|
||
ONEDAL_PY_INIT_MODULE(finiteness_checker) { | ||
using namespace dal::detail; | ||
using namespace finiteness_checker; | ||
using namespace dal::finiteness_checker; | ||
|
||
using task_list = types<task::compute>; | ||
auto sub = m.def_submodule("finiteness_checker"); | ||
|
||
#ifndef ONEDAL_DATA_PARALLEL_SPMD | ||
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task_list); | ||
ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task_list); | ||
#endif | ||
} | ||
|
||
} // namespace oneapi::dal::python |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# ============================================================================== | ||
Vika-F marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Copyright 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import time | ||
|
||
import numpy as np | ||
import numpy.random as rand | ||
import pytest | ||
import scipy.sparse as sp | ||
|
||
from onedal.tests.utils._dataframes_support import ( | ||
_convert_to_dataframe, | ||
get_dataframes_and_queues, | ||
) | ||
from onedal.utils.validation import assert_all_finite | ||
|
||
|
||
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) | ||
@pytest.mark.parametrize( | ||
"shape", | ||
[ | ||
[16, 2048], | ||
[65539], # 2**16 + 3, | ||
[1000, 1000], | ||
[ | ||
3, | ||
], | ||
], | ||
) | ||
@pytest.mark.parametrize("allow_nan", [False, True]) | ||
@pytest.mark.parametrize( | ||
"dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are there issues with pandas? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pandas inputs are not to be encountered by this function, check_array should always convert them to numpy inputs, we also do not support heterogeneous tables yet, which will be done at a later point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This issue is handled in the follow up PR: #2177 |
||
) | ||
def test_sum_infinite_actually_finite(dtype, shape, allow_nan, dataframe, queue): | ||
X = np.empty(shape, dtype=dtype) | ||
X.fill(np.finfo(dtype).max) | ||
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) | ||
assert_all_finite(X, allow_nan=allow_nan) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) | ||
@pytest.mark.parametrize( | ||
"shape", | ||
[ | ||
[16, 2048], | ||
[65539], # 2**16 + 3, | ||
[1000, 1000], | ||
[ | ||
3, | ||
], | ||
], | ||
) | ||
@pytest.mark.parametrize("allow_nan", [False, True]) | ||
@pytest.mark.parametrize("check", ["inf", "NaN", None]) | ||
@pytest.mark.parametrize("seed", [0, int(time.time())]) | ||
@pytest.mark.parametrize( | ||
"dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl") | ||
) | ||
def test_assert_finite_random_location( | ||
dtype, shape, allow_nan, check, seed, dataframe, queue | ||
): | ||
rand.seed(seed) | ||
X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) | ||
|
||
if check: | ||
loc = rand.randint(0, X.size - 1) | ||
X.reshape((-1,))[loc] = float(check) | ||
|
||
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) | ||
|
||
if check is None or (allow_nan and check == "NaN"): | ||
assert_all_finite(X, allow_nan=allow_nan) | ||
else: | ||
msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "." | ||
with pytest.raises(ValueError, match=msg_err): | ||
assert_all_finite(X, allow_nan=allow_nan) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) | ||
@pytest.mark.parametrize("allow_nan", [False, True]) | ||
@pytest.mark.parametrize("check", ["inf", "NaN", None]) | ||
@pytest.mark.parametrize("seed", [0, int(time.time())]) | ||
@pytest.mark.parametrize( | ||
"dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl") | ||
) | ||
def test_assert_finite_random_shape_and_location( | ||
dtype, allow_nan, check, seed, dataframe, queue | ||
): | ||
lb, ub = 2, 1048576 # ub is 2^20 | ||
rand.seed(seed) | ||
X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) | ||
|
||
if check: | ||
loc = rand.randint(0, X.size - 1) | ||
X[loc] = float(check) | ||
|
||
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) | ||
|
||
if check is None or (allow_nan and check == "NaN"): | ||
assert_all_finite(X, allow_nan=allow_nan) | ||
else: | ||
msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "." | ||
with pytest.raises(ValueError, match=msg_err): | ||
assert_all_finite(X, allow_nan=allow_nan) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) | ||
@pytest.mark.parametrize("allow_nan", [False, True]) | ||
@pytest.mark.parametrize("check", ["inf", "NaN", None]) | ||
@pytest.mark.parametrize("seed", [0, int(time.time())]) | ||
def test_assert_finite_sparse(dtype, allow_nan, check, seed): | ||
lb, ub = 2, 2056 | ||
rand.seed(seed) | ||
X = sp.random( | ||
rand.randint(lb, ub), | ||
rand.randint(lb, ub), | ||
format="csr", | ||
dtype=dtype, | ||
random_state=rand.default_rng(seed), | ||
) | ||
|
||
if check: | ||
locx = rand.randint(0, X.data.shape[0] - 1) | ||
X.data[locx] = float(check) | ||
|
||
if check is None or (allow_nan and check == "NaN"): | ||
assert_all_finite(X, allow_nan=allow_nan) | ||
else: | ||
msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "." | ||
with pytest.raises(ValueError, match=msg_err): | ||
assert_all_finite(X, allow_nan=allow_nan) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,12 @@ | |
from sklearn.preprocessing import LabelEncoder | ||
from sklearn.utils.validation import check_array | ||
|
||
from daal4py.sklearn.utils.validation import _assert_all_finite | ||
from daal4py.sklearn.utils.validation import ( | ||
_assert_all_finite as _daal4py_assert_all_finite, | ||
) | ||
from onedal import _backend | ||
from onedal.common._policy import _get_policy | ||
from onedal.datatypes import _convert_to_supported, to_table | ||
|
||
|
||
class DataConversionWarning(UserWarning): | ||
|
@@ -135,10 +140,10 @@ def _check_array( | |
if force_all_finite: | ||
if sp.issparse(array): | ||
if hasattr(array, "data"): | ||
_assert_all_finite(array.data) | ||
_daal4py_assert_all_finite(array.data) | ||
force_all_finite = False | ||
else: | ||
_assert_all_finite(array) | ||
_daal4py_assert_all_finite(array) | ||
force_all_finite = False | ||
array = check_array( | ||
array=array, | ||
|
@@ -191,7 +196,7 @@ def _check_X_y( | |
if y_numeric and y.dtype.kind == "O": | ||
y = y.astype(np.float64) | ||
if force_all_finite: | ||
_assert_all_finite(y) | ||
_daal4py_assert_all_finite(y) | ||
|
||
lengths = [X.shape[0], y.shape[0]] | ||
uniques = np.unique(lengths) | ||
|
@@ -276,7 +281,7 @@ def _type_of_target(y): | |
# check float and contains non-integer float values | ||
if y.dtype.kind == "f" and np.any(y != y.astype(int)): | ||
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] | ||
_assert_all_finite(y) | ||
_daal4py_assert_all_finite(y) | ||
return "continuous" + suffix | ||
|
||
if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1): | ||
|
@@ -429,3 +434,31 @@ def _is_csr(x): | |
return isinstance(x, sp.csr_matrix) or ( | ||
hasattr(sp, "csr_array") and isinstance(x, sp.csr_array) | ||
) | ||
|
||
|
||
def _assert_all_finite(X, allow_nan=False, input_name=""): | ||
policy = _get_policy(None, X) | ||
X_t = to_table(_convert_to_supported(policy, X)) | ||
params = { | ||
"fptype": X_t.dtype, | ||
"method": "dense", | ||
"allow_nan": allow_nan, | ||
} | ||
if not _backend.finiteness_checker.compute.compute(policy, params, X_t).finite: | ||
type_err = "infinity" if allow_nan else "NaN, infinity" | ||
padded_input_name = input_name + " " if input_name else "" | ||
msg_err = f"Input {padded_input_name}contains {type_err}." | ||
icfaust marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError(msg_err) | ||
|
||
|
||
def assert_all_finite( | ||
X, | ||
*, | ||
allow_nan=False, | ||
input_name="", | ||
): | ||
_assert_all_finite( | ||
X.data if sp.issparse(X) else X, | ||
allow_nan=allow_nan, | ||
input_name=input_name, | ||
) | ||
Comment on lines
+454
to
+464
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this main API for the feature? Is it used in the tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added test which uses the main function name, also testing sparse arrays as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only instantiate if not data parallel spmd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As of now there isn't an spmd implimentation in oneDAL. If you forsee this as a problem, we will have to go back to oneDAL and add it (we can make that available for next release if important). @ethanglaser let me know.