# -*-Encoding: utf-8 -*-
"""
Authors:
Li,Yan (liyan22021121@gmail.com)
"""
from torch import nn
[docs]
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.2)
elif classname.find("BatchNorm") != -1:
m.weight.data.normal_(1.0, 0.2)
m.bias.data.fill_(0)
[docs]
class MyConvo2d(nn.Module):
[docs]
def __init__(self, input_dim, output_dim, kernel_size, stride = 1, bias = True):
super(MyConvo2d, self).__init__()
self.padding = int((kernel_size - 1)/2)
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=self.padding, bias = bias)
[docs]
def forward(self, input):
output = self.conv(input)
return output
[docs]
class Square(nn.Module):
[docs]
def __init__(self):
super(Square,self).__init__()
pass
[docs]
def forward(self,in_vect):
return in_vect**2
[docs]
class Swish(nn.Module):
[docs]
def __init__(self):
super(Swish,self).__init__()
pass
[docs]
def forward(self,in_vect):
return in_vect*nn.functional.sigmoid(in_vect)
[docs]
class MeanPoolConv(nn.Module):
[docs]
def __init__(self, input_dim, output_dim, kernel_size):
super(MeanPoolConv, self).__init__()
self.conv = MyConvo2d(input_dim, output_dim, kernel_size)
[docs]
def forward(self, input):
output = input
output = self.conv(output)
return output
[docs]
class ConvMeanPool(nn.Module):
[docs]
def __init__(self, input_dim, output_dim, kernel_size):
super(ConvMeanPool, self).__init__()
self.conv = MyConvo2d(input_dim, output_dim, kernel_size)
[docs]
def forward(self, input):
output = self.conv(input)
return output
[docs]
class ResidualBlock(nn.Module):
[docs]
def __init__(self, input_dim, output_dim, kernel_size, hw, resample=None, normalize=False,AF=nn.ELU()):
super(ResidualBlock, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.kernel_size = kernel_size
self.resample = resample
self.normalize = normalize
self.bn1 = None
self.bn2 = None
self.relu1 = AF
self.relu2 = AF
if resample == 'down':
self.bn1 = nn.LayerNorm([input_dim, hw, hw])
self.bn2 = nn.LayerNorm([input_dim, hw, hw])
elif resample == 'none':
self.bn1 = nn.LayerNorm([input_dim, hw, hw])
self.bn2 = nn.LayerNorm([input_dim, hw, hw])
if resample == 'down':
self.conv_shortcut = MeanPoolConv(input_dim, output_dim, kernel_size = 1)
self.conv_1 = MyConvo2d(input_dim, input_dim, kernel_size = kernel_size, bias = False)
self.conv_2 = ConvMeanPool(input_dim, output_dim, kernel_size = kernel_size)
elif resample == 'none':
self.conv_shortcut = MyConvo2d(input_dim, output_dim, kernel_size = 1)
self.conv_1 = MyConvo2d(input_dim, input_dim, kernel_size = kernel_size, bias = False)
self.conv_2 = MyConvo2d(input_dim, output_dim, kernel_size = kernel_size)
[docs]
def forward(self, input):
if self.input_dim == self.output_dim and self.resample is None:
shortcut = input
else:
shortcut = self.conv_shortcut(input)
if self.normalize is False:
output = input
output = self.relu1(output)
output = self.conv_1(output)
output = self.relu2(output)
output = self.conv_2(output)
else:
output = input
output = self.bn1(output)
output = self.relu1(output)
output = self.conv_1(output)
output = self.bn2(output)
output = self.relu2(output)
output = self.conv_2(output)
return shortcut + output
[docs]
class Res12_Quadratic(nn.Module):
[docs]
def __init__(self,inchan,dim,hw,normalize=False,AF=None):
super(Res12_Quadratic, self).__init__()
self.hw = hw
self.dim = dim
self.inchan = inchan
self.conv1 = MyConvo2d(inchan,dim, 3)
self.rb1 = ResidualBlock(dim, 2*dim, 3, int(hw), resample = 'down',normalize=normalize,AF=AF)
self.rbc1 = ResidualBlock(2*dim, 2*dim, 3, int(hw/2), resample = 'none',normalize=normalize,AF=AF)
self.rb2 = ResidualBlock(2*dim, 4*dim, 3, int(hw/2), resample = 'down',normalize=normalize,AF=AF)
self.rbc2 = ResidualBlock(4*dim, 4*dim, 3, int(hw/4), resample = 'none',normalize=normalize,AF=AF)
self.rb3 = ResidualBlock(4*dim, 8*dim, 3, int(hw/4), resample = 'down',normalize=normalize,AF=AF)
self.rbc3 = ResidualBlock(8*dim, 8*dim, 3, int(hw/8), resample = 'none',normalize=normalize,AF=AF)
self.ln1 = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
self.ln2 = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
self.lq = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
self.Square = Square()
[docs]
def forward(self, x_in):
output = x_in
output = self.conv1(output)
# print(output.shape)
output = self.rb1(output)
output = self.rbc1(output)
output = self.rb2(output)
output = self.rbc2(output)
output = self.rb3(output)
output = self.rbc3(output)
output = output.view(-1, int(self.hw/8)*int(self.hw/8)*8*self.dim)
output = self.ln1(output)*self.ln2(output)+self.lq(self.Square(output))
output = output.view(-1)
return output