加入收藏 | 设为首页 | 会员中心 | 我要投稿 威海站长网 (https://www.0631zz.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 运营中心 > 建站资源 > 优化 > 正文

代码详解:用Pytorch训练快速神经网络的9个技巧

发布时间:2019-08-19 10:17:21 所属栏目:优化 来源:读芯术
导读:事实上,你的模型可能还停留在石器时代的水平。估计你还在用32位精度或*GASP(一般活动仿真语言)*训练,甚至可能只在单GPU上训练。如果市面上有99个加速指南,但你可能只看过1个?(没错,就是这样)。但这份终极指南,会一步步教你清除模型中所有的(GP模型)。

(4) 使用多GPUs时需注意的事项

  • 如果该设备上已存在model.cuda(),那么它不会完成任何操作。
  • 始终输入到设备列表中的第一个设备上。
  • 跨设备传输数据非常昂贵,不到万不得已不要这样做。
  • 优化器和梯度将存储在GPU 0上。因此,GPU 0使用的内存很可能比其他处理器大得多。

9. 多节点GPU训练

代码详解:用Pytorch训练快速神经网络的9个技巧

每台机器上的各GPU都可获取一份模型的副本。每台机器分得一部分数据,并仅针对该部分数据进行训练。各机器彼此同步梯度。

做到了这一步,就可以在几分钟内训练Imagenet数据集了! 这没有想象中那么难,但需要更多有关计算集群的知识。这些指令假定你正在集群上使用SLURM。

Pytorch在各个GPU上跨节点复制模型并同步梯度,从而实现多节点训练。因此,每个模型都是在各GPU上独立初始化的,本质上是在数据的一个分区上独立训练的,只是它们都接收来自所有模型的梯度更新。

高级阶段:

  • 在各GPU上初始化一个模型的副本(确保设置好种子,使每个模型初始化到相同的权值,否则操作会失效。)
  • 将数据集分成子集。每个GPU只在自己的子集上训练。
  • On .backward() 所有副本都会接收各模型梯度的副本。只有此时,模型之间才会相互通信。

Pytorch有一个很好的抽象概念,叫做分布式数据并行处理,它可以为你完成这一操作。要使用DDP(分布式数据并行处理),需要做4件事:

  1. def tng_dataloader(): 
  2.       
  3. d = MNIST() 
  4.  
  5.      # 4: Add distributed sampler 
  6.      # sampler sends a portion of tng data to each machine 
  7.      dist_sampler = DistributedSampler(dataset) 
  8.      dataloader = DataLoader(d, shuffle=False, sampler=dist_sampler) 
  9.  
  10. def main_process_entrypoint(gpu_nb):  
  11.      # 2: set up connections  between all gpus across all machines 
  12.      # all gpus connect to a single GPU "root" 
  13.      # the default uses env:// 
  14.  
  15.      world = nb_gpus * nb_nodes 
  16.      dist.init_process_group("nccl", rank=gpu_nb, worldworld_size=world) 
  17.  
  18.      # 3: wrap model in DPP 
  19.      torch.cuda.set_device(gpu_nb) 
  20.      model.cuda(gpu_nb) 
  21.      model = DistributedDataParallel(model, device_ids=[gpu_nb]) 
  22.  
  23.      # train your model now... 
  24.  
  25. if  __name__ == '__main__':  
  26.      # 1: spawn number of processes 
  27.      # your cluster will call main for each machine 
  28.      mp.spawn(main_process_entrypoint, nprocs=8) 

Pytorch团队对此有一份详细的实用教程

(https://github.com/pytorch/examples/blob/master/imagenet/main.py?source=post_page---------------------------)。

(编辑:威海站长网)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

热点阅读