Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[experiment] ENH: using only raw inputs for onedal backend #2153

Draft
wants to merge 55 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
daed528
ENH: using only raw inputs for onedal backend
samir-nasibli Nov 5, 2024
1be2ffb
minor fix
samir-nasibli Nov 5, 2024
a23b677
lin
samir-nasibli Nov 5, 2024
664e140
fix usw_raw_input True/False with dpctl tensor on device
ahuber21 Nov 5, 2024
518dceb
Add hacks to kmeans
ahuber21 Nov 5, 2024
df9d930
Basic statistics online
samir-nasibli Nov 5, 2024
2954913
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
samir-nasibli Nov 5, 2024
3ef345c
Covariance support
ethanglaser Nov 5, 2024
f1c9233
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
ethanglaser Nov 5, 2024
66d7b2d
DBSCAN support
samir-nasibli Nov 5, 2024
c5d26a4
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
samir-nasibli Nov 5, 2024
1350c10
minor fix for dbscan
samir-nasibli Nov 5, 2024
8aaaa70
minor fix for DBSCAN
samir-nasibli Nov 5, 2024
f0d92ae
Apply raw input for batch linear and logistic regression
Alexsandruss Nov 5, 2024
3b58beb
Apply linters
Alexsandruss Nov 5, 2024
d7f2c3c
fix for DBSCAN
samir-nasibli Nov 5, 2024
1aca420
support for Random Forest
samir-nasibli Nov 5, 2024
362930a
PCA support (batch)
ethanglaser Nov 5, 2024
bc37391
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
ethanglaser Nov 5, 2024
102dcae
minor fix for dbscan and rf
samir-nasibli Nov 5, 2024
6edab5b
fully fixed DBSCAN
samir-nasibli Nov 6, 2024
e153a28
Add Incremental Linear Regression
Alexsandruss Nov 6, 2024
37d32c9
Linting
Alexsandruss Nov 6, 2024
71c5135
add modification to knn
ahuber21 Nov 6, 2024
db9f021
minor update for RF
samir-nasibli Nov 6, 2024
bc353da
fix for RandomForestClassifier
samir-nasibli Nov 7, 2024
e873205
minor for RF
samir-nasibli Nov 7, 2024
fe3222a
Update online algos
olegkkruglov Nov 7, 2024
5b3ad17
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
samir-nasibli Nov 7, 2024
eaaab32
fix for RF regressor
samir-nasibli Nov 7, 2024
a7f0c2d
fix workaround for knn
ahuber21 Nov 7, 2024
d9a2966
kmeans predict support
ethanglaser Nov 12, 2024
3562c69
Merge remote-tracking branch 'origin/main' into enh/raw_inputs
ahuber21 Dec 16, 2024
42c3614
fix merge errors
ahuber21 Dec 16, 2024
53bcc7b
fix some tests
ahuber21 Dec 17, 2024
9964c5a
fixup
ahuber21 Dec 17, 2024
84afb62
undo more changes that broke tests
ahuber21 Dec 17, 2024
cf5b736
format
ahuber21 Dec 17, 2024
92393b9
restore original behavior when running without raw inputs
ahuber21 Dec 18, 2024
13471e5
restore original behavior when running without raw inputs
ahuber21 Dec 18, 2024
a8f3f19
align code
ahuber21 Dec 18, 2024
2b07c00
restore original from_table
ahuber21 Dec 19, 2024
6104736
add use_raw_input tests for incremental covariance
ahuber21 Dec 19, 2024
df03233
Add basic statistics testing
ahuber21 Dec 19, 2024
8a166b7
add incremental basic statistics
ahuber21 Dec 19, 2024
fb5f5fa
add dbscan
ahuber21 Dec 19, 2024
7072041
Merge remote-tracking branch 'origin/main' into dev/ahuber/raw-inputs…
ahuber21 Dec 19, 2024
91384ed
add kmeans
ahuber21 Dec 20, 2024
6dec57d
add covariance
ahuber21 Dec 20, 2024
529a7b8
align get_config() import and use_raw_input retrieval
ahuber21 Dec 20, 2024
9f78cbd
add incremental_pca
ahuber21 Dec 20, 2024
658ccc1
add pca
ahuber21 Dec 20, 2024
5e74a54
add incremental linear
ahuber21 Dec 20, 2024
dfbf223
add linear_model
ahuber21 Dec 22, 2024
c4094fb
Merge branch 'dev/ahuber/raw-inputs-dispatching' into enh/raw_inputs
ahuber21 Dec 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onedal/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"target_offload": "auto",
"allow_fallback_to_host": False,
"allow_sklearn_after_onedal": True,
"use_raw_input": False,
}

