您的位置:首页 > 汽车 > 新车 > 品牌vi设计机构_舆情优化_新闻网站排行榜_自助发外链网站

品牌vi设计机构_舆情优化_新闻网站排行榜_自助发外链网站

2025/6/19 4:06:10 来源:https://blog.csdn.net/wei582636312/article/details/147248687  浏览:    关键词:品牌vi设计机构_舆情优化_新闻网站排行榜_自助发外链网站
品牌vi设计机构_舆情优化_新闻网站排行榜_自助发外链网站

在这里插入图片描述

文章目录

  • 1、Feature Fusion Module
  • 2、代码实现

paper:A dual encoder crack segmentation network with Haar wavelet-based high–low frequency attention

Code:https://github.com/zZhiG/DECS-Net


1、Feature Fusion Module

CNN 和 Transformer 分别擅长于提取局部特征和全局上下文信息,但单独使用存在局限性。CNN 捕获的局部特征细节丰富,但难以建模长距离依赖关系。而Transformer 善于建模长距离依赖关系,但容易受到复杂背景的干扰。所以为了更好融合两部分的特征,这篇论文提出一种 特征融合模块(Feature Fusion Module)。其目的是有效地融合 CNN 编码器和 Transformer 编码器提取的中间特征,以便更好地进行裂纹分割。

实现过程:

  1. 首先使用 1x1 卷积将 CNN 编码器特征和 Transformer 编码器特征调整到相同的维度。
  2. 通道注意力 (CA):对调整后的特征分别进行通道注意力操作。具体来说,对来自 CNN 编码器和 Transformer 编码器的特征进行通道维度上的权重调整,减少冗余信息,并充分利用互补特征。
  3. 跨域融合块 (CFB):对新的特征分别进行跨域融合块操作,得到交叉融合特征。具体来说,通过多头自注意力机制实现不同域特征之间的交叉融合,增强特征交互,并提取更丰富的语义信息。
  4. 相关增强 (CE):对交叉融合特征进行相关增强操作,得到相关增强特征。具体来说,通过矩阵乘法操作建模 CNN 和 Transformer 编码器特征之间的跨域相关性,增强重要信息,抑制无关信息。
  5. 特征融合块 (FFB):将通道注意力特征、交叉融合特征、相关增强特征进行拼接,得到最终的融合特征。
  6. 将融合特征输入特征融合块进行进一步处理,得到最终输出特征。

优势:

  • 能够有效地融合 CNN 和 Transformer 的优势,提高裂纹分割的精度和鲁棒性。
  • 通过通道注意力和跨域融合,减少冗余信息,增强特征交互。
  • 能够更好地适应复杂背景和噪声干扰。

