Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit ed09b64

Browse files
author
Morvan Zhou
committed
update
1 parent 3d8df0c commit ed09b64

1 file changed

Lines changed: 4 additions & 11 deletions

File tree

‎tutorial-contents/406_GAN.py‎

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@ def artist_works(): # painting from the famous artist (real target)
3737
paintings = torch.from_numpy(paintings).float()
3838
return Variable(paintings)
3939

40-
41-
def G_ideas(): # the random ideas for generator to draw something
42-
z = torch.randn(BATCH_SIZE, N_IDEAS)
43-
return Variable(z)
44-
45-
4640
G = nn.Sequential( # Generator
4741
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
4842
nn.ReLU(),
@@ -63,15 +57,14 @@ def G_ideas(): # the random ideas for generator to draw something
6357
plt.show()
6458
for step in range(10000):
6559
artist_paintings = artist_works() # real painting from artist
66-
G_paintings = G(G_ideas()) # fake painting from G (random ideas)
60+
G_ideas = Variable(torch.randn(BATCH_SIZE, N_IDEAS)) # random ideas
61+
G_paintings = G(G_ideas) # fake painting from G (random ideas)
6762

6863
prob_artist0 = D(artist_paintings) # D try to increase this prob
6964
prob_artist1 = D(G_paintings) # D try to reduce this prob
7065

71-
D_score0 = torch.log(prob_artist0) # maximise this for D
72-
D_score1 = torch.log(1. - prob_artist1) # maximise this for D
73-
D_loss = - torch.mean(D_score0 + D_score1) # minimise the negative of both two above for D
74-
G_loss = torch.mean(D_score1) # minimise D score w.r.t G
66+
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
67+
G_loss = torch.mean(torch.log(1. - prob_artist1))
7568

7669
opt_D.zero_grad()
7770
D_loss.backward(retain_variables=True) # retain_variables for reusing computational graph

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /