模块出处
[TCSVT 24] [link] [code] DSNet: A Novel Way to Use Atrous Convolutions in Semantic Segmentation
模块名称
Multi-Scale Attention Fusion (MSAF)
模块作用
双级特征融合
模块结构
模块思想
MSAF的主要思想是让网络根据损失学习特征权重,允许模型选择性地融合来自不同尺度的信息。
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MSAF(nn.Module):def __init__(self, channels=64, r=4):super(MSAF, self).__init__()inter_channels = int(channels // r)self.local_att = nn.Sequential(nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)self.context1 = nn.Sequential(nn.AdaptiveAvgPool2d((4, 4)),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels))self.context2 = nn.Sequential(nn.AdaptiveAvgPool2d((8, 8)),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels))self.global_att = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)self.sigmoid = nn.Sigmoid()def forward(self, x, residual):h, w = x.shape[2], x.shape[3]xa = x + residualxl = self.local_att(xa)c1 = self.context1(xa)c2 = self.context2(xa)xg = self.global_att(xa)c1 = F.interpolate(c1, size=[h, w], mode='nearest')c2 = F.interpolate(c2, size=[h, w], mode='nearest')xlg = xl + xg + c1 + c2 wei = self.sigmoid(xlg)xo = 2 * x * wei + 2 * residual * (1 - wei)return xoif __name__ == '__main__':msaf = MSAF()x1 = torch.randn([2, 64, 16, 16])x2 = torch.randn([2, 64, 16, 16])out = msaf(x1, x2) print(out.shape) # 2, 64, 16, 16