Skip to content

Commit 6b5af87

Browse files
authored
Merge pull request yunjey#25 from DingKe/master
use buit-in detach
2 parents 1aab031 + d38e95c commit 6b5af87

File tree

6 files changed

+12
-12
lines changed

6 files changed

+12
-12
lines changed

tutorials/08 - Language Model/main-gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def forward(self, x, h):
6161

6262
# Truncated Backpropagation
6363
def detach(states):
64-
return [Variable(state.data) for state in states]
64+
return [state.detach() for state in states]
6565

6666
# Training
6767
for epoch in range(num_epochs):
@@ -119,4 +119,4 @@ def detach(states):
119119
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
120120

121121
# Save the Trained Model
122-
torch.save(model.state_dict(), 'model.pkl')
122+
torch.save(model.state_dict(), 'model.pkl')

tutorials/08 - Language Model/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def forward(self, x, h):
6161

6262
# Truncated Backpropagation
6363
def detach(states):
64-
return [Variable(state.data) for state in states]
64+
return [state.detach() for state in states]
6565

6666
# Training
6767
for epoch in range(num_epochs):
@@ -119,4 +119,4 @@ def detach(states):
119119
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
120120

121121
# Save the Trained Model
122-
torch.save(model.state_dict(), 'model.pkl')
122+
torch.save(model.state_dict(), 'model.pkl')

tutorials/10 - Generative Adversarial Network/main-gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def forward(self, x):
7777

7878
noise = Variable(torch.randn(images.size(0), 128)).cuda()
7979
fake_images = generator(noise)
80-
outputs = discriminator(fake_images)
80+
outputs = discriminator(fake_images.detach())
8181
fake_loss = criterion(outputs, fake_labels)
8282
fake_score = outputs
8383

@@ -107,4 +107,4 @@ def forward(self, x):
107107

108108
# Save the Models
109109
torch.save(generator.state_dict(), './generator.pkl')
110-
torch.save(discriminator.state_dict(), './discriminator.pkl')
110+
torch.save(discriminator.state_dict(), './discriminator.pkl')

tutorials/10 - Generative Adversarial Network/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def forward(self, x):
7777

7878
noise = Variable(torch.randn(images.size(0), 128))
7979
fake_images = generator(noise)
80-
outputs = discriminator(fake_images)
80+
outputs = discriminator(fake_images.detach())
8181
fake_loss = criterion(outputs, fake_labels)
8282
fake_score = outputs
8383

@@ -107,4 +107,4 @@ def forward(self, x):
107107

108108
# Save the Models
109109
torch.save(generator.state_dict(), './generator.pkl')
110-
torch.save(discriminator.state_dict(), './discriminator.pkl')
110+
torch.save(discriminator.state_dict(), './discriminator.pkl')

tutorials/11 - Deep Convolutional Generative Adversarial Network/main-gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def forward(self, x):
102102

103103
noise = Variable(torch.randn(images.size(0), 128)).cuda()
104104
fake_images = generator(noise)
105-
outputs = discriminator(fake_images)
105+
outputs = discriminator(fake_images.detach())
106106
fake_loss = criterion(outputs, fake_labels)
107107
fake_score = outputs
108108

@@ -131,4 +131,4 @@ def forward(self, x):
131131

132132
# Save the Models
133133
torch.save(generator.state_dict(), './generator.pkl')
134-
torch.save(discriminator.state_dict(), './discriminator.pkl')
134+
torch.save(discriminator.state_dict(), './discriminator.pkl')

tutorials/11 - Deep Convolutional Generative Adversarial Network/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def forward(self, x):
102102

103103
noise = Variable(torch.randn(images.size(0), 128))
104104
fake_images = generator(noise)
105-
outputs = discriminator(fake_images)
105+
outputs = discriminator(fake_images.detch())
106106
fake_loss = criterion(outputs, fake_labels)
107107
fake_score = outputs
108108

@@ -131,4 +131,4 @@ def forward(self, x):
131131

132132
# Save the Models
133133
torch.save(generator.state_dict(), './generator.pkl')
134-
torch.save(discriminator.state_dict(), './discriminator.pkl')
134+
torch.save(discriminator.state_dict(), './discriminator.pkl')

0 commit comments

Comments
 (0)