Skip to content

DEP PassiveAggressiveClassifier and PassiveAggressiveRegressor #29097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d26659a
DEP PassiveAggressiveClassifier and PassiveAggressiveRegressor
lorentzenchr May 24, 2024
fa3d842
ENH add C to SGD init
lorentzenchr May 25, 2024
7bdf688
DOC equivvalent estimator
lorentzenchr May 25, 2024
0111071
MNT redundant parameter validation
lorentzenchr May 25, 2024
3051404
DOC whatsnew 1.6
lorentzenchr May 25, 2024
3af1f2f
Merge branch 'main' into dep_passive_aggressive
lorentzenchr May 25, 2024
4f23bc0
MNT after merging main
lorentzenchr May 25, 2024
eb1c83e
Merge branch 'main' into dep_passive_aggressive
lorentzenchr Oct 22, 2024
df3825e
DOC add new whatsnew
lorentzenchr Oct 22, 2024
c24c95f
CLN make C private again
lorentzenchr Oct 22, 2024
306a5fa
FIX signature of deprecated classes
lorentzenchr Oct 23, 2024
c367698
FIX _parameter_constraints in SGD classes
lorentzenchr Oct 23, 2024
67f6404
FIX numpydoc GL09 placement of deprecation in docstring
lorentzenchr Oct 24, 2024
246cd04
TST remove PA von other tests
lorentzenchr Oct 24, 2024
05b7bbd
FIX tests by using learning_rate="pa1"
lorentzenchr Oct 25, 2024
28e59b1
Merge branch 'main' into dep_passive_aggressive
lorentzenchr Jul 29, 2025
5d02438
MNT remove type: ignore
lorentzenchr Jul 29, 2025
dc1369a
MNT rename C to PA_C and improve docstring
lorentzenchr Jul 29, 2025
aea16e6
MNT remove PA_C from method signatures and use self.PA_C instead
lorentzenchr Jul 29, 2025
260efc4
Merge branch 'main' into dep_passive_aggressive
lorentzenchr Jul 29, 2025
3f8268f
MNT C instead of PA_C in PassiveAggressiveClassifier
lorentzenchr Jul 29, 2025
f339ea1
ENH add PA to SGD
lorentzenchr Jul 29, 2025
1e260bf
DOC deprecate in 1.8 remove in 1.10 and whatsnew
lorentzenchr Jul 31, 2025
e5b01ab
MNT parameter constraint for PA_C
lorentzenchr Jul 31, 2025
28eb112
DOC use .. code-block::
lorentzenchr Jul 31, 2025
15c7b8f
TST add filterwarnings ignore of PA estimators in pyproject.toml
lorentzenchr Jul 31, 2025
64fde84
DOC add filterwarnings to doc/conf.py
lorentzenchr Jul 31, 2025
16da746
DOC fix typo in docstring
lorentzenchr Jul 31, 2025
4e69d78
DOC/MNT remove/replace PA from user guide and place TODO(1.10) for la…
lorentzenchr Jul 31, 2025
06e40e5
Merge branch 'main' into dep_passive_aggressive
lorentzenchr Jul 31, 2025
66280c7
MNT add PassiveAggressive in _get_warnings_filters_info_list
lorentzenchr Jul 31, 2025
1e70f8c
DOC fix typo in user guide
lorentzenchr Jul 31, 2025
d0290fc
DOC fix docstring of _plain_sgd
lorentzenchr Aug 1, 2025
08ca849
MNT remove filterwarnings from pyproject.toml
lorentzenchr Aug 1, 2025
414fa30
Merge branch 'main' into dep_passive_aggressive
OmarManzoor Aug 12, 2025
b2144ee
DOC correct docstring about available loss for pa2
lorentzenchr Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def _get_submodule(module_name, submodule_name):
"autosummary": [
"LogisticRegression",
"LogisticRegressionCV",
"PassiveAggressiveClassifier",
"PassiveAggressiveClassifier", # TODO(1.10): remove
"Perceptron",
"RidgeClassifier",
"RidgeClassifierCV",
Expand Down Expand Up @@ -672,7 +672,7 @@ def _get_submodule(module_name, submodule_name):
{
"title": "Miscellaneous",
"autosummary": [
"PassiveAggressiveRegressor",
"PassiveAggressiveRegressor", # TODO(1.10): remove
"enet_path",
"lars_path",
"lars_path_gram",
Expand Down
7 changes: 3 additions & 4 deletions doc/computing/computational_performance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ prediction latency too much. We will now review this idea for different
families of supervised models.

For :mod:`sklearn.linear_model` (e.g. Lasso, ElasticNet,
SGDClassifier/Regressor, Ridge & RidgeClassifier,
PassiveAggressiveClassifier/Regressor, LinearSVC, LogisticRegression...) the
decision function that is applied at prediction time is the same (a dot product)
, so latency should be equivalent.
SGDClassifier/Regressor, Ridge & RidgeClassifier, LinearSVC, LogisticRegression...) the
decision function that is applied at prediction time is the same (a dot product), so
latency should be equivalent.

Here is an example using
:class:`~linear_model.SGDClassifier` with the
Expand Down
6 changes: 2 additions & 4 deletions doc/computing/scaling_strategies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ Here is a list of incremental estimators for different tasks:
+ :class:`sklearn.naive_bayes.BernoulliNB`
+ :class:`sklearn.linear_model.Perceptron`
+ :class:`sklearn.linear_model.SGDClassifier`
+ :class:`sklearn.linear_model.PassiveAggressiveClassifier`
+ :class:`sklearn.neural_network.MLPClassifier`
- Regression
+ :class:`sklearn.linear_model.SGDRegressor`
+ :class:`sklearn.linear_model.PassiveAggressiveRegressor`
+ :class:`sklearn.neural_network.MLPRegressor`
- Clustering
+ :class:`sklearn.cluster.MiniBatchKMeans`
Expand All @@ -91,7 +89,7 @@ classes to the first ``partial_fit`` call using the ``classes=`` parameter.
Another aspect to consider when choosing a proper algorithm is that not all of
them put the same importance on each example over time. Namely, the
``Perceptron`` is still sensitive to badly labeled examples even after many
examples whereas the ``SGD*`` and ``PassiveAggressive*`` families are more
examples whereas the ``SGD*`` family is more
robust to this kind of artifacts. Conversely, the latter also tend to give less
importance to remarkably different, yet properly labeled examples when they
come late in the stream as their learning rate decreases over time.
Expand Down Expand Up @@ -130,7 +128,7 @@ Notes
......

.. [1] Depending on the algorithm the mini-batch size can influence results or
not. SGD*, PassiveAggressive*, and discrete NaiveBayes are truly online
not. SGD* and discrete NaiveBayes are truly online
and are not affected by batch size. Conversely, MiniBatchKMeans
convergence rate is affected by the batch size. Also, its memory
footprint can vary dramatically with batch size.
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,8 @@ def setup(app):
" non-GUI backend, so cannot show the figure."
),
)
# TODO(1.10): remove PassiveAggressive
warnings.filterwarnings("ignore", category=FutureWarning, message="PassiveAggressive")
if os.environ.get("SKLEARN_WARNINGS_AS_ERRORS", "0") != "0":
turn_warnings_into_errors()

