We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bf86ea9 commit 305a910Copy full SHA for 305a910
train.py
@@ -17,7 +17,7 @@ def main() -> None:
17
18
# DataLoader
19
train_dataloader = get_mnist_dataloader(_mode="train", batch_size=batch_size)
20
- val_dataloader = get_mnist_dataloader(_mode="val", batch_size=1)
+ val_dataloader = get_mnist_dataloader(_mode="val", batch_size=16)
21
22
# Device
23
device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -53,9 +53,9 @@ def main() -> None:
53
predicts = []
54
targets = []
55
for inputs, labels in tqdm(val_dataloader):
56
- inputs_all_labels = create_test_data(inputs).to(device)
+ # inputs_all_labels = create_test_data(inputs).to(device)
57
58
- predict = model.predict(inputs_all_labels)
+ predict = model.predict(inputs)
59
60
predicts.extend(predict.tolist())
61
targets.append(labels.item())
AltStyle によって変換されたページ (->オリジナル) / アドレス: モード: デフォルト 音声ブラウザ ルビ付き 配色反転 文字拡大 モバイル
0 commit comments