Skip to content

Commit 3a7a9cf

Browse files
committed
model serialization code changed
1 parent 6c74dfa commit 3a7a9cf

File tree

16 files changed

+41
-37
lines changed

16 files changed

+41
-37
lines changed

tutorials/01 - Linear Regression/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ def forward(self, x):
6161
plt.show()
6262

6363
# Save the Model
64-
torch.save(model, 'model.pkl')
64+
torch.save(model.state_dict(), 'model.pkl')

tutorials/02 - Logistic Regression/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,4 @@ def forward(self, x):
7979
print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
8080

8181
# Save the Model
82-
torch.save(model, 'model.pkl')
82+
torch.save(model.state_dict(), 'model.pkl')

tutorials/03 - Feedforward Neural Network/main-gpu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,7 @@ def forward(self, x):
8181
total += labels.size(0)
8282
correct += (predicted.cpu() == labels).sum()
8383

84-
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
84+
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
85+
86+
# Save the Model
87+
torch.save(net.state_dict(), 'model.pkl')

tutorials/03 - Feedforward Neural Network/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,7 @@ def forward(self, x):
8181
total += labels.size(0)
8282
correct += (predicted == labels).sum()
8383

84-
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
84+
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
85+
86+
# Save the Model
87+
torch.save(net.state_dict(), 'model.pkl')

tutorials/04 - Convolutional Neural Network/main-gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,4 @@ def forward(self, x):
9090
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
9191

9292
# Save the Trained Model
93-
torch.save(cnn, 'cnn.pkl')
93+
torch.save(cnn.state_dict(), 'cnn.pkl')

tutorials/04 - Convolutional Neural Network/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,4 @@ def forward(self, x):
9090
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
9191

9292
# Save the Trained Model
93-
torch.save(cnn, 'cnn.pkl')
93+
torch.save(cnn.state_dict(), 'cnn.pkl')

tutorials/05 - Deep Residual Network/main-gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,4 @@ def forward(self, x):
144144
print('Accuracy of the model on the test images: %d %%' % (100 * correct / total))
145145

146146
# Save the Model
147-
torch.save(resnet, 'resnet.pkl')
147+
torch.save(resnet.state_dict(), 'resnet.pkl')

tutorials/05 - Deep Residual Network/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(self, x):
103103
return out
104104

105105
resnet = ResNet(ResidualBlock, [2, 2, 2, 2])
106-
resnet
106+
107107

108108
# Loss and Optimizer
109109
criterion = nn.CrossEntropyLoss()
@@ -127,7 +127,7 @@ def forward(self, x):
127127
print ("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" %(epoch+1, 80, i+1, 500, loss.data[0]))
128128

129129
# Decaying Learning Rate
130-
if (epoch+1) % 30 == 0:
130+
if (epoch+1) % 20 == 0:
131131
lr /= 3
132132
optimizer = torch.optim.Adam(resnet.parameters(), lr=lr)
133133

@@ -144,4 +144,4 @@ def forward(self, x):
144144
print('Accuracy of the model on the test images: %d %%' % (100 * correct / total))
145145

146146
# Save the Model
147-
torch.save(resnet, 'resnet.pkl')
147+
torch.save(resnet.state_dict(), 'resnet.pkl')

tutorials/06 - Recurrent Neural Network/main-gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,4 @@ def forward(self, x):
9292
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
9393

9494
# Save the Model
95-
torch.save(rnn, 'rnn.pkl')
95+
torch.save(rnn.state_dict(), 'rnn.pkl')

tutorials/06 - Recurrent Neural Network/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,4 @@ def forward(self, x):
9292
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
9393

9494
# Save the Model
95-
torch.save(rnn, 'rnn.pkl')
95+
torch.save(rnn.state_dict(), 'rnn.pkl')

0 commit comments

Comments
 (0)