Skip to content

Commit 3323ae8

Browse files
authored
Merge pull request #25887 from patel-zeel/speed_up_jax_torch_plots
Update `_unpack_to_numpy` function to convert JAX and PyTorch arrays to NumPy
2 parents 7fde77e + 9acfb5b commit 3323ae8

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

lib/matplotlib/cbook.py

+30
Original file line numberDiff line numberDiff line change
@@ -2349,6 +2349,30 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
23492349
return cls.__new__(cls)
23502350

23512351

2352+
def _is_torch_array(x):
2353+
"""Check if 'x' is a PyTorch Tensor."""
2354+
try:
2355+
# we're intentionally not attempting to import torch. If somebody
2356+
# has created a torch array, torch should already be in sys.modules
2357+
return isinstance(x, sys.modules['torch'].Tensor)
2358+
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2359+
# we're attempting to access attributes on imported modules which
2360+
# may have arbitrary user code, so we deliberately catch all exceptions
2361+
return False
2362+
2363+
2364+
def _is_jax_array(x):
2365+
"""Check if 'x' is a JAX Array."""
2366+
try:
2367+
# we're intentionally not attempting to import jax. If somebody
2368+
# has created a jax array, jax should already be in sys.modules
2369+
return isinstance(x, sys.modules['jax'].Array)
2370+
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2371+
# we're attempting to access attributes on imported modules which
2372+
# may have arbitrary user code, so we deliberately catch all exceptions
2373+
return False
2374+
2375+
23522376
def _unpack_to_numpy(x):
23532377
"""Internal helper to extract data from e.g. pandas and xarray objects."""
23542378
if isinstance(x, np.ndarray):
@@ -2363,6 +2387,12 @@ def _unpack_to_numpy(x):
23632387
# so in this case we do not want to return a function
23642388
if isinstance(xtmp, np.ndarray):
23652389
return xtmp
2390+
if _is_torch_array(x) or _is_jax_array(x):
2391+
xtmp = x.__array__()
2392+
2393+
# In case __array__() method does not return a numpy array in future
2394+
if isinstance(xtmp, np.ndarray):
2395+
return xtmp
23662396
return x
23672397

23682398

lib/matplotlib/tests/test_cbook.py

+44
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import sys
34
import itertools
45
import pickle
56

@@ -16,6 +17,7 @@
1617
from matplotlib import _api, cbook
1718
import matplotlib.colors as mcolors
1819
from matplotlib.cbook import delete_masked_points, strip_math
20+
from types import ModuleType
1921

2022

2123
class Test_delete_masked_points:
@@ -938,3 +940,45 @@ def test_auto_format_str(fmt, value, result):
938940
"""Apply *value* to the format string *fmt*."""
939941
assert cbook._auto_format_str(fmt, value) == result
940942
assert cbook._auto_format_str(fmt, np.float64(value)) == result
943+
944+
945+
def test_unpack_to_numpy_from_torch():
946+
"""Test that torch tensors are converted to numpy arrays.
947+
We don't want to create a dependency on torch in the test suite, so we mock it.
948+
"""
949+
class Tensor:
950+
def __init__(self, data):
951+
self.data = data
952+
def __array__(self):
953+
return self.data
954+
torch = ModuleType('torch')
955+
torch.Tensor = Tensor
956+
sys.modules['torch'] = torch
957+
958+
data = np.arange(10)
959+
torch_tensor = torch.Tensor(data)
960+
961+
result = cbook._unpack_to_numpy(torch_tensor)
962+
assert result is torch_tensor.__array__()
963+
964+
965+
def test_unpack_to_numpy_from_jax():
966+
"""Test that jax arrays are converted to numpy arrays.
967+
We don't want to create a dependency on jax in the test suite, so we mock it.
968+
"""
969+
class Array:
970+
def __init__(self, data):
971+
self.data = data
972+
def __array__(self):
973+
return self.data
974+
975+
jax = ModuleType('jax')
976+
jax.Array = Array
977+
978+
sys.modules['jax'] = jax
979+
980+
data = np.arange(10)
981+
jax_array = jax.Array(data)
982+
983+
result = cbook._unpack_to_numpy(jax_array)
984+
assert result is jax_array.__array__()

0 commit comments

Comments
 (0)