From cd47a3fc7e326ba0a7c8a0654644e77a288538e8 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sat, 12 Feb 2022 10:51:35 -0600 Subject: [PATCH 01/17] ENH: Implement string comparison ufuncs (or almost) This makes all comparison operators and ufuncs work on strings using the ufunc machinery. It requires a half-manual "ufunc" to keep supporting void comparisons and especially `np.compare_chararrays` (that one may have a bit more overhead now). In general the new code should be much faster, and has a lot of easier optimization potential. It is also much simpler since it can outsource some complexities to the ufunc/iterator machinery. This further fixes a couple of bugs with byte-swapped strings. The backward compatibility related change is that using the normal ufunc machinery means that string comparisons between string and unicode now give a `FutureWarning` (instead of just False). --- .../include/numpy/experimental_dtype_api.h | 2 +- numpy/core/setup.py | 1 + numpy/core/src/common/numpyos.h | 8 + numpy/core/src/multiarray/array_method.h | 9 +- numpy/core/src/multiarray/arrayobject.c | 402 +--------------- numpy/core/src/multiarray/common_dtype.h | 8 + numpy/core/src/multiarray/convert_datatype.h | 12 +- numpy/core/src/multiarray/dtypemeta.h | 7 + .../experimental_public_dtype_api.c | 32 +- numpy/core/src/multiarray/multiarraymodule.c | 12 +- numpy/core/src/umath/dispatching.c | 32 ++ numpy/core/src/umath/dispatching.h | 9 + numpy/core/src/umath/string_ufuncs.cpp | 451 ++++++++++++++++++ numpy/core/src/umath/umathmodule.c | 10 + numpy/core/tests/test_deprecations.py | 2 +- numpy/core/tests/test_unicode.py | 9 +- 16 files changed, 572 insertions(+), 434 deletions(-) create mode 100644 numpy/core/src/umath/string_ufuncs.cpp diff --git a/numpy/core/include/numpy/experimental_dtype_api.h b/numpy/core/include/numpy/experimental_dtype_api.h index 1dd6215e6221..23e9a8d2160c 100644 --- a/numpy/core/include/numpy/experimental_dtype_api.h +++ b/numpy/core/include/numpy/experimental_dtype_api.h @@ -214,7 +214,7 @@ typedef struct { } PyArrayMethod_Spec; -typedef PyObject *_ufunc_addloop_fromspec_func( +typedef int _ufunc_addloop_fromspec_func( PyObject *ufunc, PyArrayMethod_Spec *spec); /* * The main ufunc registration function. This adds a new implementation/loop diff --git a/numpy/core/setup.py b/numpy/core/setup.py index 6454e641c4a4..a89faafea3f0 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -1082,6 +1082,7 @@ def generate_umath_doc_header(ext, build_dir): join('src', 'umath', 'scalarmath.c.src'), join('src', 'umath', 'ufunc_type_resolution.c'), join('src', 'umath', 'override.c'), + join('src', 'umath', 'string_ufuncs.cpp'), # For testing. Eventually, should use public API and be separate: join('src', 'umath', '_scaled_float_dtype.c'), ] diff --git a/numpy/core/src/common/numpyos.h b/numpy/core/src/common/numpyos.h index ce49cbea7f6e..6e526af17899 100644 --- a/numpy/core/src/common/numpyos.h +++ b/numpy/core/src/common/numpyos.h @@ -1,6 +1,10 @@ #ifndef NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_ #define NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_ +#ifdef __cplusplus +extern "C" { +#endif + NPY_NO_EXPORT char* NumPyOS_ascii_formatd(char *buffer, size_t buf_size, const char *format, @@ -39,4 +43,8 @@ NumPyOS_strtoll(const char *str, char **endptr, int base); NPY_NO_EXPORT npy_ulonglong NumPyOS_strtoull(const char *str, char **endptr, int base); +#ifdef __cplusplus +} +#endif + #endif /* NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_ */ diff --git a/numpy/core/src/multiarray/array_method.h b/numpy/core/src/multiarray/array_method.h index 30dd94a80b7d..6e6a026bc963 100644 --- a/numpy/core/src/multiarray/array_method.h +++ b/numpy/core/src/multiarray/array_method.h @@ -7,6 +7,9 @@ #include #include +#ifdef __cplusplus +extern "C" { +#endif typedef enum { /* Flag for whether the GIL is required */ @@ -249,6 +252,10 @@ PyArrayMethod_FromSpec(PyArrayMethod_Spec *spec); * need better tests when a public version is exposed. */ NPY_NO_EXPORT PyBoundArrayMethodObject * -PyArrayMethod_FromSpec_int(PyArrayMethod_Spec *spec, int private); +PyArrayMethod_FromSpec_int(PyArrayMethod_Spec *spec, int priv); + +#ifdef __cplusplus +} +#endif #endif /* NUMPY_CORE_SRC_MULTIARRAY_ARRAY_METHOD_H_ */ diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index 4c20fc1619b5..def2751dcbfa 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -645,375 +645,11 @@ PyArray_FailUnlessWriteable(PyArrayObject *obj, const char *name) return 0; } -/* This also handles possibly mis-aligned data */ -/* Compare s1 and s2 which are not necessarily NULL-terminated. - s1 is of length len1 - s2 is of length len2 - If they are NULL terminated, then stop comparison. -*/ -static int -_myunincmp(npy_ucs4 const *s1, npy_ucs4 const *s2, int len1, int len2) -{ - npy_ucs4 const *sptr; - npy_ucs4 *s1t = NULL; - npy_ucs4 *s2t = NULL; - int val; - npy_intp size; - int diff; - - /* Replace `s1` and `s2` with aligned copies if needed */ - if ((npy_intp)s1 % sizeof(npy_ucs4) != 0) { - size = len1*sizeof(npy_ucs4); - s1t = malloc(size); - memcpy(s1t, s1, size); - s1 = s1t; - } - if ((npy_intp)s2 % sizeof(npy_ucs4) != 0) { - size = len2*sizeof(npy_ucs4); - s2t = malloc(size); - memcpy(s2t, s2, size); - s2 = s1t; - } - - val = PyArray_CompareUCS4(s1, s2, PyArray_MIN(len1,len2)); - if ((val != 0) || (len1 == len2)) { - goto finish; - } - if (len2 > len1) { - sptr = s2+len1; - val = -1; - diff = len2-len1; - } - else { - sptr = s1+len2; - val = 1; - diff=len1-len2; - } - while (diff--) { - if (*sptr != 0) { - goto finish; - } - sptr++; - } - val = 0; - - finish: - /* Cleanup the aligned copies */ - if (s1t) { - free(s1t); - } - if (s2t) { - free(s2t); - } - return val; -} - - - - -/* - * Compare s1 and s2 which are not necessarily NULL-terminated. - * s1 is of length len1 - * s2 is of length len2 - * If they are NULL terminated, then stop comparison. - */ -static int -_mystrncmp(char const *s1, char const *s2, int len1, int len2) -{ - char const *sptr; - int val; - int diff; - - val = memcmp(s1, s2, PyArray_MIN(len1, len2)); - if ((val != 0) || (len1 == len2)) { - return val; - } - if (len2 > len1) { - sptr = s2 + len1; - val = -1; - diff = len2 - len1; - } - else { - sptr = s1 + len2; - val = 1; - diff = len1 - len2; - } - while (diff--) { - if (*sptr != 0) { - return val; - } - sptr++; - } - return 0; /* Only happens if NULLs are everywhere */ -} - -/* Borrowed from Numarray */ - -#define SMALL_STRING 2048 - -static void _rstripw(char *s, int n) -{ - int i; - for (i = n - 1; i >= 1; i--) { /* Never strip to length 0. */ - int c = s[i]; - - if (!c || NumPyOS_ascii_isspace((int)c)) { - s[i] = 0; - } - else { - break; - } - } -} - -static void _unistripw(npy_ucs4 *s, int n) -{ - int i; - for (i = n - 1; i >= 1; i--) { /* Never strip to length 0. */ - npy_ucs4 c = s[i]; - if (!c || NumPyOS_ascii_isspace((int)c)) { - s[i] = 0; - } - else { - break; - } - } -} - - -static char * -_char_copy_n_strip(char const *original, char *temp, int nc) -{ - if (nc > SMALL_STRING) { - temp = malloc(nc); - if (!temp) { - PyErr_NoMemory(); - return NULL; - } - } - memcpy(temp, original, nc); - _rstripw(temp, nc); - return temp; -} - -static void -_char_release(char *ptr, int nc) -{ - if (nc > SMALL_STRING) { - free(ptr); - } -} - -static char * -_uni_copy_n_strip(char const *original, char *temp, int nc) -{ - if (nc*sizeof(npy_ucs4) > SMALL_STRING) { - temp = malloc(nc*sizeof(npy_ucs4)); - if (!temp) { - PyErr_NoMemory(); - return NULL; - } - } - memcpy(temp, original, nc*sizeof(npy_ucs4)); - _unistripw((npy_ucs4 *)temp, nc); - return temp; -} - -static void -_uni_release(char *ptr, int nc) -{ - if (nc*sizeof(npy_ucs4) > SMALL_STRING) { - free(ptr); - } -} - - -/* End borrowed from numarray */ - -#define _rstrip_loop(CMP) { \ - void *aptr, *bptr; \ - char atemp[SMALL_STRING], btemp[SMALL_STRING]; \ - while(size--) { \ - aptr = stripfunc(iself->dataptr, atemp, N1); \ - if (!aptr) return -1; \ - bptr = stripfunc(iother->dataptr, btemp, N2); \ - if (!bptr) { \ - relfunc(aptr, N1); \ - return -1; \ - } \ - val = compfunc(aptr, bptr, N1, N2); \ - *dptr = (val CMP 0); \ - PyArray_ITER_NEXT(iself); \ - PyArray_ITER_NEXT(iother); \ - dptr += 1; \ - relfunc(aptr, N1); \ - relfunc(bptr, N2); \ - } \ - } - -#define _reg_loop(CMP) { \ - while(size--) { \ - val = compfunc((void *)iself->dataptr, \ - (void *)iother->dataptr, \ - N1, N2); \ - *dptr = (val CMP 0); \ - PyArray_ITER_NEXT(iself); \ - PyArray_ITER_NEXT(iother); \ - dptr += 1; \ - } \ - } - -static int -_compare_strings(PyArrayObject *result, PyArrayMultiIterObject *multi, - int cmp_op, void *func, int rstrip) -{ - PyArrayIterObject *iself, *iother; - npy_bool *dptr; - npy_intp size; - int val; - int N1, N2; - int (*compfunc)(void *, void *, int, int); - void (*relfunc)(char *, int); - char* (*stripfunc)(char const *, char *, int); - - compfunc = func; - dptr = (npy_bool *)PyArray_DATA(result); - iself = multi->iters[0]; - iother = multi->iters[1]; - size = multi->size; - N1 = PyArray_DESCR(iself->ao)->elsize; - N2 = PyArray_DESCR(iother->ao)->elsize; - if ((void *)compfunc == (void *)_myunincmp) { - N1 >>= 2; - N2 >>= 2; - stripfunc = _uni_copy_n_strip; - relfunc = _uni_release; - } - else { - stripfunc = _char_copy_n_strip; - relfunc = _char_release; - } - switch (cmp_op) { - case Py_EQ: - if (rstrip) { - _rstrip_loop(==); - } else { - _reg_loop(==); - } - break; - case Py_NE: - if (rstrip) { - _rstrip_loop(!=); - } else { - _reg_loop(!=); - } - break; - case Py_LT: - if (rstrip) { - _rstrip_loop(<); - } else { - _reg_loop(<); - } - break; - case Py_LE: - if (rstrip) { - _rstrip_loop(<=); - } else { - _reg_loop(<=); - } - break; - case Py_GT: - if (rstrip) { - _rstrip_loop(>); - } else { - _reg_loop(>); - } - break; - case Py_GE: - if (rstrip) { - _rstrip_loop(>=); - } else { - _reg_loop(>=); - } - break; - default: - PyErr_SetString(PyExc_RuntimeError, "bad comparison operator"); - return -1; - } - return 0; -} - -#undef _reg_loop -#undef _rstrip_loop -#undef SMALL_STRING +/* From umath/string_ufuncs.cpp/h */ NPY_NO_EXPORT PyObject * -_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, - int rstrip) -{ - PyArrayObject *result; - PyArrayMultiIterObject *mit; - int val; - - if (PyArray_TYPE(self) != PyArray_TYPE(other)) { - /* - * Comparison between Bytes and Unicode is not defined in Py3K; - * we follow. - */ - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } - if (PyArray_ISNOTSWAPPED(self) != PyArray_ISNOTSWAPPED(other)) { - /* Cast `other` to the same byte order as `self` (both unicode here) */ - PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(self)); - if (unicode == NULL) { - return NULL; - } - unicode->elsize = PyArray_DESCR(other)->elsize; - PyObject *new = PyArray_FromAny((PyObject *)other, - unicode, 0, 0, 0, NULL); - if (new == NULL) { - return NULL; - } - other = (PyArrayObject *)new; - } - else { - Py_INCREF(other); - } - - /* Broad-cast the arrays to a common shape */ - mit = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, self, other); - Py_DECREF(other); - if (mit == NULL) { - return NULL; - } - - result = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type, - PyArray_DescrFromType(NPY_BOOL), - mit->nd, - mit->dimensions, - NULL, NULL, 0, - NULL); - if (result == NULL) { - goto finish; - } - - if (PyArray_TYPE(self) == NPY_UNICODE) { - val = _compare_strings(result, mit, cmp_op, _myunincmp, rstrip); - } - else { - val = _compare_strings(result, mit, cmp_op, _mystrncmp, rstrip); - } - - if (val < 0) { - Py_DECREF(result); - result = NULL; - } - - finish: - Py_DECREF(mit); - return (PyObject *)result; -} +_umath_strings_richcompare( + PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip); /* * VOID-type arrays can only be compared equal and not-equal @@ -1144,7 +780,7 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op) } else { /* compare as a string. Assumes self and other have same descr->type */ - return _strings_richcompare(self, other, cmp_op, 0); + return _umath_strings_richcompare(self, other, cmp_op, 0); } } @@ -1278,36 +914,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) PyObject *obj_self = (PyObject *)self; PyObject *result = NULL; - /* Special case for string arrays (which don't and currently can't have - * ufunc loops defined, so there's no point in trying). - */ - if (PyArray_ISSTRING(self)) { - array_other = (PyArrayObject *)PyArray_FromObject(other, - NPY_NOTYPE, 0, 0); - if (array_other == NULL) { - PyErr_Clear(); - /* Never mind, carry on, see what happens */ - } - else if (!PyArray_ISSTRING(array_other)) { - Py_DECREF(array_other); - /* Never mind, carry on, see what happens */ - } - else { - result = _strings_richcompare(self, array_other, cmp_op, 0); - Py_DECREF(array_other); - return result; - } - /* If we reach this point, it means that we are not comparing - * string-to-string. It's possible that this will still work out, - * e.g. if the other array is an object array, then both will be cast - * to object or something? I don't know how that works actually, but - * it does, b/c this works: - * l = ["a", "b"] - * assert np.array(l, dtype="S1") == np.array(l, dtype="O") - * So we fall through and see what happens. - */ - } - switch (cmp_op) { case Py_LT: RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other); diff --git a/numpy/core/src/multiarray/common_dtype.h b/numpy/core/src/multiarray/common_dtype.h index 13d38ddf816a..9f25fc14ee3c 100644 --- a/numpy/core/src/multiarray/common_dtype.h +++ b/numpy/core/src/multiarray/common_dtype.h @@ -7,6 +7,10 @@ #include #include "dtypemeta.h" +#ifdef __cplusplus +extern "C" { +#endif + NPY_NO_EXPORT PyArray_DTypeMeta * PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2); @@ -14,4 +18,8 @@ NPY_NO_EXPORT PyArray_DTypeMeta * PyArray_PromoteDTypeSequence( npy_intp length, PyArray_DTypeMeta **dtypes_in); +#ifdef __cplusplus +} +#endif + #endif /* NUMPY_CORE_SRC_MULTIARRAY_COMMON_DTYPE_H_ */ diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index d1865d1c247e..af6d790cf254 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -3,6 +3,10 @@ #include "array_method.h" +#ifdef __cplusplus +extern "C" { +#endif + extern NPY_NO_EXPORT npy_intp REQUIRED_STR_LEN[]; NPY_NO_EXPORT PyObject * @@ -34,7 +38,7 @@ dtype_kind_to_ordering(char kind); /* Used by PyArray_CanCastArrayTo and in the legacy ufunc type resolution */ NPY_NO_EXPORT npy_bool can_cast_scalar_to(PyArray_Descr *scal_type, char *scal_data, - PyArray_Descr *to, NPY_CASTING casting); + PyArray_Descr *to, NPY_CASTING casting); NPY_NO_EXPORT int should_use_min_scalar(npy_intp narrs, PyArrayObject **arr, @@ -59,7 +63,7 @@ NPY_NO_EXPORT int PyArray_AddCastingImplementation(PyBoundArrayMethodObject *meth); NPY_NO_EXPORT int -PyArray_AddCastingImplementation_FromSpec(PyArrayMethod_Spec *spec, int private); +PyArray_AddCastingImplementation_FromSpec(PyArrayMethod_Spec *spec, int private_); NPY_NO_EXPORT NPY_CASTING PyArray_MinCastSafety(NPY_CASTING casting1, NPY_CASTING casting2); @@ -99,4 +103,8 @@ simple_cast_resolve_descriptors( NPY_NO_EXPORT int PyArray_InitializeCasts(void); +#ifdef __cplusplus +} +#endif + #endif /* NUMPY_CORE_SRC_MULTIARRAY_CONVERT_DATATYPE_H_ */ diff --git a/numpy/core/src/multiarray/dtypemeta.h b/numpy/core/src/multiarray/dtypemeta.h index e7d5505d851e..618491c98371 100644 --- a/numpy/core/src/multiarray/dtypemeta.h +++ b/numpy/core/src/multiarray/dtypemeta.h @@ -1,6 +1,9 @@ #ifndef NUMPY_CORE_SRC_MULTIARRAY_DTYPEMETA_H_ #define NUMPY_CORE_SRC_MULTIARRAY_DTYPEMETA_H_ +#ifdef __cplusplus +extern "C" { +#endif /* DType flags, currently private, since we may just expose functions */ #define NPY_DT_LEGACY 1 << 0 @@ -126,4 +129,8 @@ python_builtins_are_known_scalar_types( NPY_NO_EXPORT int dtypemeta_wrap_legacy_descriptor(PyArray_Descr *dtypem); +#ifdef __cplusplus +} +#endif + #endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPEMETA_H_ */ diff --git a/numpy/core/src/multiarray/experimental_public_dtype_api.c b/numpy/core/src/multiarray/experimental_public_dtype_api.c index cf5f152abe73..441dbdc1ff64 100644 --- a/numpy/core/src/multiarray/experimental_public_dtype_api.c +++ b/numpy/core/src/multiarray/experimental_public_dtype_api.c @@ -300,37 +300,13 @@ PyArrayInitDTypeMeta_FromSpec( } -/* Function is defined in umath/dispatching.c (same/one compilation unit) */ +/* Functions defined in umath/dispatching.c (same/one compilation unit) */ NPY_NO_EXPORT int PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate); -static int -PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec) -{ - if (!PyObject_TypeCheck(ufunc, &PyUFunc_Type)) { - PyErr_SetString(PyExc_TypeError, - "ufunc object passed is not a ufunc!"); - return -1; - } - PyBoundArrayMethodObject *bmeth = - (PyBoundArrayMethodObject *)PyArrayMethod_FromSpec(spec); - if (bmeth == NULL) { - return -1; - } - int nargs = bmeth->method->nin + bmeth->method->nout; - PyObject *dtypes = PyArray_TupleFromItems( - nargs, (PyObject **)bmeth->dtypes, 1); - if (dtypes == NULL) { - return -1; - } - PyObject *info = PyTuple_Pack(2, dtypes, bmeth->method); - Py_DECREF(bmeth); - Py_DECREF(dtypes); - if (info == NULL) { - return -1; - } - return PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0); -} +NPY_NO_EXPORT int +PyUFunc_AddLoopFromSpec(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate); + /* * Function is defined in umath/wrapping_array_method.c diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 97ed0ba2a66b..609446fd6749 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -85,6 +85,10 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; NPY_NO_EXPORT int initscalarmath(PyObject *); NPY_NO_EXPORT int set_matmul_flags(PyObject *d); /* in ufunc_object.c */ +/* From umath/string_ufuncs.cpp/h */ +NPY_NO_EXPORT PyObject * +_umath_strings_richcompare( + PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip); /* * global variable to determine if legacy printing is enabled, accessible from @@ -3726,6 +3730,12 @@ format_longfloat(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) TrimMode_LeaveOneZero, -1, -1); } + +/* + * The only purpose of this function is that it allows the "rstrip". + * From my (@seberg's) perspective, this function should be deprecated + * and I do not think it matters if it is not particularly fast. + */ static PyObject * compare_chararrays(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) { @@ -3791,7 +3801,7 @@ compare_chararrays(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) return NULL; } if (PyArray_ISSTRING(newarr) && PyArray_ISSTRING(newoth)) { - res = _strings_richcompare(newarr, newoth, cmp_op, rstrip != 0); + res = _umath_strings_richcompare(newarr, newoth, cmp_op, rstrip != 0); } else { PyErr_SetString(PyExc_TypeError, diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c index b8f102b3dff2..620335d88b3b 100644 --- a/numpy/core/src/umath/dispatching.c +++ b/numpy/core/src/umath/dispatching.c @@ -145,6 +145,38 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate) } +/* + * Add loop directly to a ufunc from a given ArrayMethod spec. + */ +NPY_NO_EXPORT int +PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec) +{ + if (!PyObject_TypeCheck(ufunc, &PyUFunc_Type)) { + PyErr_SetString(PyExc_TypeError, + "ufunc object passed is not a ufunc!"); + return -1; + } + PyBoundArrayMethodObject *bmeth = + (PyBoundArrayMethodObject *)PyArrayMethod_FromSpec(spec); + if (bmeth == NULL) { + return -1; + } + int nargs = bmeth->method->nin + bmeth->method->nout; + PyObject *dtypes = PyArray_TupleFromItems( + nargs, (PyObject **)bmeth->dtypes, 1); + if (dtypes == NULL) { + return -1; + } + PyObject *info = PyTuple_Pack(2, dtypes, bmeth->method); + Py_DECREF(bmeth); + Py_DECREF(dtypes); + if (info == NULL) { + return -1; + } + return PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0); +} + + /** * Resolves the implementation to use, this uses typical multiple dispatching * methods of finding the best matching implementation or resolver. diff --git a/numpy/core/src/umath/dispatching.h b/numpy/core/src/umath/dispatching.h index a7e9e88d0d73..f2ab0be2ed35 100644 --- a/numpy/core/src/umath/dispatching.h +++ b/numpy/core/src/umath/dispatching.h @@ -6,6 +6,9 @@ #include #include "array_method.h" +#ifdef __cplusplus +extern "C" { +#endif typedef int promoter_function(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], @@ -14,6 +17,9 @@ typedef int promoter_function(PyUFuncObject *ufunc, NPY_NO_EXPORT int PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate); +NPY_NO_EXPORT int +PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec); + NPY_NO_EXPORT PyArrayMethodObject * promote_and_get_ufuncimpl(PyUFuncObject *ufunc, PyArrayObject *const ops[], @@ -41,5 +47,8 @@ object_only_ufunc_promoter(PyUFuncObject *ufunc, NPY_NO_EXPORT int install_logical_ufunc_promoter(PyObject *ufunc); +#ifdef __cplusplus +} +#endif #endif /*_NPY_DISPATCHING_H */ diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp new file mode 100644 index 000000000000..e3960a85ce58 --- /dev/null +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -0,0 +1,451 @@ +#include + +#define NPY_NO_DEPRECATED_API NPY_API_VERSION +#define _MULTIARRAYMODULE +#define _UMATHMODULE + +#include "numpy/ndarraytypes.h" + +#include "numpyos.h" +#include "dispatching.h" +#include "dtypemeta.h" +#include "common_dtype.h" +#include "convert_datatype.h" + + +template +static NPY_INLINE int +character_cmp(character a, character b) +{ + if (a == b) { + return 0; + } + else if (a < b) { + return -1; + } + else { + return 1; + } +} + + +/* + * Compare two strings of different length. Note that either string may be + * zero padded (trailing zeros are ignored in other words, the shorter word + * is always padded with zeros). + */ +template +static NPY_INLINE int +string_cmp(int len1, character *str1, int len2, character *str2) +{ + if (rstrip) { + /* + * Ignore/"trim" trailing whitespace (and 0s). Note that this function + * does not support unicode whitespace (and never has). + */ + while (len1 > 0) { + character c = str1[len1-1]; + if (c != (character)0 && !NumPyOS_ascii_isspace(c)) { + break; + } + len1--; + } + while (len2 > 0) { + character c = str2[len2-1]; + if (c != (character)0 && !NumPyOS_ascii_isspace(c)) { + break; + } + len2--; + } + } + + int n = PyArray_MIN(len1, len2); + + for (int i = 0; i < n; i++) { + int cmp = character_cmp(*str1, *str2); + if (cmp != 0) { + return cmp; + } + str1++; + str2++; + } + if (len1 > len2) { + for (int i = n; i < len1; i++) { + int cmp = character_cmp(*str1, (character)0); + if (cmp != 0) { + return cmp; + } + str1++; + } + } + else if (len2 > len1) { + for (int i = n; i < len2; i++) { + int cmp = character_cmp((character)0, *str2); + if (cmp != 0) { + return cmp; + } + str2++; + } + } + return 0; +} + + +template +static int +string_comparison_loop(PyArrayMethod_Context *context, + char *const data[], npy_intp const dimensions[], + npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata)) +{ + /* + * Note, this works in CPython even without the GIL, however it may be that + * this will have to be moved into `auxdata` eventually, which may be + * slightly faster/cleaner (but also slightly more involved) in any case. + */ + int len1 = context->descriptors[0]->elsize / sizeof(character); + int len2 = context->descriptors[1]->elsize / sizeof(character); + + char *in1 = data[0]; + char *in2 = data[1]; + char *out = data[2]; + + npy_intp N = dimensions[0]; + + while (N--) { + int cmp = string_cmp( + len1, (character *)in1, len2, (character *)in2); + npy_bool res; + if (comp == Py_EQ) { + res = cmp == 0; + } + else if (comp == Py_NE) { + res = cmp != 0; + } + else if (comp == Py_LT) { + res = cmp < 0; + } + else if (comp == Py_LE) { + res = cmp <= 0; + } + else if (comp == Py_GT) { + res = cmp > 0; + } + else if (comp == Py_GE) { + res = cmp >= 0; + } + else { + assert(0); + } + *(npy_bool *)out = res; + + in1 += strides[0]; + in2 += strides[1]; + out += strides[2]; + } + return 0; +} + + +/* + * Machinery to add the string loops to the existing ufuncs. + */ + +/* + * This function replaces the strided loop with the passed in one, + * and registers it with the given ufunc. + */ +static int +add_loop( + PyObject *umath, const char *ufunc_name, + PyArrayMethod_Spec *spec, PyArrayMethod_StridedLoop *loop) +{ + PyObject *name = PyUnicode_FromString(ufunc_name); + if (name == nullptr) { + return -1; + } + PyObject *ufunc = PyObject_GetItem(umath, name); + Py_DECREF(name); + if (ufunc == nullptr) { + printf("%d\n", PyErr_Occurred() == nullptr); + PyObject_Print(PyErr_Occurred(), stdout, 0); + printf("\n"); + return -1; + } + spec->slots[0].pfunc = (void *)loop; + + int res = PyUFunc_AddLoopFromSpec(ufunc, spec); + Py_DECREF(ufunc); + return res; +} + + +extern "C" { + NPY_NO_EXPORT int + init_string_ufuncs(PyObject *umath); +} + +NPY_NO_EXPORT int +init_string_ufuncs(PyObject *umath) +{ + int res = -1; + /* NOTE: This should recieve global symbols? */ + PyArray_DTypeMeta *String = PyArray_DTypeFromTypeNum(NPY_STRING); + PyArray_DTypeMeta *Unicode = PyArray_DTypeFromTypeNum(NPY_UNICODE); + PyArray_DTypeMeta *Bool = PyArray_DTypeFromTypeNum(NPY_BOOL); + + /* We start with the string loops: */ + PyArray_DTypeMeta *dtypes[] = {String, String, Bool}; + /* + * We only have one loop right now, the strided one, the default type + * resolver ensures native byte order/canonical representation. + */ + PyType_Slot slots[] = { + {NPY_METH_strided_loop, nullptr}, + {0, nullptr} + }; + + PyArrayMethod_Spec spec = { + .name = "templated_string_comparison", + .nin = 2, + .nout = 1, + .dtypes = dtypes, + .slots = slots, + }; + + /* Use this loop variable for typing more explicitly */ + PyArrayMethod_StridedLoop *loop; + + /* TODO: It would be nice to condense the below */ + /* All String loops */ + loop = string_comparison_loop; + if (add_loop(umath, "equal", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "not_equal", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "less", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "less_equal", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "greater", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "greater_equal", &spec, loop) < 0) { + goto finish; + } + + /* All Unicode loops */ + dtypes[0] = Unicode; + dtypes[1] = Unicode; + + loop = string_comparison_loop; + if (add_loop(umath, "equal", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "not_equal", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "less", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "less_equal", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "greater", &spec, loop) < 0) { + goto finish; + } + loop = string_comparison_loop; + if (add_loop(umath, "greater_equal", &spec, loop) < 0) { + goto finish; + } + + res = 0; + finish: + Py_DECREF(String); + Py_DECREF(Unicode); + Py_DECREF(Bool); + return res; +} + + +template +static PyArrayMethod_StridedLoop * +get_strided_loop(int comp) +{ + if (comp == Py_EQ) { + return string_comparison_loop; + } + else if (comp == Py_NE) { + return string_comparison_loop; + } + else if (comp == Py_LT) { + return string_comparison_loop; + } + else if (comp == Py_LE) { + return string_comparison_loop; + } + else if (comp == Py_GT) { + return string_comparison_loop; + } + else if (comp == Py_GE) { + return string_comparison_loop; + } + assert(0); + return nullptr; +} + + +/* + * This function is used for `compare_chararrays` (and void comparisons + * currently). The first could probably be deprecated. + * + * The `rstrip` mechanism is presumably for some fortran compat, but the + * question is whether it would not be better to have/use `rstrip` on such + * an array first... + * + * NOTE: This function is also used for unstructured voids, this works because + * `npy_byte` works for it. + */ +extern "C" { + NPY_NO_EXPORT PyObject * + _umath_strings_richcompare( + PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip); +} + +NPY_NO_EXPORT PyObject * +_umath_strings_richcompare( + PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip) +{ + NpyIter *iter = nullptr; + PyObject *result = nullptr; + + char **dataptr = nullptr; + npy_intp *strides = nullptr; + npy_intp *countptr = nullptr; + npy_intp size = 0; + + PyArrayMethod_Context context = { + .caller = nullptr, + .method = nullptr, + .descriptors = nullptr, + }; + NpyIter_IterNextFunc *iternext = nullptr; + + npy_uint32 it_flags = ( + NPY_ITER_EXTERNAL_LOOP | NPY_ITER_ZEROSIZE_OK | + NPY_ITER_BUFFERED | NPY_ITER_GROWINNER); + npy_uint32 op_flags[3] = { + NPY_ITER_READONLY | NPY_ITER_ALIGNED, + NPY_ITER_READONLY | NPY_ITER_ALIGNED, + NPY_ITER_WRITEONLY | NPY_ITER_ALLOCATE | NPY_ITER_ALIGNED}; + + PyArrayMethod_StridedLoop *strided_loop = nullptr; + NPY_BEGIN_THREADS_DEF; + + if (PyArray_TYPE(self) != PyArray_TYPE(other)) { + /* + * Comparison between Bytes and Unicode is not defined in Py3K; + * we follow. + * TODO: This makes no sense at all for `compare_chararrays`, kept + * only under the assumption that we are more likely to deprecate + * than fix it to begin with. + */ + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + PyArrayObject *ops[3] = {self, other, nullptr}; + PyArray_Descr *descrs[3] = {nullptr, nullptr, PyArray_DescrFromType(NPY_BOOL)}; + /* ensure_dtype_nbo is in principle not necessary for == and !=: */ + descrs[0] = ensure_dtype_nbo(PyArray_DESCR(self)); + if (descrs[0] == nullptr) { + goto finish; + } + descrs[1] = ensure_dtype_nbo(PyArray_DESCR(other)); + if (descrs[1] == nullptr) { + goto finish; + } + + /* + * Create the iterator: + */ + iter = NpyIter_AdvancedNew( + 3, ops, it_flags, NPY_KEEPORDER, NPY_SAFE_CASTING, op_flags, descrs, + -1, nullptr, nullptr, 0); + if (iter == nullptr) { + goto finish; + } + + size = NpyIter_GetIterSize(iter); + if (size == 0) { + result = (PyObject *)NpyIter_GetOperandArray(iter)[2]; + Py_INCREF(result); + goto finish; + } + + iternext = NpyIter_GetIterNext(iter, nullptr); + if (iternext == nullptr) { + goto finish; + } + + /* + * Prepare the inner-loop and execute it (we only need descriptors to be + * passed in). + */ + context.descriptors = descrs; + + dataptr = NpyIter_GetDataPtrArray(iter); + strides = NpyIter_GetInnerStrideArray(iter); + countptr = NpyIter_GetInnerLoopSizePtr(iter); + + if (rstrip == 0) { + /* NOTE: Also used for VOID, so can be STRING, UNICODE, or VOID: */ + if (descrs[0]->type != NPY_UNICODE) { + strided_loop = get_strided_loop(cmp_op); + } + else { + strided_loop = get_strided_loop(cmp_op); + } + } + else { + if (descrs[0]->type != NPY_UNICODE) { + strided_loop = get_strided_loop(cmp_op); + } + else { + strided_loop = get_strided_loop(cmp_op); + } + } + + NPY_BEGIN_THREADS_THRESHOLDED(size); + + do { + /* We know the loop cannot fail */ + strided_loop(&context, dataptr, countptr, strides, nullptr); + } while (iternext(iter) != 0); + + NPY_END_THREADS; + + result = (PyObject *)NpyIter_GetOperandArray(iter)[2]; + Py_INCREF(result); + + finish: + if (NpyIter_Deallocate(iter) < 0) { + Py_CLEAR(result); + } + Py_XDECREF(descrs[0]); + Py_XDECREF(descrs[1]); + Py_XDECREF(descrs[2]); + return result; +} diff --git a/numpy/core/src/umath/umathmodule.c b/numpy/core/src/umath/umathmodule.c index e9b84df06e2d..754e22c79602 100644 --- a/numpy/core/src/umath/umathmodule.c +++ b/numpy/core/src/umath/umathmodule.c @@ -28,6 +28,11 @@ #include "funcs.inc" #include "__umath_generated.c" + +/* From string_ufuncs.cpp */ +NPY_NO_EXPORT int +init_string_ufuncs(PyObject *umath); + static PyUFuncGenericFunction pyfunc_functions[] = {PyUFunc_On_Om}; static int @@ -342,5 +347,10 @@ int initumath(PyObject *m) if (install_logical_ufunc_promoter(s) < 0) { return -1; } + + if (init_string_ufuncs(d) < 0) { + return -1; + } + return 0; } diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py index c46b294ebcb3..6d800de76506 100644 --- a/numpy/core/tests/test_deprecations.py +++ b/numpy/core/tests/test_deprecations.py @@ -166,7 +166,7 @@ def test_string(self): # For two string arrays, strings always raised the broadcasting error: a = np.array(['a', 'b']) b = np.array(['a', 'b', 'c']) - assert_raises(ValueError, lambda x, y: x == y, a, b) + assert_warns(FutureWarning, lambda x, y: x == y, a, b) # The empty list is not cast to string, and this used to pass due # to dtype mismatch; now (2018-06-21) it correctly leads to a diff --git a/numpy/core/tests/test_unicode.py b/numpy/core/tests/test_unicode.py index 8e0dd47cb077..12de25771dbc 100644 --- a/numpy/core/tests/test_unicode.py +++ b/numpy/core/tests/test_unicode.py @@ -1,3 +1,5 @@ +import pytest + import numpy as np from numpy.testing import assert_, assert_equal, assert_array_equal @@ -33,8 +35,11 @@ def test_string_cast(): uni_arr1 = str_arr.astype('>U') uni_arr2 = str_arr.astype(' Date: Sun, 13 Feb 2022 08:39:26 -0600 Subject: [PATCH 02/17] MAINT: Do not use C99 tagged struct init in C++ C++ does not like it (at least not before C++20)... GCC and clang don't seem to mind, but MSVC seems to. --- numpy/core/src/umath/string_ufuncs.cpp | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index e3960a85ce58..e2a97f35bcc9 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -166,9 +166,6 @@ add_loop( PyObject *ufunc = PyObject_GetItem(umath, name); Py_DECREF(name); if (ufunc == nullptr) { - printf("%d\n", PyErr_Occurred() == nullptr); - PyObject_Print(PyErr_Occurred(), stdout, 0); - printf("\n"); return -1; } spec->slots[0].pfunc = (void *)loop; @@ -196,7 +193,7 @@ init_string_ufuncs(PyObject *umath) /* We start with the string loops: */ PyArray_DTypeMeta *dtypes[] = {String, String, Bool}; /* - * We only have one loop right now, the strided one, the default type + * We only have one loop right now, the strided one. The default type * resolver ensures native byte order/canonical representation. */ PyType_Slot slots[] = { @@ -204,13 +201,13 @@ init_string_ufuncs(PyObject *umath) {0, nullptr} }; - PyArrayMethod_Spec spec = { - .name = "templated_string_comparison", - .nin = 2, - .nout = 1, - .dtypes = dtypes, - .slots = slots, - }; + PyArrayMethod_Spec spec = {}; + spec.name = "templated_string_comparison"; + spec.nin = 2; + spec.nout = 1; + spec.dtypes = dtypes; + spec.slots = slots; + spec.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; /* Use this loop variable for typing more explicitly */ PyArrayMethod_StridedLoop *loop; @@ -336,11 +333,7 @@ _umath_strings_richcompare( npy_intp *countptr = nullptr; npy_intp size = 0; - PyArrayMethod_Context context = { - .caller = nullptr, - .method = nullptr, - .descriptors = nullptr, - }; + PyArrayMethod_Context context = {}; NpyIter_IterNextFunc *iternext = nullptr; npy_uint32 it_flags = ( From 1f3c0fd97fa4b6df223be9ee6f4b0cc2b72a46bb Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sun, 13 Feb 2022 12:17:48 -0600 Subject: [PATCH 03/17] BENCH: Add basic string comparison benchmarks --- benchmarks/benchmarks/bench_strings.py | 45 ++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 benchmarks/benchmarks/bench_strings.py diff --git a/benchmarks/benchmarks/bench_strings.py b/benchmarks/benchmarks/bench_strings.py new file mode 100644 index 000000000000..e500d7f3f20c --- /dev/null +++ b/benchmarks/benchmarks/bench_strings.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import, division, print_function + +from .common import Benchmark + +import numpy as np +import operator + + +_OPERATORS = { + '==': operator.eq, + '!=': operator.ne, + '<': operator.lt, + '<=': operator.le, + '>': operator.gt, + '>=': operator.ge, +} + + +class StringComparisons(Benchmark): + # Basic string comparison speed tests + params = [ + [100, 10000, (1000, 20)], + ['U', 'S'], + [True, False], + ['==', '!=', '<', '<=', '>', '>=']] + param_names = ['shape', 'dtype', 'contig', 'operator'] + int64 = np.dtype(np.int64) + + def setup(self, shape, dtype, contig, operator): + self.arr = np.arange(np.prod(shape)).astype(dtype).reshape(shape) + self.arr_identical = self.arr.copy() + self.arr_different = self.arr[::-1].copy() + + if not contig: + self.arr = self.arr[..., ::2] + self.arr_identical = self.arr_identical[..., ::2] + self.arr_different = self.arr_different[..., ::2] + + self.operator = _OPERATORS[operator] + + def time_compare_identical(self, shape, dtype, contig, operator): + self.operator(self.arr, self.arr_identical) + + def time_compare_different(self, shape, dtype, contig, operator): + self.operator(self.arr, self.arr_different) From 0dbed94cda6bc30b561a0d360cb437f452a3c70b Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sun, 13 Feb 2022 12:27:39 -0600 Subject: [PATCH 04/17] DOC,STY: Fixup string-comparisons comments based on review Thanks to Marten's comments, a few clarfications and slight fixups. --- numpy/core/src/umath/string_ufuncs.cpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index e2a97f35bcc9..f87218c02c12 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -98,9 +98,9 @@ string_comparison_loop(PyArrayMethod_Context *context, npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata)) { /* - * Note, this works in CPython even without the GIL, however it may be that - * this will have to be moved into `auxdata` eventually, which may be - * slightly faster/cleaner (but also slightly more involved) in any case. + * Note, fetching `elsize` from the descriptor is OK even without the GIL, + * however it may be that this should be moved into `auxdata` eventually, + * which may also be slightly faster/cleaner (but more involved). */ int len1 = context->descriptors[0]->elsize / sizeof(character); int len2 = context->descriptors[1]->elsize / sizeof(character); @@ -155,9 +155,8 @@ string_comparison_loop(PyArrayMethod_Context *context, * and registers it with the given ufunc. */ static int -add_loop( - PyObject *umath, const char *ufunc_name, - PyArrayMethod_Spec *spec, PyArrayMethod_StridedLoop *loop) +add_loop(PyObject *umath, const char *ufunc_name, + PyArrayMethod_Spec *spec, PyArrayMethod_StridedLoop *loop) { PyObject *name = PyUnicode_FromString(ufunc_name); if (name == nullptr) { @@ -305,15 +304,17 @@ get_strided_loop(int comp) /* - * This function is used for `compare_chararrays` (and void comparisons - * currently). The first could probably be deprecated. + * This function is used for `compare_chararrays` and currently also void + * comparisons (unstructured voids). The first could probably be deprecated + * and removed but is used by `np.char.chararray` the latter should also be + * moved to the ufunc probably (removing the need for manual looping). * * The `rstrip` mechanism is presumably for some fortran compat, but the * question is whether it would not be better to have/use `rstrip` on such * an array first... * * NOTE: This function is also used for unstructured voids, this works because - * `npy_byte` works for it. + * `npy_byte` is correct. */ extern "C" { NPY_NO_EXPORT PyObject * @@ -361,7 +362,7 @@ _umath_strings_richcompare( PyArrayObject *ops[3] = {self, other, nullptr}; PyArray_Descr *descrs[3] = {nullptr, nullptr, PyArray_DescrFromType(NPY_BOOL)}; - /* ensure_dtype_nbo is in principle not necessary for == and !=: */ + /* TODO: ensure_dtype_nbo is in principle not necessary for == and != */ descrs[0] = ensure_dtype_nbo(PyArray_DESCR(self)); if (descrs[0] == nullptr) { goto finish; From 813e09442c498a6164bfc8e44bd778efd4a8fb7e Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sun, 13 Feb 2022 12:49:20 -0600 Subject: [PATCH 05/17] ENH: Use `memcmp` because it may be faster for the byte case --- numpy/core/src/umath/string_ufuncs.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index f87218c02c12..e43e5f40d5a4 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -61,13 +61,27 @@ string_cmp(int len1, character *str1, int len2, character *str2) int n = PyArray_MIN(len1, len2); - for (int i = 0; i < n; i++) { - int cmp = character_cmp(*str1, *str2); + if (sizeof(character) == 1) { + /* + * TODO: `memcmp` makes things 2x faster for longer words that match + * exactly, but at least 2x slower for short or mismatching ones. + */ + int cmp = memcmp(str1, str2, n); if (cmp != 0) { return cmp; } - str1++; - str2++; + str1 += n; + str2 += n; + } + else { + for (int i = 0; i < n; i++) { + int cmp = character_cmp(*str1, *str2); + if (cmp != 0) { + return cmp; + } + str1++; + str2++; + } } if (len1 > len2) { for (int i = n; i < len1; i++) { From ae8db174df67aa60e4867fc4a41476c457e2cbb7 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sun, 13 Feb 2022 14:35:15 -0600 Subject: [PATCH 06/17] TST: Improve string and unicode comparison tests. --- numpy/core/tests/test_strings.py | 68 ++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 numpy/core/tests/test_strings.py diff --git a/numpy/core/tests/test_strings.py b/numpy/core/tests/test_strings.py new file mode 100644 index 000000000000..35ae823f2192 --- /dev/null +++ b/numpy/core/tests/test_strings.py @@ -0,0 +1,68 @@ +import pytest + +import operator +import numpy as np + +from numpy.testing import assert_array_equal + + +COMPARISONS = [ + (operator.eq, np.equal), + (operator.ne, np.not_equal), + (operator.lt, np.less), + (operator.le, np.less_equal), + (operator.gt, np.greater), + (operator.ge, np.greater_equal), +] + + +@pytest.mark.parametrize(["op", "ufunc"], COMPARISONS) +def test_mixed_string_comparison_ufuncs_fail(op, ufunc): + arr_string = np.array(["a", "b"], dtype="S") + arr_unicode = np.array(["a", "c"], dtype="U") + + with pytest.raises(TypeError, match="did not contain a loop"): + ufunc(arr_string, arr_unicode) + + with pytest.raises(TypeError, match="did not contain a loop"): + ufunc(arr_unicode, arr_string) + +@pytest.mark.parametrize(["op", "ufunc"], COMPARISONS) +def test_mixed_string_comparisons_ufuncs_with_cast(op, ufunc): + arr_string = np.array(["a", "b"], dtype="S") + arr_unicode = np.array(["a", "c"], dtype="U") + + # While there is no loop, manual casting is acceptable: + res1 = ufunc(arr_string, arr_unicode, signature="UU->?", casting="unsafe") + res2 = ufunc(arr_string, arr_unicode, signature="SS->?", casting="unsafe") + + expected = [op("a", "a"), op("a", "c")] + assert_array_equal(res1, expected) + assert_array_equal(res2, expected) + + +@pytest.mark.parametrize(["op", "ufunc"], COMPARISONS) +@pytest.mark.parametrize("dtypes", [ + ("S2", "S2"), ("S2", "S10"), + ("U1", "U1"), ("U1", ">U1"), + ("U1", "U10")]) +@pytest.mark.parametrize("aligned", [True, False]) +def test_string_comparisons(op, ufunc, dtypes, aligned): + arr = np.arange(2**15).view(dtypes[0]) + if not aligned: + # Make `arr` unaligned: + new = np.zeros(arr.nbytes + 1, dtype=np.uint8)[1:].view(dtypes[0]) + new[...] = arr + arr = new + + arr2 = arr.astype(dtypes[1], copy=True) + np.random.shuffle(arr2) + arr[0] = arr2[0] # make sure one matches + + expected = [op(d1, d2) for d1, d2 in zip(arr.tolist(), arr2.tolist())] + assert_array_equal(op(arr, arr2), expected) + assert_array_equal(ufunc(arr, arr2), expected) + + expected = [op(d2, d1) for d1, d2 in zip(arr.tolist(), arr2.tolist())] + assert_array_equal(op(arr2, arr), expected) + assert_array_equal(ufunc(arr2, arr), expected) From 5ac55243a690d8e69c37826d3bac07536d19dafd Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sun, 13 Feb 2022 14:42:43 -0600 Subject: [PATCH 07/17] MAINT: Use switch statement based on review As suggested be Serge. Co-authored-by: Serge Guelton --- numpy/core/src/umath/string_ufuncs.cpp | 74 +++++++++++++------------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index e43e5f40d5a4..051db14c5afe 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -129,26 +129,27 @@ string_comparison_loop(PyArrayMethod_Context *context, int cmp = string_cmp( len1, (character *)in1, len2, (character *)in2); npy_bool res; - if (comp == Py_EQ) { - res = cmp == 0; - } - else if (comp == Py_NE) { - res = cmp != 0; - } - else if (comp == Py_LT) { - res = cmp < 0; - } - else if (comp == Py_LE) { - res = cmp <= 0; - } - else if (comp == Py_GT) { - res = cmp > 0; - } - else if (comp == Py_GE) { - res = cmp >= 0; - } - else { - assert(0); + switch (comp) { + case Py_EQ: + res = cmp == 0; + break; + case Py_NE: + res = cmp != 0; + break; + case Py_LT: + res = cmp < 0; + break; + case Py_LE: + res = cmp <= 0; + break; + case Py_GT: + res = cmp > 0; + break; + case Py_GE: + res = cmp >= 0; + break; + default: + assert(false); } *(npy_bool *)out = res; @@ -294,25 +295,22 @@ template static PyArrayMethod_StridedLoop * get_strided_loop(int comp) { - if (comp == Py_EQ) { - return string_comparison_loop; - } - else if (comp == Py_NE) { - return string_comparison_loop; - } - else if (comp == Py_LT) { - return string_comparison_loop; - } - else if (comp == Py_LE) { - return string_comparison_loop; - } - else if (comp == Py_GT) { - return string_comparison_loop; - } - else if (comp == Py_GE) { - return string_comparison_loop; + switch (comp) { + case Py_EQ: + return string_comparison_loop; + case Py_NE: + return string_comparison_loop; + case Py_LT: + return string_comparison_loop; + case Py_LE: + return string_comparison_loop; + case Py_GT: + return string_comparison_loop; + case Py_GE: + return string_comparison_loop; + default: + assert(false); } - assert(0); return nullptr; } From 78e4a609d4c530d381aee8714a0f69b7eeb43f9b Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sun, 13 Feb 2022 15:11:18 -0600 Subject: [PATCH 08/17] TST: Make unicode byte-swap test slightly more concrete The issue is that the `view` needs to use native byte-order, so just ensure native byte-order for the view, and then do another cast to get it right. --- numpy/core/tests/test_strings.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/numpy/core/tests/test_strings.py b/numpy/core/tests/test_strings.py index 35ae823f2192..283883b15928 100644 --- a/numpy/core/tests/test_strings.py +++ b/numpy/core/tests/test_strings.py @@ -44,11 +44,13 @@ def test_mixed_string_comparisons_ufuncs_with_cast(op, ufunc): @pytest.mark.parametrize(["op", "ufunc"], COMPARISONS) @pytest.mark.parametrize("dtypes", [ ("S2", "S2"), ("S2", "S10"), - ("U1", "U1"), ("U1", ">U1"), - ("U1", "U10")]) + ("U1"), (">U1", ">U1"), + ("U10")]) @pytest.mark.parametrize("aligned", [True, False]) def test_string_comparisons(op, ufunc, dtypes, aligned): - arr = np.arange(2**15).view(dtypes[0]) + # ensure native byte-order for the first view to stay within unicode range + native_dt = np.dtype(dtypes[0]).newbyteorder("=") + arr = np.arange(2**15).view(native_dt).astype(dtypes[0]) if not aligned: # Make `arr` unaligned: new = np.zeros(arr.nbytes + 1, dtype=np.uint8)[1:].view(dtypes[0]) From c5ffbc58422b63bdb618bc0ad4f5c6c1b537c303 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 14 Feb 2022 11:19:50 -0600 Subject: [PATCH 09/17] BUG: Add `np.compare_chararrays` to test and fix typo --- numpy/core/src/umath/string_ufuncs.cpp | 4 ++-- numpy/core/tests/test_strings.py | 26 ++++++++++++++------------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index 051db14c5afe..d473a50bd16a 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -418,7 +418,7 @@ _umath_strings_richcompare( if (rstrip == 0) { /* NOTE: Also used for VOID, so can be STRING, UNICODE, or VOID: */ - if (descrs[0]->type != NPY_UNICODE) { + if (descrs[0]->type_num != NPY_UNICODE) { strided_loop = get_strided_loop(cmp_op); } else { @@ -426,7 +426,7 @@ _umath_strings_richcompare( } } else { - if (descrs[0]->type != NPY_UNICODE) { + if (descrs[0]->type_num != NPY_UNICODE) { strided_loop = get_strided_loop(cmp_op); } else { diff --git a/numpy/core/tests/test_strings.py b/numpy/core/tests/test_strings.py index 283883b15928..02f97812efa2 100644 --- a/numpy/core/tests/test_strings.py +++ b/numpy/core/tests/test_strings.py @@ -7,17 +7,17 @@ COMPARISONS = [ - (operator.eq, np.equal), - (operator.ne, np.not_equal), - (operator.lt, np.less), - (operator.le, np.less_equal), - (operator.gt, np.greater), - (operator.ge, np.greater_equal), + (operator.eq, np.equal, "=="), + (operator.ne, np.not_equal, "!="), + (operator.lt, np.less, "<"), + (operator.le, np.less_equal, "<="), + (operator.gt, np.greater, ">"), + (operator.ge, np.greater_equal, ">="), ] -@pytest.mark.parametrize(["op", "ufunc"], COMPARISONS) -def test_mixed_string_comparison_ufuncs_fail(op, ufunc): +@pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS) +def test_mixed_string_comparison_ufuncs_fail(op, ufunc, sym): arr_string = np.array(["a", "b"], dtype="S") arr_unicode = np.array(["a", "c"], dtype="U") @@ -27,8 +27,8 @@ def test_mixed_string_comparison_ufuncs_fail(op, ufunc): with pytest.raises(TypeError, match="did not contain a loop"): ufunc(arr_unicode, arr_string) -@pytest.mark.parametrize(["op", "ufunc"], COMPARISONS) -def test_mixed_string_comparisons_ufuncs_with_cast(op, ufunc): +@pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS) +def test_mixed_string_comparisons_ufuncs_with_cast(op, ufunc, sym): arr_string = np.array(["a", "b"], dtype="S") arr_unicode = np.array(["a", "c"], dtype="U") @@ -41,13 +41,13 @@ def test_mixed_string_comparisons_ufuncs_with_cast(op, ufunc): assert_array_equal(res2, expected) -@pytest.mark.parametrize(["op", "ufunc"], COMPARISONS) +@pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS) @pytest.mark.parametrize("dtypes", [ ("S2", "S2"), ("S2", "S10"), ("U1"), (">U1", ">U1"), ("U10")]) @pytest.mark.parametrize("aligned", [True, False]) -def test_string_comparisons(op, ufunc, dtypes, aligned): +def test_string_comparisons(op, ufunc, sym, dtypes, aligned): # ensure native byte-order for the first view to stay within unicode range native_dt = np.dtype(dtypes[0]).newbyteorder("=") arr = np.arange(2**15).view(native_dt).astype(dtypes[0]) @@ -64,7 +64,9 @@ def test_string_comparisons(op, ufunc, dtypes, aligned): expected = [op(d1, d2) for d1, d2 in zip(arr.tolist(), arr2.tolist())] assert_array_equal(op(arr, arr2), expected) assert_array_equal(ufunc(arr, arr2), expected) + assert_array_equal(np.compare_chararrays(arr, arr2, sym, False), expected) expected = [op(d2, d1) for d1, d2 in zip(arr.tolist(), arr2.tolist())] assert_array_equal(op(arr2, arr), expected) assert_array_equal(ufunc(arr2, arr), expected) + assert_array_equal(np.compare_chararrays(arr2, arr, sym, False), expected) From e8c473796644d98bf3c542f7aa20eedf0cba6a57 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 14 Feb 2022 13:16:57 -0600 Subject: [PATCH 10/17] TST: Add test for empty string comparisons --- numpy/core/tests/test_strings.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/numpy/core/tests/test_strings.py b/numpy/core/tests/test_strings.py index 02f97812efa2..d64b06290715 100644 --- a/numpy/core/tests/test_strings.py +++ b/numpy/core/tests/test_strings.py @@ -70,3 +70,16 @@ def test_string_comparisons(op, ufunc, sym, dtypes, aligned): assert_array_equal(op(arr2, arr), expected) assert_array_equal(ufunc(arr2, arr), expected) assert_array_equal(np.compare_chararrays(arr2, arr, sym, False), expected) + + +@pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS) +@pytest.mark.parametrize("dtypes", [ + ("S2", "S2"), ("S2", "S10"), ("U10")]) +def test_string_comparisons_empty(op, ufunc, sym, dtypes): + arr = np.empty((1, 0, 1, 5), dtype=dtypes[0]) + arr2 = np.empty((100, 1, 0, 1), dtype=dtypes[1]) + + expected = np.empty(np.broadcast_shapes(arr.shape, arr2.shape), dtype=bool) + assert_array_equal(op(arr, arr2), expected) + assert_array_equal(ufunc(arr, arr2), expected) + assert_array_equal(np.compare_chararrays(arr, arr2, sym, False), expected) From 525955cf0d33513b42279fc105799b15e612b6d7 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 14 Feb 2022 14:40:57 -0600 Subject: [PATCH 11/17] TST: Fixup string test based on martens review --- numpy/core/tests/test_strings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/core/tests/test_strings.py b/numpy/core/tests/test_strings.py index d64b06290715..2b87ed654a0b 100644 --- a/numpy/core/tests/test_strings.py +++ b/numpy/core/tests/test_strings.py @@ -36,7 +36,7 @@ def test_mixed_string_comparisons_ufuncs_with_cast(op, ufunc, sym): res1 = ufunc(arr_string, arr_unicode, signature="UU->?", casting="unsafe") res2 = ufunc(arr_string, arr_unicode, signature="SS->?", casting="unsafe") - expected = [op("a", "a"), op("a", "c")] + expected = op(arr_string.astype('U'), arr_unicode) assert_array_equal(res1, expected) assert_array_equal(res2, expected) From d64fd7698bce3c83628b607b88eb68fe609fb8d5 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 14 Feb 2022 14:55:12 -0600 Subject: [PATCH 12/17] MAINT: Move definitions back into string_ufuncs.h --- numpy/core/src/umath/string_ufuncs.cpp | 13 ++----------- numpy/core/src/umath/string_ufuncs.h | 19 +++++++++++++++++++ numpy/core/src/umath/umathmodule.c | 5 +---- 3 files changed, 22 insertions(+), 15 deletions(-) create mode 100644 numpy/core/src/umath/string_ufuncs.h diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index d473a50bd16a..bb1b6d9c3fc9 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -12,6 +12,8 @@ #include "common_dtype.h" #include "convert_datatype.h" +#include "string_ufuncs.h" + template static NPY_INLINE int @@ -190,11 +192,6 @@ add_loop(PyObject *umath, const char *ufunc_name, } -extern "C" { - NPY_NO_EXPORT int - init_string_ufuncs(PyObject *umath); -} - NPY_NO_EXPORT int init_string_ufuncs(PyObject *umath) { @@ -328,12 +325,6 @@ get_strided_loop(int comp) * NOTE: This function is also used for unstructured voids, this works because * `npy_byte` is correct. */ -extern "C" { - NPY_NO_EXPORT PyObject * - _umath_strings_richcompare( - PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip); -} - NPY_NO_EXPORT PyObject * _umath_strings_richcompare( PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip) diff --git a/numpy/core/src/umath/string_ufuncs.h b/numpy/core/src/umath/string_ufuncs.h new file mode 100644 index 000000000000..aa17199541e7 --- /dev/null +++ b/numpy/core/src/umath/string_ufuncs.h @@ -0,0 +1,19 @@ +#ifndef _NPY_CORE_SRC_UMATH_STRING_UFUNCS_H_ +#define _NPY_CORE_SRC_UMATH_STRING_UFUNCS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +NPY_NO_EXPORT int +init_string_ufuncs(PyObject *umath); + +NPY_NO_EXPORT PyObject * +_umath_strings_richcompare( + PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip); + +#ifdef __cplusplus +} +#endif + +#endif /* _NPY_CORE_SRC_UMATH_STRING_UFUNCS_H_ */ \ No newline at end of file diff --git a/numpy/core/src/umath/umathmodule.c b/numpy/core/src/umath/umathmodule.c index 754e22c79602..5002a907314b 100644 --- a/numpy/core/src/umath/umathmodule.c +++ b/numpy/core/src/umath/umathmodule.c @@ -23,16 +23,13 @@ #include "numpy/npy_math.h" #include "number.h" #include "dispatching.h" +#include "string_ufuncs.h" /* Automatically generated code to define all ufuncs: */ #include "funcs.inc" #include "__umath_generated.c" -/* From string_ufuncs.cpp */ -NPY_NO_EXPORT int -init_string_ufuncs(PyObject *umath); - static PyUFuncGenericFunction pyfunc_functions[] = {PyUFunc_On_Om}; static int From 2cc3474ee135c7609c5b6b36021347915c166fc5 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 14 Feb 2022 18:12:13 -0600 Subject: [PATCH 13/17] MAINT: Use enum class for comparison operator templating This removes the need for a dynamic (or static) assert in the switch statement. --- numpy/core/src/umath/string_ufuncs.cpp | 62 ++++++++++++++------------ 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index bb1b6d9c3fc9..fde81360a896 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -107,7 +107,15 @@ string_cmp(int len1, character *str1, int len2, character *str2) } -template +/* + * Helper for templating, avoids warnings about uncovered switch paths. + */ +enum class COMP { + EQ, NE, LT, LE, GT, GE, +}; + + +template static int string_comparison_loop(PyArrayMethod_Context *context, char *const data[], npy_intp const dimensions[], @@ -132,26 +140,24 @@ string_comparison_loop(PyArrayMethod_Context *context, len1, (character *)in1, len2, (character *)in2); npy_bool res; switch (comp) { - case Py_EQ: + case COMP::EQ: res = cmp == 0; break; - case Py_NE: + case COMP::NE: res = cmp != 0; break; - case Py_LT: + case COMP::LT: res = cmp < 0; break; - case Py_LE: + case COMP::LE: res = cmp <= 0; break; - case Py_GT: + case COMP::GT: res = cmp > 0; break; - case Py_GE: + case COMP::GE: res = cmp >= 0; break; - default: - assert(false); } *(npy_bool *)out = res; @@ -225,27 +231,27 @@ init_string_ufuncs(PyObject *umath) /* TODO: It would be nice to condense the below */ /* All String loops */ - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "equal", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "not_equal", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "less", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "less_equal", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "greater", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "greater_equal", &spec, loop) < 0) { goto finish; } @@ -254,27 +260,27 @@ init_string_ufuncs(PyObject *umath) dtypes[0] = Unicode; dtypes[1] = Unicode; - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "equal", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "not_equal", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "less", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "less_equal", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "greater", &spec, loop) < 0) { goto finish; } - loop = string_comparison_loop; + loop = string_comparison_loop; if (add_loop(umath, "greater_equal", &spec, loop) < 0) { goto finish; } @@ -294,19 +300,19 @@ get_strided_loop(int comp) { switch (comp) { case Py_EQ: - return string_comparison_loop; + return string_comparison_loop; case Py_NE: - return string_comparison_loop; + return string_comparison_loop; case Py_LT: - return string_comparison_loop; + return string_comparison_loop; case Py_LE: - return string_comparison_loop; + return string_comparison_loop; case Py_GT: - return string_comparison_loop; + return string_comparison_loop; case Py_GE: - return string_comparison_loop; + return string_comparison_loop; default: - assert(false); + assert(false); /* caller ensures this */ } return nullptr; } From 77c49102e8295839701681e8f9cfd8bcce1f6183 Mon Sep 17 00:00:00 2001 From: serge-sans-paille Date: Tue, 15 Feb 2022 08:36:52 +0100 Subject: [PATCH 14/17] Template version of add_loop to avoid redundant code --- numpy/core/src/umath/string_ufuncs.cpp | 87 +++++++++++--------------- 1 file changed, 38 insertions(+), 49 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index fde81360a896..87491052dad7 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -114,6 +114,17 @@ enum class COMP { EQ, NE, LT, LE, GT, GE, }; +static char const* comp_name(COMP comp) { + switch(comp) { + case COMP::EQ: return "equal"; + case COMP::NE: return "not_equal"; + case COMP::LT: return "less"; + case COMP::LE: return "less_equal"; + case COMP::GT: return "greater"; + case COMP::GE: return "greater_equal"; + } +} + template static int @@ -197,6 +208,29 @@ add_loop(PyObject *umath, const char *ufunc_name, return res; } +template +struct add_loops; + +template +struct add_loops { + bool operator()(PyObject*, PyArrayMethod_Spec*) { + return false; + } +}; + +template +struct add_loops { + bool operator()(PyObject* umath, PyArrayMethod_Spec* spec) { + PyArrayMethod_StridedLoop* loop = string_comparison_loop; + if(add_loop(umath, comp_name(comp), spec, loop) < 0) { + return true; + } + else { + return add_loops()(umath, spec); + } + } +}; + NPY_NO_EXPORT int init_string_ufuncs(PyObject *umath) @@ -226,62 +260,17 @@ init_string_ufuncs(PyObject *umath) spec.slots = slots; spec.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; - /* Use this loop variable for typing more explicitly */ - PyArrayMethod_StridedLoop *loop; - - /* TODO: It would be nice to condense the below */ /* All String loops */ - loop = string_comparison_loop; - if (add_loop(umath, "equal", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "not_equal", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "less", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "less_equal", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "greater", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "greater_equal", &spec, loop) < 0) { + using string_looper = add_loops; + if(string_looper()(umath, &spec)) { goto finish; } /* All Unicode loops */ + using ucs_looper = add_loops; dtypes[0] = Unicode; dtypes[1] = Unicode; - - loop = string_comparison_loop; - if (add_loop(umath, "equal", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "not_equal", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "less", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "less_equal", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "greater", &spec, loop) < 0) { - goto finish; - } - loop = string_comparison_loop; - if (add_loop(umath, "greater_equal", &spec, loop) < 0) { + if(ucs_looper()(umath, &spec)) { goto finish; } From 28f8a188c57369f7228a6387981ef7d9d53c3324 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 15 Feb 2022 12:40:47 -0600 Subject: [PATCH 15/17] STY: Fixup style, two spaces, error is -1 --- numpy/core/src/umath/string_ufuncs.cpp | 49 ++++++++++++++------------ 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index 87491052dad7..f3f0297b2692 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -114,15 +114,18 @@ enum class COMP { EQ, NE, LT, LE, GT, GE, }; -static char const* comp_name(COMP comp) { - switch(comp) { - case COMP::EQ: return "equal"; - case COMP::NE: return "not_equal"; - case COMP::LT: return "less"; - case COMP::LE: return "less_equal"; - case COMP::GT: return "greater"; - case COMP::GE: return "greater_equal"; - } +static char const * +comp_name(COMP comp) { + switch(comp) { + case COMP::EQ: return "equal"; + case COMP::NE: return "not_equal"; + case COMP::LT: return "less"; + case COMP::LE: return "less_equal"; + case COMP::GT: return "greater"; + case COMP::GE: return "greater_equal"; + } + assert(0); + return nullptr; } @@ -208,27 +211,29 @@ add_loop(PyObject *umath, const char *ufunc_name, return res; } + template struct add_loops; template struct add_loops { - bool operator()(PyObject*, PyArrayMethod_Spec*) { - return false; - } + int operator()(PyObject*, PyArrayMethod_Spec*) { + return 0; + } }; template struct add_loops { - bool operator()(PyObject* umath, PyArrayMethod_Spec* spec) { - PyArrayMethod_StridedLoop* loop = string_comparison_loop; - if(add_loop(umath, comp_name(comp), spec, loop) < 0) { - return true; - } - else { - return add_loops()(umath, spec); + int operator()(PyObject* umath, PyArrayMethod_Spec* spec) { + PyArrayMethod_StridedLoop* loop = string_comparison_loop; + + if (add_loop(umath, comp_name(comp), spec, loop) < 0) { + return -1; + } + else { + return add_loops()(umath, spec); + } } - } }; @@ -262,7 +267,7 @@ init_string_ufuncs(PyObject *umath) /* All String loops */ using string_looper = add_loops; - if(string_looper()(umath, &spec)) { + if (string_looper()(umath, &spec) < 0) { goto finish; } @@ -270,7 +275,7 @@ init_string_ufuncs(PyObject *umath) using ucs_looper = add_loops; dtypes[0] = Unicode; dtypes[1] = Unicode; - if(ucs_looper()(umath, &spec)) { + if (ucs_looper()(umath, &spec) < 0) { goto finish; } From 8458c6021aff4a262547710ec130f4e9e18ae645 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 3 May 2022 10:09:40 +0200 Subject: [PATCH 16/17] STY: Small `string_ufuncs.cpp` fixups based on Serge's review --- numpy/core/src/umath/string_ufuncs.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index f3f0297b2692..c434eae48744 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -36,9 +36,9 @@ character_cmp(character a, character b) * zero padded (trailing zeros are ignored in other words, the shorter word * is always padded with zeros). */ -template +template static NPY_INLINE int -string_cmp(int len1, character *str1, int len2, character *str2) +string_cmp(int len1, const character *str1, int len2, const character *str2) { if (rstrip) { /* @@ -123,9 +123,10 @@ comp_name(COMP comp) { case COMP::LE: return "less_equal"; case COMP::GT: return "greater"; case COMP::GE: return "greater_equal"; + default: + assert(0); + return nullptr; } - assert(0); - return nullptr; } From da5503e5100eac1d8a7649d168873f54744f6a14 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 3 May 2022 11:12:14 +0200 Subject: [PATCH 17/17] MAINT: Fix merge conflict (ensure_dtype_nbo was removed) --- numpy/core/src/umath/string_ufuncs.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp index c434eae48744..1b45ad71fa1e 100644 --- a/numpy/core/src/umath/string_ufuncs.cpp +++ b/numpy/core/src/umath/string_ufuncs.cpp @@ -366,12 +366,12 @@ _umath_strings_richcompare( PyArrayObject *ops[3] = {self, other, nullptr}; PyArray_Descr *descrs[3] = {nullptr, nullptr, PyArray_DescrFromType(NPY_BOOL)}; - /* TODO: ensure_dtype_nbo is in principle not necessary for == and != */ - descrs[0] = ensure_dtype_nbo(PyArray_DESCR(self)); + /* TODO: ensuring native byte order is not really necessary for == and != */ + descrs[0] = NPY_DT_CALL_ensure_canonical(PyArray_DESCR(self)); if (descrs[0] == nullptr) { goto finish; } - descrs[1] = ensure_dtype_nbo(PyArray_DESCR(other)); + descrs[1] = NPY_DT_CALL_ensure_canonical(PyArray_DESCR(other)); if (descrs[1] == nullptr) { goto finish; }