From f34eb2605afac335b75561a8c1e56642ef440958 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 17 Jul 2025 14:47:47 +0200 Subject: [PATCH 1/2] Do not fail with zero-sized arrays in dataset_to_point_list Numpy does not support reshape(-1, ...) when size is zero --- pymc/backends/arviz.py | 4 +++- tests/backends/test_arviz.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 71f08da826..ba89282155 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -618,11 +618,13 @@ def dataset_to_point_list( for vn in var_names: if not isinstance(vn, str): raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.") + num_sample_dims = len(sample_dims) stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims} transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()} + stacked_size = np.prod(transposed_dict[var_names[0]].shape[:num_sample_dims], dtype=int) stacked_dict = { - vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) + vn: da.values.reshape((stacked_size, *da.shape[num_sample_dims:])) for vn, da in transposed_dict.items() } points = [ diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 3c06288b35..f09a3a9539 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -837,3 +837,14 @@ def test_dataset_to_point_list_str_key(self): ds[3] = xarray.DataArray([1, 2, 3]) with pytest.raises(ValueError, match="must be str"): dataset_to_point_list(ds, sample_dims=["chain", "draw"]) + + def test_zero_size(self): + ds = xarray.Dataset() + ds["x"] = xarray.DataArray( + np.zeros((4, 10, 0, 5)), dims=("chain", "draw", "dim_0", "dim_5") + ) + pl, _ = dataset_to_point_list(ds, sample_dims=("chain", "draw")) + assert len(pl) == 40 + assert tuple(pl[0]) == ("x",) + assert pl[0]["x"].shape == (0, 5) + assert pl[0]["x"].dtype == np.float64 From 9629256bfa419df6784e0681d5d2e80faa0cdffc Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 20 Jul 2025 19:29:04 +0800 Subject: [PATCH 2/2] Add dprint method to PointFunc and pickle regression test --- pymc/pytensorf.py | 6 ++---- tests/test_pytensorf.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f1d69c9282..572241e49a 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -567,10 +567,8 @@ def __init__(self, f): def __call__(self, state): return self.f(**state) - def __getattr__(self, item): - """Allow access to the original function attributes.""" - # This is only reached if `__getattribute__` fails. - return getattr(self.f, item) + def dprint(self, **kwrags): + return self.f.dprint(**kwrags) class CallableTensor: diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index f7efd5f6d4..de7f7f1cbe 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -746,7 +746,7 @@ def test_hessian_sign_change_warning(func): assert equal_computations([res_neg], [-res]) -def test_point_func(): +def test_point_func(capsys): x, y = pt.vectors("x", "y") outs = x * 2 + y**2 f = compile([x, y], outs) @@ -758,3 +758,30 @@ def test_point_func(): dprint_res = point_f.dprint(file="str") expected_dprint_res = point_f.f.dprint(file="str") assert dprint_res == expected_dprint_res + + point_f.dprint(print_shape=True) + captured = capsys.readouterr() + + # The shape=(?,) arises because the inputs are dvector. This checks that the dprint works, and the print_shape + # kwargs was correctly forwarded + assert "shape=(?,)" in captured.out + + +def test_pickle_point_func(): + """ + Regression test for https://github.com/pymc-devs/pymc/issues/7857 + """ + import cloudpickle + + x, y = pt.vectors("x", "y") + outs = x * 2 + y**2 + f = compile([x, y], outs) + + point_f = PointFunc(f) + point_f_pickled = cloudpickle.dumps(point_f) + point_f_unpickled = cloudpickle.loads(point_f_pickled) + + # Check that the function survived the round-trip + np.testing.assert_allclose( + point_f_unpickled({"y": [3], "x": [2]}), point_f({"y": [3], "x": [2]}) + )