Skip to content

FIX _safe_indexing for pyarrow #31040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Mar 20, 2025

Reference Issues/PRs

Partially addresses #25896 (comment).

What does this implement/fix? Explain your changes.

_safe_indexing(.., axis=1) is used in the ColumnTransformer and raises an error if a pyarrow.Table is passed even though it implements the dataframe interchange protocol:

import pyarrow as pa
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
from sklearn.compose import ColumnTransformer

X, y = load_iris(as_frame=True, return_X_y=True)
sepal_cols = ["sepal length (cm)", "sepal width (cm)"]
petal_cols = ["petal length (cm)", "petal width (cm)"]

preprocessor = ColumnTransformer(
    [
        ("scaler", StandardScaler(), sepal_cols),
        ("kbin", KBinsDiscretizer(encode="ordinal"), petal_cols),
    ],
    verbose_feature_names_out=False,
)

X_pa = pa.table(X)
preprocessor.fit_transform(X_pa)

results in

python3.10/site-packages/sklearn/utils/_indexing.py:270, in _safe_indexing(X, indices, axis)
    268     return _polars_indexing(X, indices, indices_dtype, axis=axis)
    269 elif hasattr(X, "shape"):
--> 270     return _array_indexing(X, indices, indices_dtype, axis=axis)
    271 else:
    272     return _list_indexing(X, indices, indices_dtype)

File python3.10/site-packages/sklearn/utils/_indexing.py:36, in _array_indexing(array, key, key_dtype, axis)
     34 if isinstance(key, tuple):
     35     key = list(key)
---> 36 return array[key, ...] if axis == 0 else array[:, key]

...
python3.10/site-packages/pyarrow/table.pxi:1725, in pyarrow.lib._Tabular._ensure_integer_index()

TypeError: Index must either be string or integer

which shows that the wrong branch (elif hasattr(X, "shape"):) is taken.

Any other comments?

There is no general solution with __dataframe__ because of data-apis/dataframe-api#85.

Therefore a dirtier solution is taken.
Therefore, just pyarrow indexing is implemented.

narwhals would be much cleaner, but needs its own dedicated issue for discussion. This PR just fixes a bug with pyarrow.Table passed around.

Copy link

github-actions bot commented Mar 20, 2025

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


ruff format

ruff detected issues. Please run ruff format locally and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.5.1.


