Skip to content

Add the ability to change the focal length of the camera for 3D plots #22046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions doc/users/next_whats_new/3d_plot_focal_length.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Give the 3D camera a custom focal length
----------------------------------------

Users can now better mimic real-world cameras by specifying the focal length of
the virtual camera in 3D plots. The default focal length of 1 corresponds to a
Field of View (FOV) of 90 deg, and is backwards-compatible with existing 3D
plots. An increased focal length between 1 and infinity "flattens" the image,
while a decreased focal length between 1 and 0 exaggerates the perspective and
gives the image more apparent depth.

The focal length can be calculated from a desired FOV via the equation:

.. mathmpl::

focal\_length = 1/\tan(FOV/2)

.. plot::
:include-source: true

from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, subplot_kw={'projection': '3d'},
constrained_layout=True)
X, Y, Z = axes3d.get_test_data(0.05)
focal_lengths = [0.25, 1, 4]
for ax, fl in zip(axs, focal_lengths):
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)
ax.set_proj_type('persp', focal_length=fl)
ax.set_title(f"focal_length = {fl}")
plt.tight_layout()
plt.show()
55 changes: 55 additions & 0 deletions examples/mplot3d/projections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
========================
3D plot projection types
========================

Demonstrates the different camera projections for 3D plots, and the effects of
changing the focal length for a perspective projection. Note that Matplotlib
corrects for the 'zoom' effect of changing the focal length.

The default focal length of 1 corresponds to a Field of View (FOV) of 90 deg.
An increased focal length between 1 and infinity "flattens" the image, while a
decreased focal length between 1 and 0 exaggerates the perspective and gives
the image more apparent depth. In the limiting case, a focal length of
infinity corresponds to an orthographic projection after correction of the
zoom effect.

You can calculate focal length from a FOV via the equation:

.. mathmpl::

1 / \tan (FOV / 2)

Or vice versa:

.. mathmpl::

FOV = 2 * \atan (1 / focal length)

