14
14
#include < vector>
15
15
16
16
#include " pybind11/numpy.h"
17
- #include " pybind11_backport.hpp"
18
17
19
18
#include " xtensor/xexpression.hpp"
20
19
#include " xtensor/xsemantic.hpp"
21
20
#include " xtensor/xiterator.hpp"
22
21
23
22
namespace xt
24
23
{
24
+ template <class T , int ExtraFlags>
25
+ class pyarray ;
26
+ }
25
27
26
- using pybind_array = pybind11::backport::array;
27
- using buffer_info = pybind11::buffer_info;
28
+ namespace pybind11
29
+ {
30
+ namespace detail
31
+ {
32
+ template <typename T, int ExtraFlags>
33
+ struct pyobject_caster <xt::pyarray<T, ExtraFlags>>
34
+ {
35
+ using type = xt::pyarray<T, ExtraFlags>;
36
+
37
+ bool load (handle src, bool )
38
+ {
39
+ value = type::ensure (src);
40
+ return static_cast <bool >(value);
41
+ }
42
+
43
+ static handle cast (const handle &src, return_value_policy, handle)
44
+ {
45
+ return src.inc_ref ();
46
+ }
47
+
48
+ PYBIND11_TYPE_CASTER (type, handle_type_name<type>::name());
49
+ };
50
+ }
51
+ }
52
+
53
+ namespace xt
54
+ {
55
+
56
+ using pybind_array = pybind11::array;
28
57
29
58
/* **********************
30
59
* pyarray declaration *
@@ -95,11 +124,11 @@ namespace xt
95
124
96
125
using closure_type = const self_type&;
97
126
98
- PYBIND11_OBJECT_CVT (pyarray, pybind_array, is_non_null, m_ptr = ensure_(m_ptr));
99
-
100
127
pyarray ();
101
128
102
- explicit pyarray (const buffer_info& info);
129
+ pyarray (pybind11::handle h, borrowed_t );
130
+ pyarray (pybind11::handle h, stolen_t );
131
+ pyarray (const pybind11::object &o);
103
132
104
133
pyarray (const shape_type& shape,
105
134
const strides_type& strides,
@@ -188,6 +217,9 @@ namespace xt
188
217
template <class E >
189
218
pyarray& operator =(const xexpression<E>& e);
190
219
220
+ static pyarray ensure (pybind11::handle h);
221
+ static bool _check (pybind11::handle h);
222
+
191
223
private:
192
224
193
225
template <typename ... Args>
@@ -199,11 +231,10 @@ namespace xt
199
231
200
232
static bool is_non_null (PyObject* ptr);
201
233
202
- static PyObject *ensure_ (PyObject* ptr);
203
-
204
234
mutable shape_type m_shape;
205
235
mutable strides_type m_strides;
206
236
237
+ static PyObject* raw_array_t (PyObject* ptr);
207
238
};
208
239
209
240
/* *************************************
@@ -230,16 +261,29 @@ namespace xt
230
261
231
262
template <class T , int ExtraFlags>
232
263
inline pyarray<T, ExtraFlags>::pyarray()
233
- : pybind_array()
264
+ : pybind_array(0 , static_cast <const_pointer>(nullptr ))
265
+ {
266
+ }
267
+
268
+ template <class T , int ExtraFlags>
269
+ inline pyarray<T, ExtraFlags>::pyarray(pybind11::handle h, borrowed_t ) : pybind_array(h, borrowed)
234
270
{
235
271
}
236
272
237
273
template <class T , int ExtraFlags>
238
- inline pyarray<T, ExtraFlags>::pyarray(const buffer_info& info)
239
- : pybind_array(info)
274
+ inline pyarray<T, ExtraFlags>::pyarray(pybind11::handle h, stolen_t ) : pybind_array(h, stolen)
240
275
{
241
276
}
242
277
278
+ template <class T , int ExtraFlags>
279
+ inline pyarray<T, ExtraFlags>::pyarray(const pybind11::object &o) : pybind_array(raw_array_t (o.ptr()), stolen)
280
+ {
281
+ if (!m_ptr)
282
+ {
283
+ throw pybind11::error_already_set ();
284
+ }
285
+ }
286
+
243
287
template <class T , int ExtraFlags>
244
288
inline pyarray<T, ExtraFlags>::pyarray(const shape_type& shape,
245
289
const strides_type& strides,
@@ -512,7 +556,7 @@ namespace xt
512
556
template <class T , int ExtraFlags>
513
557
inline auto pyarray<T, ExtraFlags>::storage_begin() -> storage_iterator
514
558
{
515
- return reinterpret_cast <storage_iterator>(pybind11::backport ::array_proxy (m_ptr)->data );
559
+ return reinterpret_cast <storage_iterator>(pybind11::detail ::array_proxy (m_ptr)->data );
516
560
}
517
561
518
562
template <class T , int ExtraFlags>
@@ -524,7 +568,7 @@ namespace xt
524
568
template <class T , int ExtraFlags>
525
569
inline auto pyarray<T, ExtraFlags>::storage_begin() const -> const_storage_iterator
526
570
{
527
- return reinterpret_cast <const_storage_iterator>(pybind11::backport ::array_proxy (m_ptr)->data );
571
+ return reinterpret_cast <const_storage_iterator>(pybind11::detail ::array_proxy (m_ptr)->data );
528
572
}
529
573
530
574
template <class T , int ExtraFlags>
@@ -536,7 +580,7 @@ namespace xt
536
580
template <class T , int ExtraFlags>
537
581
inline auto pyarray<T, ExtraFlags>::storage_cbegin() const -> const_storage_iterator
538
582
{
539
- return reinterpret_cast <const_storage_iterator>(pybind11::backport ::array_proxy (m_ptr)->data );
583
+ return reinterpret_cast <const_storage_iterator>(pybind11::detail ::array_proxy (m_ptr)->data );
540
584
}
541
585
542
586
template <class T , int ExtraFlags>
@@ -560,6 +604,25 @@ namespace xt
560
604
return semantic_base::operator =(e);
561
605
}
562
606
607
+ template <class T , int ExtraFlags>
608
+ inline pyarray<T, ExtraFlags> pyarray<T, ExtraFlags>::ensure(pybind11::handle h)
609
+ {
610
+ auto result = pybind11::reinterpret_steal<pyarray>(raw_array_t (h.ptr ()));
611
+ if (!pybind11::handle (result))
612
+ {
613
+ PyErr_Clear ();
614
+ }
615
+ return result;
616
+ }
617
+
618
+ template <class T , int ExtraFlags>
619
+ inline bool pyarray<T, ExtraFlags>::_check(pybind11::handle h)
620
+ {
621
+ const auto &api = pybind11::detail::npy_api::get ();
622
+ return api.PyArray_Check_ (h.ptr ())
623
+ && api.PyArray_EquivTypes_ (pybind11::detail::array_proxy (h.ptr ())->descr , pybind11::dtype::of<T>().ptr ());
624
+ }
625
+
563
626
// Private methods
564
627
565
628
template <class T , int ExtraFlags>
@@ -591,23 +654,17 @@ namespace xt
591
654
}
592
655
593
656
template <class T , int ExtraFlags>
594
- inline PyObject* pyarray<T, ExtraFlags>::ensure_ (PyObject* ptr)
657
+ inline PyObject* pyarray<T, ExtraFlags>::raw_array_t (PyObject* ptr)
595
658
{
596
659
if (ptr == nullptr )
597
660
{
598
661
return nullptr ;
599
662
}
600
- API& api = lookup_api ();
601
- PyObject* descr = api.PyArray_DescrFromType_ (pybind11::detail::npy_format_descriptor<T>::value);
602
- PyObject* result = api.PyArray_FromAny_ (ptr, descr, 0 , 0 , API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
603
- if (!result)
604
- {
605
- PyErr_Clear ();
606
- }
607
- Py_DECREF (ptr);
608
- return result;
663
+ return pybind11::detail::npy_api::get ().PyArray_FromAny_ (
664
+ ptr, pybind11::dtype::of<T>().release ().ptr (), 0 , 0 ,
665
+ pybind11::detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr
666
+ );
609
667
}
610
-
611
668
}
612
669
613
670
#endif
0 commit comments