Skip to content

Commit 057cd87

Browse files
committed
fix vgg bug
1 parent a382fbb commit 057cd87

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

models/vgg.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class VGG(nn.Module):
2525
def __init__(self, features, num_classes=1000):
2626
super(VGG, self).__init__()
2727
self.features = features
28-
self.classifier = nn.Linear(512, 10)
28+
self.classifier = nn.Linear(512, num_classes)
2929
self._initialize_weights()
3030

3131
def forward(self, x):
@@ -74,69 +74,65 @@ def make_layers(cfg, batch_norm=False):
7474
}
7575

7676

77-
def vgg11(pretrained=False, **kwargs):
77+
def vgg11(**kwargs):
7878
"""VGG 11-layer model (configuration "A")
7979
8080
Args:
8181
pretrained (bool): If True, returns a model pre-trained on ImageNet
8282
"""
8383
model = VGG(make_layers(cfg['A']), **kwargs)
84-
if pretrained:
85-
model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
8684
return model
8785

8886

8987
def vgg11_bn(**kwargs):
9088
"""VGG 11-layer model (configuration "A") with batch normalization"""
91-
return VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
89+
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
90+
return model
9291

9392

94-
def vgg13(pretrained=False, **kwargs):
93+
def vgg13(**kwargs):
9594
"""VGG 13-layer model (configuration "B")
9695
9796
Args:
9897
pretrained (bool): If True, returns a model pre-trained on ImageNet
9998
"""
10099
model = VGG(make_layers(cfg['B']), **kwargs)
101-
if pretrained:
102-
model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
103100
return model
104101

105102

106103
def vgg13_bn(**kwargs):
107104
"""VGG 13-layer model (configuration "B") with batch normalization"""
108-
return VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
105+
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
106+
return model
109107

110108

111-
def vgg16(pretrained=False, **kwargs):
109+
def vgg16(**kwargs):
112110
"""VGG 16-layer model (configuration "D")
113111
114112
Args:
115113
pretrained (bool): If True, returns a model pre-trained on ImageNet
116114
"""
117115
model = VGG(make_layers(cfg['D']), **kwargs)
118-
if pretrained:
119-
model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
120116
return model
121117

122118

123119
def vgg16_bn(**kwargs):
124120
"""VGG 16-layer model (configuration "D") with batch normalization"""
125-
return VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
121+
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
122+
return model
126123

127124

128-
def vgg19(pretrained=False, **kwargs):
125+
def vgg19(**kwargs):
129126
"""VGG 19-layer model (configuration "E")
130127
131128
Args:
132129
pretrained (bool): If True, returns a model pre-trained on ImageNet
133130
"""
134131
model = VGG(make_layers(cfg['E']), **kwargs)
135-
if pretrained:
136-
model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
137132
return model
138133

139134

140135
def vgg19_bn(**kwargs):
141136
"""VGG 19-layer model (configuration 'E') with batch normalization"""
142-
return VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
137+
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
138+
return model

0 commit comments

Comments
 (0)