From db5bbc3ab5eeb87c33c1b854e759cfbdcb5e48c7 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 18 Oct 2021 10:16:12 +0200 Subject: [PATCH 1/3] wip --- sklearn/_build_utils/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/_build_utils/__init__.py b/sklearn/_build_utils/__init__.py index be83b4c4d8baf..67b5f2c662eb0 100644 --- a/sklearn/_build_utils/__init__.py +++ b/sklearn/_build_utils/__init__.py @@ -76,7 +76,14 @@ def cythonize_extensions(top_path, config): compile_time_env={ "SKLEARN_OPENMP_PARALLELISM_ENABLED": sklearn._OPENMP_SUPPORTED }, - compiler_directives={"language_level": 3}, + compiler_directives={ + "language_level": 3, + "boundscheck": False, + "wraparound": False, + "initializedcheck": False, + "nonecheck": False, + "cdivision": True, + }, ) From ea6ef6e0ca2c167ebbb70ac9e2619ea1976642f0 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Sat, 30 Oct 2021 05:11:07 +0200 Subject: [PATCH 2/3] global cython compiler directives --- sklearn/_isotonic.pyx | 2 -- sklearn/cluster/_dbscan_inner.pyx | 2 -- sklearn/cluster/_hierarchical_fast.pyx | 26 +++---------------- sklearn/cluster/_k_means_common.pxd | 3 --- sklearn/cluster/_k_means_common.pyx | 4 --- sklearn/cluster/_k_means_elkan.pyx | 2 -- sklearn/cluster/_k_means_lloyd.pyx | 2 -- sklearn/cluster/_k_means_minibatch.pyx | 2 -- sklearn/datasets/_svmlight_format_fast.pyx | 2 -- sklearn/decomposition/_cdnmf_fast.pyx | 4 --- sklearn/decomposition/_online_lda_fast.pyx | 4 --- sklearn/ensemble/_gradient_boosting.pyx | 4 --- .../_hist_gradient_boosting/_binning.pyx | 6 ----- .../_hist_gradient_boosting/_bitset.pxd | 1 - .../_hist_gradient_boosting/_bitset.pyx | 4 --- .../_gradient_boosting.pyx | 5 ---- .../_hist_gradient_boosting/_loss.pyx | 5 ---- .../_hist_gradient_boosting/_predictor.pyx | 5 ---- .../_hist_gradient_boosting/common.pxd | 1 - .../_hist_gradient_boosting/histogram.pyx | 4 --- .../_hist_gradient_boosting/splitting.pyx | 6 ----- .../_hist_gradient_boosting/utils.pyx | 4 --- sklearn/feature_extraction/_hashing_fast.pyx | 4 +-- sklearn/linear_model/_cd_fast.pyx | 2 -- sklearn/linear_model/_sag_fast.pyx.tp | 4 --- sklearn/linear_model/_sgd_fast.pyx | 4 --- sklearn/manifold/_barnes_hut_tsne.pyx | 4 --- sklearn/manifold/_utils.pyx | 2 -- sklearn/metrics/_dist_metrics.pxd | 5 ---- sklearn/metrics/_dist_metrics.pyx | 5 ---- .../cluster/_expected_mutual_info_fast.pyx | 4 --- sklearn/neighbors/_quad_tree.pxd | 4 --- sklearn/neighbors/_quad_tree.pyx | 4 --- .../_csr_polynomial_expansion.pyx | 4 --- sklearn/svm/_libsvm.pyx | 5 +--- sklearn/tree/_criterion.pyx | 4 --- sklearn/tree/_splitter.pyx | 4 --- sklearn/tree/_tree.pyx | 4 --- sklearn/tree/_utils.pyx | 4 --- sklearn/utils/_fast_dict.pyx | 2 -- sklearn/utils/_random.pyx | 5 +--- sklearn/utils/_seq_dataset.pyx.tp | 3 --- sklearn/utils/_weight_vector.pyx.tp | 5 ---- sklearn/utils/murmurhash.pyx | 2 -- sklearn/utils/sparsefuncs_fast.pyx | 1 - 45 files changed, 6 insertions(+), 177 deletions(-) diff --git a/sklearn/_isotonic.pyx b/sklearn/_isotonic.pyx index 7f60b889fa284..1f9364cd92940 100644 --- a/sklearn/_isotonic.pyx +++ b/sklearn/_isotonic.pyx @@ -3,8 +3,6 @@ # Uses the pool adjacent violators algorithm (PAVA), with the # enhancement of searching for the longest decreasing subsequence to # pool at each step. -# -# cython: boundscheck=False, wraparound=False, cdivision=True import numpy as np cimport numpy as np diff --git a/sklearn/cluster/_dbscan_inner.pyx b/sklearn/cluster/_dbscan_inner.pyx index b9a80686a76f8..63125f60d3af1 100644 --- a/sklearn/cluster/_dbscan_inner.pyx +++ b/sklearn/cluster/_dbscan_inner.pyx @@ -1,8 +1,6 @@ # Fast inner loop for DBSCAN. # Author: Lars Buitinck # License: 3-clause BSD -# -# cython: boundscheck=False, wraparound=False cimport cython from libcpp.vector cimport vector diff --git a/sklearn/cluster/_hierarchical_fast.pyx b/sklearn/cluster/_hierarchical_fast.pyx index 11ea3294c086a..b7d4343d7fdd2 100644 --- a/sklearn/cluster/_hierarchical_fast.pyx +++ b/sklearn/cluster/_hierarchical_fast.pyx @@ -8,9 +8,6 @@ ctypedef np.float64_t DOUBLE ctypedef np.npy_intp INTP ctypedef np.int8_t INT8 -# Numpy must be initialized. When using numpy from C or Cython you must -# _always_ do that, or you will have segfaults - np.import_array() from ..metrics._dist_metrics cimport DistanceMetric @@ -32,9 +29,6 @@ from numpy.math cimport INFINITY ############################################################################### # Utilities for computing the ward momentum -@cython.boundscheck(False) -@cython.wraparound(False) -@cython.cdivision(True) def compute_ward_dist(np.ndarray[DOUBLE, ndim=1, mode='c'] m_1, np.ndarray[DOUBLE, ndim=2, mode='c'] m_2, np.ndarray[INTP, ndim=1, mode='c'] coord_row, @@ -101,8 +95,6 @@ def _hc_get_descendent(INTP node, children, INTP n_leaves): return descendent -@cython.boundscheck(False) -@cython.wraparound(False) def hc_get_heads(np.ndarray[INTP, ndim=1] parents, copy=True): """Returns the heads of the forest, as defined by parents. @@ -135,8 +127,6 @@ def hc_get_heads(np.ndarray[INTP, ndim=1] parents, copy=True): return parents -@cython.boundscheck(False) -@cython.wraparound(False) def _get_parents(nodes, heads, np.ndarray[INTP, ndim=1] parents, np.ndarray[INT8, ndim=1, mode='c'] not_visited): """Returns the heads of the given nodes, as defined by parents. @@ -176,8 +166,6 @@ def _get_parents(nodes, heads, np.ndarray[INTP, ndim=1] parents, # as keys and edge weights as values. -@cython.boundscheck(False) -@cython.wraparound(False) def max_merge(IntFloatDict a, IntFloatDict b, np.ndarray[ITYPE_t, ndim=1] mask, ITYPE_t n_a, ITYPE_t n_b): @@ -231,8 +219,6 @@ def max_merge(IntFloatDict a, IntFloatDict b, return out_obj -@cython.boundscheck(False) -@cython.wraparound(False) def average_merge(IntFloatDict a, IntFloatDict b, np.ndarray[ITYPE_t, ndim=1] mask, ITYPE_t n_a, ITYPE_t n_b): @@ -302,7 +288,6 @@ cdef class WeightedEdge: self.a = a self.b = b - @cython.nonecheck(False) def __richcmp__(self, WeightedEdge other, int op): """Cython-specific comparison method. @@ -348,8 +333,6 @@ cdef class UnionFind(object): self.size = np.hstack((np.ones(N, dtype=ITYPE), np.zeros(N - 1, dtype=ITYPE))) - @cython.boundscheck(False) - @cython.nonecheck(False) cdef void union(self, ITYPE_t m, ITYPE_t n): self.parent[m] = self.next_label self.parent[n] = self.next_label @@ -358,8 +341,7 @@ cdef class UnionFind(object): return - @cython.boundscheck(False) - @cython.nonecheck(False) + @cython.wraparound(True) cdef ITYPE_t fast_find(self, ITYPE_t n): cdef ITYPE_t p p = n @@ -371,8 +353,7 @@ cdef class UnionFind(object): p, self.parent[p] = self.parent[p], n return n -@cython.boundscheck(False) -@cython.nonecheck(False) + cpdef np.ndarray[DTYPE_t, ndim=2] _single_linkage_label( np.ndarray[DTYPE_t, ndim=2] L): """ @@ -423,6 +404,7 @@ cpdef np.ndarray[DTYPE_t, ndim=2] _single_linkage_label( return result_arr +@cython.wraparound(True) def single_linkage_label(L): """ Convert an linkage array or MST to a tree by labelling clusters at merges. @@ -452,8 +434,6 @@ def single_linkage_label(L): # Implements MST-LINKAGE-CORE from https://arxiv.org/abs/1109.2378 -@cython.boundscheck(False) -@cython.nonecheck(False) def mst_linkage_core( const DTYPE_t [:, ::1] raw_data, DistanceMetric dist_metric): diff --git a/sklearn/cluster/_k_means_common.pxd b/sklearn/cluster/_k_means_common.pxd index db70278860097..8eefa10e64e78 100644 --- a/sklearn/cluster/_k_means_common.pxd +++ b/sklearn/cluster/_k_means_common.pxd @@ -1,6 +1,3 @@ -# cython: language_level=3 - - from cython cimport floating cimport numpy as np diff --git a/sklearn/cluster/_k_means_common.pyx b/sklearn/cluster/_k_means_common.pyx index 9e8f81c9f2625..327a7ed60cb84 100644 --- a/sklearn/cluster/_k_means_common.pyx +++ b/sklearn/cluster/_k_means_common.pyx @@ -1,7 +1,3 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True -# Profiling is enabled by default as the overhead does not seem to be -# measurable on this specific use case. - # Author: Peter Prettenhofer # Olivier Grisel # Lars Buitinck diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 9459d5e9fc316..4b1ca35c15db2 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -1,5 +1,3 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True -# # Author: Andreas Mueller # # Licence: BSD 3 clause diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index e3526888c82ab..9611614c6239f 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -1,5 +1,3 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True -# # Licence: BSD 3 clause # TODO: We still need to use ndarrays instead of typed memoryviews when using diff --git a/sklearn/cluster/_k_means_minibatch.pyx b/sklearn/cluster/_k_means_minibatch.pyx index ffae55e3b3b46..c88c0d3c40828 100644 --- a/sklearn/cluster/_k_means_minibatch.pyx +++ b/sklearn/cluster/_k_means_minibatch.pyx @@ -1,5 +1,3 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True - # TODO: We still need to use ndarrays instead of typed memoryviews when using # fused types and when the array may be read-only (for instance when it's # provided by the user). This will be fixed in cython >= 0.3. diff --git a/sklearn/datasets/_svmlight_format_fast.pyx b/sklearn/datasets/_svmlight_format_fast.pyx index 9644ecbbd20a5..12d222f8cf581 100644 --- a/sklearn/datasets/_svmlight_format_fast.pyx +++ b/sklearn/datasets/_svmlight_format_fast.pyx @@ -4,8 +4,6 @@ # Lars Buitinck # Olivier Grisel # License: BSD 3 clause -# -# cython: boundscheck=False, wraparound=False import array from cpython cimport array diff --git a/sklearn/decomposition/_cdnmf_fast.pyx b/sklearn/decomposition/_cdnmf_fast.pyx index 9c6b171096ced..c50e09e1632c7 100644 --- a/sklearn/decomposition/_cdnmf_fast.pyx +++ b/sklearn/decomposition/_cdnmf_fast.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False - # Author: Mathieu Blondel, Tom Dupre la Tour # License: BSD 3 clause diff --git a/sklearn/decomposition/_online_lda_fast.pyx b/sklearn/decomposition/_online_lda_fast.pyx index 1c00af02d2375..446232a57f084 100644 --- a/sklearn/decomposition/_online_lda_fast.pyx +++ b/sklearn/decomposition/_online_lda_fast.pyx @@ -1,6 +1,3 @@ -# -# cython: boundscheck=False, wraparound=False - cimport cython cimport numpy as np import numpy as np @@ -91,7 +88,6 @@ def _dirichlet_expectation_2d(np.ndarray[ndim=2, dtype=np.float64_t] arr): # # After: J. Bernardo (1976). Algorithm AS 103: Psi (Digamma) Function. # https://www.uv.es/~bernardo/1976AppStatist.pdf -@cython.cdivision(True) cdef double psi(double x) nogil: if x <= 1e-6: # psi(x) = -EULER - 1/x + O(x) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 2e335fec62705..5942e30c701ce 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# # Author: Peter Prettenhofer # # License: BSD 3 clause diff --git a/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx b/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx index 5f5dd68935fd4..33cf0dadae011 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx @@ -1,9 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: nonecheck=False -# cython: language_level=3 - # Author: Nicolas Hug cimport cython diff --git a/sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd b/sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd index 0ea642df3ddcf..4aea8276c4398 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd +++ b/sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd @@ -1,4 +1,3 @@ -# cython: language_level=3 from .common cimport X_BINNED_DTYPE_C from .common cimport BITSET_DTYPE_C from .common cimport BITSET_INNER_DTYPE_C diff --git a/sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx b/sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx index 2df03b047aad1..0d3b630f3314f 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: language_level=3 from .common cimport BITSET_INNER_DTYPE_C from .common cimport BITSET_DTYPE_C from .common cimport X_DTYPE_C diff --git a/sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx b/sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx index f684ca57e560d..8170c8dc462e9 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx @@ -1,8 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: language_level=3 - # Author: Nicolas Hug cimport cython diff --git a/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx b/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx index da900e28c6457..23e7d2841443b 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx @@ -1,8 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: language_level=3 - # Author: Nicolas Hug cimport cython diff --git a/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx b/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx index a6b2f8b90de8e..5aee8620e34d1 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx @@ -1,8 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: language_level=3 - # Author: Nicolas Hug cimport cython diff --git a/sklearn/ensemble/_hist_gradient_boosting/common.pxd b/sklearn/ensemble/_hist_gradient_boosting/common.pxd index 6122b961fb91c..16ff7645aa740 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/common.pxd +++ b/sklearn/ensemble/_hist_gradient_boosting/common.pxd @@ -1,4 +1,3 @@ -# cython: language_level=3 import numpy as np cimport numpy as np diff --git a/sklearn/ensemble/_hist_gradient_boosting/histogram.pyx b/sklearn/ensemble/_hist_gradient_boosting/histogram.pyx index 54cfdcc077dc7..e6cdafd2d46bf 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/histogram.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/histogram.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: language_level=3 """This module contains routines for building histograms.""" # Author: Nicolas Hug diff --git a/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx b/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx index 232cf094876cb..3ba6d7a0ce9df 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx @@ -1,8 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: language_level=3 - """This module contains routines and data structures to: - Find the best possible split of a node. For a given node, a split is @@ -791,7 +786,6 @@ cdef class Splitter: split_info.sum_gradient_right, split_info.sum_hessian_right, lower_bound, upper_bound, self.l2_regularization) - @cython.initializedcheck(False) cdef void _find_best_bin_to_split_category( self, unsigned int feature_idx, diff --git a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx index 53aaa450c90ce..b2de7614fe499 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: language_level=3 """This module contains utility routines.""" # Author: Nicolas Hug diff --git a/sklearn/feature_extraction/_hashing_fast.pyx b/sklearn/feature_extraction/_hashing_fast.pyx index 3a3102444af98..722538fe166d3 100644 --- a/sklearn/feature_extraction/_hashing_fast.pyx +++ b/sklearn/feature_extraction/_hashing_fast.pyx @@ -1,7 +1,5 @@ # Author: Lars Buitinck # License: BSD 3 clause -# -# cython: boundscheck=False, cdivision=True import sys import array @@ -92,7 +90,7 @@ def transform(raw_X, Py_ssize_t n_features, dtype, indices_a = np.frombuffer(indices, dtype=np.int32) indptr_a = np.frombuffer(indptr, dtype=indices_np_dtype) - if indptr[-1] > np.iinfo(np.int32).max: # = 2**31 - 1 + if indptr[len(indptr) - 1] > np.iinfo(np.int32).max: # = 2**31 - 1 # both indices and indptr have the same dtype in CSR arrays indices_a = indices_a.astype(np.int64, copy=False) else: diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 4841809ac7aa7..338ca0f0cae7e 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -5,8 +5,6 @@ # Manoj Kumar # # License: BSD 3 clause -# -# cython: boundscheck=False, wraparound=False, cdivision=True from libc.math cimport fabs cimport numpy as np diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index b6493f5f32f96..756a048eea999 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -27,10 +27,6 @@ dtypes = [('64', 'double', 'np.float64'), #------------------------------------------------------------------------------ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# # Authors: Danny Sullivan # Tom Dupre la Tour # Arthur Mensch # Mathieu Blondel (partial_fit support) # Rob Zinkov (passive-aggressive) diff --git a/sklearn/manifold/_barnes_hut_tsne.pyx b/sklearn/manifold/_barnes_hut_tsne.pyx index 936a74373e735..2d314c0ccf3a5 100644 --- a/sklearn/manifold/_barnes_hut_tsne.pyx +++ b/sklearn/manifold/_barnes_hut_tsne.pyx @@ -1,7 +1,3 @@ -# cython: boundscheck=False -# cython: wraparound=False -# cython: cdivision=True -# # Author: Christopher Moody # Author: Nick Travers # Implementation by Chris Moody & Nick Travers diff --git a/sklearn/manifold/_utils.pyx b/sklearn/manifold/_utils.pyx index cd6ade795ae91..985aa3388d34c 100644 --- a/sklearn/manifold/_utils.pyx +++ b/sklearn/manifold/_utils.pyx @@ -1,5 +1,3 @@ -# cython: boundscheck=False - from libc cimport math cimport cython import numpy as np diff --git a/sklearn/metrics/_dist_metrics.pxd b/sklearn/metrics/_dist_metrics.pxd index 61bb4fb2fe011..611f6759e2c8b 100644 --- a/sklearn/metrics/_dist_metrics.pxd +++ b/sklearn/metrics/_dist_metrics.pxd @@ -1,8 +1,3 @@ -# cython: boundscheck=False -# cython: cdivision=True -# cython: initializedcheck=False -# cython: wraparound=False - cimport numpy as np from libc.math cimport sqrt, exp diff --git a/sklearn/metrics/_dist_metrics.pyx b/sklearn/metrics/_dist_metrics.pyx index a8fb4c45ddd0c..6bb279012e518 100644 --- a/sklearn/metrics/_dist_metrics.pyx +++ b/sklearn/metrics/_dist_metrics.pyx @@ -1,8 +1,3 @@ -# cython: boundscheck=False -# cython: cdivision=True -# cython: initializedcheck=False -# cython: wraparound=False - # By Jake Vanderplas (2013) # written for the scikit-learn project # License: BSD diff --git a/sklearn/metrics/cluster/_expected_mutual_info_fast.pyx b/sklearn/metrics/cluster/_expected_mutual_info_fast.pyx index d2f9cd8578b12..fbc910cb23b8c 100644 --- a/sklearn/metrics/cluster/_expected_mutual_info_fast.pyx +++ b/sklearn/metrics/cluster/_expected_mutual_info_fast.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# # Authors: Robert Layton # Corey Lynch # License: BSD 3 clause diff --git a/sklearn/neighbors/_quad_tree.pxd b/sklearn/neighbors/_quad_tree.pxd index 6f61b60cc0ab3..7287d5c420ca7 100644 --- a/sklearn/neighbors/_quad_tree.pxd +++ b/sklearn/neighbors/_quad_tree.pxd @@ -1,7 +1,3 @@ -# cython: boundscheck=False -# cython: wraparound=False -# cython: cdivision=True -# # Author: Thomas Moreau # Author: Olivier Grisel diff --git a/sklearn/neighbors/_quad_tree.pyx b/sklearn/neighbors/_quad_tree.pyx index 619467e69dd0c..6af7d1f547303 100644 --- a/sklearn/neighbors/_quad_tree.pyx +++ b/sklearn/neighbors/_quad_tree.pyx @@ -1,7 +1,3 @@ -# cython: boundscheck=False -# cython: wraparound=False -# cython: cdivision=True -# # Author: Thomas Moreau # Author: Olivier Grisel diff --git a/sklearn/preprocessing/_csr_polynomial_expansion.pyx b/sklearn/preprocessing/_csr_polynomial_expansion.pyx index 84fef3f042dc7..ef958b12266e1 100644 --- a/sklearn/preprocessing/_csr_polynomial_expansion.pyx +++ b/sklearn/preprocessing/_csr_polynomial_expansion.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False - # Author: Andrew nystrom from scipy.sparse import csr_matrix diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index 9488bda4ccf58..9186f0fcf7e29 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -18,10 +18,7 @@ where no sort of memory checks are done. Notes ----- -Maybe we could speed it a bit further by decorating functions with -@cython.boundscheck(False), but probably it is not worth since all -work is done in lisvm_helper.c -Also, the signature mode='c' is somewhat superficial, since we already +The signature mode='c' is somewhat superficial, since we already check that arrays are C-contiguous in svm.py Authors diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index db8a3cb821df3..2c115d0bd6ea1 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False - # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index d7b5e81c7b8f7..35ce58dce26ac 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False - # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index c8c58f12ffd3a..84fb808318a49 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False - # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index e6552debd3149..b80e7825ee6ab 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -1,7 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False - # Authors: Gilles Louppe # Peter Prettenhofer # Arnaud Joly diff --git a/sklearn/utils/_fast_dict.pyx b/sklearn/utils/_fast_dict.pyx index 719cafc3cc8c1..2bbf2dcfa667c 100644 --- a/sklearn/utils/_fast_dict.pyx +++ b/sklearn/utils/_fast_dict.pyx @@ -38,8 +38,6 @@ np.import_array() cdef class IntFloatDict: - @cython.boundscheck(False) - @cython.wraparound(False) def __init__(self, np.ndarray[ITYPE_t, ndim=1] keys, np.ndarray[DTYPE_t, ndim=1] values): cdef int i diff --git a/sklearn/utils/_random.pyx b/sklearn/utils/_random.pyx index 3caf062079211..be8ef0752ddd3 100644 --- a/sklearn/utils/_random.pyx +++ b/sklearn/utils/_random.pyx @@ -1,6 +1,3 @@ -# cython: boundscheck=False -# cython: wraparound=False -# # Author: Arnaud Joly # # License: BSD 3 clause @@ -278,7 +275,7 @@ cpdef sample_without_replacement(np.int_t n_population, all_methods = ("auto", "tracking_selection", "reservoir_sampling", "pool") - ratio = n_samples / n_population if n_population != 0.0 else 1.0 + ratio = n_samples / n_population if n_population != 0.0 else 1.0 # Check ratio and use permutation unless ratio < 0.01 or ratio > 0.99 if method == "auto" and ratio > 0.01 and ratio < 0.99: diff --git a/sklearn/utils/_seq_dataset.pyx.tp b/sklearn/utils/_seq_dataset.pyx.tp index 8bc901194a24e..9115f80c5265d 100644 --- a/sklearn/utils/_seq_dataset.pyx.tp +++ b/sklearn/utils/_seq_dataset.pyx.tp @@ -1,6 +1,3 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False {{py: """ diff --git a/sklearn/utils/_weight_vector.pyx.tp b/sklearn/utils/_weight_vector.pyx.tp index 186f391f8a955..ca552a1dff29e 100644 --- a/sklearn/utils/_weight_vector.pyx.tp +++ b/sklearn/utils/_weight_vector.pyx.tp @@ -18,11 +18,6 @@ dtypes = [('64', 'double', 1e-9), }} -# cython: language_level=3 -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: initializedcheck=False # cython: binding=False # # Author: Peter Prettenhofer diff --git a/sklearn/utils/murmurhash.pyx b/sklearn/utils/murmurhash.pyx index 0bce17737a090..dc9c3da08906f 100644 --- a/sklearn/utils/murmurhash.pyx +++ b/sklearn/utils/murmurhash.pyx @@ -55,7 +55,6 @@ cpdef np.int32_t murmurhash3_bytes_s32(bytes key, unsigned int seed): return out -@cython.boundscheck(False) cpdef np.ndarray[np.uint32_t, ndim=1] murmurhash3_bytes_array_u32( np.ndarray[np.int32_t] key, unsigned int seed): """Compute 32bit murmurhash3 hashes of a key int array at seed.""" @@ -67,7 +66,6 @@ cpdef np.ndarray[np.uint32_t, ndim=1] murmurhash3_bytes_array_u32( return out -@cython.boundscheck(False) cpdef np.ndarray[np.int32_t, ndim=1] murmurhash3_bytes_array_s32( np.ndarray[np.int32_t] key, unsigned int seed): """Compute 32bit murmurhash3 hashes of a key int array at seed.""" diff --git a/sklearn/utils/sparsefuncs_fast.pyx b/sklearn/utils/sparsefuncs_fast.pyx index 09677600cbbe4..ee12730d02b2d 100644 --- a/sklearn/utils/sparsefuncs_fast.pyx +++ b/sklearn/utils/sparsefuncs_fast.pyx @@ -7,7 +7,6 @@ # License: BSD 3 clause #!python -# cython: boundscheck=False, wraparound=False, cdivision=True from libc.math cimport fabs, sqrt, pow cimport numpy as np From ba20b7ca86b53aadf6a554f1d638040fe717aa88 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 3 Nov 2021 09:56:22 +0100 Subject: [PATCH 3/3] missing files --- sklearn/metrics/_pairwise_fast.pyx | 4 ---- sklearn/neighbors/_ball_tree.pyx | 6 ------ sklearn/neighbors/_kd_tree.pyx | 6 ------ sklearn/utils/_logistic_sigmoid.pyx | 4 ---- 4 files changed, 20 deletions(-) diff --git a/sklearn/metrics/_pairwise_fast.pyx b/sklearn/metrics/_pairwise_fast.pyx index f122972a15f89..76973529de818 100644 --- a/sklearn/metrics/_pairwise_fast.pyx +++ b/sklearn/metrics/_pairwise_fast.pyx @@ -1,7 +1,3 @@ -#cython: boundscheck=False -#cython: cdivision=True -#cython: wraparound=False -# # Author: Andreas Mueller # Lars Buitinck # Paolo Toccaceli diff --git a/sklearn/neighbors/_ball_tree.pyx b/sklearn/neighbors/_ball_tree.pyx index f8f1bd9e95f96..b5ac18365631a 100644 --- a/sklearn/neighbors/_ball_tree.pyx +++ b/sklearn/neighbors/_ball_tree.pyx @@ -1,9 +1,3 @@ -#!python -#cython: boundscheck=False -#cython: wraparound=False -#cython: cdivision=True -#cython: initializedcheck=False - # Author: Jake Vanderplas # License: BSD 3 clause diff --git a/sklearn/neighbors/_kd_tree.pyx b/sklearn/neighbors/_kd_tree.pyx index 5cdc071c38250..59199c41f2e85 100644 --- a/sklearn/neighbors/_kd_tree.pyx +++ b/sklearn/neighbors/_kd_tree.pyx @@ -1,9 +1,3 @@ -#!python -#cython: boundscheck=False -#cython: wraparound=False -#cython: cdivision=True -#cython: initializedcheck=False - # By Jake Vanderplas (2013) # written for the scikit-learn project # License: BSD diff --git a/sklearn/utils/_logistic_sigmoid.pyx b/sklearn/utils/_logistic_sigmoid.pyx index 3531d99bc4f44..c2ba685dbfcbd 100644 --- a/sklearn/utils/_logistic_sigmoid.pyx +++ b/sklearn/utils/_logistic_sigmoid.pyx @@ -1,7 +1,3 @@ -#cython: boundscheck=False -#cython: cdivision=True -#cython: wraparound=False - from libc.math cimport log, exp import numpy as np