# 告诉模型训练的时候 对某个东西 给予额外的注意 额外的权重参数 分配注意力
# 不重要的就抑制 降低权重参数 比如有些项目颜色重要 有些是形状重要
# 通道注意力 一般都要比较多的通道加注意力
# SENet
# 把上层的特征图 自动卷积为 1X1的通道数不变的特征图 然后给每一个通道乘一个权重 就分配了各个通道的注意力 把这个与原图残差回去 与原图融合 这样对比原图来说 形状 CHW都没变
# 注意力机制 可以即插即用 CHW都没变
import torch
import os
import torch.nn as nn
from torchvision.models import resnet18,ResNet18_Weights
from torchvision.models.resnet import _resnet,BasicBlock
path=os.path.dirname(__file__)
onnxpath=os.path.join(path,"assets/resnet_SE-Identity.onnx")
onnxpath=os.path.relpath(onnxpath)
class SENet1(nn.Module):
def __init__(self,inchannel,r=16):
super().__init__()
# 全局平均池化 把所以通道 整个通道进行平均池化
self.inchannel=inchannel
self.pool1=nn.AdaptiveAvgPool2d(1)
# 对全局平均池化后的结果 赋予每个通道的权重 不选择最大池化因为不是在突出最大的特征
# 这里不是直接一个全连接生成 权重 而是用两个全连接来生成 权重 第一个relu激活 第二个Sigmoid 为每一个通道生成一个0-1的权重
# 第一个全连接输出的通道数数量要缩小一下,不能直接传入多少就输出多少,不然参数量太多,第二个通道再输出回去就行
# 缩放因子
self.fc1=nn.Sequential(nn.Linear(self.inchannel,self.inchannel//r),nn.ReLU())
self.fc2=nn.Sequential(nn.Linear(self.inchannel//r,self.inchannel),nn.Sigmoid())
# fc1 用relu会信息丢失 保证inchannel//r 至少要32
# 用两层全连接可以增加注意力层的健壮性
def forward(self,x):
x1=self.pool1(x)
x1=x1.view(x1.shape[0],-1)
x1=self.fc1(x1)
x1=self.fc2(x1)
# 得到了每一个通道的权重
x1=x1.unsqueeze(2).unsqueeze(3)
# 与原来的相乘
return x*x1
def demo1():
torch.manual_seed(666)
img1=torch.rand(1,128,224,224)
senet1=SENet1(img1.shape[1],2)
res=senet1.forward(img1)
print(res.shape)
# 可以把SE模块加入到经典的CNN模型里面 有残差模块的在残差模块后面加入SE 残差模块的输出 当SE模块的输入
# 在卷积后的数据与原数据相加之前 把卷积的数据和 依靠卷积后的数据产生的SE模块的数据 相乘 然后再与原数据相加
# 这个要看源码 进行操作
# 也可以不在 残差后面 进行 有很多种插入SE的方式
# 要找到 网络的残差模块
def demo2():
# 把SE模块加入到ResNet18
# 继承一个BasicBlock类 对resnet18的残差模块进行一些重写
class BasicBlock_SE(BasicBlock):
def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1, base_width = 64, dilation = 1, norm_layer = None):
super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
self.se=SENet1(inplanes)# SE-Identity 加法 在 数据传进来的时候备份两份数据 一份卷积 一份加注意力SE模块 然后两个结果相加输出
def forward(self, x):
identity = x
identity=self.se(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(identity)
out += identity
out = self.relu(out)
return out
# self.se=SENet1(planes)# SE-POST 加法 在 残差模块彻底完成了后加注意力SE模块 然后结果输出
# def forward(self, x):
# identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out += identity
# out = self.relu(out)
# out=self.se(out)
# return out
# self.se=SENet1(inplanes)# SE-PRE 加法 在 残差模块卷积之前加注意力SE模块 然后结果输出
# def forward(self, x):
# identity = x
# out=self.se(x)
# out = self.conv1(out)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out += identity
# out = self.relu(out)
# return out
# self.se=SENet1(planes)# Standard_SE 加法 在 残差模块卷积h后加注意力SE模块 然后与原数据项加结果输出
# def forward(self, x):
# identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out=self.se(out)
# out += identity
# out = self.relu(out)
# return out
def resnet18_SE(*, weights= None, progress: bool = True, **kwargs):
weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock_SE, [2, 2, 2, 2], weights, progress, **kwargs)
model1=resnet18_SE()
x = torch.randn(1, 3, 224, 224)
# 导出onnx
torch.onnx.export(
model1,
x,
onnxpath,
verbose=True, # 输出转换过程
input_names=["input"],
output_names=["output"],
)
print("onnx导出成功")
# SE在模型的早期层并没有 起多大的作用 在后期层中加 SE机制效果明显 且参数更少
# SE在模型的早期层并没有 起多大的作用 在后期层中加 SE机制效果明显 且参数更少
# 改模型不仅需要 加 一个网络结构 而且也需要注意前向传播 有没有问题
def demo3(): # 在resnet18中的后期 层里面加 SE 前期层不加
class ResNet_SE_laye(ResNet):
def __init__(self, block, layers, num_classes = 1000, zero_init_residual = False, groups = 1, width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
def _layer_update_SE(self):
self.se=SENet1(self.layer3[1].conv2.out_channels,8)
self.layer3[1].conv2=nn.Sequential(self.layer3[1].conv2,self.se)
print(self.layer3)
pass
return self.layer3
def _resnet_SE_layer(
block,
layers,
weights,
progress: bool,
**kwargs,
):
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ResNet_SE_laye(block, layers, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
def resnet18_SE_layer(*, weights= None, progress: bool = True, **kwargs):
weights = ResNet18_Weights.verify(weights)
return _resnet_SE_layer(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
model=resnet18_SE_layer()
# print(model)
layer=model._layer_update_SE()
torch.onnx.export(layer,torch.rand(1,128,224,224),"layer.onnx")
  
pass
  
if __name__=="__main__":
# demo1()
# demo2()
pass
