2
\$\begingroup\$

One instance of the following module uses up to almost 75% of my vram. So, I was wondering how I could improve that without slowing down runtime too much. The code is below:

NUM_OF_IMUS = 13
NUM_OF_NOISE_PARAMS = 9
class mod(nn.Module):
 def __init__(self, d_model, device):
 super(Noise_Regressor, self).__init__()
 
 self.norm1 = nn.LayerNorm(d_model)
 self.hidden_state_to_noise_params = nn.Linear(d_model, NUM_OF_IMUS * NUM_OF_NOISE_PARAMS)
 self.eps = 1e-5
 self.device = device
 self.t_step_init_mat = torch.triu(torch.arange(10000, device=self.device) - torch.arange(10000, device=self.device)[:, None])
 self.MASK = torch.triu(torch.ones((10000, 10000), device=self.device), diagonal=0)
 
 """
 hidden_states should be of dimension (Batch, Sequence Len, Dim)
 B should always be 1
 Sequence Length can be up to 10000
 The dimension can be 512
 """
 def forward(self, hidden_states, min_orig_accel_norm):
 seq_len = hidden_states.shape[1]
 
 t_step_init_mat = self.t_step_init_mat[:seq_len, :seq_len]
 MASK = self.MASK[:seq_len, :seq_len]
 
 hidden_normed = self.norm1(hidden_states)
 noise_params = self.hidden_state_to_noise_params(hidden_normed).view(seq_len, NUM_OF_NOISE_PARAMS, NUM_OF_IMUS)
 
 c = noise_params[:, 4, :].view(seq_len, 1, NUM_OF_IMUS)
 c_theta = noise_params[:, 5, :].view(seq_len, 1, NUM_OF_IMUS)
 phi = noise_params[:, 6, :].view(seq_len, 1, NUM_OF_IMUS)
 phi_theta = noise_params[:, 7, :].view(seq_len, 1, NUM_OF_IMUS)
 d = torch.sqrt((noise_params[:, 1, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
 k = (d**2) / 4 + F.softplus(noise_params[:, 0, :]).view(seq_len, 1, NUM_OF_IMUS)
 d_theta = torch.sqrt((noise_params[:, 3, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
 k_theta = (d_theta**2) / 4 + F.softplus(noise_params[:, 2, :]).view(seq_len, 1, NUM_OF_IMUS)
 
 noise_bias = noise_params[:, 8, :].T
 
 dynamics_list = [] 
 for imu_num in range(NUM_OF_IMUS):
 omega1 = torch.sqrt(4 * k[:, :, imu_num] - (d[:, :, imu_num] ** 2)) / 2
 linear_dynamics = c[:, :, imu_num] * torch.exp((-d[:, :, imu_num] / 2) * t_step_init_mat) * torch.sin(phi[:, :, imu_num] + t_step_init_mat * omega1)
 omega1_theta = torch.sqrt(4 * k_theta[:, :, imu_num] - (d_theta[:, :, imu_num] ** 2)) / 2
 angular_dynamics = c_theta[:, :, imu_num] * torch.exp((-d_theta[:, :, imu_num] / 2) * t_step_init_mat) * torch.sin(phi_theta[:, :, imu_num] + t_step_init_mat * omega1_theta)
 spring_damper_dynamics_per_step = (linear_dynamics + angular_dynamics) * MASK
 dynamics_list.append(torch.sum(spring_damper_dynamics_per_step, dim=0, keepdim=True)) 
 
 return torch.cat(dynamics_list, dim=0) + min_orig_accel_norm + noise_bias
pacmaninbw
26.2k13 gold badges47 silver badges113 bronze badges
asked Dec 5, 2024 at 20:55
\$\endgroup\$
1
  • \$\begingroup\$ Please do not edit the question, especially the code, after an answer has been posted. Changing the question may cause answer invalidation. Everyone needs to be able to see what the reviewer was referring to. What to do after the question has been answered. \$\endgroup\$ Commented Dec 8, 2024 at 13:20

1 Answer 1

2
\$\begingroup\$

computing a discarded result

Please don't write code like this:

def greet(name):
 42
 name + " is cool."
 print(f"Hello {name}!")

Yes, you can compute a literal or an expression and then discard the result, the python interpreter will let you do that. But it doesn't help the readability of your code.

Rather than

 """
 hidden_states should be of dimension (Batch, Sequence Len, Dim)
 B should always be 1
 Sequence Length can be up to 10000
 The dimension can be 512
 """

you meant to write

 # hidden_states should be of dimension (Batch, Sequence Len, Dim)
 # B should always be 1
 # Sequence Length can be up to 10000
 # The dimension can be 512

Please note that the OP code does not contain any docstrings, despite the presence of a discarded triple-quoted string.

inheritance

class mod(nn.Module):
 ...
 super(Noise_Regressor, self).__init__()

Maybe you'd prefer for the class MRO to inherit from both those classes?

helpful names

The various d, k, c, and phi local variables are admirably clear. Thank you for spelling out the meaning of what's at those indices.

Do try to cite your references. As written it's unclear which wikipedia page or other textbook resource mod might be trying to implement.

performance

75% of my vram

You didn't tell us the business problem you're trying to solve, the problem size, nor your VRAM size or any elapsed timings. I'm willing to believe we're computing some figures which do not impinge directly on the business problem and could be discarded, but the OP doesn't help us understand what aspects of the computation are most important to the use case.

answered Dec 6, 2024 at 3:33
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.