资讯专栏INFORMATION COLUMN

机器学习 | CNN卷积神经网络

ghnor / 1324人阅读

摘要:测试结果最后两行分别为预测类别与真实类别。这里将其分成了类展平多维的卷积图成训练优化器损失函数开始训练将数据输入并且得到计算与真实值之间的误差清空上一步残余更新参数值误差反向传播,让参数进行更新将更新后的参数值施加到的上测试选取个数据

测试结果

最后两行分别为预测类别与真实类别。


数据预览

这里的数据使用的是mnist数据集,大家可以将代码中的DOWNLOAD_MNIST值修改为True进行自动下载。


代码

</>复制代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.utils.data as Data
  4. import torchvision # 数据库模块
  5. import matplotlib.pyplot as plt
  6. #训练整批数据多少次,这里为了节约时间,只训练一次
  7. EPOCH=1
  8. #每次批处理50个数据
  9. BATCH_SIZE=50
  10. #学习效率
  11. LR=0.001
  12. # 如果已经下载好了mnist数据就写上False
  13. DOWNLOAD_MNIST = False
  14. #训练的数据集:Mnist手写数字
  15. train_data=torchvision.datasets.MNIST(
  16. #保存或提取数据集的位置
  17. root="./mnist/",
  18. #该数据是训练数据
  19. train=True,
  20. #转换PIL.Image or numpy.ndarraytorch.FloatTensor (C x H x W), 训练的时候 normalize[0.0, 1.0] 区间
  21. transform=torchvision.transforms.ToTensor(),
  22. #没下载就下载,下载了就不用再下了
  23. download=DOWNLOAD_MNIST,
  24. )
  25. #绘制一下数据集
  26. #黑色的地方的值都是0, 白色的地方值大于0.
  27. print(train_data.train_data.size()) # (60000, 28, 28)
  28. print(train_data.train_labels.size()) # (60000)
  29. plt.imshow(train_data.train_data[2].numpy(), cmap="gray")
  30. plt.title("%i" % train_data.train_labels[2])
  31. plt.show()
  32. #测试数据
  33. test_data=torchvision.datasets.MNIST(root="./mnist/",train=False)
  34. #批训练50samples,1 channel,28x28 (50, 1, 28, 28)
  35. train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
  36. #这里只测试了前2000个
  37. #特征
  38. test_x=torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)[:2000]/255.
  39. #标签
  40. test_y=test_data.test_labels[:2000]
  41. #构建CNN模型
  42. class CNN(nn.Module):
  43. def __init__(self):
  44. super(CNN,self).__init__()
  45. #input shape(1,28,28)
  46. self.conv1=nn.Sequential(
  47. #卷积
  48. nn.Conv2d(
  49. in_channels=1,
  50. out_channels=16,
  51. #filter size
  52. kernel_size=5,
  53. #filter movement/step
  54. stride=1,
  55. #如果想要con2d出来的图片长宽没有变化,
  56. #padding=(kernel_size-1)/2当stride=1
  57. padding=2,
  58. ),
  59. #output shape(16,28,28)
  60. #激励函数
  61. nn.ReLU(),
  62. #池化
  63. # 在2x2空间里向下采样,output shape(16,14,14)
  64. nn.MaxPool2d(kernel_size=2),
  65. )
  66. #input shape(16,14,14)
  67. self.conv2=nn.Sequential(
  68. nn.Conv2d(16,32,5,1,2),
  69. #output shape(32,14,14)
  70. #激励函数
  71. nn.ReLU(),
  72. #output shape(32,7,7)
  73. nn.MaxPool2d(2),
  74. )
  75. #全连接层——进行分类。这里将其分成了10类
  76. self.out=nn.Linear(32*7*7,10)
  77. def forward(self,x):
  78. x=self.conv1(x)
  79. x=self.conv2(x)
  80. #展平多维的卷积图成(batch_size,32*7*7)
  81. x=x.view(x.size(0),-1)
  82. output=self.out(x)
  83. return output
  84. cnn=CNN()
  85. print(cnn)
  86. #训练
  87. #优化器
  88. optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
  89. #损失函数
  90. loss_func=nn.CrossEntropyLoss()
  91. #开始训练
  92. for epoch in range(EPOCH):
  93. for step,(b_x,b_y) in enumerate(train_loader):
  94. #将数据输入nn并且得到output
  95. output=cnn(b_x)
  96. #计算output与真实值之间的误差
  97. loss=loss_func(output,b_y)
  98. #清空上一步残余更新参数值
  99. optimizer.zero_grad()
  100. #误差反向传播,让参数进行更新
  101. loss.backward()
  102. #将更新后的参数值施加到nn的parameters上
  103. optimizer.step()
  104. #测试:选取10个数据
  105. test_output=cnn(test_x[:10])
  106. pred_y=torch.max(test_output,1)[1].data.numpy().squeeze()
  107. print(pred_y, "prediction number")
  108. print(test_y[:10].numpy(), "real number")
  109. # if __name__=="__main__":
  110. # print("hello word")

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/43525.html

相关文章

  • 机器学习基础】卷积神经网络CNN)基础

    摘要:而在卷积神经网络中,这两个神经元可以共用一套参数,用来做同一件事情。卷积神经网络的基本结构卷积神经网络的基本结构如图所示从右到左,输入一张图片卷积层池化层卷积层池化层展开全连接神经网络输出。最近几天陆续补充了一些线性回归部分内容,这节继续机器学习基础部分,这节主要对CNN的基础进行整理,仅限于基础原理的了解,更复杂的内容和实践放在以后再进行总结。卷积神经网络的基本原理  前面对全连接神经网络...

    番茄西红柿 评论0 收藏2637
  • CNN超参数优化和可视化技巧详解

    摘要:在计算机视觉领域,对卷积神经网络简称为的研究和应用都取得了显著的成果。文章讨论了在卷积神经网络中,该如何调整超参数以及可视化卷积层。卷积神经网络可以完成这项任务。 在深度学习中,有许多不同的深度网络结构,包括卷积神经网络(CNN或convnet)、长短期记忆网络(LSTM)和生成对抗网络(GAN)等。在计算机视觉领域,对卷积神经网络(简称为CNN)的研究和应用都取得了显著的成果。CNN网络最...

    Fundebug 评论0 收藏0

发表评论

0条评论

最新活动
阅读需要支付1元查看
<