2
\$\begingroup\$

How can this code be improved? I'm a novice programmer trying to learn ml by doing it from scratch. This code is part of a transformer model that I'm working on. Do you have any ideas about how to improve it for better performance and easier reality?

import jax
import jax.numpy as jnp
class Embedding():
 def __init__(self, vocab_size, d_model, learning_rate=0.01, decay_steps=100, decay_rate=0.9):
 self.vocab_size = vocab_size
 self.d_model = d_model
 self.learning_rate = learning_rate
 self.decay_steps = decay_steps
 self.decay_rate = decay_rate
 self.global_step = 0
 def __call__(self, x):
 return jnp.take(jnp.eye(self.vocab_size), x, axis=0)
 def weights_init(self):
 key = jax.random.PRNGKey(0)
 self.embedding_matrix = jax.random.normal(key, (self.vocab_size, self.d_model))
 key = jax.random.PRNGKey(1)
 self.context_matrix = jax.random.normal(key, (self.d_model, self.vocab_size))
 def forward(self, x):
 # Embedding layer
 # x is a vector of size (vocab_size, 1)
 # embedding_matrix is a matrix of size (vocab_size, d_model)
 # hidden is a vector of size (d_model, 1)
 hidden = jnp.dot(self.embedding_matrix.T, x)
 # Context layer
 # context_matrix is a matrix of size (d_model, vocab_size)
 # output is a vector of size (vocab_size, 1)
 output = jnp.dot(self.context_matrix.T, hidden)
 # Using softmax as activation function
 prediction = jax.nn.softmax(output)
 return hidden, prediction
 def backward(self, hidden, prediction, label):
 # Calculate error
 error = jnp.array(label) - prediction
 # Calculate cross-entropy loss
 loss = -jnp.sum(jnp.array(label) * jnp.log(prediction))
 # Calculate gradient
 grad_context = jnp.dot(hidden, error.T)
 grad_embedding = jnp.dot(error, self.context_matrix.T)
 self.loss = loss
 self.update_weights(grad_context, grad_embedding)
 def update_weights(self, grad_context, grad_embedding):
 # Update weights
 self.context_matrix += self.learning_rate * grad_context
 self.embedding_matrix += self.learning_rate * grad_embedding
 # Update learning rate
 self.global_step += 1
 if self.global_step % self.decay_steps == 0:
 self.learning_rate = self.learning_rate * self.decay_rate
Sᴀᴍ Onᴇᴌᴀ
29.5k16 gold badges45 silver badges201 bronze badges
asked Jan 31, 2024 at 20:24
\$\endgroup\$
1
  • 4
    \$\begingroup\$ In weights_init() you make the curious suggestion that the PRNG is in a non-deterministic state and must therefore be re-seeded prior to creating context_matrix. Other than that, the code reads like it was faithfully typed in from a nice text book. "How can this code be improved?" You gave us barely a smidgen of Review Context. We don't know your use case, your automated test suite, nor your evaluation rubric, so absent any obvious stack trace bugs we can't nudge this code in an "improved" direction, since you haven't described what's sub-optimal about it. \$\endgroup\$ Commented Jan 31, 2024 at 23:15

1 Answer 1

2
\$\begingroup\$

As @J_H pointed out in a comment, there is no more information other than your code to base these improvements upon. Thus, take it with a grain of salt.


Use Direct Embedding Lookups

Right now, your __call__ method returns a one-hot vector, and then later you do a dot product. This is not very efficient. The typical approach for embedding lookups is direct indexing:

def __call__(self, x):
 return self.embedding_matrix[x]

This avoids creating large one-hot vectors and performing costly dot products, but you'll need to ensure that x is an integer index (or a batch of indices)

Initialize Weights in __init__

You currently have a separate method that needs to be called. It's often cleaner to initialize weights in the constructor, as this ensure that the method is always in a ready-to-use state.

Clarify the Shapes and Flow

Your forward method is a bit confusing because it still expects x as if it were a one-hot vector. If you change your embedding lookup as recommended, then x should be an integer. If you want to handle a single token at a time:

def forward(self, x):
 hidden = self.embedding_matrix[x]
 output = jnp.dot(hidden, self.context_matrix)
 prediction = jax.nn.softmax(output)
 return hidden, prediction

If you want to handle a batch of tokens (say x with shape [batch_size]), your code would need to use broadcasting or a batch dimension:

def forward(self, x):
 hidden = self.embedding_matrix[x]
 output = jnp.einsum('bd,dv->bv', hidden, self.context_matrix)
 prediction = jax.nn.softmax(output, axis=-1)
 return hidden, prediction

Use Standard Libraries

Instead of manually implementing learning rate decay, consider using optimizers from opt (the recommended optimization library for JAX). For example, optax.exponential_decay can handle the learning rate schedule for you.

answered Dec 10, 2024 at 8:32
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.