import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
'''
class VectorQuantizer(nn.Module):
"""
Inspired from Sonnet implementation of VQ-VAE https://arxiv.org/abs/1711.00937,
in https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py and
pytorch implementation of it from zalandoresearch in https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb.
Implements the algorithm presented in
'Neural Discrete Representation Learning' by van den Oord et al.
https://arxiv.org/abs/1711.00937
Input any tensor to be quantized. Last dimension will be used as space in
which to quantize. All other dimensions will be flattened and will be seen
as different examples to quantize.
The output tensor will have the same shape as the input.
For example a tensor with shape [16, 32, 32, 64] will be reshaped into
[16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized
independently.
Args:
embedding_dim: integer representing the dimensionality of the tensors in the
quantized space. Inputs to the modules must be in this format as well.
num_embeddings: integer, the number of vectors in the quantized space.
commitment_cost: scalar which controls the weighting of the loss terms
(see equation 4 in the paper - this variable is Beta).
"""
def __init__(self, num_embeddings, embedding_dim, commitment_cost, device):
super(VectorQuantizer, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
self._commitment_cost = commitment_cost
self._device = device
def forward(self, inputs, compute_distances_if_possible=True, record_codebook_stats=False):
"""
Connects the module to some inputs.
Args:
inputs: Tensor, final dimension must be equal to embedding_dim. All other
leading dimensions will be flattened and treated as a large batch.
Returns:
loss: Tensor containing the loss to optimize.
quantize: Tensor containing the quantized version of the input.
perplexity: Tensor containing the perplexity of the encodings.
encodings: Tensor containing the discrete encodings, ie which element
of the quantized space each input element was mapped to.
distances
"""
# Convert inputs from BCHW -> BHWC
inputs = inputs.permute(1, 2, 0).contiguous()
input_shape = inputs.shape
_, time, batch_size = input_shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
# Compute distances between encoded audio frames and embedding vectors
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
"""
encoding_indices: Tensor containing the discrete encoding indices, ie
which element of the quantized space each input element was mapped to.
"""
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, dtype=torch.float).to(self._device)
encodings.scatter_(1, encoding_indices, 1)
# Compute distances between encoding vectors
if not self.training and compute_distances_if_possible:
_encoding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in combinations(flat_input, r=2)]
encoding_distances = torch.tensor(_encoding_distances).to(self._device).view(batch_size, -1)
else:
encoding_distances = None
# Compute distances between embedding vectors
if not self.training and compute_distances_if_possible:
_embedding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in combinations(self._embedding.weight, r=2)]
embedding_distances = torch.tensor(_embedding_distances).to(self._device)
else:
embedding_distances = None
# Sample nearest embedding
if not self.training and compute_distances_if_possible:
_frames_vs_embedding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in product(flat_input, self._embedding.weight.detach())]
frames_vs_embedding_distances = torch.tensor(_frames_vs_embedding_distances).to(self._device).view(batch_size, time, -1)
else:
frames_vs_embedding_distances = None
# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# TODO: Check if the more readable self._embedding.weight.index_select(dim=1, index=encoding_indices) works better
concatenated_quantized = self._embedding.weight[torch.argmin(distances, dim=1).detach().cpu()] if not self.training or record_codebook_stats else None
# Losses
e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
q_latent_loss = torch.mean((quantized - inputs.detach())**2)
commitment_loss = self._commitment_cost * e_latent_loss
vq_loss = q_latent_loss + commitment_loss
quantized = inputs + (quantized - inputs).detach() # Trick to prevent backpropagation of quantized
avg_probs = torch.mean(encodings, dim=0)
"""
The perplexity a useful value to track during training.
It indicates how many codes are 'active' on average.
"""
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # Exponential entropy
# Convert quantized from BHWC -> BCHW
return vq_loss, quantized.permute(2, 0, 1).contiguous(), \
perplexity, encodings.view(batch_size, time, -1), \
distances.view(batch_size, time, -1), encoding_indices, \
{'e_latent_loss': e_latent_loss.item(), 'q_latent_loss': q_latent_loss.item(),
'commitment_loss': commitment_loss.item(), 'vq_loss': vq_loss.item()}, \
encoding_distances, embedding_distances, frames_vs_embedding_distances, concatenated_quantized
@property
def embedding(self):
return self._embedding
'''
[docs]
class VectorQuantizer(nn.Module):
[docs]
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(VectorQuantizer, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
self._commitment_cost = commitment_cost
[docs]
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
inputs = inputs.permute(0, 2, 1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self._commitment_cost * e_latent_loss
quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BHWC -> BCHW
#return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings.view(input_shape[0],-1,encodings.shape[1])
[docs]
class VectorQuantizerEMA(nn.Module):
[docs]
def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
super(VectorQuantizerEMA, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.normal_()
self._commitment_cost = commitment_cost
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
self._ema_w.data.normal_()
self._decay = decay
self._epsilon = epsilon
[docs]
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
inputs = inputs.permute(0, 2, 1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# Use EMA to update the embedding vectors
if self.training:
self._ema_cluster_size = self._ema_cluster_size * self._decay + \
(1 - self._decay) * torch.sum(encodings, 0)
# Laplace smoothing of the cluster size
n = torch.sum(self._ema_cluster_size.data)
self._ema_cluster_size = (
(self._ema_cluster_size + self._epsilon)
/ (n + self._num_embeddings * self._epsilon) * n)
dw = torch.matmul(encodings.t(), flat_input)
self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
loss = self._commitment_cost * e_latent_loss
# Straight Through Estimator
quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BHWC -> BCHW
return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings.view(input_shape[0],-1,encodings.shape[1])
[docs]
class Residual(nn.Module):
[docs]
def __init__(self, in_channels, hidden_channels, num_residual_hiddens):
super(Residual, self).__init__()
relu_1 = nn.ReLU(True)
conv_1 = nn.Conv1d(
in_channels=in_channels,
out_channels=num_residual_hiddens,
kernel_size=3,
stride=1,
padding=1,
bias=False
)
relu_2 = nn.ReLU(True)
conv_2 = nn.Conv1d(
in_channels=num_residual_hiddens,
out_channels=hidden_channels,
kernel_size=1,
stride=1,
bias=False
)
# All parameters same as specified in the paper
self._block = nn.Sequential(
relu_1,
conv_1,
relu_2,
conv_2
)
[docs]
def forward(self, x):
return x + self._block(x)
[docs]
class ResidualStack(nn.Module):
[docs]
def __init__(self, in_channels, hidden_channels, num_residual_layers, num_residual_hiddens):
super(ResidualStack, self).__init__()
self._num_residual_layers = num_residual_layers
self._layers = nn.ModuleList([Residual(in_channels, hidden_channels, num_residual_hiddens)] * self._num_residual_layers)
[docs]
def forward(self, x):
for i in range(self._num_residual_layers):
x = self._layers[i](x)
return F.relu(x)
[docs]
class Encoder(nn.Module):
[docs]
def __init__(self, in_channels, hidden_channels,num_residual_layers=3):
super(Encoder, self).__init__()
self._conv_1 = nn.Conv1d(in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=3, padding=1)
self._conv_2 = nn.Conv1d(in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=3,padding=1)
self._conv_3 = nn.Conv1d(in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=4,
stride=2, padding=1)
self._conv_4 = nn.Conv1d(in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=3,
padding=1)
self._conv_5 = nn.Conv1d(in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=3, padding=1)
self._residual_stack = ResidualStack(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
num_residual_layers=num_residual_layers,
num_residual_hiddens=hidden_channels//2
)
[docs]
def forward(self, inputs):
x_conv_1 = F.relu(self._conv_1(inputs))
x = F.relu(self._conv_2(x_conv_1)) + x_conv_1
x_conv_3 = F.relu(self._conv_3(x))
x_conv_4 = F.relu(self._conv_4(x_conv_3)) + x_conv_3
x_conv_5 = F.relu(self._conv_5(x_conv_4)) + x_conv_4
x = self._residual_stack(x_conv_5) + x_conv_5
return x
[docs]
class Jitter(nn.Module):
"""
Jitter implementation from [Chorowski et al., 2019].
During training, each latent vector can replace either one or both of
its neighbors. As in dropout, this prevents the model from
relying on consistency across groups of tokens. Additionally,
this regularization also promotes latent representation stability
over time: a latent vector extracted at time step t must strive
to also be useful at time steps t − 1 or t + 1.
"""
[docs]
def __init__(self, probability=0.12):
super(Jitter, self).__init__()
self._probability = probability
[docs]
def forward(self, quantized):
original_quantized = quantized.detach().clone()
length = original_quantized.size(2)
for i in range(length):
"""
Each latent vector is replace with either of its neighbors with a certain probability
(0.12 from the paper).
"""
replace = [True, False][np.random.choice([1, 0], p=[self._probability, 1 - self._probability])]
if replace:
if i == 0:
neighbor_index = i + 1
elif i == length - 1:
neighbor_index = i - 1
else:
"""
"We independently sample whether it is to
be replaced with the token right after
or before it."
"""
neighbor_index = i + np.random.choice([-1, 1], p=[0.5, 0.5])
quantized[:, :, i] = original_quantized[:, :, neighbor_index]
return quantized
[docs]
class Decoder(nn.Module):
[docs]
def __init__(self, in_channels, hidden_channels,out_channels,num_residual_layers=3):
super(Decoder, self).__init__()
self._jitter = Jitter(0.125)
self._conv_1 = nn.Conv1d(in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=3,
padding=1)
self._upsample = nn.Upsample(scale_factor=2)
self._residual_stack = ResidualStack(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
num_residual_layers=num_residual_layers,
num_residual_hiddens=hidden_channels//2
)
self._conv_trans_1 = nn.ConvTranspose1d(in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=3, padding=1)
self._conv_trans_2 = nn.ConvTranspose1d(
in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=4,padding=2)
self._conv_trans_3 = nn.ConvTranspose1d(in_channels=hidden_channels,
out_channels=out_channels,
kernel_size=4, padding=1)
[docs]
def forward(self, x,is_training=True):
#if is_training:
# x = self._jitter(x)
x = self._conv_1(x)
x = self._upsample(x)
x = self._residual_stack(x)
x = F.relu(self._conv_trans_1(x))
x = F.relu(self._conv_trans_2(x))
x = self._conv_trans_3(x)
return x
[docs]
class VQVAE(nn.Module):
[docs]
def __init__(self,in_channels, hidden_channels, out_channels,num_embeddings, embedding_dim, commitment_cost, decay):
super(VQVAE, self).__init__()
self._encoder = Encoder(in_channels, hidden_channels)
self._pre_vq_conv = nn.Conv1d(in_channels=hidden_channels,
out_channels=embedding_dim,
kernel_size=1,
stride=1)
if decay > 0.0:
self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
commitment_cost, decay)
else:
logging.info('CARE NOT TESTED')
self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
commitment_cost)
self._decoder = Decoder(in_channels=embedding_dim,hidden_channels=hidden_channels,out_channels=out_channels)
[docs]
def forward(self, x,is_training=True):
z = self._encoder(x)
z = self._pre_vq_conv(z)
loss, quantized, perplexity, encodings = self._vq_vae(z)
x_recon = self._decoder(quantized,is_training)
return loss, x_recon, perplexity,quantized,encodings