Skip to content

Accept column vectors when having binary or multiclass targets #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 2, 2020
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ var/
*.egg-info/
.installed.cfg
*.egg
Pipfile
Pipfile.lock

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
3 changes: 3 additions & 0 deletions imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.base import clone
from sklearn.neighbors._base import KNeighborsMixin
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import column_or_1d
from sklearn.utils.multiclass import type_of_target

from ..exceptions import raise_isinstance_error
Expand Down Expand Up @@ -96,6 +97,8 @@ def check_target_type(y, indicate_one_vs_all=False):
"multioutput targets are not supported."
)
y = y.argmax(axis=1)
else:
y = column_or_1d(y)

return (y, type_y == "multilabel-indicator") if indicate_one_vs_all else y

Expand Down
15 changes: 15 additions & 0 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _yield_sampler_checks(name, Estimator):
yield check_samplers_multiclass_ova
yield check_samplers_preserve_dtype
yield check_samplers_sample_indices
yield check_samplers_2d_target


def _yield_classifier_checks(name, Estimator):
Expand Down Expand Up @@ -283,6 +284,20 @@ def check_samplers_multiclass_ova(name, Sampler):
assert_allclose(y_res, y_res_ova.argmax(axis=1))


def check_samplers_2d_target(name, Sampler):
X, y = make_classification(
n_samples=100,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)

y = y.reshape(-1, 1) # Make the target 2d
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add another check when having pandas Series and DataFrame

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to make a follow up PR to address #666. So, this check could be added in that same PR. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK good then

sampler = Sampler()
sampler.fit_resample(X, y)


def check_samplers_preserve_dtype(name, Sampler):
X, y = make_classification(
n_samples=1000,
Expand Down