Embeddings Module
The jax_hdc.embeddings module provides encoders for transforming data into hypervectors.
RandomEncoder
- class jax_hdc.embeddings.RandomEncoder(codebook: Array, num_features: int, num_values: int, dimensions: int, vsa_model_name: str = 'map')[source]
Encoder using random hypervectors for discrete features.
Each unique feature value is mapped to a random hypervector from a codebook. Multiple features are bundled together to form the final representation.
- static create(num_features: int, num_values: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) RandomEncoder[source]
Create a random encoder.
- Parameters:
num_features – Number of features to encode
num_values – Number of possible values per feature
dimensions – Dimensionality of hypervectors (default: 10000)
vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’) or VSAModel instance
key – JAX random key (default: PRNGKey(0))
- Returns:
Initialized RandomEncoder
- encode(indices: Array) Array[source]
Encode discrete features as hypervectors.
- Parameters:
indices – Feature indices of shape (num_features,) with values in [0, num_values). Out-of-bounds indices are clamped to valid range.
- Returns:
Encoded hypervector of shape (dimensions,)
Example:
from jax_hdc import MAP, RandomEncoder
import jax
model = MAP.create(dimensions=10000)
key = jax.random.PRNGKey(42)
encoder = RandomEncoder.create(
num_features=20,
num_values=10,
dimensions=10000,
vsa_model=model,
key=key
)
# Encode discrete features
data = jax.random.randint(key, (20,), 0, 10)
encoded = encoder.encode(data)
LevelEncoder
- class jax_hdc.embeddings.LevelEncoder(level_hvs: Array, num_levels: int, dimensions: int, min_value: float, max_value: float, vsa_model_name: str = 'map', encoding_type: str = 'linear')[source]
Encoder for continuous values using level hypervectors.
Continuous values are encoded by interpolating between level hypervectors, creating a smooth representation where similar values map to similar hypervectors.
- static create(num_levels: int = 100, dimensions: int = 10000, min_value: float = 0.0, max_value: float = 1.0, vsa_model: str | VSAModel = 'map', encoding_type: str = 'linear', key: Array | None = None) LevelEncoder[source]
Create a level encoder.
- Parameters:
num_levels – Number of levels for discretization (default: 100)
dimensions – Dimensionality of hypervectors (default: 10000)
min_value – Minimum value of the range (default: 0.0)
max_value – Maximum value of the range (default: 1.0)
vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’)
encoding_type – ‘linear’ or ‘circular’ (default: ‘linear’)
key – JAX random key
- Returns:
Initialized LevelEncoder
- encode(value: float | Array) Array[source]
Encode a continuous value as a hypervector.
- Parameters:
value – Continuous value to encode (scalar or array)
- Returns:
Encoded hypervector of shape (dimensions,) or batch shape + (dimensions,)
Example:
from jax_hdc import LevelEncoder
encoder = LevelEncoder.create(
num_levels=100,
dimensions=10000,
min_value=0.0,
max_value=1.0,
vsa_model=model,
key=key
)
# Encode continuous value
encoded = encoder.encode(0.75)
ProjectionEncoder
- class jax_hdc.embeddings.ProjectionEncoder(projection_matrix: Array, input_dim: int, dimensions: int, vsa_model_name: str = 'map')[source]
Encoder using random projection for high-dimensional data.
Projects high-dimensional input data into hypervector space using a random projection matrix. Useful for images, text embeddings, etc.
- static create(input_dim: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) ProjectionEncoder[source]
Create a projection encoder.
- Parameters:
input_dim – Dimensionality of input data
dimensions – Dimensionality of hypervectors (default: 10000)
vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’)
key – JAX random key
- Returns:
Initialized ProjectionEncoder
- encode(x: Array) Array[source]
Encode input data as a hypervector.
- Parameters:
x – Input data of shape (input_dim,)
- Returns:
Encoded hypervector of shape (dimensions,)
Example:
from jax_hdc import ProjectionEncoder
encoder = ProjectionEncoder.create(
input_dim=784,
dimensions=10000,
vsa_model=model,
key=key
)
# Encode high-dimensional input
image = jax.random.normal(key, (784,))
encoded = encoder.encode(image)
KernelEncoder
- class jax_hdc.embeddings.KernelEncoder(omega: Array, bias: Array, input_dim: int, dimensions: int, gamma: float, vsa_model_name: str = 'map')[source]
Encoder using RBF kernel approximation (Random Fourier Features).
Approximates the RBF kernel k(x,y) = exp(-gamma ||x-y||^2) via random Fourier features, mapping input to a hypervector space that preserves kernel similarity.
- static create(input_dim: int, dimensions: int = 10000, gamma: float = 1.0, vsa_model: str | VSAModel = 'map', key: Array | None = None) KernelEncoder[source]
Create a kernel encoder.
- Parameters:
input_dim – Dimensionality of input data
dimensions – Dimensionality of output hypervectors
gamma – RBF kernel scale parameter (1 / 2*sigma^2)
vsa_model – VSA model (‘map’, ‘hrr’, ‘fhrr’ for real-valued)
key – JAX random key
- Returns:
Initialized KernelEncoder
GraphEncoder
- class jax_hdc.embeddings.GraphEncoder(node_embeddings: Array, num_nodes: int, dimensions: int, vsa_model_name: str = 'map')[source]
Encoder for graph structures (nodes and edges).
Encodes a graph by assigning random hypervectors to nodes and bundling bound node pairs for edges. Graph = bundle of edge HVs.
- static create(num_nodes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) GraphEncoder[source]
Create a graph encoder.
- Parameters:
num_nodes – Maximum number of nodes
dimensions – Hypervector dimensionality
vsa_model – VSA model for real-valued graphs
key – JAX random key
- Returns:
Initialized GraphEncoder