Models Module

The jax_hdc.models module provides classification and learning algorithms.

CentroidClassifier

class jax_hdc.models.CentroidClassifier(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]

Centroid-based classifier for HDC.

Stores one prototype hypervector per class. Classification finds the most similar prototype to the query.

prototypes: Array
num_classes: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', initial_prototypes: Array | None = None, key: Array | None = None) CentroidClassifier[source]

Create a centroid classifier.

Parameters:
  • num_classes – Number of classes

  • dimensions – Dimensionality of hypervectors

  • vsa_model – VSA model name or instance

  • initial_prototypes – Optional initial prototypes of shape (num_classes, dimensions)

  • key – JAX random key for initialization

similarity(query: Array) Array[source]

Compute similarity between query and all class prototypes.

predict(queries: Array) Array[source]

Predict class labels for queries.

Parameters:

queries – Shape (batch_size, dimensions) or (dimensions,)

Returns:

Predicted class indices

predict_proba(queries: Array) Array[source]

Predict class probabilities using softmax of similarities.

fit(train_hvs: Array, train_labels: Array) CentroidClassifier[source]

Train classifier by computing class centroids.

Parameters:
  • train_hvs – Training hypervectors of shape (n_samples, dimensions)

  • train_labels – Training labels of shape (n_samples,)

Returns:

Trained CentroidClassifier (new instance)

update_online(sample_hv: Array, label: int, learning_rate: float = 0.1) CentroidClassifier[source]

Update classifier online with a single sample.

score(test_hvs: Array, test_labels: Array) Array[source]

Compute accuracy on test data.

replace(**updates: Any) CentroidClassifier[source]
__init__(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map') None

Example:

from jax_hdc import MAP, CentroidClassifier
import jax
import jax.numpy as jnp

model = MAP.create(dimensions=10000)
key = jax.random.PRNGKey(42)

# Create classifier
classifier = CentroidClassifier.create(
    num_classes=10,
    dimensions=10000,
    vsa_model=model
)

# Train
train_hvs = model.random(key, (100, 10000))
train_labels = jax.random.randint(key, (100,), 0, 10)
classifier = classifier.fit(train_hvs, train_labels)

# Predict
test_hvs = model.random(key, (20, 10000))
predictions = classifier.predict(test_hvs)

# Evaluate
test_labels = jax.random.randint(key, (20,), 0, 10)
accuracy = classifier.score(test_hvs, test_labels)

LVQClassifier

class jax_hdc.models.LVQClassifier(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]

Learning Vector Quantization classifier.

Prototypes are updated: move winner toward sample if correct, away if wrong.

prototypes: Array
num_classes: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) LVQClassifier[source]
predict(queries: Array) Array[source]

Predict class labels by nearest prototype.

fit(train_hvs: Array, train_labels: Array, epochs: int = 10, lr: float = 0.1) LVQClassifier[source]

Train with LVQ updates (winner-take-all, move toward/away).

score(test_hvs: Array, test_labels: Array) Array[source]
replace(**updates: Any) LVQClassifier[source]
__init__(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map') None

RegularizedLSClassifier

class jax_hdc.models.RegularizedLSClassifier(weights: Array, dimensions: int, num_classes: int, reg: float)[source]

Regularized Least Squares classifier in HV space.

Solves (X^T X + lambda I) W = X^T Y for weights W.

weights: Array
dimensions: int
num_classes: int
reg: float
static create(dimensions: int, num_classes: int, reg: float = 0.0001) RegularizedLSClassifier[source]
fit(train_hvs: Array, train_labels: Array) RegularizedLSClassifier[source]

Fit by solving regularized least squares.

predict(queries: Array) Array[source]
score(test_hvs: Array, test_labels: Array) Array[source]
replace(**updates: Any) RegularizedLSClassifier[source]
__init__(weights: Array, dimensions: int, num_classes: int, reg: float) None

AdaptiveHDC

class jax_hdc.models.AdaptiveHDC(prototypes: Array, num_updates: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]

Adaptive HDC classifier with iterative prototype refinement.

prototypes: Array
num_updates: Array
num_classes: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) AdaptiveHDC[source]
predict(queries: Array) Array[source]

Predict class labels.

fit(train_hvs: Array, train_labels: Array, epochs: int = 1, learning_rate: float = 0.1) AdaptiveHDC[source]

Train with iterative refinement.

Parameters:
  • train_hvs – Training hypervectors

  • train_labels – Training labels

  • epochs – Number of training epochs

  • learning_rate – Learning rate for updates

score(test_hvs: Array, test_labels: Array) Array[source]

Compute accuracy.

replace(**updates: Any) AdaptiveHDC[source]
__init__(prototypes: Array, num_updates: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map') None

Example:

from jax_hdc import AdaptiveHDC

classifier = AdaptiveHDC.create(
    num_classes=10,
    dimensions=10000,
    vsa_model=model
)

# Iterative training
classifier = classifier.fit(
    train_hvs,
    train_labels,
    epochs=10,
    learning_rate=0.1
)