_threadlocal = threading.local()
Expand Down
56 changes: 34 additions & 22 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,30 +180,42 @@ def support_input_format(freefunc=False, queue_param=True):

def decorator(func):
def wrapper_impl(obj, *args, **kwargs):
if len(args) == 0 and len(kwargs) == 0:
if _get_config()["use_raw_input"] is True:
if "queue" not in kwargs:
usm_iface = getattr(args[0], "__sycl_usm_array_interface__", None)
data_queue = (
usm_iface["syclobj"] if usm_iface is not None else data_queue
)
kwargs["queue"] = data_queue
return _run_on_device(func, obj, *args, **kwargs)

elif len(args) == 0 and len(kwargs) == 0:
return _run_on_device(func, obj, *args, **kwargs)
data = (*args, *kwargs.values())
data_queue, hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
if queue_param and not (
"queue" in hostkwargs and hostkwargs["queue"] is not None
):
hostkwargs["queue"] = data_queue
result = _run_on_device(func, obj, *hostargs, **hostkwargs)
usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
if usm_iface is not None:
result = _copy_to_usm(data_queue, result)
if dpnp_available and isinstance(data[0], dpnp.ndarray):
result = _convert_to_dpnp(result)

else:
data = (*args, *kwargs.values())
data_queue, hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
if queue_param and not (
"queue" in hostkwargs and hostkwargs["queue"] is not None
):
hostkwargs["queue"] = data_queue
result = _run_on_device(func, obj, *hostargs, **hostkwargs)
usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
if usm_iface is not None:
result = _copy_to_usm(data_queue, result)
if dpnp_available and isinstance(data[0], dpnp.ndarray):
result = _convert_to_dpnp(result)
return result
if not get_config().get("transform_output", False):
input_array_api = getattr(
data[0], "__array_namespace__", lambda: None
)()
if input_array_api:
input_array_api_device = data[0].device
result = _asarray(
result, input_array_api, device=input_array_api_device
)
return result
config = get_config()
if not ("transform_output" in config and config["transform_output"]):
input_array_api = getattr(data[0], "__array_namespace__", lambda: None)()
if input_array_api:
input_array_api_device = data[0].device
result = _asarray(
result, input_array_api, device=input_array_api_device
)
return result

if freefunc:

Expand Down
35 changes: 22 additions & 13 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
# limitations under the License.
# ==============================================================================

import warnings
from abc import ABCMeta, abstractmethod

import numpy as np

from .._config import _get_config
from ..common._base import BaseEstimator
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _is_csr
from ..utils._array_api import _get_sycl_namespace
from ..utils.validation import _check_array


Expand Down Expand Up @@ -72,27 +73,35 @@ def __init__(self, result_options="all", algorithm="by_default"):
super().__init__(result_options, algorithm)

def fit(self, data, sample_weight=None, queue=None):
policy = self._get_policy(queue, data, sample_weight)

is_csr = _is_csr(data)

if data is not None and not is_csr:
data = _check_array(data, ensure_2d=False)
if sample_weight is not None:
sample_weight = _check_array(sample_weight, ensure_2d=False)
use_raw_input = _get_config().get("use_raw_input", False) is True

# All data should use the same sycl queue
if use_raw_input and _get_sycl_namespace(data)[0] is not None:
queue = data.sycl_queue

if not use_raw_input:
if data is not None and not is_csr:
data = _check_array(data, ensure_2d=False)
if sample_weight is not None:
sample_weight = _check_array(sample_weight, ensure_2d=False)

# TODO
# use xp for dtype.
policy = self._get_policy(queue, data, sample_weight)
data, sample_weight = _convert_to_supported(policy, data, sample_weight)
is_single_dim = data.ndim == 1
data_table, weights_table = to_table(data, sample_weight)

data_table = to_table(data, sua_iface=_get_sycl_namespace(data)[0])
weights_table = to_table(
sample_weight, sua_iface=_get_sycl_namespace(sample_weight)[0]
)

