Source code for dsipts.models.Persistent


from torch import nn
from .base import  Base
from .utils import L1Loss
from ..data_structure.utils import beauty_string
from .utils import  get_scope

[docs] class Persistent(Base): handle_multivariate = True handle_future_covariates = False handle_categorical_variables = False handle_quantile_loss = False description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs] def __init__(self, future_steps:int, past_steps:int, loss_type:str=None,#not used but needed persistence_weight:float=0.1,#not used but needed optim_config:dict=None, scheduler_config:dict=None, **kwargs)->None: """Persistent model propagatinng last observed values Args: future_steps (int): number of future lag to predict past_steps (int): number of future lag to predict. Useless but needed for the other stuff optim_config (dict, optional): configuration for Adam optimizer. Defaults to None. Usless for this model scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None. Usless for this model """ super().__init__(**kwargs) self.save_hyperparameters(logger=False) self.past_steps = past_steps self.future_steps = future_steps self.optim = None self.optim_config = optim_config self.scheduler_config = scheduler_config self.loss = L1Loss() self.fake = nn.Linear(1,1) self.use_quantiles = False self.loss_type = 'l1' self.loss = nn.L1Loss()
[docs] def forward(self, batch): """It is mandatory to implement this method Args: batch (dict): batch of the dataloader Returns: torch.tensor: result """ x = batch['x_num_past'].to(self.device) idx_target = batch['idx_target'][0] x_start = x[:,-1,idx_target].unsqueeze(1) #this is B,1,C #[B,L,C,1] remember the outoput size res = x_start.repeat(1,self.future_steps,1).unsqueeze(3) return res