Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit bf86ea9

Browse files
committed
[Fix] device error
1 parent eddffdf commit bf86ea9

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

‎ffmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, dims: List[int], device: str) -> None:
8282
) for i in range(len(dims)-1)
8383
]
8484

85-
self.classify_layer = torch.nn.Linear(dims[-1], 10)
85+
self.classify_layer = torch.nn.Linear(dims[-1], 10).to(device)
8686
self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
8787
self.optimizer = torch.optim.AdamW(self.classify_layer.parameters(), lr=0.01)
8888
self.softmax = torch.nn.Softmax(dim=1)

‎train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def main() -> None:
3636
loss_logger.reset()
3737

3838
for inputs, labels in pbar:
39-
break
4039
pos_inputs = create_pos_data(inputs, labels).to(device)
4140
neg_inputs = create_neg_data(inputs, labels).to(device)
41+
labels = labels.to(device)
4242

4343
loss = model(pos_inputs=pos_inputs, neg_inputs=neg_inputs, pos_labels=labels)
4444
loss_logger.update(loss, inputs.shape[0])

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /