-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathvae.py
159 lines (132 loc) · 5.7 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import pytorch_lightning as pl
import torch
import os
import torch.nn as nn
from torch.nn import functional as F
from pl_bolts.models.autoencoders.components import resnet18_encoder, resnet18_decoder
from collections import OrderedDict
class VAE(nn.Module):
def __init__(self, beta=4, enc_out_dim=512, latent_dim=256, input_height=128, device='cpu'):
super().__init__()
self.beta = beta
self.latent_dim = latent_dim
self.device = device
self.input_height = input_height
# encoder, decoder
self.encoder = resnet18_encoder(False, False)
self.decoder = nn.Sequential(
resnet18_decoder(
latent_dim=latent_dim,
input_height=input_height,
first_conv=False,
maxpool1=False),
nn.Tanh() # Tanh activation to clamp values to [-1, 1] of the input
)
# distribution parameters
self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
self.fc_var = nn.Linear(enc_out_dim, latent_dim)
# for the gaussian likelihood
self.log_scale = nn.Parameter(torch.Tensor([0.0]))
self.p = 0.2
def reparameterise(self, mu, logvar):
if self.training:
std = logvar.mul(0.5).exp_()
eps = std.data.new(std.size()).normal_()
return eps.mul(std).add_(mu)
else:
return mu
def gaussian_likelihood(self, mean, logscale, sample):
scale = torch.exp(logscale)
dist = torch.distributions.Normal(mean, scale)
log_pxz = dist.log_prob(sample)
return log_pxz.sum(dim=(1, 2, 3))
def encode_image(self, x):
# encode x to get the mu and variance parameters
x_encoded = self.encoder(x)
mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)
# sample z from mu and log_var
z = self.reparameterise(mu, log_var)
return z
def get_elbo_loss(self, x, p):
# encode x to get the mu and variance parameters
x_encoded = self.encoder(x)
mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)
# sample z from mu and log_var
z = self.reparameterise(mu, log_var)
x_hat = self.decoder(z)
boot_recon = int(p * self.input_height * self.input_height * 3)
recon_loss = F.mse_loss(x_hat, x, reduction='none').view(-1).topk(boot_recon, sorted=False)[0].sum()
# recon_loss = F.mse_loss(x_hat, x, reduction='sum')
kld = 0.5*torch.sum(log_var.exp() - log_var - 1 + mu.pow(2))
elbo = recon_loss + self.beta * kld
log_dict = {
'elbo': elbo.item(),
'recon_loss': recon_loss.item(),
'kl': kld.item()
}
return elbo, log_dict
def reconstruct(self, n_preds, sampled_noise=None):
'''
Decode from a normal distribution to give images
'''
if sampled_noise is None:
# Z COMES FROM NORMAL(0, 1)
p = torch.distributions.Normal(torch.zeros((self.latent_dim,)), torch.ones((self.latent_dim,)))
z = p.rsample((n_preds,))
else:
z = sampled_noise
# SAMPLE IMAGES
with torch.no_grad():
pred = self.decoder(z.to(self.device)).cpu()
return pred
def forward(self, x):
# encode x to get the mu and variance parameters
x_encoded = self.encoder(x)
mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)
# sample z from mu and log_var
z = self.reparameterise(mu, log_var)
x_hat = self.decoder(z)
return x_hat
def save_weights(self, fpath):
print('saving checkpoint...')
checkpoint = {
'encoder': self.encoder.state_dict(),
'fc_mu': self.fc_mu.state_dict(),
'fc_var': self.fc_var.state_dict(),
'decoder': self.decoder.state_dict()
}
torch.save(checkpoint, fpath)
print(f"checkpoint saved at {fpath}")
def load_weights(self, fpath):
if os.path.isfile(fpath):
checkpoint = torch.load(fpath, map_location=self.device)
self.encoder.load_state_dict(self.sanitise_state_dict(checkpoint['encoder']))
self.fc_mu.load_state_dict(self.sanitise_state_dict(checkpoint['fc_mu']))
self.fc_var.load_state_dict(self.sanitise_state_dict(checkpoint['fc_var']))
self.decoder.load_state_dict(self.sanitise_state_dict(checkpoint['decoder']))
print('checkpoint loaded at {}'.format(fpath))
else:
raise AssertionError(f"No weights file found at {fpath}")
def dataparallel(self, ngpu):
print(f"using {ngpu} gpus, gpu id: {list(range(ngpu))}")
self.encoder = nn.DataParallel(self.encoder, list(range(ngpu)))
self.decoder = nn.DataParallel(self.decoder, list(range(ngpu)))
self.fc_mu = nn.DataParallel(self.fc_mu, list(range(ngpu)))
self.fc_var = nn.DataParallel(self.fc_var, list(range(ngpu)))
def sanitise_state_dict(self, state_dict):
'''
Weights saved with nn.DataParallel wrapper cannot be loaded with a normal net
This utility function serves to remove the module. prefix so that the state_dict can
be loaded without nn.DataParallel wrapper
Args:
state_dict (OrderedDict): the weights to be loaded
Returns:
output_dict (OrderedDict): weights that is able to be loaded without nn.DataParallel wrapper
'''
output_dict = OrderedDict()
for k, v in state_dict.items():
if 'module' in k:
output_dict[k[7:]] = v # remove the first 7 characters 'module.' with string slicing
else:
output_dict[k] = v
return output_dict