2
2
from torchsummary import summary
3
3
4
4
5
- class Discriminator (nn .Module ):
5
+ class Critic (nn .Module ):
6
6
def __init__ (self , img_channels , feature_d ):
7
7
super ().__init__ ()
8
8
self .net = nn .Sequential (
@@ -17,7 +17,7 @@ def __init__(self, img_channels, feature_d):
17
17
kernel_size = 4 , stride = 2 , padding = 1 ),
18
18
nn .Conv2d (feature_d * 8 , 1 ,
19
19
kernel_size = 4 , stride = 2 , padding = 0 ),
20
- nn . Sigmoid ()
20
+
21
21
)
22
22
self .initialize_weights ()
23
23
@@ -55,7 +55,8 @@ def __init__(self, z_dim, img_channels, feature_d):
55
55
56
56
def __block (self , in_channels , out_channels , ** kwargs ):
57
57
58
- return [nn .ConvTranspose2d (in_channels , out_channels , ** kwargs , bias = False ),
58
+ return [nn .ConvTranspose2d (in_channels , out_channels ,
59
+ ** kwargs , bias = False ),
59
60
nn .BatchNorm2d (out_channels ),
60
61
nn .ReLU ()]
61
62
@@ -69,7 +70,7 @@ def forward(self, x):
69
70
70
71
71
72
if __name__ == "__main__" :
72
- disc = Discriminator (1 , 128 )
73
+ disc = Critic (1 , 128 )
73
74
gen = Faker (100 , 1 , 128 )
74
75
75
76
summary (disc , (1 , 64 , 64 ))
0 commit comments