API Reference

Main package exports:

JAX-HDC: Hyperdimensional Computing with JAX.

jax_hdc.bind_bsc(x: Array, y: Array) Array[source]

Bind two hypervectors using XOR for Binary Spatter Codes.

Binding creates a new hypervector that is dissimilar to both inputs.

Parameters:
  • x – Binary hypervector of shape (…, d)

  • y – Binary hypervector of shape (…, d)

Returns:

Bound hypervector of shape (…, d), dissimilar to both x and y

jax_hdc.bundle_bsc(vectors: Array, axis: int = 0) Array[source]

Bundle hypervectors using majority rule for Binary Spatter Codes.

Bundling creates a new hypervector similar to all inputs by taking the majority vote at each dimension.

Parameters:
  • vectors – Binary hypervectors of shape with axis containing vectors to bundle

  • axis – Axis along which to bundle (default: 0)

Returns:

Bundled hypervector, similar to all inputs

jax_hdc.inverse_bsc(x: Array) Array[source]

Compute inverse for BSC (identity since XOR is self-inverse).

jax_hdc.hamming_similarity(x: Array, y: Array) Array[source]

Compute normalized Hamming similarity between binary hypervectors.

Returns the fraction of matching bits between two binary vectors. Random vectors have similarity ≈ 0.5.

Parameters:
  • x – Binary hypervector of shape (…, d)

  • y – Binary hypervector of shape (…, d)

Returns:

Similarity score in [0, 1], where 1 is identical and 0.5 is random

jax_hdc.bind_map(x: Array, y: Array) Array[source]

Bind two hypervectors using element-wise multiplication for MAP.

For real-valued vectors (MAP model), binding is element-wise multiplication. The result is dissimilar to both inputs.

Parameters:
  • x – Real-valued hypervector of shape (…, d)

  • y – Real-valued hypervector of shape (…, d)

Returns:

Bound hypervector of shape (…, d)

jax_hdc.bundle_map(vectors: Array, axis: int = 0) Array[source]

Bundle hypervectors using normalized sum for MAP.

For real-valued vectors, bundling is the normalized sum. The result is similar to all inputs (high cosine similarity).

Parameters:
  • vectors – Real-valued hypervectors with axis containing vectors to bundle

  • axis – Axis along which to bundle (default: 0)

Returns:

Bundled and normalized hypervector

jax_hdc.inverse_map(x: Array, eps: float = 1e-08) Array[source]

Compute inverse for MAP using element-wise reciprocal.

For MAP binding (element-wise multiplication), the inverse is element-wise reciprocal: bind(bind(x, y), inverse(y)) = x. Near-zero elements return 0 (no inverse; bind with 0 destroys information).

Parameters:
  • x – Real-valued hypervector of shape (…, d)

  • eps – Small constant for numerical stability (default: EPS)

Returns:

Inverse hypervector

jax_hdc.cosine_similarity(x: Array, y: Array) Array[source]

Compute cosine similarity between real-valued hypervectors.

Returns the cosine of the angle between two vectors. Random unit vectors have similarity ≈ 0.

Parameters:
  • x – Real-valued hypervector of shape (…, d)

  • y – Real-valued hypervector of shape (…, d)

Returns:

Similarity score in [-1, 1], where 1 is identical, -1 is opposite, and 0 is orthogonal

jax_hdc.permute(x: Array, shifts: int = 1) Array[source]

Cyclically permute a hypervector to encode sequence information.

Permutation reorders elements to represent positional or sequential information. Cyclic shifts preserve the distribution of values.

Parameters:
  • x – Hypervector of shape (…, d)

  • shifts – Number of positions to shift (default: 1)

Returns:

Permuted hypervector of shape (…, d)

jax_hdc.cleanup(query: ~jax.Array, memory: ~jax.Array, similarity_fn: ~typing.Callable[[~jax.Array, ~jax.Array], ~jax.Array] = <PjitFunction of <function cosine_similarity>>, return_similarity: bool = False) Array | tuple[Array, Array][source]

Find the most similar vector in memory to the query.

Cleanup (or resonator) is used to retrieve the closest known hypervector from memory, useful for error correction and symbol retrieval.

