import torch
import torch.nn as nn
from .tft import sub_nn
from .base import Base
from .utils import QuantileLossMO
from typing import List, Union
from ..data_structure.utils import beauty_string
from .utils import get_scope
[docs]
class TFT(Base):
handle_multivariate = True
handle_future_covariates = True
handle_categorical_variables = True
handle_quantile_loss = True
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs]
def __init__(self,
d_model: int,
out_channels:int,
past_steps:int,
future_steps: int,
past_channels:int,
future_channels:int,
num_layers_RNN: int,
embs: list[int],
d_head: int,
n_head: int,
dropout_rate: float,
persistence_weight:float=0.0,
loss_type: str='l1',
quantiles:List[float]=[],
optim:Union[str,None]=None,
optim_config:dict=None,
scheduler_config:dict=None,
**kwargs)->None:
"""TEMPORAL FUSION TRANSFORMER - Multi-Horizon TimeSeries Forecasting
- Direct Model: predicting all future step at once..
- Multi-Output Forecasting: predicting one or more variables.
- Multi-Horizon Forecasting: predicting variables at multiple future time steps.
- Attention based: Enhance selection of relevant time steps in the past and learn long-term dependencies. Weights of attention as importance magnitude for each head.
- RNN Enrichment: Enpowering the initial autoregressive process. The RNN (here LSTM) provides an initial approximation of the target varible(s), then improved by the rest of th Net.
- Gating Mechanisms: Minimize the contribution of irrelevant variables.
- Prediction Intervals (Quantile Regression): Outputting percentiles at each timestep. [10th, 50th, 90th] usually.
TFT facilitates Interpretability identifying:
- Global importance of variables for past and for future
- Temporal patterns
- Significant events
Args:
d_model (int): general hidden dimension across the Net. Could be changed in subNets
out_channels (int): number of variables to predict
past_steps (int): steps of the look-back window
future_steps (int): steps in the future to be predicted
past_channels (int): total number of variables available in the past
future_channels (int): total number of variables available in the future
num_layers_RNN (int): number of layers for recurrent NN (here LSTM)
embs (list[int]): embedding dimensions for added categorical variables (here for pos_seq, is_fut, pos_fut)
d_head (int): attention head dimension
n_head (int): number of attention heads
dropout_rate (float): dropout. Common rate for all dropout layers used.
persistence_weight (float, optional): ASK TO GOBBI. Defaults to 0.0.
loss_type (str, optional): Type of loss for prediction. Defaults to 'l1'.
quantiles (List[float], optional): list of quantiles to predict. If empty, only the exact value. Only empty list or lisst of len 3 allowed. Defaults to [].
optim (Union[str,None], optional): ASK TO GOBBI. Defaults to None.
optim_config (dict, optional): ASK TO GOBBI. Defaults to None.
scheduler_config (dict, optional): ASK TO GOBBI. Defaults to None.
"""
super().__init__(**kwargs)
self.save_hyperparameters(logger=False)
# assert out_channels==1, logging.info("ONLY ONE CHANNEL IMPLEMENTED")
self.future_steps = future_steps
self.d_model = d_model
self.out_channels = out_channels
# linear to embed the target vartiable
self.target_linear = nn.Linear(out_channels, d_model) # same for past and fut! (same variable)
# number of variables in the past different from the target one(s)
self.aux_past_channels = past_channels - out_channels # -1 because one channel is occupied by the target variable
# one linear for each auxiliar past var
self.linear_aux_past = nn.ModuleList([nn.Linear(1, d_model) for _ in range(self.aux_past_channels)])
# number of variables in the future used to predict the target one(s)
self.aux_fut_channels = future_channels
# one linear for each auxiliar future var
self.linear_aux_fut = nn.ModuleList([nn.Linear(1, d_model) for _ in range(self.aux_fut_channels)])
# length of the full sequence, parameter used for the embedding of all categorical variables
# - we assume that these are no available or available both for past and future
seq_len = past_steps+future_steps
self.emb_cat_var = sub_nn.embedding_cat_variables(seq_len, future_steps, d_model, embs, self.device)
# Recurrent Neural Network for first aproximated inference of the target variable(s) - IT IS NON RE-EMBEDDED YET
self.rnn = sub_nn.LSTM_Model(num_var=out_channels,
d_model = d_model,
pred_step = future_steps,
num_layers = num_layers_RNN,
dropout = dropout_rate)
# PARTS OF TFT:
# - Residual connections
# - Gated Residual Network
# - Interpretable MultiHead Attention
self.res_conn1_past = sub_nn.ResidualConnection(d_model, dropout_rate)
self.res_conn1_fut = sub_nn.ResidualConnection(d_model, dropout_rate)
self.grn1_past = sub_nn.GRN(d_model, dropout_rate)
self.grn1_fut = sub_nn.GRN(d_model, dropout_rate)
self.InterpretableMultiHead = sub_nn.InterpretableMultiHead(d_model, d_head, n_head)
self.res_conn2_att = sub_nn.ResidualConnection(d_model, dropout_rate)
self.grn2_att = sub_nn.GRN(d_model, dropout_rate)
self.res_conn3_out = sub_nn.ResidualConnection(d_model, dropout_rate)
self.persistence_weight = persistence_weight
self.loss_type = loss_type
# output, handling quantiles or not
assert (len(quantiles) ==0) or (len(quantiles)==3), beauty_string('Only 3 quantiles are availables, otherwise set quantiles=[]','block',True)
if len(quantiles)==0:
self.mul = 1
self.use_quantiles = False
self.outLinear = nn.Linear(d_model, out_channels)
if self.loss_type == 'mse':
self.loss = nn.MSELoss()
else:
self.loss = nn.L1Loss()
else:
assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
self.mul = len(quantiles)
self.use_quantiles = True
self.outLinear = nn.Linear(d_model, out_channels*len(quantiles))
self.loss = QuantileLossMO(quantiles)
self.optim = optim
self.optim_config = optim_config
self.scheduler_config = scheduler_config
[docs]
def forward(self, batch:dict) -> torch.Tensor:
"""Temporal Fusion Transformer
Collectiong Data
- Extract the autoregressive variable(s)
- Embedding and compute a first approximated prediction
- 'summary_past' and 'summary_fut' collecting data about past and future
Concatenating on the dimension 2 all different datas, which will be mixed through a MEAN over that imension
Info get from other tensor of the batch taken as input
TFT actual computations
- Residual Connection for y_past and summary_past
- Residual Connection for y_fut and summary_fut
- GRN1 for past and for fut
- ATTENTION(summary_fut, summary_past, y_past)
- Residual Connection for attention itself
- GRN2 for attention
- Residual Connection for attention and summary_fut
- Linear for actual values and reshape
Args:
batch (dict): Keys used are ['x_num_past', 'idx_target', 'x_num_future', 'x_cat_past', 'x_cat_future']
Returns:
torch.Tensor: shape [B, self.future_steps, self.out_channels, self.mul] or [B, self.future_steps, self.out_channels] according to quantiles
"""
num_past = batch['x_num_past'].to(self.device)
# PAST TARGET NUMERICAL VARIABLE
# always available: autoregressive variable
# compute rnn prediction
idx_target = batch['idx_target'][0]
target_num_past = num_past[:,:,idx_target]
target_emb_num_past = self.target_linear(target_num_past) # target_variables comunicating with each others
target_num_fut_approx = self.rnn(target_emb_num_past)
# embed future predictions
target_emb_num_fut_approx = self.target_linear(target_num_fut_approx)
### create variable summary_past and summary_fut
# at the beggining it is composed only by past and future target variable
summary_past = target_emb_num_past.unsqueeze(2)
summary_fut = target_emb_num_fut_approx.unsqueeze(2)
# now we search for others categorical and numerical variables!
### PAST NUMERICAL VARIABLES
if self.aux_past_channels>0: # so we have more numerical variables about past
# AUX = AUXILIARY variables
aux_num_past = self.remove_var(num_past, idx_target, 2) # remove the target index on the second dimension
assert self.aux_past_channels == aux_num_past.size(2), beauty_string(f"{self.aux_past_channels} LAYERS FOR PAST VARS AND {aux_num_past.shape(2)} VARS",'section',True) # to check if we are using the expected number of variables about past
aux_emb_num_past = torch.Tensor().to(aux_num_past.device)
for i, layer in enumerate(self.linear_aux_past):
aux_emb_past = layer(aux_num_past[:,:,[i]]).unsqueeze(2)
aux_emb_num_past = torch.cat((aux_emb_num_past, aux_emb_past), dim=2)
## update summary about past
summary_past = torch.cat((summary_past, aux_emb_num_past), dim=2)
### FUTURE NUMERICAL VARIABLES
if self.aux_fut_channels>0: # so we have more numerical variables about future
aux_num_fut = batch['x_num_future'].to(self.device)
assert self.aux_fut_channels == aux_num_fut.size(2), beauty_string(f"{self.aux_fut_channels} LAYERS FOR PAST VARS AND {aux_num_fut.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about fut
aux_emb_num_fut = torch.Tensor().to(aux_num_fut.device)
for j, layer in enumerate(self.linear_aux_fut):
aux_emb_fut = layer(aux_num_fut[:,:,[j]]).unsqueeze(2)
aux_emb_num_fut = torch.cat((aux_emb_num_fut, aux_emb_fut), dim=2)
## update summary about future
summary_fut = torch.cat((summary_fut, aux_emb_num_fut), dim=2)
### CATEGORICAL VARIABLES
if 'x_cat_past' in batch.keys() and 'x_cat_future' in batch.keys(): # if we have both
# HERE WE ASSUME SAME NUMBER AND KIND OF VARIABLES IN PAST AND FUTURE
cat_past = batch['x_cat_past'].to(self.device)
cat_fut = batch['x_cat_future'].to(self.device)
cat_full = torch.cat((cat_past, cat_fut), dim = 1)
# EMB CATEGORICAL VARIABLES AND THEN SPLIT IN PAST AND FUTURE
emb_cat_full = self.emb_cat_var(cat_full,self.device)
else:
emb_cat_full = self.emb_cat_var(num_past.shape[0],self.device)
cat_emb_past = emb_cat_full[:,:-self.future_steps,:,:]
cat_emb_fut = emb_cat_full[:,-self.future_steps:,:,:]
## update summary
# past
summary_past = torch.cat((summary_past, cat_emb_past), dim=2)
# future
summary_fut = torch.cat((summary_fut, cat_emb_fut), dim=2)
# >>> PAST:
summary_past = torch.mean(summary_past, dim=2)
# >>> FUTURE:
summary_fut = torch.mean(summary_fut, dim=2)
### Residual Connection from LSTM
summary_past = self.res_conn1_past(summary_past, target_emb_num_past)
summary_fut = self.res_conn1_fut(summary_fut, target_emb_num_fut_approx)
### GRN1
summary_past = self.grn1_past(summary_past)
summary_fut = self.grn1_fut(summary_fut)
### INTERPRETABLE MULTI HEAD ATTENTION
attention = self.InterpretableMultiHead(summary_fut, summary_past, target_emb_num_past)
### Residual Connection from ATT
attention = self.res_conn2_att(attention, attention)
### GRN
attention = self.grn2_att(attention)
### Resuidual Connection from GRN1
out = self.res_conn3_out(attention, summary_fut)
### OUT
out = self.outLinear(out)
if self.mul>0:
out = out.view(-1, self.future_steps, self.out_channels, self.mul)
return out
#function to extract from batch['x_num_past'] all variables except the one autoregressive
[docs]
def remove_var(self, tensor: torch.Tensor, indexes_to_exclude: int, dimension: int)-> torch.Tensor:
"""Function to remove variables from tensors in chosen dimension and position
Args:
tensor (torch.Tensor): starting tensor
indexes_to_exclude (int): index of the chosen dimension we want t oexclude
dimension (int): dimension of the tensor on which we want to work
Returns:
torch.Tensor: new tensor without the chosen variables
"""
remaining_idx = torch.tensor([i for i in range(tensor.size(dimension)) if i not in indexes_to_exclude]).to(tensor.device)
# Select the desired sub-tensor
extracted_subtensors = torch.index_select(tensor, dim=dimension, index=remaining_idx)
return extracted_subtensors