One instance of the following module uses up to almost 75% of my vram. So, I was wondering how I could improve that without slowing down runtime too much. The code is below:
NUM_OF_IMUS = 13
NUM_OF_NOISE_PARAMS = 9
class mod(nn.Module):
def __init__(self, d_model, device):
super(Noise_Regressor, self).__init__()
self.norm1 = nn.LayerNorm(d_model)
self.hidden_state_to_noise_params = nn.Linear(d_model, NUM_OF_IMUS * NUM_OF_NOISE_PARAMS)
self.eps = 1e-5
self.device = device
self.t_step_init_mat = torch.triu(torch.arange(10000, device=self.device) - torch.arange(10000, device=self.device)[:, None])
self.MASK = torch.triu(torch.ones((10000, 10000), device=self.device), diagonal=0)
"""
hidden_states should be of dimension (Batch, Sequence Len, Dim)
B should always be 1
Sequence Length can be up to 10000
The dimension can be 512
"""
def forward(self, hidden_states, min_orig_accel_norm):
seq_len = hidden_states.shape[1]
t_step_init_mat = self.t_step_init_mat[:seq_len, :seq_len]
MASK = self.MASK[:seq_len, :seq_len]
hidden_normed = self.norm1(hidden_states)
noise_params = self.hidden_state_to_noise_params(hidden_normed).view(seq_len, NUM_OF_NOISE_PARAMS, NUM_OF_IMUS)
c = noise_params[:, 4, :].view(seq_len, 1, NUM_OF_IMUS)
c_theta = noise_params[:, 5, :].view(seq_len, 1, NUM_OF_IMUS)
phi = noise_params[:, 6, :].view(seq_len, 1, NUM_OF_IMUS)
phi_theta = noise_params[:, 7, :].view(seq_len, 1, NUM_OF_IMUS)
d = torch.sqrt((noise_params[:, 1, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
k = (d**2) / 4 + F.softplus(noise_params[:, 0, :]).view(seq_len, 1, NUM_OF_IMUS)
d_theta = torch.sqrt((noise_params[:, 3, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
k_theta = (d_theta**2) / 4 + F.softplus(noise_params[:, 2, :]).view(seq_len, 1, NUM_OF_IMUS)
noise_bias = noise_params[:, 8, :].T
dynamics_list = []
for imu_num in range(NUM_OF_IMUS):
omega1 = torch.sqrt(4 * k[:, :, imu_num] - (d[:, :, imu_num] ** 2)) / 2
linear_dynamics = c[:, :, imu_num] * torch.exp((-d[:, :, imu_num] / 2) * t_step_init_mat) * torch.sin(phi[:, :, imu_num] + t_step_init_mat * omega1)
omega1_theta = torch.sqrt(4 * k_theta[:, :, imu_num] - (d_theta[:, :, imu_num] ** 2)) / 2
angular_dynamics = c_theta[:, :, imu_num] * torch.exp((-d_theta[:, :, imu_num] / 2) * t_step_init_mat) * torch.sin(phi_theta[:, :, imu_num] + t_step_init_mat * omega1_theta)
spring_damper_dynamics_per_step = (linear_dynamics + angular_dynamics) * MASK
dynamics_list.append(torch.sum(spring_damper_dynamics_per_step, dim=0, keepdim=True))
return torch.cat(dynamics_list, dim=0) + min_orig_accel_norm + noise_bias
-
\$\begingroup\$ Please do not edit the question, especially the code, after an answer has been posted. Changing the question may cause answer invalidation. Everyone needs to be able to see what the reviewer was referring to. What to do after the question has been answered. \$\endgroup\$pacmaninbw– pacmaninbw ♦2024年12月08日 13:20:20 +00:00Commented Dec 8, 2024 at 13:20
1 Answer 1
computing a discarded result
Please don't write code like this:
def greet(name):
42
name + " is cool."
print(f"Hello {name}!")
Yes, you can compute a literal or an expression and then discard the result, the python interpreter will let you do that. But it doesn't help the readability of your code.
Rather than
"""
hidden_states should be of dimension (Batch, Sequence Len, Dim)
B should always be 1
Sequence Length can be up to 10000
The dimension can be 512
"""
you meant to write
# hidden_states should be of dimension (Batch, Sequence Len, Dim)
# B should always be 1
# Sequence Length can be up to 10000
# The dimension can be 512
Please note that the OP code does not contain any docstrings, despite the presence of a discarded triple-quoted string.
inheritance
class mod(nn.Module):
...
super(Noise_Regressor, self).__init__()
Maybe you'd prefer for the class
MRO to inherit from both those classes?
helpful names
The various d
, k
, c
, and phi
local variables are admirably clear.
Thank you for spelling out the meaning of what's at those indices.
Do try to cite your references.
As written it's unclear which wikipedia page or other textbook resource
mod
might be trying to implement.
performance
75% of my vram
You didn't tell us the business problem you're trying to solve, the problem size, nor your VRAM size or any elapsed timings. I'm willing to believe we're computing some figures which do not impinge directly on the business problem and could be discarded, but the OP doesn't help us understand what aspects of the computation are most important to the use case.
Explore related questions
See similar questions with these tags.