Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 6eb696e

Browse files
committed
Add notebook and pretrained model
1 parent 07e682d commit 6eb696e

File tree

6 files changed

+390
-42
lines changed

6 files changed

+390
-42
lines changed

‎Mental rotation.ipynb‎

Lines changed: 335 additions & 0 deletions
Large diffs are not rendered by default.

‎README.md‎

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,27 @@ described in the paper look at the article by [DeepMind](https://deepmind.com/bl
1010
The current implementation generalises to any of the datasets described
1111
in the paper. However, currently, *only the Shepard-Metzler dataset* has
1212
been implemented. To use this dataset you must download the [tf-records
13-
from DeepMind](https://github.com/deepmind/gqn-datasets) and convert them to PyTorch tensors.
13+
from DeepMind](https://github.com/deepmind/gqn-datasets) and convert them to PyTorch tensors,
14+
such as by using the [gqn_datasets_translator](https://github.com/l3robot/gqn_datasets_translator).
1415

15-
## Implementation
16+
The model can be trained in full by in accordance to the paper by running the
17+
script `run-gqn.py`.
1618

17-
The implementation shown in this repository consists of the `tower`
18-
representation architecture along with the generative model that is
19-
similar to the one described in "Towards conceptual compression" by
20-
Gregor et al.
19+
## Implementation
2120

22-
![](https://kevinzakka.github.io/assets/rnn/draw2.gif)
21+
The implementation shown in this repository consists of all of the
22+
representation architectures described in the paper along with the
23+
generative model that is similar to the one described in
24+
"Towards conceptual compression" by Gregor et al.
2325

2426
Additionally, this repository also contains implementations of the **DRAW
2527
model and the ConvolutionalDRAW** model both described by Gregor et al.
26-
These serve as the basis for the generative model in the GQN.
2728

28-
## Results
29+
## Contributing
30+
31+
The best way to contribute to this project is to train the model as described
32+
in the paper (by running `run-gqn.py`) and submitting a pull request with the
33+
fully trained model.
2934

30-
Currently, the results are pending as the model is very computationally
31-
costly to train for the datasets described in the paper.
35+
Currently, the repository contains a model `model-final.pt` that has only
36+
been trained on a subset of the data.

‎gqn/representation.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class TowerRepresentation(nn.Module):
7-
def __init__(self, n_channels=3, v_dim=7, r_dim=256, pool=True):
7+
def __init__(self, n_channels, v_dim, r_dim=256, pool=True):
88
"""
99
Network that generates a condensed representation
1010
vector from a joint input of image and viewpoint.
@@ -68,7 +68,7 @@ def forward(self, x, v):
6868

6969

7070
class PyramidRepresentation(nn.Module):
71-
def __init__(self, n_channels=3, v_dim=7, r_dim=256, pool=True):
71+
def __init__(self, n_channels, v_dim, r_dim=256):
7272
"""
7373
Network that generates a condensed representation
7474
vector from a joint input of image and viewpoint.

‎model-final.pt‎

46.9 MB
Binary file not shown.

‎run-gqn.py‎

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,15 @@
1818
from torchvision.utils import save_image
1919

2020
from gqn import GenerativeQueryNetwork
21-
from shepardmetzler import ShepardMetzler, Scene
21+
from shepardmetzler import ShepardMetzler, Scene, transform_viewpoint
2222

2323
cuda = torch.cuda.is_available()
2424
device = torch.device("cuda:0" if cuda else "cpu")
2525

26-
def transform_viewpoint(v):
27-
"""
28-
Transforms the viewpoint vector into a consistent
29-
representation
30-
"""
31-
w, z = torch.split(v, 3, dim=-1)
32-
y, p = torch.split(z, 1, dim=-1)
33-
34-
# position, [yaw, pitch]
35-
view_vector = [w, torch.cos(y), torch.sin(y), torch.cos(p), torch.sin(p)]
36-
v_hat = torch.cat(view_vector, dim=-1)
37-
38-
return v_hat
39-
4026

4127
if __name__ == '__main__':
4228
parser = argparse.ArgumentParser(description='Generative Query Network on Shepard Metzler Example')
43-
parser.add_argument('--epochs', type=int, default=10000, help='number of epochs to train (default: 10000)')
29+
parser.add_argument('--gradient_steps', type=int, default=2*(10**6), help='number of gradient steps to run (default: 2 million)')
4430
parser.add_argument('--batch_size', type=int, default=36, help='size of batch (default: 36)')
4531
parser.add_argument('--data_dir', type=str, help='location of training data', default="train")
4632
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
@@ -71,7 +57,13 @@ def transform_viewpoint(v):
7157
kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}
7258
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
7359

74-
for epoch in range(args.epochs):
60+
# Number of gradient steps
61+
s = 0
62+
while True:
63+
if s >= args.gradient_steps:
64+
torch.save(model, "model-final.pt")
65+
break
66+
7567
for x, v in tqdm(loader):
7668
if args.fp16:
7769
x, v = x.half(), v.half()
@@ -96,26 +88,27 @@ def transform_viewpoint(v):
9688

9789
optimizer.step()
9890
optimizer.zero_grad()
91+
92+
s += 1
93+
94+
# Keep a checkpoint every 100,000 steps
95+
if s % 100000 == 0:
96+
torch.save(model, "model-{}.pt".format(s))
9997

10098
with torch.no_grad():
101-
print("Epoch: {} |ELBO\t{} |NLL\t{} |KL\t{}".format(epoch, elbo.item(), reconstruction.item(), kl_divergence.item()))
102-
103-
if epoch % 5 == 0:
104-
x, v = next(iter(loader))
105-
x, v = x.to(device), v.to(device)
99+
print("|Steps: {}\t|NLL: {}\t|KL: {}\t|".format(s, reconstruction.item(), kl_divergence.item()))
106100

107-
x_mu, _, r, _ = model(x, v)
101+
x, v = next(iter(loader))
102+
x, v = x.to(device), v.to(device)
108103

109-
r=r.view(-1, 1, 16, 16)
104+
x_mu, _, r, _=model(x, v)
110105

111-
save_image(r.float(), "representation-{}.jpg".format(epoch))
112-
save_image(x_mu.float(), "reconstruction-{}.jpg".format(epoch))
106+
r = r.view(-1, 1, 16, 16)
113107

114-
ifepoch%10==0:
115-
torch.save(model, "model-{}.pt".format(epoch))
108+
save_image(r.float(), "representation.jpg")
109+
save_image(x_mu.float(), "reconstruction.jpg")
116110

117111
# Anneal learning rate
118-
s = epoch + 1
119112
mu = max(mu_f + (mu_i - mu_f)*(1 - s/(1.6 * 10**6)), mu_f)
120113
optimizer.lr = mu * math.sqrt(1 - 0.999**s)/(1 - 0.9**s)
121114

‎shepardmetzler.py‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88
Scene = collections.namedtuple('Scene', ['frames', 'cameras'])
99

1010

11+
def transform_viewpoint(v):
12+
"""
13+
Transforms the viewpoint vector into a consistent
14+
representation
15+
"""
16+
w, z = torch.split(v, 3, dim=-1)
17+
y, p = torch.split(z, 1, dim=-1)
18+
19+
# position, [yaw, pitch]
20+
view_vector = [w, torch.cos(y), torch.sin(y), torch.cos(p), torch.sin(p)]
21+
v_hat = torch.cat(view_vector, dim=-1)
22+
23+
return v_hat
24+
25+
1126
class ShepardMetzler(Dataset):
1227
def __init__(self, root_dir, transform=None, target_transform=None):
1328
self.root_dir = root_dir

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /