1
1
# coding: utf-8
2
+ from sklearn import metrics
2
3
import torch
3
4
from tqdm import tqdm
4
5
5
6
from dataloader import get_mnist_dataloader
6
7
from ffmodel import FFClassifier
7
- from utils import AverageMeter , create_pos_data , create_neg_data
8
+ from utils import AverageMeter , create_pos_data , create_neg_data , create_test_data
8
9
9
10
torch .manual_seed (2999 )
10
11
@@ -16,12 +17,15 @@ def main() -> None:
16
17
17
18
# DataLoader
18
19
train_dataloader = get_mnist_dataloader (_mode = "train" , batch_size = batch_size )
19
- val_dataloader = get_mnist_dataloader (_mode = "val" , batch_size = batch_size )
20
- test_dataloader = get_mnist_dataloader (_mode = "test" , batch_size = batch_size )
20
+ val_dataloader = get_mnist_dataloader (_mode = "val" , batch_size = 1 )
21
+
22
+ # Device
23
+ device = "cuda:0" if torch .cuda .is_available () else "cpu"
21
24
22
25
# FFModel
23
- model = FFClassifier ([28 * 28 , 2000 , 2000 , 2000 , 2000 ])
24
-
26
+ model = FFClassifier ([28 * 28 , 2000 , 2000 , 2000 , 2000 ], device = device )
27
+ torch .compile (model )
28
+
25
29
# Loss Logger
26
30
loss_logger = AverageMeter ()
27
31
@@ -32,8 +36,8 @@ def main() -> None:
32
36
loss_logger .reset ()
33
37
34
38
for inputs , labels in pbar :
35
- pos_inputs = create_pos_data (inputs , labels )
36
- neg_inputs = create_neg_data (inputs , labels )
39
+ pos_inputs = create_pos_data (inputs , labels ). to ( device )
40
+ neg_inputs = create_neg_data (inputs , labels ). to ( device )
37
41
38
42
loss = model (pos_inputs , neg_inputs )
39
43
loss_logger .update (loss , inputs .shape [0 ])
@@ -43,20 +47,20 @@ def main() -> None:
43
47
44
48
# Validation
45
49
model .eval ()
46
- pbar = tqdm (val_dataloader , desc = f"Valid - Epoch [{ epoch } /{ num_epochs } ] Loss: { loss_logger .avg :.4f} " )
47
- loss_logger .reset ()
48
50
49
- for inputs , labels in pbar :
50
- pos_inputs = create_pos_data (inputs , labels )
51
- neg_inputs = create_neg_data (inputs , labels )
52
-
53
- with torch .no_grad ():
54
- loss = model (pos_inputs , neg_inputs , train_mode = False )
51
+ # Evaluation
52
+ predicts = []
53
+ targets = []
54
+ for inputs , labels in tqdm (val_dataloader ):
55
+ inputs_all_labels = create_test_data (inputs ).to (device )
55
56
56
- loss_logger .update (loss , inputs .shape [0 ])
57
- pbar .set_description (f"Valid - Epoch [{ epoch } /{ num_epochs } ] Loss: { loss_logger .avg :.4f} " )
57
+ predict = model .predict (inputs_all_labels )
58
+ predicts .append (predict .item ())
59
+ targets .append (labels .item ())
58
60
61
+ print (metrics .classification_report (targets , predicts ))
59
62
print ()
60
63
64
+
61
65
if __name__ == "__main__" :
62
66
main ()
0 commit comments