1818from  torchvision .utils  import  save_image 
1919
2020from  gqn  import  GenerativeQueryNetwork 
21- from  shepardmetzler  import  ShepardMetzler , Scene 
21+ from  shepardmetzler  import  ShepardMetzler , Scene ,  transform_viewpoint 
2222
2323cuda  =  torch .cuda .is_available ()
2424device  =  torch .device ("cuda:0"  if  cuda  else  "cpu" )
2525
26- def  transform_viewpoint (v ):
27-  """ 
28-  Transforms the viewpoint vector into a consistent 
29-  representation 
30-  """ 
31-  w , z  =  torch .split (v , 3 , dim = - 1 )
32-  y , p  =  torch .split (z , 1 , dim = - 1 )
33- 34-  # position, [yaw, pitch] 
35-  view_vector  =  [w , torch .cos (y ), torch .sin (y ), torch .cos (p ), torch .sin (p )]
36-  v_hat  =  torch .cat (view_vector , dim = - 1 )
37- 38-  return  v_hat 
39- 4026
4127if  __name__  ==  '__main__' :
4228 parser  =  argparse .ArgumentParser (description = 'Generative Query Network on Shepard Metzler Example' )
43-  parser .add_argument ('--epochs ' , type = int , default = 10000 , help = 'number of epochs  to train  (default: 10000 )' )
29+  parser .add_argument ('--gradient_steps ' , type = int , default = 2 * ( 10 ** 6 ) , help = 'number of gradient steps  to run  (default: 2 million )' )
4430 parser .add_argument ('--batch_size' , type = int , default = 36 , help = 'size of batch (default: 36)' )
4531 parser .add_argument ('--data_dir' , type = str , help = 'location of training data' , default = "train" )
4632 parser .add_argument ('--workers' , type = int , help = 'number of data loading workers' , default = 2 )
@@ -71,7 +57,13 @@ def transform_viewpoint(v):
7157 kwargs  =  {'num_workers' : args .workers , 'pin_memory' : True } if  cuda  else  {}
7258 loader  =  DataLoader (dataset , batch_size = args .batch_size , shuffle = True , ** kwargs )
7359
74-  for  epoch  in  range (args .epochs ):
60+  # Number of gradient steps 
61+  s  =  0 
62+  while  True :
63+  if  s  >=  args .gradient_steps :
64+  torch .save (model , "model-final.pt" )
65+  break 
66+ 7567 for  x , v  in  tqdm (loader ):
7668 if  args .fp16 :
7769 x , v  =  x .half (), v .half ()
@@ -96,26 +88,27 @@ def transform_viewpoint(v):
9688
9789 optimizer .step ()
9890 optimizer .zero_grad ()
91+ 92+  s  +=  1 
93+ 94+  # Keep a checkpoint every 100,000 steps 
95+  if  s  %  100000  ==  0 :
96+  torch .save (model , "model-{}.pt" .format (s ))
9997
10098 with  torch .no_grad ():
101-  print ("Epoch: {} |ELBO\t {} |NLL\t {} |KL\t {}" .format (epoch , elbo .item (), reconstruction .item (), kl_divergence .item ()))
102- 103-  if  epoch  %  5  ==  0 :
104-  x , v  =  next (iter (loader ))
105-  x , v  =  x .to (device ), v .to (device )
99+  print ("|Steps: {}\t |NLL: {}\t |KL: {}\t |" .format (s , reconstruction .item (), kl_divergence .item ()))
106100
107-  x_mu , _ , r , _  =  model (x , v )
101+  x , v  =  next (iter (loader ))
102+  x , v  =  x .to (device ), v .to (device )
108103
109-  r = r . view ( - 1 ,  1 ,  16 ,  16 )
104+  x_mu ,  _ ,  r ,  _ = model ( x ,  v )
110105
111-  save_image (r .float (), "representation-{}.jpg" .format (epoch ))
112-  save_image (x_mu .float (), "reconstruction-{}.jpg" .format (epoch ))
106+  r  =  r .view (- 1 , 1 , 16 , 16 )
113107
114-  if epoch % 10 == 0 : 
115-  torch . save ( model , "model-{}.pt"  . format ( epoch ) )
108+  save_image ( r . float (),  "representation.jpg" ) 
109+  save_image ( x_mu . float () , "reconstruction.jpg"  )
116110
117111 # Anneal learning rate 
118-  s  =  epoch  +  1 
119112 mu  =  max (mu_f  +  (mu_i  -  mu_f )* (1  -  s / (1.6  *  10 ** 6 )), mu_f )
120113 optimizer .lr  =  mu  *  math .sqrt (1  -  0.999 ** s )/ (1  -  0.9 ** s )
121114
0 commit comments