Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0958531

Browse files
committed
[Fix] Validation
1 parent 5faddba commit 0958531

File tree

3 files changed

+30
-24
lines changed

3 files changed

+30
-24
lines changed

‎ffmodel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,17 @@ def forward(
6868

6969
return pos_outputs.detach(), neg_outputs.detach(), loss.detach()
7070

71-
71+
72+
# TODO: This replemented cannot convert into DEVICE we set...
7273
class FFClassifier(torch.nn.Module):
73-
def __init__(self, dims: List[int]) -> None:
74+
def __init__(self, dims: List[int], device: str) -> None:
7475
super().__init__()
7576
self.layers = [
7677
FFLinear(
7778
in_features=dims[i],
7879
out_features=dims[i+1],
7980
lr=0.01,
81+
device=device,
8082
) for i in range(len(dims)-1)
8183
]
8284

‎test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@
1111

1212
def main() -> None:
1313
# DataLoader
14-
# train_dataloader = get_mnist_dataloader(_mode="train", batch_size=16)
15-
# val_dataloader = get_mnist_dataloader(_mode="val", batch_size=16)
1614
test_dataloader = get_mnist_dataloader(_mode="test", batch_size=1)
15+
16+
# Device
17+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
1718

1819
# FFModel
19-
# model = FFClassifier([28*28, 2000, 2000, 2000, 2000])
20-
model = torch.load("./models/epoch80.ckpt")
20+
model = torch.load("./models/epoch31.ckpt").eval()
2121

2222
# Evaluation
2323
predicts = []
2424
targets = []
2525
for inputs, labels in tqdm(test_dataloader):
26-
inputs_all_labels = create_test_data(inputs)
26+
inputs_all_labels = create_test_data(inputs).to(device)
2727

2828
predict = model.predict(inputs_all_labels)
2929
predicts.append(predict.item())

‎train.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# coding: utf-8
2+
from sklearn import metrics
23
import torch
34
from tqdm import tqdm
45

56
from dataloader import get_mnist_dataloader
67
from ffmodel import FFClassifier
7-
from utils import AverageMeter, create_pos_data, create_neg_data
8+
from utils import AverageMeter, create_pos_data, create_neg_data, create_test_data
89

910
torch.manual_seed(2999)
1011

@@ -16,12 +17,15 @@ def main() -> None:
1617

1718
# DataLoader
1819
train_dataloader = get_mnist_dataloader(_mode="train", batch_size=batch_size)
19-
val_dataloader = get_mnist_dataloader(_mode="val", batch_size=batch_size)
20-
test_dataloader = get_mnist_dataloader(_mode="test", batch_size=batch_size)
20+
val_dataloader = get_mnist_dataloader(_mode="val", batch_size=1)
21+
22+
# Device
23+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
2124

2225
# FFModel
23-
model = FFClassifier([28*28, 2000, 2000, 2000, 2000])
24-
26+
model = FFClassifier([28*28, 2000, 2000, 2000, 2000], device=device)
27+
torch.compile(model)
28+
2529
# Loss Logger
2630
loss_logger = AverageMeter()
2731

@@ -32,8 +36,8 @@ def main() -> None:
3236
loss_logger.reset()
3337

3438
for inputs, labels in pbar:
35-
pos_inputs = create_pos_data(inputs, labels)
36-
neg_inputs = create_neg_data(inputs, labels)
39+
pos_inputs = create_pos_data(inputs, labels).to(device)
40+
neg_inputs = create_neg_data(inputs, labels).to(device)
3741

3842
loss = model(pos_inputs, neg_inputs)
3943
loss_logger.update(loss, inputs.shape[0])
@@ -43,20 +47,20 @@ def main() -> None:
4347

4448
# Validation
4549
model.eval()
46-
pbar = tqdm(val_dataloader, desc=f"Valid - Epoch [{epoch}/{num_epochs}] Loss: {loss_logger.avg:.4f}")
47-
loss_logger.reset()
4850

49-
for inputs, labels in pbar:
50-
pos_inputs = create_pos_data(inputs, labels)
51-
neg_inputs = create_neg_data(inputs, labels)
52-
53-
with torch.no_grad():
54-
loss = model(pos_inputs, neg_inputs, train_mode=False)
51+
# Evaluation
52+
predicts = []
53+
targets = []
54+
for inputs, labels in tqdm(val_dataloader):
55+
inputs_all_labels = create_test_data(inputs).to(device)
5556

56-
loss_logger.update(loss, inputs.shape[0])
57-
pbar.set_description(f"Valid - Epoch [{epoch}/{num_epochs}] Loss: {loss_logger.avg:.4f}")
57+
predict = model.predict(inputs_all_labels)
58+
predicts.append(predict.item())
59+
targets.append(labels.item())
5860

61+
print(metrics.classification_report(targets, predicts))
5962
print()
6063

64+
6165
if __name__ == "__main__":
6266
main()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /