Source code for jax_hdc.models

"""Classification and learning models for Hyperdimensional Computing."""

from dataclasses import dataclass, field
from dataclasses import replace as dataclass_replace
from typing import Any, Optional, Union

import jax
import jax.numpy as jnp

from jax_hdc import functional as F
from jax_hdc._compat import register_dataclass
from jax_hdc.constants import EPS
from jax_hdc.vsa import VSAModel, create_vsa_model


[docs] @register_dataclass @dataclass class CentroidClassifier: """Centroid-based classifier for HDC. Stores one prototype hypervector per class. Classification finds the most similar prototype to the query. """ prototypes: jax.Array # (num_classes, dimensions) num_classes: int = field(metadata=dict(static=True)) dimensions: int = field(metadata=dict(static=True)) vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs] @staticmethod def create( num_classes: int, dimensions: int = 10000, vsa_model: Union[str, VSAModel] = "map", initial_prototypes: Optional[jax.Array] = None, key: Optional[jax.Array] = None, ) -> "CentroidClassifier": """Create a centroid classifier. Args: num_classes: Number of classes dimensions: Dimensionality of hypervectors vsa_model: VSA model name or instance initial_prototypes: Optional initial prototypes of shape (num_classes, dimensions) key: JAX random key for initialization """ if isinstance(vsa_model, str): vsa_model_name = vsa_model vsa = create_vsa_model(vsa_model, dimensions) else: vsa_model_name = vsa_model.name vsa = vsa_model if initial_prototypes is not None: prototypes = initial_prototypes else: if key is None: key = jax.random.PRNGKey(0) prototypes = vsa.random(key, shape=(num_classes, dimensions)) return CentroidClassifier( prototypes=prototypes, num_classes=num_classes, dimensions=dimensions, vsa_model_name=vsa_model_name, )
[docs] @jax.jit def similarity(self, query: jax.Array) -> jax.Array: """Compute similarity between query and all class prototypes.""" if self.vsa_model_name == "bsc": return jax.vmap(lambda p: F.hamming_similarity(query, p))(self.prototypes) else: return jax.vmap(lambda p: F.cosine_similarity(query, p))(self.prototypes)
[docs] @jax.jit def predict(self, queries: jax.Array) -> jax.Array: """Predict class labels for queries. Args: queries: Shape (batch_size, dimensions) or (dimensions,) Returns: Predicted class indices """ is_single = queries.ndim == 1 if is_single: queries = queries[None, :] similarities = jax.vmap(self.similarity)(queries) predictions = jnp.argmax(similarities, axis=-1) if is_single: return predictions[0] return predictions
[docs] @jax.jit def predict_proba(self, queries: jax.Array) -> jax.Array: """Predict class probabilities using softmax of similarities.""" is_single = queries.ndim == 1 if is_single: queries = queries[None, :] similarities = jax.vmap(self.similarity)(queries) probs = jax.nn.softmax(similarities, axis=-1) if is_single: return probs[0] return probs
[docs] def fit(self, train_hvs: jax.Array, train_labels: jax.Array) -> "CentroidClassifier": """Train classifier by computing class centroids. Args: train_hvs: Training hypervectors of shape (n_samples, dimensions) train_labels: Training labels of shape (n_samples,) Returns: Trained CentroidClassifier (new instance) """ if train_hvs.shape[0] == 0: raise ValueError("Cannot fit CentroidClassifier: training data is empty") new_prototypes_list = [] for class_idx in range(self.num_classes): class_mask = train_labels == class_idx num_samples = jnp.sum(class_mask) if num_samples > 0: weights = jnp.where(class_mask[:, None], 1.0, 0.0) if self.vsa_model_name == "bsc": weighted_hvs = train_hvs.astype(jnp.float32) * weights summed = jnp.sum(weighted_hvs, axis=0) centroid = summed > (num_samples / 2.0) else: weighted_hvs = train_hvs * weights summed = jnp.sum(weighted_hvs, axis=0) centroid = summed / (jnp.linalg.norm(summed) + EPS) new_prototypes_list.append(centroid) else: new_prototypes_list.append(self.prototypes[class_idx]) return self.replace(prototypes=jnp.stack(new_prototypes_list))
[docs] def update_online( self, sample_hv: jax.Array, label: int, learning_rate: float = 0.1 ) -> "CentroidClassifier": """Update classifier online with a single sample.""" old_prototype = self.prototypes[label] if self.vsa_model_name == "bsc": combined = jnp.stack([old_prototype, sample_hv]) new_prototype = F.bundle_bsc(combined, axis=0) else: new_prototype = (1 - learning_rate) * old_prototype + learning_rate * sample_hv new_prototype = new_prototype / (jnp.linalg.norm(new_prototype) + EPS) return self.replace(prototypes=self.prototypes.at[label].set(new_prototype))
[docs] @jax.jit def score(self, test_hvs: jax.Array, test_labels: jax.Array) -> jax.Array: """Compute accuracy on test data.""" predictions = self.predict(test_hvs) return jnp.mean(predictions == test_labels)
[docs] def replace(self, **updates: Any) -> "CentroidClassifier": return dataclass_replace(self, **updates)
[docs] @register_dataclass @dataclass class AdaptiveHDC: """Adaptive HDC classifier with iterative prototype refinement.""" prototypes: jax.Array num_updates: jax.Array num_classes: int = field(metadata=dict(static=True)) dimensions: int = field(metadata=dict(static=True)) vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs] @staticmethod def create( num_classes: int, dimensions: int = 10000, vsa_model: Union[str, VSAModel] = "map", key: Optional[jax.Array] = None, ) -> "AdaptiveHDC": if isinstance(vsa_model, str): vsa_model_name = vsa_model vsa = create_vsa_model(vsa_model, dimensions) else: vsa_model_name = vsa_model.name vsa = vsa_model if key is None: key = jax.random.PRNGKey(0) return AdaptiveHDC( prototypes=vsa.random(key, shape=(num_classes, dimensions)), num_updates=jnp.zeros(num_classes, dtype=jnp.int32), num_classes=num_classes, dimensions=dimensions, vsa_model_name=vsa_model_name, )
[docs] @jax.jit def predict(self, queries: jax.Array) -> jax.Array: """Predict class labels.""" is_single = queries.ndim == 1 if is_single: queries = queries[None, :] if self.vsa_model_name == "bsc": similarities = jax.vmap( lambda q: jax.vmap(lambda p: F.hamming_similarity(q, p))(self.prototypes) )(queries) else: similarities = jax.vmap( lambda q: jax.vmap(lambda p: F.cosine_similarity(q, p))(self.prototypes) )(queries) predictions = jnp.argmax(similarities, axis=-1) if is_single: return predictions[0] return predictions
[docs] def fit( self, train_hvs: jax.Array, train_labels: jax.Array, epochs: int = 1, learning_rate: float = 0.1, ) -> "AdaptiveHDC": """Train with iterative refinement. Args: train_hvs: Training hypervectors train_labels: Training labels epochs: Number of training epochs learning_rate: Learning rate for updates """ if train_hvs.shape[0] == 0: raise ValueError("Cannot fit AdaptiveHDC: training data is empty") classifier = self for class_idx in range(self.num_classes): class_mask = train_labels == class_idx num_samples = jnp.sum(class_mask) if num_samples > 0: weights = jnp.where(class_mask[:, None], 1.0, 0.0) if self.vsa_model_name == "bsc": weighted_hvs = train_hvs.astype(jnp.float32) * weights summed = jnp.sum(weighted_hvs, axis=0) centroid = summed > (num_samples / 2.0) else: weighted_hvs = train_hvs * weights summed = jnp.sum(weighted_hvs, axis=0) centroid = summed / (jnp.linalg.norm(summed) + EPS) classifier = classifier.replace( prototypes=classifier.prototypes.at[class_idx].set(centroid) ) for _epoch in range(epochs): for i in range(len(train_hvs)): pred = classifier.predict(train_hvs[i]) true_label = train_labels[i] if pred != true_label: classifier = classifier._update_prototypes( train_hvs[i], true_label, pred, learning_rate ) return classifier
def _update_prototypes( self, sample_hv: jax.Array, true_label: Union[int, jax.Array], pred_label: Union[int, jax.Array], learning_rate: float, ) -> "AdaptiveHDC": true_proto = self.prototypes[true_label] if self.vsa_model_name != "bsc": new_true_proto = true_proto + learning_rate * sample_hv new_true_proto = new_true_proto / (jnp.linalg.norm(new_true_proto) + EPS) else: new_true_proto = F.bundle_bsc(jnp.stack([true_proto, sample_hv]), axis=0) return self.replace(prototypes=self.prototypes.at[true_label].set(new_true_proto))
[docs] @jax.jit def score(self, test_hvs: jax.Array, test_labels: jax.Array) -> jax.Array: """Compute accuracy.""" predictions = self.predict(test_hvs) return jnp.mean(predictions == test_labels)
[docs] def replace(self, **updates: Any) -> "AdaptiveHDC": return dataclass_replace(self, **updates)
[docs] @register_dataclass @dataclass class LVQClassifier: """Learning Vector Quantization classifier. Prototypes are updated: move winner toward sample if correct, away if wrong. """ prototypes: jax.Array num_classes: int = field(metadata=dict(static=True)) dimensions: int = field(metadata=dict(static=True)) vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs] @staticmethod def create( num_classes: int, dimensions: int = 10000, vsa_model: Union[str, VSAModel] = "map", key: Optional[jax.Array] = None, ) -> "LVQClassifier": if isinstance(vsa_model, str): vsa = create_vsa_model(vsa_model, dimensions) else: vsa = vsa_model if key is None: key = jax.random.PRNGKey(0) return LVQClassifier( prototypes=vsa.random(key, (num_classes, dimensions)), num_classes=num_classes, dimensions=dimensions, vsa_model_name=vsa.name, )
[docs] @jax.jit def predict(self, queries: jax.Array) -> jax.Array: """Predict class labels by nearest prototype.""" is_single = queries.ndim == 1 if is_single: queries = queries[None, :] if self.vsa_model_name == "bsc": sims = jax.vmap( lambda q: jax.vmap(lambda p: F.hamming_similarity(q, p))(self.prototypes) )(queries) else: sims = jax.vmap( lambda q: jax.vmap(lambda p: F.cosine_similarity(q, p))(self.prototypes) )(queries) preds = jnp.argmax(sims, axis=-1) return preds[0] if is_single else preds
[docs] def fit( self, train_hvs: jax.Array, train_labels: jax.Array, epochs: int = 10, lr: float = 0.1, ) -> "LVQClassifier": """Train with LVQ updates (winner-take-all, move toward/away).""" if train_hvs.shape[0] == 0: raise ValueError("Cannot fit LVQClassifier: training data is empty") clf = self for _ in range(epochs): for i in range(len(train_hvs)): x, y_true = train_hvs[i], int(train_labels[i]) pred = int(clf.predict(x)) if pred == y_true: delta = lr * (x - clf.prototypes[pred]) else: delta = -lr * (x - clf.prototypes[pred]) if self.vsa_model_name != "bsc": new_p = clf.prototypes[pred] + delta new_p = new_p / (jnp.linalg.norm(new_p) + EPS) else: new_p = F.bundle_bsc( jnp.stack([clf.prototypes[pred], (clf.prototypes[pred] + delta) > 0.5]), axis=0, ) clf = clf.replace(prototypes=clf.prototypes.at[pred].set(new_p)) return clf
[docs] @jax.jit def score(self, test_hvs: jax.Array, test_labels: jax.Array) -> jax.Array: preds = self.predict(test_hvs) return jnp.mean(preds == test_labels)
[docs] def replace(self, **updates: Any) -> "LVQClassifier": return dataclass_replace(self, **updates)
[docs] @register_dataclass @dataclass class RegularizedLSClassifier: """Regularized Least Squares classifier in HV space. Solves (X^T X + lambda I) W = X^T Y for weights W. """ weights: jax.Array # (dimensions, num_classes) dimensions: int = field(metadata=dict(static=True)) num_classes: int = field(metadata=dict(static=True)) reg: float = field(metadata=dict(static=True))
[docs] @staticmethod def create( dimensions: int, num_classes: int, reg: float = 1e-4, ) -> "RegularizedLSClassifier": return RegularizedLSClassifier( weights=jnp.zeros((dimensions, num_classes)), dimensions=dimensions, num_classes=num_classes, reg=reg, )
[docs] def fit(self, train_hvs: jax.Array, train_labels: jax.Array) -> "RegularizedLSClassifier": """Fit by solving regularized least squares.""" n = train_hvs.shape[0] if n == 0: raise ValueError("Cannot fit RegularizedLSClassifier: training data is empty") Y = jax.nn.one_hot(train_labels, self.num_classes) XtX = train_hvs.T @ train_hvs + self.reg * jnp.eye(self.dimensions) XtY = train_hvs.T @ Y weights, *_ = jnp.linalg.lstsq(XtX, XtY, rcond=None) return self.replace(weights=weights)
[docs] @jax.jit def predict(self, queries: jax.Array) -> jax.Array: logits = queries @ self.weights return jnp.argmax(logits, axis=-1)
[docs] @jax.jit def score(self, test_hvs: jax.Array, test_labels: jax.Array) -> jax.Array: preds = self.predict(test_hvs) return jnp.mean(preds == test_labels)
[docs] def replace(self, **updates: Any) -> "RegularizedLSClassifier": return dataclass_replace(self, **updates)
__all__ = [ "CentroidClassifier", "AdaptiveHDC", "LVQClassifier", "RegularizedLSClassifier", ]