Expand Down
2 changes: 1 addition & 1 deletion doc/modules/feature_extraction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ text classification tasks.

Note that the dimensionality does not affect the CPU training time of
algorithms which operate on CSR matrices (``LinearSVC(dual=True)``,
``Perceptron``, ``SGDClassifier``, ``PassiveAggressive``) but it does for
``Perceptron``, ``SGDClassifier``) but it does for
algorithms that work with CSC matrices (``LinearSVC(dual=False)``, ``Lasso()``,
etc.).

Expand Down
28 changes: 14 additions & 14 deletions doc/modules/linear_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1335,10 +1335,10 @@ You can refer to the dedicated :ref:`sgd` documentation section for more details
.. _perceptron:

Perceptron
==========
----------

The :class:`Perceptron` is another simple classification algorithm suitable for
large scale learning. By default:
large scale learning and derives from SGD. By default:

- It does not require a learning rate.

Expand All @@ -1358,18 +1358,18 @@ for more details.
.. _passive_aggressive:

Passive Aggressive Algorithms
=============================

The passive-aggressive algorithms are a family of algorithms for large-scale
learning. They are similar to the Perceptron in that they do not require a
learning rate. However, contrary to the Perceptron, they include a
regularization parameter ``C``.

For classification, :class:`PassiveAggressiveClassifier` can be used with
``loss='hinge'`` (PA-I) or ``loss='squared_hinge'`` (PA-II). For regression,
:class:`PassiveAggressiveRegressor` can be used with
``loss='epsilon_insensitive'`` (PA-I) or
``loss='squared_epsilon_insensitive'`` (PA-II).
-----------------------------

