用Pytorch基于MNIST实现手写数字识别

代码的基本结构还是延续我通过深度学习神经网络,基于MNIST实现手写数字识别 的结构,只是神经网络部分使用了Pytorch的API 。
有一些地方要多说一点,但是不展开:
1、激活函数选用了ReLU,而非之前的sigmoid,二者的不同,网上文章很多,有机会总结一下 。
2、可以跟前文的代码进行比较看,主要看train、query两个方法,感受一下Pytorch的封装 。
3、用Pytorch构建的神经网络,在训练、测试时采用对应的模式train()、eval(),主要是对BN、Dropout层进行设置,具体的情况有机会详细说一下 。
4、本次还是使用了200个隐藏层节点,学习率0.1,使用了6W+条数据用以训练,1W+条数据用以测试,激活函数分别用ReLU,Sigmoid训练了7个世代,结果如下:
Sigmoid 7世代准确率=94.21%该循环程序运行时间: 404.66520285606384Relu 7世代准确率=98.15%该循环程序运行时间: 396.1038899421692之前代码,7个世代结果:
【用Pytorch基于MNIST实现手写数字识别】准确率=97.26%该循环程序运行时间: 314.45252776145935


    推荐阅读