From eaf830fa0265ccac16b4996cd95a5fe0eb3e1a6f Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Sat, 17 Dec 2022 16:36:42 +0530 Subject: [PATCH 1/5] Allow using masked array as offsets --- lib/matplotlib/collections.py | 6 +++--- lib/matplotlib/tests/test_collections.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 29667ff13922..dd7a09c922f0 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -545,9 +545,9 @@ def set_offsets(self, offsets): offsets = np.asanyarray(offsets) if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else. offsets = offsets[None, :] - self._offsets = np.column_stack( - (np.asarray(self.convert_xunits(offsets[:, 0]), float), - np.asarray(self.convert_yunits(offsets[:, 1]), float))) + self._offsets = np.ma.column_stack( + (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), + np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) self.stale = True def get_offsets(self): diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 782df21c5985..659e821029fc 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -1149,3 +1149,14 @@ def test_check_masked_offsets(): fig, ax = plt.subplots() ax.scatter(unmasked_x, masked_y) + + +def test_masked_set_offsets(): + x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0]) + y = np.arange(1, 6) + + fig, ax = plt.subplots() + scat = ax.scatter(x, y) + x += 1 + scat.set_offsets(np.ma.column_stack([x, y])) + assert np.ma.is_masked(scat.get_offsets()) From bf73d48765f793ee1a7e2b8a6e44d5441d92a0d6 Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Mon, 19 Dec 2022 12:20:05 +0530 Subject: [PATCH 2/5] Use np.ma functions only when input is masked --- lib/matplotlib/collections.py | 11 +++++--- lib/matplotlib/tests/test_collections.py | 34 +++++++++++++++++++++--- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index dd7a09c922f0..9404b3f87137 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -545,9 +545,14 @@ def set_offsets(self, offsets): offsets = np.asanyarray(offsets) if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else. offsets = offsets[None, :] - self._offsets = np.ma.column_stack( - (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), - np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) + if isinstance(offsets, np.ma.MaskedArray): + self._offsets = np.ma.column_stack( + (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), + np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) + else: + self._offsets = np.column_stack( + (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), + np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) self.stale = True def get_offsets(self): diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 659e821029fc..6bdd83ddc2f9 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -1151,12 +1151,40 @@ def test_check_masked_offsets(): ax.scatter(unmasked_x, masked_y) -def test_masked_set_offsets(): +@check_figures_equal(extensions=["png"]) +def test_masked_set_offsets(fig_ref, fig_test): + x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0]) + y = np.arange(1, 6) + + ax_test = fig_test.add_subplot() + scat = ax_test.scatter(x, y) + x += 1 + scat.set_offsets(np.ma.column_stack([x, y])) + ax_test.set_xticks([]) + ax_test.set_yticks([]) + ax_test.set_xlim(0, 7) + ax_test.set_ylim(0, 6) + + ax_ref = fig_ref.add_subplot() + ax_ref.scatter([2, 3, 6], [1, 2, 5]) + ax_ref.set_xticks([]) + ax_ref.set_yticks([]) + ax_ref.set_xlim(0, 7) + ax_ref.set_ylim(0, 6) + + +def test_check_offsets_dtype(): + # Check that setting offsets doesn't change dtype x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0]) y = np.arange(1, 6) fig, ax = plt.subplots() scat = ax.scatter(x, y) x += 1 - scat.set_offsets(np.ma.column_stack([x, y])) - assert np.ma.is_masked(scat.get_offsets()) + masked_offsets = np.ma.column_stack([x, y]) + scat.set_offsets(masked_offsets) + assert isinstance(scat.get_offsets(), type(masked_offsets)) + + unmasked_offsets = np.column_stack([x, y]) + scat.set_offsets(unmasked_offsets) + assert isinstance(scat.get_offsets(), type(unmasked_offsets)) From a05712180552ffebda33d6a49d8d1120293d4e68 Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Tue, 20 Dec 2022 13:28:34 +0530 Subject: [PATCH 3/5] Refactor function selection and test --- lib/matplotlib/collections.py | 13 +++++-------- lib/matplotlib/tests/test_collections.py | 4 +--- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 9404b3f87137..44888e469104 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -545,14 +545,11 @@ def set_offsets(self, offsets): offsets = np.asanyarray(offsets) if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else. offsets = offsets[None, :] - if isinstance(offsets, np.ma.MaskedArray): - self._offsets = np.ma.column_stack( - (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), - np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) - else: - self._offsets = np.column_stack( - (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), - np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) + cstack = (np.ma.column_stack if isinstance(offsets, np.ma.MaskedArray) + else np.column_stack) + self._offsets = cstack( + (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), + np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) self.stale = True def get_offsets(self): diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 6bdd83ddc2f9..8bb53a6fe0fe 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -1158,7 +1158,6 @@ def test_masked_set_offsets(fig_ref, fig_test): ax_test = fig_test.add_subplot() scat = ax_test.scatter(x, y) - x += 1 scat.set_offsets(np.ma.column_stack([x, y])) ax_test.set_xticks([]) ax_test.set_yticks([]) @@ -1166,7 +1165,7 @@ def test_masked_set_offsets(fig_ref, fig_test): ax_test.set_ylim(0, 6) ax_ref = fig_ref.add_subplot() - ax_ref.scatter([2, 3, 6], [1, 2, 5]) + ax_ref.scatter([1, 2, 5], [1, 2, 5]) ax_ref.set_xticks([]) ax_ref.set_yticks([]) ax_ref.set_xlim(0, 7) @@ -1180,7 +1179,6 @@ def test_check_offsets_dtype(): fig, ax = plt.subplots() scat = ax.scatter(x, y) - x += 1 masked_offsets = np.ma.column_stack([x, y]) scat.set_offsets(masked_offsets) assert isinstance(scat.get_offsets(), type(masked_offsets)) From 6605026a64574cc0688a8681ddd900298630f9a2 Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Wed, 21 Dec 2022 14:31:05 +0530 Subject: [PATCH 4/5] Remove indentation. --- lib/matplotlib/collections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 44888e469104..c75ac103886b 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -549,7 +549,7 @@ def set_offsets(self, offsets): else np.column_stack) self._offsets = cstack( (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), - np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) + np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) self.stale = True def get_offsets(self): From 5267067330465f5636eefe1bd50a4be33f4b6b87 Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Wed, 21 Dec 2022 14:39:53 +0530 Subject: [PATCH 5/5] Remove explicit limits --- lib/matplotlib/tests/test_collections.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 8bb53a6fe0fe..445249fae525 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -1161,15 +1161,11 @@ def test_masked_set_offsets(fig_ref, fig_test): scat.set_offsets(np.ma.column_stack([x, y])) ax_test.set_xticks([]) ax_test.set_yticks([]) - ax_test.set_xlim(0, 7) - ax_test.set_ylim(0, 6) ax_ref = fig_ref.add_subplot() ax_ref.scatter([1, 2, 5], [1, 2, 5]) ax_ref.set_xticks([]) ax_ref.set_yticks([]) - ax_ref.set_xlim(0, 7) - ax_ref.set_ylim(0, 6) def test_check_offsets_dtype():