dsipts.models.base module

dsipts.models.base.standardize_momentum(x, order)[source]
dsipts.models.base.dilate_loss(outputs, targets, alpha, gamma, device)[source]
class dsipts.models.base.Base(verbose)[source]

Bases: LightningModule

This is the basic model, each model implemented must overwrite the init method and the forward method. The inference step is optional, by default it uses the forward method but for recurrent network you should implement your own method

handle_multivariate = False
handle_future_covariates = False
handle_categorical_variables = False
handle_quantile_loss = False
description = 'Can NOT  handle multivariate output \nCan NOT  handle future covariates\nCan NOT  handle categorical covariates\nCan NOT  handle Quantile loss function'
abstractmethod __init__(verbose)[source]

This is the basic model, each model implemented must overwrite the init method and the forward method. The inference step is optional, by default it uses the forward method but for recurrent network you should implement your own method

abstractmethod forward(batch)[source]

Forlward method used during the training loop

Parameters:

batch (dict) – the batch structure. The keys are: y : the target variable(s). This is always present x_num_past: the numerical past variables. This is always present x_num_future: the numerical future variables x_cat_past: the categorical past variables x_cat_future: the categorical future variables idx_target: index of target features in the past array

Returns:

output of the mode;

Return type:

torch.tensor

inference(batch)[source]

Usually it is ok to return the output of the forward method but sometimes not (e.g. RNN)

Parameters:

batch (dict) – batch

Returns:

result

Return type:

torch.tensor