简化代码示例
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDPdef setup(rank, world_size):# 设置分布式环境dist.init_process_group("nccl", rank=rank, world_size=world_size)def cleanup():# 清理分布式环境dist.destroy_process_group()def train(rank, world_size):setup(rank, world_size)# 创建模型并将其移动到对应的GPUmodel = MyModel().to(rank)ddp_model = DDP(model, device_ids=[rank])# 创建优化器optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.001)# 训练循环for epoch in range(10):optimizer.zero_grad()# 假设data和target已经准备好output = ddp_model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()cleanup()if __name__ == "__main__":world_size = 4 # 假设4个GPUtorch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
- torch.multiprocessing.spawn函数用于在每个GPU上启动一个进程。每个进程都会调用train函数,并传入其对应的rank和总的world_size(即GPU的数量)。
- 在train函数中,首先调用setup函数来初始化分布式环境,然后创建模型并将其包装在DDP中,以便在多个GPU上进行训练。最后,进行训练循环,并在训练结束后调用cleanup函数来清理分布式环境.
更多逻辑上设置的依据:https://zhuanlan.zhihu.com/p/581677880