- 
  Notifications
 You must be signed in to change notification settings 
- Fork 537
How ot.dist work with sequence? #524
-
Hello,
I consider OT a black box, so I may ask something stupid.
I'm following Wasserstein 2 Minibatch GAN with PyTorch to train my own model, but I got error. My input and output are sequences. Here is my code:
 ab = (torch.ones(batch_size) / batch_size).to(device)
 sgd = torch.optim.Adam(model.parameters(), lr=0.001)
 CE_loss = nn.CrossEntropyLoss(ignore_index = 41)
 for epoch in range(1000):
 logits, c_emb, t_emb = model(phonetic, linguistic, transcript)
 # print(logits.shape) #batch x classes x time
 # print(c_emb.shape) #batch x time x features
 # print(t_emb.shape) #batch x time x features
 M = ot.dist(c_emb, t_emb)
 loss_W = ot.emd2(ab, ab, M).to(device)
 loss_CE = CE_loss(logits, output)
 loss = loss_W + loss_CE
 loss.backward()
 sgd.step()
 sgd.zero_grad()
The error:
M = ot.dist(c_emb, t_emb)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 307, in dist
return euclidean_distances(x1, x2, squared=True)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 253, in euclidean_distances
a2 = nx.einsum('ij,ij->i', X, X)
File "/opt/conda/lib/python3.8/site-packages/ot/backend.py", line 1897, in einsum
return torch.einsum(subscripts, *operands)
File "/opt/conda/lib/python3.8/site-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given
How can I use ot.dist with sequence correctly?
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 1 comment
-
You seem to have a size problem here. On sequences you should write your own distance function in full pytorch so that you can have full backprop. Te function should return an nxm matrix where n is the number of sequences in c_emb and t_emb respectively. ot.dist works only between samples in vector format following numpy cdist API.
I'm converting this to a discussion since it does not seem a bug from POT.
Beta Was this translation helpful? Give feedback.