广州网站建设电话,工商局注册官网入口,网站做关键词库的作用,微信二维码文章目录 数据准备建立模型建立输入层 x建立隐藏层h1建立隐藏层h2建立输出层 定义训练方式建立训练数据label真实值 placeholder定义loss function选择optimizer 定义评估模型的准确率计算每一项数据是否正确预测将计算预测正确结果#xff0c;加总平均 开始训练画出误差执行结… 文章目录 数据准备建立模型建立输入层 x建立隐藏层h1建立隐藏层h2建立输出层 定义训练方式建立训练数据label真实值 placeholder定义loss function选择optimizer 定义评估模型的准确率计算每一项数据是否正确预测将计算预测正确结果加总平均 开始训练画出误差执行结果画出准确率执行结果 评估模型的准确率进行预测找出预测错误 GITHUB地址https://github.com/fz861062923/TensorFlow 注意下载数据连接的是外网有一股神秘力量让你403
数据准备
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_datamnist input_data.read_data_sets(MNIST_data/, one_hotTrue)C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from float to np.floating is deprecated. In future, it will be treated as np.float64 np.dtype(float).type.from ._conv import register_converters as _register_convertersWARNING:tensorflow:From ipython-input-1-2ee827ab903d:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.locals.wrap.locals.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.print(train images :, mnist.train.images.shape,labels: , mnist.train.labels.shape)
print(validation images:, mnist.validation.images.shape, labels: , mnist.validation.labels.shape)
print(test images :, mnist.test.images.shape,labels: , mnist.test.labels.shape)train images : (55000, 784) labels: (55000, 10)
validation images: (5000, 784) labels: (5000, 10)
test images : (10000, 784) labels: (10000, 10)建立模型
def layer(output_dim,input_dim,inputs, activationNone):#激活函数默认为NoneW tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重Wb tf.Variable(tf.random_normal([1, output_dim]))XWb tf.matmul(inputs, W) bif activation is None:outputs XWbelse:outputs activation(XWb)return outputs建立输入层 x
x tf.placeholder(float, [None, 784])建立隐藏层h1
h1layer(output_dim1000,input_dim784,inputsx ,activationtf.nn.relu)
建立隐藏层h2
h2layer(output_dim1000,input_dim1000,inputsh1 ,activationtf.nn.relu) 建立输出层
y_predictlayer(output_dim10,input_dim1000,inputsh2,activationNone)定义训练方式
建立训练数据label真实值 placeholder
y_label tf.placeholder(float, [None, 10])#训练数据的个数很多所以设置为None定义loss function
# 深度学习模型的训练中使用交叉熵训练的效果比较好
loss_function tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logitsy_predict , labelsy_label))选择optimizer
optimizer tf.train.AdamOptimizer(learning_rate0.001) \.minimize(loss_function)
#使用Loss_function来计算误差并且按照误差更新模型权重与偏差使误差最小化定义评估模型的准确率
计算每一项数据是否正确预测
correct_prediction tf.equal(tf.argmax(y_label , 1),tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较将计算预测正确结果加总平均
accuracy tf.reduce_mean(tf.cast(correct_prediction, float))开始训练
trainEpochs 15#执行15个训练周期
batchSize 100#每一批的数量为100
totalBatchs int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
epoch_list[];accuracy_list[];loss_list[];
from time import time
startTimetime()
sess tf.Session()
sess.run(tf.global_variables_initializer())for epoch in range(trainEpochs):#执行15个训练周期#每个训练周期执行550批次训练for i in range(totalBatchs):batch_x, batch_y mnist.train.next_batch(batchSize)#用该函数批次读取数据sess.run(optimizer,feed_dict{x: batch_x,y_label: batch_y})#使用验证数据计算准确率loss,acc sess.run([loss_function,accuracy],feed_dict{x: mnist.validation.images, #验证数据的featuresy_label: mnist.validation.labels})#验证数据的labelepoch_list.append(epoch)loss_list.append(loss);accuracy_list.append(acc) print(Train Epoch:, %02d % (epoch1), \Loss,{:.9f}.format(loss), Accuracy,acc)duration time()-startTime
print(Train Finished takes:,duration) Train Epoch: 01 Loss 133.117172241 Accuracy 0.9194
Train Epoch: 02 Loss 88.949943542 Accuracy 0.9392
Train Epoch: 03 Loss 80.701606750 Accuracy 0.9446
Train Epoch: 04 Loss 72.045913696 Accuracy 0.9506
Train Epoch: 05 Loss 71.911483765 Accuracy 0.9502
Train Epoch: 06 Loss 63.642936707 Accuracy 0.9558
Train Epoch: 07 Loss 67.192626953 Accuracy 0.9494
Train Epoch: 08 Loss 55.959281921 Accuracy 0.9618
Train Epoch: 09 Loss 58.867351532 Accuracy 0.9592
Train Epoch: 10 Loss 61.904548645 Accuracy 0.9612
Train Epoch: 11 Loss 58.283069611 Accuracy 0.9608
Train Epoch: 12 Loss 54.332244873 Accuracy 0.9646
Train Epoch: 13 Loss 58.152175903 Accuracy 0.9624
Train Epoch: 14 Loss 51.552104950 Accuracy 0.9688
Train Epoch: 15 Loss 52.803482056 Accuracy 0.9678
Train Finished takes: 545.0556836128235画出误差执行结果
%matplotlib inline
import matplotlib.pyplot as plt
fig plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)#设置图的大小
plt.plot(epoch_list, loss_list, label loss)
plt.ylabel(loss)
plt.xlabel(epoch)
plt.legend([loss], locupper left)matplotlib.legend.Legend at 0x1edb8d4c240画出准确率执行结果
plt.plot(epoch_list, accuracy_list,labelaccuracy )
fig plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel(accuracy)
plt.xlabel(epoch)
plt.legend()
plt.show()评估模型的准确率
print(Accuracy:, sess.run(accuracy,feed_dict{x: mnist.test.images, y_label: mnist.test.labels}))Accuracy: 0.9643进行预测
prediction_resultsess.run(tf.argmax(y_predict,1),feed_dict{x: mnist.test.images })prediction_result[:10]array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtypeint64)import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,idx,num10):fig plt.gcf()fig.set_size_inches(12, 14)if num25: num25 for i in range(0, num):axplt.subplot(5,5, 1i)ax.imshow(np.reshape(images[idx],(28, 28)), cmapbinary)title label str(np.argmax(labels[idx]))if len(prediction)0:title,predictstr(prediction[idx]) ax.set_title(title,fontsize10) ax.set_xticks([]);ax.set_yticks([]) idx1 plt.show()plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,0)y_predict_Onehotsess.run(y_predict,feed_dict{x: mnist.test.images })y_predict_Onehot[8]array([-6185.544 , -5329.589 , 1897.1707 , -3942.7764 , 347.9809 ,5513.258 , 6735.7153 , -5088.5273 , 649.2062 , 69.50408],dtypefloat32)找出预测错误
for i in range(400):if prediction_result[i]!np.argmax(mnist.test.labels[i]):print(istr(i) label,np.argmax(mnist.test.labels[i]),predict,prediction_result[i])i8 label 5 predict 6
i18 label 3 predict 8
i149 label 2 predict 4
i151 label 9 predict 8
i233 label 8 predict 7
i241 label 9 predict 8
i245 label 3 predict 5
i247 label 4 predict 2
i259 label 6 predict 0
i320 label 9 predict 1
i340 label 5 predict 3
i381 label 3 predict 7
i386 label 6 predict 5sess.close()