Skip to content

Commit 731b26f

Browse files
check if all points lie on a plane
1 parent 3ef2340 commit 731b26f

File tree

3 files changed

+74
-17
lines changed

3 files changed

+74
-17
lines changed

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,41 @@ def _zalpha(colors, zs):
11851185
return np.column_stack([rgba[:, :3], rgba[:, 3] * sats])
11861186

11871187

1188+
def _all_points_on_plane(xs, ys, zs, atol=1e-8):
1189+
"""
1190+
Check if all points are on the same plane. Note that NaN values are
1191+
ignored.
1192+
1193+
Parameters
1194+
----------
1195+
xs, ys, zs : array-like
1196+
The x, y, and z coordinates of the points.
1197+
atol : float, default: 1e-8
1198+
The tolerance for the equality check.
1199+
"""
1200+
xs, ys, zs = np.asarray(xs), np.asarray(ys), np.asarray(zs)
1201+
points = np.column_stack([xs, ys, zs])
1202+
points = points[~np.isnan(points).any(axis=1)]
1203+
# Check for the case where we have less than 3 unique points
1204+
points = np.unique(points, axis=0)
1205+
if len(points) <= 3:
1206+
return True
1207+
# Calculate the vectors from the first point to all other points
1208+
vs = (points - points[0])[1:]
1209+
vs = vs / np.linalg.norm(vs, axis=1)[:, np.newaxis]
1210+
# Check for the case where all points lie on a line
1211+
vs = np.unique(vs, axis=0)
1212+
if len(vs) <= 2:
1213+
return True
1214+
# Calculate the normal vector from the first three points
1215+
n = np.cross(vs[0], vs[1])
1216+
n = n / np.linalg.norm(n)
1217+
# If the dot product of the normal vector and all other vectors is zero,
1218+
# all points are on the same plane
1219+
dots = np.dot(n, vs.transpose())
1220+
return np.allclose(dots, 0, atol=atol)
1221+
1222+
11881223
def _generate_normals(polygons):
11891224
"""
11901225
Compute the normals of a list of polygons, one normal per polygon.

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,9 +1994,8 @@ def fill_between(self, x1, y1, z1, x2, y2, z2, *,
19941994
- 'polygon': The two lines are connected to form a single polygon.
19951995
This is faster and can render more cleanly for simple shapes
19961996
(e.g. for filling between two lines that lie within a plane).
1997-
- 'auto': If the lines are in a plane parallel to a coordinate axis
1998-
(one of *x*, *y*, *z* are constant and equal for both lines),
1999-
'polygon' is used. Otherwise, 'quad' is used.
1997+
- 'auto': If the points all lie on the same 3D plane, 'polygon' is
1998+
used. Otherwise, 'quad' is used.
20001999
20012000
facecolors : list of :mpltype:`color`, default: None
20022001
Colors of each individual patch, or a single color to be used for
@@ -2019,19 +2018,6 @@ def fill_between(self, x1, y1, z1, x2, y2, z2, *,
20192018

20202019
had_data = self.has_data()
20212020
x1, y1, z1, x2, y2, z2 = cbook._broadcast_with_masks(x1, y1, z1, x2, y2, z2)
2022-
if mode == 'auto':
2023-
if ((np.all(x1 == x1[0]) and np.all(x2 == x1[0]))
2024-
or (np.all(y1 == y1[0]) and np.all(y2 == y1[0]))
2025-
or (np.all(z1 == z1[0]) and np.all(z2 == z1[0]))):
2026-
mode = 'polygon'
2027-
else:
2028-
mode = 'quad'
2029-
2030-
if shade is None:
2031-
if mode == 'quad':
2032-
shade = True
2033-
else:
2034-
shade = False
20352021

20362022
if facecolors is None:
20372023
facecolors = [self._get_patches_for_fill.get_next_color()]
@@ -2046,6 +2032,21 @@ def fill_between(self, x1, y1, z1, x2, y2, z2, *,
20462032
f"size ({x1.size})")
20472033
where = where & ~np.isnan(x1) # NaNs were broadcast in _broadcast_with_masks
20482034

2035+
if mode == 'auto':
2036+
if art3d._all_points_on_plane(np.concatenate((x1[where], x2[where])),
2037+
np.concatenate((y1[where], y2[where])),
2038+
np.concatenate((z1[where], z2[where])),
2039+
atol=1e-12):
2040+
mode = 'polygon'
2041+
else:
2042+
mode = 'quad'
2043+
2044+
if shade is None:
2045+
if mode == 'quad':
2046+
shade = True
2047+
else:
2048+
shade = False
2049+
20492050
polys = []
20502051
for idx0, idx1 in cbook.contiguous_regions(where):
20512052
x1i = x1[idx0:idx1]

lib/mpl_toolkits/mplot3d/tests/test_art3d.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import matplotlib.pyplot as plt
44

55
from matplotlib.backend_bases import MouseEvent
6-
from mpl_toolkits.mplot3d.art3d import Line3DCollection
6+
from mpl_toolkits.mplot3d.art3d import Line3DCollection, _all_points_on_plane
77

88

99
def test_scatter_3d_projection_conservation():
@@ -54,3 +54,24 @@ def test_zordered_error():
5454
ax.add_collection(Line3DCollection(lc))
5555
ax.scatter(*pc, visible=False)
5656
plt.draw()
57+
58+
def test_all_points_on_plane():
59+
# Non-coplanar points
60+
points = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
61+
assert not _all_points_on_plane(*points.T)
62+
63+
# Duplicate points
64+
points = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 0]])
65+
assert _all_points_on_plane(*points.T)
66+
67+
# NaN values
68+
points = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, np.nan]])
69+
assert _all_points_on_plane(*points.T)
70+
71+
# Less than 3 unique points
72+
points = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
73+
assert _all_points_on_plane(*points.T)
74+
75+
# All points lie on a line
76+
points = np.array([[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0]])
77+
assert _all_points_on_plane(*points.T)

0 commit comments

Comments
 (0)