Skip to content

Commit 90359ff

Browse files
removed sigmoid function from critic
1 parent ec69b09 commit 90359ff

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

GAN/WGAN/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torchsummary import summary
33

44

5-
class Discriminator(nn.Module):
5+
class Critic(nn.Module):
66
def __init__(self, img_channels, feature_d):
77
super().__init__()
88
self.net = nn.Sequential(
@@ -17,7 +17,7 @@ def __init__(self, img_channels, feature_d):
1717
kernel_size=4, stride=2, padding=1),
1818
nn.Conv2d(feature_d*8, 1,
1919
kernel_size=4, stride=2, padding=0),
20-
nn.Sigmoid()
20+
2121
)
2222
self.initialize_weights()
2323

@@ -55,7 +55,8 @@ def __init__(self, z_dim, img_channels, feature_d):
5555

5656
def __block(self, in_channels, out_channels, **kwargs):
5757

58-
return [nn.ConvTranspose2d(in_channels, out_channels, **kwargs, bias=False),
58+
return [nn.ConvTranspose2d(in_channels, out_channels,
59+
**kwargs, bias=False),
5960
nn.BatchNorm2d(out_channels),
6061
nn.ReLU()]
6162

@@ -69,7 +70,7 @@ def forward(self, x):
6970

7071

7172
if __name__ == "__main__":
72-
disc = Discriminator(1, 128)
73+
disc = Critic(1, 128)
7374
gen = Faker(100, 1, 128)
7475

7576
summary(disc, (1, 64, 64))

0 commit comments

Comments
 (0)