Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement] make BasicStatistics and IncrementalBasicStatistics array_api-compliant #2189

Draft
wants to merge 145 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
145 commits
Select commit Hold shift + click to select a range
32fe269
add finiteness_checker pybind11 bindings
icfaust Oct 23, 2024
cdbf1b5
added finiteness checker
icfaust Oct 23, 2024
62674a2
Update finiteness_checker.cpp
icfaust Oct 23, 2024
c75c23b
Update finiteness_checker.cpp
icfaust Oct 23, 2024
6a20938
Update finiteness_checker.cpp
icfaust Oct 23, 2024
382d7a1
Update finiteness_checker.cpp
icfaust Oct 23, 2024
c8ffd9c
Update finiteness_checker.cpp
icfaust Oct 23, 2024
9aa13d5
Update finiteness_checker.cpp
icfaust Oct 23, 2024
84e15d5
Rename finiteness_checker.cpp to finiteness_checker.cpp
icfaust Oct 23, 2024
63073c6
Update finiteness_checker.cpp
icfaust Oct 24, 2024
d915da5
Merge branch 'intel:main' into dev/new_assert_all_fininte
icfaust Oct 28, 2024
3dddf2d
add next step
icfaust Oct 31, 2024
1e1213e
follow conventions
icfaust Oct 31, 2024
0531713
make xtable explicit
icfaust Oct 31, 2024
e831167
remove comment
icfaust Oct 31, 2024
d6eb1d0
Update validation.py
icfaust Oct 31, 2024
fb30d6e
Update __init__.py
icfaust Nov 1, 2024
63a18c2
Update validation.py
icfaust Nov 1, 2024
76c0856
Update __init__.py
icfaust Nov 1, 2024
7deb2bb
Update __init__.py
icfaust Nov 1, 2024
ed46b29
Update validation.py
icfaust Nov 1, 2024
67d6273
Update _data_conversion.py
icfaust Nov 1, 2024
054f0a1
Merge branch 'main' into dev/new_assert_all_fininte
icfaust Nov 1, 2024
8abead9
Update _data_conversion.py
icfaust Nov 1, 2024
47d0f8b
Update policy_common.cpp
icfaust Nov 1, 2024
e48c2bd
Update policy_common.cpp
icfaust Nov 1, 2024
c6751c4
Update _policy.py
icfaust Nov 1, 2024
f3e4a3a
Update policy_common.cpp
icfaust Nov 2, 2024
39cdb5f
Rename finiteness_checker.cpp to finiteness_checker.cpp
icfaust Nov 2, 2024
0f39613
Create finiteness_checker.py
icfaust Nov 2, 2024
b42cfe3
Update validation.py
icfaust Nov 2, 2024
0ed615e
Update __init__.py
icfaust Nov 2, 2024
f101aff
attempt at fixing circular imports again
icfaust Nov 2, 2024
24c0e94
fix isort
icfaust Nov 2, 2024
3f96166
remove __init__ changes
icfaust Nov 2, 2024
d985053
last move
icfaust Nov 2, 2024
90ec48b
Update policy_common.cpp
icfaust Nov 2, 2024
8c2c854
Update policy_common.cpp
icfaust Nov 2, 2024
6fa38d7
Update policy_common.cpp
icfaust Nov 2, 2024
9c1ca9c
Update policy_common.cpp
icfaust Nov 2, 2024
4b67dbd
Update validation.py
icfaust Nov 2, 2024
fa59a3c
add testing
icfaust Nov 2, 2024
3330b33
isort
icfaust Nov 2, 2024
4895940
attempt to fix module error
icfaust Nov 2, 2024
0c6dd5d
add fptype
icfaust Nov 2, 2024
e2182fa
fix typo
icfaust Nov 2, 2024
982ef2c
Update validation.py
icfaust Nov 2, 2024
2fb52a8
remove sua_ifcae from to_table
icfaust Nov 3, 2024
28dc267
isort and black
icfaust Nov 3, 2024
2f85fd4
Update test_memory_usage.py
icfaust Nov 3, 2024
8659248
format
icfaust Nov 3, 2024
3827d6f
Update _data_conversion.py
icfaust Nov 3, 2024
55fa7d2
Update _data_conversion.py
icfaust Nov 3, 2024
175cd78
Update test_validation.py
icfaust Nov 3, 2024
7016ad0
remove unnecessary code
icfaust Nov 3, 2024
1a01859
Merge branch 'main' into dev/new_assert_all_fininte
icfaust Nov 18, 2024
2fbcdd9
merge master
icfaust Nov 18, 2024
fb7375f
make reviewer changes
icfaust Nov 19, 2024
30816bf
make dtype check change
icfaust Nov 19, 2024
abb3b16
add sparse testing
icfaust Nov 19, 2024
97aef73
try again
icfaust Nov 19, 2024
6e29651
try again
icfaust Nov 19, 2024
59363a8
try again
icfaust Nov 19, 2024
12de703
temporary commit
icfaust Nov 20, 2024
07ec3d8
first attempt
icfaust Nov 20, 2024
32c565d
missing change?
icfaust Nov 20, 2024
a571a4e
Merge branch 'intel:main' into dev/sklearnex_assert_all_finite
icfaust Nov 20, 2024
5093ed7
modify DummyEstimator for testing
icfaust Nov 20, 2024
f04deba
generalize DummyEstimator
icfaust Nov 20, 2024
740a5e7
switch test
icfaust Nov 20, 2024
27050bd
further testing changes
icfaust Nov 20, 2024
53c8f7b
add initial validate_data test, will be refactored
icfaust Nov 20, 2024
90f59c4
fixes for CI
icfaust Nov 20, 2024
7f170e2
Update validation.py
icfaust Nov 20, 2024
81e2bbc
Update validation.py
icfaust Nov 20, 2024
116bdba
Update test_memory_usage.py
icfaust Nov 20, 2024
076ebc4
Update base.py
icfaust Nov 20, 2024
e1d0743
Update base.py
icfaust Nov 20, 2024
f59cdd3
improve tests
icfaust Nov 20, 2024
7f9ea25
fix logic
icfaust Nov 20, 2024
51247c0
fix logic
icfaust Nov 20, 2024
6e5c0ef
fix logic again
icfaust Nov 20, 2024
8d47744
rename file
icfaust Nov 20, 2024
1ae9af5
Revert "rename file"
icfaust Nov 20, 2024
bf9b46e
remove duplication
icfaust Nov 20, 2024
3101c3f
fix imports
icfaust Nov 20, 2024
6da176b
Merge branch 'intel:main' into dev/sklearnex_assert_all_finite
icfaust Nov 20, 2024
ee799f6
Rename test_finite.py to test_validation.py
icfaust Nov 20, 2024
db4a6c6
Revert "Rename test_finite.py to test_validation.py"
icfaust Nov 20, 2024
b5acbac
updates
icfaust Nov 21, 2024
ed57c15
Update validation.py
icfaust Nov 21, 2024
414f897
fixes for some test failures
icfaust Nov 21, 2024
83253b3
fix text
icfaust Nov 21, 2024
b22e23a
fixes for some failures
icfaust Nov 21, 2024
2f8ec16
make consistent
icfaust Nov 21, 2024
1fd9973
fix bad logic
icfaust Nov 21, 2024
c20c8cc
fix in string
icfaust Nov 21, 2024
1ce1b10
attempt tp see if dataframe conversion is causing the issue
icfaust Nov 21, 2024
5355039
fix iter problem
icfaust Nov 21, 2024
b5b8442
fix testing issues
icfaust Nov 21, 2024
d025c89
formatting
icfaust Nov 21, 2024
428bfb6
revert change
icfaust Nov 21, 2024
da23138
fixes for pandas
icfaust Nov 21, 2024
1d0c330
there is a slowdown with pandas that needs to be solved
icfaust Nov 21, 2024
f3f63a6
swap to transpose for speed
icfaust Nov 21, 2024
56c8054
more clarity
icfaust Nov 21, 2024
1580d77
add _check_sample_weight
icfaust Nov 22, 2024
ffc9f1f
add more testing'
icfaust Nov 22, 2024
d184ed0
rename
icfaust Nov 22, 2024
c68616f
remove unnecessary imports
icfaust Nov 22, 2024
e7ea94e
fix test slowness
icfaust Nov 22, 2024
dbe108d
focus get_dataframes_and_queues
icfaust Nov 22, 2024
7284b59
put config_context around
icfaust Nov 22, 2024
e1be91d
Update test_validation.py
icfaust Nov 24, 2024
8a0f9e9
Update base.py
icfaust Nov 24, 2024
5272207
Update test_validation.py
icfaust Nov 24, 2024
21a7896
Merge branch 'intel:main' into dev/sklearnex_assert_all_finite
icfaust Nov 24, 2024
56b5c4c
generalize regex
icfaust Nov 25, 2024
0d1b306
add fixes for sklearn 1.0 and input_name
icfaust Nov 25, 2024
8ff312e
fixes for test failures
icfaust Nov 25, 2024
87b7e3b
Update validation.py
icfaust Nov 25, 2024
29e8f8c
Update test_validation.py
icfaust Nov 25, 2024
527ce22
Merge branch 'intel:main' into dev/sklearnex_assert_all_finite
icfaust Nov 25, 2024
1175a98
don't have more time at the moment to do this.
icfaust Nov 26, 2024
50ba766
remove old code
icfaust Nov 28, 2024
05ef656
interim stop
icfaust Nov 28, 2024
68ffc45
attempt at fixing
icfaust Nov 28, 2024
cfeb2c5
remover abstractmethod
icfaust Nov 28, 2024
d3a69c6
fix issues
icfaust Nov 28, 2024
c74485d
fix sample weights
icfaust Nov 28, 2024
ee3c475
remove numpy
icfaust Nov 28, 2024
e135c47
try again
icfaust Nov 28, 2024
39257bb
reintroduce _compute_raw for kmeans
icfaust Nov 28, 2024
afed175
formatting
icfaust Nov 28, 2024
71cb39c
iterable fix
icfaust Nov 28, 2024
fcb543c
make stricter
icfaust Nov 28, 2024
11f3c76
attempt at fixing recursion issue
icfaust Nov 28, 2024
8c1981a
merge master
icfaust Nov 28, 2024
8581551
Update basic_statistics.py
icfaust Nov 29, 2024
5334b38
Update incremental_basic_statistics.py
icfaust Nov 29, 2024
5f353c6
remove todo
icfaust Nov 29, 2024
b3ece1e
Update basic_statistics.py
icfaust Dec 1, 2024
2ebf71b
Update basic_statistics.py
icfaust Dec 1, 2024
8e4cde0
Merge branch 'main' into dev/bs_zero
icfaust Dec 3, 2024
60aeaa6
warning removal from BS examples
icfaust Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/sklearnex/basic_statistics_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ def generate_data(par, size, seed=777):
bss = BasicStatisticsSpmd(["mean", "standard_deviation"])
bss.fit(dpt_data, dpt_weights)

