Skip to content

Commit 5ae53e9

Browse files
authored
Merge pull request #19715 from yashasvimisra2798/casting_patch1
BUG: Casting bool_ to float16
2 parents a90677a + 580d83f commit 5ae53e9

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

numpy/core/src/multiarray/lowlevel_strided_loops.c.src

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,10 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
819819
# define _CONVERT_FN(x) npy_floatbits_to_halfbits(x)
820820
# elif @is_double1@
821821
# define _CONVERT_FN(x) npy_doublebits_to_halfbits(x)
822+
# elif @is_half1@
823+
# define _CONVERT_FN(x) (x)
824+
# elif @is_bool1@
825+
# define _CONVERT_FN(x) npy_float_to_half((float)(x!=0))
822826
# else
823827
# define _CONVERT_FN(x) npy_float_to_half((float)x)
824828
# endif

numpy/core/tests/test_casting_unittests.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,13 @@ def test_object_casts_NULL_None_equivalence(self, dtype):
695695
expected = arr_normal.astype(dtype)
696696
except TypeError:
697697
with pytest.raises(TypeError):
698-
arr_NULLs.astype(dtype)
698+
arr_NULLs.astype(dtype),
699699
else:
700700
assert_array_equal(expected, arr_NULLs.astype(dtype))
701+
702+
def test_float_to_bool(self):
703+
# test case corresponding to gh-19514
704+
# simple test for casting bool_ to float16
705+
res = np.array([0, 3, -7], dtype=np.int8).view(bool)
706+
expected = [0, 1, 1]
707+
assert_array_equal(res, expected)

0 commit comments

Comments
 (0)