diff --git a/onedal/_config.py b/onedal/_config.py index 95a4af41b8..12292cc845 100644 --- a/onedal/_config.py +++ b/onedal/_config.py @@ -22,6 +22,7 @@ "target_offload": "auto", "allow_fallback_to_host": False, "allow_sklearn_after_onedal": True, + "use_raw_input": False, } _threadlocal = threading.local() diff --git a/onedal/_device_offload.py b/onedal/_device_offload.py index 4e46592bb2..6a89d43b81 100644 --- a/onedal/_device_offload.py +++ b/onedal/_device_offload.py @@ -162,30 +162,45 @@ def support_input_format(freefunc=False, queue_param=True): def decorator(func): def wrapper_impl(obj, *args, **kwargs): - if len(args) == 0 and len(kwargs) == 0: + # Check if the function is KNeighborsClassifier.fit + override_raw_input = ( + obj.__class__.__name__ == "KNeighborsClassifier" + and func.__name__ == "fit" + ) + if _get_config()["use_raw_input"] is True and not override_raw_input: + if "queue" not in kwargs: + usm_iface = getattr(args[0], "__sycl_usm_array_interface__", None) + data_queue = usm_iface["syclobj"] if usm_iface is not None else None + kwargs["queue"] = data_queue + return _run_on_device(func, obj, *args, **kwargs) + + elif len(args) == 0 and len(kwargs) == 0: return _run_on_device(func, obj, *args, **kwargs) - data = (*args, *kwargs.values()) - data_queue, hostargs, hostkwargs = _get_host_inputs(*args, **kwargs) - if queue_param and not ( - "queue" in hostkwargs and hostkwargs["queue"] is not None - ): - hostkwargs["queue"] = data_queue - result = _run_on_device(func, obj, *hostargs, **hostkwargs) - usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None) - if usm_iface is not None: - result = _copy_to_usm(data_queue, result) - if dpnp_available and isinstance(data[0], dpnp.ndarray): - result = _convert_to_dpnp(result) + + else: + data = (*args, *kwargs.values()) + data_queue, hostargs, hostkwargs = _get_host_inputs(*args, **kwargs) + if queue_param and not ( + "queue" in hostkwargs and hostkwargs["queue"] is not None + ): + hostkwargs["queue"] = data_queue + result = _run_on_device(func, obj, *hostargs, **hostkwargs) + usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None) + if usm_iface is not None: + result = _copy_to_usm(data_queue, result) + if dpnp_available and isinstance(data[0], dpnp.ndarray): + result = _convert_to_dpnp(result) + return result + if not get_config().get("transform_output", False): + input_array_api = getattr( + data[0], "__array_namespace__", lambda: None + )() + if input_array_api: + input_array_api_device = data[0].device + result = _asarray( + result, input_array_api, device=input_array_api_device + ) return result - config = get_config() - if not ("transform_output" in config and config["transform_output"]): - input_array_api = getattr(data[0], "__array_namespace__", lambda: None)() - if input_array_api: - input_array_api_device = data[0].device - result = _asarray( - result, input_array_api, device=input_array_api_device - ) - return result if freefunc: diff --git a/onedal/basic_statistics/basic_statistics.py b/onedal/basic_statistics/basic_statistics.py index 56904adce2..d749a9cd28 100644 --- a/onedal/basic_statistics/basic_statistics.py +++ b/onedal/basic_statistics/basic_statistics.py @@ -14,14 +14,15 @@ # limitations under the License. # ============================================================================== -import warnings from abc import ABCMeta, abstractmethod import numpy as np +from .._config import _get_config from ..common._base import BaseEstimator from ..datatypes import from_table, to_table from ..utils import _is_csr +from ..utils._array_api import _get_sycl_namespace from ..utils.validation import _check_array @@ -76,10 +77,12 @@ def fit(self, data, sample_weight=None, queue=None): is_csr = _is_csr(data) - if data is not None and not is_csr: - data = _check_array(data, ensure_2d=False) - if sample_weight is not None: - sample_weight = _check_array(sample_weight, ensure_2d=False) + use_raw_input = _get_config().get("use_raw_input", False) is True + if not use_raw_input: + if data is not None and not is_csr: + data = _check_array(data, ensure_2d=False) + if sample_weight is not None: + sample_weight = _check_array(sample_weight, ensure_2d=False) is_single_dim = data.ndim == 1 data_table, weights_table = to_table(data, sample_weight, queue=queue) diff --git a/onedal/basic_statistics/incremental_basic_statistics.py b/onedal/basic_statistics/incremental_basic_statistics.py index b98161ce59..896758cac0 100644 --- a/onedal/basic_statistics/incremental_basic_statistics.py +++ b/onedal/basic_statistics/incremental_basic_statistics.py @@ -18,8 +18,10 @@ from daal4py.sklearn._utils import get_dtype +from .._config import _get_config from ..datatypes import from_table, to_table from ..utils import _check_array +from ..utils._array_api import _get_sycl_namespace from .basic_statistics import BaseBasicStatistics @@ -82,6 +84,7 @@ def __getstate__(self): self.finalize_fit() data = self.__dict__.copy() data.pop("_queue", None) + data.pop("_input_xp", None) # module cannot be pickled return data @@ -104,19 +107,31 @@ def partial_fit(self, X, weights=None, queue=None): self : object Returns the instance itself. """ + use_raw_input = _get_config().get("use_raw_input", False) + sua_iface, xp, _ = _get_sycl_namespace(X) + # Saving input array namespace and sua_iface, that will be used in + # finalize_fit. + self._input_sua_iface = sua_iface + self._input_xp = xp + + # All data should use the same sycl queue + if use_raw_input and sua_iface: + queue = X.sycl_queue + self._queue = queue policy = self._get_policy(queue, X) - X = _check_array( - X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False - ) - if weights is not None: - weights = _check_array( - weights, - dtype=[np.float64, np.float32], - ensure_2d=False, - force_all_finite=False, + if not use_raw_input: + X = _check_array( + X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False ) + if weights is not None: + weights = _check_array( + weights, + dtype=[np.float64, np.float32], + ensure_2d=False, + force_all_finite=False, + ) if not hasattr(self, "_onedal_params"): dtype = get_dtype(X) diff --git a/onedal/cluster/dbscan.py b/onedal/cluster/dbscan.py index 02dcfb6a58..7581fb922d 100644 --- a/onedal/cluster/dbscan.py +++ b/onedal/cluster/dbscan.py @@ -18,10 +18,12 @@ from daal4py.sklearn._utils import get_dtype, make2d +from .._config import _get_config from ..common._base import BaseEstimator from ..common._mixin import ClusterMixin from ..datatypes import from_table, to_table from ..utils import _check_array +from ..utils._array_api import _asarray, _get_sycl_namespace class BaseDBSCAN(BaseEstimator, ClusterMixin): @@ -57,18 +59,26 @@ def _get_onedal_params(self, dtype=np.float32): } def _fit(self, X, y, sample_weight, module, queue): + use_raw_input = _get_config().get("use_raw_input", False) is True policy = self._get_policy(queue, X) - X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) - sample_weight = make2d(sample_weight) if sample_weight is not None else None - X_table, sample_weight_table = to_table(X, sample_weight, queue=queue) + if not use_raw_input: + X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) + sample_weight = make2d(sample_weight) if sample_weight is not None else None + X = make2d(X) + X_table, sample_weight_table = to_table(X, sample_weight, queue=queue) params = self._get_onedal_params(X_table.dtype) result = module.compute(policy, params, X_table, sample_weight_table) - self.labels_ = from_table(result.responses).ravel() - if result.core_observation_indices is not None: + _, xp, _ = _get_sycl_namespace(X) + self.labels_ = from_table(result.responses, sycl_queue=queue).ravel() + if ( + result.core_observation_indices is not None + and not result.core_observation_indices.kind == "empty" + ): self.core_sample_indices_ = from_table( - result.core_observation_indices + result.core_observation_indices, + sycl_queue=queue, ).ravel() else: self.core_sample_indices_ = np.array([], dtype=np.intc) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index a40729841d..6d6a94324c 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -32,10 +32,12 @@ from sklearn.metrics.pairwise import euclidean_distances from sklearn.utils import check_random_state +from .._config import _get_config from ..common._base import BaseEstimator as onedal_BaseEstimator from ..common._mixin import ClusterMixin, TransformerMixin from ..datatypes import from_table, to_table from ..utils import _check_array, _is_arraylike_not_scalar, _is_csr +from ..utils._array_api import _get_sycl_namespace class _BaseKMeans(onedal_BaseEstimator, TransformerMixin, ClusterMixin, ABC): @@ -80,7 +82,7 @@ def _get_kmeans_init(self, cluster_count, seed, algorithm): def _get_basic_statistics_backend(self, result_options): return BasicStatistics(result_options) - def _tolerance(self, X_table, rtol, is_csr, policy, dtype): + def _tolerance(self, X_table, rtol, is_csr, policy, dtype, sua_iface): """Compute absolute tolerance from the relative tolerance""" if rtol == 0.0: return rtol @@ -94,7 +96,7 @@ def _tolerance(self, X_table, rtol, is_csr, policy, dtype): return mean_var * rtol def _check_params_vs_input( - self, X_table, is_csr, policy, default_n_init=10, dtype=np.float32 + self, X_table, is_csr, policy, default_n_init=10, dtype=np.float32, sua_iface=None ): # n_clusters if X_table.shape[0] < self.n_clusters: @@ -103,7 +105,7 @@ def _check_params_vs_input( ) # tol - self._tol = self._tolerance(X_table, self.tol, is_csr, policy, dtype) + self._tol = self._tolerance(X_table, self.tol, is_csr, policy, dtype, sua_iface) # n-init # TODO(1.4): Remove @@ -261,9 +263,14 @@ def _fit_backend( def _fit(self, X, module, queue=None): policy = self._get_policy(queue, X) is_csr = _is_csr(X) - X = _check_array( - X, dtype=[np.float64, np.float32], accept_sparse="csr", force_all_finite=False - ) + + if _get_config()["use_raw_input"] is False: + X = _check_array( + X, + dtype=[np.float64, np.float32], + accept_sparse="csr", + force_all_finite=False, + ) X_table = to_table(X, queue=queue) dtype = X_table.dtype diff --git a/onedal/covariance/covariance.py b/onedal/covariance/covariance.py index 795df08dd9..27045b9168 100644 --- a/onedal/covariance/covariance.py +++ b/onedal/covariance/covariance.py @@ -17,12 +17,14 @@ import numpy as np -from daal4py.sklearn._utils import daal_check_version, get_dtype -from onedal.utils import _check_array +from daal4py.sklearn._utils import daal_check_version +from .._config import _get_config from ..common._base import BaseEstimator from ..common.hyperparameters import get_hyperparameters from ..datatypes import from_table, to_table +from ..utils import _check_array +from ..utils._array_api import _get_sycl_namespace class BaseEmpiricalCovariance(BaseEstimator, metaclass=ABCMeta): @@ -93,8 +95,14 @@ def fit(self, X, y=None, queue=None): self : object Returns the instance itself. """ + use_raw_input = _get_config()["use_raw_input"] + sua_iface, xp, _ = _get_sycl_namespace(X) + if use_raw_input and sua_iface: + queue = X.sycl_queue + policy = self._get_policy(queue, X) - X = _check_array(X, dtype=[np.float64, np.float32]) + if not use_raw_input: + X = _check_array(X, dtype=[np.float64, np.float32]) X = to_table(X, queue=queue) params = self._get_onedal_params(X.dtype) hparams = get_hyperparameters("covariance", "compute") @@ -111,12 +119,14 @@ def fit(self, X, y=None, queue=None): else: result = self._get_backend("covariance", None, "compute", policy, params, X) if daal_check_version((2024, "P", 1)) or (not self.bias): - self.covariance_ = from_table(result.cov_matrix) + self.covariance_ = from_table(result.cov_matrix, sycl_queue=queue) else: self.covariance_ = ( - from_table(result.cov_matrix) * (X.shape[0] - 1) / X.shape[0] + from_table(result.cov_matrix, sycl_queue=queue) + * (X.shape[0] - 1) + / X.shape[0] ) - self.location_ = from_table(result.means).ravel() + self.location_ = from_table(result.means, sycl_queue=queue).ravel() return self diff --git a/onedal/covariance/incremental_covariance.py b/onedal/covariance/incremental_covariance.py index b0bfb04e22..51cc5d18d3 100644 --- a/onedal/covariance/incremental_covariance.py +++ b/onedal/covariance/incremental_covariance.py @@ -15,10 +15,12 @@ # =============================================================================== import numpy as np -from daal4py.sklearn._utils import daal_check_version, get_dtype +from daal4py.sklearn._utils import daal_check_version +from .._config import _get_config from ..datatypes import from_table, to_table from ..utils import _check_array +from ..utils._array_api import _get_sycl_namespace from .covariance import BaseEmpiricalCovariance @@ -70,6 +72,7 @@ def __getstate__(self): self.finalize_fit() data = self.__dict__.copy() data.pop("_queue", None) + data.pop("_input_xp", None) # module cannot be pickled return data @@ -95,10 +98,19 @@ def partial_fit(self, X, y=None, queue=None): self : object Returns the instance itself. """ - X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True) + # Saving input array namespace and sua_iface, that will be used in + # finalize_fit. + sua_iface, xp, _ = _get_sycl_namespace(X) + self._input_sua_iface = sua_iface + self._input_xp = xp + + use_raw_input = _get_config().get("use_raw_input", False) + if use_raw_input and sua_iface: + queue = X.sycl_queue + if not use_raw_input: + X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True) self._queue = queue - policy = self._get_policy(queue, X) X_table = to_table(X, queue=queue) diff --git a/onedal/decomposition/incremental_pca.py b/onedal/decomposition/incremental_pca.py index 58c852ed81..0513194a70 100644 --- a/onedal/decomposition/incremental_pca.py +++ b/onedal/decomposition/incremental_pca.py @@ -16,10 +16,10 @@ import numpy as np -from daal4py.sklearn._utils import get_dtype - +from .._config import _get_config from ..datatypes import from_table, to_table from ..utils import _check_array +from ..utils._array_api import _get_sycl_namespace from .pca import BasePCA @@ -113,7 +113,7 @@ def __getstate__(self): self.finalize_fit() data = self.__dict__.copy() data.pop("_queue", None) - + data.pop("_input_xp", None) # module cannot be pickled return data def partial_fit(self, X, queue): @@ -133,9 +133,21 @@ def partial_fit(self, X, queue): self : object Returns the instance itself. """ - X = _check_array(X) - n_samples, n_features = X.shape + use_raw_input = _get_config()["use_raw_input"] + sua_iface, xp, _ = _get_sycl_namespace(X) + # Saving input array namespace and sua_iface, that will be used in + # finalize_fit. + self._input_sua_iface = sua_iface + self._input_xp = xp + + # All data should use the same sycl queue + if use_raw_input and sua_iface: + queue = X.sycl_queue + if not use_raw_input: + X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True) + + n_samples, n_features = X.shape first_pass = not hasattr(self, "components_") if first_pass: self.components_ = None @@ -210,5 +222,5 @@ def finalize_fit(self, queue=None): self.noise_variance_ = self._compute_noise_variance( self.n_components_, min(self.n_samples_seen_, self.n_features_in_) ) - self._need_to_finalize = False + self._need_to_finalize = False return self diff --git a/onedal/decomposition/pca.py b/onedal/decomposition/pca.py index fe6f585ba5..a87b04c5bc 100644 --- a/onedal/decomposition/pca.py +++ b/onedal/decomposition/pca.py @@ -21,8 +21,10 @@ from sklearn.decomposition._pca import _infer_dimension from sklearn.utils.extmath import stable_cumsum +from .._config import _get_config from ..common._base import BaseEstimator from ..datatypes import from_table, to_table +from ..utils._array_api import _get_sycl_namespace class BasePCA(BaseEstimator, metaclass=ABCMeta): @@ -142,6 +144,11 @@ def predict(self, X, queue=None): class PCA(BasePCA): def fit(self, X, y=None, queue=None): + use_raw_input = _get_config()["use_raw_input"] + sua_iface, xp, _ = _get_sycl_namespace(X) + if use_raw_input and sua_iface: + queue = X.sycl_queue + n_samples, n_features = X.shape n_sf_min = min(n_samples, n_features) self._validate_n_components(self.n_components, n_samples, n_features) @@ -158,8 +165,13 @@ def fit(self, X, y=None, queue=None): "decomposition", "dim_reduction", "train", policy, params, X ) - self.mean_ = from_table(result.means).ravel() - self.variances_ = from_table(result.variances) + self.mean_ = xp.reshape( + from_table(result.means, sua_iface=sua_iface, sycl_queue=queue, xp=xp), -1 + ) + self.variances_ = from_table( + result.variances, sua_iface=sua_iface, sycl_queue=queue, xp=xp + ) + # TODO: why are there errors when using sua_iface and sycl_queue on following from_table calls? self.components_ = from_table(result.eigenvectors) self.singular_values_ = from_table(result.singular_values).ravel() self.explained_variance_ = np.maximum(from_table(result.eigenvalues).ravel(), 0) diff --git a/onedal/ensemble/forest.py b/onedal/ensemble/forest.py index 0a006bf9b1..ddb214b010 100644 --- a/onedal/ensemble/forest.py +++ b/onedal/ensemble/forest.py @@ -26,6 +26,7 @@ from daal4py.sklearn._utils import daal_check_version from sklearnex import get_hyperparameters +from .._config import _get_config from ..common._base import BaseEstimator from ..common._estimator_checks import _check_is_fitted from ..common._mixin import ClassifierMixin, RegressorMixin @@ -37,6 +38,7 @@ _column_or_1d, _validate_targets, ) +from ..utils._array_api import _get_sycl_namespace class BaseForest(BaseEstimator, BaseEnsemble, metaclass=ABCMeta): @@ -289,22 +291,36 @@ def _get_sample_weight(self, sample_weight, X): return sample_weight def _fit(self, X, y, sample_weight, module, queue): - X, y = _check_X_y( - X, - y, - dtype=[np.float64, np.float32], - force_all_finite=True, - accept_sparse="csr", - ) - y = self._validate_targets(y, X.dtype) + use_raw_input = _get_config()["use_raw_input"] + sua_iface, xp, _ = _get_sycl_namespace(X) + + # All data should use the same sycl queue + if use_raw_input and sua_iface is not None: + queue = X.sycl_queue + + if not use_raw_input: + X, y = _check_X_y( + X, + y, + dtype=[np.float64, np.float32], + force_all_finite=True, + accept_sparse="csr", + ) + y = self._validate_targets(y, X.dtype) + else: + # TODO: + # check it first. + self.classes_ = xp.unique_all(y).values self.n_features_in_ = X.shape[1] if sample_weight is not None and len(sample_weight) > 0: - sample_weight = self._get_sample_weight(sample_weight, X) + if not use_raw_input: + sample_weight = self._get_sample_weight(sample_weight, X) data = (X, y, sample_weight) else: data = (X, y) + policy = self._get_policy(queue, *data) data = to_table(*data, queue=queue) params = self._get_onedal_params(data[0]) @@ -318,7 +334,7 @@ def _fit(self, X, y, sample_weight, module, queue): self.oob_decision_function_ = from_table( train_result.oob_err_decision_function ) - if np.any(self.oob_decision_function_ == 0): + if xp.any(self.oob_decision_function_ == 0): warnings.warn( "Some inputs do not have OOB scores. This probably means " "too few trees were used to compute any reliable OOB " @@ -347,10 +363,21 @@ def _create_model(self, module): def _predict(self, X, module, queue, hparams=None): _check_is_fitted(self) - X = _check_array( - X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False - ) - _check_n_features(self, X, False) + + use_raw_input = _get_config()["use_raw_input"] + sua_iface, xp, _ = _get_sycl_namespace(X) + + # All data should use the same sycl queue + if use_raw_input and sua_iface is not None: + queue = X.sycl_queue + if not use_raw_input: + X = _check_array( + X, + dtype=[np.float64, np.float32], + force_all_finite=True, + accept_sparse=False, + ) + _check_n_features(self, X, False) policy = self._get_policy(queue, X) model = self._onedal_model @@ -361,15 +388,26 @@ def _predict(self, X, module, queue, hparams=None): else: result = module.infer(policy, params, model, X) - y = from_table(result.responses) + y = from_table(result.responses, sua_iface=sua_iface, sycl_queue=queue, xp=xp) return y def _predict_proba(self, X, module, queue, hparams=None): _check_is_fitted(self) - X = _check_array( - X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False - ) - _check_n_features(self, X, False) + use_raw_input = _get_config()["use_raw_input"] + sua_iface, xp, _ = _get_sycl_namespace(X) + + # All data should use the same sycl queue + if use_raw_input and sua_iface is not None: + queue = X.sycl_queue + + if not use_raw_input: + X = _check_array( + X, + dtype=[np.float64, np.float32], + force_all_finite=True, + accept_sparse=False, + ) + _check_n_features(self, X, False) policy = self._get_policy(queue, X) X = to_table(X, queue=queue) params = self._get_onedal_params(X) @@ -381,8 +419,7 @@ def _predict_proba(self, X, module, queue, hparams=None): else: result = module.infer(policy, params, model, X) - y = from_table(result.probabilities) - return y + return from_table(result.probabilities) class RandomForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta): @@ -465,15 +502,19 @@ def fit(self, X, y, sample_weight=None, queue=None): ) def predict(self, X, queue=None): + _, xp, _ = _get_sycl_namespace(X) hparams = get_hyperparameters("decision_forest", "infer") - pred = super()._predict( - X, - self._get_backend("decision_forest", "classification", None), - queue, - hparams, + pred = xp.reshape( + super()._predict( + X, + self._get_backend("decision_forest", "classification", None), + queue, + hparams, + ), + -1, ) - return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe")) + return xp.take(self.classes_, pred.astype(xp.int64, casting="unsafe")) def predict_proba(self, X, queue=None): hparams = get_hyperparameters("decision_forest", "infer") @@ -545,10 +586,14 @@ def __init__( ) def fit(self, X, y, sample_weight=None, queue=None): - if sample_weight is not None: - if hasattr(sample_weight, "__array__"): - sample_weight[sample_weight == 0.0] = 1.0 - sample_weight = [sample_weight] + use_raw_input = _get_config()["use_raw_input"] + # TODO: + # check if required. + if not use_raw_input: + if sample_weight is not None: + if hasattr(sample_weight, "__array__"): + sample_weight[sample_weight == 0.0] = 1.0 + sample_weight = [sample_weight] return super()._fit( X, y, @@ -558,10 +603,12 @@ def fit(self, X, y, sample_weight=None, queue=None): ) def predict(self, X, queue=None): - return ( - super() - ._predict(X, self._get_backend("decision_forest", "regression", None), queue) - .ravel() + _, xp, _ = _get_sycl_namespace(X) + return xp.reshape( + super()._predict( + X, self._get_backend("decision_forest", "regression", None), queue + ), + -1, ) diff --git a/onedal/linear_model/incremental_linear_model.py b/onedal/linear_model/incremental_linear_model.py index bc48d59077..4eb4b86fcd 100644 --- a/onedal/linear_model/incremental_linear_model.py +++ b/onedal/linear_model/incremental_linear_model.py @@ -16,11 +16,11 @@ import numpy as np -from daal4py.sklearn._utils import get_dtype - +from .._config import _get_config from ..common.hyperparameters import get_hyperparameters from ..datatypes import from_table, to_table from ..utils import _check_X_y, _num_features +from ..utils._array_api import _get_sycl_namespace from .linear_model import BaseLinearRegression @@ -72,24 +72,28 @@ def partial_fit(self, X, y, queue=None): self : object Returns the instance itself. """ + if not hasattr(self, "_params"): + self._params = self._get_onedal_params(X.dtype) + + if _get_config().get("use_raw_input") is False: + X, y = _check_X_y( + X, + y, + dtype=[np.float64, np.float32], + accept_2d_y=True, + force_all_finite=False, + ) + y = np.asarray(y, dtype=X.dtype) + module = self._get_backend("linear_model", "regression") - self._queue = queue policy = self._get_policy(queue, X) - - X, y = _check_X_y( - X, y, dtype=[np.float64, np.float32], accept_2d_y=True, force_all_finite=False - ) - y = np.asarray(y, dtype=X.dtype) + queue = self._queue = getattr(policy, "_queue", None) self.n_features_in_ = _num_features(X, fallback_1d=True) X_table, y_table = to_table(X, y, queue=queue) - if not hasattr(self, "_dtype"): - self._dtype = X_table.dtype - self._params = self._get_onedal_params(self._dtype) - hparams = get_hyperparameters("linear_regression", "train") if hparams is not None and not hparams.is_default: self._partial_result = module.partial_train( @@ -120,7 +124,6 @@ def finalize_fit(self, queue=None): self : object Returns the instance itself. """ - if queue is not None: policy = self._get_policy(queue) else: @@ -137,7 +140,9 @@ def finalize_fit(self, queue=None): self._onedal_model = result.model - packed_coefficients = from_table(result.model.packed_coefficients) + packed_coefficients = from_table( + result.model.packed_coefficients, sycl_queue=self._queue + ) self.coef_, self.intercept_ = ( packed_coefficients[:, 1:].squeeze(), packed_coefficients[:, 0].squeeze(), @@ -201,24 +206,33 @@ def partial_fit(self, X, y, queue=None): self : object Returns the instance itself. """ + if not hasattr(self, "_params"): + self._params = self._get_onedal_params(X.dtype) + module = self._get_backend("linear_model", "regression") + self._sua_iface, self._xp, _ = _get_sycl_namespace(X) + use_raw_input = _get_config().get("use_raw_input") is True + if use_raw_input and self._sua_iface is not None: + queue = X.sycl_queue + self._queue = queue policy = self._get_policy(queue, X) - X, y = _check_X_y( - X, y, dtype=[np.float64, np.float32], accept_2d_y=True, force_all_finite=False - ) - y = np.asarray(y, dtype=X.dtype) + if not use_raw_input: + X, y = _check_X_y( + X, + y, + dtype=[np.float64, np.float32], + accept_2d_y=True, + force_all_finite=False, + ) + y = np.asarray(y, dtype=X.dtype) self.n_features_in_ = _num_features(X, fallback_1d=True) X_table, y_table = to_table(X, y, queue=queue) - if not hasattr(self, "_dtype"): - self._dtype = X_table.dtype - self._params = self._get_onedal_params(self._dtype) - self._partial_result = module.partial_train( policy, self._params, self._partial_result, X_table, y_table ) @@ -247,10 +261,22 @@ def finalize_fit(self, queue=None): self._onedal_model = result.model - packed_coefficients = from_table(result.model.packed_coefficients) - self.coef_, self.intercept_ = ( - packed_coefficients[:, 1:].squeeze(), - packed_coefficients[:, 0].squeeze(), - ) + if _get_config().get("use_raw_input") is True: + packed_coefficients = from_table( + result.model.packed_coefficients, + sua_iface=self._sua_iface, + sycl_queue=self._queue, + xp=self._xp, + ) + self.coef_, self.intercept_ = ( + self._xp.squeeze(packed_coefficients[:, 1:]), + self._xp.squeeze(packed_coefficients[:, 0]), + ) + else: + packed_coefficients = from_table(result.model.packed_coefficients) + self.coef_, self.intercept_ = ( + packed_coefficients[:, 1:].squeeze(), + packed_coefficients[:, 0].squeeze(), + ) return self diff --git a/onedal/linear_model/linear_model.py b/onedal/linear_model/linear_model.py index 264a571de0..9b9c1b5a8b 100755 --- a/onedal/linear_model/linear_model.py +++ b/onedal/linear_model/linear_model.py @@ -21,11 +21,13 @@ from daal4py.sklearn._utils import daal_check_version, get_dtype, make2d +from .._config import _get_config from ..common._base import BaseEstimator from ..common._estimator_checks import _check_is_fitted from ..common.hyperparameters import get_hyperparameters from ..datatypes import from_table, to_table from ..utils import _check_array, _check_n_features, _check_X_y, _num_features +from ..utils._array_api import _get_sycl_namespace class BaseLinearRegression(BaseEstimator, metaclass=ABCMeta): @@ -119,11 +121,17 @@ def predict(self, X, queue=None): _check_is_fitted(self) + sua_iface, xp, _ = _get_sycl_namespace(X) + use_raw_input = _get_config().get("use_raw_input") is True + policy = self._get_policy(queue, X) - X = _check_array( - X, dtype=[np.float64, np.float32], force_all_finite=False, ensure_2d=False - ) + if not use_raw_input: + X = _check_array( + X, dtype=[np.float64, np.float32], force_all_finite=False, ensure_2d=False + ) + X = make2d(X) + _check_n_features(self, X, False) if hasattr(self, "_onedal_model"): @@ -135,10 +143,10 @@ def predict(self, X, queue=None): params = self._get_onedal_params(X_table.dtype) result = module.infer(policy, params, model, X_table) - y = from_table(result.responses) + y = from_table(result.responses, sua_iface=sua_iface, sycl_queue=queue, xp=xp) if y.shape[1] == 1 and self.coef_.ndim == 1: - return y.ravel() + return xp.reshape(y, (-1,)) else: return y @@ -192,26 +200,29 @@ def fit(self, X, y, queue=None): """ module = self._get_backend("linear_model", "regression") - # TODO Fix _check_X_y to make sure this conversion is there - if not isinstance(X, np.ndarray): - X = np.asarray(X) + if _get_config()["use_raw_input"] is False: + if not isinstance(X, np.ndarray): + X = np.asarray(X) - dtype = get_dtype(X) - if dtype not in [np.float32, np.float64]: - dtype = np.float64 - X = X.astype(dtype, copy=self.copy_X) + dtype = get_dtype(X) + if dtype not in [np.float32, np.float64]: + dtype = np.float64 + X = X.astype(dtype, copy=self.copy_X) - y = np.asarray(y).astype(dtype=dtype) + y = np.asarray(y).astype(dtype=dtype) - X, y = _check_X_y(X, y, force_all_finite=False, accept_2d_y=True) + X, y = _check_X_y(X, y, force_all_finite=False, accept_2d_y=True) policy = self._get_policy(queue, X, y) + if _get_config()["use_raw_input"] is True: + # make sure we are using the queue from the on-device provided data + queue = getattr(policy, "_queue", queue) + + X_table, y_table = to_table(X, y, queue=queue) self.n_features_in_ = _num_features(X, fallback_1d=True) - X_table, y_table = to_table(X, y, queue=queue) params = self._get_onedal_params(X_table.dtype) - hparams = get_hyperparameters("linear_regression", "train") if hparams is not None and not hparams.is_default: result = module.train(policy, params, hparams.backend, X_table, y_table) @@ -220,14 +231,16 @@ def fit(self, X, y, queue=None): self._onedal_model = result.model - packed_coefficients = from_table(result.model.packed_coefficients) + packed_coefficients = from_table( + result.model.packed_coefficients, sycl_queue=queue + ) self.coef_, self.intercept_ = ( packed_coefficients[:, 1:], packed_coefficients[:, 0], ) if self.coef_.shape[0] == 1 and y.ndim == 1: - self.coef_ = self.coef_.ravel() + self.coef_ = np.reshape(self.coef_, (-1,)) self.intercept_ = self.intercept_[0] return self @@ -288,7 +301,10 @@ def fit(self, X, y, queue=None): self : object Fitted Estimator. """ + sua_iface, xp, _ = _get_sycl_namespace(X) + module = self._get_backend("linear_model", "regression") + _, xp, _ = _get_sycl_namespace(X) if not isinstance(X, np.ndarray): X = np.asarray(X) @@ -298,9 +314,17 @@ def fit(self, X, y, queue=None): dtype = np.float64 X = X.astype(dtype, copy=self.copy_X) - y = np.asarray(y).astype(dtype=dtype) - - X, y = _check_X_y(X, y, force_all_finite=False, accept_2d_y=True) + use_raw_input = _get_config().get("use_raw_input") is True + if not use_raw_input: + X = _check_array( + X, + dtype=[np.float64, np.float32], + force_all_finite=False, + ensure_2d=False, + copy=self.copy_X, + ) + X, y = _check_X_y(X, y, force_all_finite=False, accept_2d_y=True) + y = np.asarray(y).astype(dtype=get_dtype(X)) policy = self._get_policy(queue, X, y) @@ -312,14 +336,16 @@ def fit(self, X, y, queue=None): result = module.train(policy, params, X_table, y_table) self._onedal_model = result.model - packed_coefficients = from_table(result.model.packed_coefficients) + packed_coefficients = from_table( + result.model.packed_coefficients, sua_iface=sua_iface, sycl_queue=queue, xp=xp + ) self.coef_, self.intercept_ = ( packed_coefficients[:, 1:], packed_coefficients[:, 0], ) if self.coef_.shape[0] == 1 and y.ndim == 1: - self.coef_ = self.coef_.ravel() + self.coef_ = xp.reshape(self.coef_, (-1,)) self.intercept_ = self.intercept_[0] return self diff --git a/onedal/linear_model/logistic_regression.py b/onedal/linear_model/logistic_regression.py index 53e5f293ce..147a3686b7 100644 --- a/onedal/linear_model/logistic_regression.py +++ b/onedal/linear_model/logistic_regression.py @@ -21,6 +21,7 @@ from daal4py.sklearn._utils import daal_check_version, get_dtype, make2d +from .._config import _get_config from ..common._base import BaseEstimator as onedal_BaseEstimator from ..common._estimator_checks import _check_is_fitted from ..common._mixin import ClassifierMixin @@ -33,6 +34,8 @@ _num_features, _type_of_target, ) +from ..utils._array_api import _get_sycl_namespace +from ..utils._dpep_helpers import get_unique_values_with_dpep class BaseLogisticRegression(onedal_BaseEstimator, metaclass=ABCMeta): @@ -63,25 +66,34 @@ def _get_onedal_params(self, is_csr, dtype=np.float32): } def _fit(self, X, y, module, queue): + use_raw_input = _get_config().get("use_raw_input") is True + sua_iface = _get_sycl_namespace(X, y)[0] + if use_raw_input and sua_iface is not None: + queue = X.sycl_queue + sparsity_enabled = daal_check_version((2024, "P", 700)) - X, y = _check_X_y( - X, - y, - accept_sparse=sparsity_enabled, - force_all_finite=True, - accept_2d_y=False, - dtype=[np.float64, np.float32], - ) - is_csr = _is_csr(X) + if not use_raw_input: + X, y = _check_X_y( + X, + y, + accept_sparse=sparsity_enabled, + force_all_finite=True, + accept_2d_y=False, + dtype=[np.float64, np.float32], + ) + if _type_of_target(y) != "binary": + raise ValueError("Only binary classification is supported") + + self.classes_, y = np.unique(y, return_inverse=True) + y = y.astype(dtype=np.int32) + else: + self.classes_ = get_unique_values_with_dpep(y) + n_classes = len(self.classes_) + if n_classes != 2: + raise ValueError("Only binary classification is supported") self.n_features_in_ = _num_features(X, fallback_1d=True) - - if _type_of_target(y) != "binary": - raise ValueError("Only binary classification is supported") - - self.classes_, y = np.unique(y, return_inverse=True) - y = y.astype(dtype=np.int32) - + is_csr = _is_csr(X) policy = self._get_policy(queue, X, y) X_table, y_table = to_table(X, y, queue=queue) params = self._get_onedal_params(is_csr, X_table.dtype) @@ -151,22 +163,29 @@ def _create_model(self, module, policy): return m - def _infer(self, X, module, queue): + def _infer(self, X, module, queue, sua_iface): _check_is_fitted(self) + + use_raw_input = _get_config().get("use_raw_input") is True + if use_raw_input and _get_sycl_namespace(X)[0] is not None: + queue = X.sycl_queue + sparsity_enabled = daal_check_version((2024, "P", 700)) - X = _check_array( - X, - dtype=[np.float64, np.float32], - accept_sparse=sparsity_enabled, - force_all_finite=True, - ensure_2d=False, - accept_large_sparse=sparsity_enabled, - ) - is_csr = _is_csr(X) + if not use_raw_input: + X = _check_array( + X, + dtype=[np.float64, np.float32], + accept_sparse=sparsity_enabled, + force_all_finite=True, + ensure_2d=False, + accept_large_sparse=sparsity_enabled, + ) + X = make2d(X) + _check_n_features(self, X, False) + is_csr = _is_csr(X) - X = make2d(X) policy = self._get_policy(queue, X) if hasattr(self, "_onedal_model"): @@ -181,21 +200,32 @@ def _infer(self, X, module, queue): return result def _predict(self, X, module, queue): - result = self._infer(X, module, queue) - y = from_table(result.responses) - y = np.take(self.classes_, y.ravel(), axis=0) + use_raw_input = _get_config().get("use_raw_input") is True + sua_iface, xp, _ = _get_sycl_namespace(X) + if use_raw_input and sua_iface is not None: + queue = X.sycl_queue + + result = self._infer(X, module, queue, sua_iface) + y = from_table(result.responses, sua_iface=sua_iface, sycl_queue=queue, xp=xp) + y = xp.take(xp.asarray(self.classes_), xp.reshape(y, (-1,)), axis=0) return y def _predict_proba(self, X, module, queue): - result = self._infer(X, module, queue) + use_raw_input = _get_config().get("use_raw_input") is True + sua_iface, xp, _ = _get_sycl_namespace(X) + if use_raw_input and sua_iface is not None: + queue = X.sycl_queue + + result = self._infer(X, module, queue, sua_iface) - y = from_table(result.probabilities) + y = from_table(result.probabilities, sua_iface=sua_iface, sycl_queue=queue, xp=xp) y = y.reshape(-1, 1) - return np.hstack([1 - y, y]) + return xp.hstack([1 - y, y]) def _predict_log_proba(self, X, module, queue): + _, xp, _ = _get_sycl_namespace(X) y_proba = self._predict_proba(X, module, queue) - return np.log(y_proba) + return xp.log(y_proba) class LogisticRegression(ClassifierMixin, BaseLogisticRegression): diff --git a/onedal/neighbors/neighbors.py b/onedal/neighbors/neighbors.py index b97706e49a..074a9c5ed4 100755 --- a/onedal/neighbors/neighbors.py +++ b/onedal/neighbors/neighbors.py @@ -28,6 +28,7 @@ kdtree_knn_classification_training, ) +from .._config import _get_config from ..common._base import BaseEstimator from ..common._estimator_checks import _check_is_fitted, _is_classifier, _is_regressor from ..common._mixin import ClassifierMixin, RegressorMixin @@ -462,7 +463,8 @@ def fit(self, X, y, queue=None): return super()._fit(X, y, queue=queue) def predict(self, X, queue=None): - X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) + if not _get_config()["use_raw_input"]: + X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) onedal_model = getattr(self, "_onedal_model", None) n_features = getattr(self, "n_features_in_", None) n_samples_fit_ = getattr(self, "n_samples_fit_", None) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 47da103da9..6daa8123a9 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -18,6 +18,8 @@ from collections.abc import Iterable +import numpy as np + from ._dpep_helpers import dpctl_available, dpnp_available if dpctl_available: @@ -78,4 +80,4 @@ def _get_sycl_namespace(*arrays): else: raise ValueError(f"SYCL type not recognized: {sua_iface}") - return sua_iface, None, False + return sua_iface, np, False diff --git a/onedal/utils/_dpep_helpers.py b/onedal/utils/_dpep_helpers.py index 3494f71d6d..f9c4cbe56e 100644 --- a/onedal/utils/_dpep_helpers.py +++ b/onedal/utils/_dpep_helpers.py @@ -54,3 +54,18 @@ def is_dpnp_available(version=None): dpctl_available = is_dpctl_available() dpnp_available = is_dpnp_available() + + +if dpnp_available: + import dpnp +if dpctl_available: + import dpctl.tensor as dpt + + +def get_unique_values_with_dpep(X): + if dpnp_available: + return dpnp.unique(X) + elif dpctl_available: + return dpt.unique_values(X) + else: + raise RuntimeError("No DPEP package available to provide `unique` function.") diff --git a/sklearnex/_config.py b/sklearnex/_config.py index 6589f77d85..683008a159 100644 --- a/sklearnex/_config.py +++ b/sklearnex/_config.py @@ -44,6 +44,7 @@ def set_config( target_offload=None, allow_fallback_to_host=None, allow_sklearn_after_onedal=None, + use_raw_input=None, **sklearn_configs, ): """Set global configuration @@ -82,6 +83,8 @@ def set_config( local_config["allow_fallback_to_host"] = allow_fallback_to_host if allow_sklearn_after_onedal is not None: local_config["allow_sklearn_after_onedal"] = allow_sklearn_after_onedal + if use_raw_input is not None: + local_config["use_raw_input"] = use_raw_input @contextmanager diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 7e299f07e0..5d6d5ba374 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -58,44 +58,49 @@ def _get_backend(obj, queue, method_name, *data): def dispatch(obj, method_name, branches, *args, **kwargs): - q = _get_global_queue() - has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) - has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) - hostkwargs = dict(zip(kwargs.keys(), hostvalues)) - - backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) - has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs - if backend == "onedal": - # Host args only used before onedal backend call. - # Device will be offloaded when onedal backend will be called. - patching_status.write_log(queue=q, transferred_to_host=False) - return branches[backend](obj, *hostargs, **hostkwargs, queue=q) - if backend == "sklearn": - if ( - "array_api_dispatch" in get_config() - and get_config()["array_api_dispatch"] - and "array_api_support" in obj._get_tags() - and obj._get_tags()["array_api_support"] - and not has_usm_data - ): - # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is - # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, - # except for the linalg module. There is no guarantee that stock scikit-learn will - # work with such input data. The condition will be updated after DPNP.ndarray and - # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance - # of the fallback cases. - # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, - # then raw inputs are used for the fallback. - patching_status.write_log(transferred_to_host=False) - return branches[backend](obj, *args, **kwargs) - else: - patching_status.write_log() - return branches[backend](obj, *hostargs, **hostkwargs) - raise RuntimeError( - f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}" - ) + if get_config()["use_raw_input"] is False: + q = _get_global_queue() + has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) + has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) + hostkwargs = dict(zip(kwargs.keys(), hostvalues)) + + backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) + has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs + if backend == "onedal": + # Host args only used before onedal backend call. + # Device will be offloaded when onedal backend will be called. + patching_status.write_log(queue=q, transferred_to_host=False) + return branches[backend](obj, *hostargs, **hostkwargs, queue=q) + if backend == "sklearn": + if ( + "array_api_dispatch" in get_config() + and get_config()["array_api_dispatch"] + and "array_api_support" in obj._get_tags() + and obj._get_tags()["array_api_support"] + and not has_usm_data + ): + # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is + # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, + # except for the linalg module. There is no guarantee that stock scikit-learn will + # work with such input data. The condition will be updated after DPNP.ndarray and + # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance + # of the fallback cases. + # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, + # then raw inputs are used for the fallback. + patching_status.write_log(transferred_to_host=False) + return branches[backend](obj, *args, **kwargs) + else: + patching_status.write_log() + return branches[backend](obj, *hostargs, **hostkwargs) + raise RuntimeError( + f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}" + ) + else: + return branches["onedal"](obj, *args, **kwargs) +# TODO: +# wrap output. def wrap_output_data(func): """ Converts and moves the output arrays of the decorated function diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index da82e3bd82..8caeb0f1ec 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -25,6 +25,7 @@ from daal4py.sklearn._utils import sklearn_check_version from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics +from .._config import get_config from .._device_offload import dispatch from .._utils import IntelEstimator, PatchingConditionsChain @@ -179,13 +180,16 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): if sklearn_check_version("1.2"): self._validate_params() - if sklearn_check_version("1.0"): - X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False) - else: - X = check_array(X, dtype=[np.float64, np.float32]) + if get_config()["use_raw_input"] is False: + if sklearn_check_version("1.0"): + X = validate_data( + self, X, dtype=[np.float64, np.float32], ensure_2d=False + ) + else: + X = check_array(X, dtype=[np.float64, np.float32]) - if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X) + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X) onedal_params = { "result_options": self.result_options, diff --git a/sklearnex/basic_statistics/incremental_basic_statistics.py b/sklearnex/basic_statistics/incremental_basic_statistics.py index d1ddcd55dc..cb77f15d4d 100644 --- a/sklearnex/basic_statistics/incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/incremental_basic_statistics.py @@ -25,6 +25,7 @@ IncrementalBasicStatistics as onedal_IncrementalBasicStatistics, ) +from .._config import get_config from .._device_offload import dispatch from .._utils import IntelEstimator, PatchingConditionsChain @@ -194,6 +195,9 @@ def _onedal_finalize_fit(self, queue=None): def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=True): first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0 + use_raw_input = get_config()["use_raw_input"] + # never check input when using raw input + check_input &= use_raw_input is False if check_input: if sklearn_check_version("1.0"): X = validate_data( @@ -208,7 +212,7 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru dtype=[np.float64, np.float32], ) - if sample_weight is not None: + if not use_raw_input and sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) if first_pass: diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index a5515f240d..e661a9fbc7 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -24,6 +24,7 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex._config import config_context from sklearnex.basic_statistics import BasicStatistics @@ -96,8 +97,17 @@ def test_multiple_options_on_gold_data(dataframe, queue, weighted, dtype): @pytest.mark.parametrize("column_count", [10, 100]) @pytest.mark.parametrize("weighted", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("use_raw_input", [True, False]) def test_single_option_on_random_data( - dataframe, queue, result_option, row_count, column_count, weighted, dtype + skip_unsupported_raw_input, + dataframe, + queue, + result_option, + row_count, + column_count, + weighted, + dtype, + use_raw_input, ): function, tols = options_and_tests[result_option] fp32tol, fp64tol = tols @@ -112,10 +122,11 @@ def test_single_option_on_random_data( weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe) basicstat = BasicStatistics(result_options=result_option) - if weighted: - result = basicstat.fit(X_df, sample_weight=weights_df) - else: - result = basicstat.fit(X_df) + with config_context(use_raw_input=use_raw_input): + if weighted: + result = basicstat.fit(X_df, sample_weight=weights_df) + else: + result = basicstat.fit(X_df) res = getattr(result, result_option) if weighted: diff --git a/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py b/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py index c648537993..ffd1beb29a 100644 --- a/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py @@ -24,13 +24,17 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex._config import config_context from sklearnex.basic_statistics import IncrementalBasicStatistics @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @pytest.mark.parametrize("weighted", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_partial_fit_multiple_options_on_gold_data(dataframe, queue, weighted, dtype): +@pytest.mark.parametrize("use_raw_input", [True, False]) +def test_partial_fit_multiple_options_on_gold_data( + skip_unsupported_raw_input, dataframe, queue, weighted, dtype, use_raw_input +): X = np.array([[0, 0], [1, 1]]) X = X.astype(dtype=dtype) X_split = np.array_split(X, 2) @@ -40,17 +44,18 @@ def test_partial_fit_multiple_options_on_gold_data(dataframe, queue, weighted, d weights_split = np.array_split(weights, 2) incbs = IncrementalBasicStatistics() - for i in range(2): - X_split_df = _convert_to_dataframe( - X_split[i], sycl_queue=queue, target_df=dataframe - ) - if weighted: - weights_split_df = _convert_to_dataframe( - weights_split[i], sycl_queue=queue, target_df=dataframe + with config_context(use_raw_input=use_raw_input): + for i in range(2): + X_split_df = _convert_to_dataframe( + X_split[i], sycl_queue=queue, target_df=dataframe ) - result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df) - else: - result = incbs.partial_fit(X_split_df) + if weighted: + weights_split_df = _convert_to_dataframe( + weights_split[i], sycl_queue=queue, target_df=dataframe + ) + result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df) + else: + result = incbs.partial_fit(X_split_df) if weighted: expected_weighted_mean = np.array([0.25, 0.25]) diff --git a/sklearnex/cluster/dbscan.py b/sklearnex/cluster/dbscan.py index ef5f6b78d9..ee8a8a2963 100755 --- a/sklearnex/cluster/dbscan.py +++ b/sklearnex/cluster/dbscan.py @@ -25,6 +25,7 @@ from daal4py.sklearn._utils import sklearn_check_version from onedal.cluster import DBSCAN as onedal_DBSCAN +from .._config import get_config from .._device_offload import dispatch from .._utils import PatchingConditionsChain @@ -89,8 +90,9 @@ def __init__( self.n_jobs = n_jobs def _onedal_fit(self, X, y, sample_weight=None, queue=None): - if sklearn_check_version("1.0"): - X = validate_data(self, X, force_all_finite=False) + if get_config()["use_raw_input"] is False: + if sklearn_check_version("1.0"): + X = validate_data(self, X, force_all_finite=False) onedal_params = { "eps": self.eps, @@ -178,7 +180,8 @@ def fit(self, X, y=None, sample_weight=None): if self.eps <= 0.0: raise ValueError(f"eps == {self.eps}, must be > 0.0.") - if sample_weight is not None: + use_raw_input = get_config().get("use_raw_input", False) is True + if not use_raw_input and sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) dispatch( self, diff --git a/sklearnex/cluster/k_means.py b/sklearnex/cluster/k_means.py index 4ba75ca5b8..6dddfd207f 100644 --- a/sklearnex/cluster/k_means.py +++ b/sklearnex/cluster/k_means.py @@ -17,6 +17,7 @@ import logging from daal4py.sklearn._utils import daal_check_version +from sklearnex._config import get_config if daal_check_version((2023, "P", 200)): @@ -156,20 +157,21 @@ def fit(self, X, y=None, sample_weight=None): return self def _onedal_fit(self, X, _, sample_weight, queue=None): - X = validate_data( - self, - X, - accept_sparse="csr", - dtype=[np.float64, np.float32], - order="C", - copy=self.copy_x, - accept_large_sparse=False, - ) + if get_config()["use_raw_input"] is False: + X = validate_data( + self, + X, + accept_sparse="csr", + dtype=[np.float64, np.float32], + order="C", + copy=self.copy_x, + accept_large_sparse=False, + ) - if sklearn_check_version("1.2"): - self._check_params_vs_input(X) - else: - self._check_params(X) + if sklearn_check_version("1.2"): + self._check_params_vs_input(X) + else: + self._check_params(X) self._n_features_out = self.n_clusters @@ -295,13 +297,14 @@ def predict( ) def _onedal_predict(self, X, sample_weight=None, queue=None): - X = validate_data( - self, - X, - accept_sparse="csr", - reset=False, - dtype=[np.float64, np.float32], - ) + if get_config()["use_raw_input"] is False: + X = validate_data( + self, + X, + accept_sparse="csr", + reset=False, + dtype=[np.float64, np.float32], + ) if not hasattr(self, "_onedal_estimator"): self._initialize_onedal_estimator() diff --git a/sklearnex/cluster/tests/test_dbscan.py b/sklearnex/cluster/tests/test_dbscan.py index a83b5b7cec..4d7bd13171 100755 --- a/sklearnex/cluster/tests/test_dbscan.py +++ b/sklearnex/cluster/tests/test_dbscan.py @@ -22,15 +22,20 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex._config import config_context @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) -def test_sklearnex_import_dbscan(dataframe, queue): +@pytest.mark.parametrize("use_raw_input", [True, False]) +def test_sklearnex_import_dbscan( + skip_unsupported_raw_input, dataframe, queue, use_raw_input +): from sklearnex.cluster import DBSCAN - X = np.array([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]]) + X = np.array([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]], dtype=np.float32) X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) - dbscan = DBSCAN(eps=3, min_samples=2).fit(X) + with config_context(use_raw_input=use_raw_input): + dbscan = DBSCAN(eps=3, min_samples=2).fit(X) assert "sklearnex" in dbscan.__module__ result = dbscan.labels_ diff --git a/sklearnex/cluster/tests/test_kmeans.py b/sklearnex/cluster/tests/test_kmeans.py index ca18a07114..389a84043b 100755 --- a/sklearnex/cluster/tests/test_kmeans.py +++ b/sklearnex/cluster/tests/test_kmeans.py @@ -92,15 +92,21 @@ def test_sklearnex_import_for_sparse_data(queue, algorithm, init): @pytest.mark.parametrize( "algorithm", ["lloyd" if sklearn_check_version("1.1") else "full", "elkan"] ) -def test_results_on_dense_gold_data(dataframe, queue, algorithm): +@pytest.mark.parametrize("use_raw_input", [True, False]) +def test_results_on_dense_gold_data( + skip_unsupported_raw_input, dataframe, queue, algorithm, use_raw_input +): from sklearnex.cluster import KMeans - X_train = np.array([[1, 2], [1, 4], [1, 0], [10, 2], [10, 4], [10, 0]]) - X_test = np.array([[0, 0], [12, 3]]) + X_train = np.array( + [[1, 2], [1, 4], [1, 0], [10, 2], [10, 4], [10, 0]], dtype=np.float32 + ) + X_test = np.array([[0, 0], [12, 3]], dtype=np.float32) X_train_df = _convert_to_dataframe(X_train, sycl_queue=queue, target_df=dataframe) X_test_df = _convert_to_dataframe(X_test, sycl_queue=queue, target_df=dataframe) - kmeans = KMeans(n_clusters=2, random_state=0, algorithm=algorithm).fit(X_train_df) + with config_context(use_raw_input=use_raw_input): + kmeans = KMeans(n_clusters=2, random_state=0, algorithm=algorithm).fit(X_train_df) if queue and queue.sycl_device.is_gpu: # KMeans Init Dense GPU implementation is different from CPU @@ -112,7 +118,10 @@ def test_results_on_dense_gold_data(dataframe, queue, algorithm): expected_cluster_centers = np.array([[10.0, 2.0], [1.0, 2.0]], dtype=np.float32) expected_inertia = 16.0 - assert_allclose(expected_cluster_labels, _as_numpy(kmeans.predict(X_test_df))) + with config_context(use_raw_input=use_raw_input): + result = kmeans.predict(X_test_df) + + assert_allclose(expected_cluster_labels, _as_numpy(result)) assert_allclose(expected_cluster_centers, _as_numpy(kmeans.cluster_centers_)) assert expected_inertia == kmeans.inertia_ diff --git a/sklearnex/conftest.py b/sklearnex/conftest.py index 4ecad5383b..bf35b6ad31 100644 --- a/sklearnex/conftest.py +++ b/sklearnex/conftest.py @@ -80,3 +80,23 @@ def with_array_api(): def without_allow_sklearn_after_onedal(): with config_context(allow_sklearn_after_onedal=False): yield + + +@pytest.fixture +def skip_unsupported_raw_input(request): + # lookup if use_raw_input and dataframe are used in the test + use_raw_input = ( + request.getfixturevalue("use_raw_input") + if "use_raw_input" in request.fixturenames + else False + ) + dataframe = ( + request.getfixturevalue("dataframe") + if "dataframe" in request.fixturenames + else None + ) + + # skip tests of unsupported dataframes when using use_raw_input=True + if use_raw_input is True and dataframe not in ["numpy", "dpnp", "dpctl"]: + pytest.skip(f"use_raw_input is not supported for {dataframe}") + yield diff --git a/sklearnex/covariance/incremental_covariance.py b/sklearnex/covariance/incremental_covariance.py index 89ed92b601..61e2cb7f51 100644 --- a/sklearnex/covariance/incremental_covariance.py +++ b/sklearnex/covariance/incremental_covariance.py @@ -184,7 +184,6 @@ def location_(self): ) def _onedal_partial_fit(self, X, queue=None, check_input=True): - first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0 # finite check occurs on onedal side diff --git a/sklearnex/covariance/tests/test_incremental_covariance.py b/sklearnex/covariance/tests/test_incremental_covariance.py index e42373cf84..f84fc46171 100644 --- a/sklearnex/covariance/tests/test_incremental_covariance.py +++ b/sklearnex/covariance/tests/test_incremental_covariance.py @@ -43,6 +43,7 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex import config_context @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @@ -126,8 +127,16 @@ def test_sklearnex_fit_on_gold_data(dataframe, queue, batch_size, dtype): @pytest.mark.parametrize("row_count", [100, 1000]) @pytest.mark.parametrize("column_count", [10, 100]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("use_raw_input", [True, False]) def test_sklearnex_partial_fit_on_random_data( - dataframe, queue, num_batches, row_count, column_count, dtype + skip_unsupported_raw_input, + dataframe, + queue, + num_batches, + row_count, + column_count, + dtype, + use_raw_input, ): from sklearnex.covariance import IncrementalEmpiricalCovariance @@ -138,17 +147,18 @@ def test_sklearnex_partial_fit_on_random_data( X_split = np.array_split(X, num_batches) inccov = IncrementalEmpiricalCovariance() - for i in range(num_batches): - X_split_df = _convert_to_dataframe( - X_split[i], sycl_queue=queue, target_df=dataframe - ) - result = inccov.partial_fit(X_split_df) + with config_context(use_raw_input=use_raw_input): + for i in range(num_batches): + X_split_df = _convert_to_dataframe( + X_split[i], sycl_queue=queue, target_df=dataframe + ) + result = inccov.partial_fit(X_split_df, check_input=not use_raw_input) - expected_covariance = np.cov(X.T, bias=1) - expected_means = np.mean(X, axis=0) + expected_covariance = np.cov(X.T, bias=1) + expected_means = np.mean(X, axis=0) - assert_allclose(expected_covariance, result.covariance_, atol=1e-6) - assert_allclose(expected_means, result.location_, atol=1e-6) + assert_allclose(expected_covariance, result.covariance_, atol=1e-6) + assert_allclose(expected_means, result.location_, atol=1e-6) @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) diff --git a/sklearnex/decomposition/pca.py b/sklearnex/decomposition/pca.py index 143587aa16..17f0f5fc23 100755 --- a/sklearnex/decomposition/pca.py +++ b/sklearnex/decomposition/pca.py @@ -17,6 +17,7 @@ import logging from daal4py.sklearn._utils import daal_check_version +from sklearnex._config import get_config if daal_check_version((2024, "P", 100)): import numbers @@ -138,13 +139,23 @@ def _fit(self, X): ) def _onedal_fit(self, X, queue=None): - X = validate_data( - self, - X, - dtype=[np.float64, np.float32], - ensure_2d=True, - copy=self.copy, - ) + if get_config()["use_raw_input"] is True: + # With `use_raw_input=True`` we never check for oneDAL compatibility and instead + # always dispatch to oneDAL. + # For this algorithm this means that `_is_solver_compatible_with_onedal()`` + # never gets called, and therefore `self._fit_svd_solver` is not set. + # We therefore assert the solver compatibility here explictly to set all + # variables correctly. + assert self._is_solver_compatible_with_onedal(X.shape) + else: + # Compatibility is already asserted, continue with checking the provided data + X = validate_data( + self, + X, + dtype=[np.float64, np.float32], + ensure_2d=True, + copy=self.copy, + ) onedal_params = { "n_components": self.n_components, @@ -182,18 +193,19 @@ def transform(self, X): ) def _onedal_transform(self, X, queue=None): - if sklearn_check_version("1.0"): - X = validate_data( - self, - X, - dtype=[np.float64, np.float32], - reset=False, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - ) + if get_config()["use_raw_input"] is False: + if sklearn_check_version("1.0"): + X = validate_data( + self, + X, + dtype=[np.float64, np.float32], + reset=False, + ) + else: + X = check_array( + X, + dtype=[np.float64, np.float32], + ) self._validate_n_features_in_after_fitting(X) return self._onedal_estimator.predict(X, queue=queue) diff --git a/sklearnex/decomposition/tests/test_pca.py b/sklearnex/decomposition/tests/test_pca.py index 5f8270d80c..2aa2ce04f3 100755 --- a/sklearnex/decomposition/tests/test_pca.py +++ b/sklearnex/decomposition/tests/test_pca.py @@ -24,13 +24,15 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex._config import config_context @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) -def test_sklearnex_import(dataframe, queue): +@pytest.mark.parametrize("use_raw_input", [True, False]) +def test_sklearnex_import(skip_unsupported_raw_input, dataframe, queue, use_raw_input): from sklearnex.decomposition import PCA - X = [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]] + X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float32) X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) X_transformed_expected = [ [-1.38340578, -0.2935787], @@ -42,9 +44,12 @@ def test_sklearnex_import(dataframe, queue): ] pca = PCA(n_components=2, svd_solver="covariance_eigh") - pca.fit(X) - X_transformed = pca.transform(X) - X_fit_transformed = PCA(n_components=2, svd_solver="covariance_eigh").fit_transform(X) + with config_context(use_raw_input=use_raw_input): + pca.fit(X) + X_transformed = pca.transform(X) + X_fit_transformed = PCA( + n_components=2, svd_solver="covariance_eigh" + ).fit_transform(X) if daal_check_version((2024, "P", 100)): assert "sklearnex" in pca.__module__ diff --git a/sklearnex/ensemble/_forest.py b/sklearnex/ensemble/_forest.py index 2a04962645..ab62c219a4 100644 --- a/sklearnex/ensemble/_forest.py +++ b/sklearnex/ensemble/_forest.py @@ -17,6 +17,7 @@ import numbers import warnings from abc import ABC +from collections.abc import Iterable import numpy as np from scipy import sparse as sp @@ -60,6 +61,7 @@ from sklearnex import get_hyperparameters from sklearnex._utils import register_hyperparameters +from .._config import get_config from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain from ..utils._array_api import get_namespace @@ -79,19 +81,25 @@ class BaseForest(ABC): _onedal_factory = None def _onedal_fit(self, X, y, sample_weight=None, queue=None): - X, y = validate_data( - self, - X, - y, - multi_output=True, - accept_sparse=False, - dtype=[np.float64, np.float32], - force_all_finite=False, - ensure_2d=True, - ) + use_raw_input = get_config()["use_raw_input"] + xp, _ = get_namespace(X) + if not use_raw_input: + X, y = validate_data( + self, + X, + y, + multi_output=True, + accept_sparse=False, + dtype=[np.float64, np.float32], + force_all_finite=False, + ensure_2d=True, + ) - if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X) + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X) + else: + self.classes_ = xp.unique_all(y).values + self.n_classes_ = len(self.classes_) if y.ndim == 2 and y.shape[1] == 1: warnings.warn( @@ -102,22 +110,28 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): stacklevel=2, ) - if y.ndim == 1: - # reshape is necessary to preserve the data contiguity against vs - # [:, np.newaxis] that does not. - y = np.reshape(y, (-1, 1)) + if not use_raw_input: + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = xp.reshape(y, (-1, 1)) - self._n_samples, self.n_outputs_ = y.shape + if y.ndim == 1: + self._n_samples, self.n_outputs_ = y.shape[0], 1 + else: + self._n_samples, self.n_outputs_ = y.shape - y, expanded_class_weight = self._validate_y_class_weight(y) + if not use_raw_input: + y, expanded_class_weight = self._validate_y_class_weight(y) - if expanded_class_weight is not None: + if expanded_class_weight is not None: + if sample_weight is not None: + sample_weight = sample_weight * expanded_class_weight + else: + sample_weight = expanded_class_weight if sample_weight is not None: - sample_weight = sample_weight * expanded_class_weight - else: - sample_weight = expanded_class_weight - if sample_weight is not None: - sample_weight = [sample_weight] + sample_weight = [sample_weight] + self.n_features_in_ = X.shape[1] onedal_params = { "n_estimators": self.n_estimators, @@ -155,13 +169,20 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): # Compute self._onedal_estimator = self._onedal_factory(**onedal_params) - self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue) + if use_raw_input: + self._onedal_estimator.fit(X, y, sample_weight, queue=queue) + else: + self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue) self._save_attributes() # Decapsulate classes_ attributes if hasattr(self, "classes_") and self.n_outputs_ == 1: - self.n_classes_ = self.n_classes_[0] + self.n_classes_ = ( + self.n_classes_[0] + if isinstance(self.n_classes_, Iterable) + else self.n_classes_ + ) self.classes_ = self.classes_[0] return self @@ -371,6 +392,7 @@ def _estimators_(self): "nodes": check_tree_nodes(tree_i_state_class.node_ar), "values": tree_i_state_class.value_ar, } + # Note: only on host. est_i.tree_ = Tree( self.n_features_in_, np.array([n_classes_], dtype=np.intp), @@ -790,7 +812,6 @@ def _onedal_gpu_supported(self, method_name, *data): return patching_status def _onedal_predict(self, X, queue=None): - if sklearn_check_version("1.0"): X = validate_data( self, @@ -801,11 +822,12 @@ def _onedal_predict(self, X, queue=None): ensure_2d=True, ) else: - X = check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - ) # Warning, order of dtype matters + if not get_config()["use_raw_input"]: + X = check_array( + X, + dtype=[np.float64, np.float32], + force_all_finite=False, + ) # Warning, order of dtype matters if hasattr(self, "n_features_in_"): try: num_features = _num_features(X) @@ -825,23 +847,26 @@ def _onedal_predict(self, X, queue=None): return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe")) def _onedal_predict_proba(self, X, queue=None): - - if sklearn_check_version("1.0"): - X = validate_data( - self, - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - reset=False, - ensure_2d=True, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - ) # Warning, order of dtype matters - self._check_n_features(X, reset=False) + xp, _ = get_namespace(X) + use_raw_input = get_config()["use_raw_input"] + if not use_raw_input: + if sklearn_check_version("1.0"): + X = validate_data( + self, + X, + dtype=[np.float64, np.float32], + force_all_finite=False, + reset=False, + ensure_2d=True, + ) + # sklearn version < 1.0 is not supported + # else: + # X = check_array( + # X, + # dtype=[np.float64, np.float32], + # force_all_finite=False, + # ) # Warning, order of dtype matters + # self._check_n_features(X, reset=False) return self._onedal_estimator.predict_proba(X, queue=queue) @@ -1130,24 +1155,29 @@ def _onedal_gpu_supported(self, method_name, *data): def _onedal_predict(self, X, queue=None): check_is_fitted(self, "_onedal_estimator") + use_raw_input = get_config()["use_raw_input"] - if sklearn_check_version("1.0"): - X = validate_data( - self, - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - reset=False, - ensure_2d=True, - ) # Warning, order of dtype matters - else: - X = check_array( - X, dtype=[np.float64, np.float32], force_all_finite=False - ) # Warning, order of dtype matters + if not use_raw_input: + if sklearn_check_version("1.0"): + X = validate_data( + self, + X, + dtype=[np.float64, np.float32], + force_all_finite=False, + reset=False, + ensure_2d=True, + ) # Warning, order of dtype matters + # sklearn version < 1.0 is not supported + # else: + # X = check_array( + # X, dtype=[np.float64, np.float32], force_all_finite=False + # ) # Warning, order of dtype matters return self._onedal_estimator.predict(X, queue=queue) def _onedal_score(self, X, y, sample_weight=None, queue=None): + # TODO: + # should be checked for dpctl/dpnp inputs. return r2_score( y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight ) diff --git a/sklearnex/linear_model/incremental_linear.py b/sklearnex/linear_model/incremental_linear.py index db2d6549c0..98708b4821 100644 --- a/sklearnex/linear_model/incremental_linear.py +++ b/sklearnex/linear_model/incremental_linear.py @@ -28,6 +28,7 @@ from onedal.linear_model import ( IncrementalLinearRegression as onedal_IncrementalLinearRegression, ) +from sklearnex._config import get_config if sklearn_check_version("1.2"): from sklearn.utils._param_validation import Interval @@ -240,31 +241,31 @@ def _onedal_finalize_fit(self, queue=None): self._need_to_finalize = False def _onedal_fit(self, X, y, queue=None): - if sklearn_check_version("1.2"): - self._validate_params() - - if sklearn_check_version("1.0"): - X, y = validate_data( - self, - X, - y, - dtype=[np.float64, np.float32], - copy=self.copy_X, - multi_output=True, - ensure_2d=True, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - copy=self.copy_X, - ) - y = check_array( - y, - dtype=[np.float64, np.float32], - copy=False, - ensure_2d=False, - ) + if get_config()["use_raw_input"] is False: + if sklearn_check_version("1.2"): + self._validate_params() + if sklearn_check_version("1.0"): + X, y = validate_data( + self, + X, + y, + dtype=[np.float64, np.float32], + copy=self.copy_X, + multi_output=True, + ensure_2d=True, + ) + else: + X = check_array( + X, + dtype=[np.float64, np.float32], + copy=self.copy_X, + ) + y = check_array( + y, + dtype=[np.float64, np.float32], + copy=False, + ensure_2d=False, + ) n_samples, n_features = X.shape @@ -283,9 +284,6 @@ def _onedal_fit(self, X, y, queue=None): X_batch, y_batch = X[batch], y[batch] self._onedal_partial_fit(X_batch, y_batch, check_input=False, queue=queue) - if sklearn_check_version("1.2"): - self._validate_params() - # finite check occurs on onedal side self.n_features_in_ = n_features diff --git a/sklearnex/linear_model/linear.py b/sklearnex/linear_model/linear.py index fb7eca8cf1..7a5849d021 100644 --- a/sklearnex/linear_model/linear.py +++ b/sklearnex/linear_model/linear.py @@ -249,10 +249,12 @@ def _onedal_fit(self, X, y, sample_weight, queue=None): "y_numeric": True, "multi_output": supports_multi_output, } - if sklearn_check_version("1.0"): - X, y = validate_data(self, **check_params) - else: - X, y = check_X_y(**check_params) + + if get_config()["use_raw_input"] is False: + if sklearn_check_version("1.0"): + X, y = validate_data(self, **check_params) + else: + X, y = check_X_y(**check_params) if sklearn_check_version("1.0") and not sklearn_check_version("1.2"): self._normalize = _deprecate_normalize( diff --git a/sklearnex/linear_model/tests/test_incremental_linear.py b/sklearnex/linear_model/tests/test_incremental_linear.py index e4ab649daf..4c446690d2 100644 --- a/sklearnex/linear_model/tests/test_incremental_linear.py +++ b/sklearnex/linear_model/tests/test_incremental_linear.py @@ -23,6 +23,7 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex._config import config_context from sklearnex.linear_model import IncrementalLinearRegression from sklearnex.tests.utils import _IS_INTEL @@ -31,7 +32,16 @@ @pytest.mark.parametrize("fit_intercept", [True, False]) @pytest.mark.parametrize("macro_block", [None, 1024]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_sklearnex_fit_on_gold_data(dataframe, queue, fit_intercept, macro_block, dtype): +@pytest.mark.parametrize("use_raw_input", [True, False]) +def test_sklearnex_fit_on_gold_data( + skip_unsupported_raw_input, + dataframe, + queue, + fit_intercept, + macro_block, + dtype, + use_raw_input, +): X = np.array([[1], [2]]) X = X.astype(dtype=dtype) X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) @@ -44,7 +54,8 @@ def test_sklearnex_fit_on_gold_data(dataframe, queue, fit_intercept, macro_block hparams = IncrementalLinearRegression.get_hyperparameters("fit") hparams.cpu_macro_block = macro_block hparams.gpu_macro_block = macro_block - inclin.fit(X_df, y_df) + with config_context(use_raw_input=use_raw_input): + inclin.fit(X_df, y_df) y_pred = inclin.predict(X_df) np_y_pred = _as_numpy(y_pred) diff --git a/sklearnex/linear_model/tests/test_linear.py b/sklearnex/linear_model/tests/test_linear.py index 128f5110dc..f9b4be26a3 100644 --- a/sklearnex/linear_model/tests/test_linear.py +++ b/sklearnex/linear_model/tests/test_linear.py @@ -26,6 +26,7 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex._config import config_context from sklearnex.tests.utils import _IS_INTEL @@ -34,8 +35,16 @@ @pytest.mark.parametrize("macro_block", [None, 1024]) @pytest.mark.parametrize("overdetermined", [False, True]) @pytest.mark.parametrize("multi_output", [False, True]) +@pytest.mark.parametrize("use_raw_input", [True, False]) def test_sklearnex_import_linear( - dataframe, queue, dtype, macro_block, overdetermined, multi_output + skip_unsupported_raw_input, + dataframe, + queue, + dtype, + macro_block, + overdetermined, + multi_output, + use_raw_input, ): if (not overdetermined or multi_output) and not daal_check_version((2025, "P", 1)): pytest.skip("Functionality introduced in later versions") @@ -71,7 +80,9 @@ def test_sklearnex_import_linear( y_list = y.tolist() X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe) - linreg.fit(X, y) + + with config_context(use_raw_input=use_raw_input): + linreg.fit(X, y) assert hasattr(linreg, "_onedal_estimator") assert "sklearnex" in linreg.__module__ diff --git a/sklearnex/preview/covariance/covariance.py b/sklearnex/preview/covariance/covariance.py index 04bdc0be8d..288cf3e353 100644 --- a/sklearnex/preview/covariance/covariance.py +++ b/sklearnex/preview/covariance/covariance.py @@ -28,6 +28,7 @@ from sklearnex import config_context from sklearnex.metrics import pairwise_distances +from ..._config import get_config from ..._device_offload import dispatch, wrap_output_data from ..._utils import PatchingConditionsChain, register_hyperparameters @@ -95,10 +96,11 @@ def _onedal_supported(self, method_name, *data): def fit(self, X, y=None): if sklearn_check_version("1.2"): self._validate_params() - if sklearn_check_version("0.23"): - X = validate_data(self, X, force_all_finite=False) - else: - X = check_array(X, force_all_finite=False) + if get_config()["use_raw_input"] is False: + if sklearn_check_version("0.23"): + X = validate_data(self, X, force_all_finite=False) + else: + X = check_array(X, force_all_finite=False) dispatch( self, diff --git a/sklearnex/preview/covariance/tests/test_covariance.py b/sklearnex/preview/covariance/tests/test_covariance.py index 71eb9235c3..39c87dbb47 100644 --- a/sklearnex/preview/covariance/tests/test_covariance.py +++ b/sklearnex/preview/covariance/tests/test_covariance.py @@ -23,22 +23,33 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex._config import config_context @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @pytest.mark.parametrize("macro_block", [None, 1024]) @pytest.mark.parametrize("assume_centered", [True, False]) -def test_sklearnex_import_covariance(dataframe, queue, macro_block, assume_centered): +@pytest.mark.parametrize("use_raw_input", [True, False]) +def test_sklearnex_import_covariance( + skip_unsupported_raw_input, + dataframe, + queue, + macro_block, + assume_centered, + use_raw_input, +): from sklearnex.preview.covariance import EmpiricalCovariance - X = np.array([[0, 1], [0, 1]]) + X = np.array([[0, 1], [0, 1]], dtype=np.float32) X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) empcov = EmpiricalCovariance(assume_centered=assume_centered) if daal_check_version((2024, "P", 0)) and macro_block is not None: hparams = EmpiricalCovariance.get_hyperparameters("fit") hparams.cpu_macro_block = macro_block - result = empcov.fit(X) + + with config_context(use_raw_input=use_raw_input): + result = empcov.fit(X) expected_covariance = np.array([[0, 0], [0, 0]]) expected_means = np.array([0, 0]) @@ -51,10 +62,12 @@ def test_sklearnex_import_covariance(dataframe, queue, macro_block, assume_cente assert_allclose(expected_covariance, result.covariance_) assert_allclose(expected_means, result.location_) - X = np.array([[1, 2], [3, 6]]) + X = np.array([[1, 2], [3, 6]], dtype=np.float32) X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) - result = empcov.fit(X) + + with config_context(use_raw_input=use_raw_input): + result = empcov.fit(X) if assume_centered: expected_covariance = np.array([[5, 10], [10, 20]]) diff --git a/sklearnex/preview/decomposition/incremental_pca.py b/sklearnex/preview/decomposition/incremental_pca.py index fdf13e0817..185a137087 100644 --- a/sklearnex/preview/decomposition/incremental_pca.py +++ b/sklearnex/preview/decomposition/incremental_pca.py @@ -22,6 +22,7 @@ from daal4py.sklearn._utils import sklearn_check_version from onedal.decomposition import IncrementalPCA as onedal_IncrementalPCA +from ..._config import get_config from ..._device_offload import dispatch, wrap_output_data from ..._utils import PatchingConditionsChain @@ -36,21 +37,22 @@ ) class IncrementalPCA(_sklearn_IncrementalPCA): + _need_to_finalize_attrs = { + "mean_", + "explained_variance_", + "explained_variance_ratio_", + "n_components_", + "components_", + "noise_variance_", + "singular_values_", + "var_", + } + def __init__(self, n_components=None, *, whiten=False, copy=True, batch_size=None): super().__init__( n_components=n_components, whiten=whiten, copy=copy, batch_size=batch_size ) self._need_to_finalize = False - self._need_to_finalize_attrs = { - "mean_", - "explained_variance_", - "explained_variance_ratio_", - "n_components_", - "components_", - "noise_variance_", - "singular_values_", - "var_", - } _onedal_incremental_pca = staticmethod(onedal_IncrementalPCA) @@ -68,6 +70,9 @@ def _onedal_fit_transform(self, X, queue=None): def _onedal_partial_fit(self, X, check_input=True, queue=None): first_pass = not hasattr(self, "_onedal_estimator") + use_raw_input = get_config()["use_raw_input"] + # never check input when using raw input + check_input &= use_raw_input is False if check_input: if sklearn_check_version("1.0"): X = validate_data( @@ -161,18 +166,20 @@ def _onedal_supported(self, method_name, *data): _onedal_gpu_supported = _onedal_supported def __getattr__(self, attr): - if attr in self._need_to_finalize_attrs: - if hasattr(self, "_onedal_estimator"): - if self._need_to_finalize: - self._onedal_finalize_fit() - return getattr(self._onedal_estimator, attr) - else: + # finalize the fit if requested attribute requires it + if attr in IncrementalPCA._need_to_finalize_attrs: + if "_onedal_estimator" not in self.__dict__: + # _onedal_estimator required to finalize the fit raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{attr}'" + f"Requested postfit attribute '{attr}' before fitting the model." ) - if attr in self.__dict__: - return self.__dict__[attr] - + if self.__dict__["_need_to_finalize"]: + self._onedal_finalize_fit() + # join attributes of the class and the onedal_estimator to provide common interface + joined = self.__dict__ | self.__dict__.get("_onedal_estimator", {}).__dict__ + if attr in joined: + return joined[attr] + # raise AttributeError if attribute is neither in this class nor in _onedal_estimator raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{attr}'" ) diff --git a/sklearnex/preview/decomposition/tests/test_incremental_pca.py b/sklearnex/preview/decomposition/tests/test_incremental_pca.py index c4c47c8adb..fb86a0d441 100644 --- a/sklearnex/preview/decomposition/tests/test_incremental_pca.py +++ b/sklearnex/preview/decomposition/tests/test_incremental_pca.py @@ -24,6 +24,7 @@ _convert_to_dataframe, get_dataframes_and_queues, ) +from sklearnex import config_context from sklearnex.preview.decomposition import IncrementalPCA @@ -245,8 +246,18 @@ def test_sklearnex_fit_transform_on_gold_data( @pytest.mark.parametrize("row_count", [100, 1000]) @pytest.mark.parametrize("column_count", [10, 100]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("use_raw_input", [True, False]) def test_sklearnex_partial_fit_on_random_data( - dataframe, queue, n_components, whiten, num_blocks, row_count, column_count, dtype + skip_unsupported_raw_input, + dataframe, + queue, + n_components, + whiten, + num_blocks, + row_count, + column_count, + dtype, + use_raw_input, ): seed = 81 gen = np.random.default_rng(seed) @@ -254,12 +265,12 @@ def test_sklearnex_partial_fit_on_random_data( X = X.astype(dtype=dtype) X_split = np.array_split(X, num_blocks) incpca = IncrementalPCA(n_components=n_components, whiten=whiten) - - for i in range(num_blocks): - X_split_df = _convert_to_dataframe( - X_split[i], sycl_queue=queue, target_df=dataframe - ) - incpca.partial_fit(X_split_df) + with config_context(use_raw_input=use_raw_input): + for i in range(num_blocks): + X_split_df = _convert_to_dataframe( + X_split[i], sycl_queue=queue, target_df=dataframe + ) + incpca.partial_fit(X_split_df) X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) transformed_data = incpca.transform(X_df)