Source code for dsipts.models.crossformer.cross_encoder

import torch
import torch.nn as nn

from .attn import TwoStageAttentionLayer
from math import ceil

[docs] class SegMerging(nn.Module): ''' Segment Merging Layer. The adjacent `win_size' segments in each dimension will be merged into one segment to get representation of a coarser scale we set win_size = 2 in our paper '''
[docs] def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): super().__init__() self.d_model = d_model self.win_size = win_size self.linear_trans = nn.Linear(win_size * d_model, d_model) self.norm = norm_layer(win_size * d_model)
[docs] def forward(self, x): """ x: B, ts_d, L, d_model """ batch_size, ts_d, seg_num, d_model = x.shape pad_num = seg_num % self.win_size #import pdb #pdb.set_trace() if pad_num != 0: pad_num = self.win_size - pad_num x = torch.cat((x, x[:, :, -pad_num:, :]), dim = -2) seg_to_merge = [] for i in range(self.win_size): seg_to_merge.append(x[:, :, i::self.win_size, :]) x = torch.cat(seg_to_merge, -1) # [B, ts_d, seg_num/win_size, win_size*d_model] x = self.norm(x) x = self.linear_trans(x) return x
[docs] class scale_block(nn.Module): ''' We can use one segment merging layer followed by multiple TSA layers in each scale the parameter `depth' determines the number of TSA layers used in each scale We set depth = 1 in the paper '''
[docs] def __init__(self, win_size, d_model, n_heads, d_ff, depth, dropout, \ seg_num = 10, factor=10): super(scale_block, self).__init__() if (win_size > 1): self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) else: self.merge_layer = None self.encode_layers = nn.ModuleList() for i in range(depth): self.encode_layers.append(TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, \ d_ff, dropout))
[docs] def forward(self, x): _, ts_dim, _, _ = x.shape if self.merge_layer is not None: x = self.merge_layer(x) for layer in self.encode_layers: x = layer(x) return x
[docs] class Encoder(nn.Module): ''' The Encoder of Crossformer. '''
[docs] def __init__(self, e_blocks, win_size, d_model, n_heads, d_ff, block_depth, dropout, in_seg_num = 10, factor=10): super(Encoder, self).__init__() self.encode_blocks = nn.ModuleList() self.encode_blocks.append(scale_block(1, d_model, n_heads, d_ff, block_depth, dropout,\ in_seg_num, factor)) for i in range(1, e_blocks): self.encode_blocks.append(scale_block(win_size, d_model, n_heads, d_ff, block_depth, dropout,\ ceil(in_seg_num/win_size**i), factor))
[docs] def forward(self, x): encode_x = [] encode_x.append(x) for block in self.encode_blocks: x = block(x) encode_x.append(x) return encode_x