print(f"Computed mean on rank {rank}:\n", bss.mean)
print(f"Computed std on rank {rank}:\n", bss.standard_deviation)
print(f"Computed mean on rank {rank}:\n", bss.mean_)
print(f"Computed std on rank {rank}:\n", bss.standard_deviation_)
12 changes: 6 additions & 6 deletions examples/sklearnex/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
X_3 = np.array([[1, 1], [1, 2], [2, 3]])
result = incbs.partial_fit(X_3)

print(f"Mean:\n{result.mean}")
print(f"Max:\n{result.max}")
print(f"Sum:\n{result.sum}")
print(f"Mean:\n{result.mean_}")
print(f"Max:\n{result.max_}")
print(f"Sum:\n{result.sum_}")

# We put the whole data to fit method, it is split automatically and then
# partial_fit is called for each batch.
incbs = IncrementalBasicStatistics(result_options=["mean", "max", "sum"], batch_size=3)
X = np.array([[0, 1], [0, 1], [1, 2], [1, 1], [1, 2], [2, 3]])
result = incbs.fit(X)

print(f"Mean:\n{result.mean}")
print(f"Max:\n{result.max}")
print(f"Sum:\n{result.sum}")
print(f"Mean:\n{result.mean_}")
print(f"Max:\n{result.max_}")
print(f"Sum:\n{result.sum_}")
12 changes: 6 additions & 6 deletions examples/sklearnex/incremental_basic_statistics_dpctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@
X_3 = dpt.asarray([[1, 1], [1, 2], [2, 3]], sycl_queue=queue)
result = incbs.partial_fit(X_3)

print(f"Mean:\n{result.mean}")
print(f"Max:\n{result.max}")
print(f"Sum:\n{result.sum}")
print(f"Mean:\n{result.mean_}")
print(f"Max:\n{result.max_}")
print(f"Sum:\n{result.sum_}")

# We put the whole data to fit method, it is split automatically and then
# partial_fit is called for each batch.
incbs = IncrementalBasicStatistics(result_options=["mean", "max", "sum"], batch_size=3)
X = dpt.asarray([[0, 1], [0, 1], [1, 2], [1, 1], [1, 2], [2, 3]], sycl_queue=queue)
result = incbs.fit(X)

print(f"Mean:\n{result.mean}")
print(f"Max:\n{result.max}")
print(f"Sum:\n{result.sum}")
print(f"Mean:\n{result.mean_}")
print(f"Max:\n{result.max_}")
print(f"Sum:\n{result.sum_}")
73 changes: 30 additions & 43 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
import warnings
from abc import ABCMeta, abstractmethod

import numpy as np

from ..common._base import BaseEstimator
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _is_csr
from ..utils.validation import _check_array


class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta):
@abstractmethod
def __init__(self, result_options, algorithm):
class BasicStatistics(BaseEstimator, metaclass=ABCMeta):
"""
Basic Statistics oneDAL implementation.
"""

