Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 7aab770

Browse files
committed
Fix tfrecord-converter
1 parent 08bf31e commit 7aab770

File tree

5 files changed

+23
-13
lines changed

5 files changed

+23
-13
lines changed

‎run-gqn.py‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
6868

6969
def step(engine, batch):
70+
model.train()
71+
7072
x, v = batch
7173
x, v = x.to(device), v.to(device)
7274
x, v, x_q, v_q = partition(x, v)
@@ -101,7 +103,10 @@ def step(engine, batch):
101103
# Trainer and metrics
102104
trainer = Engine(step)
103105
metric_names = ["elbo", "kl", "sigma", "mu"]
104-
metrics = [RunningAverage(output_transform=lambda x: x[m]).attach(trainer, m) for m in metric_names]
106+
RunningAverage(output_transform=lambda x: x["elbo"]).attach(trainer, "elbo")
107+
RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl")
108+
RunningAverage(output_transform=lambda x: x["sigma"]).attach(trainer, "sigma")
109+
RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu")
105110
ProgressBar().attach(trainer, metric_names=metric_names)
106111

107112
# Model checkpointing
@@ -142,6 +147,7 @@ def save_images(engine):
142147

143148
@trainer.on(Events.EPOCH_COMPLETED)
144149
def validate(engine):
150+
model.eval()
145151
with torch.no_grad():
146152
x, v = next(iter(valid_loader))
147153
x, v = x.to(device), v.to(device)

‎scripts/data.sh‎

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
#!/usr/bin/env bash
22

3-
LOCATION=1ドル
4-
BATCH_SIZE=2ドル
3+
LOCATION=1ドル# example: /tmp/data
4+
BATCH_SIZE=2ドル# example: 64
55

66
echo "Downloading data"
77
gsutil -m cp -R gs://gqn-dataset/shepard_metzler_5_parts $LOCATION
88

9-
echo "Deleting small records"
10-
TRAIN_PATH="$LOCATION/shepard_metzler_5_parts/train"
11-
find "$TRAIN_PATH/*.tfrecord" -type f -size -10M | xargs rm# remove smaller than 10mb
9+
echo "Deleting small records"# less than 10MB
10+
DATA_PATH="$LOCATION/shepard_metzler_5_parts/**/*.tfrecord"
11+
find $DATA_PATH -type f -size -10M | xargs rm
1212

1313
echo "Converting data"
1414
python tfrecord-converter.py $LOCATION shepard_metzler_5_parts -b $BATCH_SIZE -m "train"
1515
echo "Training data: done"
1616
python tfrecord-converter.py $LOCATION shepard_metzler_5_parts -b $BATCH_SIZE -m "test"
17-
echo "Testing data: done"
17+
echo "Testing data: done"
18+
19+
echo "Removing original records"
20+
rm -rf "$LOCATION/shepard_metzler_5_parts/**/*.tfrecord"

‎scripts/gpu.sh‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ python ../run-gqn.py \
1414
--log_dir "../logs" \
1515
--data_parallel "True" \
1616
--batch_size 1 \
17-
--n_workers 6
17+
--workers 6

‎scripts/tfrecord-converter.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def process(record):
4646

4747
# Convert
4848
images = tf.map_fn(tf.image.decode_jpeg, tf.reshape(images, [-1]), **kwargs)
49-
images = tf.reshape(images, (-1, SEQ_DIM, 3, IMG_DIM, IMG_DIM))
49+
images = tf.reshape(images, (-1, SEQ_DIM, IMG_DIM, IMG_DIM, 3))
5050
poses = tf.reshape(poses, (-1, SEQ_DIM, POSE_DIM))
5151

5252
# Numpy conversion
@@ -64,8 +64,8 @@ def convert(record, batch_size):
6464
batch_process = lambda r: chunk(process(r), batch_size)
6565

6666
for i, batch in enumerate(batch_process(record)):
67-
path = os.path.join(path, "{0:}-{1:02}.pt.gz".format(basename, i))
68-
with gzip.open(path, 'wb') as f:
67+
p = os.path.join(path, "{0:}-{1:02}.pt.gz".format(basename, i))
68+
with gzip.open(p, 'wb') as f:
6969
torch.save(list(batch), f)
7070

7171
if __name__ == '__main__':
@@ -91,4 +91,4 @@ def convert(record, batch_size):
9191

9292
with mp.Pool(processes=mp.cpu_count()) as pool:
9393
f = partial(convert, batch_size=args.batch_size)
94-
pool.map(f, records)
94+
pool.map(f, records)

‎shepardmetzler.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def __getitem__(self, idx):
4949
images, viewpoints = list(zip(*data))
5050

5151
# (b, m, c, h, w)
52-
images = torch.FloatTensor(images)
52+
images = torch.FloatTensor(images)/255
53+
images = images.permute(0, 1, 4, 2, 3)
5354
if self.transform:
5455
images = self.transform(images)
5556

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /