Source code for jax_hdc.functional

"""Core functional operations for Hyperdimensional Computing.

This module provides the fundamental operations for manipulating hypervectors:
- Binding: Combines two hypervectors into a dissimilar result
- Bundling: Aggregates multiple hypervectors into a similar result
- Permutation: Reorders elements to encode sequences
- Similarity: Measures relatedness between hypervectors
"""

import functools
from typing import Callable, Union

import jax
import jax.numpy as jnp

from jax_hdc.constants import EPS


[docs] @jax.jit def bind_bsc(x: jax.Array, y: jax.Array) -> jax.Array: """Bind two hypervectors using XOR for Binary Spatter Codes. Binding creates a new hypervector that is dissimilar to both inputs. Args: x: Binary hypervector of shape (..., d) y: Binary hypervector of shape (..., d) Returns: Bound hypervector of shape (..., d), dissimilar to both x and y """ return jnp.logical_xor(x, y)
[docs] def bundle_bsc(vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle hypervectors using majority rule for Binary Spatter Codes. Bundling creates a new hypervector similar to all inputs by taking the majority vote at each dimension. Args: vectors: Binary hypervectors of shape with axis containing vectors to bundle axis: Axis along which to bundle (default: 0) Returns: Bundled hypervector, similar to all inputs """ counts = jnp.sum(vectors, axis=axis) shape_size = vectors.shape[axis] threshold = shape_size / 2.0 return counts > threshold
[docs] @jax.jit def inverse_bsc(x: jax.Array) -> jax.Array: """Compute inverse for BSC (identity since XOR is self-inverse).""" return x # XOR is self-inverse
[docs] @jax.jit def hamming_similarity(x: jax.Array, y: jax.Array) -> jax.Array: """Compute normalized Hamming similarity between binary hypervectors. Returns the fraction of matching bits between two binary vectors. Random vectors have similarity ≈ 0.5. Args: x: Binary hypervector of shape (..., d) y: Binary hypervector of shape (..., d) Returns: Similarity score in [0, 1], where 1 is identical and 0.5 is random """ matches = jnp.logical_not(jnp.logical_xor(x, y)) return jnp.mean(matches.astype(jnp.float32), axis=-1)
[docs] @jax.jit def bind_map(x: jax.Array, y: jax.Array) -> jax.Array: """Bind two hypervectors using element-wise multiplication for MAP. For real-valued vectors (MAP model), binding is element-wise multiplication. The result is dissimilar to both inputs. Args: x: Real-valued hypervector of shape (..., d) y: Real-valued hypervector of shape (..., d) Returns: Bound hypervector of shape (..., d) """ return x * y
[docs] def bundle_map(vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle hypervectors using normalized sum for MAP. For real-valued vectors, bundling is the normalized sum. The result is similar to all inputs (high cosine similarity). Args: vectors: Real-valued hypervectors with axis containing vectors to bundle axis: Axis along which to bundle (default: 0) Returns: Bundled and normalized hypervector """ summed = jnp.sum(vectors, axis=axis) norm = jnp.linalg.norm(summed, axis=-1, keepdims=True) return summed / (norm + EPS)
[docs] @jax.jit def inverse_map(x: jax.Array, eps: float = EPS) -> jax.Array: """Compute inverse for MAP using element-wise reciprocal. For MAP binding (element-wise multiplication), the inverse is element-wise reciprocal: bind(bind(x, y), inverse(y)) = x. Near-zero elements return 0 (no inverse; bind with 0 destroys information). Args: x: Real-valued hypervector of shape (..., d) eps: Small constant for numerical stability (default: EPS) Returns: Inverse hypervector """ safe_inv = jnp.where(jnp.abs(x) > eps, 1.0 / x, 0.0) return safe_inv
[docs] @jax.jit def cosine_similarity(x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity between real-valued hypervectors. Returns the cosine of the angle between two vectors. Random unit vectors have similarity ≈ 0. Args: x: Real-valued hypervector of shape (..., d) y: Real-valued hypervector of shape (..., d) Returns: Similarity score in [-1, 1], where 1 is identical, -1 is opposite, and 0 is orthogonal """ x_norm = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + EPS) y_norm = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + EPS) return jnp.clip(jnp.sum(x_norm * y_norm, axis=-1), -1.0, 1.0)
[docs] @jax.jit def permute(x: jax.Array, shifts: int = 1) -> jax.Array: """Cyclically permute a hypervector to encode sequence information. Permutation reorders elements to represent positional or sequential information. Cyclic shifts preserve the distribution of values. Args: x: Hypervector of shape (..., d) shifts: Number of positions to shift (default: 1) Returns: Permuted hypervector of shape (..., d) """ return jnp.roll(x, shifts, axis=-1)
[docs] @functools.partial(jax.jit, static_argnames=("return_similarity",)) def cleanup( query: jax.Array, memory: jax.Array, similarity_fn: Callable[[jax.Array, jax.Array], jax.Array] = cosine_similarity, return_similarity: bool = False, ) -> Union[jax.Array, tuple[jax.Array, jax.Array]]: """Find the most similar vector in memory to the query. Cleanup (or resonator) is used to retrieve the closest known hypervector from memory, useful for error correction and symbol retrieval. Args: query: Query hypervector of shape (..., d) memory: Memory hypervectors of shape (n, d) similarity_fn: Function to compute similarity (default: cosine_similarity) return_similarity: Whether to return similarity scores (default: False) Returns: Most similar vector from memory, or (vector, similarity) if return_similarity=True """ similarities = jax.vmap(lambda m: similarity_fn(query, m))(memory) best_idx = jnp.argmax(similarities) best_vector = memory[best_idx] if return_similarity: return best_vector, similarities[best_idx] return best_vector
# Batch versions for common operations batch_bind_bsc = jax.vmap(bind_bsc, in_axes=(0, 0)) batch_bind_map = jax.vmap(bind_map, in_axes=(0, 0)) batch_hamming_similarity = jax.vmap(hamming_similarity, in_axes=(0, None)) batch_cosine_similarity = jax.vmap(cosine_similarity, in_axes=(0, None))
[docs] @jax.jit def bind_hrr(x: jax.Array, y: jax.Array) -> jax.Array: """Bind two hypervectors using circular convolution for HRR. Circular convolution in the spatial domain is equivalent to element-wise multiplication in the Fourier domain, making it efficient to compute. Args: x: Real-valued hypervector of shape (..., d) y: Real-valued hypervector of shape (..., d) Returns: Bound hypervector via circular convolution """ x_fft = jnp.fft.fft(x, axis=-1) y_fft = jnp.fft.fft(y, axis=-1) result_fft = x_fft * y_fft return jnp.fft.ifft(result_fft, axis=-1).real
[docs] @jax.jit def inverse_hrr(x: jax.Array) -> jax.Array: """Compute inverse for HRR (reverse the circular convolution). For HRR, the inverse reverses the order of elements (except the first). Args: x: Real-valued hypervector of shape (..., d) Returns: Inverse hypervector """ return jnp.concatenate([x[..., :1], jnp.flip(x[..., 1:], axis=-1)], axis=-1)
# HRR bundle reuses MAP bundle bundle_hrr = bundle_map @jax.jit def bind_cgr(x: jax.Array, y: jax.Array, q: int) -> jax.Array: """Bind using modular addition for Cyclic Group Representation. Args: x: Integer hypervector with values in {0, ..., q-1}, shape (..., d) y: Integer hypervector with values in {0, ..., q-1}, shape (..., d) q: Size of the cyclic group Returns: Bound hypervector: (x + y) mod q """ return (x + y) % q def bundle_cgr(vectors: jax.Array, q: int, axis: int = 0) -> jax.Array: """Bundle using component-wise mode for CGR. Selects the most frequent value at each dimension. Args: vectors: Integer hypervectors with values in {0, ..., q-1} q: Size of the cyclic group axis: Axis along which to bundle (default: 0) Returns: Bundled hypervector with mode value at each dimension """ one_hot = jax.nn.one_hot(vectors, q) counts = jnp.sum(one_hot, axis=axis) return jnp.argmax(counts, axis=-1).astype(jnp.int32) @jax.jit def inverse_cgr(x: jax.Array, q: int) -> jax.Array: """Inverse via modular negation for CGR. Args: x: Integer hypervector with values in {0, ..., q-1} q: Size of the cyclic group Returns: Inverse: (q - x) mod q """ return (q - x) % q @jax.jit def matching_similarity(x: jax.Array, y: jax.Array) -> jax.Array: """Fraction of matching elements between integer hypervectors. Random vectors with q levels have expected similarity 1/q. Args: x: Integer hypervector of shape (..., d) y: Integer hypervector of shape (..., d) Returns: Similarity in [0, 1] """ return jnp.mean((x == y).astype(jnp.float32), axis=-1) # MCR bind reuses CGR bind bind_mcr = bind_cgr # MCR inverse reuses CGR inverse inverse_mcr = inverse_cgr def bundle_mcr(vectors: jax.Array, q: int, axis: int = 0) -> jax.Array: """Bundle using phasor sum and snap-to-grid for MCR. Converts indices to complex phasors (q-th roots of unity), sums them, and snaps back to the nearest discrete phase level. Args: vectors: Integer hypervectors with values in {0, ..., q-1} q: Number of phase levels axis: Axis along which to bundle (default: 0) Returns: Bundled hypervector with values snapped to nearest phase index """ phases = 2 * jnp.pi * vectors.astype(jnp.float32) / q phasors = jnp.exp(1j * phases) summed = jnp.sum(phasors, axis=axis) result_angles = jnp.angle(summed) % (2 * jnp.pi) result_indices = jnp.round(result_angles * q / (2 * jnp.pi)) % q return result_indices.astype(jnp.int32) @jax.jit def phasor_similarity(x: jax.Array, y: jax.Array, q: int) -> jax.Array: """Similarity via phasor inner product for MCR. Converts to phasors and computes the real part of the normalized inner product. Random vectors have expected similarity ~0. Args: x: Integer hypervector with values in {0, ..., q-1} y: Integer hypervector with values in {0, ..., q-1} q: Number of phase levels Returns: Similarity in [-1, 1] """ phases_x = 2 * jnp.pi * x.astype(jnp.float32) / q phases_y = 2 * jnp.pi * y.astype(jnp.float32) / q phasors_x = jnp.exp(1j * phases_x) phasors_y = jnp.exp(1j * phases_y) return jnp.real(jnp.mean(phasors_x * jnp.conj(phasors_y), axis=-1)) @jax.jit def bind_vtb(x: jax.Array, y: jax.Array) -> jax.Array: """Bind using matrix multiplication for VTB. Reshapes d-dimensional vectors into sqrt(d) x sqrt(d) matrices and multiplies them. Requires d to be a perfect square. Args: x: Real-valued hypervector of shape (..., d) y: Real-valued hypervector of shape (..., d) Returns: Bound hypervector of shape (..., d) """ d = x.shape[-1] n = round(d**0.5) X = x.reshape(*x.shape[:-1], n, n) Y = y.reshape(*y.shape[:-1], n, n) return (X @ Y).reshape(*x.shape[:-1], d) @jax.jit def inverse_vtb(x: jax.Array) -> jax.Array: """Inverse via matrix pseudoinverse for VTB.""" d = x.shape[-1] n = round(d**0.5) X = x.reshape(*x.shape[:-1], n, n) return jnp.linalg.pinv(X).reshape(*x.shape[:-1], d) # VTB bundle reuses MAP bundle bundle_vtb = bundle_map __all__ = [ # BSC operations "bind_bsc", "bundle_bsc", "inverse_bsc", "hamming_similarity", # MAP operations "bind_map", "bundle_map", "inverse_map", "cosine_similarity", # HRR operations "bind_hrr", "bundle_hrr", "inverse_hrr", # CGR operations "bind_cgr", "bundle_cgr", "inverse_cgr", "matching_similarity", # MCR operations "bind_mcr", "bundle_mcr", "inverse_mcr", "phasor_similarity", # VTB operations "bind_vtb", "bundle_vtb", "inverse_vtb", # Universal operations "permute", "cleanup", # Batch operations "batch_bind_bsc", "batch_bind_map", "batch_hamming_similarity", "batch_cosine_similarity", ]