# -*-Encoding: utf-8 -*-
"""
Authors:
Li,Yan (liyan22021121@gmail.com)
"""
import torch
import torch.nn as nn
import math
[docs]
class PositionalEmbedding(nn.Module):
[docs]
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
[docs]
def forward(self, x):
return self.pe[:, :x.size(1)]
[docs]
class TokenEmbedding(nn.Module):
[docs]
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
[docs]
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
[docs]
class TemporalEmbedding(nn.Module):
[docs]
def __init__(self, d_model, freq='h'):
super(TemporalEmbedding, self).__init__()
minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
# Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
Embed = nn.Embedding
if freq == 't':
self.minute_embed = Embed(minute_size, d_model)
self.fc = nn.Linear(5*d_model, d_model)
else:
self.fc = nn.Linear(4*d_model, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
[docs]
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0.
hour_x = self.hour_embed(x[:,:,3])
weekday_x = self.weekday_embed(x[:,:,2])
day_x = self.day_embed(x[:,:,1])
month_x = self.month_embed(x[:,:,0])
if hasattr(self, 'minute_embed'):
out = torch.cat((minute_x, hour_x, weekday_x, day_x, month_x), dim=2)
else:
out = torch.cat((hour_x, weekday_x, day_x, month_x), dim=2)
# print(out.shape)
out = self.fc(out)
# print(out.shape)
return out
[docs]
class DataEmbedding(nn.Module):
[docs]
def __init__(self, c_in, d_model, embs, dropout=0.1):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
#self.temporal_embedding = TemporalEmbedding(d_model=d_model, freq=freq)
self.emb_list = nn.ModuleList()
if embs is not None:
for k in embs:
self.emb_list.append(nn.Embedding(k+1,d_model))
self.dropout = nn.Dropout(p=dropout)
[docs]
def forward(self, x, x_mark):
tot = None
for i in range(len(self.emb_list)):
if tot is None:
tot = self.emb_list[i](x_mark[:,:,i])
else:
tot += self.emb_list[i](x_mark[:,:,i])
x = self.value_embedding(x) + tot + self.position_embedding(x)
return self.dropout(x)