Source code for dsipts.models.samformer.utils

import torch
import torch.nn as nn
import numpy as np
from torch.optim import Optimizer


[docs] def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): """ A copy-paste from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html """ L, S = query.size(-2), key.size(-2) scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value
[docs] class RevIN(nn.Module): """ Reversible Instance Normalization (RevIN) https://openreview.net/pdf?id=cGDAkQo1C0p https://github.com/ts-kim/RevIN """
[docs] def __init__(self, num_features: int, eps=1e-5, affine=True): """ :param num_features: the number of features or channels :param eps: a value added for numerical stability :param affine: if True, RevIN has learnable affine parameters """ super(RevIN, self).__init__() self.num_features = num_features self.eps = eps self.affine = affine if self.affine: self._init_params()
[docs] def forward(self, x, mode:str): if mode == 'norm': self._get_statistics(x) x = self._normalize(x) elif mode == 'denorm': x = self._denormalize(x) else: raise NotImplementedError return x
def _init_params(self): # initialize RevIN params: (C,) self.affine_weight = nn.Parameter(torch.ones(self.num_features)) self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) def _get_statistics(self, x): dim2reduce = tuple(range(1, x.ndim-1)) self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() def _normalize(self, x): x = x - self.mean x = x / self.stdev if self.affine: x = x * self.affine_weight x = x + self.affine_bias return x def _denormalize(self, x): if self.affine: x = x - self.affine_bias x = x / (self.affine_weight + self.eps*self.eps) x = x * self.stdev x = x + self.mean return x
[docs] class SAM(torch.optim.Optimizer):
[docs] def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" defaults = dict(rho=rho, adaptive=adaptive, **kwargs) super(SAM, self).__init__(params, defaults) self.base_optimizer = base_optimizer(self.param_groups, **kwargs) self.param_groups = self.base_optimizer.param_groups self.defaults.update(self.base_optimizer.defaults)
[docs] @torch.no_grad() def first_step(self, zero_grad=False): grad_norm = self._grad_norm() for group in self.param_groups: scale = group["rho"] / (grad_norm + 1e-12) for p in group["params"]: if p.grad is None: continue self.state[p]["old_p"] = p.data.clone() e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) p.add_(e_w) # Perturb weights in the gradient direction if zero_grad: self.zero_grad()
[docs] @torch.no_grad() def second_step(self, zero_grad=False): for group in self.param_groups: for p in group["params"]: if p.grad is None: continue p.data = self.state[p]["old_p"] # Restore original weights self.base_optimizer.step() # Apply the sharpness-aware update if zero_grad: self.zero_grad()
[docs] @torch.no_grad() def step(self, closure=None): assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" with torch.enable_grad(): closure() # First forward-backward pass self.first_step(zero_grad=True) with torch.enable_grad(): closure() # Second forward-backward pass self.second_step()
def _grad_norm(self): shared_device = self.param_groups[0]["params"][0].device grads = [ ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) for group in self.param_groups for p in group["params"] if p.grad is not None ] return torch.norm(torch.stack(grads), p=2) if grads else torch.tensor(0.0, device=shared_device)
[docs] def load_state_dict(self, state_dict): super().load_state_dict(state_dict) if hasattr(self, "base_optimizer"): # Ensure base optimizer exists self.base_optimizer.load_state_dict(state_dict["base_optimizer"])