This library contains a Pytorch implementation of the Power Spherical distribution, as presented in [1](http://arxiv.org/abs/2006.TBA).
- python>=3.6
- pytorch>=1.5: https://pytorch.org
Notice that older version could work but they were not tested.
Optional dependency for examples needed for plotting and numerical checks (again older version could work but they were not tested):
- numpy>=1.18.1: https://numpy.org
- matplotlib>=3.1.1: https://matplotlib.org
- quadpy>=0.14.11: https://pypi.org/project/quadpy
To install, run
$ python setup.py install
- distributions: Pytorch implementation of the Power Spherical and hyperspherical Uniform distributions. Both inherit from
torch.distributions.Distribution
. - examples: Example code for using the library within a PyTorch project.
Please have a look into the examples. We adapted our implementation to follow the structure of the Pytorch probability distributions.
Here a minimal example that demonstrate differentiable sampling:
>>> from power_spherical import PowerSpherical
>>> p = PowerSpherical(
loc=torch.tensor([0., 1.], requires_grad=True),
scale=torch.tensor(4., requires_grad=True),
)
>>> p.rsample()
tensor([-0.1786, 0.9839], grad_fn=<SubBackward0>)
and computing KL divergence with the uniform distribution:
>>> from power_spherical import HypersphericalUniform
>>> q = HypersphericalUniform(dim=2)
>>> torch.distributions.kl_divergence(p, q)
tensor(1.2486, grad_fn=<AddBackward0>)
Examples of 2D and 3D plots are show in examples and will generate something similar to these figures below.
Please cite [1] in your work when using this library in your experiments.
For questions and comments, feel free to contact Nicola De Cao.
MIT
[1] De Cao, N., Aziz, W. (2020).
The Power Spherical distrbution.
arXiv preprint arXiv:2006.TBA.
BibTeX format:
@article{decao2020power,
title={The Power Spherical distrbution},
author={
De Cao, Nicola and
Aziz, Wilker},
journal={arXiv preprint arXiv:2006.TBA},
year={2020}
}