import torch
import torch.nn as nn
import numpy as np
from torch.optim import Optimizer
[docs]
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
"""
A copy-paste from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
"""
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
[docs]
class RevIN(nn.Module):
"""
Reversible Instance Normalization (RevIN) https://openreview.net/pdf?id=cGDAkQo1C0p
https://github.com/ts-kim/RevIN
"""
[docs]
def __init__(self, num_features: int, eps=1e-5, affine=True):
"""
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super(RevIN, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self._init_params()
[docs]
def forward(self, x, mode:str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else: raise NotImplementedError
return x
def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim-1))
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
def _normalize(self, x):
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps*self.eps)
x = x * self.stdev
x = x + self.mean
return x
[docs]
class SAM(torch.optim.Optimizer):
[docs]
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)
[docs]
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None:
continue
self.state[p]["old_p"] = p.data.clone()
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # Perturb weights in the gradient direction
if zero_grad:
self.zero_grad()
[docs]
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.data = self.state[p]["old_p"] # Restore original weights
self.base_optimizer.step() # Apply the sharpness-aware update
if zero_grad:
self.zero_grad()
[docs]
@torch.no_grad()
def step(self, closure=None):
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
with torch.enable_grad():
closure() # First forward-backward pass
self.first_step(zero_grad=True)
with torch.enable_grad():
closure() # Second forward-backward pass
self.second_step()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][0].device
grads = [
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]
return torch.norm(torch.stack(grads), p=2) if grads else torch.tensor(0.0, device=shared_device)
[docs]
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
if hasattr(self, "base_optimizer"): # Ensure base optimizer exists
self.base_optimizer.load_state_dict(state_dict["base_optimizer"])