Parameters:
  • query – Query hypervector of shape (…, d)

  • memory – Memory hypervectors of shape (n, d)

  • similarity_fn – Function to compute similarity (default: cosine_similarity)

  • return_similarity – Whether to return similarity scores (default: False)

Returns:

Most similar vector from memory, or (vector, similarity) if return_similarity=True

class jax_hdc.BSC(name: str, dimensions: int)[source]

Bases: VSAModel

Binary Spatter Codes (BSC).

Binary hypervectors with XOR binding, majority bundling, Hamming similarity.

static create(dimensions: int = 10000) BSC[source]

Create a BSC model.

Parameters:

dimensions – Dimensionality of hypervectors (default: 10000)

Returns:

Initialized BSC model

bind(x: Array, y: Array) Array[source]

Bind using XOR.

bundle(vectors: Array, axis: int = 0) Array[source]

Bundle using majority rule.

inverse(x: Array) Array[source]

Inverse is identity for XOR.

similarity(x: Array, y: Array) Array[source]

Compute Hamming similarity.

random(key: Array, shape: tuple) Array[source]

Generate random binary hypervectors.

Parameters:
  • key – JAX random key

  • shape – Shape of output array

Returns:

Random binary hypervectors with ~50% ones

__init__(name: str, dimensions: int) None
class jax_hdc.BSBC(name: str, dimensions: int, block_size: int = 100, k_active: int = 5)[source]

Bases: VSAModel

Binary Sparse Block Codes (B-SBC).

Block-sparse binary vectors with k_active ones per block, XOR binding, majority bundling.

block_size: int = 100
k_active: int = 5
static create(dimensions: int = 10000, block_size: int = 100, k_active: int = 5) BSBC[source]

Create a B-SBC model.

Parameters:
  • dimensions – Total dimensionality (must be divisible by block_size)

  • block_size – Size of each block

  • k_active – Number of ones per block (sparsity)

Returns:

Initialized BSBC model

bind(x: Array, y: Array) Array[source]

Bind using XOR (same as BSC).

bundle(vectors: Array, axis: int = 0) Array[source]

Bundle using majority rule.

inverse(x: Array) Array[source]

Inverse is identity for XOR.

similarity(x: Array, y: Array) Array[source]

Compute Hamming similarity.

random(key: Array, shape: tuple) Array[source]

Generate random block-sparse binary hypervectors.

__init__(name: str, dimensions: int, block_size: int = 100, k_active: int = 5) None
class jax_hdc.MAP(name: str, dimensions: int)[source]

Bases: VSAModel

Multiply-Add-Permute (MAP).

Real-valued vectors with element-wise multiply binding, normalized sum bundling, cosine similarity.

static create(dimensions: int = 10000) MAP[source]

Create a MAP model.

Parameters:

dimensions – Dimensionality of hypervectors (default: 10000)

Returns:

Initialized MAP model

bind(x: Array, y: Array) Array[source]

Bind using element-wise multiplication.

bundle(vectors: Array, axis: int = 0) Array[source]

Bundle using normalized sum.

inverse(x: Array) Array[source]

Inverse via element-wise reciprocal.

similarity(x: Array, y: Array) Array[source]

Compute cosine similarity.

random(key: Array, shape: tuple) Array[source]

Generate random real-valued hypervectors.

Parameters:
  • key – JAX random key

  • shape – Shape of output array

Returns:

Random normalized hypervectors sampled from normal distribution

__init__(name: str, dimensions: int) None
class jax_hdc.HRR(name: str, dimensions: int)[source]

Bases: VSAModel

Holographic Reduced Representations (HRR).

Real-valued vectors with circular convolution binding, normalized sum bundling, cosine similarity.

static create(dimensions: int = 10000) HRR[source]

Create an HRR model.

Parameters:

dimensions – Dimensionality of hypervectors (default: 10000)

Returns:

Initialized HRR model

bind(x: Array, y: Array) Array[source]

Bind using circular convolution.

bundle(vectors: Array, axis: int = 0) Array[source]

Bundle using normalized sum.

inverse(x: Array) Array[source]

Inverse via element reversal.

