dsipts.models.samformer.utils module

dsipts.models.samformer.utils.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)[source]

A copy-paste from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

class dsipts.models.samformer.utils.RevIN(num_features, eps=1e-5, affine=True)[source]

Bases: Module

Reversible Instance Normalization (RevIN) https://openreview.net/pdf?id=cGDAkQo1C0p https://github.com/ts-kim/RevIN

Parameters:
  • num_features (int) – the number of features or channels

  • eps – a value added for numerical stability

  • affine – if True, RevIN has learnable affine parameters

__init__(num_features, eps=1e-5, affine=True)[source]
Parameters:
  • num_features (int) – the number of features or channels

  • eps – a value added for numerical stability

  • affine – if True, RevIN has learnable affine parameters

forward(x, mode)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dsipts.models.samformer.utils.SAM(params, base_optimizer, rho=0.05, adaptive=False, **kwargs)[source]

Bases: Optimizer

__init__(params, base_optimizer, rho=0.05, adaptive=False, **kwargs)[source]
first_step(zero_grad=False)[source]
second_step(zero_grad=False)[source]
step(closure=None)[source]

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

load_state_dict(state_dict)[source]

Load the optimizer state.

Parameters:

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

Note

The names of the parameters (if they exist under the “param_names” key of each param group in state_dict()) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a custom register_load_state_dict_pre_hook should be implemented to adapt the loaded dict accordingly. If param_names exist in loaded state dict param_groups they will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizer param_names will remain unchanged.