Source code for dsipts.models.CrossFormer
## Copyright 2023 Yunhao Zhang and Junchi Yan (https://github.com/Thinklab-SJTU/Crossformer?tab=Apache-2.0-1-ov-file#readme)
## Code modified for align the notation and the batch generation
## extended to all present in crossformer folder
from torch import nn
import torch
from .base import Base
from typing import List,Union
from einops import repeat
from ..data_structure.utils import beauty_string
from .utils import get_scope
from .crossformer.cross_encoder import Encoder
from .crossformer.cross_decoder import Decoder
from .crossformer.cross_embed import DSW_embedding
from math import ceil
# self, past_channels, past_steps, future_steps, seg_len, win_size = 4,
# factor=10, d_model=512, hidden_size = 1024, n_head=8, n_layer_encoder=3,
# dropout=0.0, baseline = False,
[docs]
class CrossFormer(Base):
handle_multivariate = True
handle_future_covariates = False
handle_categorical_variables = False
handle_quantile_loss = False
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs]
def __init__(self,
past_steps:int,
future_steps:int,
past_channels:int,
future_channels:int,
d_model:int,
embs:List[int],
hidden_size:int,
n_head:int,
seg_len:int,
n_layer_encoder:int,
win_size:int,
out_channels:int,
factor:int=5,
remove_last = False,
persistence_weight:float=0.0,
loss_type: str='l1',
quantiles:List[int]=[],
dropout_rate:float=0.1,
optim:Union[str,None]=None,
optim_config:dict=None,
scheduler_config:dict=None,
**kwargs)->None:
"""CroosFormer (https://openreview.net/forum?id=vSVLM2j9eie)
Args:
past_steps (int): number of past datapoints used , not used here
future_steps (int): number of future lag to predict
past_channels (int): number of numeric past variables, must be >0
future_channels (int): number of future numeric variables
d_model (int): dimension of the attention model
embs (List): list of the initial dimension of the categorical variables
hidden_size (int): hidden size of the linear block
n_head (int): number of heads
seg_len (int): segment length (L_seg) see the paper for more details
n_layer_encoder (int): layers to use in the encoder
win_size (int): window size for segment merg
factor (int): num of routers in Cross-Dimension Stage of TSA (c) see the paper
remove_last (boolean,optional): if true the model try to predic the difference respect the last observation.
out_channels (int): number of output channels
persistence_weight (float): weight controlling the divergence from persistence model. Default 0
loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1,
loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1,
quantiles (List[int], optional): NOT USED YET
dropout_rate (float, optional): dropout rate in Dropout layers. Defaults to 0.1.
optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam.
optim_config (dict, optional): configuration for Adam optimizer. Defaults to None.
scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None.
"""
super().__init__(**kwargs)
self.save_hyperparameters(logger=False)
self.use_quantiles = False
self.optim = optim
self.optim_config = optim_config
self.scheduler_config = scheduler_config
self.loss_type = loss_type
self.persistence_weight = persistence_weight
self.remove_last = remove_last
if self.loss_type == 'mse':
self.loss = nn.MSELoss()
else:
self.loss = nn.L1Loss()
self.future_steps = future_steps
# The padding operation to handle invisible sgemnet length
self.pad_past_steps = ceil(1.0 *past_steps / seg_len) * seg_len
self.pad_future_steps = ceil(1.0 * future_steps / seg_len) * seg_len
self.past_steps_add = self.pad_past_steps - past_steps
# Embedding
self.enc_value_embedding = DSW_embedding(seg_len, d_model)
self.enc_pos_embedding = nn.Parameter(torch.randn(1, past_channels, (self.pad_past_steps // seg_len), d_model))
self.pre_norm = nn.LayerNorm(d_model)
# Encoder
self.encoder = Encoder(n_layer_encoder, win_size, d_model, n_head, hidden_size, block_depth = 1, \
dropout = dropout_rate,in_seg_num = (self.pad_past_steps // seg_len), factor = factor)
# Decoder
self.dec_pos_embedding = nn.Parameter(torch.randn(1, past_channels, (self.pad_future_steps // seg_len), d_model))
self.decoder = Decoder(seg_len, n_layer_encoder + 1, d_model, n_head, hidden_size, dropout_rate, \
out_seg_num = (self.pad_future_steps // seg_len), factor = factor)
[docs]
def forward(self, batch):
idx_target = batch['idx_target'][0]
x_seq = batch['x_num_past'].to(self.device)#[:,:,idx_target]
if self.remove_last:
x_start = x_seq[:,-1,:].unsqueeze(1)
x_seq[:,:,:]-=x_start
batch_size = x_seq.shape[0]
if (self.past_steps_add != 0):
x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.past_steps_add, -1), x_seq), dim = 1)
x_seq = self.enc_value_embedding(x_seq)
x_seq += self.enc_pos_embedding
x_seq = self.pre_norm(x_seq)
enc_out = self.encoder(x_seq)
dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size)
predict_y = self.decoder(dec_in, enc_out)
res = predict_y[:, :self.future_steps,:].unsqueeze(3)
if self.remove_last:
res+=x_start.unsqueeze(1)
return res[:, :,idx_target,:]