@@ -2349,6 +2349,30 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
2349
2349
return cls .__new__ (cls )
2350
2350
2351
2351
2352
+ def _is_torch_array (x ):
2353
+ """Check if 'x' is a PyTorch Tensor."""
2354
+ try :
2355
+ # we're intentionally not attempting to import torch. If somebody
2356
+ # has created a torch array, torch should already be in sys.modules
2357
+ return isinstance (x , sys .modules ['torch' ].Tensor )
2358
+ except Exception : # TypeError, KeyError, AttributeError, maybe others?
2359
+ # we're attempting to access attributes on imported modules which
2360
+ # may have arbitrary user code, so we deliberately catch all exceptions
2361
+ return False
2362
+
2363
+
2364
+ def _is_jax_array (x ):
2365
+ """Check if 'x' is a JAX Array."""
2366
+ try :
2367
+ # we're intentionally not attempting to import jax. If somebody
2368
+ # has created a jax array, jax should already be in sys.modules
2369
+ return isinstance (x , sys .modules ['jax' ].Array )
2370
+ except Exception : # TypeError, KeyError, AttributeError, maybe others?
2371
+ # we're attempting to access attributes on imported modules which
2372
+ # may have arbitrary user code, so we deliberately catch all exceptions
2373
+ return False
2374
+
2375
+
2352
2376
def _unpack_to_numpy (x ):
2353
2377
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2354
2378
if isinstance (x , np .ndarray ):
@@ -2363,6 +2387,12 @@ def _unpack_to_numpy(x):
2363
2387
# so in this case we do not want to return a function
2364
2388
if isinstance (xtmp , np .ndarray ):
2365
2389
return xtmp
2390
+ if _is_torch_array (x ) or _is_jax_array (x ):
2391
+ xtmp = x .__array__ ()
2392
+
2393
+ # In case __array__() method does not return a numpy array in future
2394
+ if isinstance (xtmp , np .ndarray ):
2395
+ return xtmp
2366
2396
return x
2367
2397
2368
2398
0 commit comments