1
- # Implementation of https://arxiv.org/pdf/1512.03385.pdf/
1
+ # Implementation of https://arxiv.org/pdf/1512.03385.pdf
2
2
# See section 4.2 for model architecture on CIFAR-10.
3
3
# Some part of the code was referenced below.
4
4
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
5
- import torch
5
+ import torch
6
6
import torch .nn as nn
7
7
import torchvision .datasets as dsets
8
8
import torchvision .transforms as transforms
9
9
from torch .autograd import Variable
10
10
11
- # Image Preprocessing
11
+ # Image Preprocessing
12
12
transform = transforms .Compose ([
13
13
transforms .Scale (40 ),
14
14
transforms .RandomHorizontalFlip (),
17
17
18
18
# CIFAR-10 Dataset
19
19
train_dataset = dsets .CIFAR10 (root = './data/' ,
20
- train = True ,
20
+ train = True ,
21
21
transform = transform ,
22
22
download = True )
23
23
24
24
test_dataset = dsets .CIFAR10 (root = './data/' ,
25
- train = False ,
25
+ train = False ,
26
26
transform = transforms .ToTensor ())
27
27
28
28
# Data Loader (Input Pipeline)
29
29
train_loader = torch .utils .data .DataLoader (dataset = train_dataset ,
30
- batch_size = 100 ,
30
+ batch_size = 100 ,
31
31
shuffle = True )
32
32
33
33
test_loader = torch .utils .data .DataLoader (dataset = test_dataset ,
34
- batch_size = 100 ,
34
+ batch_size = 100 ,
35
35
shuffle = False )
36
36
37
37
# 3x3 Convolution
38
38
def conv3x3 (in_channels , out_channels , stride = 1 ):
39
- return nn .Conv2d (in_channels , out_channels , kernel_size = 3 ,
39
+ return nn .Conv2d (in_channels , out_channels , kernel_size = 3 ,
40
40
stride = stride , padding = 1 , bias = False )
41
41
42
42
# Residual Block
@@ -49,7 +49,7 @@ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
49
49
self .conv2 = conv3x3 (out_channels , out_channels )
50
50
self .bn2 = nn .BatchNorm2d (out_channels )
51
51
self .downsample = downsample
52
-
52
+
53
53
def forward (self , x ):
54
54
residual = x
55
55
out = self .conv1 (x )
@@ -76,7 +76,7 @@ def __init__(self, block, layers, num_classes=10):
76
76
self .layer3 = self .make_layer (block , 64 , layers [1 ], 2 )
77
77
self .avg_pool = nn .AvgPool2d (8 )
78
78
self .fc = nn .Linear (64 , num_classes )
79
-
79
+
80
80
def make_layer (self , block , out_channels , blocks , stride = 1 ):
81
81
downsample = None
82
82
if (stride != 1 ) or (self .in_channels != out_channels ):
@@ -89,7 +89,7 @@ def make_layer(self, block, out_channels, blocks, stride=1):
89
89
for i in range (1 , blocks ):
90
90
layers .append (block (out_channels , out_channels ))
91
91
return nn .Sequential (* layers )
92
-
92
+
93
93
def forward (self , x ):
94
94
out = self .conv (x )
95
95
out = self .bn (out )
@@ -101,36 +101,36 @@ def forward(self, x):
101
101
out = out .view (out .size (0 ), - 1 )
102
102
out = self .fc (out )
103
103
return out
104
-
104
+
105
105
resnet = ResNet (ResidualBlock , [3 , 3 , 3 ])
106
106
resnet .cuda ()
107
107
108
108
# Loss and Optimizer
109
109
criterion = nn .CrossEntropyLoss ()
110
110
lr = 0.001
111
111
optimizer = torch .optim .Adam (resnet .parameters (), lr = lr )
112
-
113
- # Training
112
+
113
+ # Training
114
114
for epoch in range (80 ):
115
115
for i , (images , labels ) in enumerate (train_loader ):
116
116
images = Variable (images .cuda ())
117
117
labels = Variable (labels .cuda ())
118
-
118
+
119
119
# Forward + Backward + Optimize
120
120
optimizer .zero_grad ()
121
121
outputs = resnet (images )
122
122
loss = criterion (outputs , labels )
123
123
loss .backward ()
124
124
optimizer .step ()
125
-
125
+
126
126
if (i + 1 ) % 100 == 0 :
127
127
print ("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" % (epoch + 1 , 80 , i + 1 , 500 , loss .data [0 ]))
128
128
129
129
# Decaying Learning Rate
130
130
if (epoch + 1 ) % 20 == 0 :
131
131
lr /= 3
132
- optimizer = torch .optim .Adam (resnet .parameters (), lr = lr )
133
-
132
+ optimizer = torch .optim .Adam (resnet .parameters (), lr = lr )
133
+
134
134
# Test
135
135
correct = 0
136
136
total = 0
@@ -144,4 +144,4 @@ def forward(self, x):
144
144
print ('Accuracy of the model on the test images: %d %%' % (100 * correct / total ))
145
145
146
146
# Save the Model
147
- torch .save (resnet .state_dict (), 'resnet.pkl' )
147
+ torch .save (resnet .state_dict (), 'resnet.pkl' )
0 commit comments