网站建设维护教程,私人定制平台网站,博客网站,网站用什么技术做本笔记记录使用自定义Layer和Model来做CIFAR10数据集的训练。 CIFAR10数据集下载#xff1a; https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 自定义的Layer和Model实现较为简单#xff0c;参数量较少#xff0c;并且没有卷积层和dropout等#xff0c;最终准确率… 本笔记记录使用自定义Layer和Model来做CIFAR10数据集的训练。 CIFAR10数据集下载 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 自定义的Layer和Model实现较为简单参数量较少并且没有卷积层和dropout等最终准确率不高仅做练习使用。
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__def preprocess(x, y):x tf.cast(x, dtypetf.float32) / 255y tf.cast(y, dtypetf.int32)return x,ybatchsize 128
#CIFAR10数据集下载可以直接使用网络下载
(x,y), (x_val, y_val) datasets.cifar10.load_data()
#CIFAR10的标签训练集数据维度是[50000, 1],通过squeeze消除掉里面1的维度变成[50000]
print(y.shape:, y.shape)
y tf.squeeze(y)
print(squeezed y.shape:, y.shape)
y_val tf.squeeze(y_val)
#进行onehot编码
y tf.one_hot(y, depth10)
y_val tf.one_hot(y_val, depth10)
print(Datasets: , x.shape, , y.shape, x.min():, x.min(), x.max():, x.max())train_db tf.data.Dataset.from_tensor_slices((x, y))
train_db train_db.map(preprocess).shuffle(10000).batch(batchsize)
test_db tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db test_db.map(preprocess).batch(batchsize)sample next(iter(train_db))
print(Batch:, sample[0].shape, sample[1].shape)#自定义Layer
class MyDense(layers.Layer):def __init__(self, input_dim, output_dim):super(MyDense, self).__init__()self.kernel self.add_weight(namew, shape[input_dim, output_dim], initializertf.random_uniform_initializer(0, 1.0))self.bias self.add_weight(nameb, shape[output_dim], initializertf.random_uniform_initializer(0, 1.0))#self.kernel self.add_weight(namew, shape[input_dim, output_dim])#self.bias self.add_weight(nameb, shape[output_dim])def call(self, inputs, training None):x inputsself.kernel self.biasreturn xclass MyNetwork(keras.Model):def __init__(self):super(MyNetwork, self).__init__()self.fc1 MyDense(32 * 32 * 3, 512)self.fc2 MyDense(512, 512)self.fc3 MyDense(512, 256)self.fc4 MyDense(256, 256)self.fc5 MyDense(256, 10)def call(self, inputs, training None):x tf.reshape(inputs, [-1, 32 * 32 * 3])x self.fc1(x)x tf.nn.relu(x)x self.fc2(x)x tf.nn.relu(x)x self.fc3(x)x tf.nn.relu(x)x self.fc4(x)x tf.nn.relu(x)x self.fc5(x)x tf.nn.relu(x)#返回logitsreturn xtotal_epoches 35
learn_rate 0.001
network MyNetwork()
network.compile(optimizeroptimizers.Adam(learning_ratelearn_rate),loss tf.losses.CategoricalCrossentropy(from_logitsTrue),metrics[Accuracy])
network.fit(train_db, epochstotal_epoches, validation_datatest_db, validation_freq1)
运行结果