The passive-aggressive (PA) algorithms are another family of 2 algorithms (PA-I and
PA-II) for large-scale online learning that derive from SGD. They are similar to the
Perceptron in that they do not require a learning rate. However, contrary to the
Perceptron, they include a regularization parameter ``PA_C``.

For classification,
:class:`SGDClassifier(loss="hinge", penalty=None, learning_rate="pa1", PA_C=1.0)` can
be used for PA-I or with ``learning_rate="pa2"`` for PA-II. For regression,
:class:`SGDRegressor(loss="epsilon_insensitive", penalty=None, learning_rate="pa1",
PA_C=1.0)` can be used for PA-I or with ``learning_rate="pa2"`` for PA-II.

.. dropdown:: References

Expand Down
1 change: 0 additions & 1 deletion doc/modules/multiclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ can provide additional strategies beyond what is built-in:
- :class:`linear_model.LogisticRegressionCV` (most solvers)
- :class:`linear_model.SGDClassifier`
- :class:`linear_model.Perceptron`
- :class:`linear_model.PassiveAggressiveClassifier`


- **Support multilabel:**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- `PassiveAggressiveClassifier` and `PassiveAggressiveRegressor` are deprecated
and will be removed in 1.10. Equivalent estimators are available with `SGDClassifier`
and `SGDRegressor`, both of which expose the options `learning_rate="pa1"` and
`"pa2"` as well as the new parameter `PA_C` for the aggressiveness parameter of the
Passive-Aggressive-Algorithms.
By :user:`Christian Lorentzen <lorentzenchr>`.
6 changes: 4 additions & 2 deletions examples/applications/plot_out_of_core_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from sklearn.datasets import get_data_home
from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.linear_model import PassiveAggressiveClassifier, Perceptron, SGDClassifier
from sklearn.linear_model import Perceptron, SGDClassifier
from sklearn.naive_bayes import MultinomialNB


Expand Down Expand Up @@ -208,7 +208,9 @@ def progress(blocknum, bs, size):
"SGD": SGDClassifier(max_iter=5),
"Perceptron": Perceptron(),
"NB Multinomial": MultinomialNB(alpha=0.01),
"Passive-Aggressive": PassiveAggressiveClassifier(),
"Passive-Aggressive": SGDClassifier(
loss="hinge", penalty=None, learning_rate="pa1", PA_C=1.0
),
}


Expand Down
5 changes: 2 additions & 3 deletions sklearn/feature_selection/tests/test_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
LassoCV,
LinearRegression,
LogisticRegression,
PassiveAggressiveClassifier,
SGDClassifier,
)
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -393,8 +392,8 @@ def test_2d_coef():


def test_partial_fit():
est = PassiveAggressiveClassifier(
random_state=0, shuffle=False, max_iter=5, tol=None
est = SGDClassifier(
random_state=0, shuffle=False, max_iter=5, tol=None, learning_rate="pa1"
)
transformer = SelectFromModel(estimator=est)
transformer.partial_fit(data, y, classes=np.unique(y))
Expand Down
65 changes: 55 additions & 10 deletions sklearn/linear_model/_passive_aggressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,41 @@
BaseSGDClassifier,
BaseSGDRegressor,
)
from sklearn.utils import deprecated
from sklearn.utils._param_validation import Interval, StrOptions


