Models Module
The jax_hdc.models module provides classification and learning algorithms.
CentroidClassifier
- class jax_hdc.models.CentroidClassifier(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]
Centroid-based classifier for HDC.
Stores one prototype hypervector per class. Classification finds the most similar prototype to the query.
- static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', initial_prototypes: Array | None = None, key: Array | None = None) CentroidClassifier[source]
Create a centroid classifier.
- Parameters:
num_classes – Number of classes
dimensions – Dimensionality of hypervectors
vsa_model – VSA model name or instance
initial_prototypes – Optional initial prototypes of shape (num_classes, dimensions)
key – JAX random key for initialization
- predict(queries: Array) Array[source]
Predict class labels for queries.
- Parameters:
queries – Shape (batch_size, dimensions) or (dimensions,)
- Returns:
Predicted class indices
- predict_proba(queries: Array) Array[source]
Predict class probabilities using softmax of similarities.
- fit(train_hvs: Array, train_labels: Array) CentroidClassifier[source]
Train classifier by computing class centroids.
- Parameters:
train_hvs – Training hypervectors of shape (n_samples, dimensions)
train_labels – Training labels of shape (n_samples,)
- Returns:
Trained CentroidClassifier (new instance)
- update_online(sample_hv: Array, label: int, learning_rate: float = 0.1) CentroidClassifier[source]
Update classifier online with a single sample.
- replace(**updates: Any) CentroidClassifier[source]
Example:
from jax_hdc import MAP, CentroidClassifier
import jax
import jax.numpy as jnp
model = MAP.create(dimensions=10000)
key = jax.random.PRNGKey(42)
# Create classifier
classifier = CentroidClassifier.create(
num_classes=10,
dimensions=10000,
vsa_model=model
)
# Train
train_hvs = model.random(key, (100, 10000))
train_labels = jax.random.randint(key, (100,), 0, 10)
classifier = classifier.fit(train_hvs, train_labels)
# Predict
test_hvs = model.random(key, (20, 10000))
predictions = classifier.predict(test_hvs)
# Evaluate
test_labels = jax.random.randint(key, (20,), 0, 10)
accuracy = classifier.score(test_hvs, test_labels)
LVQClassifier
- class jax_hdc.models.LVQClassifier(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]
Learning Vector Quantization classifier.
Prototypes are updated: move winner toward sample if correct, away if wrong.
- static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) LVQClassifier[source]
- fit(train_hvs: Array, train_labels: Array, epochs: int = 10, lr: float = 0.1) LVQClassifier[source]
Train with LVQ updates (winner-take-all, move toward/away).
- replace(**updates: Any) LVQClassifier[source]
RegularizedLSClassifier
- class jax_hdc.models.RegularizedLSClassifier(weights: Array, dimensions: int, num_classes: int, reg: float)[source]
Regularized Least Squares classifier in HV space.
Solves (X^T X + lambda I) W = X^T Y for weights W.
- static create(dimensions: int, num_classes: int, reg: float = 0.0001) RegularizedLSClassifier[source]
- fit(train_hvs: Array, train_labels: Array) RegularizedLSClassifier[source]
Fit by solving regularized least squares.
- replace(**updates: Any) RegularizedLSClassifier[source]
AdaptiveHDC
- class jax_hdc.models.AdaptiveHDC(prototypes: Array, num_updates: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]
Adaptive HDC classifier with iterative prototype refinement.
- static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) AdaptiveHDC[source]
- fit(train_hvs: Array, train_labels: Array, epochs: int = 1, learning_rate: float = 0.1) AdaptiveHDC[source]
Train with iterative refinement.
- Parameters:
train_hvs – Training hypervectors
train_labels – Training labels
epochs – Number of training epochs
learning_rate – Learning rate for updates
- replace(**updates: Any) AdaptiveHDC[source]
Example:
from jax_hdc import AdaptiveHDC
classifier = AdaptiveHDC.create(
num_classes=10,
dimensions=10000,
vsa_model=model
)
# Iterative training
classifier = classifier.fit(
train_hvs,
train_labels,
epochs=10,
learning_rate=0.1
)