"""

from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt


fig, axs = plt.subplots(1, 3, subplot_kw={'projection': '3d'})

# Get the test data
X, Y, Z = axes3d.get_test_data(0.05)

# Plot the data
for ax in axs:
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)

# Set the orthographic projection.
axs[0].set_proj_type('ortho') # FOV = 0 deg
axs[0].set_title("'ortho'\nfocal_length = ∞", fontsize=10)

# Set the perspective projections
axs[1].set_proj_type('persp') # FOV = 90 deg
axs[1].set_title("'persp'\nfocal_length = 1 (default)", fontsize=10)

axs[2].set_proj_type('persp', focal_length=0.2) # FOV = 157.4 deg
axs[2].set_title("'persp'\nfocal_length = 0.2", fontsize=10)

plt.show()
57 changes: 46 additions & 11 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Axes3D(Axes):
def __init__(
self, fig, rect=None, *args,
elev=30, azim=-60, roll=0, sharez=None, proj_type='persp',
box_aspect=None, computed_zorder=True,
box_aspect=None, computed_zorder=True, focal_length=None,
**kwargs):
"""
Parameters
Expand Down Expand Up @@ -104,6 +104,13 @@ def __init__(
This behavior is deprecated in 3.4, the default will
change to False in 3.5. The keyword will be undocumented
and a non-False value will be an error in 3.6.
focal_length : float, default: None
For a projection type of 'persp', the focal length of the virtual
camera. Must be > 0. If None, defaults to 1.
For a projection type of 'ortho', must be set to either None
or infinity (numpy.inf). If None, defaults to infinity.
The focal length can be computed from a desired Field Of View via
the equation: focal_length = 1/tan(FOV/2)

**kwargs
Other optional keyword arguments:
Expand All @@ -117,7 +124,7 @@ def __init__(
self.initial_azim = azim
self.initial_elev = elev
self.initial_roll = roll
self.set_proj_type(proj_type)
self.set_proj_type(proj_type, focal_length)
self.computed_zorder = computed_zorder

self.xy_viewLim = Bbox.unit()
Expand Down Expand Up @@ -989,18 +996,33 @@ def view_init(self, elev=None, azim=None, roll=None, vertical_axis="z"):
dict(x=0, y=1, z=2), vertical_axis=vertical_axis
)

def set_proj_type(self, proj_type):
def set_proj_type(self, proj_type, focal_length=None):
"""
Set the projection type.

Parameters
----------
proj_type : {'persp', 'ortho'}
"""
self._projection = _api.check_getitem({
'persp': proj3d.persp_transformation,
'ortho': proj3d.ortho_transformation,
}, proj_type=proj_type)
The projection type.
focal_length : float, default: None
For a projection type of 'persp', the focal length of the virtual
camera. Must be > 0. If None, defaults to 1.
The focal length can be computed from a desired Field Of View via
the equation: focal_length = 1/tan(FOV/2)
"""
_api.check_in_list(['persp', 'ortho'], proj_type=proj_type)
if proj_type == 'persp':
if focal_length is None:
focal_length = 1
elif focal_length <= 0:
raise ValueError(f"focal_length = {focal_length} must be "
"greater than 0")
self._focal_length = focal_length
elif proj_type == 'ortho':
if focal_length not in (None, np.inf):
raise ValueError(f"focal_length = {focal_length} must be "
f"None for proj_type = {proj_type}")
self._focal_length = np.inf

def _roll_to_vertical(self, arr):
"""Roll arrays to match the different vertical axis."""
Expand Down Expand Up @@ -1056,8 +1078,21 @@ def get_proj(self):
V = np.zeros(3)
V[self._vertical_axis] = -1 if abs(elev_rad) > 0.5 * np.pi else 1

viewM = proj3d.view_transformation(eye, R, V, roll_rad)
projM = self._projection(-self._dist, self._dist)
# Generate the view and projection transformation matrices
if self._focal_length == np.inf:
# Orthographic projection
viewM = proj3d.view_transformation(eye, R, V, roll_rad)
projM = proj3d.ortho_transformation(-self._dist, self._dist)
else:
# Perspective projection
# Scale the eye dist to compensate for the focal length zoom effect
eye_focal = R + self._dist * ps * self._focal_length
viewM = proj3d.view_transformation(eye_focal, R, V, roll_rad)
projM = proj3d.persp_transformation(-self._dist,
self._dist,
self._focal_length)

# Combine all the transformation matrices to get the final projection
M0 = np.dot(viewM, worldM)
M = np.dot(projM, M0)
return M
Expand Down Expand Up @@ -1120,7 +1155,7 @@ def cla(self):
pass

self._autoscaleZon = True
if self._projection is proj3d.ortho_transformation:
if self._focal_length == np.inf:
self._zmargin = rcParams['axes.zmargin']
else:
self._zmargin = 0.
Expand Down
26 changes: 15 additions & 11 deletions lib/mpl_toolkits/mplot3d/proj3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,27 @@ def view_transformation(E, R, V, roll):
return np.dot(Mr, Mt)


def persp_transformation(zfront, zback):
a = (zfront+zback)/(zfront-zback)
b = -2*(zfront*zback)/(zfront-zback)
return np.array([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, a, b],
[0, 0, -1, 0]])
def persp_transformation(zfront, zback, focal_length):
e = focal_length
a = 1 # aspect ratio
b = (zfront+zback)/(zfront-zback)
c = -2*(zfront*zback)/(zfront-zback)
proj_matrix = np.array([[e, 0, 0, 0],
[0, e/a, 0, 0],
[0, 0, b, c],
[0, 0, -1, 0]])
return proj_matrix


def ortho_transformation(zfront, zback):
# note: w component in the resulting vector will be (zback-zfront), not 1
a = -(zfront + zback)
b = -(zfront - zback)
return np.array([[2, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, -2, 0],
[0, 0, a, b]])
proj_matrix = np.array([[2, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, -2, 0],
[0, 0, a, b]])
return proj_matrix


def _proj_transform_vec(vec, M):
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 17 additions & 1 deletion lib/mpl_toolkits/tests/test_mplot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def _test_proj_make_M():
V = np.array([0, 0, 1])
roll = 0
viewM = proj3d.view_transformation(E, R, V, roll)
perspM = proj3d.persp_transformation(100, -100)
perspM = proj3d.persp_transformation(100, -100, 1)
M = np.dot(perspM, viewM)
return M

Expand Down Expand Up @@ -1036,6 +1036,22 @@ def test_unautoscale(axis, auto):
np.testing.assert_array_equal(get_lim(), (-0.5, 0.5))


def test_axes3d_focal_length_checks():
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
with pytest.raises(ValueError):
ax.set_proj_type('persp', focal_length=0)
with pytest.raises(ValueError):
ax.set_proj_type('ortho', focal_length=1)


@mpl3d_image_comparison(['axes3d_focal_length.png'], remove_text=False)
def test_axes3d_focal_length():
fig, axs = plt.subplots(1, 2, subplot_kw={'projection': '3d'})
axs[0].set_proj_type('persp', focal_length=np.inf)
axs[1].set_proj_type('persp', focal_length=0.15)


@mpl3d_image_comparison(['axes3d_ortho.png'], remove_text=False)
def test_axes3d_ortho():
fig = plt.figure()
Expand Down