similarity(x: Array, y: Array) Array[source]

Compute cosine similarity.

random(key: Array, shape: tuple) Array[source]

Generate random real-valued hypervectors.

Parameters:
  • key – JAX random key

  • shape – Shape of output array

Returns:

Random normalized hypervectors sampled from normal distribution

__init__(name: str, dimensions: int) None
class jax_hdc.FHRR(name: str, dimensions: int)[source]

Bases: VSAModel

Fourier Holographic Reduced Representations (FHRR).

Complex-valued vectors with element-wise multiply binding, normalized sum bundling.

static create(dimensions: int = 10000) FHRR[source]

Create an FHRR model.

Parameters:

dimensions – Dimensionality of hypervectors (default: 10000)

Returns:

Initialized FHRR model

bind(x: Array, y: Array) Array[source]

Bind using element-wise multiplication.

bundle(vectors: Array, axis: int = 0) Array[source]

Bundle using normalized sum.

inverse(x: Array) Array[source]

Inverse via complex conjugate.

similarity(x: Array, y: Array) Array[source]

Compute cosine similarity of complex vectors.

random(key: Array, shape: tuple) Array[source]

Generate random complex hypervectors on unit circle.

Parameters:
  • key – JAX random key

  • shape – Shape of output array

Returns:

Random unit complex hypervectors

__init__(name: str, dimensions: int) None
class jax_hdc.CGR(name: str, dimensions: int, q: int = 8)[source]

Bases: VSAModel

Cyclic Group Representation (CGR).

Integer hypervectors in Z_q with modular addition binding, component-wise mode bundling.

q: int = 8
static create(dimensions: int = 10000, q: int = 8) CGR[source]
bind(x: Array, y: Array) Array[source]

Bind using modular addition.

bundle(vectors: Array, axis: int = 0) Array[source]

Bundle using component-wise mode.

inverse(x: Array) Array[source]

Inverse via modular negation.

similarity(x: Array, y: Array) Array[source]

Compute fraction of matching elements.

random(key: Array, shape: tuple) Array[source]

Generate random integer hypervectors in {0, …, q-1}.

__init__(name: str, dimensions: int, q: int = 8) None
class jax_hdc.MCR(name: str, dimensions: int, q: int = 64)[source]

Bases: VSAModel

Modular Composite Representation (MCR).

Integer phase vectors with modular addition binding, phasor sum bundling.

q: int = 64
static create(dimensions: int = 10000, q: int = 64) MCR[source]
bind(x: Array, y: Array) Array[source]

Bind using modular addition (phase addition).

bundle(vectors: Array, axis: int = 0) Array[source]

Bundle using phasor sum with snap-to-grid.

inverse(x: Array) Array[source]

Inverse via modular negation (phase conjugate).

similarity(x: Array, y: Array) Array[source]

Compute phasor similarity.

random(key: Array, shape: tuple) Array[source]

Generate random integer hypervectors in {0, …, q-1}.

__init__(name: str, dimensions: int, q: int = 64) None
class jax_hdc.VTB(name: str, dimensions: int)[source]

Bases: VSAModel

Vector-Derived Transformation Binding (VTB).

Real-valued vectors with matrix multiplication binding, normalized sum bundling.

static create(dimensions: int = 10000) VTB[source]
bind(x: Array, y: Array) Array[source]

Bind using matrix multiplication.

bundle(vectors: Array, axis: int = 0) Array[source]

Bundle using normalized sum.

inverse(x: Array) Array[source]

Inverse via matrix pseudoinverse.

similarity(x: Array, y: Array) Array[source]

Compute cosine similarity.

random(key: Array, shape: tuple) Array[source]

Generate random normalized real-valued hypervectors.

__init__(name: str, dimensions: int) None
class jax_hdc.RandomEncoder(codebook: Array, num_features: int, num_values: int, dimensions: int, vsa_model_name: str = 'map')[source]

Bases: object

Encoder using random hypervectors for discrete features.

Each unique feature value is mapped to a random hypervector from a codebook. Multiple features are bundled together to form the final representation.

codebook: Array
num_features: int
num_values: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_features: int, num_values: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) RandomEncoder[source]

Create a random encoder.

