diff --git a/numpy/__init__.cython-30.pxd b/numpy/__init__.cython-30.pxd index 0270f0ee988f..28e34ba12cec 100644 --- a/numpy/__init__.cython-30.pxd +++ b/numpy/__init__.cython-30.pxd @@ -740,7 +740,7 @@ cdef extern from "numpy/arrayobject.h": object PyArray_Arange (double, double, double, int) #object PyArray_ArangeObj (object, object, object, dtype) int PyArray_SortkindConverter (object, NPY_SORTKIND *) except 0 - object PyArray_LexSort (object, int) + object PyArray_LexSort (object, int, NPY_SORTKIND) object PyArray_Round (ndarray, int, ndarray) unsigned char PyArray_EquivTypenums (int, int) int PyArray_RegisterDataType (dtype) except -1 diff --git a/numpy/__init__.pxd b/numpy/__init__.pxd index aebb71fffa9c..b8f6a9d8cd89 100644 --- a/numpy/__init__.pxd +++ b/numpy/__init__.pxd @@ -655,7 +655,7 @@ cdef extern from "numpy/arrayobject.h": object PyArray_Arange (double, double, double, int) #object PyArray_ArangeObj (object, object, object, dtype) int PyArray_SortkindConverter (object, NPY_SORTKIND *) except 0 - object PyArray_LexSort (object, int) + object PyArray_LexSort (object, int, NPY_SORTKIND) object PyArray_Round (ndarray, int, ndarray) unsigned char PyArray_EquivTypenums (int, int) int PyArray_RegisterDataType (dtype) except -1 diff --git a/numpy/_core/multiarray.py b/numpy/_core/multiarray.py index 77e249a85828..3e77ca2dba40 100644 --- a/numpy/_core/multiarray.py +++ b/numpy/_core/multiarray.py @@ -434,11 +434,12 @@ def where(condition, x=None, y=None): @array_function_from_c_func_and_dispatcher(_multiarray_umath.lexsort) -def lexsort(keys, axis=None): +def lexsort(keys, axis=None, kind=None): """ lexsort(keys, axis=-1) - Perform an indirect stable sort using a sequence of keys. + Perform an indirect sort (stable by default - see "kind" parameter) using a + sequence of keys. Given multiple sorting keys, lexsort returns an array of integer indices that describes the sort order by multiple keys. The last key in the @@ -456,6 +457,8 @@ def lexsort(keys, axis=None): Axis to be indirectly sorted. By default, sort over the last axis of each sequence. Separate slices along `axis` sorted over independently; see last example. + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + Sorting algorithm. The default is 'stable'. See argsort for details. Returns ------- diff --git a/numpy/_core/multiarray.pyi b/numpy/_core/multiarray.pyi index 74cc86e64e79..97420d0629e6 100644 --- a/numpy/_core/multiarray.pyi +++ b/numpy/_core/multiarray.pyi @@ -42,6 +42,7 @@ from numpy import ( _OrderCF, _CastingKind, _ModeKind, + _SortKind, _SupportsBuffer, _IOProtocol, _CopyMode, @@ -390,6 +391,7 @@ def where( def lexsort( keys: ArrayLike, axis: None | SupportsIndex = ..., + kind: None | _SortKind = ..., ) -> Any: ... def can_cast( diff --git a/numpy/_core/src/multiarray/item_selection.c b/numpy/_core/src/multiarray/item_selection.c index 656688bda2fc..09a7dd1d8077 100644 --- a/numpy/_core/src/multiarray/item_selection.c +++ b/numpy/_core/src/multiarray/item_selection.c @@ -1793,7 +1793,7 @@ PyArray_ArgPartition(PyArrayObject *op, PyArrayObject *ktharray, int axis, *the given axis. */ NPY_NO_EXPORT PyObject * -PyArray_LexSort(PyObject *sort_keys, int axis) +PyArray_LexSort(PyObject *sort_keys, int axis, NPY_SORTKIND which) { PyArrayObject **mps; PyArrayIterObject **its; @@ -1849,7 +1849,7 @@ PyArray_LexSort(PyObject *sort_keys, int axis) goto fail; } } - if (!PyDataType_GetArrFuncs(PyArray_DESCR(mps[i]))->argsort[NPY_STABLESORT] + if (!PyDataType_GetArrFuncs(PyArray_DESCR(mps[i]))->argsort[which] && !PyDataType_GetArrFuncs(PyArray_DESCR(mps[i]))->compare) { PyErr_Format(PyExc_TypeError, "item %zd type does not have compare function", i); @@ -1966,9 +1966,22 @@ PyArray_LexSort(PyObject *sort_keys, int axis) int rcode; elsize = PyArray_ITEMSIZE(mps[j]); astride = PyArray_STRIDES(mps[j])[axis]; - argsort = PyDataType_GetArrFuncs(PyArray_DESCR(mps[j]))->argsort[NPY_STABLESORT]; - if(argsort == NULL) { - argsort = npy_atimsort; + argsort = PyDataType_GetArrFuncs(PyArray_DESCR(mps[j]))->argsort[which]; + if (argsort == NULL) { + if (PyDataType_GetArrFuncs(PyArray_DESCR(mps[j]))->compare) { + switch (which) { + default: + case NPY_QUICKSORT: + argsort = npy_aquicksort; + break; + case NPY_HEAPSORT: + argsort = npy_aheapsort; + break; + case NPY_STABLESORT: + argsort = npy_atimsort; + break; + } + } } _unaligned_strided_byte_copy(valbuffer, (npy_intp) elsize, its[j]->dataptr, astride, N, elsize); @@ -2001,9 +2014,22 @@ PyArray_LexSort(PyObject *sort_keys, int axis) } for (j = 0; j < n; j++) { int rcode; - argsort = PyDataType_GetArrFuncs(PyArray_DESCR(mps[j]))->argsort[NPY_STABLESORT]; - if(argsort == NULL) { - argsort = npy_atimsort; + argsort = PyDataType_GetArrFuncs(PyArray_DESCR(mps[j]))->argsort[which]; + if (argsort == NULL) { + if (PyDataType_GetArrFuncs(PyArray_DESCR(mps[j]))->compare) { + switch (which) { + default: + case NPY_QUICKSORT: + argsort = npy_aquicksort; + break; + case NPY_HEAPSORT: + argsort = npy_aheapsort; + break; + case NPY_STABLESORT: + argsort = npy_atimsort; + break; + } + } } rcode = argsort(its[j]->dataptr, (npy_intp *)rit->dataptr, N, mps[j]); diff --git a/numpy/_core/src/multiarray/multiarraymodule.c b/numpy/_core/src/multiarray/multiarraymodule.c index 4946465617bc..6abcc5d13600 100644 --- a/numpy/_core/src/multiarray/multiarraymodule.c +++ b/numpy/_core/src/multiarray/multiarraymodule.c @@ -3471,16 +3471,23 @@ array_lexsort(PyObject *NPY_UNUSED(ignored), PyObject *const *args, Py_ssize_t l PyObject *kwnames) { int axis = -1; + NPY_SORTKIND sortkind = _NPY_SORT_UNDEFINED; PyObject *obj; NPY_PREPARE_ARGPARSER; if (npy_parse_arguments("lexsort", args, len_args, kwnames, "keys", NULL, &obj, "|axis", PyArray_PythonPyIntFromInt, &axis, + "|kind", &PyArray_SortkindConverter, &sortkind, NULL, NULL, NULL) < 0) { return NULL; } - return PyArray_Return((PyArrayObject *)PyArray_LexSort(obj, axis)); + + if (sortkind == _NPY_SORT_UNDEFINED) { + sortkind = NPY_STABLESORT; + } + + return PyArray_Return((PyArrayObject *)PyArray_LexSort(obj, axis, sortkind)); } static PyObject *