一、模型初始化模块
参考:项目代码
1.1 参数统计函数
def count_parameters(model):return sum(p.numel() for p in model.parameters() if p.requires_grad)
技术解析:
numel()
方法计算张量元素总数requires_grad
筛选需要梯度更新的参数- 统计结果反映模型复杂度,典型Transformer-base约65M参数
1.2 权重初始化
def initialize_weights(m):if hasattr(m, 'weight') and m.weight.dim() > 1:nn.init.kaiming_uniform_(m.weight.data)
初始化原理:
- Kaiming初始化针对ReLU族激活函数设计
- 保持前向传播时方差一致性
- 公式: W ∼ U ( − 6 / n i n , 6 / n i n ) W \sim U(-\sqrt{6/n_{in}}, \sqrt{6/n_{in}}) W∼U(−6/nin,6/nin)
1.3 模型实例化
model = Transformer(src_pad_idx=src_pad_idx,trg_pad_idx=trg_pad_idx,trg_sos_idx=trg_sos_idx,d_model=d_model,enc_voc_size=enc_voc_size,dec_voc_size=dec_voc_size,max_len=max_len,ffn_hidden=ffn_hidden,n_head=n_heads,n_layers=n_layers,drop_prob=drop_prob,device=device).to(device)
关键参数解析:
参数 | 典型值 | 作用 |
---|---|---|
d_model | 512 | 向量表征维度 |
n_head | 8 | 注意力头数量 |
ffn_hidden | 2048 | 前馈网络隐层维度 |
n_layers | 6 | 编码器/解码器堆叠层数 |
drop_prob | 0.1 | Dropout概率 |
二、训练准备模块
2.1 优化器配置
optimizer = Adam(params=model.parameters(),lr=init_lr,weight_decay=weight_decay,eps=adam_eps)
Adam优化器数学原理:
θ t + 1 = θ t − η v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon}\hat{m}_t θt+1=θt−v^t+ϵηm^t
其中 m ^ t \hat{m}_t m^t和 v ^ t \hat{v}_t v^t为一阶、二阶矩估计的偏差修正项
2.2 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,verbose=True,factor=factor,patience=patience)
调度策略:
- 监控验证集损失变化
- 当损失停滞时按factor比例(典型0.5)衰减学习率
- patience=5表示连续5次无改善触发衰减
2.3 损失函数
criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
Padding处理机制:
- 通过ignore_index屏蔽填充符的梯度计算
- 数学表达式修正为:
L = − ∑ i = 1 n y i log p i ⋅ I ( y i ≠ pad ) \mathcal{L} = -\sum_{i=1}^{n} y_i \log p_i \cdot \mathbb{I}(y_i \neq \text{pad}) L=−i=1∑nyilogpi⋅I(yi=pad)
三、训练与评估模块
3.1 训练函数
def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, batch in enumerate(iterator):src = batch.srctrg = batch.trgoptimizer.zero_grad()output = model(src, trg[:, :-1])output_reshape = output.contiguous().view(-1, output.shape[-1])trg = trg[:, 1:].contiguous().view(-1)loss = criterion(output_reshape, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())return epoch_loss / len(iterator)
关键技术点:
- 教师强制(Teacher Forcing):使用真实目标序列作为解码器输入
- 序列切片
trg[:, :-1]
去除终止符 - 梯度裁剪防止梯度爆炸
3.2 评估函数
def evaluate(model, iterator, criterion):model.eval()epoch_loss = 0batch_bleu = []with torch.no_grad():for i, batch in enumerate(iterator):src = batch.srctrg = batch.trgoutput = model(src, trg[:, :-1])output_reshape = output.contiguous().view(-1, output.shape[-1])trg = trg[:, 1:].contiguous().view(-1)loss = criterion(output_reshape, trg)epoch_loss += loss.item()total_bleu = []for j in range(batch_size):try:trg_words = idx_to_word(batch.trg[j], loader.target.vocab)output_words = output[j].max(dim=1)[1]output_words = idx_to_word(output_words, loader.target.vocab)bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())total_bleu.append(bleu)except:passtotal_bleu = sum(total_bleu) / len(total_bleu)batch_bleu.append(total_bleu)batch_bleu = sum(batch_bleu) / len(batch_bleu)return epoch_loss / len(iterator), batch_bleu
BLEU计算原理:
B L E U = B P ⋅ exp ( ∑ n = 1 N w n log p n ) BLEU = BP \cdot \exp\left(\sum_{n=1}^N w_n \log p_n\right) BLEU=BP⋅exp(n=1∑Nwnlogpn)
其中:
- BP为简洁惩罚因子
- p n p_n pn为n-gram精度
- w n w_n wn为各阶权重(通常平均加权)
四、运行控制模块
4.1 训练循环
def run(total_epoch, best_loss):train_losses, test_losses, bleus = [], [], []for step in range(total_epoch):start_time = time.time()train_loss = train(model, train_iter, optimizer, criterion, clip)valid_loss, bleu = evaluate(model, valid_iter, criterion)end_time = time.time()if step > warmup:scheduler.step(valid_loss)train_losses.append(train_loss)test_losses.append(valid_loss)bleus.append(bleu)epoch_mins, epoch_secs = epoch_time(start_time, end_time)if valid_loss < best_loss:best_loss = valid_losstorch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))f = open('result/train_loss.txt', 'w')f.write(str(train_losses))f.close()f = open('result/bleu.txt', 'w')f.write(str(bleus))f.close()f = open('result/test_loss.txt', 'w')f.write(str(test_losses))f.close()print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')print(f'\tVal Loss: {valid_loss:.3f} | Val PPL: {math.exp(valid_loss):7.3f}')print(f'\tBLEU Score: {bleu:.3f}')
模型保存策略:
- 采用验证损失作为保存标准
- 使用
model.state_dict()
保存参数快照 - 文件命名包含验证损失便于版本管理
五、工程实践要点
5.1 训练技巧
- Warm-up策略:前warmup个epoch不启动学习率衰减
- 混合精度训练:可结合
torch.cuda.amp
加速训练 - 梯度累积:小批量数据累积梯度模拟大批量效果
5.2 性能优化
torch.backends.cudnn.benchmark = True # 启用cuDNN自动优化器
torch.autograd.set_detect_anomaly(False) # 禁用异常检测提升速度
5.3 扩展实现
模型并行改造示例
class ParallelTransformer(Transformer):def __init__(self, ...):super().__init__(...)self.encoder = nn.DataParallel(self.encoder)self.decoder = nn.DataParallel(self.decoder)
本节从代码实现到理论机制进行了多角度解析,完整保留原始代码结构的同时通过流程图解耦了各模块的运作机制。实际应用中可根据任务规模调整超参数,建议在8*V100 GPU环境下进行大规模预训练,结合混合精度训练提升训练效率。
源码(附):
"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""
import math
import timefrom torch import nn, optim
from torch.optim import Adamfrom data import *
from models.model.transformer import Transformer
from util.bleu import idx_to_word, get_bleu
from util.epoch_timer import epoch_timedef count_parameters(model):return sum(p.numel() for p in model.parameters() if p.requires_grad)def initialize_weights(m):if hasattr(m, 'weight') and m.weight.dim() > 1:nn.init.kaiming_uniform(m.weight.data)model = Transformer(src_pad_idx=src_pad_idx,trg_pad_idx=trg_pad_idx,trg_sos_idx=trg_sos_idx,d_model=d_model,enc_voc_size=enc_voc_size,dec_voc_size=dec_voc_size,max_len=max_len,ffn_hidden=ffn_hidden,n_head=n_heads,n_layers=n_layers,drop_prob=drop_prob,device=device).to(device)print(f'The model has {count_parameters(model):,} trainable parameters')
model.apply(initialize_weights)
optimizer = Adam(params=model.parameters(),lr=init_lr,weight_decay=weight_decay,eps=adam_eps)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,verbose=True,factor=factor,patience=patience)criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, batch in enumerate(iterator):src = batch.srctrg = batch.trgoptimizer.zero_grad()output = model(src, trg[:, :-1])output_reshape = output.contiguous().view(-1, output.shape[-1])trg = trg[:, 1:].contiguous().view(-1)loss = criterion(output_reshape, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())return epoch_loss / len(iterator)def evaluate(model, iterator, criterion):model.eval()epoch_loss = 0batch_bleu = []with torch.no_grad():for i, batch in enumerate(iterator):src = batch.srctrg = batch.trgoutput = model(src, trg[:, :-1])output_reshape = output.contiguous().view(-1, output.shape[-1])trg = trg[:, 1:].contiguous().view(-1)loss = criterion(output_reshape, trg)epoch_loss += loss.item()total_bleu = []for j in range(batch_size):try:trg_words = idx_to_word(batch.trg[j], loader.target.vocab)output_words = output[j].max(dim=1)[1]output_words = idx_to_word(output_words, loader.target.vocab)bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())total_bleu.append(bleu)except:passtotal_bleu = sum(total_bleu) / len(total_bleu)batch_bleu.append(total_bleu)batch_bleu = sum(batch_bleu) / len(batch_bleu)return epoch_loss / len(iterator), batch_bleudef run(total_epoch, best_loss):train_losses, test_losses, bleus = [], [], []for step in range(total_epoch):start_time = time.time()train_loss = train(model, train_iter, optimizer, criterion, clip)valid_loss, bleu = evaluate(model, valid_iter, criterion)end_time = time.time()if step > warmup:scheduler.step(valid_loss)train_losses.append(train_loss)test_losses.append(valid_loss)bleus.append(bleu)epoch_mins, epoch_secs = epoch_time(start_time, end_time)if valid_loss < best_loss:best_loss = valid_losstorch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))f = open('result/train_loss.txt', 'w')f.write(str(train_losses))f.close()f = open('result/bleu.txt', 'w')f.write(str(bleus))f.close()f = open('result/test_loss.txt', 'w')f.write(str(test_losses))f.close()print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')print(f'\tVal Loss: {valid_loss:.3f} | Val PPL: {math.exp(valid_loss):7.3f}')print(f'\tBLEU Score: {bleu:.3f}')if __name__ == '__main__':run(total_epoch=epoch, best_loss=inf)