Parameters:
  • num_features – Number of features to encode

  • num_values – Number of possible values per feature

  • dimensions – Dimensionality of hypervectors (default: 10000)

  • vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’) or VSAModel instance

  • key – JAX random key (default: PRNGKey(0))

Returns:

Initialized RandomEncoder

encode(indices: Array) Array[source]

Encode discrete features as hypervectors.

Parameters:

indices – Feature indices of shape (num_features,) with values in [0, num_values). Out-of-bounds indices are clamped to valid range.

Returns:

Encoded hypervector of shape (dimensions,)

encode_batch(indices: Array) Array[source]

Encode a batch of samples.

Parameters:

indices – Batch of feature indices of shape (batch_size, num_features)

Returns:

Encoded hypervectors of shape (batch_size, dimensions)

__init__(codebook: Array, num_features: int, num_values: int, dimensions: int, vsa_model_name: str = 'map') None
class jax_hdc.LevelEncoder(level_hvs: Array, num_levels: int, dimensions: int, min_value: float, max_value: float, vsa_model_name: str = 'map', encoding_type: str = 'linear')[source]

Bases: object

Encoder for continuous values using level hypervectors.

Continuous values are encoded by interpolating between level hypervectors, creating a smooth representation where similar values map to similar hypervectors.

level_hvs: Array
num_levels: int
dimensions: int
min_value: float
max_value: float
vsa_model_name: str = 'map'
encoding_type: str = 'linear'
static create(num_levels: int = 100, dimensions: int = 10000, min_value: float = 0.0, max_value: float = 1.0, vsa_model: str | VSAModel = 'map', encoding_type: str = 'linear', key: Array | None = None) LevelEncoder[source]

Create a level encoder.

Parameters:
  • num_levels – Number of levels for discretization (default: 100)

  • dimensions – Dimensionality of hypervectors (default: 10000)

  • min_value – Minimum value of the range (default: 0.0)

  • max_value – Maximum value of the range (default: 1.0)

  • vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’)

  • encoding_type – ‘linear’ or ‘circular’ (default: ‘linear’)

  • key – JAX random key

Returns:

Initialized LevelEncoder

encode(value: float | Array) Array[source]

Encode a continuous value as a hypervector.

Parameters:

value – Continuous value to encode (scalar or array)

Returns:

Encoded hypervector of shape (dimensions,) or batch shape + (dimensions,)

encode_batch(values: Array) Array[source]

Encode a batch of continuous values.

Parameters:

values – Batch of values of shape (batch_size,) or (batch_size, num_features)

Returns:

Encoded hypervectors

__init__(level_hvs: Array, num_levels: int, dimensions: int, min_value: float, max_value: float, vsa_model_name: str = 'map', encoding_type: str = 'linear') None
class jax_hdc.ProjectionEncoder(projection_matrix: Array, input_dim: int, dimensions: int, vsa_model_name: str = 'map')[source]

Bases: object

Encoder using random projection for high-dimensional data.

Projects high-dimensional input data into hypervector space using a random projection matrix. Useful for images, text embeddings, etc.

projection_matrix: Array
input_dim: int
dimensions: int
vsa_model_name: str = 'map'
static create(input_dim: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) ProjectionEncoder[source]

Create a projection encoder.

Parameters:
  • input_dim – Dimensionality of input data

  • dimensions – Dimensionality of hypervectors (default: 10000)

  • vsa_model – VSA model to use (‘bsc’, ‘map’, ‘hrr’, ‘fhrr’)

  • key – JAX random key

Returns:

Initialized ProjectionEncoder

encode(x: Array) Array[source]

Encode input data as a hypervector.

Parameters:

x – Input data of shape (input_dim,)

Returns:

Encoded hypervector of shape (dimensions,)

encode_batch(x: Array) Array[source]

Encode a batch of inputs.

Parameters:

x – Batch of inputs of shape (batch_size, input_dim)

Returns:

Encoded hypervectors of shape (batch_size, dimensions)

__init__(projection_matrix: Array, input_dim: int, dimensions: int, vsa_model_name: str = 'map') None
class jax_hdc.KernelEncoder(omega: Array, bias: Array, input_dim: int, dimensions: int, gamma: float, vsa_model_name: str = 'map')[source]

Bases: object

Encoder using RBF kernel approximation (Random Fourier Features).

Approximates the RBF kernel k(x,y) = exp(-gamma ||x-y||^2) via random Fourier features, mapping input to a hypervector space that preserves kernel similarity.

omega: Array
bias: Array
input_dim: int
dimensions: int
gamma: float
vsa_model_name: str = 'map'
static create(input_dim: int, dimensions: int = 10000, gamma: float = 1.0, vsa_model: str | VSAModel = 'map', key: Array | None = None) KernelEncoder[source]

Create a kernel encoder.

Parameters:
  • input_dim – Dimensionality of input data

  • dimensions – Dimensionality of output hypervectors

  • gamma – RBF kernel scale parameter (1 / 2*sigma^2)

  • vsa_model – VSA model (‘map’, ‘hrr’, ‘fhrr’ for real-valued)

  • key – JAX random key

Returns:

Initialized KernelEncoder

encode(x: Array) Array[source]

Encode input using RBF kernel approximation.

encode_batch(x: Array) Array[source]

Encode a batch of inputs.

__init__(omega: Array, bias: Array, input_dim: int, dimensions: int, gamma: float, vsa_model_name: str = 'map') None
class jax_hdc.GraphEncoder(node_embeddings: Array, num_nodes: int, dimensions: int, vsa_model_name: str = 'map')[source]

Bases: object

Encoder for graph structures (nodes and edges).

Encodes a graph by assigning random hypervectors to nodes and bundling bound node pairs for edges. Graph = bundle of edge HVs.

node_embeddings: Array
num_nodes: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_nodes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) GraphEncoder[source]

Create a graph encoder.

Parameters:
  • num_nodes – Maximum number of nodes

  • dimensions – Hypervector dimensionality

  • vsa_model – VSA model for real-valued graphs

  • key – JAX random key

Returns:

Initialized GraphEncoder

encode_edges(edges: Array) Array[source]

Encode graph as bundle of bound edge pairs.

Parameters:

edges – Array of shape (num_edges, 2) with node indices in [0, num_nodes). Out-of-bounds indices are clamped to valid range.

Returns:

Graph hypervector of shape (dimensions,)

__init__(node_embeddings: Array, num_nodes: int, dimensions: int, vsa_model_name: str = 'map') None
class jax_hdc.CentroidClassifier(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]

Bases: object

Centroid-based classifier for HDC.

Stores one prototype hypervector per class. Classification finds the most similar prototype to the query.

prototypes: Array
num_classes: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', initial_prototypes: Array | None = None, key: Array | None = None) CentroidClassifier[source]

Create a centroid classifier.

Parameters:
  • num_classes – Number of classes

  • dimensions – Dimensionality of hypervectors

  • vsa_model – VSA model name or instance

  • initial_prototypes – Optional initial prototypes of shape (num_classes, dimensions)

  • key – JAX random key for initialization

similarity(query: Array) Array[source]

Compute similarity between query and all class prototypes.

predict(queries: Array) Array[source]

Predict class labels for queries.

Parameters:

queries – Shape (batch_size, dimensions) or (dimensions,)

Returns:

Predicted class indices

predict_proba(queries: Array) Array[source]

Predict class probabilities using softmax of similarities.

fit(train_hvs: Array, train_labels: Array) CentroidClassifier[source]

Train classifier by computing class centroids.

Parameters:
  • train_hvs – Training hypervectors of shape (n_samples, dimensions)

  • train_labels – Training labels of shape (n_samples,)

Returns:

Trained CentroidClassifier (new instance)

update_online(sample_hv: Array, label: int, learning_rate: float = 0.1) CentroidClassifier[source]

Update classifier online with a single sample.

score(test_hvs: Array, test_labels: Array) Array[source]

Compute accuracy on test data.

replace(**updates: Any) CentroidClassifier[source]
__init__(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map') None
class jax_hdc.AdaptiveHDC(prototypes: Array, num_updates: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]

Bases: object

Adaptive HDC classifier with iterative prototype refinement.

prototypes: Array
num_updates: Array
num_classes: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) AdaptiveHDC[source]
predict(queries: Array) Array[source]

