jax-healpy: A JAX-based implementation of HEALPix functions for high-performance scientific computing.
This project provides a comprehensive JAX-native implementation of HEALPix (Hierarchical Equal Area isoLatitude Pixelization) functions, designed for modern scientific computing with GPU acceleration, automatic differentiation, and seamless integration with machine learning workflows.
⚠️ WARNING: BETA STAGE - This project is in active development. APIs may change and some features are still under development.
- 🚀 GPU Acceleration: Leverage JAX's XLA compilation for high-performance computing on CPUs and GPUs
- 🔄 Automatic Differentiation: Full support for forward and reverse-mode automatic differentiation
- 📊 Vectorized Operations: Efficient batch processing of HEALPix operations
- 🔧 HEALPix Compatibility: Drop-in replacement for many healpy functions
- 🌐 Spherical Harmonics: Integration with s2fft for spherical harmonic transforms
- 🎯 Clustering Tools: Advanced clustering algorithms for astronomical data analysis
First, install JAX following the official documentation for your target architecture (CPU/GPU).
Install via PyPI:
pip install jax-healpy
For spherical harmonics functionality, install with recommended dependencies:
pip install jax-healpy[recommended]
Clone the repository and install in editable mode:
git clone https://github.com/pchanial/jax-healpy.git
cd jax-healpy
pip install -e .
import jax.numpy as jnp
import jax_healpy as hp
# Create a HEALPix map
nside = 64
npix = hp.nside2npix(nside)
# Convert pixel indices to sky coordinates
pixels = jnp.arange(npix)
theta, phi = hp.pix2ang(nside, pixels)
# Convert sky coordinates back to pixels
recovered_pixels = hp.ang2pix(nside, theta, phi, nest=False)
# Spherical harmonics transform (requires s2fft)
alm = hp.map2alm(skymap, lmax=128)
reconstructed_map = hp.alm2map(alm, nside=nside)
Execution time measured on high-performance computing systems:
Test System:
- CPU: Intel(R) Xeon(R) Gold 2648 @ 2.50GHz
- GPU: NVIDIA Tesla V100-SXM2-16GB
jax-healpy demonstrates significant performance improvements, especially for GPU-accelerated workloads and batch operations.
Complete documentation is available at jax-healpy.readthedocs.io
Install development dependencies:
pip install -e .[test]
Execute the test suite:
pytest
This project uses pre-commit hooks for code quality:
pip install pre-commit
pre-commit install
For HPC systems, load required modules:
module load python/3.10
python -m venv venv
source venv/bin/activate
pip install jax-healpy
Ensure JAX is properly configured for your GPU architecture. See the JAX GPU installation guide for details.
We welcome contributions! Please see our Contributing Guide for details on:
- Setting up the development environment
- Code style and testing requirements
- Submitting pull requests
- Reporting issues
If you use jax-healpy in your research, please cite:
@software{jax_healpy,
author = {Chanial, Pierre and Biquard, Simon and Kabalan, Wassim},
title = {jax-healpy: JAX-based HEALPix implementation},
url = {https://github.com/pchanial/jax-healpy},
year = {2024}
}
This project is licensed under the GNU General Public License v3.0 - see the LICENSE file for details.