Utilities Module
The jax_hdc.utils module provides utility functions for configuration and benchmarking.
Configuration
Random Number Generation
Benchmarking
- jax_hdc.utils.benchmark_function(fn: Callable[[...], Any], *args: Any, num_trials: int = 100, warmup: int = 10, **kwargs: Any) dict[str, float | int][source]
Benchmark a JAX function with proper warmup and async handling.
- Parameters:
fn – Function to benchmark
*args – Positional arguments to fn
num_trials – Number of trials to run
warmup – Number of warmup trials
**kwargs – Keyword arguments to fn
- Returns:
Dictionary with timing statistics (mean, std, min, max, median in ms)
Validation
Helpers
Example Usage
Memory configuration:
from jax_hdc.utils import configure_memory
# Use 90% of GPU memory with flexible allocation
configure_memory(preallocate=False, memory_fraction=0.9)
Device management:
from jax_hdc.utils import get_device, to_device
device = get_device('gpu', 0)
data_on_gpu = to_device(data, device)
Benchmarking:
from jax_hdc.utils import benchmark_function
from jax_hdc.functional import bind_map
stats = benchmark_function(bind_map, x, y, num_trials=100)
print(f"Mean time: {stats['mean_ms']:.3f} ms")