Predict class labels.

fit(train_hvs: Array, train_labels: Array, epochs: int = 1, learning_rate: float = 0.1) AdaptiveHDC[source]

Train with iterative refinement.

Parameters:
  • train_hvs – Training hypervectors

  • train_labels – Training labels

  • epochs – Number of training epochs

  • learning_rate – Learning rate for updates

score(test_hvs: Array, test_labels: Array) Array[source]

Compute accuracy.

replace(**updates: Any) AdaptiveHDC[source]
__init__(prototypes: Array, num_updates: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map') None
class jax_hdc.LVQClassifier(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map')[source]

Bases: object

Learning Vector Quantization classifier.

Prototypes are updated: move winner toward sample if correct, away if wrong.

prototypes: Array
num_classes: int
dimensions: int
vsa_model_name: str = 'map'
static create(num_classes: int, dimensions: int = 10000, vsa_model: str | VSAModel = 'map', key: Array | None = None) LVQClassifier[source]
predict(queries: Array) Array[source]

Predict class labels by nearest prototype.

fit(train_hvs: Array, train_labels: Array, epochs: int = 10, lr: float = 0.1) LVQClassifier[source]

Train with LVQ updates (winner-take-all, move toward/away).

score(test_hvs: Array, test_labels: Array) Array[source]
replace(**updates: Any) LVQClassifier[source]
__init__(prototypes: Array, num_classes: int, dimensions: int, vsa_model_name: str = 'map') None
class jax_hdc.RegularizedLSClassifier(weights: Array, dimensions: int, num_classes: int, reg: float)[source]

Bases: object

Regularized Least Squares classifier in HV space.

Solves (X^T X + lambda I) W = X^T Y for weights W.

weights: Array
dimensions: int
num_classes: int
reg: float
static create(dimensions: int, num_classes: int, reg: float = 0.0001) RegularizedLSClassifier[source]
fit(train_hvs: Array, train_labels: Array) RegularizedLSClassifier[source]

Fit by solving regularized least squares.

predict(queries: Array) Array[source]
score(test_hvs: Array, test_labels: Array) Array[source]
replace(**updates: Any) RegularizedLSClassifier[source]
__init__(weights: Array, dimensions: int, num_classes: int, reg: float) None
class jax_hdc.SparseDistributedMemory(locations: Array, contents: Array, dimensions: int, radius: float)[source]

Bases: object

Sparse Distributed Memory (SDM) for content-addressable storage.

locations: Array
contents: Array
dimensions: int
radius: float
static create(num_locations: int, dimensions: int, radius: float = 0.0, key: Array | None = None) SparseDistributedMemory[source]
write(address: Array, value: Array) SparseDistributedMemory[source]
read(address: Array) Array[source]
__init__(locations: Array, contents: Array, dimensions: int, radius: float) None
class jax_hdc.HopfieldMemory(patterns: Array, dimensions: int, beta: float = 1.0)[source]

Bases: object

Modern Hopfield network for associative memory.

patterns: Array
dimensions: int
beta: float = 1.0
static create(dimensions: int, beta: float = 1.0) HopfieldMemory[source]
add(pattern: Array) HopfieldMemory[source]
retrieve(query: Array) Array[source]
__init__(patterns: Array, dimensions: int, beta: float = 1.0) None
class jax_hdc.AttentionMemory(keys: Array, values: Array, dimensions: int, temperature: float = 1.0, num_heads: int = 1)[source]

Bases: object

Attention-based retrieval with key-value storage and multi-head support.

keys: Array
values: Array
dimensions: int
temperature: float = 1.0
num_heads: int = 1
static create(dimensions: int, temperature: float = 1.0, num_heads: int = 1) AttentionMemory[source]
write(key: Array, value: Array) AttentionMemory[source]
write_batch(keys: Array, values: Array) AttentionMemory[source]
retrieve(query: Array) Array[source]
retrieve_with_weights(query: Array) tuple[source]
__init__(keys: Array, values: Array, dimensions: int, temperature: float = 1.0, num_heads: int = 1) None