Skip to content

Commit 62cbfc3

Browse files
committed
MAINT: Use enum class for comparison operator templating
This removes the need for a dynamic (or static) assert in the switch statement.
1 parent ff35854 commit 62cbfc3

File tree

1 file changed

+39
-28
lines changed

1 file changed

+39
-28
lines changed

numpy/core/src/umath/string_ufuncs.cpp

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,20 @@ string_cmp(int len1, character *str1, int len2, character *str2)
107107
}
108108

109109

110-
template <bool rstrip, int comp, typename character>
110+
/*
111+
* Helper for templating, avoids warnings about uncovered switch paths.
112+
*/
113+
enum class COMP {
114+
EQ = Py_EQ,
115+
NE = Py_NE,
116+
LT = Py_LT,
117+
LE = Py_LE,
118+
GT = Py_GT,
119+
GE = Py_GE,
120+
};
121+
122+
123+
template <bool rstrip, COMP comp, typename character>
111124
static int
112125
string_comparison_loop(PyArrayMethod_Context *context,
113126
char *const data[], npy_intp const dimensions[],
@@ -132,26 +145,24 @@ string_comparison_loop(PyArrayMethod_Context *context,
132145
len1, (character *)in1, len2, (character *)in2);
133146
npy_bool res;
134147
switch (comp) {
135-
case Py_EQ:
148+
case COMP::EQ:
136149
res = cmp == 0;
137150
break;
138-
case Py_NE:
151+
case COMP::NE:
139152
res = cmp != 0;
140153
break;
141-
case Py_LT:
154+
case COMP::LT:
142155
res = cmp < 0;
143156
break;
144-
case Py_LE:
157+
case COMP::LE:
145158
res = cmp <= 0;
146159
break;
147-
case Py_GT:
160+
case COMP::GT:
148161
res = cmp > 0;
149162
break;
150-
case Py_GE:
163+
case COMP::GE:
151164
res = cmp >= 0;
152165
break;
153-
default:
154-
assert(false);
155166
}
156167
*(npy_bool *)out = res;
157168

@@ -225,27 +236,27 @@ init_string_ufuncs(PyObject *umath)
225236

226237
/* TODO: It would be nice to condense the below */
227238
/* All String loops */
228-
loop = string_comparison_loop<false, Py_EQ, npy_byte>;
239+
loop = string_comparison_loop<false, COMP::EQ, npy_byte>;
229240
if (add_loop(umath, "equal", &spec, loop) < 0) {
230241
goto finish;
231242
}
232-
loop = string_comparison_loop<false, Py_NE, npy_byte>;
243+
loop = string_comparison_loop<false, COMP::NE, npy_byte>;
233244
if (add_loop(umath, "not_equal", &spec, loop) < 0) {
234245
goto finish;
235246
}
236-
loop = string_comparison_loop<false, Py_LT, npy_byte>;
247+
loop = string_comparison_loop<false, COMP::LT, npy_byte>;
237248
if (add_loop(umath, "less", &spec, loop) < 0) {
238249
goto finish;
239250
}
240-
loop = string_comparison_loop<false, Py_LE, npy_byte>;
251+
loop = string_comparison_loop<false, COMP::LE, npy_byte>;
241252
if (add_loop(umath, "less_equal", &spec, loop) < 0) {
242253
goto finish;
243254
}
244-
loop = string_comparison_loop<false, Py_GT, npy_byte>;
255+
loop = string_comparison_loop<false, COMP::GT, npy_byte>;
245256
if (add_loop(umath, "greater", &spec, loop) < 0) {
246257
goto finish;
247258
}
248-
loop = string_comparison_loop<false, Py_GE, npy_byte>;
259+
loop = string_comparison_loop<false, COMP::GE, npy_byte>;
249260
if (add_loop(umath, "greater_equal", &spec, loop) < 0) {
250261
goto finish;
251262
}
@@ -254,27 +265,27 @@ init_string_ufuncs(PyObject *umath)
254265
dtypes[0] = Unicode;
255266
dtypes[1] = Unicode;
256267

257-
loop = string_comparison_loop<false, Py_EQ, npy_ucs4>;
268+
loop = string_comparison_loop<false, COMP::EQ, npy_ucs4>;
258269
if (add_loop(umath, "equal", &spec, loop) < 0) {
259270
goto finish;
260271
}
261-
loop = string_comparison_loop<false, Py_NE, npy_ucs4>;
272+
loop = string_comparison_loop<false, COMP::NE, npy_ucs4>;
262273
if (add_loop(umath, "not_equal", &spec, loop) < 0) {
263274
goto finish;
264275
}
265-
loop = string_comparison_loop<false, Py_LT, npy_ucs4>;
276+
loop = string_comparison_loop<false, COMP::LT, npy_ucs4>;
266277
if (add_loop(umath, "less", &spec, loop) < 0) {
267278
goto finish;
268279
}
269-
loop = string_comparison_loop<false, Py_LE, npy_ucs4>;
280+
loop = string_comparison_loop<false, COMP::LE, npy_ucs4>;
270281
if (add_loop(umath, "less_equal", &spec, loop) < 0) {
271282
goto finish;
272283
}
273-
loop = string_comparison_loop<false, Py_GT, npy_ucs4>;
284+
loop = string_comparison_loop<false, COMP::GT, npy_ucs4>;
274285
if (add_loop(umath, "greater", &spec, loop) < 0) {
275286
goto finish;
276287
}
277-
loop = string_comparison_loop<false, Py_GE, npy_ucs4>;
288+
loop = string_comparison_loop<false, COMP::GE, npy_ucs4>;
278289
if (add_loop(umath, "greater_equal", &spec, loop) < 0) {
279290
goto finish;
280291
}
@@ -294,19 +305,19 @@ get_strided_loop(int comp)
294305
{
295306
switch (comp) {
296307
case Py_EQ:
297-
return string_comparison_loop<rstrip, Py_EQ, character>;
308+
return string_comparison_loop<rstrip, COMP::EQ, character>;
298309
case Py_NE:
299-
return string_comparison_loop<rstrip, Py_NE, character>;
310+
return string_comparison_loop<rstrip, COMP::NE, character>;
300311
case Py_LT:
301-
return string_comparison_loop<rstrip, Py_LT, character>;
312+
return string_comparison_loop<rstrip, COMP::LT, character>;
302313
case Py_LE:
303-
return string_comparison_loop<rstrip, Py_LE, character>;
314+
return string_comparison_loop<rstrip, COMP::LE, character>;
304315
case Py_GT:
305-
return string_comparison_loop<rstrip, Py_GT, character>;
316+
return string_comparison_loop<rstrip, COMP::GT, character>;
306317
case Py_GE:
307-
return string_comparison_loop<rstrip, Py_GE, character>;
318+
return string_comparison_loop<rstrip, COMP::GE, character>;
308319
default:
309-
assert(false);
320+
assert(false); /* caller ensures this */
310321
}
311322
return nullptr;
312323
}

0 commit comments

Comments
 (0)