--- examples/applications/plot_species_distribution_modeling.py
+++ examples/applications/plot_species_distribution_modeling.py
@@ -109,7 +109,7 @@
 
 
 def plot_species_distribution(
-    species=("bradypus_variegatus_0", "microryzomys_minutus_0")
+    species=("bradypus_variegatus_0", "microryzomys_minutus_0"),
 ):
     """
     Plot the species distribution.

--- examples/ensemble/plot_bias_variance.py
+++ examples/ensemble/plot_bias_variance.py
@@ -177,8 +177,8 @@
 
     plt.subplot(2, n_estimators, n_estimators + n + 1)
     plt.plot(X_test, y_error, "r", label="$error(x)$")
-    plt.plot(X_test, y_bias, "b", label="$bias^2(x)$"),
-    plt.plot(X_test, y_var, "g", label="$variance(x)$"),
+    (plt.plot(X_test, y_bias, "b", label="$bias^2(x)$"),)
+    (plt.plot(X_test, y_var, "g", label="$variance(x)$"),)
     plt.plot(X_test, y_noise, "c", label="$noise(x)$")
 
     plt.xlim([-5, 5])

--- examples/linear_model/plot_tweedie_regression_insurance_claims.py
+++ examples/linear_model/plot_tweedie_regression_insurance_claims.py
@@ -606,8 +606,9 @@
             "predicted, frequency*severity model": np.sum(
                 exposure * glm_freq.predict(X) * glm_sev.predict(X)
             ),
-            "predicted, tweedie, power=%.2f"
-            % glm_pure_premium.power: np.sum(exposure * glm_pure_premium.predict(X)),
+            "predicted, tweedie, power=%.2f" % glm_pure_premium.power: np.sum(
+                exposure * glm_pure_premium.predict(X)
+            ),
         }
     )
 

--- examples/manifold/plot_lle_digits.py
+++ examples/manifold/plot_lle_digits.py
@@ -10,7 +10,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 # %%
 # Load digits dataset
 # -------------------

--- examples/manifold/plot_manifold_sphere.py
+++ examples/manifold/plot_manifold_sphere.py
@@ -50,7 +50,7 @@
 t = random_state.rand(n_samples) * np.pi
 
 # Sever the poles from the sphere.
-indices = (t < (np.pi - (np.pi / 8))) & (t > ((np.pi / 8)))
+indices = (t < (np.pi - (np.pi / 8))) & (t > (np.pi / 8))
 colors = p[indices]
 x, y, z = (
     np.sin(t[indices]) * np.cos(p[indices]),

--- sklearn/_loss/tests/test_loss.py
+++ sklearn/_loss/tests/test_loss.py
@@ -215,7 +215,8 @@
 
 
 @pytest.mark.parametrize(
-    "loss, y_pred_success, y_pred_fail", Y_COMMON_PARAMS + Y_PRED_PARAMS  # type: ignore
+    "loss, y_pred_success, y_pred_fail",
+    Y_COMMON_PARAMS + Y_PRED_PARAMS,  # type: ignore
 )
 def test_loss_boundary_y_pred(loss, y_pred_success, y_pred_fail):
     """Test boundaries of y_pred for loss functions."""
@@ -501,12 +502,14 @@
         sample_weight=sample_weight,
         loss_out=out_l1,
     )
-    loss.closs.loss(
-        y_true=y_true,
-        raw_prediction=raw_prediction,
-        sample_weight=sample_weight,
-        loss_out=out_l2,
-    ),
+    (
+        loss.closs.loss(
+            y_true=y_true,
+            raw_prediction=raw_prediction,
+            sample_weight=sample_weight,
+            loss_out=out_l2,
+        ),
+    )
     assert_allclose(out_l1, out_l2)
     loss.gradient(
         y_true=y_true,

--- sklearn/cluster/_feature_agglomeration.py
+++ sklearn/cluster/_feature_agglomeration.py
@@ -6,7 +6,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import numpy as np
 from scipy.sparse import issparse
 

--- sklearn/cross_decomposition/tests/test_pls.py
+++ sklearn/cross_decomposition/tests/test_pls.py
@@ -404,12 +404,12 @@
 
     X_orig = X.copy()
     with pytest.raises(AssertionError):
-        pls.transform(X, Y, copy=False),
+        (pls.transform(X, Y, copy=False),)
         assert_array_almost_equal(X, X_orig)
 
     X_orig = X.copy()
     with pytest.raises(AssertionError):
-        pls.predict(X, copy=False),
+        (pls.predict(X, copy=False),)
         assert_array_almost_equal(X, X_orig)
 
     # Make sure copy=True gives same transform and predictions as predict=False

--- sklearn/ensemble/_bagging.py
+++ sklearn/ensemble/_bagging.py
@@ -3,7 +3,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import itertools
 import numbers
 from abc import ABCMeta, abstractmethod

--- sklearn/ensemble/_forest.py
+++ sklearn/ensemble/_forest.py
@@ -35,7 +35,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import threading
 from abc import ABCMeta, abstractmethod
 from numbers import Integral, Real

--- sklearn/ensemble/tests/test_forest.py
+++ sklearn/ensemble/tests/test_forest.py
@@ -168,11 +168,12 @@
     reg = ForestRegressor(n_estimators=5, criterion=criterion, random_state=1)
     reg.fit(X_reg, y_reg)
     score = reg.score(X_reg, y_reg)
-    assert (
-        score > 0.93
-    ), "Failed with max_features=None, criterion %s and score = %f" % (
-        criterion,
-        score,
+    assert score > 0.93, (
+        "Failed with max_features=None, criterion %s and score = %f"
+        % (
+            criterion,
+            score,
+        )
     )
 
     reg = ForestRegressor(

--- sklearn/experimental/enable_hist_gradient_boosting.py
+++ sklearn/experimental/enable_hist_gradient_boosting.py
@@ -13,7 +13,6 @@
 # Don't remove this file, we don't want to break users code just because the
 # feature isn't experimental anymore.
 
-
 import warnings
 
 warnings.warn(

--- sklearn/feature_selection/_univariate_selection.py
+++ sklearn/feature_selection/_univariate_selection.py
@@ -3,7 +3,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import warnings
 from numbers import Integral, Real
 

--- sklearn/gaussian_process/tests/test_gpc.py
+++ sklearn/gaussian_process/tests/test_gpc.py
@@ -147,8 +147,9 @@
     # Define a dummy optimizer that simply tests 10 random hyperparameters
     def optimizer(obj_func, initial_theta, bounds):
         rng = np.random.RandomState(global_random_seed)
-        theta_opt, func_min = initial_theta, obj_func(
-            initial_theta, eval_gradient=False
+        theta_opt, func_min = (
+            initial_theta,
+            obj_func(initial_theta, eval_gradient=False),
         )
         for _ in range(10):
             theta = np.atleast_1d(

--- sklearn/gaussian_process/tests/test_gpr.py
+++ sklearn/gaussian_process/tests/test_gpr.py
@@ -394,8 +394,9 @@
     # Define a dummy optimizer that simply tests 50 random hyperparameters
     def optimizer(obj_func, initial_theta, bounds):
         rng = np.random.RandomState(0)
-        theta_opt, func_min = initial_theta, obj_func(
-            initial_theta, eval_gradient=False
+        theta_opt, func_min = (
+            initial_theta,
+            obj_func(initial_theta, eval_gradient=False),
         )
         for _ in range(50):
             theta = np.atleast_1d(

--- sklearn/linear_model/_linear_loss.py
+++ sklearn/linear_model/_linear_loss.py
@@ -537,9 +537,9 @@
                 # The L2 penalty enters the Hessian on the diagonal only. To add those
                 # terms, we use a flattened view of the array.
                 order = "C" if hess.flags.c_contiguous else "F"
-                hess.reshape(-1, order=order)[
-                    : (n_features * n_dof) : (n_dof + 1)
-                ] += l2_reg_strength
+                hess.reshape(-1, order=order)[: (n_features * n_dof) : (n_dof + 1)] += (
+                    l2_reg_strength
+                )
 
             if self.fit_intercept:
                 # With intercept included as added column to X, the hessian becomes

--- sklearn/linear_model/_ridge.py
+++ sklearn/linear_model/_ridge.py
@@ -5,7 +5,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import numbers
 import warnings
 from abc import ABCMeta, abstractmethod

--- sklearn/linear_model/_theil_sen.py
+++ sklearn/linear_model/_theil_sen.py
@@ -5,7 +5,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import warnings
 from itertools import combinations
 from numbers import Integral, Real

--- sklearn/manifold/_spectral_embedding.py
+++ sklearn/manifold/_spectral_embedding.py
@@ -3,7 +3,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import warnings
 from numbers import Integral, Real
 

--- sklearn/metrics/_ranking.py
+++ sklearn/metrics/_ranking.py
@@ -10,7 +10,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import warnings
 from functools import partial
 from numbers import Integral, Real

--- sklearn/metrics/cluster/_supervised.py
+++ sklearn/metrics/cluster/_supervised.py
@@ -7,7 +7,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import warnings
 from math import log
 from numbers import Real

--- sklearn/metrics/cluster/_unsupervised.py
+++ sklearn/metrics/cluster/_unsupervised.py
@@ -3,7 +3,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import functools
 from numbers import Integral
 

--- sklearn/metrics/tests/test_common.py
+++ sklearn/metrics/tests/test_common.py
@@ -976,7 +976,8 @@
 @pytest.mark.parametrize("metric", CLASSIFICATION_METRICS.values())
 @pytest.mark.parametrize(
     "y_true, y_score",
-    invalids_nan_inf +
+    invalids_nan_inf
+    +
     # Add an additional case for classification only
     # non-regression test for:
     # https://github.com/scikit-learn/scikit-learn/issues/6809
@@ -2075,7 +2076,6 @@
 
 
 def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name):
-
     X_np = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=dtype_name)
     Y_np = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]], dtype=dtype_name)
 

--- sklearn/model_selection/_validation.py
+++ sklearn/model_selection/_validation.py
@@ -6,7 +6,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import numbers
 import time
 import warnings

--- sklearn/multioutput.py
+++ sklearn/multioutput.py
@@ -8,7 +8,6 @@
 # Authors: The scikit-learn developers
 # SPDX-License-Identifier: BSD-3-Clause
 
-
 import warnings
 from abc import ABCMeta, abstractmethod
 from numbers import Integral
@@ -687,7 +686,6 @@
             )
 
         if self.base_estimator != "deprecated":
-
             warning_msg = (
                 "`base_estimator` as an argument was deprecated in 1.7 and will be"
                 " removed in 1.9. Use `estimator` instead."

--- sklearn/neighbors/tests/test_neighbors.py
+++ sklearn/neighbors/tests/test_neighbors.py
@@ -655,10 +655,12 @@
             assert_allclose(np.concatenate(list(ind)), np.concatenate(list(ind1)))
 
         for i in range(len(results) - 1):
-            assert_allclose(
-                np.concatenate(list(results[i][0])),
-                np.concatenate(list(results[i + 1][0])),
-            ),
+            (
+                assert_allclose(
+                    np.concatenate(list(results[i][0])),
+                    np.concatenate(list(results[i + 1][0])),
+                ),
+            )
             assert_allclose(
                 np.concatenate(list(results[i][1])),
                 np.concatenate(list(results[i + 1][1])),

--- sklearn/tests/test_common.py
+++ sklearn/tests/test_common.py
@@ -298,7 +298,6 @@
     "transformer", GET_FEATURES_OUT_ESTIMATORS, ids=_get_check_estimator_ids
 )
 def test_transformers_get_feature_names_out(transformer):
-
     with ignore_warnings(category=(FutureWarning)):
         check_transformer_get_feature_names_out(
             transformer.__class__.__name__, transformer

--- sklearn/tests/test_metaestimators.py
+++ sklearn/tests/test_metaestimators.py
@@ -157,11 +157,12 @@
             if method in delegator_data.skip_methods:
                 continue
             assert hasattr(delegate, method)
-            assert hasattr(
-                delegator, method
-            ), "%s does not have method %r when its delegate does" % (
-                delegator_data.name,
-                method,
+            assert hasattr(delegator, method), (
+                "%s does not have method %r when its delegate does"
+                % (
+                    delegator_data.name,
+                    method,
+                )
             )
             # delegation before fit raises a NotFittedError
             if method == "score":
@@ -191,11 +192,12 @@
             delegate = SubEstimator(hidden_method=method)
             delegator = delegator_data.construct(delegate)
             assert not hasattr(delegate, method)
-            assert not hasattr(
-                delegator, method
-            ), "%s has method %r when its delegate does not" % (
-                delegator_data.name,
-                method,
+            assert not hasattr(delegator, method), (
+                "%s has method %r when its delegate does not"
+                % (
+                    delegator_data.name,
+                    method,
+                )
             )
 
 

--- sklearn/utils/_metadata_requests.py
+++ sklearn/utils/_metadata_requests.py
@@ -1098,8 +1098,9 @@
             method_mapping = MethodMapping()
             for method in METHODS:
                 method_mapping.add(caller=method, callee=method)
-            yield "$self_request", RouterMappingPair(
-                mapping=method_mapping, router=self._self_request
+            yield (
+                "$self_request",
+                RouterMappingPair(mapping=method_mapping, router=self._self_request),
             )
         for name, route_mapping in self._route_mappings.items():
             yield (name, route_mapping)

--- sklearn/utils/estimator_checks.py
+++ sklearn/utils/estimator_checks.py
@@ -5339,9 +5339,7 @@
                 'Only binary classification is supported. The type of the target '
                 f'is {{y_type}}.'
         )
-    """.format(
-        name=name
-    )
+    """.format(name=name)
     err_msg = textwrap.dedent(err_msg)
 
     with raises(

--- sklearn/utils/tests/test_indexing.py
+++ sklearn/utils/tests/test_indexing.py
@@ -609,7 +609,6 @@
 
 
 def test_notimplementederror():
-
     with pytest.raises(
         NotImplementedError,
         match="Resampling with sample_weight is only implemented for replace=True.",

--- sklearn/utils/tests/test_multiclass.py
+++ sklearn/utils/tests/test_multiclass.py
@@ -416,12 +416,13 @@
 def test_type_of_target():
     for group, group_examples in EXAMPLES.items():
         for example in group_examples:
-            assert (
-                type_of_target(example) == group
-            ), "type_of_target(%r) should be %r, got %r" % (
-                example,
-                group,
-                type_of_target(example),
+            assert type_of_target(example) == group, (
+                "type_of_target(%r) should be %r, got %r"
+                % (
+                    example,
+                    group,
+                    type_of_target(example),
+                )
             )
 
     for example in NON_ARRAY_LIKE_EXAMPLES:

--- sklearn/utils/tests/test_seq_dataset.py
+++ sklearn/utils/tests/test_seq_dataset.py
@@ -154,30 +154,34 @@
 
 def test_buffer_dtype_mismatch_error():
     with pytest.raises(ValueError, match="Buffer dtype mismatch"):
-        ArrayDataset64(X32, y32, sample_weight32, seed=42),
+        (ArrayDataset64(X32, y32, sample_weight32, seed=42),)
 
     with pytest.raises(ValueError, match="Buffer dtype mismatch"):
-        ArrayDataset32(X64, y64, sample_weight64, seed=42),
+        (ArrayDataset32(X64, y64, sample_weight64, seed=42),)
 
     for csr_container in CSR_CONTAINERS:
         X_csr32 = csr_container(X32)
         X_csr64 = csr_container(X64)
         with pytest.raises(ValueError, match="Buffer dtype mismatch"):
-            CSRDataset64(
-                X_csr32.data,
-                X_csr32.indptr,
-                X_csr32.indices,
-                y32,
-                sample_weight32,
-                seed=42,
-            ),
+            (
+                CSRDataset64(
+                    X_csr32.data,
+                    X_csr32.indptr,
+                    X_csr32.indices,
+                    y32,
+                    sample_weight32,
+                    seed=42,
+                ),
+            )
 
         with pytest.raises(ValueError, match="Buffer dtype mismatch"):
-            CSRDataset32(
-                X_csr64.data,
-                X_csr64.indptr,
-                X_csr64.indices,
-                y64,
-                sample_weight64,
-                seed=42,
-            ),
+            (
+                CSRDataset32(
+                    X_csr64.data,
+                    X_csr64.indptr,
+                    X_csr64.indices,
+                    y64,
+                    sample_weight64,
+                    seed=42,
+                ),
+            )

--- sklearn/utils/tests/test_tags.py
+++ sklearn/utils/tests/test_tags.py
@@ -565,7 +565,6 @@
     assert _to_new_tags(_to_old_tags(new_tags), estimator=estimator) == new_tags
 
     class MyClass:
-
         def fit(self, X, y=None):
             return self  # pragma: no cover
 

--- sklearn/utils/validation.py
+++ sklearn/utils/validation.py
@@ -1547,8 +1547,7 @@
         # hasattr(estimator, "fit") makes it so that we don't fail for an estimator
         # that does not have a `fit` method during collection of checks. The right
         # checks will fail later.
-        hasattr(estimator, "fit")
-        and parameter in signature(estimator.fit).parameters
+        hasattr(estimator, "fit") and parameter in signature(estimator.fit).parameters
     )
 
 

35 files would be reformatted, 885 files already formatted

Generated for commit: b170b47. Link to the linter CI: here

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for supporting arrow, and I fine with doing it this way as a pragmatic and incremental improvement pending a potential refactoring to make it more general in the future (to be discussed in a follow-up issue or PR).

However I am worried that this code is untested on our CI. Could you add PyArrow as a soft dependencies on one of the worker CIs (like we do for pandas / polars/ pytorch / array-api-strict...)?

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Mar 22, 2025

However I am worried that this code is untested on our CI.

I don't follow. scikit-learn/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml and pylatest_conda_forge_cuda_array-api_linux-64_environment.yml both specify pyarrow as requirement, so the code path of this PR is tested.
Maybe our codevoc is not clever enough (configured).

@lorentzenchr
Copy link
Member Author

Should the changelog be in sklearn.compose for the ColumnTransformer or in sklearn.utils for _safe_indexing?

@ogrisel
Copy link
Member

ogrisel commented Mar 27, 2025

Maybe our codevoc is not clever enough (configured).

Actually, everything is fine: the coverage data is collected on the CI config that runs with pylatest_conda_forge_mkl_linux-64_environment.yml and pyarrow. It's just that we actually do not cover the lines reported by codecov in the tests. For instance the NotImplementeError case is trick to cover.

Some lines such as the NotImplementedError branch are legitimately fine not to try to cover in the tests.

The others could maybe benefit from expanding the existing tests a bit but not big deal wither. Still +1 for merge on my side.

@lorentzenchr
Copy link
Member Author

@ogrisel Do we need a 2nd reviewer?

@lorentzenchr lorentzenchr added this to the 1.7 milestone Apr 14, 2025
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, given that there seems to be a consensus to use narwhals for our dataframe support tools, do you think that this PR is still relevant ?
We can easily fit it in 1.7 while it's unlikely for narwhals. But do we want to introduce this is we're going to rewrite it soon (ref #31049) ?

@lorentzenchr
Copy link
Member Author

do you think that this PR is still relevant ?

Yes, it’s a simple fix for a bug. If we later decide to go with narwhals, there is no cost of having this PR.

@lorentzenchr lorentzenchr changed the title FIX _safe_indexing for __dataframe__ interchange protocol FIX _safe_indexing for pyarrow Apr 21, 2025
@lorentzenchr
Copy link
Member Author

@ogrisel @jeremiedbb I still would like to have this PR in 1.7. But I changed my mind about the implementation. I replaced the "dataframe interchange protocol" indexing by proper pyarrow indexing. This is much cleaner.

I'm now convinced that the dataframe interchange protocol is not usable for anything else but interchange which we don't really do or need, we need an API for indexing and assigning. Therefore, I defer this topic to #31049.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants