Skip to content

Commit 1cf953c

Browse files
committed
API: Introduce numpy.astype
1 parent 29cbb1f commit 1cf953c

File tree

7 files changed

+92
-12
lines changed

7 files changed

+92
-12
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
`numpy.astype` was added to provide Array API compatible alternative to
2+
`numpy.ndarray.astype` method.

doc/source/reference/routines.array-creation.rst

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ From existing data
3333
asanyarray
3434
ascontiguousarray
3535
asmatrix
36+
astype
3637
copy
3738
frombuffer
3839
from_dlpack

numpy/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@
126126
arctan2, arctanh, argmax, argmin, argpartition, argsort, argwhere,
127127
around, array, array2string, array_equal, array_equiv, array_repr,
128128
array_str, asanyarray, asarray, ascontiguousarray, asfortranarray,
129-
atleast_1d, atleast_2d, atleast_3d, base_repr, binary_repr,
129+
astype, atleast_1d, atleast_2d, atleast_3d, base_repr, binary_repr,
130130
bitwise_and, bitwise_count, bitwise_not, bitwise_or, bitwise_xor,
131131
block, bool_, broadcast, busday_count, busday_offset, busdaycalendar,
132132
byte, bytes_, can_cast, cbrt, cdouble, ceil, character, choose, clip,

numpy/__init__.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ from numpy._core.numeric import (
376376
isclose as isclose,
377377
array_equal as array_equal,
378378
array_equiv as array_equiv,
379+
astype as astype,
379380
)
380381

381382
from numpy._core.numerictypes import (

numpy/_core/numeric.py

+64-11
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,15 @@
99
import numpy as np
1010
from . import multiarray
1111
from .multiarray import (
12-
ALLOW_THREADS,
13-
BUFSIZE, CLIP, MAXDIMS, MAY_SHARE_BOUNDS, MAY_SHARE_EXACT, RAISE,
14-
WRAP, arange, array, asarray, asanyarray, ascontiguousarray,
15-
asfortranarray, broadcast, can_cast,
16-
concatenate, copyto, dot, dtype, empty,
17-
empty_like, flatiter, frombuffer, from_dlpack, fromfile, fromiter,
18-
fromstring, inner, lexsort, matmul, may_share_memory,
19-
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
20-
putmask, result_type, shares_memory, vdot, where,
21-
zeros, normalize_axis_index, _get_promotion_state, _set_promotion_state)
12+
ALLOW_THREADS, BUFSIZE, CLIP, MAXDIMS, MAY_SHARE_BOUNDS, MAY_SHARE_EXACT,
13+
RAISE, WRAP, arange, array, asarray, asanyarray, ascontiguousarray,
14+
asfortranarray, broadcast, can_cast, concatenate, copyto, dot, dtype,
15+
empty, empty_like, flatiter, frombuffer, from_dlpack, fromfile, fromiter,
16+
fromstring, inner, lexsort, matmul, may_share_memory, min_scalar_type,
17+
ndarray, nditer, nested_iters, promote_types, putmask, result_type,
18+
shares_memory, vdot, where, zeros, normalize_axis_index,
19+
_get_promotion_state, _set_promotion_state
20+
)
2221

2322
from . import overrides
2423
from . import umath
@@ -43,7 +42,7 @@
4342
'arange', 'array', 'asarray', 'asanyarray', 'ascontiguousarray',
4443
'asfortranarray', 'zeros', 'count_nonzero', 'empty', 'broadcast', 'dtype',
4544
'fromstring', 'fromfile', 'frombuffer', 'from_dlpack', 'where',
46-
'argwhere', 'copyto', 'concatenate', 'lexsort',
45+
'argwhere', 'copyto', 'concatenate', 'lexsort', 'astype',
4746
'can_cast', 'promote_types', 'min_scalar_type',
4847
'result_type', 'isfortran', 'empty_like', 'zeros_like', 'ones_like',
4948
'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot', 'roll',
@@ -2508,6 +2507,60 @@ def array_equiv(a1, a2):
25082507
return bool(asarray(a1 == a2).all())
25092508

25102509

2510+
def _astype_dispatcher(x, dtype, /, *, copy=None):
2511+
return (x, dtype)
2512+
2513+
2514+
@array_function_dispatch(_astype_dispatcher)
2515+
def astype(x, dtype, /, *, copy = True):
2516+
"""
2517+
Copies an array to a specified data type.
2518+
2519+
This function is an Array API compatible alternative to
2520+
`numpy.ndarray.astype`.
2521+
2522+
Parameters
2523+
----------
2524+
x : array_like
2525+
Input array to cast.
2526+
dtype : dtype
2527+
Data type of the result.
2528+
copy : bool, optional
2529+
Specifies whether to copy an array when the specified dtype matches
2530+
the data type of the input array ``x``. If ``True``, a newly allocated
2531+
array must always be returned. If ``False`` and the specified dtype
2532+
matches the data type of the input array, the input array must be
2533+
returned; otherwise, a newly allocated array must be returned.
2534+
Defaults to ``True``.
2535+
2536+
Returns
2537+
-------
2538+
out : ndarray
2539+
An array having the specified data type.
2540+
2541+
See Also
2542+
--------
2543+
ndarray.astype
2544+
2545+
Examples
2546+
--------
2547+
>>> arr = np.array([1, 2, 3]); arr
2548+
array([1, 2, 3])
2549+
>>> np.astype(arr, np.float64)
2550+
array([1., 2., 3.])
2551+
2552+
Non-copy case:
2553+
2554+
>>> arr = np.array([1, 2, 3])
2555+
>>> arr_noncpy = np.astype(arr, arr.dtype, copy=False)
2556+
>>> np.shares_memory(arr, arr_noncpy)
2557+
True
2558+
2559+
"""
2560+
x = asarray(x)
2561+
return x.astype(dtype, copy=copy)
2562+
2563+
25112564
inf = PINF
25122565
nan = NAN
25132566
False_ = bool_(False)

numpy/_core/numeric.pyi

+6
Original file line numberDiff line numberDiff line change
@@ -658,3 +658,9 @@ def isclose(
658658
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = ...) -> bool: ...
659659

660660
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> bool: ...
661+
662+
def astype(
663+
x: ArrayLike,
664+
dtype: DTypeLike,
665+
copy: bool = ...,
666+
) -> NDArray[Any]: ...

numpy/_core/tests/test_numeric.py

+17
Original file line numberDiff line numberDiff line change
@@ -3981,3 +3981,20 @@ def test_zero_dimensional(self):
39813981
arr_0d = np.array(1)
39823982
ret = np.tensordot(arr_0d, arr_0d, ([], [])) # contracting no axes is well defined
39833983
assert_array_equal(ret, arr_0d)
3984+
3985+
3986+
class TestAsType:
3987+
3988+
def test_astype(self):
3989+
data = [[1, 2], [3, 4]]
3990+
arr = np.astype(
3991+
np.array(data, dtype=np.int64), np.uint32
3992+
)
3993+
expected = np.array(data, dtype=np.uint32)
3994+
3995+
assert_array_equal(arr, expected)
3996+
assert_equal(arr.dtype, expected.dtype)
3997+
3998+
assert np.shares_memory(
3999+
arr, np.astype(arr, arr.dtype, copy=False)
4000+
)

0 commit comments

Comments
 (0)