dtype = data.dtype
raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)
for opt, raw_value in raw_result.items():
value = from_table(raw_value).ravel()
if is_single_dim:
setattr(self, opt, value[0])
else:
setattr(self, opt, value)
setattr(self, opt, value[0]) if data.ndim == 1 else setattr(self, opt, value)

return self

Expand Down
51 changes: 36 additions & 15 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

from daal4py.sklearn._utils import get_dtype

from .._config import _get_config
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
from ..utils._array_api import _get_sycl_namespace
from .basic_statistics import BaseBasicStatistics


Expand Down Expand Up @@ -93,26 +95,39 @@ def partial_fit(self, X, weights=None, queue=None):
self : object
Returns the instance itself.
"""
use_raw_input = _get_config().get("use_raw_input", False) is True
sua_iface, xp, _ = _get_sycl_namespace(X)
# Saving input array namespace and sua_iface, that will be used in
# finalize_fit.
self._input_sua_iface = sua_iface
self._input_xp = xp

# All data should use the same sycl queue
if use_raw_input and sua_iface is not None:
queue = X.sycl_queue

self._queue = queue
policy = self._get_policy(queue, X)
X, weights = _convert_to_supported(policy, X, weights)

X = _check_array(
X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False
)
if weights is not None:
weights = _check_array(
weights,
dtype=[np.float64, np.float32],
ensure_2d=False,
force_all_finite=False,
if not use_raw_input:
X = _check_array(
X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False
)
if weights is not None:
weights = _check_array(
weights,
dtype=[np.float64, np.float32],
ensure_2d=False,
force_all_finite=False,
)

if not hasattr(self, "_onedal_params"):
dtype = get_dtype(X)
self._onedal_params = self._get_onedal_params(False, dtype=dtype)

X_table, weights_table = to_table(X, weights)
X_table = to_table(X, sua_iface=sua_iface)
weights_table = to_table(weights, sua_iface=_get_sycl_namespace(weights)[0])
self._partial_result = self._get_backend(
"basic_statistics",
None,
Expand Down Expand Up @@ -140,10 +155,8 @@ def finalize_fit(self, queue=None):
Returns the instance itself.
"""

if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)
queue = queue if queue is not None else self._queue
policy = self._get_policy(queue)

result = self._get_backend(
"basic_statistics",
Expand All @@ -155,6 +168,14 @@ def finalize_fit(self, queue=None):
)
options = self._get_result_options(self.options).split("|")
for opt in options:
setattr(self, opt, from_table(getattr(result, opt)).ravel())
opt_value = self._input_xp.ravel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running sklearnex example incremental_basic_statistics_dpctl.py leads to AttributeError: 'NoneType' object has no attribute 'ravel'

from_table(
getattr(result, opt),
sua_iface=self._input_sua_iface,
sycl_queue=queue,
xp=self._input_xp,
)
)
setattr(self, opt, opt_value)

return self
45 changes: 35 additions & 10 deletions onedal/cluster/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

from daal4py.sklearn._utils import get_dtype, make2d

from .._config import _get_config
from ..common._base import BaseEstimator
from ..common._mixin import ClusterMixin
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
from ..utils._array_api import _get_sycl_namespace


class BaseDBSCAN(BaseEstimator, ClusterMixin):
Expand Down Expand Up @@ -57,27 +59,50 @@ def _get_onedal_params(self, dtype=np.float32):
}

def _fit(self, X, y, sample_weight, module, queue):
use_raw_input = _get_config().get("use_raw_input", False) is True
sua_iface, xp, _ = _get_sycl_namespace(X)

# All data should use the same sycl queue
if use_raw_input and sua_iface is not None:
queue = X.sycl_queue

policy = self._get_policy(queue, X)
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
sample_weight = make2d(sample_weight) if sample_weight is not None else None
X = make2d(X)

if not use_raw_input:
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
sample_weight = make2d(sample_weight) if sample_weight is not None else None
X = make2d(X)

types = [np.float32, np.float64]
if get_dtype(X) not in types:
X = X.astype(np.float64)
X = _convert_to_supported(policy, X)
dtype = get_dtype(X)
params = self._get_onedal_params(dtype)
result = module.compute(policy, params, to_table(X), to_table(sample_weight))

