Skip to content

MAINT Refactor vector sentinel into utils #22728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 1 addition & 65 deletions sklearn/metrics/_pairwise_distances_reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ from ..utils._heap cimport heap_push
from ..utils._sorting cimport simultaneous_sort
from ..utils._openmp_helpers cimport _openmp_thread_num
from ..utils._typedefs cimport ITYPE_t, DTYPE_t
from ..utils._typedefs cimport ITYPECODE, DTYPECODE
from ..utils._vector_sentinel cimport vector_to_nd_array

from numbers import Integral, Real
from typing import List
Expand Down Expand Up @@ -76,70 +76,6 @@ ctypedef fused vector_vector_DITYPE_t:
vector[vector[DTYPE_t]]


cdef class StdVectorSentinel:
"""Wraps a reference to a vector which will be deallocated with this object.

When created, the StdVectorSentinel swaps the reference of its internal
vectors with the provided one (vec_ptr), thus making the StdVectorSentinel
manage the provided one's lifetime.
"""
pass


# We necessarily need to define two extension types extending StdVectorSentinel
# because we need to provide the dtype of the vector but can't use numeric fused types.
cdef class StdVectorSentinelDTYPE(StdVectorSentinel):
cdef vector[DTYPE_t] vec

@staticmethod
cdef StdVectorSentinel create_for(vector[DTYPE_t] * vec_ptr):
# This initializes the object directly without calling __init__
# See: https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#instantiation-from-existing-c-c-pointers # noqa
cdef StdVectorSentinelDTYPE sentinel = StdVectorSentinelDTYPE.__new__(StdVectorSentinelDTYPE)
sentinel.vec.swap(deref(vec_ptr))
return sentinel


cdef class StdVectorSentinelITYPE(StdVectorSentinel):
cdef vector[ITYPE_t] vec

@staticmethod
cdef StdVectorSentinel create_for(vector[ITYPE_t] * vec_ptr):
# This initializes the object directly without calling __init__
# See: https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#instantiation-from-existing-c-c-pointers # noqa
cdef StdVectorSentinelITYPE sentinel = StdVectorSentinelITYPE.__new__(StdVectorSentinelITYPE)
sentinel.vec.swap(deref(vec_ptr))
return sentinel


cdef np.ndarray vector_to_nd_array(vector_DITYPE_t * vect_ptr):
"""Create a numpy ndarray given a C++ vector.

The numpy array buffer is the one of the C++ vector.
A StdVectorSentinel is registered as the base object for the numpy array,
freeing the C++ vector it encapsulates when the numpy array is freed.
"""
typenum = DTYPECODE if vector_DITYPE_t is vector[DTYPE_t] else ITYPECODE
cdef:
np.npy_intp size = deref(vect_ptr).size()
np.ndarray arr = np.PyArray_SimpleNewFromData(1, &size, typenum,
deref(vect_ptr).data())
StdVectorSentinel sentinel

if vector_DITYPE_t is vector[DTYPE_t]:
sentinel = StdVectorSentinelDTYPE.create_for(vect_ptr)
else:
sentinel = StdVectorSentinelITYPE.create_for(vect_ptr)

# Makes the numpy array responsible of the life-cycle of its buffer.
# A reference to the StdVectorSentinel will be stolen by the call to
# `PyArray_SetBaseObject` below, so we increase its reference counter.
# See: https://docs.python.org/3/c-api/intro.html#reference-count-details
Py_INCREF(sentinel)
np.PyArray_SetBaseObject(arr, sentinel)
return arr


cdef np.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays(
shared_ptr[vector_vector_DITYPE_t] vecs
):
Expand Down
10 changes: 10 additions & 0 deletions sklearn/utils/_vector_sentinel.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
cimport numpy as np

from libcpp.vector cimport vector
from ..utils._typedefs cimport ITYPE_t, DTYPE_t

ctypedef fused vector_typed:
vector[DTYPE_t]
vector[ITYPE_t]

cdef np.ndarray vector_to_nd_array(vector_typed * vect_ptr)
80 changes: 80 additions & 0 deletions sklearn/utils/_vector_sentinel.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from cython.operator cimport dereference as deref
from cpython.ref cimport Py_INCREF
cimport numpy as np

from ._typedefs cimport DTYPECODE, ITYPECODE

np.import_array()


cdef StdVectorSentinel _create_sentinel(vector_typed * vect_ptr):
if vector_typed is vector[DTYPE_t]:
return StdVectorSentinelFloat64.create_for(vect_ptr)
else:
return StdVectorSentinelIntP.create_for(vect_ptr)


cdef class StdVectorSentinel:
"""Wraps a reference to a vector which will be deallocated with this object.

When created, the StdVectorSentinel swaps the reference of its internal
vectors with the provided one (vect_ptr), thus making the StdVectorSentinel
manage the provided one's lifetime.
"""
cdef void* get_data(self):
"""Return pointer to data."""

cdef int get_typenum(self):
"""Get typenum for PyArray_SimpleNewFromData."""


cdef class StdVectorSentinelFloat64(StdVectorSentinel):
cdef vector[DTYPE_t] vec

@staticmethod
cdef StdVectorSentinel create_for(vector[DTYPE_t] * vect_ptr):
# This initializes the object directly without calling __init__
# See: https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#instantiation-from-existing-c-c-pointers # noqa
cdef StdVectorSentinelFloat64 sentinel = StdVectorSentinelFloat64.__new__(StdVectorSentinelFloat64)
sentinel.vec.swap(deref(vect_ptr))
return sentinel

cdef void* get_data(self):
return self.vec.data()

cdef int get_typenum(self):
return DTYPECODE


cdef class StdVectorSentinelIntP(StdVectorSentinel):
cdef vector[ITYPE_t] vec

@staticmethod
cdef StdVectorSentinel create_for(vector[ITYPE_t] * vect_ptr):
# This initializes the object directly without calling __init__
# See: https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#instantiation-from-existing-c-c-pointers # noqa
cdef StdVectorSentinelIntP sentinel = StdVectorSentinelIntP.__new__(StdVectorSentinelIntP)
sentinel.vec.swap(deref(vect_ptr))
return sentinel

cdef void* get_data(self):
return self.vec.data()

cdef int get_typenum(self):
return ITYPECODE


cdef np.ndarray vector_to_nd_array(vector_typed * vect_ptr):
cdef:
np.npy_intp size = deref(vect_ptr).size()
StdVectorSentinel sentinel = _create_sentinel(vect_ptr)
np.ndarray arr = np.PyArray_SimpleNewFromData(
1, &size, sentinel.get_typenum(), sentinel.get_data())
Comment on lines +70 to +72
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is slightly more clear than the current implementation. The get_data makes it clear that arr points to the data owned by the sentinel.

The implementation on main defines arr with the buffer from vect_ptr and then the sentinel would set the internal pointer in sentinel.vec to vect_ptr. Because only the pointers were swapped, arr is pointing to the correct place in memory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.


# Makes the numpy array responsible of the life-cycle of its buffer.
# A reference to the StdVectorSentinel will be stolen by the call to
# `PyArray_SetBaseObject` below, so we increase its reference counter.
# See: https://docs.python.org/3/c-api/intro.html#reference-count-details
Py_INCREF(sentinel)
np.PyArray_SetBaseObject(arr, sentinel)
return arr
8 changes: 8 additions & 0 deletions sklearn/utils/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ def configuration(parent_package="", top_path=None):
libraries=libraries,
)

config.add_extension(
"_vector_sentinel",
sources=["_vector_sentinel.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
language="c++",
)

config.add_subpackage("tests")

return config
Expand Down