Source code for dsipts.models.d3vae.encoder

# -*-Encoding: utf-8 -*-
"""
Description:
    The model architecture of the bidirectional vae.
    Note: Part of the code are borrowed from 'https://github.com/NVlabs/NVAE'
Authors:
    Li,Yan (liyan22021121@gmail.com)
"""
import math
import numpy as np
import torch
import torch.nn as nn
from .neural_operations import OPS, EncCombinerCell, DecCombinerCell, Conv2D, get_skip_connection
from .utils import get_stride_for_cell_type, get_arch_cells


[docs] class Cell(nn.Module):
[docs] def __init__(self, Cin, Cout, cell_type, arch, use_se): super(Cell, self).__init__() self.cell_type = cell_type stride = get_stride_for_cell_type(self.cell_type) self.skip = get_skip_connection(Cin, stride, channel_mult=2) self.use_se = use_se self._num_nodes = len(arch) self._ops = nn.ModuleList() for i in range(self._num_nodes): stride = get_stride_for_cell_type(self.cell_type) if i == 0 else 1 if i==0: primitive = arch[i] op = OPS[primitive](Cin, Cout, stride) else: primitive = arch[i] op = OPS[primitive](Cout, Cout, stride) self._ops.append(op)
[docs] def forward(self, s): # skip branch skip = self.skip(s) for i in range(self._num_nodes): s = self._ops[i](s) return skip + 0.1 * s
[docs] def soft_clamp5(x: torch.Tensor): return x.div(5.).tanh_().mul(5.)
[docs] def sample_normal_jit(mu, sigma): eps = mu.mul(0).normal_() # print(eps) z = eps.mul_(sigma).add_(mu) # print(z.shape) return z, eps
[docs] class Normal:
[docs] def __init__(self, mu, log_sigma, temp=1.): self.mu = soft_clamp5(mu) log_sigma = soft_clamp5(log_sigma) self.sigma = torch.exp(log_sigma) if temp != 1.: self.sigma *= temp
[docs] def sample(self): return sample_normal_jit(self.mu, self.sigma)
[docs] def sample_given_eps(self, eps): return eps * self.sigma + self.mu
[docs] def log_p(self, samples): normalized_samples = (samples - self.mu) / self.sigma log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - torch.log(self.sigma) return log_p
[docs] def kl(self, normal_dist): term1 = (self.mu - normal_dist.mu) / normal_dist.sigma term2 = self.sigma / normal_dist.sigma return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2)
[docs] class NormalDecoder:
[docs] def __init__(self, param): B, C, H, W = param.size() self.num_c = C // 2 self.mu = param[:, :self.num_c, :, :] # B, 3, H, W self.log_sigma = param[:, self.num_c:, :, :] # B, 3, H, W self.sigma = torch.exp(self.log_sigma) + 1e-2 self.dist = Normal(self.mu, self.log_sigma)
[docs] def log_prob(self, samples): return self.dist.log_p(samples)
[docs] def sample(self,): x, _ = self.dist.sample() return x
[docs] def log_density_gaussian(sample, mu, logvar): """Calculates log density of a Gaussian. Parameters ---------- x: torch.Tensor or np.ndarray or float Value at which to compute the density. mu: torch.Tensor or np.ndarray or float Mean. logvar: torch.Tensor or np.ndarray or float Log variance. """ normalization = - 0.5 * (math.log(2 * math.pi) + logvar) inv_var = torch.exp(-logvar) log_density = normalization - 0.5 * ((sample - mu)**2 * inv_var) log_qz = torch.logsumexp(torch.sum(log_density, [2,3]), dim=1, keepdim=False) log_prod_qzi = torch.logsumexp(log_density, dim=1, keepdim=False).sum((1,2)) loss_p_z = (log_qz - log_prod_qzi) loss_p_z = ((loss_p_z - torch.min(loss_p_z))/(torch.max(loss_p_z)-torch.min(loss_p_z))).mean() return loss_p_z
[docs] class Encoder(nn.Module):
[docs] def __init__(self, channel_mult,mult,prediction_length,num_preprocess_blocks,num_preprocess_cells,num_channels_enc, arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells,embedding_dimension,hidden_size,target_dim,sequence_length,num_layers,dropout_rate): super(Encoder, self).__init__() self.channel_mult = channel_mult self.mult = mult self.prediction_length = prediction_length self.num_preprocess_blocks = num_preprocess_blocks self.num_preprocess_cells = num_preprocess_cells self.num_channels_enc = num_channels_enc self.arch_instance = get_arch_cells(arch_instance) self.stem = Conv2D(1, num_channels_enc, 3, padding=1, bias=True) self.num_latent_per_group = num_latent_per_group self.num_channels_dec = num_channels_dec self.groups_per_scale = groups_per_scale self.num_postprocess_blocks = num_postprocess_blocks self.num_postprocess_cells = num_postprocess_cells self.use_se = False self.input_size = embedding_dimension self.hidden_size = hidden_size self.projection = nn.Linear(embedding_dimension+hidden_size, target_dim) c_scaling = self.channel_mult ** (self.num_preprocess_blocks) #4 spatial_scaling = 2 ** (self.num_preprocess_blocks) #4 prior_ftr0_size = (int(c_scaling * self.num_channels_dec), sequence_length// spatial_scaling, #prediction_length (embedding_dimension + hidden_size + 1) // spatial_scaling) self.prior_ftr0 = nn.Parameter(torch.rand(size=prior_ftr0_size), requires_grad=True) self.z0_size = [self.num_latent_per_group, sequence_length // spatial_scaling, #prediction_length (embedding_dimension+ hidden_size + 1) // spatial_scaling] self.pre_process = self.init_pre_process(self.mult) self.enc_tower = self.init_encoder_tower(self.mult) self.enc0 = nn.Sequential(nn.ELU(), Conv2D(self.num_channels_enc * self.mult, self.num_channels_enc * self.mult, kernel_size=1, bias=True), nn.ELU()) self.enc_sampler, self.dec_sampler = self.init_sampler(self.mult) self.dec_tower = self.init_decoder_tower(self.mult) self.post_process = self.init_post_process(self.mult) self.image_conditional = nn.Sequential(nn.ELU(), Conv2D(int(self.num_channels_dec * self.mult), 2, 3, padding=1, bias=True)) self.rnn = nn.GRU( input_size=sequence_length, hidden_size=prediction_length, num_layers=num_layers, dropout=dropout_rate, batch_first=True, )
[docs] def init_pre_process(self, mult): pre_process = nn.ModuleList() for b in range(self.num_preprocess_blocks): for c in range(self.num_preprocess_cells): if c == self.num_preprocess_cells - 1: arch = self.arch_instance['down_pre'] num_ci = int(self.num_channels_enc * mult) num_co = int(self.channel_mult * num_ci) cell = Cell(num_ci, num_co, cell_type='down_pre', arch=arch, use_se=self.use_se) mult = self.channel_mult * mult else: arch = self.arch_instance['normal_pre'] num_c = self.num_channels_enc * mult cell = Cell(num_c, num_c, cell_type='normal_pre', arch=arch, use_se=self.use_se) pre_process.append(cell) self.mult = mult return pre_process
[docs] def init_encoder_tower(self, mult): enc_tower = nn.ModuleList() for g in range(self.groups_per_scale): arch = self.arch_instance['normal_enc'] num_c = int(self.num_channels_enc * mult) cell = Cell(num_c, num_c, cell_type='normal_enc', arch=arch, use_se=self.use_se) enc_tower.append(cell) if not (g == self.groups_per_scale - 1): num_ce = int(self.num_channels_enc * mult) num_cd = int(self.num_channels_dec * mult) cell = EncCombinerCell(num_ce, num_cd, num_ce, cell_type='combiner_enc') enc_tower.append(cell) self.mult = mult return enc_tower
[docs] def init_decoder_tower(self, mult): dec_tower = nn.ModuleList() for g in range(self.groups_per_scale): num_c = int(self.num_channels_dec * mult) if not (g == 0): arch = self.arch_instance['normal_dec'] cell = Cell(num_c, num_c, cell_type='normal_dec', arch=arch, use_se=self.use_se) dec_tower.append(cell) #print(num_c) cell = DecCombinerCell(num_c, self.num_latent_per_group, num_c, cell_type='combiner_dec') dec_tower.append(cell) self.mult = mult return dec_tower
[docs] def init_sampler(self, mult): enc_sampler = nn.ModuleList() dec_sampler = nn.ModuleList() for g in range(self.groups_per_scale): num_c = int(self.num_channels_enc * mult) cell = Conv2D(num_c, 2 * self.num_latent_per_group, kernel_size=3, padding=1, bias=True) enc_sampler.append(cell) if g != 0: num_c = int(self.num_channels_dec * mult) cell = nn.Sequential( nn.ELU(), Conv2D(num_c, 2 * self.num_latent_per_group, kernel_size=1, padding=0, bias=True)) dec_sampler.append(cell) mult = mult/self.channel_mult return enc_sampler, dec_sampler
[docs] def init_post_process(self, mult): post_process = nn.ModuleList() for b in range(self.num_postprocess_blocks): for c in range(self.num_postprocess_cells): if c == 0: arch = self.arch_instance['up_post'] num_ci = int(self.num_channels_dec * mult) num_co = int(num_ci / self.channel_mult) cell = Cell(num_ci, num_co, cell_type='up_post', arch=arch, use_se=self.use_se) mult = mult / self.channel_mult else: arch = self.arch_instance['normal_post'] num_c = int(self.num_channels_dec * mult) cell = Cell(num_c, num_c, cell_type='normal_post', arch=arch, use_se=self.use_se) post_process.append(cell) self.mult = mult return post_process
[docs] def forward(self, x): s = self.stem(2 * x - 1.0) for cell in self.pre_process: s = cell(s) combiner_cells_enc = [] combiner_cells_s = [] all_z = [] for cell in self.enc_tower: if cell.cell_type == 'combiner_enc': combiner_cells_enc.append(cell) combiner_cells_s.append(s) else: s = cell(s) combiner_cells_enc.reverse() combiner_cells_s.reverse() idx_dec = 0 ftr = self.enc0(s) #conv param0 = self.enc_sampler[idx_dec](ftr) # another conv2d mu_q, log_sig_q = torch.chunk(param0, 2, dim=1) dist = Normal(mu_q, log_sig_q) z, _ = dist.sample() #z_0 all_z.append(z) loss_qz = log_density_gaussian(z, mu_q, log_sig_q) # total_c = [loss_qz] idx_dec = 0 s = self.prior_ftr0.unsqueeze(0) # random value batch_size = z.size(0) s = s.expand(batch_size, -1, -1, -1) total_c = 0 idx_dec = 0 for cell in self.dec_tower: if cell.cell_type == 'combiner_dec': if idx_dec > 0: ftr = combiner_cells_enc[idx_dec - 1](combiner_cells_s[idx_dec - 1], s) param = self.enc_sampler[idx_dec](ftr) mu_q, log_sig_q = torch.chunk(param, 2, dim=1) dist = Normal(mu_q, log_sig_q) z, _ = dist.sample() # z_n all_z.append(z) #print(z.shape) loss_qz = log_density_gaussian(z, mu_q, log_sig_q) total_c += loss_qz #total_c.append(loss_qz) s = cell(s, z) idx_dec += 1 else: s = cell(s) for cell in self.post_process: s = cell(s) # print(s.shape) logits = self.image_conditional(s) tmp_tot =[] for i in range(idx_dec): tmp, _ = self.rnn(logits[:,i,:,:].squeeze().permute(0,2,1)) tmp_tot.append(tmp.permute(0,2,1)) logits = torch.stack(tmp_tot,1) logits = self.projection(logits[...,-(self.input_size + self.hidden_size):]) # total_c = torch.mean(torch.tensor(total_c)) total_c = total_c/idx_dec return logits, total_c, all_z# , log_q, log_p, kl_all, kl_diag
[docs] def decoder_output(self, logits): return NormalDecoder(logits)