Source code for dsipts.models.xlstm.xLSTM

import torch
import torch.nn as nn
import torch.nn.functional as F



[docs] class mLSTM(nn.Module):
[docs] def __init__(self, input_size, hidden_size, num_layers, dropout=0.0): super(mLSTM, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.dropout = dropout self.lstms = nn.ModuleList([nn.LSTMCell(input_size, hidden_size) for _ in range(num_layers)]) self.dropout_layers = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers - 1)]) self.W_q = nn.Linear(input_size, hidden_size) self.W_k = nn.Linear(input_size, hidden_size) self.W_v = nn.Linear(input_size, hidden_size) self.exp_input_gates = nn.ModuleList([nn.Linear(input_size, hidden_size) for _ in range(num_layers)]) self.exp_forget_gates = nn.ModuleList([nn.Linear(input_size, hidden_size) for _ in range(num_layers)]) self.output_gates = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)]) self.reset_parameters()
[docs] def reset_parameters(self): for lstm in self.lstms: nn.init.xavier_uniform_(lstm.weight_ih) nn.init.xavier_uniform_(lstm.weight_hh) nn.init.zeros_(lstm.bias_ih) nn.init.zeros_(lstm.bias_hh) nn.init.xavier_uniform_(self.W_q.weight) nn.init.xavier_uniform_(self.W_k.weight) nn.init.xavier_uniform_(self.W_v.weight) nn.init.zeros_(self.W_q.bias) nn.init.zeros_(self.W_k.bias) nn.init.zeros_(self.W_v.bias) for gate in self.exp_input_gates + self.exp_forget_gates + self.output_gates: nn.init.xavier_uniform_(gate.weight) nn.init.zeros_(gate.bias)
[docs] def forward(self, input_seq, hidden_state=None): batch_size = input_seq.size(0) seq_length = input_seq.size(1) if hidden_state is None: hidden_state = self.init_hidden(batch_size) output_seq = [] for t in range(seq_length): x = input_seq[:, t, :] queries = self.W_q(x) keys = self.W_k(x) values = self.W_v(x) new_hidden_state = [] for i, (lstm, dropout, i_gate, f_gate, o_gate) in enumerate(zip(self.lstms, self.dropout_layers, self.exp_input_gates, self.exp_forget_gates, self.output_gates)): if hidden_state[i][0] is None: h, C = lstm(x) else: h, C = hidden_state[i] ii = torch.exp(i_gate(x)) f = torch.exp(f_gate(x)) C_t = f * C + ii * torch.matmul(values.unsqueeze(2), keys.unsqueeze(1)).squeeze(1) attn_output = torch.matmul(queries, C_t).squeeze(2) o = torch.sigmoid(o_gate(h)) h = o * attn_output new_hidden_state.append((h, C_t)) if i < self.num_layers - 1: x = dropout(h) else: x = h hidden_state = new_hidden_state output_seq.append(x) output_seq = torch.stack(output_seq, dim=1) return output_seq, hidden_state
[docs] def init_hidden(self, batch_size): hidden_state = [] for lstm in self.lstms: h = torch.zeros(batch_size, self.hidden_size, device=lstm.weight_ih.device) C = torch.zeros(batch_size, self.hidden_size, self.hidden_size, device=lstm.weight_ih.device) hidden_state.append((h, C)) return hidden_state
[docs] class sLSTM(nn.Module):
[docs] def __init__(self, input_size, hidden_size, num_layers, dropout=0.0): super(sLSTM, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.dropout = dropout self.lstms = nn.ModuleList([nn.LSTMCell(input_size if i == 0 else hidden_size, hidden_size) for i in range(num_layers)]) self.dropout_layers = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers - 1)]) self.exp_forget_gates = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)]) self.exp_input_gates = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)]) self.reset_parameters()
[docs] def reset_parameters(self): for lstm in self.lstms: nn.init.xavier_uniform_(lstm.weight_ih) nn.init.xavier_uniform_(lstm.weight_hh) nn.init.zeros_(lstm.bias_ih) nn.init.zeros_(lstm.bias_hh) for gate in self.exp_forget_gates + self.exp_input_gates: nn.init.xavier_uniform_(gate.weight) nn.init.zeros_(gate.bias)
[docs] def forward(self, input_seq, hidden_state=None): batch_size = input_seq.size(0) seq_length = input_seq.size(1) if hidden_state is None: hidden_state = self.init_hidden(batch_size) output_seq = [] for t in range(seq_length): x = input_seq[:, t, :] new_hidden_state = [] for i, (lstm, dropout, f_gate, i_gate) in enumerate(zip(self.lstms, self.dropout_layers, self.exp_forget_gates, self.exp_input_gates)): if hidden_state[i][0] is None: h, c = lstm(x) else: h, c = lstm(x, (hidden_state[i][0], hidden_state[i][1])) f = torch.exp(f_gate(h)) ii = torch.exp(i_gate(h)) c = f * c + ii * lstm.weight_hh.new_zeros(batch_size, self.hidden_size) new_hidden_state.append((h, c)) if i < self.num_layers - 1: x = dropout(h) else: x = h hidden_state = new_hidden_state output_seq.append(x) output_seq = torch.stack(output_seq, dim=1) return output_seq, hidden_state
[docs] def init_hidden(self, batch_size): hidden_state = [] for lstm in self.lstms: h = torch.zeros(batch_size, self.hidden_size, device=lstm.weight_ih.device) c = torch.zeros(batch_size, self.hidden_size, device=lstm.weight_ih.device) hidden_state.append((h, c)) return hidden_state
[docs] class xLSTMBlock(nn.Module):
[docs] def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, bidirectional=False, lstm_type="slstm"): super(xLSTMBlock, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.dropout = dropout self.bidirectional = bidirectional self.lstm_type = lstm_type if lstm_type == "slstm": self.lstm = sLSTM(input_size, hidden_size, num_layers, dropout) elif lstm_type == "mlstm": self.lstm = mLSTM(input_size, hidden_size, num_layers, dropout) else: raise ValueError(f"Invalid LSTM type: {lstm_type}") self.norm = nn.LayerNorm(input_size) self.activation = nn.GELU() self.dropout_layer = nn.Dropout(dropout) if bidirectional: self.proj = nn.Linear(2 * hidden_size, input_size) else: if lstm_type == "mlstm": self.up_proj = nn.Sequential( nn.Linear(input_size, 4 * input_size), nn.GELU(), nn.Linear(4 * input_size, input_size) ) self.proj = nn.Linear(hidden_size, input_size) self.reset_parameters()
[docs] def reset_parameters(self): if hasattr(self, "up_proj"): nn.init.xavier_uniform_(self.up_proj[0].weight) nn.init.zeros_(self.up_proj[0].bias) nn.init.xavier_uniform_(self.up_proj[2].weight) nn.init.zeros_(self.up_proj[2].bias) nn.init.xavier_uniform_(self.proj.weight) nn.init.zeros_(self.proj.bias)
[docs] def forward(self, input_seq, hidden_state=None): if hasattr(self, "up_proj"): input_seq = self.up_proj(input_seq) lstm_output, hidden_state = self.lstm(input_seq, hidden_state) if self.lstm_type == "slstm": hidden_state = [[hidden_state[i][0], hidden_state[i][1]] for i in range(len(hidden_state))] if self.bidirectional: lstm_output = torch.cat((lstm_output[:, :, :self.hidden_size], lstm_output[:, :, -self.hidden_size:]), dim=-1) output = self.activation(self.proj(lstm_output)) output = self.norm(output + input_seq) output = self.dropout_layer(output) return output, hidden_state
[docs] class xLSTM(nn.Module):
[docs] def __init__(self, input_size, hidden_size, num_layers, num_blocks, dropout=0.0, bidirectional=False, lstm_type="slstm"): super(xLSTM, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.num_blocks = num_blocks self.dropout = dropout self.bidirectional = bidirectional self.lstm_type = lstm_type self.blocks = nn.ModuleList([xLSTMBlock(hidden_size, hidden_size, num_layers, dropout, bidirectional, lstm_type) for i in range(num_blocks)]) self.initial = nn.Linear(self.input_size,self.hidden_size)
[docs] def forward(self, input_seq, hidden_states=None): if hidden_states is None: hidden_states = [None] * self.num_blocks output_seq = self.initial(input_seq) for i, block in enumerate(self.blocks): output_seq, hidden_state = block(output_seq, hidden_states[i]) if self.lstm_type == "slstm": hidden_states[i] = [[hidden_state[j][0], hidden_state[j][1]] for j in range(len(hidden_state))] else: hidden_states[i] = hidden_state return output_seq, hidden_states