Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit eddffdf

Browse files
committed
[Add] classify layer
1 parent 0958531 commit eddffdf

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

‎ffmodel.py‎

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,42 @@ def __init__(self, dims: List[int], device: str) -> None:
8181
device=device,
8282
) for i in range(len(dims)-1)
8383
]
84+
85+
self.classify_layer = torch.nn.Linear(dims[-1], 10)
86+
self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
87+
self.optimizer = torch.optim.AdamW(self.classify_layer.parameters(), lr=0.01)
88+
self.softmax = torch.nn.Softmax(dim=1)
8489

85-
def forward(self, pos_inputs: torch.Tensor, neg_inputs: torch.Tensor, train_mode: bool = True) -> torch.Tensor:
90+
def forward(
91+
self,
92+
pos_inputs: torch.Tensor,
93+
neg_inputs: torch.Tensor,
94+
pos_labels: torch.Tensor,
95+
train_mode: bool = True,
96+
) -> torch.Tensor:
8697
total_loss = 0.0
98+
99+
# Forward layers
87100
for layer in self.layers:
88101
pos_inputs, neg_inputs, loss = layer(pos_inputs, neg_inputs, train_mode)
89102
total_loss += loss.item()
90103

91-
return total_loss
104+
# Classifier Layer (the last layer)
105+
pos_outputs = self.classify_layer(pos_inputs)
106+
pos_outputs = self.softmax(pos_outputs)
107+
loss = self.criterion(pos_outputs, pos_labels)
108+
loss.backward()
109+
self.optimizer.step()
110+
111+
return total_loss + loss.item()
92112

93113
@torch.no_grad()
94114
def predict(self, inputs: torch.Tensor, num_classes: int = 10) -> int:
95115
for layer in self.layers:
96116
inputs = layer.linear_transform(inputs)
97117

98-
goodness = inputs.pow(2).mean(1)
118+
outputs = self.classify_layer(inputs)
119+
outptus = self.softmax(outputs)
99120

100-
return torch.argmax(goodness)
121+
return torch.argmax(outputs, dim=1)
101122

‎train.py‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ def main() -> None:
3636
loss_logger.reset()
3737

3838
for inputs, labels in pbar:
39+
break
3940
pos_inputs = create_pos_data(inputs, labels).to(device)
4041
neg_inputs = create_neg_data(inputs, labels).to(device)
4142

42-
loss = model(pos_inputs, neg_inputs)
43+
loss = model(pos_inputs=pos_inputs, neg_inputs=neg_inputs, pos_labels=labels)
4344
loss_logger.update(loss, inputs.shape[0])
4445
pbar.set_description(f"Train - Epoch [{epoch}/{num_epochs}] Loss: {loss_logger.avg:.4f}")
4546

@@ -55,7 +56,8 @@ def main() -> None:
5556
inputs_all_labels = create_test_data(inputs).to(device)
5657

5758
predict = model.predict(inputs_all_labels)
58-
predicts.append(predict.item())
59+
60+
predicts.extend(predict.tolist())
5961
targets.append(labels.item())
6062

6163
print(metrics.classification_report(targets, predicts))

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /