Skip to content

Commit ddebb21

Browse files
committed
fixing bad push
1 parent ff9b82c commit ddebb21

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklearn/metrics/_regression.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,15 @@ def mean_absolute_percentage_error(
397397
"""
398398
input_arrays = [y_true, y_pred, sample_weight, multioutput]
399399
xp, _ = get_namespace(*input_arrays)
400+
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
400401

401402
y_type, y_true, y_pred, multioutput = _check_reg_targets(
402403
y_true, y_pred, multioutput
403404
)
404405
check_consistent_length(y_true, y_pred, sample_weight)
405-
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=xp.asarray(0.0).dtype)
406-
y_true_abs = xp.asarray(xp.abs(y_true), dtype=xp.asarray(0.0).dtype)
407-
mape = xp.asarray(xp.abs(y_pred - y_true), dtype=xp.asarray(0.0).dtype) / xp.where(
406+
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=dtype)
407+
y_true_abs = xp.asarray(xp.abs(y_true), dtype=dtype)
408+
mape = xp.asarray(xp.abs(y_pred - y_true), dtype=dtype) / xp.where(
408409
epsilon < y_true_abs, y_true_abs, epsilon
409410
)
410411
output_errors = _average(mape, weights=sample_weight, axis=0)

0 commit comments

Comments
 (0)