Skip to content

Commit 2b5b925

Browse files
authored
Merge pull request #24757 from chahak13/scatter_offsets_masked
Allow using masked in `set_offsets`
2 parents 4bb1538 + 5267067 commit 2b5b925

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

lib/matplotlib/collections.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,11 @@ def set_offsets(self, offsets):
545545
offsets = np.asanyarray(offsets)
546546
if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else.
547547
offsets = offsets[None, :]
548-
self._offsets = np.column_stack(
549-
(np.asarray(self.convert_xunits(offsets[:, 0]), float),
550-
np.asarray(self.convert_yunits(offsets[:, 1]), float)))
548+
cstack = (np.ma.column_stack if isinstance(offsets, np.ma.MaskedArray)
549+
else np.column_stack)
550+
self._offsets = cstack(
551+
(np.asanyarray(self.convert_xunits(offsets[:, 0]), float),
552+
np.asanyarray(self.convert_yunits(offsets[:, 1]), float)))
551553
self.stale = True
552554

553555
def get_offsets(self):

lib/matplotlib/tests/test_collections.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,3 +1149,36 @@ def test_check_masked_offsets():
11491149

11501150
fig, ax = plt.subplots()
11511151
ax.scatter(unmasked_x, masked_y)
1152+
1153+
1154+
@check_figures_equal(extensions=["png"])
1155+
def test_masked_set_offsets(fig_ref, fig_test):
1156+
x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0])
1157+
y = np.arange(1, 6)
1158+
1159+
ax_test = fig_test.add_subplot()
1160+
scat = ax_test.scatter(x, y)
1161+
scat.set_offsets(np.ma.column_stack([x, y]))
1162+
ax_test.set_xticks([])
1163+
ax_test.set_yticks([])
1164+
1165+
ax_ref = fig_ref.add_subplot()
1166+
ax_ref.scatter([1, 2, 5], [1, 2, 5])
1167+
ax_ref.set_xticks([])
1168+
ax_ref.set_yticks([])
1169+
1170+
1171+
def test_check_offsets_dtype():
1172+
# Check that setting offsets doesn't change dtype
1173+
x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0])
1174+
y = np.arange(1, 6)
1175+
1176+
fig, ax = plt.subplots()
1177+
scat = ax.scatter(x, y)
1178+
masked_offsets = np.ma.column_stack([x, y])
1179+
scat.set_offsets(masked_offsets)
1180+
assert isinstance(scat.get_offsets(), type(masked_offsets))
1181+
1182+
unmasked_offsets = np.column_stack([x, y])
1183+
scat.set_offsets(unmasked_offsets)
1184+
assert isinstance(scat.get_offsets(), type(unmasked_offsets))

0 commit comments

Comments
 (0)