I'm trying to implement a regularization term for the loss function of a neural network.
from torch import nn
import torch
import numpy as np
reg_sig = torch.randn([32, 9, 5])
reg_adj = torch.randn([32, 9, 9, 4])
Maug = reg_adj.shape[0]
n_node = 9
n_bond_features = 4
n_atom_features = 5
SM_f = nn.Softmax(dim=2)
SM_W = nn.Softmax(dim=3)
p_f = SM_f(reg_sig)
p_W = SM_W(reg_adj)
Sig = nn.Sigmoid()
q = 1 - p_f[:, :, 4]
A = 1 - p_W[:, :, :, 0]
A_0 = torch.eye(n_node)
A_0 = A_0.reshape((1, n_node, n_node))
A_i = A
B = A_0.repeat(reg_sig.size(0), 1, 1)
for i in range(1, n_node):
A_i = Sig(100 * (torch.bmm(A_i, A) - 0.5))
B += A_i
C = Sig(100 * (B - 0.5))
reg_g_ij = torch.randn([reg_sig.size(0), n_node, n_node])
for i in range(n_node):
for j in range(n_node):
reg_g_ij[:, i, j] = q[:, i] * q[:, j] * (1 - C[:, i, j]) + (1 - q[:, i] * q[:, j]) * C[:, i, j]
I believe that my implementation is computationally not efficient and would like to have some suggestions on which parts I can change. Specifically, I would like to get rid of the loops and do them using matrix operations if possible. Any suggestions or working examples or links to useful torch functions would be appreciated
1 Answer 1
I don't have many improvements to offer-- just one major one. Like you suspected, your implementation is not efficient. This is because using a double for loop to set a Torch/NumPy array is not the preferred way to do sum reductions. What is preferred, is the use of torch.einsum. It takes an indices equation and reduces the Tensors into a final representation.
First to note is that your equation for reg_g_ij
is not the most simplified form.
In your code, we start with:
q_i * q_j * (1 - C_ij) + (1 - q_i * q_j) * C_ij
But it can be reduced to:
q_i * q_j * (1 - 2 * C_ij) + C_ij
You can prove it yourself with a few lines of algebra.
The last small thing is call .unsqueeze(0)
when you're expanding the dimensions of an array. In this case we used this method to expand an array's size from (9, 9) to (1, 9, 9).
A_0 = torch.eye(n_node).unsqueeze(0)
A_i = A
B = A_0.repeat(reg_sig.size(0), 1, 1)
for i in range(1, n_node):
A_i = Sig(100 * (torch.bmm(A_i, A) - 0.5))
B += A_i
C = Sig(100 * (B - 0.5))
reg_g_ij = torch.einsum('ij,ik,ijk->ijk', q, q, 1 - 2 * C) + C
When profiling this approach, we see a pretty big reduction in time:
In [257]: %timeit new(reg_sig, reg_adj)
1000 loops, best of 5: 745 μs per loop
In [258]: %timeit orig(reg_sig, reg_adj)
The slowest run took 4.85 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 5.44 ms per loop
-
\$\begingroup\$ Thanks for suggestions, I got it down to this before seeing the post:
temp = torch.einsum('bi, bj->bij', q, q)
and thenreg_g_ij = torch.einsum('bij, bij->bij', temp, 1 - C) + torch.einsum('bij, bij->bij', 1 - temp, C)
, but I missed the algebra. \$\endgroup\$Blade– Blade2020年01月04日 01:11:32 +00:00Commented Jan 4, 2020 at 1:11