Embeddings Module

The jax_hdc.embeddings module provides encoders for transforming data into hypervectors.

RandomEncoder

class jax_hdc.embeddings.RandomEncoder(codebook: Array, num_features: int, num_values: int, dimensions: int, vsa_model_name: str = 'map')[source]

Encoder using random hypervectors for discrete features.

Each unique feature value is mapped to a random hypervector from a codebook. Multiple features are bundled together to form the final representation.

codebook: Array
num_features: int
num_values: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_features: int, num_values: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) RandomEncoder[source]

Create a random encoder.

Parameters:
  • num_features – Number of features to encode

  • num_values – Number of possible values per feature

  • dimensions – Dimensionality of hypervectors (default: 10000)

  • vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’) or VSAModel instance

  • key – JAX random key (default: PRNGKey(0))

Returns:

Initialized RandomEncoder

encode(indices: Array) Array[source]

Encode discrete features as hypervectors.

Parameters:

indices – Feature indices of shape (num_features,) with values in [0, num_values). Out-of-bounds indices are clamped to valid range.

Returns:

Encoded hypervector of shape (dimensions,)

encode_batch(indices: Array) Array[source]

Encode a batch of samples.

Parameters:

indices – Batch of feature indices of shape (batch_size, num_features)

Returns:

Encoded hypervectors of shape (batch_size, dimensions)

__init__(codebook: Array, num_features: int, num_values: int, dimensions: int, vsa_model_name: str = 'map') None

Example:

from jax_hdc import MAP, RandomEncoder
import jax

model = MAP.create(dimensions=10000)
key = jax.random.PRNGKey(42)

encoder = RandomEncoder.create(
    num_features=20,
    num_values=10,
    dimensions=10000,
    vsa_model=model,
    key=key
)

# Encode discrete features
data = jax.random.randint(key, (20,), 0, 10)
encoded = encoder.encode(data)

LevelEncoder

class jax_hdc.embeddings.LevelEncoder(level_hvs: Array, num_levels: int, dimensions: int, min_value: float, max_value: float, vsa_model_name: str = 'map', encoding_type: str = 'linear')[source]

Encoder for continuous values using level hypervectors.

Continuous values are encoded by interpolating between level hypervectors, creating a smooth representation where similar values map to similar hypervectors.

level_hvs: Array
num_levels: int
dimensions: int
min_value: float
max_value: float
vsa_model_name: str = 'map'
encoding_type: str = 'linear'
static create(num_levels: int = 100, dimensions: int = 10000, min_value: float = 0.0, max_value: float = 1.0, vsa_model: str | VSAModel = 'map', encoding_type: str = 'linear', key: Array | None = None) LevelEncoder[source]

Create a level encoder.

Parameters:
  • num_levels – Number of levels for discretization (default: 100)

  • dimensions – Dimensionality of hypervectors (default: 10000)

  • min_value – Minimum value of the range (default: 0.0)

  • max_value – Maximum value of the range (default: 1.0)

  • vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’)

  • encoding_type – ‘linear’ or ‘circular’ (default: ‘linear’)

  • key – JAX random key

Returns:

Initialized LevelEncoder

encode(value: float | Array) Array[source]

Encode a continuous value as a hypervector.

Parameters:

value – Continuous value to encode (scalar or array)

Returns:

Encoded hypervector of shape (dimensions,) or batch shape + (dimensions,)

encode_batch(values: Array) Array[source]

Encode a batch of continuous values.

Parameters:

values – Batch of values of shape (batch_size,) or (batch_size, num_features)

Returns:

Encoded hypervectors

__init__(level_hvs: Array, num_levels: int, dimensions: int, min_value: float, max_value: float, vsa_model_name: str = 'map', encoding_type: str = 'linear') None

Example:

from jax_hdc import LevelEncoder

encoder = LevelEncoder.create(
    num_levels=100,
    dimensions=10000,
    min_value=0.0,
    max_value=1.0,
    vsa_model=model,
    key=key
)

# Encode continuous value
encoded = encoder.encode(0.75)

