Source code for dsipts.models.PatchTST

## Copyright https://github.com/yuqinie98/PatchTST/blob/main/LICENSE
## Modified for notation alignmenet and batch structure
## extended to what inside patchtst folder



from torch import  nn
import torch
from .base import Base
from typing import List,Union
from ..data_structure.utils import beauty_string
from .utils import  get_scope
from .utils import  get_activation
from .patchtst.layers import series_decomp, PatchTST_backbone



  
[docs] class PatchTST(Base): handle_multivariate = True handle_future_covariates = False handle_categorical_variables = True handle_quantile_loss = False description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs] def __init__(self, past_steps:int, future_steps:int, patch_len: int, past_channels:int, future_channels:int, out_channels:int, d_model:int, embs:List[int], kernel_size:int, decomposition:bool=True, activation:str='torch.nn.ReLU', n_head:int=1, n_layer:int=2, stride:int=8, remove_last:bool = False, hidden_size:int=1048, persistence_weight:float=0.0, loss_type: str='l1', quantiles:List[int]=[], dropout_rate:float=0.1, optim:Union[str,None]=None, optim_config:dict=None, scheduler_config:dict=None, **kwargs)->None: """ Args: past_steps (int): number of past datapoints used , not used here future_steps (int): number of future lag to predict patch_len (int): patch_len past_channels (int): number of numeric past variables, must be >0 future_channels (int): number of future numeric variables out_channels (int): number of output channels d_model (int): dimension of the attention model embs (List): list of the initial dimension of the categorical variables embed_type (int): type of embedding kernel_size (int): kernel_size activation (str, optional): activation fuction function pytorch. Default torch.nn.ReLU n_head (int, optional): number of heads n_layer (int, optional): number of encoding layers remove_last (boolean,optional): if true the model try to predic the difference respect the last observation. out_channels (int): number of output channels persistence_weight (float): weight controlling the divergence from persistence model. Default 0 loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1, quantiles (List[int], optional): NOT USED YET dropout_rate (float, optional): dropout rate in Dropout layers. Defaults to 0.1. optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam. optim_config (dict, optional): configuration for Adam optimizer. Defaults to None. scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None. """ super().__init__(**kwargs) if activation == 'torch.nn.SELU': beauty_string('SELU do not require BN','info',self.verbose) if isinstance(activation, str): activation = get_activation(activation) else: beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose) self.save_hyperparameters(logger=False) self.use_quantiles = False self.optim = optim self.optim_config = optim_config self.scheduler_config = scheduler_config self.loss_type = loss_type self.persistence_weight = persistence_weight self.remove_last = remove_last self.future_steps = future_steps ##this is mandatory if self.loss_type == 'mse': self.loss = nn.MSELoss() else: self.loss = nn.L1Loss() self.embs = nn.ModuleList() emb_channels = 0 for k in embs: self.embs.append(nn.Embedding(k+1,d_model)) emb_channels = d_model past_channels+=emb_channels # model self.decomposition = decomposition if self.decomposition: self.decomp_module = series_decomp(kernel_size) self.model_trend = PatchTST_backbone(c_in=past_channels, context_window = past_steps, target_window=future_steps, patch_len=patch_len, stride=stride, max_seq_len=past_steps+future_steps, n_layers=n_layer, d_model=d_model, n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate, dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False, store_attn=False, pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end', pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False, subtract_last=remove_last, verbose=False) self.model_res = PatchTST_backbone(c_in=past_channels, context_window = past_steps, target_window=future_steps, patch_len=patch_len, stride=stride, max_seq_len=past_steps+future_steps, n_layers=n_layer, d_model=d_model, n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate, dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False, store_attn=False, pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end', pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False, subtract_last=remove_last, verbose=False) else: self.model = PatchTST_backbone(c_in=past_channels, context_window = past_steps, target_window=future_steps, patch_len=patch_len, stride=stride, max_seq_len=past_steps+future_steps, n_layers=n_layer, d_model=d_model, n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate, dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False, store_attn=False, pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end', pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False, subtract_last=remove_last, verbose=False)
#self.final_linear = nn.Sequential(nn.Linear(past_channels,past_channels//2),activation(),nn.Dropout(dropout_rate), nn.Linear(past_channels//2,out_channels) )
[docs] def forward(self, batch): # x: [Batch, Input length, Channel] x_seq = batch['x_num_past'].to(self.device)#[:,:,idx_target] if 'x_cat_past' in batch.keys(): cat_past = batch['x_cat_past'].to(self.device) tot = [x_seq] if 'x_cat_past' in batch.keys(): tmp_emb = None for i in range(len(self.embs)): if i>0: tmp_emb+=self.embs[i](cat_past[:,:,i]) else: tmp_emb=self.embs[i](cat_past[:,:,i]) tot.append(tmp_emb) x_seq = torch.cat(tot,axis=2) if self.decomposition: res_init, trend_init = self.decomp_module(x_seq) res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length] res = self.model_res(res_init) trend = self.model_trend(trend_init) x = res + trend x = x.permute(0,2,1) # x: [Batch, Input length, Channel] else: x = x_seq.permute(0,2,1)# x: [Batch, Channel, Input length] x = self.model(x) x = x.permute(0,2,1) # x: [Batch, Input length, Channel] res = x.unsqueeze(3) idx_target = batch['idx_target'][0] return res[:, :,idx_target,:]