# TODO(1.10): Remove
@deprecated(
"this is deprecated in version 1.8 and will be removed in 1.10. "
"Use `SGDClassifier(loss='hinge', penalty=None, learning_rate='pa1', PA_C=1.0)` "
"instead."
)
class PassiveAggressiveClassifier(BaseSGDClassifier):
"""Passive Aggressive Classifier.

.. deprecated:: 1.8
The whole class `PassiveAggressiveClassifier` was deprecated in version 1.8
and will be removed in 1.10. Instead use:

.. code-block:: python

clf = SGDClassifier(
loss="hinge",
penalty=None,
learning_rate="pa1", # or "pa2"
PA_C=1.0, # for parameter C
)

Read more in the :ref:`User Guide <passive_aggressive>`.

Parameters
----------
C : float, default=1.0
Maximum step size (regularization). Defaults to 1.0.
Aggressiveness parameter for the passive-agressive algorithm, see [1].
For PA-I it is the maximum step size. For PA-II it regularizes the
step size (the smaller `PA_C` the more it regularizes).
As a general rule-of-thumb, `PA_C` should be small when the data is noisy.

fit_intercept : bool, default=True
Whether the intercept should be estimated or not. If False, the
Expand Down Expand Up @@ -154,9 +177,9 @@ class PassiveAggressiveClassifier(BaseSGDClassifier):

References
----------
Online Passive-Aggressive Algorithms
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
.. [1] Online Passive-Aggressive Algorithms
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)

Examples
--------
Expand Down Expand Up @@ -212,6 +235,7 @@ def __init__(
verbose=verbose,
random_state=random_state,
eta0=1.0,
PA_C=C,
warm_start=warm_start,
class_weight=class_weight,
average=average,
Expand Down Expand Up @@ -262,12 +286,13 @@ def partial_fit(self, X, y, classes=None):
"parameter."
)

# For an explanation, see
# https://github.com/scikit-learn/scikit-learn/pull/1259#issuecomment-9818044
lr = "pa1" if self.loss == "hinge" else "pa2"
return self._partial_fit(
X,
y,
alpha=1.0,
C=self.C,
loss="hinge",
learning_rate=lr,
max_iter=1,
Expand Down Expand Up @@ -307,24 +332,45 @@ def fit(self, X, y, coef_init=None, intercept_init=None):
X,
y,
alpha=1.0,
C=self.C,
loss="hinge",
learning_rate=lr,
coef_init=coef_init,
intercept_init=intercept_init,
)


# TODO(1.10): Remove
@deprecated(
"this is deprecated in version 1.8 and will be removed in 1.10. "
"Use `SGDRegressor(loss='epsilon_insensitive', penalty=None, learning_rate='pa1', "
"PA_C = 1.0)` instead."
)
class PassiveAggressiveRegressor(BaseSGDRegressor):
"""Passive Aggressive Regressor.

.. deprecated:: 1.8
The whole class `PassiveAggressiveRegressor` was deprecated in version 1.8
and will be removed in 1.10. Instead use:

.. code-block:: python

reg = SGDRegressor(
loss="epsilon_insensitive",
penalty=None,
learning_rate="pa1", # or "pa2"
PA_C=1.0, # for parameter C
)

Read more in the :ref:`User Guide <passive_aggressive>`.

Parameters
----------

C : float, default=1.0
Maximum step size (regularization). Defaults to 1.0.
Aggressiveness parameter for the passive-agressive algorithm, see [1].
For PA-I it is the maximum step size. For PA-II it regularizes the
step size (the smaller `PA_C` the more it regularizes).
As a general rule-of-thumb, `PA_C` should be small when the data is noisy.

fit_intercept : bool, default=True
Whether the intercept should be estimated or not. If False, the
Expand Down Expand Up @@ -486,10 +532,12 @@ def __init__(
average=False,
):
super().__init__(
loss=loss,
penalty=None,
l1_ratio=0,
epsilon=epsilon,
eta0=1.0,
PA_C=C,
fit_intercept=fit_intercept,
max_iter=max_iter,
tol=tol,
Expand All @@ -503,7 +551,6 @@ def __init__(
average=average,
)
self.C = C
self.loss = loss

@_fit_context(prefer_skip_nested_validation=True)
def partial_fit(self, X, y):
Expand All @@ -530,7 +577,6 @@ def partial_fit(self, X, y):
X,
y,
alpha=1.0,
C=self.C,
loss="epsilon_insensitive",
learning_rate=lr,
max_iter=1,
Expand Down Expand Up @@ -569,7 +615,6 @@ def fit(self, X, y, coef_init=None, intercept_init=None):
X,
y,
alpha=1.0,
C=self.C,
loss="epsilon_insensitive",
learning_rate=lr,
coef_init=coef_init,
Expand Down
Loading
Loading