import torch
import torch.nn.init as init
from torch import nn
import numpy as np
from numba import jit
from torch.autograd import Function
[docs]
def get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss):
message = f'Can {"NOT" if not handle_multivariate else "" } handle multivariate output \n'\
f'Can {"NOT" if not handle_future_covariates else "" } handle future covariates\n'\
f'Can {"NOT" if not handle_categorical_variables else "" } handle categorical covariates\n'\
f'Can {"NOT" if not handle_quantile_loss else "" } handle Quantile loss function'
return message
[docs]
class SinkhornDistance():
r"""
Given two empirical measures each with :math:`P_1` locations
:math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
outputs an approximation of the regularized OT cost for point clouds.
Args:
eps (float): regularization coefficient
max_iter (int): maximum number of Sinkhorn iterations
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of
elements in the output, 'sum': the output will be summed. Default: 'none'
Shape:
- Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
- Output: :math:`(N)` or :math:`()`, depending on `reduction`
"""
[docs]
def __init__(self, eps, max_iter, reduction='none'):
super(SinkhornDistance, self).__init__()
self.eps = eps
self.max_iter = max_iter
self.reduction = reduction
[docs]
def compute(self, x, y):
# The Sinkhorn algorithm takes as input three variables :
C = self._cost_matrix(x, y).to(x.device) # Wasserstein cost function
x_points = x.shape[-2]
y_points = y.shape[-2]
if x.dim() == 2:
batch_size = 1
else:
batch_size = x.shape[0]
# both marginals are fixed with equal weights
mu = torch.empty(batch_size, x_points, dtype=torch.float,
requires_grad=False).fill_(1.0 / x_points).squeeze().to(x.device)
nu = torch.empty(batch_size, y_points, dtype=torch.float,
requires_grad=False).fill_(1.0 / y_points).squeeze().to(x.device)
u = torch.zeros_like(mu).to(x.device)
v = torch.zeros_like(nu).to(x.device)
# To check if algorithm terminates because of threshold
# or max iterations reached
actual_nits = 0
# Stopping criterion
thresh = 1e-1
# Sinkhorn iterations
for i in range(self.max_iter):
u1 = u # useful to check the update
u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
err = (u - u1).abs().sum(-1).mean()
actual_nits += 1
if err.item() < thresh:
break
U, V = u, v
# Transport plan pi = diag(a)*K*diag(b)
pi = torch.exp(self.M(C, U, V))
# Sinkhorn distance
cost = torch.sum(pi * C, dim=(-2, -1))
if self.reduction == 'mean':
cost = cost.mean()
elif self.reduction == 'sum':
cost = cost.sum()
return cost#, pi, C
[docs]
def M(self, C, u, v):
"Modified cost for logarithmic updates"
"$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
@staticmethod
def _cost_matrix(x, y, p=2):
"Returns the matrix of $|x_i-y_j|^p$."
x_col = x.unsqueeze(-2)
y_lin = y.unsqueeze(-3)
C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
return C
[docs]
@staticmethod
def ave(u, u1, tau):
"Barycenter subroutine, used by kinetic acceleration through extrapolation."
return tau * u + (1 - tau) * u1
[docs]
class QuantileLossMO(nn.Module):
"""Copied from git
"""
[docs]
def __init__(self, quantiles):
super().__init__()
self.quantiles = quantiles
[docs]
def forward(self, preds, target):
assert not target.requires_grad
assert preds.size(0) == target.size(0)
tot_loss = 0
for j in range(preds.shape[2]):
losses = []
##suppose BxLxCxMUL
for i, q in enumerate(self.quantiles):
errors = target[:,:,j] - preds[:,:,j, i]
losses.append(torch.abs(torch.max((q-1) * errors,q * errors)))
loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1))
tot_loss+=loss
return tot_loss/preds.shape[2]/len(self.quantiles)
[docs]
class L1Loss(nn.Module):
"""Custom L1Loss
"""
[docs]
def __init__(self):
super().__init__()
self.f = nn.L1Loss()
[docs]
def forward(self, preds, target):
return self.f(preds[:,:,:,0],target)
[docs]
class Permute(nn.Module):
[docs]
def __init__(self):
super().__init__()
[docs]
def forward(self, input):
return torch.permute(input,(0,2,1))
[docs]
def get_activation(activation):
return eval(activation)
[docs]
def weight_init_zeros(m):
if isinstance(m, nn.LSTM):
for param in m.parameters():
if len(param.shape) >= 2:
init.constant_(param.data,0.0)
else:
init.constant_(param.data,0.0)
elif isinstance(m, nn.Embedding):
init.constant_(m.weight,0.0)
elif isinstance(m, nn.LayerNorm):
init.zeros_(m.bias)
init.ones_(m.weight)
elif isinstance(m, nn.LSTMCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.constant_(param.data,0.0)
else:
init.constant_(param.data,0.0)
elif isinstance(m, nn.GRU):
for param in m.parameters():
if len(param.shape) >= 2:
init.constant_(param.data,0.0)
else:
init.constant_(param.data,0.0)
for names in m._all_weights:
for name in filter(lambda n: "bias" in n, names):
bias = getattr(m, name)
n = bias.size(0)
bias.data[:n // 3].fill_(-1.)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.constant_(param.data,0.0)
else:
init.constant_(param.data,0.0)
else:
try:
init.constant_(m.weight.data, 0.0)
if m.bias is not None:
init.constant_(m.bias.data, 0.0)
except:
pass
[docs]
def weight_init(m):
"""
Usage:
model = Model()
model.apply(weight_init)
"""
if isinstance(m, nn.Conv1d):
init.normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.Conv3d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.ConvTranspose1d):
init.normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.ConvTranspose2d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.ConvTranspose3d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.BatchNorm1d):
init.normal_(m.weight.data, mean=1, std=0.02)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.BatchNorm2d):
init.normal_(m.weight.data, mean=1, std=0.02)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.BatchNorm3d):
init.normal_(m.weight.data, mean=1, std=0.02)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.LSTM):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
elif isinstance(m, nn.LSTMCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
elif isinstance(m, nn.GRU):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
for names in m._all_weights:
for name in filter(lambda n: "bias" in n, names):
bias = getattr(m, name)
n = bias.size(0)
bias.data[:n // 3].fill_(-1.)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
elif isinstance(m, nn.Embedding):
init.normal_(m.weight, mean=0.0, std=0.02)
elif isinstance(m, nn.LayerNorm):
init.zeros_(m.bias)
init.ones_(m.weight)
# if isinstance(module, nn.Linear):
# torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# if module.bias is not None:
# torch.nn.init.zeros_(module.bias)
[docs]
def pairwise_distances(x, y=None):
'''
Input: x is a Nxd matrix
y is an optional Mxd matirx
Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
if y is not given then use 'y=x'.
i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
'''
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y**2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
return torch.clamp(dist, 0.0, float('inf'))
[docs]
@jit(nopython = True)
def compute_softdtw(D, gamma):
N = D.shape[0]
M = D.shape[1]
R = np.zeros((N + 2, M + 2)) + 1e8
R[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
r0 = -R[i - 1, j - 1] / gamma
r1 = -R[i - 1, j] / gamma
r2 = -R[i, j - 1] / gamma
rmax = max(max(r0, r1), r2)
rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax)
softmin = - gamma * (np.log(rsum) + rmax)
R[i, j] = D[i - 1, j - 1] + softmin
return R
[docs]
@jit(nopython = True)
def compute_softdtw_backward(D_, R, gamma):
N = D_.shape[0]
M = D_.shape[1]
D = np.zeros((N + 2, M + 2))
E = np.zeros((N + 2, M + 2))
D[1:N + 1, 1:M + 1] = D_
E[-1, -1] = 1
R[:, -1] = -1e8
R[-1, :] = -1e8
R[-1, -1] = R[-2, -2]
for j in range(M, 0, -1):
for i in range(N, 0, -1):
a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma
b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma
c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma
a = np.exp(a0)
b = np.exp(b0)
c = np.exp(c0)
E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c
return E[1:N + 1, 1:M + 1]
[docs]
class SoftDTWBatch(Function):
[docs]
@staticmethod
def forward(ctx, D, gamma = 1.0): # D.shape: [batch_size, N , N]
dev = D.device
batch_size,N,N = D.shape
gamma = torch.FloatTensor([gamma]).to(dev)
D_ = D.detach().cpu().numpy()
g_ = gamma.item()
total_loss = 0
R = torch.zeros((batch_size, N+2 ,N+2)).to(dev)
for k in range(0, batch_size): # loop over all D in the batch
Rk = torch.FloatTensor(compute_softdtw(D_[k,:,:], g_)).to(dev)
R[k:k+1,:,:] = Rk
total_loss = total_loss + Rk[-2,-2]
ctx.save_for_backward(D, R, gamma)
return total_loss / batch_size
[docs]
@staticmethod
def backward(ctx, grad_output):
dev = grad_output.device
D, R, gamma = ctx.saved_tensors
batch_size,N,N = D.shape
D_ = D.detach().cpu().numpy()
R_ = R.detach().cpu().numpy()
g_ = gamma.item()
E = torch.zeros((batch_size, N ,N)).to(dev)
for k in range(batch_size):
Ek = torch.FloatTensor(compute_softdtw_backward(D_[k,:,:], R_[k,:,:], g_)).to(dev)
E[k:k+1,:,:] = Ek
return grad_output * E, None
[docs]
@jit(nopython = True)
def my_max(x, gamma):
# use the log-sum-exp trick
max_x = np.max(x)
exp_x = np.exp((x - max_x) / gamma)
Z = np.sum(exp_x)
return gamma * np.log(Z) + max_x, exp_x / Z
[docs]
@jit(nopython = True)
def my_min(x,gamma) :
min_x, argmax_x = my_max(-x, gamma)
return - min_x, argmax_x
[docs]
@jit(nopython = True)
def my_max_hessian_product(p, z, gamma):
return ( p * z - p * np.sum(p * z) ) /gamma
[docs]
@jit(nopython = True)
def my_min_hessian_product(p, z, gamma):
return - my_max_hessian_product(p, z, gamma)
[docs]
@jit(nopython = True)
def dtw_grad(theta, gamma):
m = theta.shape[0]
n = theta.shape[1]
V = np.zeros((m + 1, n + 1))
V[:, 0] = 1e10
V[0, :] = 1e10
V[0, 0] = 0
Q = np.zeros((m + 2, n + 2, 3))
for i in range(1, m + 1):
for j in range(1, n + 1):
# theta is indexed starting from 0.
v, Q[i, j] = my_min(np.array([V[i, j - 1],
V[i - 1, j - 1],
V[i - 1, j]]) , gamma)
V[i, j] = theta[i - 1, j - 1] + v
E = np.zeros((m + 2, n + 2))
E[m + 1, :] = 0
E[:, n + 1] = 0
E[m + 1, n + 1] = 1
Q[m + 1, n + 1] = 1
for i in range(m,0,-1):
for j in range(n,0,-1):
E[i, j] = Q[i, j + 1, 0] * E[i, j + 1] + \
Q[i + 1, j + 1, 1] * E[i + 1, j + 1] + \
Q[i + 1, j, 2] * E[i + 1, j]
return V[m, n], E[1:m + 1, 1:n + 1], Q, E
[docs]
@jit(nopython = True)
def dtw_hessian_prod(theta, Z, Q, E, gamma):
m = Z.shape[0]
n = Z.shape[1]
V_dot = np.zeros((m + 1, n + 1))
V_dot[0, 0] = 0
Q_dot = np.zeros((m + 2, n + 2, 3))
for i in range(1, m + 1):
for j in range(1, n + 1):
# theta is indexed starting from 0.
V_dot[i, j] = Z[i - 1, j - 1] + \
Q[i, j, 0] * V_dot[i, j - 1] + \
Q[i, j, 1] * V_dot[i - 1, j - 1] + \
Q[i, j, 2] * V_dot[i - 1, j]
v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]])
Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma)
E_dot = np.zeros((m + 2, n + 2))
for j in range(n,0,-1):
for i in range(m,0,-1):
E_dot[i, j] = Q_dot[i, j + 1, 0] * E[i, j + 1] + \
Q[i, j + 1, 0] * E_dot[i, j + 1] + \
Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1] + \
Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1] + \
Q_dot[i + 1, j, 2] * E[i + 1, j] + \
Q[i + 1, j, 2] * E_dot[i + 1, j]
return V_dot[m, n], E_dot[1:m + 1, 1:n + 1]
[docs]
class PathDTWBatch(Function):
[docs]
@staticmethod
def forward(ctx, D, gamma): # D.shape: [batch_size, N , N]
batch_size,N,N = D.shape
device = D.device
D_cpu = D.detach().cpu().numpy()
gamma_gpu = torch.FloatTensor([gamma]).to(device)
grad_gpu = torch.zeros((batch_size, N ,N)).to(device)
Q_gpu = torch.zeros((batch_size, N+2 ,N+2,3)).to(device)
E_gpu = torch.zeros((batch_size, N+2 ,N+2)).to(device)
for k in range(0,batch_size): # loop over all D in the batch
_, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k,:,:], gamma)
grad_gpu[k,:,:] = torch.FloatTensor(grad_cpu_k).to(device)
Q_gpu[k,:,:,:] = torch.FloatTensor(Q_cpu_k).to(device)
E_gpu[k,:,:] = torch.FloatTensor(E_cpu_k).to(device)
ctx.save_for_backward(grad_gpu,D, Q_gpu ,E_gpu, gamma_gpu)
return torch.mean(grad_gpu, dim=0)
[docs]
@staticmethod
def backward(ctx, grad_output):
device = grad_output.device
grad_gpu, D_gpu, Q_gpu, E_gpu, gamma = ctx.saved_tensors
D_cpu = D_gpu.detach().cpu().numpy()
Q_cpu = Q_gpu.detach().cpu().numpy()
E_cpu = E_gpu.detach().cpu().numpy()
gamma = gamma.detach().cpu().numpy()[0]
Z = grad_output.detach().cpu().numpy()
batch_size,N,N = D_cpu.shape
Hessian = torch.zeros((batch_size, N ,N)).to(device)
for k in range(0,batch_size):
_, hess_k = dtw_hessian_prod(D_cpu[k,:,:], Z, Q_cpu[k,:,:,:], E_cpu[k,:,:], gamma)
Hessian[k:k+1,:,:] = torch.FloatTensor(hess_k).to(device)
return Hessian, None