Skip to content

MNT black → ruff format #31015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 15, 2025

Conversation

DimitriPapadopoulos
Copy link
Contributor

@DimitriPapadopoulos DimitriPapadopoulos commented Mar 18, 2025

Reference Issues/PRs

See #30695 (comment).

What does this implement/fix? Explain your changes.

black .ruff format

Any other comments?

Not sure how to get this past CI tests in a single move.

This PR change black to ruff format, but CI tests still use black.

Copy link

github-actions bot commented Mar 18, 2025

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


black

black detected issues. Please run black . locally and push the changes. Here you can see the detected issues. Note that running black might also fix some of the issues which might be detected by ruff. Note that the installed black version is black=24.3.0.


--- /home/runner/work/scikit-learn/scikit-learn/examples/linear_model/plot_tweedie_regression_insurance_claims.py	2025-03-24 12:03:52.909617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/examples/linear_model/plot_tweedie_regression_insurance_claims.py	2025-03-24 12:04:01.771112+00:00
@@ -604,13 +604,12 @@
             "subset": subset_label,
             "observed": df["ClaimAmount"].values.sum(),
             "predicted, frequency*severity model": np.sum(
                 exposure * glm_freq.predict(X) * glm_sev.predict(X)
             ),
-            "predicted, tweedie, power=%.2f" % glm_pure_premium.power: np.sum(
-                exposure * glm_pure_premium.predict(X)
-            ),
+            "predicted, tweedie, power=%.2f"
+            % glm_pure_premium.power: np.sum(exposure * glm_pure_premium.predict(X)),
         }
     )
 
 print(pd.DataFrame(res).set_index("subset").T)
 
would reformat /home/runner/work/scikit-learn/scikit-learn/examples/linear_model/plot_tweedie_regression_insurance_claims.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/datasets/tests/test_samples_generator.py	2025-03-24 12:03:52.932617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/datasets/tests/test_samples_generator.py	2025-03-24 12:04:04.923679+00:00
@@ -136,21 +136,21 @@
             # Cluster by sign, viewed as strings to allow uniquing
             signs = np.sign(X)
             signs = signs.view(dtype="|S{0}".format(signs.strides[0])).ravel()
             unique_signs, cluster_index = np.unique(signs, return_inverse=True)
 
-            assert len(unique_signs) == n_clusters, (
-                "Wrong number of clusters, or not in distinct quadrants"
-            )
+            assert (
+                len(unique_signs) == n_clusters
+            ), "Wrong number of clusters, or not in distinct quadrants"
 
             clusters_by_class = defaultdict(set)
             for cluster, cls in zip(cluster_index, y):
                 clusters_by_class[cls].add(cluster)
             for clusters in clusters_by_class.values():
-                assert len(clusters) == n_clusters_per_class, (
-                    "Wrong number of clusters per class"
-                )
+                assert (
+                    len(clusters) == n_clusters_per_class
+                ), "Wrong number of clusters per class"
             assert len(clusters_by_class) == n_classes, "Wrong number of classes"
 
             assert_array_almost_equal(
                 np.bincount(y) / len(y) // weights,
                 [1] * n_classes,
@@ -410,13 +410,13 @@
 def test_make_blobs_n_samples_list():
     n_samples = [50, 30, 20]
     X, y = make_blobs(n_samples=n_samples, n_features=2, random_state=0)
 
     assert X.shape == (sum(n_samples), 2), "X shape mismatch"
-    assert all(np.bincount(y, minlength=len(n_samples)) == n_samples), (
-        "Incorrect number of samples per blob"
-    )
+    assert all(
+        np.bincount(y, minlength=len(n_samples)) == n_samples
+    ), "Incorrect number of samples per blob"
 
 
 def test_make_blobs_n_samples_list_with_centers():
     n_samples = [20, 20, 20]
     centers = np.array([[0.0, 0.0], [1.0, 1.0], [0.0, 1.0]])
@@ -424,13 +424,13 @@
     X, y = make_blobs(
         n_samples=n_samples, centers=centers, cluster_std=cluster_stds, random_state=0
     )
 
     assert X.shape == (sum(n_samples), 2), "X shape mismatch"
-    assert all(np.bincount(y, minlength=len(n_samples)) == n_samples), (
-        "Incorrect number of samples per blob"
-    )
+    assert all(
+        np.bincount(y, minlength=len(n_samples)) == n_samples
+    ), "Incorrect number of samples per blob"
     for i, (ctr, std) in enumerate(zip(centers, cluster_stds)):
         assert_almost_equal((X[y == i] - ctr).std(), std, 1, "Unexpected std")
 
 
 @pytest.mark.parametrize(
@@ -439,13 +439,13 @@
 def test_make_blobs_n_samples_centers_none(n_samples):
     centers = None
     X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=0)
 
     assert X.shape == (sum(n_samples), 2), "X shape mismatch"
-    assert all(np.bincount(y, minlength=len(n_samples)) == n_samples), (
-        "Incorrect number of samples per blob"
-    )
+    assert all(
+        np.bincount(y, minlength=len(n_samples)) == n_samples
+    ), "Incorrect number of samples per blob"
 
 
 def test_make_blobs_return_centers():
     n_samples = [10, 20]
     n_features = 3
@@ -679,13 +679,13 @@
         )
 
 
 def test_make_moons_unbalanced():
     X, y = make_moons(n_samples=(7, 5))
-    assert np.sum(y == 0) == 7 and np.sum(y == 1) == 5, (
-        "Number of samples in a moon is wrong"
-    )
+    assert (
+        np.sum(y == 0) == 7 and np.sum(y == 1) == 5
+    ), "Number of samples in a moon is wrong"
     assert X.shape == (12, 2), "X shape mismatch"
     assert y.shape == (12,), "y shape mismatch"
 
     with pytest.raises(
         ValueError,
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/datasets/tests/test_samples_generator.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/datasets/tests/test_openml.py	2025-03-24 12:03:52.932617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/datasets/tests/test_openml.py	2025-03-24 12:04:05.291777+00:00
@@ -103,13 +103,13 @@
             .replace("-deactivated", "-dact")
             .replace("-active", "-act")
         )
 
     def _mock_urlopen_shared(url, has_gzip_header, expected_prefix, suffix):
-        assert url.startswith(expected_prefix), (
-            f"{expected_prefix!r} does not match {url!r}"
-        )
+        assert url.startswith(
+            expected_prefix
+        ), f"{expected_prefix!r} does not match {url!r}"
 
         data_file_name = _file_name(url, suffix)
         data_file_path = resources.files(data_module) / data_file_name
 
         with data_file_path.open("rb") as f:
@@ -154,13 +154,13 @@
             expected_prefix=url_prefix_download_data,
             suffix=".arff",
         )
 
     def _mock_urlopen_data_list(url, has_gzip_header):
-        assert url.startswith(url_prefix_data_list), (
-            f"{url_prefix_data_list!r} does not match {url!r}"
-        )
+        assert url.startswith(
+            url_prefix_data_list
+        ), f"{url_prefix_data_list!r} does not match {url!r}"
 
         data_file_name = _file_name(url, ".json")
         data_file_path = resources.files(data_module) / data_file_name
 
         # load the file itself, to simulate a http error
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/datasets/tests/test_openml.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/__init__.py	2025-03-24 12:03:52.941617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/__init__.py	2025-03-24 12:04:06.680726+00:00
@@ -1,5 +1,4 @@
-
 """
 External, bundled dependencies.
 
 """
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/__init__.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_packaging/_structures.py	2025-03-24 12:03:52.941617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_packaging/_structures.py	2025-03-24 12:04:06.727013+00:00
@@ -1,8 +1,9 @@
 """Vendoered from
 https://github.com/pypa/packaging/blob/main/packaging/_structures.py
 """
+
 # Copyright (c) Donald Stufft and individual contributors.
 # All rights reserved.
 
 # Redistribution and use in source and binary forms, with or without
 # modification, are permitted provided that the following conditions are met:
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_packaging/_structures.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_packaging/version.py	2025-03-24 12:03:52.941617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_packaging/version.py	2025-03-24 12:04:07.056664+00:00
@@ -1,8 +1,9 @@
 """Vendoered from
 https://github.com/pypa/packaging/blob/main/packaging/version.py
 """
+
 # Copyright (c) Donald Stufft and individual contributors.
 # All rights reserved.
 
 # Redistribution and use in source and binary forms, with or without
 # modification, are permitted provided that the following conditions are met:
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_packaging/version.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_arff.py	2025-03-24 12:03:52.941617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_arff.py	2025-03-24 12:04:07.250223+00:00
@@ -22,11 +22,11 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 # SOFTWARE.
 # =============================================================================
 
