dsipts.models.base module¶
- class dsipts.models.base.Base(verbose)[source]¶
Bases:
LightningModuleThis is the basic model, each model implemented must overwrite the init method and the forward method. The inference step is optional, by default it uses the forward method but for recurrent network you should implement your own method
- handle_multivariate = False¶
- handle_future_covariates = False¶
- handle_categorical_variables = False¶
- handle_quantile_loss = False¶
- description = 'Can NOT handle multivariate output \nCan NOT handle future covariates\nCan NOT handle categorical covariates\nCan NOT handle Quantile loss function'¶
- abstractmethod __init__(verbose)[source]¶
This is the basic model, each model implemented must overwrite the init method and the forward method. The inference step is optional, by default it uses the forward method but for recurrent network you should implement your own method
- abstractmethod forward(batch)[source]¶
Forlward method used during the training loop
- Parameters:
batch (dict) – the batch structure. The keys are: y : the target variable(s). This is always present x_num_past: the numerical past variables. This is always present x_num_future: the numerical future variables x_cat_past: the categorical past variables x_cat_future: the categorical future variables idx_target: index of target features in the past array
- Returns:
output of the mode;
- Return type:
torch.tensor