ProjectionEncoder

class jax_hdc.embeddings.ProjectionEncoder(projection_matrix: Array, input_dim: int, dimensions: int, vsa_model_name: str = 'map')[source]

Encoder using random projection for high-dimensional data.

Projects high-dimensional input data into hypervector space using a random projection matrix. Useful for images, text embeddings, etc.

projection_matrix: Array
input_dim: int
dimensions: int
vsa_model_name: str = 'map'
static create(input_dim: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) ProjectionEncoder[source]

Create a projection encoder.

Parameters:
  • input_dim – Dimensionality of input data

  • dimensions – Dimensionality of hypervectors (default: 10000)

  • vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’)

  • key – JAX random key

Returns:

Initialized ProjectionEncoder

encode(x: Array) Array[source]

Encode input data as a hypervector.

Parameters:

x – Input data of shape (input_dim,)

Returns:

Encoded hypervector of shape (dimensions,)

encode_batch(x: Array) Array[source]

Encode a batch of inputs.

Parameters:

x – Batch of inputs of shape (batch_size, input_dim)

Returns:

Encoded hypervectors of shape (batch_size, dimensions)

__init__(projection_matrix: Array, input_dim: int, dimensions: int, vsa_model_name: str = 'map') None

Example:

from jax_hdc import ProjectionEncoder

encoder = ProjectionEncoder.create(
    input_dim=784,
    dimensions=10000,
    vsa_model=model,
    key=key
)

# Encode high-dimensional input
image = jax.random.normal(key, (784,))
encoded = encoder.encode(image)

KernelEncoder

class jax_hdc.embeddings.KernelEncoder(omega: Array, bias: Array, input_dim: int, dimensions: int, gamma: float, vsa_model_name: str = 'map')[source]

Encoder using RBF kernel approximation (Random Fourier Features).

Approximates the RBF kernel k(x,y) = exp(-gamma ||x-y||^2) via random Fourier features, mapping input to a hypervector space that preserves kernel similarity.

omega: Array
bias: Array
input_dim: int
dimensions: int
gamma: float
vsa_model_name: str = 'map'
static create(input_dim: int, dimensions: int = 10000, gamma: float = 1.0, vsa_model: str | VSAModel = 'map', key: Array | None = None) KernelEncoder[source]

Create a kernel encoder.

Parameters:
  • input_dim – Dimensionality of input data

  • dimensions – Dimensionality of output hypervectors

  • gamma – RBF kernel scale parameter (1 / 2*sigma^2)

  • vsa_model – VSA model (‘map’, ‘hrr’, ‘fhrr’ for real-valued)

  • key – JAX random key

Returns:

Initialized KernelEncoder

encode(x: Array) Array[source]

Encode input using RBF kernel approximation.

encode_batch(x: Array) Array[source]

Encode a batch of inputs.

__init__(omega: Array, bias: Array, input_dim: int, dimensions: int, gamma: float, vsa_model_name: str = 'map') None

GraphEncoder

class jax_hdc.embeddings.GraphEncoder(node_embeddings: Array, num_nodes: int, dimensions: int, vsa_model_name: str = 'map')[source]

Encoder for graph structures (nodes and edges).

Encodes a graph by assigning random hypervectors to nodes and bundling bound node pairs for edges. Graph = bundle of edge HVs.

node_embeddings: Array
num_nodes: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_nodes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) GraphEncoder[source]

Create a graph encoder.

Parameters:
  • num_nodes – Maximum number of nodes

  • dimensions – Hypervector dimensionality

  • vsa_model – VSA model for real-valued graphs

  • key – JAX random key

Returns:

Initialized GraphEncoder

encode_edges(edges: Array) Array[source]

Encode graph as bundle of bound edge pairs.

Parameters:

edges – Array of shape (num_edges, 2) with node indices in [0, num_nodes). Out-of-bounds indices are clamped to valid range.

Returns:

Graph hypervector of shape (dimensions,)

__init__(node_embeddings: Array, num_nodes: int, dimensions: int, vsa_model_name: str = 'map') None