Feature Fusion Module 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
from einops.einops import rearrangeclass DSC(nn.Module):def __init__(self, c_in, c_out, k_size=3, stride=1, padding=1):super(DSC, self).__init__()self.c_in = c_inself.c_out = c_outself.dw = nn.Conv2d(c_in, c_in, k_size, stride, padding, groups=c_in)self.pw = nn.Conv2d(c_in, c_out, 1, 1)def forward(self, x):out = self.dw(x)out = self.pw(out)return outclass IDSC(nn.Module):def __init__(self, c_in, c_out, k_size=3, stride=1, padding=1):super(IDSC, self).__init__()self.c_in = c_inself.c_out = c_outself.dw = nn.Conv2d(c_out, c_out, k_size, stride, padding, groups=c_out)self.pw = nn.Conv2d(c_in, c_out, 1, 1)def forward(self, x):out = self.pw(x)out = self.dw(out)return outclass FFM(nn.Module):def __init__(self, dim1, dim2):super().__init__()self.trans_c = nn.Conv2d(dim1, dim2, 1)self.avg = nn.AdaptiveAvgPool2d(1)self.li1 = nn.Linear(dim2, dim2)self.li2 = nn.Linear(dim2, dim2)self.qx = DSC(dim2, dim2)self.kx = DSC(dim2, dim2)self.vx = DSC(dim2, dim2)self.projx = DSC(dim2, dim2)self.qy = DSC(dim2, dim2)self.ky = DSC(dim2, dim2)self.vy = DSC(dim2, dim2)self.projy = DSC(dim2, dim2)self.concat = nn.Conv2d(dim2 * 2, dim2, 1)self.fusion = nn.Sequential(IDSC(dim2 * 4, dim2),nn.BatchNorm2d(dim2),nn.GELU(),DSC(dim2, dim2),nn.BatchNorm2d(dim2),nn.GELU(),nn.Conv2d(dim2, dim2, 1),nn.BatchNorm2d(dim2),nn.GELU())def forward(self, x, y):b, c, h, w = x.shapeB, N, C = y.shapeH = W = int(N ** 0.5)x = self.trans_c(x)y = y.reshape(B, H, W, C).permute(0, 3, 1, 2)avg_x = self.avg(x).permute(0, 2, 3, 1)avg_y = self.avg(y).permute(0, 2, 3, 1)x_weight = self.li1(avg_x)y_weight = self.li2(avg_y)x = x.permute(0, 2, 3, 1) * x_weighty = y.permute(0, 2, 3, 1) * y_weightout1 = x * yout1 = out1.permute(0, 3, 1, 2)x = x.permute(0, 3, 1, 2)y = y.permute(0, 3, 1, 2)qy = self.qy(y).reshape(B, 8, C // 8, H // 4, 4, W // 4, 4).permute(0, 3, 5, 1, 4, 6, 2).reshape(B, N // 16, 8,16, C // 8)kx = self.kx(x).reshape(B, 8, C // 8, H // 4, 4, W // 4, 4).permute(0, 3, 5, 1, 4, 6, 2).reshape(B, N // 16, 8,16, C // 8)vx = self.vx(x).reshape(B, 8, C // 8, H // 4, 4, W // 4, 4).permute(0, 3, 5, 1, 4, 6, 2).reshape(B, N // 16, 8,16, C // 8)attnx = (qy @ kx.transpose(-2, -1)) * (C ** -0.5)attnx = attnx.softmax(dim=-1)attnx = (attnx @ vx).transpose(2, 3).reshape(B, H // 4, w // 4, 4, 4, C)attnx = attnx.transpose(2, 3).reshape(B, H, W, C).permute(0, 3, 1, 2)attnx = self.projx(attnx)qx = self.qx(x).reshape(B, 8, C // 8, H // 4, 4, W // 4, 4).permute(0, 3, 5, 1, 4, 6, 2).reshape(B, N // 16, 8,16, C // 8)ky = self.ky(y).reshape(B, 8, C // 8, H // 4, 4, W // 4, 4).permute(0, 3, 5, 1, 4, 6, 2).reshape(B, N // 16, 8,16, C // 8)vy = self.vy(y).reshape(B, 8, C // 8, H // 4, 4, W // 4, 4).permute(0, 3, 5, 1, 4, 6, 2).reshape(B, N // 16, 8,16, C // 8)attny = (qx @ ky.transpose(-2, -1)) * (C ** -0.5)attny = attny.softmax(dim=-1)attny = (attny @ vy).transpose(2, 3).reshape(B, H // 4, w // 4, 4, 4, C)attny = attny.transpose(2, 3).reshape(B, H, W, C).permute(0, 3, 1, 2)attny = self.projy(attny)out2 = torch.cat([attnx, attny], dim=1)out2 = self.concat(out2)out = torch.cat([x, y, out1, out2], dim=1)out = self.fusion(out)return outif __name__ == '__main__':x = torch.randn(4, 64, 128, 128).cuda()y = torch.randn(4, 64, 128, 128).cuda()y = rearrange(y, 'b c h w -> b (h w) c')model = FFM(64, 64).cuda()out = model(x,y)print(out.shape)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com