Skip to content

[Bug]: ValueError when plotting 2D pytorch tensor using matplotlib==3.8.0 #26806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
CMGeldenhuys opened this issue Sep 18, 2023 · 1 comment · Fixed by #26807
Closed

[Bug]: ValueError when plotting 2D pytorch tensor using matplotlib==3.8.0 #26806

CMGeldenhuys opened this issue Sep 18, 2023 · 1 comment · Fixed by #26807

Comments

@CMGeldenhuys
Copy link

Bug summary

Value error occurs when trying to plot a 2D pytorch tensor using matplotlib==3.8.0. The error does not arise in matplotlib==3.7.3.

Code for reproduction

# Using matplotlib==3.8.0
>>> import torch
>>> import matplotlib as mplt
>>> mplt.__version__
'3.8.0'
>>> import matplotlib.pyplot as plt
>>> a = torch.randn(185,5)
>>> plt.plot(a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../lib/python3.11/site-packages/matplotlib/pyplot.py", line 3578, in plot
    return gca().plot(
           ^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/axes/_axes.py", line 1721, in plot
    lines = [*self._get_lines(self, *args, data=data, **kwargs)]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/axes/_base.py", line 303, in __call__
    yield from self._plot_args(
               ^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/axes/_base.py", line 496, in _plot_args
    axes.yaxis.update_units(y)
  File ".../lib/python3.11/site-packages/matplotlib/axis.py", line 1706, in update_units
    converter = munits.registry.get_converter(data)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/units.py", line 183, in get_converter
    first = cbook._safe_first_finite(x)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/cbook.py", line 1730, in _safe_first_finite
    if safe_isfinite(val):
       ^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/cbook.py", line 1699, in safe_isfinite
    return math.isfinite(val)
           ^^^^^^^^^^^^^^^^^^
ValueError: only one element tensors can be converted to Python scalars

>>> b = torch.randn(185)
>>> plt.plot(b)
[<matplotlib.lines.Line2D object at 0x7f27d7ccc190>]

# Using matplotlib==3.7.3
>>> import torch
>>> import matplotlib as mplt
>>> mplt.__version__
'3.7.3'
>>> import matplotlib.pyplot as plt
>>> a = torch.randn(185,5)
>>> plt.plot(a)
[<matplotlib.lines.Line2D object at 0x7f19762d7910>, <matplotlib.lines.Line2D object at 0x7f1976684250>, <matplotlib.lines.Line2D object at 0x7f19764d5150>, <matplotlib.lines.Line2D object at 0x7f197598f9d0>, <matplotlib.lines.Line2D object at 0x7f19762bff50>]

Actual outcome

(included REPL output in above example)

Expected outcome

Expect 5 line series (one for each of the second dimension of the tensor)

Additional information

The bug seems to occur in version 3.8.0. (In both cases I was using torch==2.0.1)

Operating system

Ubuntu

Matplotlib Version

3.8.0 and 3.7.3

Matplotlib Backend

TkAgg

Python version

Python 3.11.5

Jupyter version

N/A

Installation

pip

@oscargus
Copy link
Member

We do not formally support plotting from pytorch. Despite that, I suggested a fix in #26807 that solves the problem, although I am not sure it will be merged.

For now, you can do

plt.plot(b.numpy())

@oscargus oscargus added this to the v3.8.1 milestone Sep 18, 2023
timhoffm added a commit to timhoffm/matplotlib that referenced this issue Sep 18, 2023
This does not change functionality. The code path for
`safe_first_element` is `_safe_first_finite(skip_nonfinite=False)` which
 is separate code block and does not interact with the
 skip_nonfinite=True case. IMHO this is more readable.

 Also add a comment on the exception handling recently modified in matplotlib#26806.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants