使用 Python 实现的卷积神经网络初学者指南

磐创AI 2021-10-27 17:01:56
Python 使用 神经网络 实现 卷积


磐创AI分享

来源 | geekwire

编辑 | 白峰

目录

  1. 卷积神经网络简介
  2. 其组成部分
  • 输入层
  • 卷积层
  • 池化层
  • 全连接层
  1. CNN 在数据集上的实际实现

CNN简介

卷积神经网络是一种专为处理图像和视频而设计的深度学习算法。它以图像为输入,提取和学习图像的特征,并根据学习到的特征进行分类。

该算法的灵感来自于人脑的一部分,即视觉皮层。视觉皮层是人脑的一部分,负责处理来自外界的视觉信息。它有不同的层,每一层都有自己的功能,即每一层从图像或任何视觉中提取一些信息,最后将从每一层接收到的所有信息组合起来,对图像/视觉进行解释或分类。

同样,CNN有各种滤波器,每个滤波器从图像中提取一些信息,例如边缘、不同种类的形状(垂直、水平、圆形),然后将所有这些组合起来识别图像。

现在,这里的问题可能是:为什么我们不能将人工神经网络用于相同的目的?这是因为ANN有一些缺点:

  • 对于 ANN 模型来说,训练大尺寸图像和不同类型的图像通道的计算量太大。
  • 它无法从图像中捕获所有信息,而 CNN 模型可以捕获图像的空间依赖性。
  • 另一个原因是人工神经网络对图像中物体的位置很敏感,即如果同一物体的位置或地点发生变化,它将无法正确分类。

CNN的组成部分

CNN模型分两步工作:特征提取和分类

特征提取是将各种过滤器和图层应用于图像以从中提取信息和特征的阶段,完成后将传递到下一阶段,即分类,根据问题的目标变量对它们进行分类。

典型的 CNN 模型如下所示:

  • 输入层
  • 卷积层+激活函数
  • 池化层
  • 全连接层

来源:https://learnopencv.com/image-classification-using-convolutional-neural-networks-in-keras/

让我们详细了解每一层。

输入层

顾名思义,它是我们的输入图像,可以是灰度或 RGB。每个图像由范围从 0 到 255 的像素组成。我们需要对它们进行归一化,即在将其传递给模型之前转换 0 到 1 之间的范围。

下面是大小为 4*4 的输入图像的示例,它有 3 个通道,即 RGB 和像素值。

来源:https://medium.com/@raycad.seedotech/convolutional-neural-network-cnn-8d1908c010ab

卷积层

卷积层是将过滤器应用于我们的输入图像以提取或检测其特征的层。过滤器多次应用于图像并创建一个有助于对输入图像进行分类的特征图。让我们借助一个例子来理解这一点。为简单起见,我们将采用具有归一化像素的 2D 输入图像。

在上图中,我们有一个大小为 66 的输入图像,并对其应用了 33 的过滤器来检测一些特征。在这个例子中,我们只应用了一个过滤器,但在实践中,许多这样的过滤器被用于从图像中提取信息。

将过滤器应用于图像的结果是我们得到一个 4*4 的特征图,其中包含有关输入图像的一些信息。许多这样的特征图是在实际应用中生成的。

让我们深入了解获取上图中特征图的一些数学原理。

如上图所示,第一步过滤器应用于图像的绿色高亮部分,将图像的像素值与过滤器的值相乘(如图中使用线条所示),然后相加得到最终值。

在下一步中,过滤器将移动一列,如下图所示。这种跳转到下一列或行的过程称为 stride,在本例中,我们将 stride设为1,这意味着我们将移动一列。

类似地,过滤器通过整个图像,我们得到最终的特征图。一旦我们获得特征图,就会对其应用激活函数来引入非线性。

这里需要注意的一点是,我们得到的特征图小于我们图像的大小。随着我们增加 stride 的值,特征图的大小会减小。

这就是过滤器如何以 1 的步幅穿过整个图像

池化层

池化层应用在卷积层之后,用于降低特征图的维度,有助于保留输入图像的重要信息或特征,并减少计算时间。

使用池化,可以创建一个较低分辨率的输入版本,该版本仍然包含输入图像的大元素或重要元素。

最常见的池化类型是最大池化和平均池化。

下图显示了最大池化的工作原理。使用我们从上面的例子中得到的特征图来应用池化。这里我们使用了一个大小为 2*2的池化层,步长为 2。

取每个突出显示区域的最大值,并获得大小为 2*2的新版本输入图像,因此在应用池化后,特征图的维数减少了。

全连接层

到目前为止,我们已经执行了特征提取步骤,现在是分类部分。全连接层(如我们在 ANN 中所使用的)用于将输入图像分类为标签。该层将从前面的步骤(即卷积层和池化层)中提取的信息连接到输出层,并最终将输入分类为所需的标签。

CNN 模型的完整过程可以在下图中看到。

来源:https://developersbreach.com/convolution-neural-network-deep-learning/

CNN在Python中的实现

我们将使用 Mnist Digit 分类数据集,我们在ANN的实际实现的上一篇博客中使用了该数据集。为了更好地理解CNN的应用,请先参考上一篇博客:https://www.analyticsvidhya.com/blog/2021/08/implementing-artificial-neural-network-on-unstructured-data/

#importing the required libraries
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Dense
#loading data
(X_train,y_train) , (X_test,y_test)=mnist.load_data()
#reshaping data
X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], X_train.shape[2], 1))
X_test = X_test.reshape((X_test.shape[0],X_test.shape[1],X_test.shape[2],1))
#checking the shape after reshaping
print(X_train.shape)
print(X_test.shape)
#normalizing the pixel values
X_train=X_train/255
X_test=X_test/255
#defining model
model=Sequential()
#adding convolution layer
model.add(Conv2D(32,(3,3),activation='relu',input_shape=(28,28,1)))
#adding pooling layer
model.add(MaxPool2D(2,2))
#adding fully connected layer
model.add(Flatten())
model.add(Dense(100,activation='relu'))
#adding output layer
model.add(Dense(10,activation='softmax'))
#compiling the model
model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
#fitting the model
model.fit(X_train,y_train,epochs=10)

输出:

#evaluting the model
model.evaluate(X_test,y_test)

尾注

希望这篇文章对你有所帮助。

本文分享自微信公众号 - 磐创AI(xunixs)

原文出处及转载信息见文内详细说明,如有侵权,请联系 [email protected] 删除。

原始发表时间: 2021-10-23

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

版权声明
本文为[磐创AI]所创,转载请带上原文链接,感谢
https://cloud.tencent.com/developer/article/1893999

  1. 【算法学习】237. 删除链表中的节点(java / c / c++ / python / go)
  2. 【算法学习】1672. 最富有客户的资产总量(java / c / c++ / python / go / rust)
  3. 【算法学习】771. 宝石与石头(java / c / c++ / python / go / rust)
  4. 【算法学习】02.03. 删除中间节点(java / c / c++ / python / go)
  5. 【算法学习】1769. 移动所有球到每个盒子所需的最小操作数(java / c / c++ / python / go / rust)
  6. 【算法学习】1486. 数组异或操作(java / c / c++ / python / go / rust)
  7. 【算法学习】LCP 44. 开幕式焰火(java / c / c++ / python / go / rust)
  8. 【算法学习】剑指 Offer 58 - II. 左旋转字符串(java / c / c++ / python / go / rust)
  9. python的学校疑问难题求解
  10. 大学python题 作业题 基础题
  11. Python字典的知识,输出的样例为,最高分:89
  12. python写入文件失败且程序提前中止
  13. 用Python写一个学生字典,帮帮忙
  14. Python,能不能帮帮忙,真的不会
  15. [python] yield 和 readline() 的使用问题
  16. python安装找不到问题救救孩子
  17. python中循环结构完成数字游戏
  18. 如何用python实现多列vlookup(excle操作)
  19. python语言deLong‘s test:通过统计学的角度来比较两个ROC曲线、检验两个ROC曲线的差异是否具有统计显著性
  20. LPC55S69 MicroPython模组和库函数
  21. LPC55S69 IoT Kit专属 Micropython模组和库函数简介
  22. 安装LPC55S69 MicroPython模块是遇到的CDC Interface驱动问题
  23. 使用soundcard在Python中操作声卡
  24. 自动化快速上手--Python(7)--【字典】--每天半小时
  25. Python之循环结构【包括列表、for语句、range()函数、while语句、循环嵌套、break、continue、算法优化等】
  26. Python模块安装与异常处理详解(numpy、pygame、matplotlib等)
  27. Python__init__.py作用
  28. python 爬取网页时出现多种错误
  29. Python中关于大量绘制速度曲线的问题
  30. python-async的安装和使用方法
  31. Matlab的fread(fild,1,int32)迁移到python变成什么
  32. 想用python开发一个音频过滤器,请指导?
  33. python使用openpyxl读取Excel文件显示No such file or directory
  34. xmoji虚拟头像交互如何使用python(像深度学习)制作?
  35. python 打开页面页面的链接,为什么总是报错呀?
  36. Python中DataLoader的batch_size、shuffle的疑惑。
  37. python安装pymssql库,可以import,但无法调用函数
  38. 【Python学习教程】常用的8个Python数据可视化库!
  39. python处理csv中的时间
  40. 数据结构,元音统计(Python)
  41. python的site-packages复制直接到其他电脑环境上能用吗
  42. Pycharm如何给项目配置python解释器
  43. conda创建python虚拟环境
  44. Python selenium的爬虫无法完整爬取整个页面的内容
  45. 高清版!这18张 Python 数据科学速查表,让你的代码变得更强大!
  46. python代码不会敲,请好心老哥帮助我一下
  47. Python敲七输出符合的个数
  48. Python 有人能给提供简单的思路嘛
  49. python单次运行写入csv成功,循环写入失败
  50. python利用os模块进行增量备份
  51. 【算法学习】807. 保持城市天际线(java / c / c++ / python / go / rust)
  52. 如何利用python输出等腰杨辉三角
  53. python按键执行倒计时小程序不能实现要求,要怎么改才好?
  54. Python request模块post请求的问题
  55. Django连接已有Oracle时的主键设置问题,没主键无法查询怎么办?
  56. 如何用python的dictionary编写一个联系人通讯录程序
  57. 如果Python里range反向输出,不输出步长会怎么样?
  58. 一个关于Python pip的问题: 出现Cannot open \python\Scripts\pip-script.py报错
  59. 富婆闺蜜非让我用Python给她写个淘宝双十一抢购脚本,那只能安排了
  60. 【全网最全】python正则表达式大全,所有讲解都在这,包教包会,学不会找我!