self.labels_ = from_table(result.responses).ravel()
X_table = to_table(X, sua_iface=sua_iface)
weights_table = to_table(
sample_weight, sua_iface=_get_sycl_namespace(sample_weight)[0]
)

result = module.compute(policy, params, X_table, weights_table)

self.labels_ = xp.reshape(
from_table(result.responses, sua_iface=sua_iface, sycl_queue=queue, xp=xp), -1
)
if result.core_observation_indices is not None:
self.core_sample_indices_ = from_table(
result.core_observation_indices
).ravel()
self.core_sample_indices_ = xp.reshape(
from_table(
result.core_observation_indices,
sua_iface=sua_iface,
sycl_queue=queue,
xp=xp,
),
-1,
)
else:
self.core_sample_indices_ = np.array([], dtype=np.intc)
self.components_ = np.take(X, self.core_sample_indices_, axis=0)
self.core_sample_indices_ = xp.array([], dtype=xp.int32)
self.components_ = xp.take(X, self.core_sample_indices_, axis=0)
self.n_features_in_ = X.shape[1]
return self

Expand Down
37 changes: 27 additions & 10 deletions onedal/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.utils import check_random_state

from .._config import _get_config
from ..common._base import BaseEstimator as onedal_BaseEstimator
from ..common._mixin import ClusterMixin, TransformerMixin
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array, _is_arraylike_not_scalar, _is_csr
from ..utils._array_api import _get_sycl_namespace


class _BaseKMeans(onedal_BaseEstimator, TransformerMixin, ClusterMixin, ABC):
Expand Down Expand Up @@ -80,7 +82,7 @@ def _get_kmeans_init(self, cluster_count, seed, algorithm):
def _get_basic_statistics_backend(self, result_options):
return BasicStatistics(result_options)

def _tolerance(self, X_table, rtol, is_csr, policy, dtype):
def _tolerance(self, X_table, rtol, is_csr, policy, dtype, sua_iface):
"""Compute absolute tolerance from the relative tolerance"""
if rtol == 0.0:
return rtol
Expand All @@ -94,7 +96,7 @@ def _tolerance(self, X_table, rtol, is_csr, policy, dtype):
return mean_var * rtol

def _check_params_vs_input(
self, X_table, is_csr, policy, default_n_init=10, dtype=np.float32
self, X_table, is_csr, policy, default_n_init=10, dtype=np.float32, sua_iface=None
):
# n_clusters
if X_table.shape[0] < self.n_clusters:
Expand All @@ -103,7 +105,7 @@ def _check_params_vs_input(
)

# tol
self._tol = self._tolerance(X_table, self.tol, is_csr, policy, dtype)
self._tol = self._tolerance(X_table, self.tol, is_csr, policy, dtype, sua_iface)

# n-init
# TODO(1.4): Remove
Expand Down Expand Up @@ -261,18 +263,33 @@ def _fit_backend(
)

def _fit(self, X, module, queue=None):
policy = self._get_policy(queue, X)
is_csr = _is_csr(X)
X = _check_array(
X, dtype=[np.float64, np.float32], accept_sparse="csr", force_all_finite=False
)

use_raw_input = _get_config().get("use_raw_input") is True
if use_raw_input and _get_sycl_namespace(X)[0] is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if use_raw_input and _get_sycl_namespace(X)[0] is not None:
if use_raw_input and sua_iface is not None:
  • move line 284 above this

queue = X.sycl_queue

if not use_raw_input:
X = _check_array(
X,
dtype=[np.float64, np.float32],
accept_sparse="csr",
force_all_finite=False,
)

policy = self._get_policy(queue, X)

X = _convert_to_supported(policy, X)
dtype = get_dtype(X)
X_table = to_table(X)
sua_iface = _get_sycl_namespace(X)[0]
X_table = to_table(X, sua_iface=sua_iface)

self._check_params_vs_input(X_table, is_csr, policy, dtype=dtype)
self._check_params_vs_input(
X_table, is_csr, policy, dtype=dtype, sua_iface=sua_iface
)

params = self._get_onedal_params(is_csr, dtype)
# not used?
# params = self._get_onedal_params(is_csr, dtype)

self.n_features_in_ = X_table.column_count

Expand Down
Loading
Loading