Utilities Module

The jax_hdc.utils module provides utility functions for configuration and benchmarking.

Configuration

Random Number Generation

Benchmarking

jax_hdc.utils.benchmark_function(fn: Callable[[...], Any], *args: Any, num_trials: int = 100, warmup: int = 10, **kwargs: Any) dict[str, float | int][source]

Benchmark a JAX function with proper warmup and async handling.

Parameters:
  • fn – Function to benchmark

  • *args – Positional arguments to fn

  • num_trials – Number of trials to run

  • warmup – Number of warmup trials

  • **kwargs – Keyword arguments to fn

Returns:

Dictionary with timing statistics (mean, std, min, max, median in ms)

Validation

Helpers

jax_hdc.utils.normalize(x: Array, axis: int = -1, eps: float = 1e-08) Array[source]

Normalize vectors to unit length.

Parameters:
  • x – Input array

  • axis – Axis along which to normalize (default: -1)

  • eps – Small constant to avoid division by zero

Example Usage

Memory configuration:

from jax_hdc.utils import configure_memory

# Use 90% of GPU memory with flexible allocation
configure_memory(preallocate=False, memory_fraction=0.9)

Device management:

from jax_hdc.utils import get_device, to_device

device = get_device('gpu', 0)
data_on_gpu = to_device(data, device)

Benchmarking:

from jax_hdc.utils import benchmark_function
from jax_hdc.functional import bind_map

stats = benchmark_function(bind_map, x, y, num_trials=100)
print(f"Mean time: {stats['mean_ms']:.3f} ms")