您的位置:首页 > 教育 > 锐评 > 算法面经手撕系列(2)--手撕BatchNormlization

算法面经手撕系列(2)--手撕BatchNormlization

2025/5/13 22:56:42 来源:https://blog.csdn.net/Dr_maker/article/details/142219408  浏览:    关键词:算法面经手撕系列(2)--手撕BatchNormlization

BatchNormlization

  BatchNormlization的编码流程:

  1. init阶段初始化 C i n C_in Cin大小的scale向量和shift向量,同时初始化相同大小的滑动均值向量和滑动标准差向量;
  2. forward时沿着非channel维度计算均值、有偏方差
  3. 依据得到均值和有偏方差进行归一化
  4. 对归一化的结果进行缩放和平移

代码

 代码如下:

class BN(nn.Module):def __init__(self,C_in):super(BN,self).__init__()self.scale=nn.Parameter(torch.ones(C_in).view(1,-1,1,1))self.shift=nn.Parameter(torch.zeros(C_in).view(1,-1,1,1))self.momentum=0.9self.register_buffer('running_mean',torch.zeros(C_in).view(1,-1,1,1))self.register_buffer('running_var',torch.zeros(C_in).view(1,-1,1,1))self.eps=1e-9def forward(self,x):if self.training:N,C,H,W=x.shapemean=x.mean(dim=[0,2,3],keepdim=True)var=x.var(dim=[0,2,3],keepdim=True,unbiased=False)x=(x-mean)/torch.sqrt(var+self.eps)self.running_mean=self.momentum*self.running_mean+(1-self.momentum)*meanself.running_var=self.momentum*self.running_var+(1-self.momentum)*varelse:x=(x-self.running_mean)/torch.sqrt(self.running_var+self.eps)return xif __name__=="__main__":input=torch.rand(10,3,5,5)model=BN(3)res=model(input)print('cool')

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com