博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Pytorch半精度浮点型网络训练问题
阅读量:4873 次
发布时间:2019-06-11

本文共 533 字,大约阅读时间需要 1 分钟。

用Pytorch1.0进行半精度浮点型网络训练需要注意下问题:

1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()

2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可

3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。

  另外,SGD算法对于半精度和全精度计算均没有问题。

 

还有一个问题是不知道是不是网络结构比较小的原因,使用半精度的训练速度还没有全精度快。这个值得后续进一步探索。

对于上面的这个问题,的确是网络很小的情况下,在1080Ti上半精度浮点型没有很明显的优势,但是当网络变大之后,半精度浮点型要比全精度浮点型要快。但具体快多少和模型的大小以及输入样本大小有关系,我测试的是要快1/6,同时,半精度浮点型在占用内存上比较有优势,对于精度的影响尚未探究。

 将网络再变大些,epoch的次数也增大,半精度和全精度的时间差就表现出来了,在训练的时候。

转载于:https://www.cnblogs.com/yanxingang/p/10148712.html

你可能感兴趣的文章
CSS关键字
查看>>
UIAlertView
查看>>
ES6快速入门(三)类与模块
查看>>
赛博web
查看>>
Java动手动脑第四讲课堂作业
查看>>
PowerDesigner 数据建模技术视频教程
查看>>
Webpack 开发服务器代理设置解决跨域问题
查看>>
Solr 15 - Solr添加和更新索引的过程 (文档的路由细节)
查看>>
DOS命令
查看>>
Oracle merge基本使用
查看>>
03-树1 树的同构
查看>>
第九周周记
查看>>
AdvStringGrid入门使用
查看>>
C#图像处理——ImageProcessor
查看>>
NOI2004 降雨量
查看>>
WPF的TextBox水印效果详解
查看>>
oracle启动服务和监听命令
查看>>
毒药和酒
查看>>
浅谈linux内核中内存分配函数
查看>>
走近SpringBoot
查看>>