Skip to content

[MRG] Deprecate random_state in OneClassSVM and add clarifications in docstrings and doc #9703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions doc/modules/svm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,12 @@ Then ``dual_coef_`` looks like this:
Scores and probabilities
------------------------

The :class:`SVC` method ``decision_function`` gives per-class scores
for each sample (or a single score per sample in the binary case).
When the constructor option ``probability`` is set to ``True``,
class membership probability estimates
(from the methods ``predict_proba`` and ``predict_log_proba``) are enabled.
In the binary case, the probabilities are calibrated using Platt scaling:
logistic regression on the SVM's scores,
The ``decision_function`` method of :class:`SVC` and :class:`NuSVC` gives
per-class scores for each sample (or a single score per sample in the binary
case). When the constructor option ``probability`` is set to ``True``,
class membership probability estimates (from the methods ``predict_proba`` and
``predict_log_proba``) are enabled. In the binary case, the probabilities are
calibrated using Platt scaling: logistic regression on the SVM's scores,
fit by an additional cross-validation on the training data.
In the multiclass case, this is extended as per Wu et al. (2004).

Expand All @@ -245,7 +244,7 @@ and use ``decision_function`` instead of ``predict_proba``.

* Platt
`"Probabilistic outputs for SVMs and comparisons to regularized likelihood methods"
<http://www.cs.colorado.edu/~mozer/Teaching/syllabi/6622/papers/Platt1999.pdf>`.
<http://www.cs.colorado.edu/~mozer/Teaching/syllabi/6622/papers/Platt1999.pdf>`_.

Unbalanced problems
--------------------
Expand Down Expand Up @@ -399,7 +398,7 @@ Tips on Practical Use
function can be configured to be almost the same as the :class:`LinearSVC`
model.

* **Kernel cache size**: For :class:`SVC`, :class:`SVR`, :class:`nuSVC` and
* **Kernel cache size**: For :class:`SVC`, :class:`SVR`, :class:`NuSVC` and
:class:`NuSVR`, the size of the kernel cache has a strong impact on run
times for larger problems. If you have enough RAM available, it is
recommended to set ``cache_size`` to a higher value than the default of
Expand All @@ -423,10 +422,24 @@ Tips on Practical Use
positive and few negative), set ``class_weight='balanced'`` and/or try
different penalty parameters ``C``.

* The underlying :class:`LinearSVC` implementation uses a random
number generator to select features when fitting the model. It is
thus not uncommon, to have slightly different results for the same
input data. If that happens, try with a smaller tol parameter.
* **Randomness of the underlying implementations**: The underlying
implementations of :class:`SVC` and :class:`NuSVC` use a random number
generator only to shuffle the data for probability estimation (when
``probability`` is set to ``True``). This randomness can be controlled
with the ``random_state`` parameter. If ``probability`` is set to ``False``
these estimators are not random and ``random_state`` has no effect on the
results. The underlying :class:`OneClassSVM` implementation is similar to
the ones of :class:`SVC` and :class:`NuSVC`. As no probability estimation
is provided for :class:`OneClassSVM`, it is not random.

The underlying :class:`LinearSVC` implementation uses a random number
generator to select features when fitting the model with a dual coordinate
descent (i.e when ``dual`` is set to ``True``). It is thus not uncommon,
to have slightly different results for the same input data. If that
happens, try with a smaller tol parameter. This randomness can also be
controlled with the ``random_state`` parameter. When ``dual`` is
set to ``False`` the underlying implementation of :class:`LinearSVC` is
not random and ``random_state`` has no effect on the results.

* Using L1 penalization as provided by ``LinearSVC(loss='l2', penalty='l1',
dual=False)`` yields a sparse solution, i.e. only a subset of feature
Expand Down
15 changes: 15 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ Model evaluation and meta-estimators
- A scorer based on :func:`metrics.brier_score_loss` is also available.
:issue:`9521` by :user:`Hanmin Qin <qinhanmin2014>`.

Linear, kernelized and related models

