Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

How ot.dist work with sequence? #524

Unanswered
huutuongtu asked this question in Q&A
Discussion options

Hello,

I consider OT a black box, so I may ask something stupid.

I'm following Wasserstein 2 Minibatch GAN with PyTorch to train my own model, but I got error. My input and output are sequences. Here is my code:

 ab = (torch.ones(batch_size) / batch_size).to(device)
 sgd = torch.optim.Adam(model.parameters(), lr=0.001)
 CE_loss = nn.CrossEntropyLoss(ignore_index = 41)
 for epoch in range(1000):
 logits, c_emb, t_emb = model(phonetic, linguistic, transcript)
 # print(logits.shape) #batch x classes x time
 # print(c_emb.shape) #batch x time x features
 # print(t_emb.shape) #batch x time x features
 M = ot.dist(c_emb, t_emb)
 loss_W = ot.emd2(ab, ab, M).to(device)
 loss_CE = CE_loss(logits, output)
 loss = loss_W + loss_CE
 loss.backward()
 sgd.step()
 sgd.zero_grad()

The error:

M = ot.dist(c_emb, t_emb)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 307, in dist
return euclidean_distances(x1, x2, squared=True)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 253, in euclidean_distances
a2 = nx.einsum('ij,ij->i', X, X)
File "/opt/conda/lib/python3.8/site-packages/ot/backend.py", line 1897, in einsum
return torch.einsum(subscripts, *operands)
File "/opt/conda/lib/python3.8/site-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

How can I use ot.dist with sequence correctly?

You must be logged in to vote

Replies: 1 comment

Comment options

You seem to have a size problem here. On sequences you should write your own distance function in full pytorch so that you can have full backprop. Te function should return an nxm matrix where n is the number of sequences in c_emb and t_emb respectively. ot.dist works only between samples in vector format following numpy cdist API.

I'm converting this to a discussion since it does not seem a bug from POT.

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
Converted from issue

This discussion was converted from issue #522 on September 19, 2023 07:07.

AltStyle によって変換されたページ (->オリジナル) /