-'''
+"""
 The liac-arff module implements functions to read and write ARFF files in
 Python. It was created in the Connectionist Artificial Intelligence Laboratory
 (LIAC), which takes place at the Federal University of Rio Grande do Sul
 (UFRGS), in Brazil.
 
@@ -138,37 +138,39 @@
 - Supports missing values and names with spaces;
 - Supports unicode values and names;
 - Fully compatible with Python 2.7+, Python 3.5+, pypy and pypy3;
 - Under `MIT License <http://opensource.org/licenses/MIT>`_
 
-'''
-__author__ = 'Renato de Pontes Pereira, Matthias Feurer, Joel Nothman'
-__author_email__ = ('renato.ppontes@gmail.com, '
-                    'feurerm@informatik.uni-freiburg.de, '
-                    'joel.nothman@gmail.com')
-__version__ = '2.4.0'
+"""
+__author__ = "Renato de Pontes Pereira, Matthias Feurer, Joel Nothman"
+__author_email__ = (
+    "renato.ppontes@gmail.com, "
+    "feurerm@informatik.uni-freiburg.de, "
+    "joel.nothman@gmail.com"
+)
+__version__ = "2.4.0"
 
 import re
 import csv
 from typing import TYPE_CHECKING
 from typing import Optional, List, Dict, Any, Iterator, Union, Tuple
 
 # CONSTANTS ===================================================================
-_SIMPLE_TYPES = ['NUMERIC', 'REAL', 'INTEGER', 'STRING']
-
-_TK_DESCRIPTION = '%'
-_TK_COMMENT     = '%'
-_TK_RELATION    = '@RELATION'
-_TK_ATTRIBUTE   = '@ATTRIBUTE'
-_TK_DATA        = '@DATA'
-
-_RE_RELATION     = re.compile(r'^([^\{\}%,\s]*|\".*\"|\'.*\')$', re.UNICODE)
-_RE_ATTRIBUTE    = re.compile(r'^(\".*\"|\'.*\'|[^\{\}%,\s]*)\s+(.+)$', re.UNICODE)
+_SIMPLE_TYPES = ["NUMERIC", "REAL", "INTEGER", "STRING"]
+
+_TK_DESCRIPTION = "%"
+_TK_COMMENT = "%"
+_TK_RELATION = "@RELATION"
+_TK_ATTRIBUTE = "@ATTRIBUTE"
+_TK_DATA = "@DATA"
+
+_RE_RELATION = re.compile(r"^([^\{\}%,\s]*|\".*\"|\'.*\')$", re.UNICODE)
+_RE_ATTRIBUTE = re.compile(r"^(\".*\"|\'.*\'|[^\{\}%,\s]*)\s+(.+)$", re.UNICODE)
 _RE_QUOTE_CHARS = re.compile(r'["\'\\\s%,\000-\031]', re.UNICODE)
 _RE_ESCAPE_CHARS = re.compile(r'(?=["\'\\%])|[\n\r\t\000-\031]')
-_RE_SPARSE_LINE = re.compile(r'^\s*\{.*\}\s*$', re.UNICODE)
-_RE_NONTRIVIAL_DATA = re.compile('["\'{}\\s]', re.UNICODE)
+_RE_SPARSE_LINE = re.compile(r"^\s*\{.*\}\s*$", re.UNICODE)
+_RE_NONTRIVIAL_DATA = re.compile("[\"'{}\\s]", re.UNICODE)
 
 ArffDenseDataType = Iterator[List]
 ArffSparseDataType = Tuple[List, ...]
 
 
@@ -185,11 +187,11 @@
 else:
     ArffContainerType = Dict[str, Any]
 
 
 def _build_re_values():
-    quoted_re = r'''
+    quoted_re = r"""
                     "      # open quote followed by zero or more of:
                     (?:
                         (?<!\\)    # no additional backslash
                         (?:\\\\)*  # maybe escaped backslashes
                         \\"        # escaped quote
@@ -197,127 +199,132 @@
                         \\[^"]     # escaping a non-quote
                     |
                         [^"\\]     # non-quote char
                     )*
                     "      # close quote
-                    '''
+                    """
     # a value is surrounded by " or by ' or contains no quotables
-    value_re = r'''(?:
+    value_re = r"""(?:
         %s|          # a value may be surrounded by "
         %s|          # or by '
         [^,\s"'{}]+  # or may contain no characters requiring quoting
-        )''' % (quoted_re,
-                quoted_re.replace('"', "'"))
+        )""" % (
+        quoted_re,
+        quoted_re.replace('"', "'"),
+    )
 
     # This captures (value, error) groups. Because empty values are allowed,
     # we cannot just look for empty values to handle syntax errors.
     # We presume the line has had ',' prepended...
-    dense = re.compile(r'''(?x)
+    dense = re.compile(
+        r"""(?x)
         ,                # may follow ','
         \s*
         ((?=,)|$|{value_re})  # empty or value
         |
         (\S.*)           # error
-        '''.format(value_re=value_re))
+        """.format(
+            value_re=value_re
+        )
+    )
 
     # This captures (key, value) groups and will have an empty key/value
     # in case of syntax errors.
     # It does not ensure that the line starts with '{' or ends with '}'.
-    sparse = re.compile(r'''(?x)
+    sparse = re.compile(
+        r"""(?x)
         (?:^\s*\{|,)   # may follow ',', or '{' at line start
         \s*
         (\d+)          # attribute key
         \s+
         (%(value_re)s) # value
         |
         (?!}\s*$)      # not an error if it's }$
         (?!^\s*{\s*}\s*$)  # not an error if it's ^{}$
         \S.*           # error
-        ''' % {'value_re': value_re})
+        """
+        % {"value_re": value_re}
+    )
     return dense, sparse
 
 
-
 _RE_DENSE_VALUES, _RE_SPARSE_KEY_VALUES = _build_re_values()
 
 
 _ESCAPE_SUB_MAP = {
-    '\\\\': '\\',
+    "\\\\": "\\",
     '\\"': '"',
     "\\'": "'",
-    '\\t': '\t',
-    '\\n': '\n',
-    '\\r': '\r',
-    '\\b': '\b',
-    '\\f': '\f',
-    '\\%': '%',
+    "\\t": "\t",
+    "\\n": "\n",
+    "\\r": "\r",
+    "\\b": "\b",
+    "\\f": "\f",
+    "\\%": "%",
 }
-_UNESCAPE_SUB_MAP = {chr(i): '\\%03o' % i for i in range(32)}
+_UNESCAPE_SUB_MAP = {chr(i): "\\%03o" % i for i in range(32)}
 _UNESCAPE_SUB_MAP.update({v: k for k, v in _ESCAPE_SUB_MAP.items()})
-_UNESCAPE_SUB_MAP[''] = '\\'
-_ESCAPE_SUB_MAP.update({'\\%d' % i: chr(i) for i in range(10)})
+_UNESCAPE_SUB_MAP[""] = "\\"
+_ESCAPE_SUB_MAP.update({"\\%d" % i: chr(i) for i in range(10)})
 
 
 def _escape_sub_callback(match):
     s = match.group()
     if len(s) == 2:
         try:
             return _ESCAPE_SUB_MAP[s]
         except KeyError:
-            raise ValueError('Unsupported escape sequence: %s' % s)
-    if s[1] == 'u':
+            raise ValueError("Unsupported escape sequence: %s" % s)
+    if s[1] == "u":
         return chr(int(s[2:], 16))
     else:
         return chr(int(s[1:], 8))
 
 
 def _unquote(v):
     if v[:1] in ('"', "'"):
-        return re.sub(r'\\([0-9]{1,3}|u[0-9a-f]{4}|.)', _escape_sub_callback,
-                      v[1:-1])
-    elif v in ('?', ''):
+        return re.sub(r"\\([0-9]{1,3}|u[0-9a-f]{4}|.)", _escape_sub_callback, v[1:-1])
+    elif v in ("?", ""):
         return None
     else:
         return v
 
 
 def _parse_values(s):
-    '''(INTERNAL) Split a line into a list of values'''
+    """(INTERNAL) Split a line into a list of values"""
     if not _RE_NONTRIVIAL_DATA.search(s):
         # Fast path for trivial cases (unfortunately we have to handle missing
         # values because of the empty string case :(.)
-        return [None if s in ('?', '') else s
-                for s in next(csv.reader([s]))]
+        return [None if s in ("?", "") else s for s in next(csv.reader([s]))]
 
     # _RE_DENSE_VALUES tokenizes despite quoting, whitespace, etc.
-    values, errors = zip(*_RE_DENSE_VALUES.findall(',' + s))
+    values, errors = zip(*_RE_DENSE_VALUES.findall("," + s))
     if not any(errors):
         return [_unquote(v) for v in values]
     if _RE_SPARSE_LINE.match(s):
         try:
-            return {int(k): _unquote(v)
-                    for k, v in _RE_SPARSE_KEY_VALUES.findall(s)}
+            return {int(k): _unquote(v) for k, v in _RE_SPARSE_KEY_VALUES.findall(s)}
         except ValueError:
             # an ARFF syntax error in sparse data
             for match in _RE_SPARSE_KEY_VALUES.finditer(s):
                 if not match.group(1):
-                    raise BadLayout('Error parsing %r' % match.group())
-            raise BadLayout('Unknown parsing error')
+                    raise BadLayout("Error parsing %r" % match.group())
+            raise BadLayout("Unknown parsing error")
     else:
         # an ARFF syntax error
         for match in _RE_DENSE_VALUES.finditer(s):
             if match.group(2):
-                raise BadLayout('Error parsing %r' % match.group())
-        raise BadLayout('Unknown parsing error')
-
-
-DENSE = 0     # Constant value representing a dense matrix
-COO = 1       # Constant value representing a sparse matrix in coordinate format
-LOD = 2       # Constant value representing a sparse matrix in list of
-              # dictionaries format
-DENSE_GEN = 3 # Generator of dictionaries
-LOD_GEN = 4   # Generator of dictionaries
+                raise BadLayout("Error parsing %r" % match.group())
+        raise BadLayout("Unknown parsing error")
+
+
+DENSE = 0  # Constant value representing a dense matrix
+COO = 1  # Constant value representing a sparse matrix in coordinate format
+LOD = 2  # Constant value representing a sparse matrix in list of
+# dictionaries format
+DENSE_GEN = 3  # Generator of dictionaries
+LOD_GEN = 4  # Generator of dictionaries
 _SUPPORTED_DATA_STRUCTURES = [DENSE, COO, LOD, DENSE_GEN, LOD_GEN]
 
 
 # EXCEPTIONS ==================================================================
 class ArffException(Exception):
@@ -325,95 +332,111 @@
 
     def __init__(self):
         self.line = -1
 
     def __str__(self):
-        return self.message%self.line
+        return self.message % self.line
+
 
 class BadRelationFormat(ArffException):
-    '''Error raised when the relation declaration is in an invalid format.'''
-    message = 'Bad @RELATION format, at line %d.'
+    """Error raised when the relation declaration is in an invalid format."""
+
+    message = "Bad @RELATION format, at line %d."
+
 
 class BadAttributeFormat(ArffException):
-    '''Error raised when some attribute declaration is in an invalid format.'''
-    message = 'Bad @ATTRIBUTE format, at line %d.'
+    """Error raised when some attribute declaration is in an invalid format."""
+
+    message = "Bad @ATTRIBUTE format, at line %d."
+
 
 class BadDataFormat(ArffException):
-    '''Error raised when some data instance is in an invalid format.'''
+    """Error raised when some data instance is in an invalid format."""
+
+    def __init__(self, value):
+        super().__init__()
+        self.message = "Bad @DATA instance format in line %d: " + ("%s" % value)
+
+
+class BadAttributeType(ArffException):
+    """Error raised when some invalid type is provided into the attribute
+    declaration."""
+
+    message = "Bad @ATTRIBUTE type, at line %d."
+
+
+class BadAttributeName(ArffException):
+    """Error raised when an attribute name is provided twice the attribute
+    declaration."""
+
+    def __init__(self, value, value2):
+        super().__init__()
+        self.message = (
+            ("Bad @ATTRIBUTE name %s at line" % value)
+            + " %d, this name is already in use in line"
+            + (" %d." % value2)
+        )
+
+
+class BadNominalValue(ArffException):
+    """Error raised when a value in used in some data instance but is not
+    declared into it respective attribute declaration."""
+
     def __init__(self, value):
         super().__init__()
         self.message = (
-            'Bad @DATA instance format in line %d: ' +
-            ('%s' % value)
-        )
-
-class BadAttributeType(ArffException):
-    '''Error raised when some invalid type is provided into the attribute
-    declaration.'''
-    message = 'Bad @ATTRIBUTE type, at line %d.'
-
-class BadAttributeName(ArffException):
-    '''Error raised when an attribute name is provided twice the attribute
-    declaration.'''
-
-    def __init__(self, value, value2):
-        super().__init__()
-        self.message = (
-            ('Bad @ATTRIBUTE name %s at line' % value) +
-            ' %d, this name is already in use in line' +
-            (' %d.' % value2)
-        )
-
-class BadNominalValue(ArffException):
-    '''Error raised when a value in used in some data instance but is not
-    declared into it respective attribute declaration.'''
+            "Data value %s not found in nominal declaration, " % value
+        ) + "at line %d."
+
+
+class BadNominalFormatting(ArffException):
+    """Error raised when a nominal value with space is not properly quoted."""
 
     def __init__(self, value):
         super().__init__()
         self.message = (
-            ('Data value %s not found in nominal declaration, ' % value)
-            + 'at line %d.'
-        )
-
-class BadNominalFormatting(ArffException):
-    '''Error raised when a nominal value with space is not properly quoted.'''
-    def __init__(self, value):
-        super().__init__()
-        self.message = (
-            ('Nominal data value "%s" not properly quoted in line ' % value) +
-            '%d.'
-        )
+            'Nominal data value "%s" not properly quoted in line ' % value
+        ) + "%d."
+
 
 class BadNumericalValue(ArffException):
-    '''Error raised when and invalid numerical value is used in some data
-    instance.'''
-    message = 'Invalid numerical value, at line %d.'
+    """Error raised when and invalid numerical value is used in some data
+    instance."""
+
+    message = "Invalid numerical value, at line %d."
+
 
 class BadStringValue(ArffException):
-    '''Error raise when a string contains space but is not quoted.'''
-    message = 'Invalid string value at line %d.'
+    """Error raise when a string contains space but is not quoted."""
+
+    message = "Invalid string value at line %d."
+
 
 class BadLayout(ArffException):
-    '''Error raised when the layout of the ARFF file has something wrong.'''
-    message = 'Invalid layout of the ARFF file, at line %d.'
-
-    def __init__(self, msg=''):
+    """Error raised when the layout of the ARFF file has something wrong."""
+
+    message = "Invalid layout of the ARFF file, at line %d."
+
+    def __init__(self, msg=""):
         super().__init__()
         if msg:
-            self.message = BadLayout.message + ' ' + msg.replace('%', '%%')
+            self.message = BadLayout.message + " " + msg.replace("%", "%%")
 
 
 class BadObject(ArffException):
-    '''Error raised when the object representing the ARFF file has something
-    wrong.'''
-    def __init__(self, msg='Invalid object.'):
+    """Error raised when the object representing the ARFF file has something
+    wrong."""
+
+    def __init__(self, msg="Invalid object."):
         self.msg = msg
 
     def __str__(self):
-        return '%s' % self.msg
+        return "%s" % self.msg
+
 
 # =============================================================================
+
 
 # INTERNAL ====================================================================
 def _unescape_sub_callback(match):
     return _UNESCAPE_SUB_MAP[match.group()]
 
@@ -452,73 +475,76 @@
             raise BadNominalValue(value)
         return str(value)
 
 
 class DenseGeneratorData:
-    '''Internal helper class to allow for different matrix types without
-    making the code a huge collection of if statements.'''
+    """Internal helper class to allow for different matrix types without
+    making the code a huge collection of if statements."""
 
     def decode_rows(self, stream, conversors):
         for row in stream:
             values = _parse_values(row)
 
             if isinstance(values, dict):
                 if values and max(values) >= len(conversors):
                     raise BadDataFormat(row)
                 # XXX: int 0 is used for implicit values, not '0'
-                values = [values[i] if i in values else 0 for i in
-                          range(len(conversors))]
+                values = [
+                    values[i] if i in values else 0 for i in range(len(conversors))
+                ]
             else:
                 if len(values) != len(conversors):
                     raise BadDataFormat(row)
 
             yield self._decode_values(values, conversors)
 
     @staticmethod
     def _decode_values(values, conversors):
         try:
-            values = [None if value is None else conversor(value)
-                      for conversor, value
-                      in zip(conversors, values)]
+            values = [
+                None if value is None else conversor(value)
+                for conversor, value in zip(conversors, values)
+            ]
         except ValueError as exc:
-            if 'float: ' in str(exc):
+            if "float: " in str(exc):
                 raise BadNumericalValue()
         return values
 
     def encode_data(self, data, attributes):
-        '''(INTERNAL) Encodes a line of data.
+        """(INTERNAL) Encodes a line of data.
 
         Data instances follow the csv format, i.e, attribute values are
         delimited by commas. After converted from csv.
 
         :param data: a list of values.
         :param attributes: a list of attributes. Used to check if data is valid.
         :return: a string with the encoded data line.
-        '''
+        """
         current_row = 0
 
         for inst in data:
             if len(inst) != len(attributes):
                 raise BadObject(
-                    'Instance %d has %d attributes, expected %d' %
-                     (current_row, len(inst), len(attributes))
+                    "Instance %d has %d attributes, expected %d"
+                    % (current_row, len(inst), len(attributes))
                 )
 
             new_data = []
             for value in inst:
-                if value is None or value == '' or value != value:
-                    s = '?'
+                if value is None or value == "" or value != value:
+                    s = "?"
                 else:
                     s = encode_string(str(value))
                 new_data.append(s)
 
             current_row += 1
-            yield ','.join(new_data)
+            yield ",".join(new_data)
 
 
 class _DataListMixin:
     """Mixin to return a list from decode_rows instead of a generator"""
+
     def decode_rows(self, stream, conversors):
         return list(super().decode_rows(stream, conversors))
 
 
 class Data(_DataListMixin, DenseGeneratorData):
@@ -534,14 +560,16 @@
                 raise BadLayout()
             if not values:
                 continue
             row_cols, values = zip(*sorted(values.items()))
             try:
-                values = [value if value is None else conversors[key](value)
-                          for key, value in zip(row_cols, values)]
+                values = [
+                    value if value is None else conversors[key](value)
+                    for key, value in zip(row_cols, values)
+                ]
             except ValueError as exc:
-                if 'float: ' in str(exc):
+                if "float: " in str(exc):
                     raise BadNumericalValue()
                 raise
             except IndexError:
                 # conversor out of range
                 raise BadDataFormat(row)
@@ -561,47 +589,51 @@
         col = data.col
         data = data.data
 
         # Check if the rows are sorted
         if not all(row[i] <= row[i + 1] for i in range(len(row) - 1)):
-            raise ValueError("liac-arff can only output COO matrices with "
-                             "sorted rows.")
+            raise ValueError(
+                "liac-arff can only output COO matrices with " "sorted rows."
+            )
 
         for v, col, row in zip(data, col, row):
             if row > current_row:
                 # Add empty rows if necessary
                 while current_row < row:
-                    yield " ".join(["{", ','.join(new_data), "}"])
+                    yield " ".join(["{", ",".join(new_data), "}"])
                     new_data = []
                     current_row += 1
 
             if col >= num_attributes:
                 raise BadObject(
-                    'Instance %d has at least %d attributes, expected %d' %
-                    (current_row, col + 1, num_attributes)
+                    "Instance %d has at least %d attributes, expected %d"
+                    % (current_row, col + 1, num_attributes)
                 )
 
-            if v is None or v == '' or v != v:
-                s = '?'
+            if v is None or v == "" or v != v:
+                s = "?"
             else:
                 s = encode_string(str(v))
             new_data.append("%d %s" % (col, s))
 
-        yield " ".join(["{", ','.join(new_data), "}"])
+        yield " ".join(["{", ",".join(new_data), "}"])
+
 
 class LODGeneratorData:
     def decode_rows(self, stream, conversors):
         for row in stream:
             values = _parse_values(row)
 
             if not isinstance(values, dict):
                 raise BadLayout()
             try:
-                yield {key: None if value is None else conversors[key](value)
-                       for key, value in values.items()}
+                yield {
+                    key: None if value is None else conversors[key](value)
+                    for key, value in values.items()
+                }
             except ValueError as exc:
-                if 'float: ' in str(exc):
+                if "float: " in str(exc):
                     raise BadNumericalValue()
                 raise
             except IndexError:
                 # conversor out of range
                 raise BadDataFormat(row)
@@ -613,24 +645,25 @@
         for row in data:
             new_data = []
 
             if len(row) > 0 and max(row) >= num_attributes:
                 raise BadObject(
-                    'Instance %d has %d attributes, expected %d' %
-                    (current_row, max(row) + 1, num_attributes)
+                    "Instance %d has %d attributes, expected %d"
+                    % (current_row, max(row) + 1, num_attributes)
                 )
 
             for col in sorted(row):
                 v = row[col]
-                if v is None or v == '' or v != v:
-                    s = '?'
+                if v is None or v == "" or v != v:
+                    s = "?"
                 else:
                     s = encode_string(str(v))
                 new_data.append("%d %s" % (col, s))
 
             current_row += 1
-            yield " ".join(["{", ','.join(new_data), "}"])
+            yield " ".join(["{", ",".join(new_data), "}"])
+
 
 class LODData(_DataListMixin, LODGeneratorData):
     pass
 
 
@@ -646,51 +679,54 @@
     elif matrix_type == LOD_GEN:
         return LODGeneratorData()
     else:
         raise ValueError("Matrix type %s not supported." % str(matrix_type))
 
+
 def _get_data_object_for_encoding(matrix):
     # Probably a scipy.sparse
-    if hasattr(matrix, 'format'):
-        if matrix.format == 'coo':
+    if hasattr(matrix, "format"):
+        if matrix.format == "coo":
             return COOData()
         else:
-            raise ValueError('Cannot guess matrix format!')
+            raise ValueError("Cannot guess matrix format!")
     elif isinstance(matrix[0], dict):
         return LODData()
     else:
         return Data()
 
+
 # =============================================================================
+
 
 # ADVANCED INTERFACE ==========================================================
 class ArffDecoder:
-    '''An ARFF decoder.'''
+    """An ARFF decoder."""
 
     def __init__(self):
-        '''Constructor.'''
+        """Constructor."""
         self._conversors = []
         self._current_line = 0
 
     def _decode_comment(self, s):
-        '''(INTERNAL) Decodes a comment line.
+        """(INTERNAL) Decodes a comment line.
 
         Comments are single line strings starting, obligatorily, with the ``%``
         character, and can have any symbol, including whitespaces or special
         characters.
 
         This method must receive a normalized string, i.e., a string without
         padding, including the "\r\n" characters.
 
         :param s: a normalized string.
         :return: a string with the decoded comment.
-        '''
-        res = re.sub(r'^\%( )?', '', s)
+        """
+        res = re.sub(r"^\%( )?", "", s)
         return res
 
     def _decode_relation(self, s):
-        '''(INTERNAL) Decodes a relation line.
+        """(INTERNAL) Decodes a relation line.
 
         The relation declaration is a line with the format ``@RELATION
         <relation-name>``, where ``relation-name`` is a string. The string must
         start with alphabetic character and must be quoted if the name includes
         spaces, otherwise this method will raise a `BadRelationFormat` exception.
@@ -698,22 +734,22 @@
         This method must receive a normalized string, i.e., a string without
         padding, including the "\r\n" characters.
 
         :param s: a normalized string.
         :return: a string with the decoded relation name.
-        '''
-        _, v = s.split(' ', 1)
+        """
+        _, v = s.split(" ", 1)
         v = v.strip()
 
         if not _RE_RELATION.match(v):
             raise BadRelationFormat()
 
-        res = str(v.strip('"\''))
+        res = str(v.strip("\"'"))
         return res
 
     def _decode_attribute(self, s):
-        '''(INTERNAL) Decodes an attribute line.
+        """(INTERNAL) Decodes an attribute line.
 
         The attribute is the most complex declaration in an arff file. All
         attributes must follow the template::
 
              @attribute <attribute-name> <datatype>
@@ -734,12 +770,12 @@
         This method must receive a normalized string, i.e., a string without
         padding, including the "\r\n" characters.
 
         :param s: a normalized string.
         :return: a tuple (ATTRIBUTE_NAME, TYPE_OR_VALUES).
-        '''
-        _, v = s.split(' ', 1)
+        """
+        _, v = s.split(" ", 1)
         v = v.strip()
 
         # Verify the general structure of declaration
         m = _RE_ATTRIBUTE.match(v)
         if not m:
@@ -747,45 +783,45 @@
 
         # Extracts the raw name and type
         name, type_ = m.groups()
 
         # Extracts the final name
-        name = str(name.strip('"\''))
+        name = str(name.strip("\"'"))
 
         # Extracts the final type
         if type_[:1] == "{" and type_[-1:] == "}":
             try:
-                type_ = _parse_values(type_.strip('{} '))
+                type_ = _parse_values(type_.strip("{} "))
             except Exception:
                 raise BadAttributeType()
             if isinstance(type_, dict):
                 raise BadAttributeType()
 
         else:
             # If not nominal, verify the type name
             type_ = str(type_).upper()
-            if type_ not in ['NUMERIC', 'REAL', 'INTEGER', 'STRING']:
+            if type_ not in ["NUMERIC", "REAL", "INTEGER", "STRING"]:
                 raise BadAttributeType()
 
         return (name, type_)
 
     def _decode(self, s, encode_nominal=False, matrix_type=DENSE):
-        '''Do the job the ``encode``.'''
+        """Do the job the ``encode``."""
 
         # Make sure this method is idempotent
         self._current_line = 0
 
         # If string, convert to a list of lines
         if isinstance(s, str):
-            s = s.strip('\r\n ').replace('\r\n', '\n').split('\n')
+            s = s.strip("\r\n ").replace("\r\n", "\n").split("\n")
 
         # Create the return object
         obj: ArffContainerType = {
-            'description': '',
-            'relation': '',
-            'attributes': [],
-            'data': []
+            "description": "",
+            "relation": "",
+            "attributes": [],
+            "data": [],
         }
         attribute_names = {}
 
         # Create the data helper object
         data = _get_data_object_for_decoding(matrix_type)
@@ -794,27 +830,28 @@
         STATE = _TK_DESCRIPTION
         s = iter(s)
         for row in s:
             self._current_line += 1
             # Ignore empty lines
-            row = row.strip(' \r\n')
-            if not row: continue
+            row = row.strip(" \r\n")
+            if not row:
+                continue
 
             u_row = row.upper()
 
             # DESCRIPTION -----------------------------------------------------
             if u_row.startswith(_TK_DESCRIPTION) and STATE == _TK_DESCRIPTION:
-                obj['description'] += self._decode_comment(row) + '\n'
+                obj["description"] += self._decode_comment(row) + "\n"
             # -----------------------------------------------------------------
 
             # RELATION --------------------------------------------------------
             elif u_row.startswith(_TK_RELATION):
                 if STATE != _TK_DESCRIPTION:
                     raise BadLayout()
 
                 STATE = _TK_RELATION
-                obj['relation'] = self._decode_relation(row)
+                obj["relation"] = self._decode_relation(row)
             # -----------------------------------------------------------------
 
             # ATTRIBUTE -------------------------------------------------------
             elif u_row.startswith(_TK_ATTRIBUTE):
                 if STATE != _TK_RELATION and STATE != _TK_ATTRIBUTE:
@@ -825,22 +862,24 @@
                 attr = self._decode_attribute(row)
                 if attr[0] in attribute_names:
                     raise BadAttributeName(attr[0], attribute_names[attr[0]])
                 else:
                     attribute_names[attr[0]] = self._current_line
-                obj['attributes'].append(attr)
+                obj["attributes"].append(attr)
 
                 if isinstance(attr[1], (list, tuple)):
                     if encode_nominal:
                         conversor = EncodedNominalConversor(attr[1])
                     else:
                         conversor = NominalConversor(attr[1])
                 else:
-                    CONVERSOR_MAP = {'STRING': str,
-                                     'INTEGER': lambda x: int(float(x)),
-                                     'NUMERIC': float,
-                                     'REAL': float}
+                    CONVERSOR_MAP = {
+                        "STRING": str,
+                        "INTEGER": lambda x: int(float(x)),
+                        "NUMERIC": float,
+                        "REAL": float,
+                    }
                     conversor = CONVERSOR_MAP[attr[1]]
 
                 self._conversors.append(conversor)
             # -----------------------------------------------------------------
 
@@ -867,18 +906,18 @@
                 # Ignore empty lines and comment lines.
                 if row and not row.startswith(_TK_COMMENT):
                     yield row
 
         # Alter the data object
-        obj['data'] = data.decode_rows(stream(), self._conversors)
-        if obj['description'].endswith('\n'):
-            obj['description'] = obj['description'][:-1]
+        obj["data"] = data.decode_rows(stream(), self._conversors)
+        if obj["description"].endswith("\n"):
+            obj["description"] = obj["description"][:-1]
 
         return obj
 
     def decode(self, s, encode_nominal=False, return_type=DENSE):
-        '''Returns the Python representation of a given ARFF file.
+        """Returns the Python representation of a given ARFF file.
 
         When a file object is passed as an argument, this method reads lines
         iteratively, avoiding to load unnecessary information to the memory.
 
         :param s: a string or file object with the ARFF file.
@@ -887,57 +926,58 @@
         :param return_type: determines the data structure used to store the
             dataset. Can be one of `arff.DENSE`, `arff.COO`, `arff.LOD`,
             `arff.DENSE_GEN` or `arff.LOD_GEN`.
             Consult the sections on `working with sparse data`_ and `loading
             progressively`_.
-        '''
+        """
         try:
-            return self._decode(s, encode_nominal=encode_nominal,
-                                matrix_type=return_type)
+            return self._decode(
+                s, encode_nominal=encode_nominal, matrix_type=return_type
+            )
         except ArffException as e:
             e.line = self._current_line
             raise e
 
 
 class ArffEncoder:
-    '''An ARFF encoder.'''
-
-    def _encode_comment(self, s=''):
-        '''(INTERNAL) Encodes a comment line.
+    """An ARFF encoder."""
+
+    def _encode_comment(self, s=""):
+        """(INTERNAL) Encodes a comment line.
 
         Comments are single line strings starting, obligatorily, with the ``%``
         character, and can have any symbol, including whitespaces or special
         characters.
 
         If ``s`` is None, this method will simply return an empty comment.
 
         :param s: (OPTIONAL) string.
         :return: a string with the encoded comment line.
-        '''
+        """
         if s:
-            return '%s %s'%(_TK_COMMENT, s)
+            return "%s %s" % (_TK_COMMENT, s)
         else:
-            return '%s' % _TK_COMMENT
+            return "%s" % _TK_COMMENT
 
     def _encode_relation(self, name):
-        '''(INTERNAL) Decodes a relation line.
+        """(INTERNAL) Decodes a relation line.
 
         The relation declaration is a line with the format ``@RELATION
         <relation-name>``, where ``relation-name`` is a string.
 
         :param name: a string.
         :return: a string with the encoded relation declaration.
-        '''
-        for char in ' %{},':
+        """
+        for char in " %{},":
             if char in name:
-                name = '"%s"'%name
+                name = '"%s"' % name
                 break
 
-        return '%s %s'%(_TK_RELATION, name)
+        return "%s %s" % (_TK_RELATION, name)
 
     def _encode_attribute(self, name, type_):
-        '''(INTERNAL) Encodes an attribute line.
+        """(INTERNAL) Encodes an attribute line.
 
         The attribute follow the template::
 
              @attribute <attribute-name> <datatype>
 
@@ -954,98 +994,104 @@
         the attribute type is nominal, ``type`` must be a list of values.
 
         :param name: a string.
         :param type_: a string or a list of string.
         :return: a string with the encoded attribute declaration.
-        '''
-        for char in ' %{},':
+        """
+        for char in " %{},":
             if char in name:
-                name = '"%s"'%name
+                name = '"%s"' % name
                 break
 
         if isinstance(type_, (tuple, list)):
-            type_tmp = ['%s' % encode_string(type_k) for type_k in type_]
-            type_ = '{%s}'%(', '.join(type_tmp))
-
-        return '%s %s %s'%(_TK_ATTRIBUTE, name, type_)
+            type_tmp = ["%s" % encode_string(type_k) for type_k in type_]
+            type_ = "{%s}" % (", ".join(type_tmp))
+
+        return "%s %s %s" % (_TK_ATTRIBUTE, name, type_)
 
     def encode(self, obj):
-        '''Encodes a given object to an ARFF file.
+        """Encodes a given object to an ARFF file.
 
         :param obj: the object containing the ARFF information.
         :return: the ARFF file as an string.
-        '''
+        """
         data = [row for row in self.iter_encode(obj)]
 
-        return '\n'.join(data)
+        return "\n".join(data)
 
     def iter_encode(self, obj):
-        '''The iterative version of `arff.ArffEncoder.encode`.
+        """The iterative version of `arff.ArffEncoder.encode`.
 
         This encodes iteratively a given object and return, one-by-one, the
         lines of the ARFF file.
 
         :param obj: the object containing the ARFF information.
         :return: (yields) the ARFF file as strings.
-        '''
+        """
         # DESCRIPTION
-        if obj.get('description', None):
-            for row in obj['description'].split('\n'):
+        if obj.get("description", None):
+            for row in obj["description"].split("\n"):
                 yield self._encode_comment(row)
 
         # RELATION
-        if not obj.get('relation'):
-            raise BadObject('Relation name not found or with invalid value.')
-
-        yield self._encode_relation(obj['relation'])
-        yield ''
+        if not obj.get("relation"):
+            raise BadObject("Relation name not found or with invalid value.")
+
+        yield self._encode_relation(obj["relation"])
+        yield ""
 
         # ATTRIBUTES
-        if not obj.get('attributes'):
-            raise BadObject('Attributes not found.')
+        if not obj.get("attributes"):
+            raise BadObject("Attributes not found.")
 
         attribute_names = set()
-        for attr in obj['attributes']:
+        for attr in obj["attributes"]:
             # Verify for bad object format
-            if not isinstance(attr, (tuple, list)) or \
-               len(attr) != 2 or \
-               not isinstance(attr[0], str):
-                raise BadObject('Invalid attribute declaration "%s"'%str(attr))
+            if (
+                not isinstance(attr, (tuple, list))
+                or len(attr) != 2
+                or not isinstance(attr[0], str)
+            ):
+                raise BadObject('Invalid attribute declaration "%s"' % str(attr))
 
             if isinstance(attr[1], str):
                 # Verify for invalid types
                 if attr[1] not in _SIMPLE_TYPES:
-                    raise BadObject('Invalid attribute type "%s"'%str(attr))
+                    raise BadObject('Invalid attribute type "%s"' % str(attr))
 
             # Verify for bad object format
             elif not isinstance(attr[1], (tuple, list)):
-                raise BadObject('Invalid attribute type "%s"'%str(attr))
+                raise BadObject('Invalid attribute type "%s"' % str(attr))
 
             # Verify attribute name is not used twice
             if attr[0] in attribute_names:
-                raise BadObject('Trying to use attribute name "%s" for the '
-                                'second time.' % str(attr[0]))
+                raise BadObject(
+                    'Trying to use attribute name "%s" for the '
+                    "second time." % str(attr[0])
+                )
             else:
                 attribute_names.add(attr[0])
 
             yield self._encode_attribute(attr[0], attr[1])
-        yield ''
-        attributes = obj['attributes']
+        yield ""
+        attributes = obj["attributes"]
 
         # DATA
         yield _TK_DATA
-        if 'data' in obj:
-            data = _get_data_object_for_encoding(obj.get('data'))
-            yield from data.encode_data(obj.get('data'), attributes)
-
-        yield ''
+        if "data" in obj:
+            data = _get_data_object_for_encoding(obj.get("data"))
+            yield from data.encode_data(obj.get("data"), attributes)
+
+        yield ""
+
 
 # =============================================================================
+
 
 # BASIC INTERFACE =============================================================
 def load(fp, encode_nominal=False, return_type=DENSE):
-    '''Load a file-like object containing the ARFF document and convert it into
+    """Load a file-like object containing the ARFF document and convert it into
     a Python object.
 
     :param fp: a file-like object.
     :param encode_nominal: boolean, if True perform a label encoding
         while reading the .arff file.
@@ -1053,17 +1099,17 @@
         dataset. Can be one of `arff.DENSE`, `arff.COO`, `arff.LOD`,
         `arff.DENSE_GEN` or `arff.LOD_GEN`.
         Consult the sections on `working with sparse data`_ and `loading
         progressively`_.
     :return: a dictionary.
-     '''
+    """
     decoder = ArffDecoder()
-    return decoder.decode(fp, encode_nominal=encode_nominal,
-                          return_type=return_type)
+    return decoder.decode(fp, encode_nominal=encode_nominal, return_type=return_type)
+
 
 def loads(s, encode_nominal=False, return_type=DENSE):
-    '''Convert a string instance containing the ARFF document into a Python
+    """Convert a string instance containing the ARFF document into a Python
     object.
 
     :param s: a string object.
     :param encode_nominal: boolean, if True perform a label encoding
         while reading the .arff file.
@@ -1071,37 +1117,40 @@
         dataset. Can be one of `arff.DENSE`, `arff.COO`, `arff.LOD`,
         `arff.DENSE_GEN` or `arff.LOD_GEN`.
         Consult the sections on `working with sparse data`_ and `loading
         progressively`_.
     :return: a dictionary.
-    '''
+    """
     decoder = ArffDecoder()
-    return decoder.decode(s, encode_nominal=encode_nominal,
-                          return_type=return_type)
+    return decoder.decode(s, encode_nominal=encode_nominal, return_type=return_type)
+
 
 def dump(obj, fp):
-    '''Serialize an object representing the ARFF document to a given file-like
+    """Serialize an object representing the ARFF document to a given file-like
     object.
 
     :param obj: a dictionary.
     :param fp: a file-like object.
-    '''
+    """
     encoder = ArffEncoder()
     generator = encoder.iter_encode(obj)
 
     last_row = next(generator)
     for row in generator:
-        fp.write(last_row + '\n')
+        fp.write(last_row + "\n")
         last_row = row
     fp.write(last_row)
 
     return fp
 
+
 def dumps(obj):
-    '''Serialize an object representing the ARFF document, returning a string.
+    """Serialize an object representing the ARFF document, returning a string.
 
     :param obj: a dictionary.
     :return: a string with the ARFF document.
-    '''
+    """
     encoder = ArffEncoder()
     return encoder.encode(obj)
+
+
 # =============================================================================
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/externals/_arff.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/ensemble/tests/test_forest.py	2025-03-24 12:03:52.940617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/ensemble/tests/test_forest.py	2025-03-24 12:04:07.795107+00:00
@@ -166,16 +166,15 @@
     ForestRegressor = FOREST_REGRESSORS[name]
 
     reg = ForestRegressor(n_estimators=5, criterion=criterion, random_state=1)
     reg.fit(X_reg, y_reg)
     score = reg.score(X_reg, y_reg)
-    assert score > 0.93, (
-        "Failed with max_features=None, criterion %s and score = %f"
-        % (
-            criterion,
-            score,
-        )
+    assert (
+        score > 0.93
+    ), "Failed with max_features=None, criterion %s and score = %f" % (
+        criterion,
+        score,
     )
 
     reg = ForestRegressor(
         n_estimators=5, criterion=criterion, max_features=6, random_state=1
     )
@@ -1067,14 +1066,14 @@
         est.fit(X, y, sample_weight=weights)
         out = est.estimators_[0].tree_.apply(X)
         node_weights = np.bincount(out, weights=weights)
         # drop inner nodes
         leaf_weights = node_weights[node_weights != 0]
-        assert np.min(leaf_weights) >= total_weight * est.min_weight_fraction_leaf, (
-            "Failed with {0} min_weight_fraction_leaf={1}".format(
-                name, est.min_weight_fraction_leaf
-            )
+        assert (
+            np.min(leaf_weights) >= total_weight * est.min_weight_fraction_leaf
+        ), "Failed with {0} min_weight_fraction_leaf={1}".format(
+            name, est.min_weight_fraction_leaf
         )
 
 
 @pytest.mark.parametrize("name", FOREST_ESTIMATORS)
 @pytest.mark.parametrize(
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/ensemble/tests/test_forest.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py	2025-03-24 12:03:52.947617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py	2025-03-24 12:04:09.484740+00:00
@@ -1184,25 +1184,25 @@
         pd_line_kw=pd_line_kw,
         ice_lines_kw=ice_lines_kw,
     )
 
     line = disp.lines_[0, 0, -1]
-    assert line.get_color() == expected_colors[0], (
-        f"{line.get_color()}!={expected_colors[0]}\n{line_kw} and {pd_line_kw}"
-    )
+    assert (
+        line.get_color() == expected_colors[0]
+    ), f"{line.get_color()}!={expected_colors[0]}\n{line_kw} and {pd_line_kw}"
     if pd_line_kw is not None:
         if "linestyle" in pd_line_kw:
             assert line.get_linestyle() == pd_line_kw["linestyle"]
         elif "ls" in pd_line_kw:
             assert line.get_linestyle() == pd_line_kw["ls"]
     else:
         assert line.get_linestyle() == "--"
 
     line = disp.lines_[0, 0, 0]
-    assert line.get_color() == expected_colors[1], (
-        f"{line.get_color()}!={expected_colors[1]}"
-    )
+    assert (
+        line.get_color() == expected_colors[1]
+    ), f"{line.get_color()}!={expected_colors[1]}"
     if ice_lines_kw is not None:
         if "linestyle" in ice_lines_kw:
             assert line.get_linestyle() == ice_lines_kw["linestyle"]
         elif "ls" in ice_lines_kw:
             assert line.get_linestyle() == ice_lines_kw["ls"]
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/_linear_loss.py	2025-03-24 12:03:52.950617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/_linear_loss.py	2025-03-24 12:04:09.635155+00:00
@@ -535,13 +535,13 @@
 
             if l2_reg_strength > 0:
                 # The L2 penalty enters the Hessian on the diagonal only. To add those
                 # terms, we use a flattened view of the array.
                 order = "C" if hess.flags.c_contiguous else "F"
-                hess.reshape(-1, order=order)[: (n_features * n_dof) : (n_dof + 1)] += (
-                    l2_reg_strength
-                )
+                hess.reshape(-1, order=order)[
+                    : (n_features * n_dof) : (n_dof + 1)
+                ] += l2_reg_strength
 
             if self.fit_intercept:
                 # With intercept included as added column to X, the hessian becomes
                 # hess = (X, 1)' @ diag(h) @ (X, 1)
                 #      = (X' @ diag(h) @ X, X' @ h)
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/_linear_loss.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/manifold/_t_sne.py	2025-03-24 12:03:52.955617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/manifold/_t_sne.py	2025-03-24 12:04:11.231932+00:00
@@ -962,13 +962,13 @@
 
             # compute the joint probability distribution for the input space
             P = _joint_probabilities(distances, self.perplexity, self.verbose)
             assert np.all(np.isfinite(P)), "All probabilities should be finite"
             assert np.all(P >= 0), "All probabilities should be non-negative"
-            assert np.all(P <= 1), (
-                "All probabilities should be less or then equal to one"
-            )
+            assert np.all(
+                P <= 1
+            ), "All probabilities should be less or then equal to one"
 
         else:
             # Compute the number of nearest neighbors to find.
             # LvdM uses 3 * perplexity as the number of neighbors.
             # In the event that we have very small # of points
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/manifold/_t_sne.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/tests/test_ridge.py	2025-03-24 12:03:52.953617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/tests/test_ridge.py	2025-03-24 12:04:12.307061+00:00
@@ -857,13 +857,13 @@
     gcv_ridge = RidgeCV(fit_intercept=True, alphas=alphas, scoring=scoring)
 
     loo_ridge.fit(X, y)
     gcv_ridge.fit(X, y)
 
-    assert gcv_ridge.alpha_ == pytest.approx(loo_ridge.alpha_), (
-        f"{gcv_ridge.alpha_=}, {loo_ridge.alpha_=}"
-    )
+    assert gcv_ridge.alpha_ == pytest.approx(
+        loo_ridge.alpha_
+    ), f"{gcv_ridge.alpha_=}, {loo_ridge.alpha_=}"
     assert_allclose(gcv_ridge.coef_, loo_ridge.coef_, rtol=1e-3)
     assert_allclose(gcv_ridge.intercept_, loo_ridge.intercept_, rtol=1e-3)
 
 
 @pytest.mark.parametrize("gcv_mode", ["svd", "eigen"])
@@ -1517,13 +1517,13 @@
     else:
         y = rng.randint(0, 2, n_samples)
     X = rng.randn(n_samples, n_features)
 
     ridge_est = Estimator(alphas=alphas)
-    assert ridge_est.alphas is alphas, (
-        f"`alphas` was mutated in `{Estimator.__name__}.__init__`"
-    )
+    assert (
+        ridge_est.alphas is alphas
+    ), f"`alphas` was mutated in `{Estimator.__name__}.__init__`"
 
     ridge_est.fit(X, y)
     assert_array_equal(ridge_est.alphas, np.asarray(alphas))
 
 
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/tests/test_ridge.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/metrics/tests/test_common.py	2025-03-24 12:03:52.960617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/metrics/tests/test_common.py	2025-03-24 12:04:13.467445+00:00
@@ -1001,12 +1001,11 @@
 
 
 @pytest.mark.parametrize("metric", CLASSIFICATION_METRICS.values())
 @pytest.mark.parametrize(
     "y_true, y_score",
-    invalids_nan_inf
-    +
+    invalids_nan_inf +
     # Add an additional case for classification only
     # non-regression test for:
     # https://github.com/scikit-learn/scikit-learn/issues/6809
     [
         ([np.nan, 1, 2], [1, 2, 3]),
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/metrics/tests/test_common.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/metrics/tests/test_pairwise_distances_reduction.py	2025-03-24 12:03:52.961617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/metrics/tests/test_pairwise_distances_reduction.py	2025-03-24 12:04:13.635919+00:00
@@ -226,13 +226,13 @@
     # Find a non-trivial radius using a small subsample of the pairwise
     # distances between X and Y: we want to return around expected_n_neighbors
     # on average. Yielding too many results would make the test slow (because
     # checking the results is expensive for large result sets), yielding 0 most
     # of the time would make the test useless.
-    assert precomputed_dists is not None or metric is not None, (
-        "Either metric or precomputed_dists must be provided."
-    )
+    assert (
+        precomputed_dists is not None or metric is not None
+    ), "Either metric or precomputed_dists must be provided."
 
     if precomputed_dists is None:
         assert X is not None
         assert Y is not None
         sampled_dists = pairwise_distances(X, Y, metric=metric, **metric_kwargs)
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/metrics/tests/test_pairwise_distances_reduction.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/model_selection/tests/test_split.py	2025-03-24 12:03:52.965617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/model_selection/tests/test_split.py	2025-03-24 12:04:16.074445+00:00
@@ -883,13 +883,13 @@
         # per index is close enough to a binomial
         threshold = 0.05 / n_splits
         bf = stats.binom(n_splits, p)
         for count in idx_counts:
             prob = bf.pmf(count)
-            assert prob > threshold, (
-                "An index is not drawn with chance corresponding to even draws"
-            )
+            assert (
+                prob > threshold
+            ), "An index is not drawn with chance corresponding to even draws"
 
     for n_samples in (6, 22):
         groups = np.array((n_samples // 2) * [0, 1])
         splits = StratifiedShuffleSplit(
             n_splits=n_splits, test_size=1.0 / n_folds, random_state=0
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/model_selection/tests/test_split.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/model_selection/tests/test_search.py	2025-03-24 12:03:52.964617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/model_selection/tests/test_search.py	2025-03-24 12:04:16.499634+00:00
@@ -2417,13 +2417,13 @@
     attr_message = "BaseSearchCV _pairwise property must match estimator"
 
     for _pairwise_setting in [True, False]:
         est.set_params(pairwise=_pairwise_setting)
         cv = GridSearchCV(est, {"n_neighbors": [10]})
-        assert _pairwise_setting == cv.__sklearn_tags__().input_tags.pairwise, (
-            attr_message
-        )
+        assert (
+            _pairwise_setting == cv.__sklearn_tags__().input_tags.pairwise
+        ), attr_message
 
 
 def test_search_cv_pairwise_property_equivalence_of_precomputed():
     """
     Test implementation of BaseSearchCV has the pairwise tag
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/model_selection/tests/test_search.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/preprocessing/tests/test_function_transformer.py	2025-03-24 12:03:52.972617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/preprocessing/tests/test_function_transformer.py	2025-03-24 12:04:17.256757+00:00
@@ -34,17 +34,17 @@
         X,
         "transform should have returned X unchanged",
     )
 
     # The function should only have received X.
-    assert args_store == [X], (
-        "Incorrect positional arguments passed to func: {args}".format(args=args_store)
-    )
-
-    assert not kwargs_store, (
-        "Unexpected keyword arguments passed to func: {args}".format(args=kwargs_store)
-    )
+    assert args_store == [
+        X
+    ], "Incorrect positional arguments passed to func: {args}".format(args=args_store)
+
+    assert (
+        not kwargs_store
+    ), "Unexpected keyword arguments passed to func: {args}".format(args=kwargs_store)
 
     # reset the argument stores.
     args_store[:] = []
     kwargs_store.clear()
     transformed = FunctionTransformer(
@@ -54,17 +54,17 @@
     assert_array_equal(
         transformed, X, err_msg="transform should have returned X unchanged"
     )
 
     # The function should have received X
-    assert args_store == [X], (
-        "Incorrect positional arguments passed to func: {args}".format(args=args_store)
-    )
-
-    assert not kwargs_store, (
-        "Unexpected keyword arguments passed to func: {args}".format(args=kwargs_store)
-    )
+    assert args_store == [
+        X
+    ], "Incorrect positional arguments passed to func: {args}".format(args=args_store)
+
+    assert (
+        not kwargs_store
+    ), "Unexpected keyword arguments passed to func: {args}".format(args=kwargs_store)
 
 
 def test_np_log():
     X = np.arange(10).reshape((5, 2))
 
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/preprocessing/tests/test_function_transformer.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/metadata_routing_common.py	2025-03-24 12:03:52.975617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/metadata_routing_common.py	2025-03-24 12:04:17.965981+00:00
@@ -72,26 +72,26 @@
         getattr(obj, "_records", dict()).get(method, dict()).get(parent, list())
     )
     for record in all_records:
         # first check that the names of the metadata passed are the same as
         # expected. The names are stored as keys in `record`.
-        assert set(kwargs.keys()) == set(record.keys()), (
-            f"Expected {kwargs.keys()} vs {record.keys()}"
-        )
+        assert set(kwargs.keys()) == set(
+            record.keys()
+        ), f"Expected {kwargs.keys()} vs {record.keys()}"
         for key, value in kwargs.items():
             recorded_value = record[key]
             # The following condition is used to check for any specified parameters
             # being a subset of the original values
             if key in split_params and recorded_value is not None:
                 assert np.isin(recorded_value, value).all()
             else:
                 if isinstance(recorded_value, np.ndarray):
                     assert_array_equal(recorded_value, value)
                 else:
-                    assert recorded_value is value, (
-                        f"Expected {recorded_value} vs {value}. Method: {method}"
-                    )
+                    assert (
+                        recorded_value is value
+                    ), f"Expected {recorded_value} vs {value}. Method: {method}"
 
 
 record_metadata_not_default = partial(record_metadata, record_default=False)
 
 
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/metadata_routing_common.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/test_metaestimators.py	2025-03-24 12:03:52.977617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/test_metaestimators.py	2025-03-24 12:04:18.439122+00:00
@@ -155,16 +155,15 @@
         delegator = delegator_data.construct(delegate)
         for method in methods:
             if method in delegator_data.skip_methods:
                 continue
             assert hasattr(delegate, method)
-            assert hasattr(delegator, method), (
-                "%s does not have method %r when its delegate does"
-                % (
-                    delegator_data.name,
-                    method,
-                )
+            assert hasattr(
+                delegator, method
+            ), "%s does not have method %r when its delegate does" % (
+                delegator_data.name,
+                method,
             )
             # delegation before fit raises a NotFittedError
             if method == "score":
                 with pytest.raises(NotFittedError):
                     getattr(delegator, method)(
@@ -190,16 +189,15 @@
             if method in delegator_data.skip_methods:
                 continue
             delegate = SubEstimator(hidden_method=method)
             delegator = delegator_data.construct(delegate)
             assert not hasattr(delegate, method)
-            assert not hasattr(delegator, method), (
-                "%s has method %r when its delegate does not"
-                % (
-                    delegator_data.name,
-                    method,
-                )
+            assert not hasattr(
+                delegator, method
+            ), "%s has method %r when its delegate does not" % (
+                delegator_data.name,
+                method,
             )
 
 
 def _get_instance_with_pipeline(meta_estimator, init_params):
     """Given a single meta-estimator instance, generate an instance with a pipeline"""
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/test_metaestimators.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/test_discriminant_analysis.py	2025-03-24 12:03:52.976617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/test_discriminant_analysis.py	2025-03-24 12:04:18.572824+00:00
@@ -303,20 +303,20 @@
     y = state.randint(0, 3, size=(40,))
 
     clf_lda_eigen = LinearDiscriminantAnalysis(solver="eigen")
     clf_lda_eigen.fit(X, y)
     assert_almost_equal(clf_lda_eigen.explained_variance_ratio_.sum(), 1.0, 3)
-    assert clf_lda_eigen.explained_variance_ratio_.shape == (2,), (
-        "Unexpected length for explained_variance_ratio_"
-    )
+    assert clf_lda_eigen.explained_variance_ratio_.shape == (
+        2,
+    ), "Unexpected length for explained_variance_ratio_"
 
     clf_lda_svd = LinearDiscriminantAnalysis(solver="svd")
     clf_lda_svd.fit(X, y)
     assert_almost_equal(clf_lda_svd.explained_variance_ratio_.sum(), 1.0, 3)
-    assert clf_lda_svd.explained_variance_ratio_.shape == (2,), (
-        "Unexpected length for explained_variance_ratio_"
-    )
+    assert clf_lda_svd.explained_variance_ratio_.shape == (
+        2,
+    ), "Unexpected length for explained_variance_ratio_"
 
     assert_array_almost_equal(
         clf_lda_svd.explained_variance_ratio_, clf_lda_eigen.explained_variance_ratio_
     )
 
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/tests/test_discriminant_analysis.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/tree/tests/test_monotonic_tree.py	2025-03-24 12:03:52.979617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/tree/tests/test_monotonic_tree.py	2025-03-24 12:04:19.322092+00:00
@@ -78,13 +78,13 @@
     if sparse_splitter:
         X_train = csc_container(X_train)
     est.fit(X_train, y_train)
     proba_test = est.predict_proba(X_test)
 
-    assert np.logical_and(proba_test >= 0.0, proba_test <= 1.0).all(), (
-        "Probability should always be in [0, 1] range."
-    )
+    assert np.logical_and(
+        proba_test >= 0.0, proba_test <= 1.0
+    ).all(), "Probability should always be in [0, 1] range."
     assert_allclose(proba_test.sum(axis=1), 1.0)
 
     # Monotonic increase constraint, it applies to the positive class
     assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= proba_test[:, 1])
     assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= proba_test[:, 1])
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/tree/tests/test_monotonic_tree.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/tests/test_multiclass.py	2025-03-24 12:03:52.987617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/tests/test_multiclass.py	2025-03-24 12:04:21.564223+00:00
@@ -364,21 +364,21 @@
                         + DOK_CONTAINERS
                         + LIL_CONTAINERS
                     )
                 ]
                 for exmpl_sparse in examples_sparse:
-                    assert sparse_exp == is_multilabel(exmpl_sparse), (
-                        f"is_multilabel({exmpl_sparse!r}) should be {sparse_exp}"
-                    )
+                    assert sparse_exp == is_multilabel(
+                        exmpl_sparse
+                    ), f"is_multilabel({exmpl_sparse!r}) should be {sparse_exp}"
 
             # Densify sparse examples before testing
             if issparse(example):
                 example = example.toarray()
 
-            assert dense_exp == is_multilabel(example), (
-                f"is_multilabel({example!r}) should be {dense_exp}"
-            )
+            assert dense_exp == is_multilabel(
+                example
+            ), f"is_multilabel({example!r}) should be {dense_exp}"
 
 
 @pytest.mark.parametrize(
     "array_namespace, device, dtype_name",
     yield_namespace_device_dtype_combinations(),
@@ -394,13 +394,13 @@
             else:
                 example = np.asarray(example)
             example = xp.asarray(example, device=device)
 
             with config_context(array_api_dispatch=True):
-                assert dense_exp == is_multilabel(example), (
-                    f"is_multilabel({example!r}) should be {dense_exp}"
-                )
+                assert dense_exp == is_multilabel(
+                    example
+                ), f"is_multilabel({example!r}) should be {dense_exp}"
 
 
 def test_check_classification_targets():
     for y_type in EXAMPLES.keys():
         if y_type in ["unknown", "continuous", "continuous-multioutput"]:
@@ -414,17 +414,16 @@
 
 
 def test_type_of_target():
     for group, group_examples in EXAMPLES.items():
         for example in group_examples:
-            assert type_of_target(example) == group, (
-                "type_of_target(%r) should be %r, got %r"
-                % (
-                    example,
-                    group,
-                    type_of_target(example),
-                )
+            assert (
+                type_of_target(example) == group
+            ), "type_of_target(%r) should be %r, got %r" % (
+                example,
+                group,
+                type_of_target(example),
             )
 
     for example in NON_ARRAY_LIKE_EXAMPLES:
         msg_regex = r"Expected array-like \(array or non-string sequence\).*"
         with pytest.raises(ValueError, match=msg_regex):
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/tests/test_multiclass.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/tree/tests/test_tree.py	2025-03-24 12:03:52.980617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/tree/tests/test_tree.py	2025-03-24 12:04:21.831342+00:00
@@ -196,14 +196,14 @@
     "zeros": {"X": np.zeros((20, 3)), "y": y_random},
 }
 
 
 def assert_tree_equal(d, s, message):
-    assert s.node_count == d.node_count, (
-        "{0}: inequal number of node ({1} != {2})".format(
-            message, s.node_count, d.node_count
-        )
+    assert (
+        s.node_count == d.node_count
+    ), "{0}: inequal number of node ({1} != {2})".format(
+        message, s.node_count, d.node_count
     )
 
     assert_array_equal(
         d.children_right, s.children_right, message + ": inequal children_right"
     )
@@ -328,13 +328,13 @@
     # check consistency of overfitted trees on the diabetes dataset
     # since the trees will overfit, we expect an MSE of 0
     reg = Tree(criterion=criterion, random_state=0)
     reg.fit(diabetes.data, diabetes.target)
     score = mean_squared_error(diabetes.target, reg.predict(diabetes.data))
-    assert score == pytest.approx(0), (
-        f"Failed with {name}, criterion = {criterion} and score = {score}"
-    )
+    assert score == pytest.approx(
+        0
+    ), f"Failed with {name}, criterion = {criterion} and score = {score}"
 
 
 @skip_if_32bit
 @pytest.mark.parametrize("name, Tree", REG_TREES.items())
 @pytest.mark.parametrize(
@@ -695,14 +695,14 @@
             out = est.tree_.apply(X)
 
         node_weights = np.bincount(out, weights=weights)
         # drop inner nodes
         leaf_weights = node_weights[node_weights != 0]
-        assert np.min(leaf_weights) >= total_weight * est.min_weight_fraction_leaf, (
-            "Failed with {0} min_weight_fraction_leaf={1}".format(
-                name, est.min_weight_fraction_leaf
-            )
+        assert (
+            np.min(leaf_weights) >= total_weight * est.min_weight_fraction_leaf
+        ), "Failed with {0} min_weight_fraction_leaf={1}".format(
+            name, est.min_weight_fraction_leaf
         )
 
     # test case with no weights passed in
     total_weight = X.shape[0]
 
@@ -718,14 +718,14 @@
             out = est.tree_.apply(X)
 
         node_weights = np.bincount(out)
         # drop inner nodes
         leaf_weights = node_weights[node_weights != 0]
-        assert np.min(leaf_weights) >= total_weight * est.min_weight_fraction_leaf, (
-            "Failed with {0} min_weight_fraction_leaf={1}".format(
-                name, est.min_weight_fraction_leaf
-            )
+        assert (
+            np.min(leaf_weights) >= total_weight * est.min_weight_fraction_leaf
+        ), "Failed with {0} min_weight_fraction_leaf={1}".format(
+            name, est.min_weight_fraction_leaf
         )
 
 
 @pytest.mark.parametrize("name", ALL_TREES)
 def test_min_weight_fraction_leaf_on_dense_input(name):
@@ -843,14 +843,14 @@
             (est1, 1e-7),
             (est2, 0.05),
             (est3, 0.0001),
             (est4, 0.1),
         ):
-            assert est.min_impurity_decrease <= expected_decrease, (
-                "Failed, min_impurity_decrease = {0} > {1}".format(
-                    est.min_impurity_decrease, expected_decrease
-                )
+            assert (
+                est.min_impurity_decrease <= expected_decrease
+            ), "Failed, min_impurity_decrease = {0} > {1}".format(
+                est.min_impurity_decrease, expected_decrease
             )
             est.fit(X, y)
             for node in range(est.tree_.node_count):
                 # If current node is a not leaf node, check if the split was
                 # justified w.r.t the min_impurity_decrease
@@ -877,14 +877,14 @@
 
                     actual_decrease = fractional_node_weight * (
                         imp_parent - wtd_avg_left_right_imp
                     )
 
-                    assert actual_decrease >= expected_decrease, (
-                        "Failed with {0} expected min_impurity_decrease={1}".format(
-                            actual_decrease, expected_decrease
-                        )
+                    assert (
+                        actual_decrease >= expected_decrease
+                    ), "Failed with {0} expected min_impurity_decrease={1}".format(
+                        actual_decrease, expected_decrease
                     )
 
 
 def test_pickle():
     """Test pickling preserves Tree properties and performance."""
@@ -921,13 +921,13 @@
         serialized_object = pickle.dumps(est)
         est2 = pickle.loads(serialized_object)
         assert type(est2) == est.__class__
 
         score2 = est2.score(X, y)
-        assert score == score2, (
-            "Failed to generate same score  after pickling with {0}".format(name)
-        )
+        assert (
+            score == score2
+        ), "Failed to generate same score  after pickling with {0}".format(name)
         for attribute in fitted_attribute:
             assert_array_equal(
                 getattr(est2.tree_, attribute),
                 fitted_attribute[attribute],
                 err_msg=(
@@ -2612,13 +2612,13 @@
     tree = Tree(random_state=global_random_seed)
 
     # Check that the tree can learn the predictive feature
     # over an average of cross-validation fits.
     tree_cv_score = cross_val_score(tree, X, y, cv=5).mean()
-    assert tree_cv_score >= expected_score, (
-        f"Expected CV score: {expected_score} but got {tree_cv_score}"
-    )
+    assert (
+        tree_cv_score >= expected_score
+    ), f"Expected CV score: {expected_score} but got {tree_cv_score}"
 
 
 @pytest.mark.parametrize(
     "make_data, Tree",
     [
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/tree/tests/test_tree.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/validation.py	2025-03-24 12:03:52.988617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/validation.py	2025-03-24 12:04:22.701356+00:00
@@ -1545,11 +1545,12 @@
     return (
         # This is used during test collection in common tests. The
         # hasattr(estimator, "fit") makes it so that we don't fail for an estimator
         # that does not have a `fit` method during collection of checks. The right
         # checks will fail later.
-        hasattr(estimator, "fit") and parameter in signature(estimator.fit).parameters
+        hasattr(estimator, "fit")
+        and parameter in signature(estimator.fit).parameters
     )
 
 
 def check_symmetric(array, *, tol=1e-10, raise_warning=True, raise_exception=False):
     """Make sure that array is 2D, square and symmetric.
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/validation.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/estimator_checks.py	2025-03-24 12:03:52.984617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/estimator_checks.py	2025-03-24 12:04:23.304065+00:00
@@ -4767,13 +4767,13 @@
     if isinstance(X_transform, tuple):
         n_features_out = X_transform[0].shape[1]
     else:
         n_features_out = X_transform.shape[1]
 
-    assert len(feature_names_out) == n_features_out, (
-        f"Expected {n_features_out} feature names, got {len(feature_names_out)}"
-    )
+    assert (
+        len(feature_names_out) == n_features_out
+    ), f"Expected {n_features_out} feature names, got {len(feature_names_out)}"
 
 
 def check_transformer_get_feature_names_out_pandas(name, transformer_orig):
     try:
         import pandas as pd
@@ -4824,13 +4824,13 @@
     if isinstance(X_transform, tuple):
         n_features_out = X_transform[0].shape[1]
     else:
         n_features_out = X_transform.shape[1]
 
-    assert len(feature_names_out_default) == n_features_out, (
-        f"Expected {n_features_out} feature names, got {len(feature_names_out_default)}"
-    )
+    assert (
+        len(feature_names_out_default) == n_features_out
+    ), f"Expected {n_features_out} feature names, got {len(feature_names_out_default)}"
 
 
 def check_param_validation(name, estimator_orig):
     # Check that an informative error is raised when the value of a constructor
     # parameter does not have an appropriate type or value.
@@ -5337,11 +5337,13 @@
         if y_type != 'binary':
             raise ValueError(
                 'Only binary classification is supported. The type of the target '
                 f'is {{y_type}}.'
         )
-    """.format(name=name)
+    """.format(
+        name=name
+    )
     err_msg = textwrap.dedent(err_msg)
 
     with raises(
         ValueError, match="Only binary classification is supported.", err_msg=err_msg
     ):
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/estimator_checks.py
--- /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/tests/test_validation.py	2025-03-24 12:03:52.988617+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/tests/test_validation.py	2025-03-24 12:04:23.444773+00:00
@@ -847,13 +847,13 @@
     class TestClassWithDeprecatedFitMethod:
         @deprecated("Deprecated for the purpose of testing has_fit_parameter")
         def fit(self, X, y, sample_weight=None):
             pass
 
-    assert has_fit_parameter(TestClassWithDeprecatedFitMethod, "sample_weight"), (
-        "has_fit_parameter fails for class with deprecated fit method."
-    )
+    assert has_fit_parameter(
+        TestClassWithDeprecatedFitMethod, "sample_weight"
+    ), "has_fit_parameter fails for class with deprecated fit method."
 
 
 def test_check_symmetric():
     arr_sym = np.array([[0, 1], [1, 2]])
     arr_bad = np.ones(2)
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/utils/tests/test_validation.py

Oh no! 💥 💔 💥
26 files would be reformatted, 904 files would be left unchanged.

Generated for commit: 3e6767b. Link to the linter CI: here

@DimitriPapadopoulos DimitriPapadopoulos force-pushed the black_ruff_format branch 2 times, most recently from 40e80b0 to a04d7be Compare March 18, 2025 15:37
@DimitriPapadopoulos DimitriPapadopoulos force-pushed the black_ruff_format branch 2 times, most recently from b87f48e to 63e9224 Compare March 18, 2025 17:18
@DimitriPapadopoulos DimitriPapadopoulos marked this pull request as ready for review March 18, 2025 17:25
@DimitriPapadopoulos DimitriPapadopoulos force-pushed the black_ruff_format branch 4 times, most recently from e2d9310 to f6d2b41 Compare March 20, 2025 08:06
Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mostly okay with this. We need to check what the difference is now between the --preview mode of black and ruff, especially on long string lines, which black was mostly handling nicely.

cc @scikit-learn/core-devs

This reduces our developer dependencies by one, which is nice.

@DimitriPapadopoulos
Copy link
Contributor Author

Rebased to solve conflicts. Please consider merging #30976, first to update ruff, so that we start with a recent and more stable version of the formatter.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM.

@@ -10,26 +10,25 @@ set -o pipefail

global_status=0

echo -e "### Running black ###\n"
black --check --diff .
echo -e "### Running ruff check ###\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these messages must be identical to what's in the get_comment.py for it to find the right section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linting.sh and get_comment.py should be in sync now.

@DimitriPapadopoulos
Copy link
Contributor Author

  • Rebased to solve conflicts and take into account recent ruff 0.11 from MNT Update ruff #30976.
  • Pinned ruff to latest 0.11.2.
  • Rerun ruff format with ruff 0.11.2.

@DimitriPapadopoulos DimitriPapadopoulos force-pushed the black_ruff_format branch 3 times, most recently from 907d6d5 to fd4b36f Compare March 24, 2025 16:34
@DimitriPapadopoulos
Copy link
Contributor Author

DimitriPapadopoulos commented Mar 24, 2025

The linter / lint error is because CI runs the previous version of build_tools/shared.sh:

  curl https://raw.githubusercontent.com/scikit-learn/scikit-learn/main/build_tools/shared.sh --retry 5 -o ./build_tools/shared.sh
  source build_tools/shared.sh

- name: Install dependencies
run: |
curl https://raw.githubusercontent.com/${{ github.repository }}/main/build_tools/shared.sh --retry 5 -o ./build_tools/shared.sh
source build_tools/shared.sh

@DimitriPapadopoulos
Copy link
Contributor Author

Not sure about the linter / comment error:

Downloading single artifact
Error: Unable to download artifact(s): Artifact not found for name: lint-log
        Please ensure that your artifact is not expired and the artifact was uploaded using a compatible version of toolkit/upload-artifact.
        For more information, visit the GitHub Artifacts FAQ: https://github.com/actions/toolkit/blob/main/packages/artifact/docs/faq.md

@DimitriPapadopoulos
Copy link
Contributor Author

I think the linter / comment error is a result of the linter / lint error and the lack of a /tmp/versions.txt file.

- name: Upload Artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: lint-log
path: |
/tmp/linting_output.txt
/tmp/versions.txt

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's one "comma" issue left, and I think you've forgotten to push your fixes for the other ones?

Otherwise, I'm happy to merge this and see if it works on main since that's where the bot gets its data.

@DimitriPapadopoulos
Copy link
Contributor Author

Rebased and fixed the spurious commas.

@adrinjalali
Copy link
Member

Please avoid force pushing. It makes reviews much harder since I can't check what's changed. We don't care about number of commits in a PR. Everything's squashed and merged anyway.

@thomasjpfan
Copy link
Member

We need to check what the difference is now between the --preview mode of black and ruff, especially on long string lines, which black was mostly handling nicely.

Although long string handling is nice with black --preview, I feel like it will never stabilize, (psf/black#4208). So I'm okay with just switching to ruff without needing black's string handling.

@DimitriPapadopoulos
Copy link
Contributor Author

DimitriPapadopoulos commented Mar 25, 2025

I'm not sure which part of black --preview you're interested in. From Preview style / Improved string processing:

Black will split long string literals and merge short ones.

It's true that ruff format avoids splitting long strings. The feature is tracked in astral-sh/ruff#6936.

@thomasjpfan
Copy link
Member

Yup, I'm referring to the "Improve string processing" in black. For this PR, if it means removing another tool, I'm happy with not having it.

@adrinjalali
Copy link
Member

Sorry I didn't get to merge this in time, @DimitriPapadopoulos would you mind fixing the merge conflicts?

@DimitriPapadopoulos
Copy link
Contributor Author

I will:

  • delete the last commit (ruff format),
  • rebase,
  • rerun ruff format and commit the result.

Additionally, although it doesn't matter here, run the linter first,
then the formatter, as suggested by the documentation:
	Ruff's formatter is designed to be used alongside the linter.
	However, the linter includes some rules that, when enabled,
	can cause conflicts with the formatter, leading to unexpected
	behavior. When configured appropriately, the goal of Ruff's
	formatter-linter compatibility is such that running the formatter
	should never introduce new lint errors.
From the ruff documentation:
	By default, Ruff enables Flake8's `F` rules, along with a subset of the `E`
	rules, omitting any stylistic rules that overlap with the use of a formatter,
	like `ruff format` or Black.

Therefore let ruff select the default ruleset, to avoid `E` rules that overlap
with the formatter. Only document additional rules.
I001 Import block is un-sorted or un-formatted
They would be interpreted as a tuple by `ruff format`.
@DimitriPapadopoulos
Copy link
Contributor Author

It would be great if this could be merged before new conflicts appear.

@lesteve
Copy link
Member

lesteve commented Apr 15, 2025

2 approvals already let's merge this and see what happens (probably many PRs will have conflicts but 😅).

The linter GHA is red but this is expected because the bot comment thingy is partly driven by main as mentioned in #31015 (review).

@lesteve lesteve merged commit ff78e25 into scikit-learn:main Apr 15, 2025
37 of 40 checks passed
@DimitriPapadopoulos DimitriPapadopoulos deleted the black_ruff_format branch April 15, 2025 11:25
DimitriPapadopoulos added a commit to DimitriPapadopoulos/scikit-learn that referenced this pull request Apr 15, 2025
* Enforce ruff rules (RUF) (scikit-learn#30694)
* Apply ruff/flake8-implicit-str-concat rules (ISC) (scikit-learn#30695)
* black → ruff format (scikit-learn#31015)
DimitriPapadopoulos added a commit to DimitriPapadopoulos/scikit-learn that referenced this pull request Apr 15, 2025
* Enforce ruff rules (RUF) (scikit-learn#30694)
* Apply ruff/flake8-implicit-str-concat rules (ISC) (scikit-learn#30695)
* black → ruff format (scikit-learn#31015)
DimitriPapadopoulos added a commit to DimitriPapadopoulos/scikit-learn that referenced this pull request Apr 15, 2025
* Enforce ruff rules (RUF) (scikit-learn#30694)
* Apply ruff/flake8-implicit-str-concat rules (ISC) (scikit-learn#30695)
* black → ruff format (scikit-learn#31015)
DimitriPapadopoulos added a commit to DimitriPapadopoulos/scikit-learn that referenced this pull request Apr 15, 2025
Introduced by this commit in scikit-learn#31015:
620b0de MNT black → ruff format
DimitriPapadopoulos added a commit to DimitriPapadopoulos/scikit-learn that referenced this pull request Apr 15, 2025
* Enforce ruff rules (RUF) (scikit-learn#30694)
* Apply ruff/flake8-implicit-str-concat rules (ISC) (scikit-learn#30695)
* black → ruff format (scikit-learn#31015)
lucyleeow pushed a commit to EmilyXinyi/scikit-learn that referenced this pull request Apr 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants