@@ -81,21 +81,42 @@ def __init__(self, dims: List[int], device: str) -> None:
81
81
device = device ,
82
82
) for i in range (len (dims )- 1 )
83
83
]
84
+
85
+ self .classify_layer = torch .nn .Linear (dims [- 1 ], 10 )
86
+ self .criterion = torch .nn .CrossEntropyLoss (label_smoothing = 0.1 )
87
+ self .optimizer = torch .optim .AdamW (self .classify_layer .parameters (), lr = 0.01 )
88
+ self .softmax = torch .nn .Softmax (dim = 1 )
84
89
85
- def forward (self , pos_inputs : torch .Tensor , neg_inputs : torch .Tensor , train_mode : bool = True ) -> torch .Tensor :
90
+ def forward (
91
+ self ,
92
+ pos_inputs : torch .Tensor ,
93
+ neg_inputs : torch .Tensor ,
94
+ pos_labels : torch .Tensor ,
95
+ train_mode : bool = True ,
96
+ ) -> torch .Tensor :
86
97
total_loss = 0.0
98
+
99
+ # Forward layers
87
100
for layer in self .layers :
88
101
pos_inputs , neg_inputs , loss = layer (pos_inputs , neg_inputs , train_mode )
89
102
total_loss += loss .item ()
90
103
91
- return total_loss
104
+ # Classifier Layer (the last layer)
105
+ pos_outputs = self .classify_layer (pos_inputs )
106
+ pos_outputs = self .softmax (pos_outputs )
107
+ loss = self .criterion (pos_outputs , pos_labels )
108
+ loss .backward ()
109
+ self .optimizer .step ()
110
+
111
+ return total_loss + loss .item ()
92
112
93
113
@torch .no_grad ()
94
114
def predict (self , inputs : torch .Tensor , num_classes : int = 10 ) -> int :
95
115
for layer in self .layers :
96
116
inputs = layer .linear_transform (inputs )
97
117
98
- goodness = inputs .pow (2 ).mean (1 )
118
+ outputs = self .classify_layer (inputs )
119
+ outptus = self .softmax (outputs )
99
120
100
- return torch .argmax (goodness )
121
+ return torch .argmax (outputs , dim = 1 )
101
122
0 commit comments