本文共 2432 字,大约阅读时间需要 8 分钟。
PyTorch的多机多卡分布式训练(DDP)技术,能够显著提升训练效率。以下是基于实践总结的一系列优化技巧和实现方案。
Batch Normalization(BN)在深度学习中扮演着重要角色。然而,在多卡训练环境下,传统BN存在性能瓶颈。PyTorch提出的SyncBN(Sync Batch Normalization)有效解决了这一问题。
SyncBN通过分布式接口all_gather,实现了真正的多卡BN训练。其核心优化包括只传输小批量的mean和variance,减少了通信开销。
torch.nn.SyncBatchNorm.convert_sync_batchnorm
将普通BN替换为SyncBN。_BatchNorm
,否则无法直接替换。梯度累加(Gradient Accumulation)允许多个小步骤合并为一个大步骤,提升训练效率。然而,在DDP环境下,传统梯度累加方式存在效率低下问题。
model.no_sync()
临时禁用梯度同步。nullcontext
或contextlib.nullcontext
,实现梯度累加与同步控制。from contextlib import nullcontextif local_rank != -1: model = DDP(model)optimizer.zero_grad()for i, (data, label) in enumerate(dataloader): my_context = model.no_sync if local_rank != -1 and i % K != 0 else nullcontext with my_context(): prediction = model(data) loss_fn(prediction, label).backward() if i % K == 0: optimizer.step() optimizer.zero_grad()
在多卡训练环境下,推理测试难以充分利用多卡资源。
SequentialDistributedSampler
,实现连续数据分割。all_gather
接口,聚合各卡推理结果。class SequentialDistributedSampler(torch.utils.data.sampler.Sampler): def __init__(self, dataset, batch_size, rank, num_replicas): # ...(完整实现见参考代码)
def distributed_concat(tensor, num_total_examples): output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(output_tensors, tensor) concat = torch.cat(output_tensors, dim=0) return concat[:num_total_examples]
为保证一致性,需在不同进程间分配不同的随机种子。
def init_seeds(seed=0, cuda_deterministic=True): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if cuda_deterministic: cudnn.deterministic = True cudnn.benchmark = False else: cudnn.deterministic = False cudnn.benchmark = Truerank = torch.distributed.get_rank()init_seeds(1 + rank)
if rank == 0: # 进行独特操作 torch.distributed.barrier()
torch.distributed.barrier()
实现多进程同步。logging.basicConfig(level=logging.INFO if rank in [-1, 0] else logging.WARN)logging.error("This is a fatal log!")
通过以上优化,充分发挥多卡环境性能。建议结合实际场景灵活调整配置,逐步提升训练效率。更多内容请参考前两篇DDP系列文章。
转载地址:http://pgrfk.baihongyu.com/