Source code for dsipts.models.Samformer
## Copyright https://github.com/romilbert/samformer/tree/main?tab=MIT-1-ov-file#readme
## Modified for notation alignmenet and batch structure
## extended to what inside samformer folder
import torch
import torch.nn as nn
import numpy as np
from .samformer.utils import scaled_dot_product_attention, RevIN
from .base import Base
from .utils import QuantileLossMO,Permute, get_activation
from typing import List, Union
from ..data_structure.utils import beauty_string
from .utils import get_scope
[docs]
class Samformer(Base):
handle_multivariate = True
handle_future_covariates = False # or at least it seems...
handle_categorical_variables = False #solo nel encoder
handle_quantile_loss = False # NOT EFFICIENTLY ADDED, TODO fix this
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs]
def __init__(self,
out_channels: int,
past_steps: int,
future_steps: int,
past_channels: int,
future_channels: int,
embs: List[int],
# specific params
hidden_size:int,
use_revin: bool,
rho: float=0.5,
dropout_rate: float=0.1,
activation: str='',
persistence_weight:float=0.0,
loss_type: str='l1',
quantiles:List[float]=[],
optim:Union[str,None]=None,
optim_config:Union[dict,None]=None,
scheduler_config:Union[dict,None]=None,
**kwargs)->None:
"""Samformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention.
https://arxiv.org/pdf/2402.10198
Args:
out_channels (int): number of variables to be predicted
past_steps (int): Lookback window length
future_steps (int): Horizon window length
past_channels (int): number of past variables
future_channels (int): number of future auxiliary variables
embs (List[int]): list of embeddings
hidden_size (int): first embedding size of the model ('r' in the paper)
d_model (int): second embedding size (r^{tilda} in the model). Should be smaller than hidden_size
n_head (int): number of heads
n_layer_decoder (int): number layers
dropout_rate (float):
class_strategy (str): strategy (see paper) projection/average/cls_token
activation (str, optional): activation function to be used 'nn.GELU'.
persistence_weight (float, optional): Defaults to 0.0.
loss_type (str, optional): Defaults to 'l1'.
quantiles (List[float], optional): Defaults to []. NOT USED
optim (Union[str,None], optional): Defaults to None.
optim_config (Union[dict,None], optional): Defaults to None.
scheduler_config (Union[dict,None], optional): Defaults to None.
"""
super().__init__(**kwargs)
if activation == 'torch.nn.SELU':
beauty_string('SELU do not require BN','info',self.verbose)
use_bn = False
if isinstance(activation,str):
activation = get_activation(activation)
self.save_hyperparameters(logger=False)
# self.dropout = dropout_rate
self.persistence_weight = persistence_weight
self.optim_config = optim_config
self.scheduler_config = scheduler_config
self.loss_type = loss_type
self.future_steps = future_steps
if len(quantiles)==0:
self.mul = 1
self.use_quantiles = False
if self.loss_type == 'mse':
self.loss = nn.MSELoss()
else:
self.loss = nn.L1Loss()
else:
assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
self.mul = len(quantiles)
self.use_quantiles = True
self.loss = QuantileLossMO(quantiles)
self.out_channels = out_channels
self.revin = RevIN(num_features=past_channels)
self.compute_keys = nn.Linear(past_steps, hidden_size)
self.compute_queries = nn.Linear(past_steps, hidden_size)
self.compute_values = nn.Linear(past_steps, past_steps)
self.linear_forecaster = nn.Linear(past_steps, future_steps)
self.use_revin = use_revin
[docs]
def forward(self, batch:dict)-> float:
x = batch['x_num_past'].to(self.device)
BS = x.shape[0]
if self.use_revin:
x_norm = self.revin(x.transpose(1, 2), mode='norm').transpose(1, 2) # (n, D, L)
else:
x_norm = x
# Channel-Wise Attention
queries = self.compute_queries(x_norm) # (n, D, hid_dim)
keys = self.compute_keys(x_norm) # (n, D, hid_dim)
values = self.compute_values(x_norm) # (n, D, L)
if hasattr(nn.functional, 'scaled_dot_product_attention'):
att_score = nn.functional.scaled_dot_product_attention(queries, keys, values) # (n, D, L)
else:
att_score = scaled_dot_product_attention(queries, keys, values) # (n, D, L)
out = x_norm + att_score # (n, D, L)
# Linear Forecasting
out = self.linear_forecaster(out) # (n, D, H)
# RevIN Denormalization
if self.use_revin:
out = self.revin(out.transpose(1, 2), mode='denorm').transpose(1, 2) # (n, D, H)
return out.reshape(BS,self.future_steps,self.out_channels,self.mul)