Source code for dsipts.models.VQVAEA


## Copyright Copyright (c) 2020 Andrej Karpathy https://github.com/karpathy/minGPT?tab=MIT-1-ov-file#readme
## Modified for notation alignmenet,  batch structure and adapted for timeseries
## extended to what inside vva folder

from torch import  nn
import torch
from torch.nn import functional as F
from .base import Base
from typing import List, Union
from .vva.minigpt import Block
from .vva.vqvae import VQVAE
import logging
from random import random
from ..data_structure.utils import beauty_string
from .utils import  get_scope

torch.autograd.set_detect_anomaly(True)

        

[docs] class VQVAEA(Base): handle_multivariate = False handle_future_covariates = False handle_categorical_variables = False handle_quantile_loss = False description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
[docs] def __init__(self, past_steps:int, future_steps:int, past_channels:int, future_channels:int, hidden_channels:int, embs:List[int], d_model:int, max_voc_size:int, num_layers:int, dropout_rate:float, commitment_cost:float, decay:float, n_heads:int, out_channels:int, epoch_vqvae: int, persistence_weight:float=0.0, loss_type: str='l1', quantiles:List[int]=[], optim:Union[str,None]=None, optim_config:dict=None, scheduler_config:dict=None, **kwargs)->None: """ Custom encoder-decoder Args: past_steps (int): number of past datapoints used future_steps (int): number of future lag to predict past_channels (int): number of numeric past variables, must be >0 future_channels (int): number of future numeric variables embs (List): list of the initial dimension of the categorical variables cat_emb_dim (int): final dimension of each categorical variable hidden_RNN (int): hidden size of the RNN block num_layers_RNN (int): number of RNN layers kind (str): one among GRU or LSTM kernel_size (int): kernel size in the encoder convolutional block sum_emb (bool): if true the contribution of each embedding will be summed-up otherwise stacked out_channels (int): number of output channels activation (str, optional): activation fuction function pytorch. Default torch.nn.ReLU remove_last (bool, optional): if True the model learns the difference respect to the last seen point persistence_weight (float): weight controlling the divergence from persistence model. Default 0 loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1, quantiles (List[int], optional): we can use quantile loss il len(quantiles) = 0 (usually 0.1,0.5, 0.9) or L1loss in case len(quantiles)==0. Defaults to []. dropout_rate (float, optional): dropout rate in Dropout layers use_bn (bool, optional): if true BN layers will be added and dropouts will be removed use_glu (bool,optional): use GLU for feature selection. Defaults to True. glu_percentage (float, optiona): percentage of features to use. Defaults to 1.0. n_classes (int): number of classes (0 in regression) optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam. optim_config (dict, optional): configuration for Adam optimizer. Defaults to None. scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None. """ super().__init__(**kwargs) self.save_hyperparameters(logger=False) self.d_model = d_model self.max_voc_size = max_voc_size self.future_steps = future_steps self.epoch_vqvae = epoch_vqvae ##PRIMA VQVAE assert out_channels==1, beauty_string('Working only for one singal','section',True) assert past_steps%2==0 and future_steps%2==0, beauty_string('There are some issue with the deconder in case of odd length','section',True) self.vqvae = VQVAE(in_channels=1, hidden_channels=hidden_channels,out_channels=1,num_embeddings= max_voc_size,embedding_dim=d_model,commitment_cost=commitment_cost,decay=decay ) ##POI GPT self.block_size = past_steps//2 + future_steps//2 -1 self.sentence_length = future_steps//2 self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(max_voc_size, d_model), wpe = nn.Embedding(self.block_size, d_model), drop = nn.Dropout(dropout_rate), h = nn.ModuleList([Block( d_model,dropout_rate,n_heads,dropout_rate,self.block_size) for _ in range(num_layers)]), ##care can be different dropouts ln_f = nn.LayerNorm(d_model), lm_head = nn.Linear(d_model, max_voc_size, bias=False) )) # report number of parameters (note we don't count the decoder parameters in lm_head) n_params = sum(p.numel() for p in self.transformer.parameters()) beauty_string("number of parameters: %.2fM" % (n_params/1e6,),'info',self.verbose) self.use_quantiles = False self.is_classification = True self.optim_config = optim_config
def configure_optimizers(self): #return torch.optim.Adam(self.vqvae.parameters(), lr=self.optim_config.lr_vqvae, # weight_decay=self.optim_config.weight_decay_vqvae) return torch.optim.AdamW([ {'params':self.vqvae.parameters(),'lr':self.optim_config.lr_vqvae,'weight_decay':self.optim_config.weight_decay_vqvae}, {'params':self.transformer.parameters(),'lr':self.optim_config.lr_gpt,'weight_decay':self.optim_config.weight_decay_gpt}, ])
[docs] def gpt(self,tokens): b, t = tokens.size() assert t <= self.block_size, beauty_string("Cannot forward sequence of length {t}, block size is only {self.block_size}",'section',True) pos = torch.arange(0, t, dtype=torch.long, device=self.device).unsqueeze(0) # shape (1, t) # forward the GPT model itself tok_emb = self.transformer.wte(tokens) # token embeddings of shape (b, t, n_embd) pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) logits = self.transformer.lm_head(x) return logits
[docs] def forward(self, batch): ##VQVAE #current_epoch = self.current_epoch #if current_epoch < 1000: # self.vqvae.train() # loss_gpt = 100 #else: # self.vqvae.eval() idx_target = batch['idx_target'][0] #(tensor([194, 163, 174, 176, 160, 168, 175]), # tensor([ -1, -1, -1, 160, 168, 175, 160])) data = batch['x_num_past'][:,:,idx_target] if self.current_epoch > self.epoch_vqvae: with torch.no_grad(): vqloss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1)) loss_vqvae = 0 else: vq_loss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1)) if random()<0.001: beauty_string(perplexity,'info',self.verbose) recon_error = F.mse_loss(data_recon.squeeze(), data.squeeze()) loss_vqvae = recon_error + vq_loss if self.current_epoch > self.epoch_vqvae: with torch.no_grad(): _, _, _,quantized_y,encodings_y = self.vqvae(batch['y'].permute(0,2,1)) ##GPT tokens = torch.cat([encodings_x.argmax(dim=2),encodings_y.argmax(dim=2)[:,0:-1]],1) tokens_y = torch.cat([encodings_x.argmax(dim=2)[:,0:-1],encodings_y.argmax(dim=2)],1) tokens_y[:,0:encodings_x.shape[1]-1] = -1 logits = self.gpt(tokens) loss_gpt = F.cross_entropy(logits.view(-1, logits.size(-1)),tokens_y.view(-1), ignore_index=-1) ##adesso devo ricostruire la y perche' e quello che voglio come output with torch.no_grad(): encoding_indices = torch.argmax(logits.reshape(-1,self.max_voc_size), dim=1).unsqueeze(1) ## encodings = torch.zeros(encoding_indices.shape[0], self.vqvae._vq_vae._num_embeddings, device=self.device) encodings.scatter_(1, encoding_indices, 1) quantized = torch.matmul(encodings, self.vqvae._vq_vae._embedding.weight).view(data.shape[0],-1,self.d_model) ##B x L x hidden quantized = quantized.permute(0, 2, 1).contiguous() y_hat = self.vqvae._decoder(quantized,False).squeeze()[:,-self.future_steps:] l1_loss = nn.L1Loss()(y_hat,batch['y'].squeeze()) return y_hat, loss_vqvae+loss_gpt+l1_loss else: return None, loss_vqvae
def training_step(self, batch, batch_idx): """ pythotrch lightening stuff :meta private: """ _, loss = self(batch) return loss def validation_step(self, batch, batch_idx): """ pythotrch lightening stuff :meta private: """ _, loss = self(batch) return loss
[docs] def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None,num_samples=100): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this. """ assert do_sample is False,logging.info('NOT IMPLEMENTED YET') if do_sample: idx = idx.repeat(num_samples,1,1) for _ in range(max_new_tokens): tmp = [] for i in range(num_samples): idx_cond = idx[i,:,:] if idx.size(2) <= self.block_size else idx[i,:, -self.block_size:] logits = self.gpt(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1, replacement=True) tmp.append(idx_next) tmp = torch.cat(tmp,dim=1).T.unsqueeze(2) idx = torch.cat((idx, tmp), dim=2) return idx else: for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] # forward the model to get the logits for the index in the sequence logits = self.gpt(idx_cond) # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('Inf') # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # either sample from the distribution or take the most likely element _, idx_next = torch.topk(probs, k=1, dim=-1) # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=1) return idx
[docs] def inference(self, batch:dict)->torch.tensor: idx_target = batch['idx_target'][0] data = batch['x_num_past'][:,:,idx_target].to(self.device) vq_loss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1)) x = encodings_x.argmax(dim=2) inp = x[:, :self.sentence_length] # let the model sample the rest of the sequence cat = self.generate(inp, self.sentence_length, do_sample=False) # non riesco a gestirla qui :-) encoding_indices = cat.flatten().unsqueeze(1) ## encodings = torch.zeros(encoding_indices.shape[0], self.vqvae._vq_vae._num_embeddings, device=self.device) encodings.scatter_(1, encoding_indices, 1) quantized = torch.matmul(encodings, self.vqvae._vq_vae._embedding.weight).view(x.shape[0],-1,self.d_model) ##B x L x hidden quantized = quantized.permute(0, 2, 1).contiguous() y_hat = self.vqvae._decoder(quantized,False).squeeze()[:,-self.future_steps:] ## BxLxCx3 return y_hat.unsqueeze(2).unsqueeze(3)