Source code for dsipts.models.TIDE

import torch
import torch.nn as nn
import numpy as np
from .tft import sub_nn
from .base import  Base
from .utils import  QuantileLossMO
from typing import List, Union
from ..data_structure.utils import beauty_string
from .utils import  get_scope

[docs] class TIDE(Base): handle_multivariate = True handle_future_covariates = True handle_categorical_variables = True handle_quantile_loss = True description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs] def __init__(self, out_channels: int, past_steps: int, future_steps: int, past_channels: int, future_channels: int, embs: List[int], # specific params hidden_size:int, d_model: int, n_add_enc: int, n_add_dec: int, dropout_rate: float, activation: str='', persistence_weight:float=0.0, loss_type: str='l1', quantiles:List[float]=[], optim:Union[str,None]=None, optim_config:Union[dict,None]=None, scheduler_config:Union[dict,None]=None, **kwargs)->None: """Long-term Forecasting with TiDE: Time-series Dense Encoder https://arxiv.org/abs/2304.08424 This NN uses as subnet the ResidualBlocks, which is composed by skip connection and activation+dropout. Every encoder and decoder head is composed by one Residual Block, like the temporal decoder and the feature projection for covariates. Args: out_channels (int): number of variables to be predicted past_steps (int): Lookback window length future_steps (int): Horizon window length past_channels (int): number of past variables future_channels (int): number of future auxiliary variables embs (List[int]): hidden_size (int): first embedding size of the model ('r' in the paper) d_model (int): second embedding size (r^{tilda} in the model). Should be smaller than hidden_size n_add_enc (int): number of OTHERS heads for the encoder part in the NN. 1 is always used by default. n_add_dec (int): number of OTHERS heads for the decoder part in the NN. 1 is always used by default. dropout_rate (float): activation (str, optional): activation function to be used in the Residual Block. E.g., 'nn.GELU'. Defaults to ''. persistence_weight (float, optional): Defaults to 0.0. loss_type (str, optional): Defaults to 'l1'. quantiles (List[float], optional): Defaults to []. optim (Union[str,None], optional): Defaults to None. optim_config (Union[dict,None], optional): Defaults to None. scheduler_config (Union[dict,None], optional): Defaults to None. """ super().__init__(**kwargs) self.save_hyperparameters(logger=False) # self.dropout = dropout_rate self.persistence_weight = persistence_weight self.optim = optim self.optim_config = optim_config self.scheduler_config = scheduler_config self.loss_type = loss_type if len(quantiles)==0: self.mul = 1 self.use_quantiles = False if self.loss_type == 'mse': self.loss = nn.MSELoss() else: self.loss = nn.L1Loss() else: assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True) self.mul = len(quantiles) self.use_quantiles = True self.loss = QuantileLossMO(quantiles) self.hidden_size = hidden_size # r self.d_model = d_model # r^tilda self.past_steps = past_steps # lookback size self.future_steps = future_steps # horizon size self.past_channels = past_channels # psat_vars self.future_channels = future_channels # fut_vars self.output_channels = out_channels # target_vars # for other numerical variables in the past self.aux_past_channels = past_channels - out_channels self.linear_aux_past = nn.ModuleList([nn.Linear(1, self.hidden_size) for _ in range(self.aux_past_channels)]) # for numerical variables in the future self.aux_fut_channels = future_channels self.linear_aux_fut = nn.ModuleList([nn.Linear(1, self.hidden_size) for _ in range(self.aux_fut_channels)]) # embedding categorical for both past and future (ASSUMING BOTH AVAILABLE OR NO ONE) self.seq_len = past_steps + future_steps self.emb_cat_var = sub_nn.embedding_cat_variables(self.seq_len, future_steps, hidden_size, embs, self.device) ## FEATURE PROJECTION # past if self.aux_past_channels>0: self.feat_proj_past = ResidualBlock(2*hidden_size, d_model, dropout_rate, activation) else: self.feat_proj_past = ResidualBlock(hidden_size, d_model, dropout_rate, activation) # future if self.aux_fut_channels>0: self.feat_proj_fut = ResidualBlock(2*hidden_size, d_model, dropout_rate, activation) else: self.feat_proj_fut = ResidualBlock(hidden_size, d_model, dropout_rate, activation) # # ENCODER self.enc_dim_input = past_steps*self.output_channels + (past_steps+future_steps)*d_model self.enc_dim_output = future_steps*d_model self.first_encoder = ResidualBlock(self.enc_dim_input, self.enc_dim_output, dropout_rate, activation) self.aux_encoder = nn.ModuleList([ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation) for _ in range(1, n_add_enc)]) # # DECODER self.first_decoder = ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation) self.aux_decoder = nn.ModuleList([ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation) for _ in range(1, n_add_dec)]) ## TEMPORAL DECOER self.temporal_decoder = ResidualBlock(2*d_model, out_channels*self.mul, dropout_rate, activation) # linear for Y lookback self.linear_target = nn.Linear(past_steps*out_channels, future_steps*out_channels*self.mul)
[docs] def forward(self, batch:dict)-> float: """training process of the diffusion network Args: batch (dict): variables loaded Returns: float: total loss about the prediction of the noises over all subnets extracted """ # LOADING AUTOREGRESSIVE CONTEXT OF TARGET VARIABLES num_past = batch['x_num_past'].to(self.device) idx_target = batch['idx_target'][0] y_past = num_past[:,:,idx_target] B = y_past.shape[0] # LOADING EMBEDDING CATEGORICAL VARIABLES emb_cat_past, emb_cat_fut = self.cat_categorical_vars(batch) emb_cat_past = torch.mean(emb_cat_past, dim = 2) emb_cat_fut = torch.mean(emb_cat_fut, dim = 2) ### LOADING PAST AND FUTURE NUMERICAL VARIABLES # load in the model auxiliar numerical variables if self.aux_past_channels>0: # if we have more numerical variables about past aux_num_past = self.remove_var(num_past, idx_target, 2) # remove the autoregressive variable assert self.aux_past_channels == aux_num_past.size(2), beauty_string(f"{self.aux_past_channels} LAYERS FOR PAST VARS AND {aux_num_past.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about past # concat all embedded vars and mean of them aux_emb_num_past = torch.Tensor().to(self.device) for i, layer in enumerate(self.linear_aux_past): aux_emb_past = layer(aux_num_past[:,:,[i]]).unsqueeze(2) aux_emb_num_past = torch.cat((aux_emb_num_past, aux_emb_past), dim=2) aux_emb_num_past = torch.mean(aux_emb_num_past, dim = 2) else: aux_emb_num_past = None # non available vars if self.aux_fut_channels>0: # if we have more numerical variables about future # AUX means AUXILIARY variables aux_num_fut = batch['x_num_future'].to(self.device) assert self.aux_fut_channels == aux_num_fut.size(2), beauty_string(f"{self.aux_fut_channels} LAYERS FOR PAST VARS AND {aux_num_fut.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about fut # concat all embedded vars and mean of them aux_emb_num_fut = torch.Tensor().to(self.device) for j, layer in enumerate(self.linear_aux_fut): aux_emb_fut = layer(aux_num_fut[:,:,[j]]).unsqueeze(2) aux_emb_num_fut = torch.cat((aux_emb_num_fut, aux_emb_fut), dim=2) aux_emb_num_fut = torch.mean(aux_emb_num_fut, dim = 2) else: aux_emb_num_fut = None # non available vars # past^tilda if self.aux_past_channels>0: emb_past = torch.cat((emb_cat_past, aux_emb_num_past), dim = 2) # [B, L, 2R] # proj_past = self.feat_proj_past(emb_past, True) # [B, L, R^tilda] # else: proj_past = self.feat_proj_past(emb_cat_past, True) # [B, L, R^tilda] # # fut^tilda if self.aux_fut_channels>0: emb_fut = torch.cat((emb_cat_fut, aux_emb_num_fut), dim = 2) # [B, H, 2R] # proj_fut = self.feat_proj_fut(emb_fut, True) # [B, H, R^tilda] # else: proj_fut = self.feat_proj_fut(emb_cat_fut, True) # [B, H, R^tilda] # concat = torch.cat((y_past.view(B, -1), proj_past.view(B, -1), proj_fut.view(B, -1)), dim = 1) # [B, L*self.mul + (L+H)*R^tilda] # dense_enc = self.first_encoder(concat) for lay_enc in self.aux_encoder: dense_enc = lay_enc(dense_enc) dense_dec = self.first_decoder(dense_enc) for lay_dec in self.aux_decoder: dense_dec = lay_dec(dense_dec) temp_dec_input = torch.cat((dense_dec.view(B, self.future_steps, self.d_model), proj_fut), dim = 2) temp_dec_output = self.temporal_decoder(temp_dec_input, False) temp_dec_output = temp_dec_output.view(B, self.future_steps, self.output_channels, self.mul) linear_regr = self.linear_target(y_past.view(B, -1)) linear_output = linear_regr.view(B, self.future_steps, self.output_channels, self.mul) output = temp_dec_output + linear_output return output
# function to concat embedded categorical variables
[docs] def cat_categorical_vars(self, batch:dict): """Extracting categorical context about past and future Args: batch (dict): Keys checked -> ['x_cat_past', 'x_cat_future'] Returns: List[torch.Tensor, torch.Tensor]: cat_emb_past, cat_emb_fut """ cat_past = None cat_fut = None # GET AVAILABLE CATEGORICAL CONTEXT if 'x_cat_past' in batch.keys(): cat_past = batch['x_cat_past'].to(self.device) if 'x_cat_future' in batch.keys(): cat_fut = batch['x_cat_future'].to(self.device) # CONCAT THEM, according to self.emb_cat_var usage if cat_past is None: emb_cat_full = self.emb_cat_var(batch['x_num_past'].shape[0],self.device) else: cat_full = torch.cat((cat_past, cat_fut), dim = 1) emb_cat_full = self.emb_cat_var(cat_full,self.device) cat_emb_past = emb_cat_full[:,:self.past_steps,:,:] cat_emb_fut = emb_cat_full[:,-self.future_steps:,:,:] return cat_emb_past, cat_emb_fut
#function to extract from batch['x_num_past'] all variables except the one autoregressive
[docs] def remove_var(self, tensor: torch.Tensor, indexes_to_exclude: list, dimension: int)-> torch.Tensor: """Function to remove variables from tensors in chosen dimension and position Args: tensor (torch.Tensor): starting tensor indexes_to_exclude (list): index of the chosen dimension we want t oexclude dimension (int): dimension of the tensor on which we want to work (not list od dims!!) Returns: torch.Tensor: new tensor without the chosen variables """ remaining_idx = torch.tensor([i for i in range(tensor.size(dimension)) if i not in indexes_to_exclude]).to(tensor.device) # Select the desired sub-tensor extracted_subtensors = torch.index_select(tensor, dim=dimension, index=remaining_idx) return extracted_subtensors
[docs] class ResidualBlock(nn.Module):
[docs] def __init__(self, in_size:int, out_size:int, dropout_rate:float, activation_fun:str=''): """Residual Block as basic layer of the archetecture. MLP with one hidden layer, activation and skip connection Basically dimension d_model, but better if input_dim and output_dim are explicit in_size and out_size to handle dimensions at different stages of the NN Args: in_size (int): out_size (int): dropout_rate (float): activation_fun (str, optional): activation function to use in the Residual Block. Defaults to nn.ReLU. """ super().__init__() self.direct_linear = nn.Linear(in_size, out_size, bias = False) if activation_fun=='': self.act = nn.ReLU() else: activation = eval(activation_fun) self.act = activation() self.lin = nn.Linear(in_size, out_size) self.dropout = nn.Dropout(dropout_rate) self.final_norm = nn.LayerNorm(out_size)
[docs] def forward(self, x, apply_final_norm = True): direct_x = self.direct_linear(x) x = self.dropout(self.lin(self.act(x))) out = x + direct_x if apply_final_norm: return self.final_norm(out) return out