Skip to content

Commit

Permalink
Fix failure with squared in scikit-learn>=1.6.0
Browse files Browse the repository at this point in the history
The `squared` parameter of functions
`mean_squared_error` and `mean_squared_log_error`
was removed in scikit-learn>=1.6.0.
  • Loading branch information
vnmabus committed Feb 2, 2025
1 parent 15c2360 commit 6183d26
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions skfda/misc/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,12 +750,17 @@ def mean_squared_error(
multioutput = 'raw_values', ndarray.
"""
return sklearn.metrics.mean_squared_error( # type: ignore [no-any-return]
function = (
sklearn.metrics.mean_squared_error
if squared
else sklearn.metrics.root_mean_squared_error
)

return function( # type: ignore [no-any-return]
y_true,
y_pred,
sample_weight=sample_weight,
multioutput=multioutput,
squared=squared,
)


Expand Down Expand Up @@ -919,13 +924,18 @@ def mean_squared_log_error(
multioutput = 'raw_values', ndarray.
"""
function = (
sklearn.metrics.mean_squared_log_error
if squared
else sklearn.metrics.root_mean_squared_log_error
)

return ( # type: ignore [no-any-return]
sklearn.metrics.mean_squared_log_error(
function(
y_true,
y_pred,
sample_weight=sample_weight,
multioutput=multioutput,
squared=squared,
)
)

Expand Down

0 comments on commit 6183d26

Please sign in to comment.