Source code for dsipts.models.Informer


## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
## Code modified for align the notation and the batch generation
## extended to all present in informer, autoformer folder
from torch import  nn
import torch
from .base import Base
from typing import List,Union

from .informer.encoder import Encoder, EncoderLayer, ConvLayer
from .informer.decoder import Decoder, DecoderLayer
from .informer.attn import FullAttention, ProbAttention, AttentionLayer
from .informer.embed import DataEmbedding
from ..data_structure.utils import beauty_string
from .utils import  get_scope,QuantileLossMO
  
    
  
[docs] class Informer(Base): handle_multivariate = True handle_future_covariates = True handle_categorical_variables = True handle_quantile_loss = True description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs] def __init__(self, past_steps:int, future_steps:int, past_channels:int, future_channels:int, d_model:int, embs:List[int], hidden_size:int, n_layer_encoder:int, n_layer_decoder:int, out_channels:int, mix:bool=True, activation:str='torch.nn.ReLU', remove_last = False, attn: str='prob', distil:bool=True, factor:int=5, n_head:int=1, persistence_weight:float=0.0, loss_type: str='l1', quantiles:List[int]=[], dropout_rate:float=0.1, optim:Union[str,None]=None, optim_config:dict=None, scheduler_config:dict=None, **kwargs)->None: """Informer Args: past_steps (int): number of past datapoints used , not used here future_steps (int): number of future lag to predict past_channels (int): number of numeric past variables, must be >0 future_channels (int): number of future numeric variables d_model (int): dimension of the attention model embs (List): list of the initial dimension of the categorical variables hidden_size (int): hidden size of the linear block n_layer_encoder (int): layers to use in the encoder n_layer_decoder (int): layers to use in the decoder out_channels (int): number of output channels mix (bool, optional): se mix attention in generative decoder. Defaults to True. activation (str, optional): relu or gelu. Defaults to 'relu'. remove_last (boolean,optional): if true the model try to predic the difference respect the last observation. attn (str, optional): attention used in encoder, options:[prob, full]. Defaults to 'prob'. distil (bool, optional): whether to use distilling in encoder, using this argument means not using distilling. Defaults to True. factor (int, optional): probsparse attn factor. Defaults to 5. n_head (int, optional): heads equal in the encoder and encoder. Defaults to 1. persistence_weight (float): weight controlling the divergence from persistence model. Default 0 loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1, quantiles (List[int], optional): NOT USED YET dropout_rate (float, optional): dropout rate in Dropout layers. Defaults to 0.1. optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam. optim_config (dict, optional): configuration for Adam optimizer. Defaults to None. scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None. """ super().__init__(**kwargs) self.save_hyperparameters(logger=False) beauty_string("BE SURE TO SETUP split_params: shift: ${model_configs.future_steps} BECAUSE IT IS REQUIRED",'info',True) self.future_steps = future_steps self.use_quantiles = False self.optim = optim self.optim_config = optim_config self.scheduler_config = scheduler_config self.persistence_weight = persistence_weight self.loss_type = loss_type self.remove_last = remove_last if len(quantiles)>0: assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True) self.use_quantiles = True self.mul = len(quantiles) self.loss = QuantileLossMO(quantiles) else: self.use_quantiles = False self.mul = 1 if self.loss_type == 'mse': self.loss = nn.MSELoss() else: self.loss = nn.L1Loss() self.enc_embedding = DataEmbedding(past_channels, d_model, embs, dropout_rate) self.dec_embedding = DataEmbedding(future_channels, d_model, embs, dropout_rate) # Attention Attn = ProbAttention if attn=='prob' else FullAttention # Encoder self.encoder = Encoder( [ EncoderLayer( AttentionLayer(Attn(False, factor, attention_dropout=dropout_rate, output_attention=False), d_model, n_head, mix=False), d_model, hidden_size, dropout=dropout_rate, activation=activation ) for _ in range(n_layer_encoder) ], [ ConvLayer( d_model ) for _ in range(n_layer_encoder-1) ] if distil else None, norm_layer=torch.nn.LayerNorm(d_model) ) # Decoder self.decoder = Decoder( [ DecoderLayer( AttentionLayer(Attn(True, factor, attention_dropout=dropout_rate, output_attention=False), d_model, n_head, mix=mix), AttentionLayer(FullAttention(False, factor, attention_dropout=dropout_rate, output_attention=False), d_model, n_head, mix=False), d_model, hidden_size, dropout=dropout_rate, activation=activation, ) for _ in range(n_layer_decoder) ], norm_layer=torch.nn.LayerNorm(d_model) ) self.projection = nn.Linear(d_model, out_channels*self.mul, bias=True)
[docs] def forward(self,batch): #x_enc, x_mark_enc, x_dec, x_mark_dec,enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): x_enc = batch['x_num_past'].to(self.device) idx_target_future = batch['idx_target_future'][0] if 'x_cat_past' in batch.keys(): x_mark_enc = batch['x_cat_past'].to(self.device) else: x_mark_enc = None enc_self_mask = None x_dec = batch['x_num_future'].to(self.device) x_dec[:,-self.future_steps:,idx_target_future] = 0 if 'x_cat_future' in batch.keys(): x_mark_dec = batch['x_cat_future'].to(self.device) else: x_mark_dec = None dec_self_mask = None dec_enc_mask = None if self.remove_last: idx_target = batch['idx_target'][0] x_start = x_enc[:,-1,idx_target].unsqueeze(1) x_enc[:,:,idx_target]-=x_start enc_out = self.enc_embedding(x_enc, x_mark_enc) enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) dec_out = self.dec_embedding(x_dec, x_mark_dec) dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) dec_out = self.projection(dec_out) # dec_out = self.end_conv1(dec_out) # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2) #import pdb #pdb.set_trace() res = dec_out[:,-self.future_steps:,:].unsqueeze(3) if self.remove_last: res+=x_start.unsqueeze(1) BS = res.shape[0] return res.reshape(BS,self.future_steps,-1,self.mul)