Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 7e96460

Browse files
Initial commit
1 parent 9c2d093 commit 7e96460

10 files changed

+283
-2
lines changed

‎README.md

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,39 @@
1-
# pytorch-to-javascript
2-
Convert PyTorch models into JavaScript models via ONNX.js or TensorFlow.js.
1+
## Run PyTorch models in the browser using ONNX.js
2+
3+
Run PyTorch models in the browser with JavaScript by first converting your PyTorch model into the ONNX format and then loading that ONNX model in your website or app using ONNX.js. In the video tutorial below, I take you through this process using the demo example of a handwritten digit recognition model trained on the MNIST dataset.
4+
5+
### Tutorial
6+
https://www.youtube.com/watch?v=IK7nBOLYzdE
7+
[<img src="https://img.youtube.com/vi/IK7nBOLYzdE/hqdefault.jpg">](https://www.youtube.com/watch?v=IK7nBOLYzdE)
8+
9+
### Live Demo and Code Sandbox
10+
11+
* [Live demo](https://vgzep.csb.app/)
12+
13+
* [Code sandbox](https://codesandbox.io/s/pytorch-to-javascript-with-onnx-vgzep)
14+
15+
16+
### The files in this repo (and a description of what they do)
17+
```
18+
├── demo
19+
│ ├── test.html (a minimal test for making sure the generated ONNX model works, uses ONNX.js to load and run the generated ONNX model)
20+
│ └── onnx_model.onnx (a copy of the generated ONNX model, used by the test code)
21+
├── convert_to_onnx.py (converts a trained PyTorch model into an ONNX model)
22+
├── inference_mnist_model.py (the PyTorch model description (without the trained parameters) used by convert_to_onnx.py to generate the ONNX model)
23+
├── inputs_batch_preview.png (a preview of a batch of augmented input data, generated by preview_mnist_dataset.py)
24+
├── onnx_model.py (the ONNX model generated by convert_to_onnx.py)
25+
├── preview_dataset.py (for testing out different types of data augmentation)
26+
├── pytorch_model.pt (the trained PyTorch model parameters, generated by train_mnist.model.py, and used by convert_to_onnx.py to generate the ONNX model)
27+
└── train_mnist_model.pt (trains the PyTorch model and saves the trained parameters as pytorch_model.pt)
28+
```
29+
30+
### The benefits of running a model in the browser:
31+
* Faster inference times with smaller models.
32+
* Easy to host and scale (only static files).
33+
* Offline support.
34+
* User privacy (can keep the data on the device).
35+
36+
### The benefits of using a backend server:
37+
* Faster load times (don't have to download the model).
38+
* Faster and consistent inference times with larger models (can take advantage of GPUs or other accelerators).
39+
* Model privacy (don't have to share your model if you want to keep it private).

‎convert_to_onnx.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
from inference_mnist_model import Net
4+
5+
6+
def main():
7+
pytorch_model = Net()
8+
pytorch_model.load_state_dict(torch.load('pytorch_model.pt'))
9+
pytorch_model.eval()
10+
dummy_input = torch.zeros(280 * 280 * 4)
11+
torch.onnx.export(pytorch_model, dummy_input, 'onnx_model.onnx', verbose=True)
12+
13+
14+
if __name__ == '__main__':
15+
main()

‎demo/onnx_model.onnx

4.58 MB
Binary file not shown.

‎demo/test.html

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<html>
2+
<body>
3+
<script src="https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js"></script>
4+
<script>
5+
async function test() {
6+
const sess = new onnx.InferenceSession()
7+
await sess.loadModel('./onnx_model.onnx')
8+
const input = new onnx.Tensor(new Float32Array(280 * 280 * 4), 'float32', [280 * 280 * 4])
9+
const outputMap = await sess.run([input])
10+
const outputTensor = outputMap.values().next().value
11+
console.log(`Output tensor: ${outputTensor.data}`)
12+
}
13+
test()
14+
</script>
15+
</body>
16+
</html>

‎inference_mnist_model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
MEAN = 0.1307
6+
STANDARD_DEVIATION = 0.3081
7+
8+
9+
class Net(nn.Module):
10+
def __init__(self):
11+
super(Net, self).__init__()
12+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
13+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
14+
self.dropout1 = nn.Dropout2d(0.25)
15+
self.dropout2 = nn.Dropout2d(0.5)
16+
self.fc1 = nn.Linear(9216, 128)
17+
self.fc2 = nn.Linear(128, 10)
18+
19+
def forward(self, x):
20+
x = x.reshape(280, 280, 4)
21+
x = torch.narrow(x, dim=2, start=3, length=1)
22+
x = x.reshape(1, 1, 280, 280)
23+
x = F.avg_pool2d(x, 10, stride=10)
24+
x = x / 255
25+
x = (x - MEAN) / STANDARD_DEVIATION
26+
27+
x = self.conv1(x)
28+
x = F.relu(x)
29+
x = self.conv2(x)
30+
x = F.max_pool2d(x, 2)
31+
x = self.dropout1(x)
32+
x = torch.flatten(x, 1)
33+
x = self.fc1(x)
34+
x = F.relu(x)
35+
x = self.dropout2(x)
36+
x = self.fc2(x)
37+
output = F.softmax(x, dim=1)
38+
return output

‎inputs_batch_preview.png

135 KB
Loading[フレーム]

‎onnx_model.onnx

4.58 MB
Binary file not shown.

‎preview_dataset.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
import torchvision
3+
4+
5+
def main():
6+
train_loader = torch.utils.data.DataLoader(
7+
torchvision.datasets.MNIST(
8+
'data', train=True, download=True,
9+
transform=torchvision.transforms.Compose([
10+
11+
# torchvision.transforms.RandomAffine(
12+
# degrees=30),
13+
14+
# torchvision.transforms.RandomAffine(
15+
# degrees=0, translate=(0.0, 0.5)),
16+
17+
# torchvision.transforms.RandomAffine(
18+
# degrees=0, translate=(0.5, 0.5)),
19+
20+
# torchvision.transforms.RandomAffine(
21+
# degrees=0, scale=(0.25, 1)),
22+
23+
# torchvision.transforms.RandomAffine(
24+
# degrees=0, shear=(-30, 30, -30, 30)),
25+
26+
torchvision.transforms.RandomAffine(
27+
degrees=30, translate=(0.5, 0.5), scale=(0.25, 1),
28+
shear=(-30, 30, -30, 30)),
29+
30+
torchvision.transforms.ToTensor(),
31+
])),
32+
batch_size=800)
33+
inputs_batch, labels_batch = next(iter(train_loader))
34+
grid = torchvision.utils.make_grid(inputs_batch, nrow=40, pad_value=1)
35+
torchvision.utils.save_image(grid, 'inputs_batch_preview.png')
36+
37+
38+
if __name__ == '__main__':
39+
main()

‎pytorch_model.pt

4.58 MB
Binary file not shown.

‎train_mnist_model.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
This code is from PyTorch's MNIST example (with only a few changes):
3+
https://github.com/pytorch/examples/blob/master/mnist/main.py
4+
"""
5+
from __future__ import print_function
6+
import argparse
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
import torch.optim as optim
11+
from torchvision import datasets, transforms
12+
from torch.optim.lr_scheduler import StepLR
13+
14+
15+
class Net(nn.Module):
16+
def __init__(self):
17+
super(Net, self).__init__()
18+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
19+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
20+
self.dropout1 = nn.Dropout2d(0.25)
21+
self.dropout2 = nn.Dropout2d(0.5)
22+
self.fc1 = nn.Linear(9216, 128)
23+
self.fc2 = nn.Linear(128, 10)
24+
25+
def forward(self, x):
26+
x = self.conv1(x)
27+
x = F.relu(x)
28+
x = self.conv2(x)
29+
x = F.max_pool2d(x, 2)
30+
x = self.dropout1(x)
31+
x = torch.flatten(x, 1)
32+
x = self.fc1(x)
33+
x = F.relu(x)
34+
x = self.dropout2(x)
35+
x = self.fc2(x)
36+
output = F.log_softmax(x, dim=1)
37+
return output
38+
39+
40+
def train(args, model, device, train_loader, optimizer, epoch):
41+
model.train()
42+
for batch_idx, (data, target) in enumerate(train_loader):
43+
data, target = data.to(device), target.to(device)
44+
optimizer.zero_grad()
45+
output = model(data)
46+
loss = F.nll_loss(output, target)
47+
loss.backward()
48+
optimizer.step()
49+
if batch_idx % args.log_interval == 0:
50+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
51+
epoch, batch_idx * len(data), len(train_loader.dataset),
52+
100. * batch_idx / len(train_loader), loss.item()))
53+
54+
55+
def test(args, model, device, test_loader):
56+
model.eval()
57+
test_loss = 0
58+
correct = 0
59+
with torch.no_grad():
60+
for data, target in test_loader:
61+
data, target = data.to(device), target.to(device)
62+
output = model(data)
63+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
64+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
65+
correct += pred.eq(target.view_as(pred)).sum().item()
66+
67+
test_loss /= len(test_loader.dataset)
68+
69+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
70+
test_loss, correct, len(test_loader.dataset),
71+
100. * correct / len(test_loader.dataset)))
72+
73+
74+
def main():
75+
# Training settings
76+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
77+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
78+
help='input batch size for training (default: 64)')
79+
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
80+
help='input batch size for testing (default: 1000)')
81+
parser.add_argument('--epochs', type=int, default=14, metavar='N',
82+
help='number of epochs to train (default: 14)')
83+
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
84+
help='learning rate (default: 1.0)')
85+
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
86+
help='Learning rate step gamma (default: 0.7)')
87+
parser.add_argument('--no-cuda', action='store_true', default=False,
88+
help='disables CUDA training')
89+
parser.add_argument('--seed', type=int, default=1, metavar='S',
90+
help='random seed (default: 1)')
91+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
92+
help='how many batches to wait before logging training status')
93+
94+
parser.add_argument('--save-model', action='store_true', default=False,
95+
help='For Saving the current Model')
96+
args = parser.parse_args()
97+
use_cuda = not args.no_cuda and torch.cuda.is_available()
98+
99+
torch.manual_seed(args.seed)
100+
101+
device = torch.device("cuda" if use_cuda else "cpu")
102+
103+
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
104+
train_loader = torch.utils.data.DataLoader(
105+
datasets.MNIST('data', train=True, download=True,
106+
transform=transforms.Compose([
107+
# Add random transformations to the image.
108+
transforms.RandomAffine(
109+
degrees=30, translate=(0.5, 0.5), scale=(0.25, 1),
110+
shear=(-30, 30, -30, 30)),
111+
112+
transforms.ToTensor(),
113+
transforms.Normalize((0.1307,), (0.3081,))
114+
])),
115+
batch_size=args.batch_size, shuffle=True, **kwargs)
116+
test_loader = torch.utils.data.DataLoader(
117+
datasets.MNIST('data', train=False, transform=transforms.Compose([
118+
transforms.ToTensor(),
119+
transforms.Normalize((0.1307,), (0.3081,))
120+
])),
121+
batch_size=args.test_batch_size, shuffle=True, **kwargs)
122+
123+
model = Net().to(device)
124+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
125+
126+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
127+
for epoch in range(1, args.epochs + 1):
128+
train(args, model, device, train_loader, optimizer, epoch)
129+
test(args, model, device, test_loader)
130+
scheduler.step()
131+
132+
torch.save(model.state_dict(), "pytorch_model.pt")
133+
134+
135+
if __name__ == '__main__':
136+
main()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /