Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 6717487

Browse files
committed
[Add] CIFAR10 trainning
1 parent 85b2b31 commit 6717487

File tree

4 files changed

+177
-2
lines changed

4 files changed

+177
-2
lines changed

‎analysis.ipynb

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 3,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"Files already downloaded and verified\n",
13+
"Mean: tensor(0.4734)\n",
14+
"Std: tensor(0.2182)\n"
15+
]
16+
}
17+
],
18+
"source": [
19+
"from torchvision import datasets, transforms\n",
20+
"import torch\n",
21+
"\n",
22+
"# Load CIFAR-10 dataset\n",
23+
"cifar10_train = datasets.CIFAR10(root=\"CIFAR10\", train=True, download=True, transform=transforms.ToTensor())\n",
24+
"\n",
25+
"# Compute mean and std\n",
26+
"mean = 0.0\n",
27+
"var = 0.0\n",
28+
"num_pixels = 0\n",
29+
"for images, _ in cifar10_train:\n",
30+
" flattened_images = images.view(-1)\n",
31+
" mean += flattened_images.mean()\n",
32+
" var += flattened_images.var()\n",
33+
" num_pixels += flattened_images.numel()\n",
34+
"\n",
35+
"mean = mean / len(cifar10_train)\n",
36+
"std = torch.sqrt(var / len(cifar10_train))\n",
37+
"\n",
38+
"print(\"Mean:\", mean)\n",
39+
"print(\"Std:\", std)\n",
40+
"\n"
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": null,
46+
"metadata": {},
47+
"outputs": [],
48+
"source": []
49+
}
50+
],
51+
"metadata": {
52+
"interpreter": {
53+
"hash": "c6747903916dc58a7ba40bc6ae759cb749c0812e4b412aba9b66e06fcd59c242"
54+
},
55+
"kernelspec": {
56+
"display_name": "Python 3.8.17 ('venv': venv)",
57+
"language": "python",
58+
"name": "python3"
59+
},
60+
"language_info": {
61+
"codemirror_mode": {
62+
"name": "ipython",
63+
"version": 3
64+
},
65+
"file_extension": ".py",
66+
"mimetype": "text/x-python",
67+
"name": "python",
68+
"nbconvert_exporter": "python",
69+
"pygments_lexer": "ipython3",
70+
"version": "3.8.17"
71+
},
72+
"orig_nbformat": 4
73+
},
74+
"nbformat": 4,
75+
"nbformat_minor": 2
76+
}

‎dataloader.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from torch.utils.data import DataLoader, random_split
44
from torchvision import transforms
5-
from torchvision.datasets import MNIST
5+
from torchvision.datasets import MNIST, CIFAR10
66

77

88
def get_mnist_dataloader(_mode: str = "train", batch_size: int = 32) -> DataLoader:
@@ -33,3 +33,36 @@ def get_mnist_dataloader(_mode: str = "train", batch_size: int = 32) -> DataLoad
3333
_dataset = MNIST("", train=False, download=True, transform=transform)
3434

3535
return DataLoader(_dataset, batch_size=batch_size)
36+
37+
38+
def get_cifar10_dataloader(_mode: str = "train", batch_size: int = 32) -> DataLoader:
39+
mean = (0.4914, 0.4822, 0.4465)
40+
std = (0.247, 0.243, 0.261)
41+
42+
transform = transforms.Compose([
43+
transforms.ToTensor(),
44+
transforms.Normalize(mean=mean, std=std),
45+
transforms.Lambda(lambda x: torch.flatten(x)),
46+
])
47+
if _mode == "train":
48+
_dataset = CIFAR10("CIFAR10", train=True, download=True, transform=transform)
49+
_dataset, _ = random_split(
50+
_dataset,
51+
(
52+
int(len(_dataset) * 0.9),
53+
len(_dataset) - int(len(_dataset) * 0.9),
54+
),
55+
)
56+
elif _mode == "val":
57+
_dataset = CIFAR10("CIFAR10", train=True, download=True, transform=transform)
58+
_, _dataset = random_split(
59+
_dataset,
60+
(
61+
int(len(_dataset) * 0.9),
62+
len(_dataset) - int(len(_dataset) * 0.9),
63+
),
64+
)
65+
elif _mode == "test":
66+
_dataset = CIFAR10("CIFAR10", train=False, download=True, transform=transform)
67+
68+
return DataLoader(_dataset, batch_size=batch_size)

‎ffmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
self.relu = torch.nn.ReLU()
3232

3333
def linear_transform(self, inputs: torch.Tensor) -> torch.Tensor:
34-
# L2 Norm & smoothy TODO: why???
34+
# L2 Norm & smoothy
3535
inputs_l2_norm = inputs.norm(2, 1, keepdim=True) + 1e-4
3636

3737
# Normalization

‎train_cifar.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# coding: utf-8
2+
from sklearn import metrics
3+
import torch
4+
from tqdm import tqdm
5+
6+
from dataloader import get_cifar10_dataloader
7+
from ffmodel import FFClassifier
8+
from utils import AverageMeter, create_pos_data, create_neg_data, create_test_data
9+
10+
torch.manual_seed(2999)
11+
12+
13+
def main() -> None:
14+
# Settings
15+
num_epochs = 80
16+
batch_size = 64
17+
18+
# DataLoader
19+
train_dataloader = get_cifar10_dataloader(_mode="train", batch_size=batch_size)
20+
val_dataloader = get_cifar10_dataloader(_mode="val", batch_size=1)
21+
22+
# Device
23+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
24+
25+
# FFModel
26+
model = FFClassifier([32*32*3, 2000, 2000, 2000, 2000], device=device)
27+
torch.compile(model)
28+
29+
# Loss Logger
30+
loss_logger = AverageMeter()
31+
32+
# Train
33+
for epoch in range(1, num_epochs+1):
34+
model.train()
35+
pbar = tqdm(train_dataloader, desc=f"Train - Epoch [{epoch}/{num_epochs}] Loss: {loss_logger.avg:.4f}")
36+
loss_logger.reset()
37+
38+
for inputs, labels in pbar:
39+
pos_inputs = create_pos_data(inputs, labels).to(device)
40+
neg_inputs = create_neg_data(inputs, labels).to(device)
41+
42+
loss = model(pos_inputs=pos_inputs, neg_inputs=neg_inputs)
43+
loss_logger.update(loss, inputs.shape[0])
44+
pbar.set_description(f"Train - Epoch [{epoch}/{num_epochs}] Loss: {loss_logger.avg:.4f}")
45+
46+
torch.save(model, f"./models/epoch{epoch}.ckpt")
47+
48+
# Validation
49+
model.eval()
50+
51+
# Evaluation
52+
predicts = []
53+
targets = []
54+
for inputs, labels in tqdm(val_dataloader):
55+
inputs = create_test_data(inputs).to(device)
56+
predict = model.predict(inputs)
57+
58+
predicts.append(predict.item())
59+
targets.append(labels.item())
60+
61+
print(metrics.classification_report(targets, predicts))
62+
print()
63+
64+
65+
if __name__ == "__main__":
66+
main()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /