diff --git a/numpy/core/src/multiarray/_datetime.h b/numpy/core/src/multiarray/_datetime.h index 3c216d85690f..5f04eb017254 100644 --- a/numpy/core/src/multiarray/_datetime.h +++ b/numpy/core/src/multiarray/_datetime.h @@ -1,6 +1,12 @@ #ifndef _NPY_PRIVATE__DATETIME_H_ #define _NPY_PRIVATE__DATETIME_H_ +/* + * Set to non-zero to make np.array() automatically convert python datetime.date + * and datetime.datetime to np.datetime64. + */ +#define NPY_ARRAY_AUTOCONVERT_PYDATETIME 0 + NPY_NO_EXPORT char *_datetime_strings[NPY_DATETIME_NUMUNITS]; NPY_NO_EXPORT int _days_per_month_table[2][12]; @@ -20,6 +26,12 @@ is_leapyear(npy_int64 year); NPY_NO_EXPORT npy_int64 get_datetimestruct_days(const npy_datetimestruct *dts); +/* + * Gets a descr for a Date or Datetime + */ +NPY_NO_EXPORT PyArray_Descr * +get_datetime_dtype_for_obj(PyObject *obj); + /* * Creates a datetime or timedelta dtype using a copy of the provided metadata. */ diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c index 89d4c5f63ac7..299935db990a 100644 --- a/numpy/core/src/multiarray/common.c +++ b/numpy/core/src/multiarray/common.c @@ -14,10 +14,14 @@ #include "common.h" #include "buffer.h" +#include "_datetime.h" NPY_NO_EXPORT PyArray_Descr * _array_find_python_scalar_type(PyObject *op) { +#if NPY_ARRAY_AUTOCONVERT_PYDATETIME + PyArray_Descr *res; +#endif if (PyFloat_Check(op)) { return PyArray_DescrFromType(PyArray_DOUBLE); } @@ -41,6 +45,14 @@ _array_find_python_scalar_type(PyObject *op) } return PyArray_DescrFromType(PyArray_LONGLONG); } +#if NPY_ARRAY_AUTOCONVERT_PYDATETIME + else { + res = get_datetime_dtype_for_obj(op); + if (res != NULL) { + return res; + } + } +#endif return NULL; } diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c index 7a0bafbf28c7..3c5f72a8c531 100644 --- a/numpy/core/src/multiarray/datetime.c +++ b/numpy/core/src/multiarray/datetime.c @@ -25,6 +25,9 @@ #include "_datetime.h" #include "datetime_strings.h" +static PyObject *dtype_us; +static PyObject *dtype_D; + /* * Imports the PyDateTime functions so we can create these objects. * This is called during module initialization @@ -32,7 +35,17 @@ NPY_NO_EXPORT void numpy_pydatetime_import() { + PyArray_DatetimeMetaData meta; + PyDateTime_IMPORT; + + meta.num = 1; + + meta.base = NPY_FR_us; + dtype_us = create_datetime_dtype(PyArray_DATETIME, &meta); + + meta.base = NPY_FR_D; + dtype_D = create_datetime_dtype(PyArray_DATETIME, &meta); } /* Exported as DATETIMEUNITS in multiarraymodule.c */ @@ -683,6 +696,23 @@ PyArray_TimedeltaToTimedeltaStruct(npy_timedelta val, NPY_DATETIMEUNIT fr, memset(result, -1, sizeof(npy_timedeltastruct)); } +/* + * Gets a descr for a Date or Datetime + */ +NPY_NO_EXPORT PyArray_Descr * +get_datetime_dtype_for_obj(PyObject *datetime) +{ + if (PyDateTime_Check(datetime)) { + Py_INCREF(dtype_us); + return dtype_us; + } else if (PyDate_Check(datetime)) { + Py_INCREF(dtype_D); + return dtype_D; + } else { + return NULL; + } +} + /* * Creates a datetime or timedelta dtype using a copy of the provided metadata. */