当前位置: 首页 > news >正文

用Python实现神经网络(六)

传统神经网络的一个局限是它对于图像平移变换--即一个猫在右上角的图片与猫在中心的图片是不同对待的。卷积神经网络Convolutional neural networks (CNNs)用于处理这种问题 。 因为CNN可以处理图像的平移,它被认为很有用,而且CNN架构被认为是目标识别/检测最选进的技术。

要理解为什么需要CNN,我们从一个例子开始。假如我们要分类一张图像里是否有垂直的线 (可能是告诉我们是否有1存在)。 为了简单起见,我们假定图像是5 × 5像素大小的。垂直的线(数字1)可以用一些方法表示 :

们也可以检查MNIST数据集里数字1的不同表示方法。一张数字1的图片见图9-1

9-1. 与数字1对应的图像的像素

在图里,越红的地方就是我们越常写的地方,模糊的地方是比较少写的地方。中间的像素最红,因为人们最常在那个地方写,不管他们写1的角度如何--垂直或向左或向右斜。在下面的一节,你会注意到神经网络预测不会准确当图像平移一些像素时。在后一节,我们会理解CNN如何解决图像平移的问题。

传统神经网络的问题

刚才提到的情况,传统的神经网络突出图像为1仅当中间的像素被突出而别的像素不突出时(因为许多人在中间突出像素)。

要更好的理解这个问题 ,我们看一下代码 :

    1. 下载数据集并提取训练集和测试集:

from keras.datasets import mnist import matplotlib.pyplot as plt

%matplotlib inline

# load (downloaded if needed) the MNIST dataset  (X_train,  y_train),  (X_test,   y_test)  =   mnist.load_data()

# plot 4 images as gray scale plt.subplot(221)

plt.imshow(X_train[0],   cmap=plt.get_cmap('gray')) plt.subplot(222)

plt.imshow(X_train[1],   cmap=plt.get_cmap('gray')) plt.subplot(223)

plt.imshow(X_train[2],   cmap=plt.get_cmap('gray')) plt.subplot(224)

plt.imshow(X_train[3],   cmap=plt.get_cmap('gray'))

# show the plot plt.show()

    1. 导入相关的包:

import numpy as np

from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense

from keras.layers import Dropout from  keras.layers  import  Flatten

from   keras.layers.convolutional   import   Conv2D

from keras.layers.convolutional import MaxPooling2D from  keras.utils  import  np_utils

from  keras  import  backend  as  K

    1. 获取训练集里的数字1:

X_train1  =  X_train[y_train==1]

    1. 变改形状变归一化数据集:

num_pixels  =  X_train.shape[1]  *  X_train.shape[2] X_train    =    X_train.reshape(X_train.shape[0],num_pixels

).astype('float32')

X_test  =  X_test.reshape(X_test.shape[0],num_pixels). astype('float32')

X_train = X_train / 255 X_test = X_test / 255

    1. 独热编码标签:

y_train    =    np_utils.to_categorical(y_train) y_test    =    np_utils.to_categorical(y_test) num_classes = y_train.shape[1]

    1. 构建模型并运行它:

model   =   Sequential()