def __init__(self, result_options="all", algorithm="by_default"):
self.options = result_options
self.algorithm = algorithm

Expand All @@ -46,62 +46,49 @@ def get_all_result_options():
"second_order_raw_moment",
]

def _get_result_options(self, options):
if options == "all":
options = self.get_all_result_options()
if isinstance(options, list):
options = "|".join(options)
assert isinstance(options, str)
return options
@property
def options(self):
if self._options == ["all"]:
return self.get_all_result_options()
return self._options

@options.setter
def options(self, opts):
# options always to be an iterable
self._options = opts.split("|") if isinstance(opts, str) else opts

def _get_onedal_params(self, is_csr, dtype=np.float32):
options = self._get_result_options(self.options)
def _get_onedal_params(self, is_csr, dtype=None):
return {
"fptype": dtype,
"method": "sparse" if is_csr else self.algorithm,
"result_option": options,
"result_option": "|".join(self.options),
}


class BasicStatistics(BaseBasicStatistics):
"""
Basic Statistics oneDAL implementation.
"""

def __init__(self, result_options="all", algorithm="by_default"):
super().__init__(result_options, algorithm)

def fit(self, data, sample_weight=None, queue=None):
policy = self._get_policy(queue, data, sample_weight)

is_csr = _is_csr(data)

if data is not None and not is_csr:
data = _check_array(data, ensure_2d=False)
if sample_weight is not None:
sample_weight = _check_array(sample_weight, ensure_2d=False)

data, sample_weight = _convert_to_supported(policy, data, sample_weight)
is_single_dim = data.ndim == 1
data_table, weights_table = to_table(data, sample_weight)
data, sample_weight = to_table(
*_convert_to_supported(policy, data, sample_weight)
)

