Source code for jax_hdc.vsa

"""Vector Symbolic Architecture (VSA) model implementations.

This module provides different VSA models, each with their own binding,
bundling, and similarity operations. All models follow a consistent API.
"""

from dataclasses import dataclass, field

import jax
import jax.numpy as jnp

from jax_hdc import functional as F
from jax_hdc._compat import register_dataclass
from jax_hdc.constants import EPS


[docs] @register_dataclass @dataclass class VSAModel: """Base class for VSA models defining the interface.""" name: str = field(metadata=dict(static=True)) dimensions: int = field(metadata=dict(static=True))
[docs] def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind two hypervectors.""" raise NotImplementedError
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle multiple hypervectors.""" raise NotImplementedError
[docs] def inverse(self, x: jax.Array) -> jax.Array: """Compute the inverse of a hypervector.""" raise NotImplementedError
[docs] def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute similarity between hypervectors.""" raise NotImplementedError
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random hypervectors.""" raise NotImplementedError
[docs] @register_dataclass @dataclass class BSC(VSAModel): """Binary Spatter Codes (BSC). Binary hypervectors with XOR binding, majority bundling, Hamming similarity. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "BSC": """Create a BSC model. Args: dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized BSC model """ return BSC(name="bsc", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using XOR.""" return F.bind_bsc(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using majority rule.""" return F.bundle_bsc(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse is identity for XOR.""" return F.inverse_bsc(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute Hamming similarity.""" return F.hamming_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random binary hypervectors. Args: key: JAX random key shape: Shape of output array Returns: Random binary hypervectors with ~50% ones """ return jax.random.bernoulli(key, 0.5, shape=shape)
[docs] @register_dataclass @dataclass class MAP(VSAModel): """Multiply-Add-Permute (MAP). Real-valued vectors with element-wise multiply binding, normalized sum bundling, cosine similarity. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "MAP": """Create a MAP model. Args: dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized MAP model """ return MAP(name="map", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using element-wise multiplication.""" return F.bind_map(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using normalized sum.""" return F.bundle_map(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via element-wise reciprocal.""" return F.inverse_map(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity.""" return F.cosine_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random real-valued hypervectors. Args: key: JAX random key shape: Shape of output array Returns: Random normalized hypervectors sampled from normal distribution """ vectors = jax.random.normal(key, shape=shape) # Normalize to unit length norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True) return vectors / (norm + EPS)
[docs] @register_dataclass @dataclass class HRR(VSAModel): """Holographic Reduced Representations (HRR). Real-valued vectors with circular convolution binding, normalized sum bundling, cosine similarity. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "HRR": """Create an HRR model. Args: dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized HRR model """ return HRR(name="hrr", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using circular convolution.""" return F.bind_hrr(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using normalized sum.""" return F.bundle_hrr(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via element reversal.""" return F.inverse_hrr(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity.""" return F.cosine_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random real-valued hypervectors. Args: key: JAX random key shape: Shape of output array Returns: Random normalized hypervectors sampled from normal distribution """ vectors = jax.random.normal(key, shape=shape) # Normalize to unit length norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True) return vectors / (norm + EPS)
[docs] @register_dataclass @dataclass class FHRR(VSAModel): """Fourier Holographic Reduced Representations (FHRR). Complex-valued vectors with element-wise multiply binding, normalized sum bundling. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "FHRR": """Create an FHRR model. Args: dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized FHRR model """ return FHRR(name="fhrr", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using element-wise multiplication.""" return x * y
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using normalized sum.""" summed = jnp.sum(vectors, axis=axis) norm = jnp.linalg.norm(summed, axis=-1, keepdims=True) return summed / (norm + EPS)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via complex conjugate.""" return jnp.conj(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity of complex vectors.""" x_norm = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + EPS) y_norm = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + EPS) # Use real part of inner product, clip to handle floating point precision return jnp.clip(jnp.real(jnp.sum(x_norm * jnp.conj(y_norm), axis=-1)), -1.0, 1.0)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random complex hypervectors on unit circle. Args: key: JAX random key shape: Shape of output array Returns: Random unit complex hypervectors """ # Random phases on unit circle phases = jax.random.uniform(key, shape=shape, minval=0, maxval=2 * jnp.pi) return jnp.exp(1j * phases)
[docs] @register_dataclass @dataclass class BSBC(VSAModel): """Binary Sparse Block Codes (B-SBC). Block-sparse binary vectors with k_active ones per block, XOR binding, majority bundling. """ block_size: int = field(metadata=dict(static=True), default=100) k_active: int = field(metadata=dict(static=True), default=5)
[docs] @staticmethod def create( dimensions: int = 10000, block_size: int = 100, k_active: int = 5, ) -> "BSBC": """Create a B-SBC model. Args: dimensions: Total dimensionality (must be divisible by block_size) block_size: Size of each block k_active: Number of ones per block (sparsity) Returns: Initialized BSBC model """ if dimensions % block_size != 0: raise ValueError( f"dimensions ({dimensions}) must be divisible by block_size ({block_size})" ) if k_active > block_size or k_active < 1: raise ValueError(f"k_active must be in [1, block_size], got {k_active}") return BSBC( name="bsbc", dimensions=dimensions, block_size=block_size, k_active=k_active, )
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using XOR (same as BSC).""" return F.bind_bsc(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using majority rule.""" return F.bundle_bsc(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse is identity for XOR.""" return F.inverse_bsc(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute Hamming similarity.""" return F.hamming_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random block-sparse binary hypervectors.""" num_blocks = self.dimensions // self.block_size def gen_block(key_b: jax.Array) -> jax.Array: perm = jax.random.permutation(key_b, self.block_size) block = jnp.zeros(self.block_size, dtype=jnp.bool_) return block.at[perm[: self.k_active]].set(True) batch_size = max(1, int(jnp.prod(jnp.array(shape))) // self.dimensions) keys = jax.random.split(key, batch_size * num_blocks + 1)[1:] keys_per_hv = jnp.reshape( jnp.stack(keys[: batch_size * num_blocks]), (batch_size, num_blocks, 2) ) def make_hv(block_keys: jax.Array) -> jax.Array: blocks = jax.vmap(gen_block)(block_keys) return jnp.reshape(blocks, (self.dimensions,)) hvs = jax.vmap(make_hv)(keys_per_hv) if batch_size == 1 and shape == (self.dimensions,): return hvs[0] if batch_size == 1 and len(shape) == 1: return hvs[0] return jnp.reshape(hvs, shape)
[docs] @register_dataclass @dataclass class CGR(VSAModel): """Cyclic Group Representation (CGR). Integer hypervectors in Z_q with modular addition binding, component-wise mode bundling. """ q: int = field(metadata=dict(static=True), default=8)
[docs] @staticmethod def create(dimensions: int = 10000, q: int = 8) -> "CGR": if q < 2: raise ValueError(f"q must be >= 2, got {q}") return CGR(name="cgr", dimensions=dimensions, q=q)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using modular addition.""" return F.bind_cgr(x, y, self.q)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using component-wise mode.""" return F.bundle_cgr(vectors, self.q, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via modular negation.""" return F.inverse_cgr(x, self.q)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute fraction of matching elements.""" return F.matching_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random integer hypervectors in {0, ..., q-1}.""" return jax.random.randint(key, shape=shape, minval=0, maxval=self.q)
[docs] @register_dataclass @dataclass class MCR(VSAModel): """Modular Composite Representation (MCR). Integer phase vectors with modular addition binding, phasor sum bundling. """ q: int = field(metadata=dict(static=True), default=64)
[docs] @staticmethod def create(dimensions: int = 10000, q: int = 64) -> "MCR": if q < 2: raise ValueError(f"q must be >= 2, got {q}") return MCR(name="mcr", dimensions=dimensions, q=q)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using modular addition (phase addition).""" return F.bind_mcr(x, y, self.q)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using phasor sum with snap-to-grid.""" return F.bundle_mcr(vectors, self.q, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via modular negation (phase conjugate).""" return F.inverse_mcr(x, self.q)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute phasor similarity.""" return F.phasor_similarity(x, y, self.q)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random integer hypervectors in {0, ..., q-1}.""" return jax.random.randint(key, shape=shape, minval=0, maxval=self.q)
[docs] @register_dataclass @dataclass class VTB(VSAModel): """Vector-Derived Transformation Binding (VTB). Real-valued vectors with matrix multiplication binding, normalized sum bundling. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "VTB": n = round(dimensions**0.5) if n * n != dimensions: raise ValueError(f"VTB requires dimensions to be a perfect square, got {dimensions}") return VTB(name="vtb", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using matrix multiplication.""" return F.bind_vtb(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using normalized sum.""" return F.bundle_vtb(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via matrix pseudoinverse.""" return F.inverse_vtb(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity.""" return F.cosine_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random normalized real-valued hypervectors.""" vectors = jax.random.normal(key, shape=shape) norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True) return vectors / (norm + EPS)
[docs] def create_vsa_model(model_type: str = "map", dimensions: int = 10000) -> VSAModel: """Factory function to create VSA models. Args: model_type: Type of VSA model ('bsc', 'map', 'hrr', 'fhrr', 'bsbc', 'cgr', 'mcr', 'vtb') dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized VSA model """ models = { "bsc": BSC, "map": MAP, "hrr": HRR, "fhrr": FHRR, "bsbc": BSBC, "cgr": CGR, "mcr": MCR, "vtb": VTB, } if model_type not in models: raise ValueError( f"Unknown VSA model: {model_type}. Available models: {list(models.keys())}" ) return models[model_type].create(dimensions=dimensions) # type: ignore[attr-defined]
__all__ = [ "VSAModel", "BSC", "MAP", "HRR", "FHRR", "BSBC", "CGR", "MCR", "VTB", "create_vsa_model", ]