# -*-Encoding: utf-8 -*-
"""
Authors:
Li,Yan (liyan22021121@gmail.com)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from collections import OrderedDict
BN_EPS = 1e-5
SYNC_BN = False
OPS = OrderedDict([
('res_elu', lambda Cin, Cout, stride: ELUConv(Cin, Cout, 3, stride, 1)),
('res_bnelu', lambda Cin, Cout, stride: BNELUConv(Cin, Cout, 3, stride, 1)),
('res_bnswish', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 1)),
('res_bnswish5', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 2, 2)),
('mconv_e6k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=5, g=1)),
('mconv_e3k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=1)),
('mconv_e3k5g8', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=8)),
('mconv_e6k11g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=11, g=0)),
])
[docs]
class SyncBatchNormSwish(_BatchNorm):
[docs]
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, process_group=None):
super(SyncBatchNormSwish, self).__init__(num_features, eps, momentum, affine, track_running_stats)
self.process_group = process_group
self.ddp_gpu_size = None
[docs]
def forward(self, input):
exponential_average_factor = self.momentum
out = F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
return out
[docs]
def get_skip_connection(C, stride, channel_mult):
if stride == 1:
return Identity()
elif stride == 2:
return FactorizedReduce(C, int(channel_mult * C))
elif stride == -1:
return nn.Sequential(UpSample(), Conv2D(C, int(C / channel_mult), kernel_size=1))
[docs]
def norm(t, dim):
return torch.sqrt(torch.sum(t * t, dim))
[docs]
def logit(t):
return torch.log(t) - torch.log(1 - t)
[docs]
def act(t):
# The following implementation has lower memory.
return SwishFN.apply(t)
[docs]
class SwishFN(torch.autograd.Function):
[docs]
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
[docs]
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
[docs]
class Swish(nn.Module):
[docs]
def __init__(self):
super(Swish, self).__init__()
[docs]
def forward(self, x):
return act(x)
[docs]
def normalize_weight_jit(log_weight_norm, weight):
n = torch.exp(log_weight_norm)
wn = torch.sqrt(torch.sum(weight * weight, dim=[1, 2, 3])) # norm(w)
weight = n * weight / (wn.view(-1, 1, 1, 1) + 1e-5)
return weight
[docs]
class Conv2D(nn.Conv2d):
"""Allows for weights as input."""
[docs]
def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, data_init=False,
weight_norm=True):
"""
Args:
use_shared (bool): Use weights for this layer or not?
"""
super(Conv2D, self).__init__(C_in, C_out, kernel_size, stride, padding, dilation, groups, bias)
self.log_weight_norm = None
if weight_norm:
init = norm(self.weight, dim=[1, 2, 3]).view(-1, 1, 1, 1)
self.log_weight_norm = nn.Parameter(torch.log(init + 1e-2), requires_grad=True)
self.data_init = data_init
self.init_done = False
self.weight_normalized = self.normalize_weight()
[docs]
def forward(self, x):
# do data based initialization
self.weight_normalized = self.normalize_weight()
#print(self.weight_normalized.shape)
bias = self.bias
return F.conv2d(x, self.weight_normalized, bias, self.stride,
self.padding, self.dilation, self.groups)
[docs]
def normalize_weight(self):
""" applies weight normalization """
if self.log_weight_norm is not None:
weight = normalize_weight_jit(self.log_weight_norm, self.weight)
else:
weight = self.weight
return weight
[docs]
class Identity(nn.Module):
[docs]
def __init__(self):
super(Identity, self).__init__()
[docs]
def forward(self, x):
return x
[docs]
class SyncBatchNorm(nn.Module):
[docs]
def __init__(self, *args, **kwargs):
super(SyncBatchNorm, self).__init__()
self.bn = nn.BatchNorm(*args, **kwargs)
[docs]
def forward(self, x):
return self.bn(x)
# quick switch between multi-gpu, single-gpu batch norm
[docs]
def get_batchnorm(*args, **kwargs):
return nn.BatchNorm2d(*args, **kwargs)
[docs]
class ELUConv(nn.Module):
[docs]
def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
super(ELUConv, self).__init__()
self.upsample = stride == -1
stride = abs(stride)
self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation,
data_init=True)
[docs]
def forward(self, x):
out = F.elu(x)
if self.upsample:
out = F.interpolate(out, scale_factor=2, mode='nearest')
out = self.conv_0(out)
return out
[docs]
class BNELUConv(nn.Module):
[docs]
def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
super(BNELUConv, self).__init__()
self.upsample = stride == -1
stride = abs(stride)
self.bn = get_batchnorm(C_in, eps=BN_EPS, momentum=0.05)
self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation)
[docs]
def forward(self, x):
x = self.bn(x)
out = F.elu(x)
if self.upsample:
out = F.interpolate(out, scale_factor=2, mode='nearest')
out = self.conv_0(out)
return out
[docs]
class BNSwishConv(nn.Module):
"""ReLU + Conv2d + BN."""
[docs]
def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
super(BNSwishConv, self).__init__()
self.upsample = stride == -1
stride = abs(stride)
self.bn_act = SyncBatchNormSwish(C_in, eps=BN_EPS, momentum=0.05)
self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation)
[docs]
def forward(self, x):
"""
Args:
x (torch.Tensor): of size (B, C_in, H, W)
"""
out = self.bn_act(x)
if self.upsample:
out = F.interpolate(out, scale_factor=2, mode='nearest')
out = self.conv_0(out)
return out
[docs]
class FactorizedReduce(nn.Module):
[docs]
def __init__(self, C_in, C_out):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.conv_1 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
self.conv_2 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
self.conv_3 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
self.conv_4 = Conv2D(C_in, C_out - 3 * (C_out // 4), 1, stride=2, padding=0, bias=True)
[docs]
def forward(self, x):
out = act(x)
conv1 = self.conv_1(out[:,:,:, :])
conv2 = self.conv_2(out[:, :, 1:, :])
conv3 = self.conv_3(out[:, :, :, :])
conv4 = self.conv_4(out[:, :, 1:, :])
out = torch.cat([conv1, conv2, conv3, conv4], dim=1)
return out
[docs]
class UpSample(nn.Module):
[docs]
def __init__(self):
super(UpSample, self).__init__()
pass
[docs]
def forward(self, x):
return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
[docs]
class EncCombinerCell(nn.Module):
[docs]
def __init__(self, Cin1, Cin2, Cout, cell_type):
super(EncCombinerCell, self).__init__()
self.cell_type = cell_type
# Cin = Cin1 + Cin2
self.conv = Conv2D(Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True)
[docs]
def forward(self, x1, x2):
x2 = self.conv(x2)
out = x1 + x2
return out
# original combiner
[docs]
class DecCombinerCell(nn.Module):
[docs]
def __init__(self, Cin1, Cin2, Cout, cell_type):
super(DecCombinerCell, self).__init__()
self.cell_type = cell_type
self.conv = Conv2D(Cin1 + Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True)
[docs]
def forward(self, x1, x2):
out = torch.cat([x1, x2], dim=1)
out = self.conv(out)
return out
[docs]
class ConvBNSwish(nn.Module):
[docs]
def __init__(self, Cin, Cout, k=3, stride=1, groups=1, dilation=1):
padding = dilation * (k - 1) // 2
super(ConvBNSwish, self).__init__()
self.conv = nn.Sequential(
Conv2D(Cin, Cout, k, stride, padding, groups=groups, bias=False, dilation=dilation, weight_norm=False),
SyncBatchNormSwish(Cout, eps=BN_EPS, momentum=0.05) # drop in replacement for BN + Swish
)
[docs]
def forward(self, x):
return self.conv(x)
[docs]
class SE(nn.Module):
[docs]
def __init__(self, Cin, Cout):
super(SE, self).__init__()
num_hidden = max(Cout // 16, 4)
self.se = nn.Sequential(nn.Linear(Cin, num_hidden), nn.ReLU(inplace=True),
nn.Linear(num_hidden, Cout), nn.Sigmoid())
[docs]
def forward(self, x):
se = torch.mean(x, dim=[2, 3])
se = se.view(se.size(0), -1)
se = self.se(se)
se = se.view(se.size(0), -1, 1, 1)
return x * se
[docs]
class InvertedResidual(nn.Module):
[docs]
def __init__(self, Cin, Cout, stride, ex, dil, k, g):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2, -1]
hidden_dim = int(round(Cin * ex))
self.use_res_connect = self.stride == 1 and Cin == Cout
self.upsample = self.stride == -1
self.stride = abs(self.stride)
groups = hidden_dim if g == 0 else g
layers0 = [nn.UpsamplingNearest2d(scale_factor=2)] if self.upsample else []
layers = [get_batchnorm(Cin, eps=BN_EPS, momentum=0.05),
ConvBNSwish(Cin, hidden_dim, k=1),
ConvBNSwish(hidden_dim, hidden_dim, stride=self.stride, groups=groups, k=k, dilation=dil),
Conv2D(hidden_dim, Cout, 1, 1, 0, bias=False, weight_norm=False),
get_batchnorm(Cout, momentum=0.05)]
layers0.extend(layers)
self.conv = nn.Sequential(*layers0)
[docs]
def forward(self, x):
return self.conv(x)