Source code for dsipts.models.LinearTS


## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
## Code modified for align the notation and the batch generation
## extended to all present in informer, autoformer folder

from torch import nn
import torch
from .base import  Base
from .utils import QuantileLossMO, get_activation
from typing import List, Union
from ..data_structure.utils import beauty_string
from .utils import  get_scope
  
    


[docs] class moving_avg(nn.Module): """ Moving average block to highlight the trend of time series """
[docs] def __init__(self, kernel_size, stride): super(moving_avg, self).__init__() self.kernel_size = kernel_size self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
[docs] def forward(self, x): # padding on the both ends of time series front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) x = torch.cat([front, x, end], dim=1) x = self.avg(x.permute(0, 2, 1)) x = x.permute(0, 2, 1) return x
[docs] class series_decomp(nn.Module): """ Series decomposition block """
[docs] def __init__(self, kernel_size): super(series_decomp, self).__init__() self.moving_avg = moving_avg(kernel_size, stride=1)
[docs] def forward(self, x): moving_mean = self.moving_avg(x) res = x - moving_mean return res, moving_mean
[docs] class LinearTS(Base): handle_multivariate = True handle_future_covariates = True handle_categorical_variables = True handle_quantile_loss = True description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss) description+='\n THE SIMPLE IMPLEMENTATION DOES NOT USE CATEGORICAL NOR FUTURE VARIABLES'
[docs] def __init__(self, past_steps:int, future_steps:int, past_channels:int, future_channels:int, embs:List[int], cat_emb_dim:int, kernel_size:int, sum_emb:bool, out_channels:int, hidden_size:int, dropout_rate:float=0.1, activation:str='torch.nn.ReLU', kind:str='linear', use_bn:bool=False, persistence_weight:float=0.0, loss_type: str='l1', quantiles:List[int]=[], n_classes:int=0, optim:Union[str,None]=None, optim_config:dict=None, scheduler_config:dict=None, simple:bool=False, **kwargs)->None: """Linear model from https://github.com/cure-lab/LTSF-Linear/blob/main/run_longExp.py Args: past_steps (int): number of past datapoints used future_steps (int): number of future lag to predict past_channels (int): number of numeric past variables, must be >0 future_channels (int): number of future numeric variables embs (List[int]): list of the initial dimension of the categorical variables cat_emb_dim (int): final dimension of each categorical variable kernel_size (int): kernel dimension for initial moving average sum_emb (bool): if true the contribution of each embedding will be summed-up otherwise stacked out_channels (int): number of output channels hidden_size (int): hidden size of the lienar block dropout_rate (float, optional): dropout rate in Dropout layers. Default 0.1 activation (str, optional): activation fuction function pytorch. Default torch.nn.ReLU kind (str, optional): one among linear, dlinear (de-trending), nlinear (differential). Defaults to 'linear'. use_bn (bool, optional): if true BN layers will be added and dropouts will be removed. Default False quantiles (List[int], optional): we can use quantile loss il len(quantiles) = 0 (usually 0.1,0.5, 0.9) or L1loss in case len(quantiles)==0. Defaults to []. persistence_weight (float): weight controlling the divergence from persistence model. Default 0 loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1, n_classes (int): number of classes (0 in regression) optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam. optim_config (dict, optional): configuration for Adam optimizer. Defaults to None. scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None. simple (bool, optional): if simple, the model used is the same that the one illustrated in the paper, otherwise it is a more complicated one with the same idea """ super().__init__(**kwargs) if activation == 'torch.nn.SELU': beauty_string('SELU do not require BN','info',self.verbose) use_bn = False if isinstance(activation, str): activation = get_activation(activation) else: beauty_string('There is a bug in pytorch lightening, the constructior is called twice','info',self.verbose) self.save_hyperparameters(logger=False) self.past_steps = past_steps self.future_steps = future_steps self.kind = kind self.past_channels = past_channels self.future_channels = future_channels self.embs = nn.ModuleList() self.sum_emb = sum_emb self.persistence_weight = persistence_weight self.loss_type = loss_type self.simple = simple if n_classes==0: self.is_classification = False if len(quantiles)>0: self.use_quantiles = True self.mul = len(quantiles) self.loss = QuantileLossMO(quantiles) else: self.use_quantiles = False self.mul = 1 if self.loss_type == 'mse': self.loss = nn.MSELoss() else: self.loss = nn.L1Loss() else: self.is_classification = True self.use_quantiles = False self.mul = n_classes self.loss = torch.nn.CrossEntropyLoss() #assert out_channels==1, "Classification require only one channel" emb_channels = 0 self.optim = optim self.optim_config = optim_config self.scheduler_config = scheduler_config for k in embs: self.embs.append(nn.Embedding(k+1,cat_emb_dim)) emb_channels+=cat_emb_dim if sum_emb and (emb_channels>0): emb_channels = cat_emb_dim beauty_string('Using sum','info',self.verbose) else: beauty_string('Using stacked',"info",self.verbose) ## ne faccio uno per ogni canale self.linear = nn.ModuleList() if kind=='dlinear': self.decompsition = series_decomp(kernel_size) self.Linear_Trend = nn.ModuleList() for _ in range(out_channels): self.Linear_Trend.append(nn.Linear(past_steps,future_steps)) for _ in range(out_channels): if simple: self.linear.append(nn.Linear(past_steps,self.future_steps*self.mul)) else: self.linear.append(nn.Sequential(nn.Linear(emb_channels*(past_steps+future_steps)+past_steps*past_channels+future_channels*future_steps,hidden_size), activation(), nn.BatchNorm1d(hidden_size) if use_bn else nn.Dropout(dropout_rate) , nn.Linear(hidden_size,hidden_size//2), activation(), nn.BatchNorm1d(hidden_size//2) if use_bn else nn.Dropout(dropout_rate) , nn.Linear(hidden_size//2,hidden_size//4), activation(), nn.BatchNorm1d(hidden_size//4) if use_bn else nn.Dropout(dropout_rate) , nn.Linear(hidden_size//4,hidden_size//8), activation(), nn.BatchNorm1d(hidden_size//8) if use_bn else nn.Dropout(dropout_rate) , nn.Linear(hidden_size//8,self.future_steps*self.mul)))
[docs] def forward(self, batch): x = batch['x_num_past'].to(self.device) idx_target = batch['idx_target'][0] if self.kind=='nlinear': x_start = x[:,-1,idx_target].unsqueeze(1) ##BxC x[:,:,idx_target]-=x_start if self.kind=='alinear': x[:,:,idx_target]=0 if self.kind=='dlinear': x_start = x[:,:,idx_target] seasonal_init, trend_init = self.decompsition(x_start) seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1) x[:,:,idx_target] = seasonal_init.permute(0,2,1) tmp = [] for j in range(len(self.Linear_Trend)): tmp.append(self.Linear_Trend[j](trend_init[:,j,:])) trend = torch.stack(tmp,2) if self.simple is False: if 'x_cat_future' in batch.keys(): cat_future = batch['x_cat_future'].to(self.device) if 'x_cat_past' in batch.keys(): cat_past = batch['x_cat_past'].to(self.device) if 'x_num_future' in batch.keys(): x_future = batch['x_num_future'].to(self.device) else: x_future = None tmp = [x] tmp_emb = None for i in range(len(self.embs)): if self.sum_emb: if i>0: tmp_emb+=self.embs[i](cat_past[:,:,i]) else: tmp_emb=self.embs[i](cat_past[:,:,i]) else: tmp.append(self.embs[i](cat_past[:,:,i])) if self.sum_emb and (len(self.embs)>0): tmp.append(tmp_emb) ##BxLxC tot_past = torch.cat(tmp,2).flatten(1) tmp = [] for i in range(len(self.embs)): if self.sum_emb: if i>0: tmp_emb+=self.embs[i](cat_future[:,:,i]) else: tmp_emb=self.embs[i](cat_future[:,:,i]) else: tmp.append(self.embs[i](cat_future[:,:,i])) if self.sum_emb and (len(self.embs)): tmp.append(tmp_emb) if x_future is not None: tmp.append(x_future) if len(tmp)>0: tot_future = torch.cat(tmp,2).flatten(1) tot = torch.cat([tot_past,tot_future],1) else: tot = tot_past tot = tot.unsqueeze(2).repeat(1,1,len(self.linear)).permute(0,2,1) else: tot = seasonal_init res = [] B = tot.shape[0] for j in range(len(self.linear)): res.append(self.linear[j](tot[:,j,:]).reshape(B,-1,self.mul)) ## BxLxCxMUL res = torch.stack(res,2) if self.kind=='nlinear': #res BxLxCx3 #start BxCx1 res+=x_start.unsqueeze(1) if self.kind=='dlinear': res = res+trend.unsqueeze(3) return res