BatchNormlization
BatchNormlization的编码流程:
- init阶段初始化 C i n C_in Cin大小的scale向量和shift向量,同时初始化相同大小的滑动均值向量和滑动标准差向量;
- forward时沿着非channel维度计算均值、有偏方差
- 依据得到均值和有偏方差进行归一化
- 对归一化的结果进行缩放和平移
代码
代码如下:
class BN(nn.Module):def __init__(self,C_in):super(BN,self).__init__()self.scale=nn.Parameter(torch.ones(C_in).view(1,-1,1,1))self.shift=nn.Parameter(torch.zeros(C_in).view(1,-1,1,1))self.momentum=0.9self.register_buffer('running_mean',torch.zeros(C_in).view(1,-1,1,1))self.register_buffer('running_var',torch.zeros(C_in).view(1,-1,1,1))self.eps=1e-9def forward(self,x):if self.training:N,C,H,W=x.shapemean=x.mean(dim=[0,2,3],keepdim=True)var=x.var(dim=[0,2,3],keepdim=True,unbiased=False)x=(x-mean)/torch.sqrt(var+self.eps)self.running_mean=self.momentum*self.running_mean+(1-self.momentum)*meanself.running_var=self.momentum*self.running_var+(1-self.momentum)*varelse:x=(x-self.running_mean)/torch.sqrt(self.running_var+self.eps)return xif __name__=="__main__":input=torch.rand(10,3,5,5)model=BN(3)res=model(input)print('cool')