diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index 4b0ceb6..c888826 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: from typing import Union, Optional, Literal - from ._typing import Device + from ._typing import Device, Dtype as DType from collections.abc import Sequence from ._dtypes import ( @@ -251,7 +251,14 @@ def ihfft( return res @requires_extension('fft') -def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: +def fftfreq( + n: int, + /, + *, + d: float = 1.0, + dtype: Optional[DType] = None, + device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftfreq `. @@ -259,10 +266,23 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar """ if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") - return Array._new(np.fft.fftfreq(n, d=d), device=device) + if dtype and not dtype in _real_floating_dtypes: + raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.") + + np_result = np.fft.fftfreq(n, d=d) + if dtype: + np_result = np_result.astype(dtype._np_dtype) + return Array._new(np_result, device=device) @requires_extension('fft') -def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: +def rfftfreq( + n: int, + /, + *, + d: float = 1.0, + dtype: Optional[DType] = None, + device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.rfftfreq `. @@ -270,7 +290,13 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A """ if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") - return Array._new(np.fft.rfftfreq(n, d=d), device=device) + if dtype and not dtype in _real_floating_dtypes: + raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.") + + np_result = np.fft.rfftfreq(n, d=d) + if dtype: + np_result = np_result.astype(dtype._np_dtype) + return Array._new(np_result, device=device) @requires_extension('fft') def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: