diff --git a/db_dtypes/__init__.py b/db_dtypes/__init__.py index 056be28..a518a0b 100644 --- a/db_dtypes/__init__.py +++ b/db_dtypes/__init__.py @@ -22,13 +22,7 @@ import numpy import packaging.version import pandas -import pandas.compat.numpy.function -import pandas.core.algorithms -import pandas.core.arrays -import pandas.core.dtypes.base -import pandas.core.dtypes.dtypes -import pandas.core.dtypes.generic -import pandas.core.nanops +import pandas.api.extensions import pyarrow import pyarrow.compute @@ -44,7 +38,7 @@ pandas_release = packaging.version.parse(pandas.__version__).release -@pandas.core.dtypes.dtypes.register_extension_dtype +@pandas.api.extensions.register_extension_dtype class TimeDtype(core.BaseDatetimeDtype): """ Extension dtype for time data. @@ -113,7 +107,7 @@ def _datetime( .as_py() ) - if scalar is None: + if pandas.isna(scalar): return None if isinstance(scalar, datetime.time): return pandas.Timestamp( @@ -194,7 +188,7 @@ def __arrow_array__(self, type=None): ) -@pandas.core.dtypes.dtypes.register_extension_dtype +@pandas.api.extensions.register_extension_dtype class DateDtype(core.BaseDatetimeDtype): """ Extension dtype for time data. @@ -238,7 +232,7 @@ def _datetime( if isinstance(scalar, (pyarrow.Date32Scalar, pyarrow.Date64Scalar)): scalar = scalar.as_py() - if scalar is None: + if pandas.isna(scalar): return None elif isinstance(scalar, datetime.date): return pandas.Timestamp( diff --git a/db_dtypes/core.py b/db_dtypes/core.py index 3ade198..05daf37 100644 --- a/db_dtypes/core.py +++ b/db_dtypes/core.py @@ -12,20 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence +from typing import Optional import numpy import pandas -from pandas._libs import NaT +from pandas import NaT import pandas.api.extensions -import pandas.compat.numpy.function -import pandas.core.algorithms -import pandas.core.arrays -import pandas.core.dtypes.base -from pandas.core.dtypes.common import is_dtype_equal, is_list_like, pandas_dtype -import pandas.core.dtypes.dtypes -import pandas.core.dtypes.generic -import pandas.core.nanops +from pandas.api.types import is_dtype_equal, is_list_like, pandas_dtype from db_dtypes import pandas_backports @@ -107,42 +100,11 @@ def isna(self): return pandas.isna(self._ndarray) def _validate_scalar(self, value): - if pandas.isna(value): - return None - - if not isinstance(value, self.dtype.type): - raise ValueError(value) - - return value - - def take( - self, - indices: Sequence[int], - *, - allow_fill: bool = False, - fill_value: Any = None, - ): - indices = numpy.asarray(indices, dtype=numpy.intp) - data = self._ndarray - if allow_fill: - fill_value = self._validate_scalar(fill_value) - fill_value = ( - numpy.datetime64() if fill_value is None else self._datetime(fill_value) - ) - if (indices < -1).any(): - raise ValueError( - "take called with negative indexes other than -1," - " when a fill value is provided." - ) - out = data.take(indices) - if allow_fill: - out[indices == -1] = fill_value - - return self.__class__(out) - - # TODO: provide implementations of dropna, fillna, unique, - # factorize, argsort, searchsoeted for better performance over - # abstract implementations. + """ + Validate and convert a scalar value to datetime64[ns] for storage in + backing NumPy array. + """ + return self._datetime(value) def any( self, @@ -152,10 +114,8 @@ def any( keepdims: bool = False, skipna: bool = True, ): - pandas.compat.numpy.function.validate_any( - (), {"out": out, "keepdims": keepdims} - ) - result = pandas.core.nanops.nanany(self._ndarray, axis=axis, skipna=skipna) + pandas_backports.numpy_validate_any((), {"out": out, "keepdims": keepdims}) + result = pandas_backports.nanany(self._ndarray, axis=axis, skipna=skipna) return result def all( @@ -166,22 +126,20 @@ def all( keepdims: bool = False, skipna: bool = True, ): - pandas.compat.numpy.function.validate_all( - (), {"out": out, "keepdims": keepdims} - ) - result = pandas.core.nanops.nanall(self._ndarray, axis=axis, skipna=skipna) + pandas_backports.numpy_validate_all((), {"out": out, "keepdims": keepdims}) + result = pandas_backports.nanall(self._ndarray, axis=axis, skipna=skipna) return result def min(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs): - pandas.compat.numpy.function.validate_min((), kwargs) - result = pandas.core.nanops.nanmin( + pandas_backports.numpy_validate_min((), kwargs) + result = pandas_backports.nanmin( values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna ) return self._box_func(result) def max(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs): - pandas.compat.numpy.function.validate_max((), kwargs) - result = pandas.core.nanops.nanmax( + pandas_backports.numpy_validate_max((), kwargs) + result = pandas_backports.nanmax( values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna ) return self._box_func(result) @@ -197,11 +155,9 @@ def median( keepdims: bool = False, skipna: bool = True, ): - pandas.compat.numpy.function.validate_median( + pandas_backports.numpy_validate_median( (), {"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims}, ) - result = pandas.core.nanops.nanmedian( - self._ndarray, axis=axis, skipna=skipna - ) + result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna) return self._box_func(result) diff --git a/db_dtypes/pandas_backports.py b/db_dtypes/pandas_backports.py index 003224f..4b733cc 100644 --- a/db_dtypes/pandas_backports.py +++ b/db_dtypes/pandas_backports.py @@ -20,15 +20,32 @@ """ import operator +from typing import Any import numpy import packaging.version import pandas -from pandas._libs.lib import is_integer +from pandas.api.types import is_integer +import pandas.compat.numpy.function +import pandas.core.nanops pandas_release = packaging.version.parse(pandas.__version__).release +# Create aliases for private methods in case they move in a future version. +nanall = pandas.core.nanops.nanall +nanany = pandas.core.nanops.nanany +nanmax = pandas.core.nanops.nanmax +nanmin = pandas.core.nanops.nanmin +numpy_validate_all = pandas.compat.numpy.function.validate_all +numpy_validate_any = pandas.compat.numpy.function.validate_any +numpy_validate_max = pandas.compat.numpy.function.validate_max +numpy_validate_min = pandas.compat.numpy.function.validate_min + +if pandas_release >= (1, 2): + nanmedian = pandas.core.nanops.nanmedian + numpy_validate_median = pandas.compat.numpy.function.validate_median + def import_default(module_name, force=False, default=None): """ @@ -55,6 +72,10 @@ def import_default(module_name, force=False, default=None): return getattr(module, name, default) +# pandas.core.arraylike.OpsMixin is private, but the related public API +# "ExtensionScalarOpsMixin" is not sufficient for adding dates to times. +# It results in unsupported operand type(s) for +: 'datetime.time' and +# 'datetime.date' @import_default("pandas.core.arraylike") class OpsMixin: def _cmp_method(self, other, op): # pragma: NO COVER @@ -81,6 +102,8 @@ def __ge__(self, other): __add__ = __radd__ = __sub__ = lambda self, other: NotImplemented +# TODO: use public API once pandas 1.5 / 2.x is released. +# See: https://github.com/pandas-dev/pandas/pull/45544 @import_default("pandas.core.arrays._mixins", pandas_release < (1, 3)) class NDArrayBackedExtensionArray(pandas.core.arrays.base.ExtensionArray): @@ -130,6 +153,28 @@ def copy(self): def repeat(self, n): return self.__class__(self._ndarray.repeat(n), self._dtype) + def take( + self, + indices, + *, + allow_fill: bool = False, + fill_value: Any = None, + axis: int = 0, + ): + from pandas.core.algorithms import take + + if allow_fill: + fill_value = self._validate_scalar(fill_value) + + new_data = take( + self._ndarray, + indices, + allow_fill=allow_fill, + fill_value=fill_value, + axis=axis, + ) + return self._from_backing_data(new_data) + @classmethod def _concat_same_type(cls, to_concat, axis=0): dtypes = {str(x.dtype) for x in to_concat}