How can this code be improved? I'm a novice programmer trying to learn ml by doing it from scratch. This code is part of a transformer model that I'm working on. Do you have any ideas about how to improve it for better performance and easier reality?
import jax
import jax.numpy as jnp
class Embedding():
def __init__(self, vocab_size, d_model, learning_rate=0.01, decay_steps=100, decay_rate=0.9):
self.vocab_size = vocab_size
self.d_model = d_model
self.learning_rate = learning_rate
self.decay_steps = decay_steps
self.decay_rate = decay_rate
self.global_step = 0
def __call__(self, x):
return jnp.take(jnp.eye(self.vocab_size), x, axis=0)
def weights_init(self):
key = jax.random.PRNGKey(0)
self.embedding_matrix = jax.random.normal(key, (self.vocab_size, self.d_model))
key = jax.random.PRNGKey(1)
self.context_matrix = jax.random.normal(key, (self.d_model, self.vocab_size))
def forward(self, x):
# Embedding layer
# x is a vector of size (vocab_size, 1)
# embedding_matrix is a matrix of size (vocab_size, d_model)
# hidden is a vector of size (d_model, 1)
hidden = jnp.dot(self.embedding_matrix.T, x)
# Context layer
# context_matrix is a matrix of size (d_model, vocab_size)
# output is a vector of size (vocab_size, 1)
output = jnp.dot(self.context_matrix.T, hidden)
# Using softmax as activation function
prediction = jax.nn.softmax(output)
return hidden, prediction
def backward(self, hidden, prediction, label):
# Calculate error
error = jnp.array(label) - prediction
# Calculate cross-entropy loss
loss = -jnp.sum(jnp.array(label) * jnp.log(prediction))
# Calculate gradient
grad_context = jnp.dot(hidden, error.T)
grad_embedding = jnp.dot(error, self.context_matrix.T)
self.loss = loss
self.update_weights(grad_context, grad_embedding)
def update_weights(self, grad_context, grad_embedding):
# Update weights
self.context_matrix += self.learning_rate * grad_context
self.embedding_matrix += self.learning_rate * grad_embedding
# Update learning rate
self.global_step += 1
if self.global_step % self.decay_steps == 0:
self.learning_rate = self.learning_rate * self.decay_rate
1 Answer 1
As @J_H pointed out in a comment, there is no more information other than your code to base these improvements upon. Thus, take it with a grain of salt.
Use Direct Embedding Lookups
Right now, your __call__
method returns a one-hot vector, and then later you do a dot product. This is not very efficient. The typical approach for embedding lookups is direct indexing:
def __call__(self, x):
return self.embedding_matrix[x]
This avoids creating large one-hot vectors and performing costly dot products, but you'll need to ensure that x
is an integer index (or a batch of indices)
Initialize Weights in __init__
You currently have a separate method that needs to be called. It's often cleaner to initialize weights in the constructor, as this ensure that the method is always in a ready-to-use state.
Clarify the Shapes and Flow
Your forward
method is a bit confusing because it still expects x
as if it were a one-hot vector. If you change your embedding lookup as recommended, then x
should be an integer. If you want to handle a single token at a time:
def forward(self, x):
hidden = self.embedding_matrix[x]
output = jnp.dot(hidden, self.context_matrix)
prediction = jax.nn.softmax(output)
return hidden, prediction
If you want to handle a batch of tokens (say x
with shape [batch_size]
), your code would need to use broadcasting or a batch dimension:
def forward(self, x):
hidden = self.embedding_matrix[x]
output = jnp.einsum('bd,dv->bv', hidden, self.context_matrix)
prediction = jax.nn.softmax(output, axis=-1)
return hidden, prediction
Use Standard Libraries
Instead of manually implementing learning rate decay, consider using optimizers from opt
(the recommended optimization library for JAX). For example, optax.exponential_decay
can handle the learning rate schedule for you.
weights_init()
you make the curious suggestion that the PRNG is in a non-deterministic state and must therefore be re-seeded prior to creatingcontext_matrix
. Other than that, the code reads like it was faithfully typed in from a nice text book. "How can this code be improved?" You gave us barely a smidgen of Review Context. We don't know your use case, your automated test suite, nor your evaluation rubric, so absent any obvious stack trace bugs we can't nudge this code in an "improved" direction, since you haven't described what's sub-optimal about it. \$\endgroup\$