@@ -46,7 +46,7 @@ def process(record):
4646
4747 # Convert 
4848 images  =  tf .map_fn (tf .image .decode_jpeg , tf .reshape (images , [- 1 ]), ** kwargs )
49-  images  =  tf .reshape (images , (- 1 , SEQ_DIM , 3 , IMG_DIM , IMG_DIM ))
49+  images  =  tf .reshape (images , (- 1 , SEQ_DIM , IMG_DIM , IMG_DIM , 3 ))
5050 poses  =  tf .reshape (poses , (- 1 , SEQ_DIM , POSE_DIM ))
5151
5252 # Numpy conversion 
@@ -64,8 +64,8 @@ def convert(record, batch_size):
6464 batch_process  =  lambda  r : chunk (process (r ), batch_size )
6565
6666 for  i , batch  in  enumerate (batch_process (record )):
67-  path  =  os .path .join (path , "{0:}-{1:02}.pt.gz" .format (basename , i ))
68-  with  gzip .open (path , 'wb' ) as  f :
67+  p  =  os .path .join (path , "{0:}-{1:02}.pt.gz" .format (basename , i ))
68+  with  gzip .open (p , 'wb' ) as  f :
6969 torch .save (list (batch ), f )
7070
7171if  __name__  ==  '__main__' :
@@ -91,4 +91,4 @@ def convert(record, batch_size):
9191
9292 with  mp .Pool (processes = mp .cpu_count ()) as  pool :
9393 f  =  partial (convert , batch_size = args .batch_size )
94-  pool .map (f , records )
94+  pool .map (f , records )
0 commit comments