Source code for dsipts.models.Autoformer

## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
## Code modified for align the notation and the batch generation
## extended to all present in informer, autoformer folder


from torch import  nn
import torch
from .base import Base
from typing import List,Union
from ..data_structure.utils import beauty_string
from .utils import  get_activation,get_scope,QuantileLossMO
from .autoformer.layers import AutoCorrelation, AutoCorrelationLayer, Encoder, Decoder,\
    EncoderLayer, DecoderLayer, my_Layernorm, series_decomp,PositionalEmbedding


  
[docs] class Autoformer(Base): handle_multivariate = True handle_future_covariates = True handle_categorical_variables = True handle_quantile_loss= True description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs] def __init__(self, past_steps:int, future_steps:int, label_len: int, past_channels:int, future_channels:int, out_channels:int, d_model:int, embs:List[int], kernel_size:int, activation:str='torch.nn.ReLU', factor: int=5, n_head:int=1, n_layer_encoder:int=2, n_layer_decoder:int=2, hidden_size:int=1048, persistence_weight:float=0.0, loss_type: str='l1', quantiles:List[int]=[], dropout_rate:float=0.1, optim:Union[str,None]=None, optim_config:dict=None, scheduler_config:dict=None, **kwargs )->None: """ Args: past_steps (int): number of past datapoints used , not used here future_steps (int): number of future lag to predict label_len (int): overlap len past_channels (int): number of numeric past variables, must be >0 future_channels (int): number of future numeric variables out_channels (int): number of output channels d_model (int): dimension of the attention model embs (List): list of the initial dimension of the categorical variables embed_type (int): type of embedding kernel_size (int): kernel_size activation (str, optional): activation fuction function pytorch. Default torch.nn.ReLU n_head (int, optional): number of heads n_layer_encoder (int, optional): number of encoding layers n_layer_decoder (int, optional): number of decoding layers factor (int): num of routers in Cross-Dimension Stage of TSA (c) see the paper out_channels (int): number of output channels persistence_weight (float): weight controlling the divergence from persistence model. Default 0 loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1, quantiles (List[int], optional): quantiles (List[int], optional): we can use quantile loss il len(quantiles) = 0 (usually 0.1,0.5, 0.9) or L1loss in case len(quantiles)==0. Defaults to []. dropout_rate (float, optional): dropout rate in Dropout layers. Defaults to 0.1. optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam. optim_config (dict, optional): configuration for Adam optimizer. Defaults to None. scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None. """ super().__init__(**kwargs) beauty_string(self.description,'info',True) if activation == 'torch.nn.SELU': beauty_string('SELU do not require BN','info',self.verbose) if isinstance(activation,str): activation = get_activation(activation) else: beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose) self.save_hyperparameters(logger=False) self.future_steps = future_steps ##mandatory self.use_quantiles = False self.optim = optim self.optim_config = optim_config self.scheduler_config = scheduler_config self.loss_type = loss_type self.persistence_weight = persistence_weight if len(quantiles)>0: assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True) self.use_quantiles = True self.mul = len(quantiles) self.loss = QuantileLossMO(quantiles) else: self.use_quantiles = False self.mul = 1 if self.loss_type == 'mse': self.loss = nn.MSELoss() else: self.loss = nn.L1Loss() self.seq_len = past_steps self.label_len = label_len self.pred_len = future_steps # Decomp self.decomp = series_decomp(kernel_size) self.embs = nn.ModuleList() for k in embs: self.embs.append(nn.Embedding(k+1,d_model)) #past_channels+=emb_channels #future_channels+=emb_channels self.linear_encoder = nn.Sequential(nn.Linear(past_channels,past_channels*2), activation(), nn.Dropout(dropout_rate), nn.Linear(past_channels*2,d_model*2), activation(), nn.Dropout(dropout_rate), nn.Linear(d_model*2,d_model)) self.linear_decoder = nn.Sequential(nn.Linear(future_channels,future_channels*2), activation(), nn.Dropout(dropout_rate), nn.Linear(future_channels*2,d_model*2), activation() ,nn.Dropout(dropout_rate), nn.Linear(d_model*2,d_model)) self.final_layer = nn.Linear(past_channels,out_channels) # Encoder self.encoder = Encoder( [ EncoderLayer( AutoCorrelationLayer( AutoCorrelation(False, factor, attention_dropout=dropout_rate, output_attention=False), d_model, n_head), d_model, hidden_size, moving_avg=kernel_size, dropout=dropout_rate, activation=activation ) for _ in range(n_layer_encoder) ], norm_layer=my_Layernorm(d_model) ) # Decoder self.decoder = Decoder( [ DecoderLayer( AutoCorrelationLayer( AutoCorrelation(True, factor, attention_dropout=dropout_rate, output_attention=False), d_model, n_head), AutoCorrelationLayer( AutoCorrelation(False, factor, attention_dropout=dropout_rate, output_attention=False), d_model, n_head), d_model, out_channels, hidden_size, moving_avg=kernel_size, dropout=dropout_rate, activation=activation, ) for _ in range(n_layer_decoder) ], norm_layer=my_Layernorm(d_model), projection=nn.Linear(d_model, out_channels*self.mul, bias=True) ) self.pee = PositionalEmbedding(d_model=d_model) self.ped = PositionalEmbedding(d_model=d_model)
[docs] def forward(self, batch): #self.decoder.device = self.device #self.encoder.device = self.device batch['idx_target'][0] idx_target_future = batch['idx_target_future'][0] x_seq = batch['x_num_past'].to(self.device)#[:,:,idx_target] if 'x_cat_future' in batch.keys(): cat_future = batch['x_cat_future'].to(self.device) if 'x_cat_past' in batch.keys(): cat_past = batch['x_cat_past'].to(self.device) if 'x_num_future' in batch.keys(): x_future = batch['x_num_future'].to(self.device) x_future[:,-self.pred_len:,idx_target_future] = 0 pee = self.pee(x_seq).repeat(x_seq.shape[0],1,1) ped = self.ped(torch.zeros(x_seq.shape[0], self.pred_len+self.label_len).float()).repeat(x_seq.shape[0],1,1) if 'x_cat_past' in batch.keys(): tmp_emb=self.embs[0](cat_past[:,:,0]) if len(self.embs)>1: for i in range(1,len(self.embs)): tmp_emb+=self.embs[i](cat_past[:,:,i]) pee+=tmp_emb if 'x_cat_future' in batch.keys(): for i in range(len(self.embs)): if i>0: tmp_emb+=self.embs[i](cat_future[:,:,i]) else: tmp_emb=self.embs[i](cat_future[:,:,i]) ped+=tmp_emb mean = torch.mean(x_seq, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) zeros = torch.zeros([x_future.shape[0], self.pred_len, x_seq.shape[2]], device=x_seq.device) seasonal_init, trend_init = self.decomp(x_seq) # decoder input trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1) # enc enc_out = self.linear_encoder(x_seq)+pee enc_out, attns = self.encoder(enc_out, attn_mask=None) # dec dec_out = self.linear_decoder(x_future)+ped seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init) # final dec_out = self.final_layer(trend_part + seasonal_part) BS = dec_out.shape[0] return dec_out[:, -self.pred_len:, :].reshape(BS,self.pred_len,-1,self.mul) # [B, L, D,MUL]