- Deprecate ``random_state`` parameter in :class:`svm.OneClassSVM` as the
underlying implementation is not random.
:issue:`9497` by :user:`Albert Thomas <albertcthomas>`.

Bug fixes
.........

Expand All @@ -82,6 +88,15 @@ Decomposition, manifold learning and clustering
where all samples had equal similarity.
:issue:`9612`. By :user:`Jonatan Samoocha <jsamoocha>`.

API changes summary
-------------------

Linear, kernelized and related models

- Deprecate ``random_state`` parameter in :class:`svm.OneClassSVM` as the
underlying implementation is not random.
:issue:`9497` by :user:`Albert Thomas <albertcthomas>`.

Version 0.19
============

Expand Down
46 changes: 27 additions & 19 deletions sklearn/svm/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,13 @@ class LinearSVC(BaseEstimator, LinearClassifierMixin,

random_state : int, RandomState instance or None, optional (default=None)
The seed of the pseudo random number generator to use when shuffling
the data. If int, random_state is the seed used by the random number
generator; If RandomState instance, random_state is the random number
generator; If None, the random number generator is the RandomState
instance used by `np.random`.
the data for the dual coordinate descent (if ``dual=True``). When
``dual=False`` the underlying implementation of :class:`LinearSVC`
is not random and ``random_state`` has no effect on the results. If
int, random_state is the seed used by the random number generator; If
RandomState instance, random_state is the random number generator; If
None, the random number generator is the RandomState instance used by
`np.random`.

max_iter : int, (default=1000)
The maximum number of iterations to be run.
Expand Down Expand Up @@ -509,11 +512,11 @@ class SVC(BaseSVC):
Deprecated *decision_function_shape='ovo' and None*.

random_state : int, RandomState instance or None, optional (default=None)
The seed of the pseudo random number generator to use when shuffling
the data. If int, random_state is the seed used by the random number
generator; If RandomState instance, random_state is the random number
generator; If None, the random number generator is the RandomState
instance used by `np.random`.
The seed of the pseudo random number generator used when shuffling
the data for probability estimates. If int, random_state is the
seed used by the random number generator; If RandomState instance,
random_state is the random number generator; If None, the random
number generator is the RandomState instance used by `np.random`.

Attributes
----------
Expand Down Expand Up @@ -665,11 +668,11 @@ class NuSVC(BaseSVC):
Deprecated *decision_function_shape='ovo' and None*.

random_state : int, RandomState instance or None, optional (default=None)
The seed of the pseudo random number generator to use when shuffling
the data. If int, random_state is the seed used by the random number
generator; If RandomState instance, random_state is the random number
generator; If None, the random number generator is the RandomState
instance used by `np.random`.
The seed of the pseudo random number generator used when shuffling
the data for probability estimates. If int, random_state is the seed
used by the random number generator; If RandomState instance,
random_state is the random number generator; If None, the random
number generator is the RandomState instance used by `np.random`.

Attributes
----------
Expand Down Expand Up @@ -1019,11 +1022,11 @@ class OneClassSVM(BaseLibSVM):
Hard limit on iterations within solver, or -1 for no limit.

random_state : int, RandomState instance or None, optional (default=None)
The seed of the pseudo random number generator to use when shuffling
the data. If int, random_state is the seed used by the random number
generator; If RandomState instance, random_state is the random number
generator; If None, the random number generator is the RandomState
instance used by `np.random`.
Ignored.

.. deprecated:: 0.20
``random_state`` has been deprecated in 0.20 and will be removed in
0.22.

Attributes
----------
Expand Down Expand Up @@ -1080,6 +1083,11 @@ def fit(self, X, y=None, sample_weight=None, **params):
If X is not a C-ordered contiguous array it is copied.

"""

if self.random_state is not None:
warnings.warn("The random_state parameter is deprecated and will"
" be removed in version 0.22.", DeprecationWarning)

super(OneClassSVM, self).fit(X, np.ones(_num_samples(X)),
sample_weight=sample_weight, **params)
return self
Expand Down