model.add(Dense(1000,  input_dim=num_pixels,  activation='relu')) model.add(Dense(num_classes,  activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=[''accuracy'])

model.fit(X_train,  y_train,  validation_data=(X_test,  y_test), epochs=5, batch_size=1024, verbose=1)

我们绘制平均1的标签:

pic=np.zeros((28,28)) pic2=np.copy(pic)

for i in range(X_train1.shape[0]): pic2=X_train1[i,:,:]

pic=pic+pic2 pic=(pic/X_train1.shape[0]) plt.imshow(pic)

图 9-2显示结果。

9-2.  平均 1图像

情况 1

在这种情况,新的图像被创建 (图 9-3)通过原始图像向左平移一个像素:

for i in range(pic.shape[0]): if i<20:

pic[:,i]=pic[:,i+1] plt.imshow(pic)

9-3. 平均 1图像向左平移一个像素

我们继续并用构建的模型预测图9-3的标签:

model.predict(pic.reshape(1,784))

我们看到错误的预测为8作为输出。

情况 2

不平移原始平均1图像创建新的图像 (图 9-4):

pic=np.zeros((28,28)) pic2=np.copy(pic)

for i in range(X_train1.shape[0]): pic2=X_train1[i,:,:] pic=pic+pic2

pic=(pic/X_train1.shape[0]) plt.imshow(pic)

9-4.  平均 1图像

这个图像的预测如下:

model.predict(pic.reshape(1,784))

我们看到准确的预测1作为输出。

情况 3

通过原始的平均1的图像向右平移一个像素创建新的图像( 9-5):

pic=np.zeros((28,28)) pic2=np.copy(pic)

for i in range(X_train1.shape[0]): pic2=X_train1[i,:,:] pic=pic+pic2

pic=(pic/X_train1.shape[0]) pic2=np.copy(pic)

for i in range(pic.shape[0]): if ((i>6) and (i<26)):

pic[:,i]=pic2[:,(i-1)] plt.imshow(pic)

9-5. 平均1图像向右平移一个像素

我们继续并用构建的模型预测上面的图像:

model.predict(pic.reshape(1,784))

我们准确的预测1作为输出。

情况 4

通过向右平移原始的平均1图像2个像素创建图像 ( 9-6):

pic=np.zeros((28,28)) pic2=np.copy(pic)

for i in range(X_train1.shape[0]): pic2=X_train1[i,:,:] pic=pic+pic2

pic=(pic/X_train1.shape[0]) pic2=np.copy(pic)

for i in range(pic.shape[0]): if ((i>6) and (i<26)):

pic[:,i]=pic2[:,(i-2)] plt.imshow(pic)

9-6.平均1图像向右平移2个像素

我们用构建的模型预测图像的标签:

model.predict(pic.reshape(1,784))

我们看到错误的预测 3作为输出。

从上面的情况,你可以看到传统的神经网络对于平移的数据不会得到很好的结果。这些情况需要不同的网络来处理。这就是卷积神经网络convolutional neural network (CNN) 产生的原因。

http://www.lryc.cn/news/596056.html

相关文章:

  • 【计算机网络 篇】TCP基本认识和TCP三次握手相关问题
  • WebSocket心跳机制实现要点
  • 深入浅出理解 TCP 与 UDP:网络传输协议的核心差异与应用
  • 基于SpringBoot+Vue的高校特长互助系统(WebSocket实时聊天、协同过滤算法、ECharts图形化分析)
  • JavaScript,发生异常,try...catch...finally处理,继续向上层调用者传递异常信息
  • zabbix“专家坐诊”第295期问答
  • 服务器无法访问公网的原因及解决方案
  • 在 WebSocket 中使用 @Autowired 时遇到空指针异常
  • XML高效处理类 - 专为Office文档XML处理优化
  • 智能制造——解读52页汽车设计制造一体化整车产品生命周期PLM解决方案【附全文阅读】
  • 智慧制造合同解决方案
  • React 项目性能优化概要
  • 客户案例 | Jabil 整合 IT 与运营,大规模转型制造流程
  • 厚铜板载流革命与精密压合工艺——高可靠性PCB批量制造的新锚点
  • 中小制造企业如何对技术图纸进行管理?
  • OneCode 3.0 @FormAnnotation 注解速查手册
  • 漫画版:细说金仓数据库
  • Qt/C++源码/监控设备模拟器/支持onvif和gb28181/多路批量模拟/虚拟监控摄像头
  • 秋招Day17 - Spring - AOP
  • 《基于蛋白质组学的精准医学》:研究进展与未来展望
  • 双指针算法介绍及使用(上)
  • GitHub 上的开源项目 ticktick(滴答清单)
  • MSTP技术
  • 【加解密与C】Rot系列(四)RotSpecial
  • 解决http下浏览器无法开启麦克风问题
  • haproxy七层均衡
  • n1 armbian docker compose 部署aipan mysql
  • 理解后端开发中的API设计原则
  • 清华大学顶刊发表|破解无人机抓取与投递难题
  • 第三章 Freertos物联网实战esp8266模块