@@ -397,14 +397,15 @@ def mean_absolute_percentage_error(
397
397
"""
398
398
input_arrays = [y_true , y_pred , sample_weight , multioutput ]
399
399
xp , _ = get_namespace (* input_arrays )
400
+ dtype = _find_matching_floating_dtype (y_true , y_pred , sample_weight , xp = xp )
400
401
401
402
y_type , y_true , y_pred , multioutput = _check_reg_targets (
402
403
y_true , y_pred , multioutput
403
404
)
404
405
check_consistent_length (y_true , y_pred , sample_weight )
405
- epsilon = xp .asarray (xp .finfo (xp .float64 ).eps , dtype = xp . asarray ( 0.0 ). dtype )
406
- y_true_abs = xp .asarray (xp .abs (y_true ), dtype = xp . asarray ( 0.0 ). dtype )
407
- mape = xp .asarray (xp .abs (y_pred - y_true ), dtype = xp . asarray ( 0.0 ). dtype ) / xp .where (
406
+ epsilon = xp .asarray (xp .finfo (xp .float64 ).eps , dtype = dtype )
407
+ y_true_abs = xp .asarray (xp .abs (y_true ), dtype = dtype )
408
+ mape = xp .asarray (xp .abs (y_pred - y_true ), dtype = dtype ) / xp .where (
408
409
epsilon < y_true_abs , y_true_abs , epsilon
409
410
)
410
411
output_errors = _average (mape , weights = sample_weight , axis = 0 )
0 commit comments