dtype = data.dtype
raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)
for opt, raw_value in raw_result.items():
value = from_table(raw_value).ravel()
result = self._compute_raw(data, sample_weight, policy, data.dtype, is_csr)

for opt in self.options:
value = from_table(getattr(result, opt))[0] # two-dimensional table [1, n]
if is_single_dim:
setattr(self, opt, value[0])
else:
setattr(self, opt, value)

return self

def _compute_raw(
self, data_table, weights_table, policy, dtype=np.float32, is_csr=False
):
def _compute_raw(self, data_table, weights_table, policy, dtype=None, is_csr=False):
# This function is maintained for internal use by KMeans tolerance
# calculations, but is otherwise considered legacy code and is not
# to be used externally in any circumstance
module = self._get_backend("basic_statistics")
params = self._get_onedal_params(is_csr, dtype)
result = module.compute(policy, params, data_table, weights_table)
options = self._get_result_options(self.options).split("|")

return {opt: getattr(result, opt) for opt in options}
return module.compute(policy, params, data_table, weights_table)
41 changes: 11 additions & 30 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
# limitations under the License.
# ==============================================================================

import numpy as np

from daal4py.sklearn._utils import get_dtype

from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
from .basic_statistics import BaseBasicStatistics
from .basic_statistics import BasicStatistics


class IncrementalBasicStatistics(BaseBasicStatistics):
class IncrementalBasicStatistics(BasicStatistics):
"""
Incremental estimator for basic statistics based on oneDAL implementation.
Allows to compute basic statistics if data are splitted into batches.
Expand Down Expand Up @@ -65,8 +60,8 @@ class IncrementalBasicStatistics(BaseBasicStatistics):
Second order moment of each feature over all samples.
"""

def __init__(self, result_options="all"):
super().__init__(result_options, algorithm="by_default")
def __init__(self, result_options="all", algorithm="by_default"):
super().__init__(result_options, algorithm)
self._reset()

def _reset(self):
Expand All @@ -85,7 +80,7 @@ def __getstate__(self):

return data

def partial_fit(self, X, weights=None, queue=None):
def partial_fit(self, X, sample_weight=None, queue=None):
"""
Computes partial data for basic statistics
from data batch X and saves it to `_partial_result`.
Expand All @@ -106,33 +101,20 @@ def partial_fit(self, X, weights=None, queue=None):
"""
self._queue = queue
policy = self._get_policy(queue, X)
X, weights = _convert_to_supported(policy, X, weights)

X = _check_array(
X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False
)
if weights is not None:
weights = _check_array(
weights,
dtype=[np.float64, np.float32],
ensure_2d=False,
force_all_finite=False,
)
X, sample_weight = to_table(*_convert_to_supported(policy, X, sample_weight))

if not hasattr(self, "_onedal_params"):
dtype = get_dtype(X)
self._onedal_params = self._get_onedal_params(False, dtype=dtype)
self._onedal_params = self._get_onedal_params(False, dtype=X.dtype)

X_table, weights_table = to_table(X, weights)
self._partial_result = self._get_backend(
"basic_statistics",
None,
"partial_compute",
policy,
self._onedal_params,
self._partial_result,
X_table,
weights_table,
X,
sample_weight,
)

self._need_to_finalize = True
Expand Down Expand Up @@ -167,9 +149,8 @@ def finalize_fit(self, queue=None):
self._onedal_params,
self._partial_result,
)
options = self._get_result_options(self.options).split("|")
for opt in options:
setattr(self, opt, from_table(getattr(result, opt)).ravel())
for opt in self.options:
setattr(self, opt, from_table(getattr(result, opt))[0])

self._need_to_finalize = False

Expand Down
2 changes: 1 addition & 1 deletion onedal/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _tolerance(self, X_table, rtol, is_csr, policy, dtype):
bs = self._get_basic_statistics_backend("variance")

res = bs._compute_raw(X_table, dummy, policy, dtype, is_csr)
mean_var = from_table(res["variance"]).mean()
mean_var = from_table(res.variance).mean()

return mean_var * rtol

Expand Down
8 changes: 1 addition & 7 deletions onedal/spmd/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,4 @@


class BasicStatistics(BaseEstimatorSPMD, BasicStatistics_Batch):
@support_input_format()
def compute(self, data, weights=None, queue=None):
return super().compute(data, weights=weights, queue=queue)

@support_input_format()
def fit(self, data, sample_weight=None, queue=None):
return super().fit(data, sample_weight=sample_weight, queue=queue)
pass
12 changes: 5 additions & 7 deletions onedal/spmd/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _reset(self):
"basic_statistics", None, "partial_compute_result"
)

def partial_fit(self, X, weights=None, queue=None):
def partial_fit(self, X, sample_weight=None, queue=None):
"""
Computes partial data for basic statistics
from data batch X and saves it to `_partial_result`.
Expand All @@ -51,22 +51,20 @@ def partial_fit(self, X, weights=None, queue=None):
"""
self._queue = queue
policy = super(base_IncrementalBasicStatistics, self)._get_policy(queue, X)
X, weights = _convert_to_supported(policy, X, weights)
X, sample_weight = to_table(*_convert_to_supported(policy, X, sample_weight))

if not hasattr(self, "_onedal_params"):
dtype = get_dtype(X)
self._onedal_params = self._get_onedal_params(False, dtype=dtype)
self._onedal_params = self._get_onedal_params(False, dtype=X.dtype)

X_table, weights_table = to_table(X, weights)
self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend(
"basic_statistics",
None,
"partial_compute",
policy,
self._onedal_params,
self._partial_result,
X_table,
weights_table,
X,
sample_weight,
)

self._need_to_finalize = True
Expand Down
45 changes: 14 additions & 31 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,17 @@

import warnings

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from sklearn.utils.validation import _check_sample_weight

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics

from .._device_offload import dispatch
from .._utils import IntelEstimator, PatchingConditionsChain

if sklearn_check_version("1.6"):
from sklearn.utils.validation import validate_data
else:
validate_data = BaseEstimator._validate_data
from ..utils._array_api import get_namespace
from ..utils.validation import _check_sample_weight, validate_data

if sklearn_check_version("1.2"):
from sklearn.utils._param_validation import StrOptions
Expand Down Expand Up @@ -130,30 +125,15 @@ def __init__(self, result_options="all"):

def _save_attributes(self):
assert hasattr(self, "_onedal_estimator")

if self.result_options == "all":
result_options = onedal_BasicStatistics.get_all_result_options()
else:
result_options = self.result_options

if isinstance(result_options, str):
setattr(
self,
result_options + "_",
getattr(self._onedal_estimator, result_options),
)
elif isinstance(result_options, list):
for option in result_options:
setattr(self, option + "_", getattr(self._onedal_estimator, option))
for option in self._onedal_estimator.options:
setattr(self, option + "_", getattr(self._onedal_estimator, option))

def __getattr__(self, attr):
if self.result_options == "all":
result_options = onedal_BasicStatistics.get_all_result_options()
else:
result_options = self.result_options
is_deprecated_attr = (
isinstance(result_options, str) and (attr == result_options)
) or (isinstance(result_options, list) and (attr in result_options))
attr in self._onedal_estimator.options
if "_onedal_estimator" in self.__dict__
else False
)
if is_deprecated_attr:
warnings.warn(
"Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0"
Expand All @@ -179,13 +159,16 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):
if sklearn_check_version("1.2"):
self._validate_params()

xp, _ = get_namespace(X)
if sklearn_check_version("1.0"):
X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False)
X = validate_data(self, X, dtype=[xp.float64, xp.float32], ensure_2d=False)
else:
X = check_array(X, dtype=[np.float64, np.float32])
X = check_array(X, dtype=[xp.float64, xp.float32])

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)
sample_weight = _check_sample_weight(
sample_weight, X, dtype=[xp.float64, xp.float32]
)

onedal_params = {
"result_options": self.result_options,
Expand Down
Loading