
- import os
- import warnings
- warnings.filterwarnings("ignore")
- import tensorflow as tf
- from tensorflow.keras.optimizers import Adam
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
- # 数据所在文件夹
- base_dir = './data/cats_and_dogs'
- train_dir = os.path.join(base_dir, 'train')
- validation_dir = os.path.join(base_dir, 'validation')
-
- # 训练集
- train_cats_dir = os.path.join(train_dir, 'cats')
- train_dogs_dir = os.path.join(train_dir, 'dogs')
-
- # 验证集
- validation_cats_dir = os.path.join(validation_dir, 'cats')
- validation_dogs_dir = os.path.join(validation_dir, 'dogs')
- model = tf.keras.models.Sequential([
- #如果训练慢,可以把数据设置的更小一些
- tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),
- tf.keras.layers.MaxPooling2D(2, 2),
-
- tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
- tf.keras.layers.MaxPooling2D(2,2),
-
- tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
- tf.keras.layers.MaxPooling2D(2,2),
-
- #为全连接层准备
- tf.keras.layers.Flatten(),
-
- tf.keras.layers.Dense(512, activation='relu'),
- # 二分类sigmoid就够了
- tf.keras.layers.Dense(1, activation='sigmoid')
- ])
model.summary() 
配置训练器
- model.compile(loss='binary_crossentropy',
- optimizer=Adam(lr=1e-4),
- metrics=['acc'])
- train_datagen = ImageDataGenerator(rescale=1./255)
- test_datagen = ImageDataGenerator(rescale=1./255)
- train_generator = train_datagen.flow_from_directory(
- train_dir, # 文件夹路径
- target_size=(64, 64), # 指定resize成的大小
- batch_size=20,
- # 如果one-hot就是categorical,二分类用binary就可以
- class_mode='binary')
-
- validation_generator = test_datagen.flow_from_directory(
- validation_dir,
- target_size=(64, 64),
- batch_size=20,
- class_mode='binary')
- history = model.fit_generator(
- train_generator,
- steps_per_epoch=100, # 2000 images = batch_size * steps
- epochs=20,
- validation_data=validation_generator,
- validation_steps=50, # 1000 images = batch_size * steps
- verbose=2)
Epoch 1/20 100/100 - 7s - loss: 0.6892 - acc: 0.5325 - val_loss: 0.6705 - val_acc: 0.5970 Epoch 2/20 100/100 - 6s - loss: 0.6595 - acc: 0.6055 - val_loss: 0.6346 - val_acc: 0.6470 Epoch 3/20 100/100 - 6s - loss: 0.6350 - acc: 0.6515 - val_loss: 0.6358 - val_acc: 0.6320 Epoch 4/20 100/100 - 7s - loss: 0.5936 - acc: 0.6865 - val_loss: 0.5906 - val_acc: 0.6780 Epoch 5/20 100/100 - 7s - loss: 0.5530 - acc: 0.7170 - val_loss: 0.5978 - val_acc: 0.6670 Epoch 6/20 100/100 - 8s - loss: 0.5179 - acc: 0.7490 - val_loss: 0.5484 - val_acc: 0.7140 Epoch 7/20 100/100 - 8s - loss: 0.4854 - acc: 0.7725 - val_loss: 0.5686 - val_acc: 0.7080 Epoch 8/20 100/100 - 8s - loss: 0.4595 - acc: 0.7905 - val_loss: 0.5452 - val_acc: 0.7150 Epoch 9/20 100/100 - 8s - loss: 0.4406 - acc: 0.7885 - val_loss: 0.5453 - val_acc: 0.7210 Epoch 10/20 100/100 - 7s - loss: 0.4109 - acc: 0.8170 - val_loss: 0.5317 - val_acc: 0.7270 Epoch 11/20 100/100 - 8s - loss: 0.3892 - acc: 0.8285 - val_loss: 0.5384 - val_acc: 0.7220 Epoch 12/20 100/100 - 8s - loss: 0.3542 - acc: 0.8570 - val_loss: 0.5480 - val_acc: 0.7180 Epoch 13/20 100/100 - 8s - loss: 0.3421 - acc: 0.8580 - val_loss: 0.5355 - val_acc: 0.7420 Epoch 14/20 100/100 - 8s - loss: 0.3217 - acc: 0.8665 - val_loss: 0.5572 - val_acc: 0.7340 Epoch 15/20 100/100 - 8s - loss: 0.2931 - acc: 0.8805 - val_loss: 0.5545 - val_acc: 0.7400 Epoch 16/20 100/100 - 8s - loss: 0.2739 - acc: 0.8870 - val_loss: 0.5540 - val_acc: 0.7360 Epoch 17/20 100/100 - 8s - loss: 0.2535 - acc: 0.9040 - val_loss: 0.5564 - val_acc: 0.7380 Epoch 18/20 100/100 - 8s - loss: 0.2257 - acc: 0.9245 - val_loss: 0.5710 - val_acc: 0.7420 Epoch 19/20 100/100 - 8s - loss: 0.2084 - acc: 0.9350 - val_loss: 0.5734 - val_acc: 0.7460 Epoch 20/20 100/100 - 8s - loss: 0.2258 - acc: 0.9130 - val_loss: 0.5897 - val_acc: 0.7300
- import matplotlib.pyplot as plt
- acc = history.history['acc']
- val_acc = history.history['val_acc']
- loss = history.history['loss']
- val_loss = history.history['val_loss']
-
- epochs = range(len(acc))
-
- plt.plot(epochs, acc, 'bo', label='Training accuracy')
- plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
- plt.title('Training and validation accuracy')
-
- plt.figure()
-
- plt.plot(epochs, loss, 'bo', label='Training Loss')
- plt.plot(epochs, val_loss, 'b', label='Validation Loss')
- plt.title('Training and validation loss')
- plt.legend()
-
- plt.show()

