Installation
Requirements
Python 3.9+
JAX 0.4.20+
NumPy 1.22+
Optax 0.1.7+
JAX-HDC is currently in alpha and is not yet published to PyPI.
Installing from Source
git clone https://github.com/rlogger/jax-hdc.git
cd jax-hdc
pip install -e .
For development (testing, linting, type checking):
pip install -e ".[dev]"
For running examples (matplotlib, scikit-learn):
pip install -e ".[examples]"
Using Nix
For reproducible development environments:
nix develop # Enter development shell
nix build # Build the package
nix run .#basic-operations
nix run .#classification-simple
GPU/TPU
JAX automatically uses GPU when available. For CUDA:
pip install --upgrade "jax[cuda12]"
For TPU, install JAX with TPU support as per the JAX installation guide.