• TensorFlow 2.9的零零碎碎(六)-模型训练和评价


    目录

    fit函数的定义在哪里?

    fit函数的参数

    x

    y

    epochs

    evaluate函数的定义在哪里?

    evaluate函数的参数

    fit函数源码

    evaluate函数源码


    TensorFlow 2.9的零零碎碎(二)-TensorFlow 2.9的零零碎碎(六)都是围绕使用TensorFlow 2.9在MNIST数据集上训练和评价模型来展开。

    Python环境3.8。

    代码调试都用的PyCharm。

    在构建好数据、模型,完成模型编译之后,接下来的就是模型训练,使用fit函数,以及模型评价,使用evaluate函数。

    1. import tensorflow as tf
    2. mnist = tf.keras.datasets.mnist
    3. (x_train, y_train), (x_test, y_test) = mnist.load_data()
    4. x_train, x_test = x_train / 255.0, x_test / 255.0
    5. model = tf.keras.models.Sequential()
    6. model.add(tf.keras.layers.Input(shape=(28, 28)))
    7. model.add(tf.keras.layers.Flatten())
    8. model.add(tf.keras.layers.Dense(128, activation='relu'))
    9. model.add(tf.keras.layers.Dropout(0.2))
    10. model.add(tf.keras.layers.Dense(10, activation='softmax'))
    11. model.compile(optimizer='adam',
    12. loss='sparse_categorical_crossentropy',
    13. metrics=['acc'])
    14. model.fit(x_train, y_train, epochs=5)
    15. model.evaluate(x_test, y_test, verbose=2)

     其实代码写到这,也就写完了,看看代码的运行结果

    fit函数的定义在哪里?

    通过model = tf.keras.models.Sequential(),我们知道model是Sequential类的一个对象,所以我们去Sequential类里看看

    层层跳转,在keras.engine.sequential找到Sequential类的定义

    1. @keras_export('keras.Sequential', 'keras.models.Sequential')
    2. class Sequential(functional.Functional):
    3. """`Sequential` groups a linear stack of layers into a `tf.keras.Model`.
    4. `Sequential` provides training and inference features on this model.

    可以看出,Sequential类里并没有fit函数

    但是我们发现Sequential类继承了functional模块的Functional类,继续跳转

    1. # pylint: disable=g-classes-have-attributes
    2. class Functional(training_lib.Model):
    3. """A `Functional` model is a `Model` defined as a directed graph of layers.
    4. Three types of `Model` exist: subclassed `Model`, `Functional` model,
    5. and `Sequential` (a special case of `Functional`).
    6. In general, more Keras features are supported with `Functional`
    7. than with subclassed `Model`s, specifically:
    8. - Model cloning (`keras.models.clone`)
    9. - Serialization (`model.get_config()/from_config`, `model.to_json()`
    10. - Whole-model saving (`model.save()`)
    11. A `Functional` model can be instantiated by passing two arguments to
    12. `__init__`. The first argument is the `keras.Input` Tensors that represent
    13. the inputs to the model. The second argument specifies the output
    14. tensors that represent the outputs of this model. Both arguments can be a
    15. nested structure of tensors.

    同样,Functional类里并没有fit函数

    但是我们发现Functional类继承了training模块的Model类

    注意不是training_lib模块,因为training_lib只是一个别名(from keras.engine import training as training_lib)

    继续跳转

    1. @keras_export('keras.Model', 'keras.models.Model')
    2. class Model(base_layer.Layer, version_utils.ModelVersionSelector):
    3. """`Model` groups layers into an object with training and inference features.

    在training模块的Model类中定义了fit函数。源码太长了,这里给一部分,全部的版本贴在最后。

    1. def fit(self,
    2. x=None,
    3. y=None,
    4. batch_size=None,
    5. epochs=1,
    6. verbose='auto',
    7. callbacks=None,
    8. validation_split=0.,
    9. validation_data=None,
    10. shuffle=True,
    11. class_weight=None,
    12. sample_weight=None,
    13. initial_epoch=0,
    14. steps_per_epoch=None,
    15. validation_steps=None,
    16. validation_batch_size=None,
    17. validation_freq=1,
    18. max_queue_size=10,
    19. workers=1,
    20. use_multiprocessing=False):
    21. """Trains the model for a fixed number of epochs (iterations on a dataset).

    fit函数的参数

    fit函数的参数比较多,但在上面我们训练模型的代码里其实就用了3个参数,分别是x(输入的数据)、y(输入数据的标签)、epochs(训练的轮数)。

    x

    x也就是输入的数据,可以是多种类型,比如Numpy数组、TensorFlow张量。

    1. x: Input data. It could be:
    2. - A Numpy array (or array-like), or a list of arrays
    3. (in case the model has multiple inputs).
    4. - A TensorFlow tensor, or a list of tensors
    5. (in case the model has multiple inputs).
    6. - A dict mapping input names to the corresponding array/tensors,
    7. if the model has named inputs.
    8. - A `tf.data` dataset. Should return a tuple
    9. of either `(inputs, targets)` or
    10. `(inputs, targets, sample_weights)`.
    11. - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
    12. or `(inputs, targets, sample_weights)`.
    13. - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a
    14. callable that takes a single argument of type
    15. `tf.distribute.InputContext`, and returns a `tf.data.Dataset`.
    16. `DatasetCreator` should be used when users prefer to specify the
    17. per-replica batching and sharding logic for the `Dataset`.
    18. See `tf.keras.utils.experimental.DatasetCreator` doc for more
    19. information.
    20. A more detailed description of unpacking behavior for iterator types
    21. (Dataset, generator, Sequence) is given below. If using
    22. `tf.distribute.experimental.ParameterServerStrategy`, only
    23. `DatasetCreator` type is supported for `x`.

    TensorFlow前面的版本中,还有一个函数叫fit_generator。传进去的不是数组或者张量,而是生成器,生成器里会同时生成输入数据和输入数据的标签,主要用于资源有限(如GPU资源)、数据量比较大的情况,将大规模的数据切片,分别传入进行模型训练。

    1. def fit_generator(self,
    2. generator,
    3. steps_per_epoch=None,
    4. epochs=1,
    5. verbose=1,
    6. callbacks=None,
    7. validation_data=None,
    8. validation_steps=None,
    9. validation_freq=1,
    10. class_weight=None,
    11. max_queue_size=10,
    12. workers=1,
    13. use_multiprocessing=False,
    14. shuffle=True,
    15. initial_epoch=0):
    16. """Fits the model on data yielded batch-by-batch by a Python generator.

    但在TensorFlow 2.9中,fit_generator函数已经被弃用了,这是因为fit函数将fit_generator函数的功能集成了进去。

    现在的fit函数的x参数一样接受生成器作为参数。

    y

    y是输入数据的标签,数据类型和x类似,比如可以是Numpy数组或者TensorFlow张量,但y的数据类型需要和x一致

    比如x是Numpy数组,则y也必须是Numpy数组

    如果x是一个生成器,由于生成器里会同时生成输入数据和输入数据的标签,所以这里的y应该留空(不指定)

    1. y: Target data. Like the input data `x`,
    2. it could be either Numpy array(s) or TensorFlow tensor(s).
    3. It should be consistent with `x` (you cannot have Numpy inputs and
    4. tensor targets, or inversely). If `x` is a dataset, generator,
    5. or `keras.utils.Sequence` instance, `y` should
    6. not be specified (since targets will be obtained from `x`).

    epochs

    epochs是训练的轮数,也有用“迭代数”这个词的

    如果不设置steps_per_epoch参数,则一个epoch(轮)指完整的在x和y的所有数据上训练一次

    1. epochs: Integer. Number of epochs to train the model.
    2. An epoch is an iteration over the entire `x` and `y`
    3. data provided
    4. (unless the `steps_per_epoch` flag is set to
    5. something other than None).
    6. Note that in conjunction with `initial_epoch`,
    7. `epochs` is to be understood as "final epoch".
    8. The model is not trained for a number of iterations
    9. given by `epochs`, but merely until the epoch
    10. of index `epochs` is reached.

    evaluate函数的定义在哪里?

    和fit函数一样的原理,在training模块下可以找到。

    evaluate函数的参数

    evaluate函数的参数和fit函数的参数类似

    1. @traceback_utils.filter_traceback
    2. def evaluate(self,
    3. x=None,
    4. y=None,
    5. batch_size=None,
    6. verbose='auto',
    7. sample_weight=None,
    8. steps=None,
    9. callbacks=None,
    10. max_queue_size=10,
    11. workers=1,
    12. use_multiprocessing=False,
    13. return_dict=False,
    14. **kwargs):
    15. """Returns the loss value & metrics values for the model in test mode.
    16. Computation is done in batches (see the `batch_size` arg.)

    在上面我们评价模型的代码里 ,有一个verbose=2

    这个参数在fit函数中也有,用于控制打印过程信息

    verbose参数有4个取值,auto、0、1、2,默认值是auto

    verbose=2指的是逐行打印

    这个自己试试看看效果就很清楚了

    1. verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
    2. 0 = silent, 1 = progress bar, 2 = single line.
    3. `"auto"` defaults to 1 for most cases, and to 2 when used with
    4. `ParameterServerStrategy`. Note that the progress bar is not
    5. particularly useful when logged to a file, so `verbose=2` is
    6. recommended when not running interactively (e.g. in a production
    7. environment).

    fit函数源码

    1. @traceback_utils.filter_traceback
    2. def fit(self,
    3. x=None,
    4. y=None,
    5. batch_size=None,
    6. epochs=1,
    7. verbose='auto',
    8. callbacks=None,
    9. validation_split=0.,
    10. validation_data=None,
    11. shuffle=True,
    12. class_weight=None,
    13. sample_weight=None,
    14. initial_epoch=0,
    15. steps_per_epoch=None,
    16. validation_steps=None,
    17. validation_batch_size=None,
    18. validation_freq=1,
    19. max_queue_size=10,
    20. workers=1,
    21. use_multiprocessing=False):
    22. """Trains the model for a fixed number of epochs (iterations on a dataset).
    23. Args:
    24. x: Input data. It could be:
    25. - A Numpy array (or array-like), or a list of arrays
    26. (in case the model has multiple inputs).
    27. - A TensorFlow tensor, or a list of tensors
    28. (in case the model has multiple inputs).
    29. - A dict mapping input names to the corresponding array/tensors,
    30. if the model has named inputs.
    31. - A `tf.data` dataset. Should return a tuple
    32. of either `(inputs, targets)` or
    33. `(inputs, targets, sample_weights)`.
    34. - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
    35. or `(inputs, targets, sample_weights)`.
    36. - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a
    37. callable that takes a single argument of type
    38. `tf.distribute.InputContext`, and returns a `tf.data.Dataset`.
    39. `DatasetCreator` should be used when users prefer to specify the
    40. per-replica batching and sharding logic for the `Dataset`.
    41. See `tf.keras.utils.experimental.DatasetCreator` doc for more
    42. information.
    43. A more detailed description of unpacking behavior for iterator types
    44. (Dataset, generator, Sequence) is given below. If using
    45. `tf.distribute.experimental.ParameterServerStrategy`, only
    46. `DatasetCreator` type is supported for `x`.
    47. y: Target data. Like the input data `x`,
    48. it could be either Numpy array(s) or TensorFlow tensor(s).
    49. It should be consistent with `x` (you cannot have Numpy inputs and
    50. tensor targets, or inversely). If `x` is a dataset, generator,
    51. or `keras.utils.Sequence` instance, `y` should
    52. not be specified (since targets will be obtained from `x`).
    53. batch_size: Integer or `None`.
    54. Number of samples per gradient update.
    55. If unspecified, `batch_size` will default to 32.
    56. Do not specify the `batch_size` if your data is in the
    57. form of datasets, generators, or `keras.utils.Sequence` instances
    58. (since they generate batches).
    59. epochs: Integer. Number of epochs to train the model.
    60. An epoch is an iteration over the entire `x` and `y`
    61. data provided
    62. (unless the `steps_per_epoch` flag is set to
    63. something other than None).
    64. Note that in conjunction with `initial_epoch`,
    65. `epochs` is to be understood as "final epoch".
    66. The model is not trained for a number of iterations
    67. given by `epochs`, but merely until the epoch
    68. of index `epochs` is reached.
    69. verbose: 'auto', 0, 1, or 2. Verbosity mode.
    70. 0 = silent, 1 = progress bar, 2 = one line per epoch.
    71. 'auto' defaults to 1 for most cases, but 2 when used with
    72. `ParameterServerStrategy`. Note that the progress bar is not
    73. particularly useful when logged to a file, so verbose=2 is
    74. recommended when not running interactively (eg, in a production
    75. environment).
    76. callbacks: List of `keras.callbacks.Callback` instances.
    77. List of callbacks to apply during training.
    78. See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`
    79. and `tf.keras.callbacks.History` callbacks are created automatically
    80. and need not be passed into `model.fit`.
    81. `tf.keras.callbacks.ProgbarLogger` is created or not based on
    82. `verbose` argument to `model.fit`.
    83. Callbacks with batch-level calls are currently unsupported with
    84. `tf.distribute.experimental.ParameterServerStrategy`, and users are
    85. advised to implement epoch-level calls instead with an appropriate
    86. `steps_per_epoch` value.
    87. validation_split: Float between 0 and 1.
    88. Fraction of the training data to be used as validation data.
    89. The model will set apart this fraction of the training data,
    90. will not train on it, and will evaluate
    91. the loss and any model metrics
    92. on this data at the end of each epoch.
    93. The validation data is selected from the last samples
    94. in the `x` and `y` data provided, before shuffling. This argument is
    95. not supported when `x` is a dataset, generator or
    96. `keras.utils.Sequence` instance.
    97. If both `validation_data` and `validation_split` are provided,
    98. `validation_data` will override `validation_split`.
    99. `validation_split` is not yet supported with
    100. `tf.distribute.experimental.ParameterServerStrategy`.
    101. validation_data: Data on which to evaluate
    102. the loss and any model metrics at the end of each epoch.
    103. The model will not be trained on this data. Thus, note the fact
    104. that the validation loss of data provided using `validation_split`
    105. or `validation_data` is not affected by regularization layers like
    106. noise and dropout.
    107. `validation_data` will override `validation_split`.
    108. `validation_data` could be:
    109. - A tuple `(x_val, y_val)` of Numpy arrays or tensors.
    110. - A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays.
    111. - A `tf.data.Dataset`.
    112. - A Python generator or `keras.utils.Sequence` returning
    113. `(inputs, targets)` or `(inputs, targets, sample_weights)`.
    114. `validation_data` is not yet supported with
    115. `tf.distribute.experimental.ParameterServerStrategy`.
    116. shuffle: Boolean (whether to shuffle the training data
    117. before each epoch) or str (for 'batch'). This argument is ignored
    118. when `x` is a generator or an object of tf.data.Dataset.
    119. 'batch' is a special option for dealing
    120. with the limitations of HDF5 data; it shuffles in batch-sized
    121. chunks. Has no effect when `steps_per_epoch` is not `None`.
    122. class_weight: Optional dictionary mapping class indices (integers)
    123. to a weight (float) value, used for weighting the loss function
    124. (during training only).
    125. This can be useful to tell the model to
    126. "pay more attention" to samples from
    127. an under-represented class.
    128. sample_weight: Optional Numpy array of weights for
    129. the training samples, used for weighting the loss function
    130. (during training only). You can either pass a flat (1D)
    131. Numpy array with the same length as the input samples
    132. (1:1 mapping between weights and samples),
    133. or in the case of temporal data,
    134. you can pass a 2D array with shape
    135. `(samples, sequence_length)`,
    136. to apply a different weight to every timestep of every sample. This
    137. argument is not supported when `x` is a dataset, generator, or
    138. `keras.utils.Sequence` instance, instead provide the sample_weights
    139. as the third element of `x`.
    140. initial_epoch: Integer.
    141. Epoch at which to start training
    142. (useful for resuming a previous training run).
    143. steps_per_epoch: Integer or `None`.
    144. Total number of steps (batches of samples)
    145. before declaring one epoch finished and starting the
    146. next epoch. When training with input tensors such as
    147. TensorFlow data tensors, the default `None` is equal to
    148. the number of samples in your dataset divided by
    149. the batch size, or 1 if that cannot be determined. If x is a
    150. `tf.data` dataset, and 'steps_per_epoch'
    151. is None, the epoch will run until the input dataset is exhausted.
    152. When passing an infinitely repeating dataset, you must specify the
    153. `steps_per_epoch` argument. If `steps_per_epoch=-1` the training
    154. will run indefinitely with an infinitely repeating dataset.
    155. This argument is not supported with array inputs.
    156. When using `tf.distribute.experimental.ParameterServerStrategy`:
    157. * `steps_per_epoch=None` is not supported.
    158. validation_steps: Only relevant if `validation_data` is provided and
    159. is a `tf.data` dataset. Total number of steps (batches of
    160. samples) to draw before stopping when performing validation
    161. at the end of every epoch. If 'validation_steps' is None, validation
    162. will run until the `validation_data` dataset is exhausted. In the
    163. case of an infinitely repeated dataset, it will run into an
    164. infinite loop. If 'validation_steps' is specified and only part of
    165. the dataset will be consumed, the evaluation will start from the
    166. beginning of the dataset at each epoch. This ensures that the same
    167. validation samples are used every time.
    168. validation_batch_size: Integer or `None`.
    169. Number of samples per validation batch.
    170. If unspecified, will default to `batch_size`.
    171. Do not specify the `validation_batch_size` if your data is in the
    172. form of datasets, generators, or `keras.utils.Sequence` instances
    173. (since they generate batches).
    174. validation_freq: Only relevant if validation data is provided. Integer
    175. or `collections.abc.Container` instance (e.g. list, tuple, etc.).
    176. If an integer, specifies how many training epochs to run before a
    177. new validation run is performed, e.g. `validation_freq=2` runs
    178. validation every 2 epochs. If a Container, specifies the epochs on
    179. which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
    180. validation at the end of the 1st, 2nd, and 10th epochs.
    181. max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
    182. input only. Maximum size for the generator queue.
    183. If unspecified, `max_queue_size` will default to 10.
    184. workers: Integer. Used for generator or `keras.utils.Sequence` input
    185. only. Maximum number of processes to spin up
    186. when using process-based threading. If unspecified, `workers`
    187. will default to 1.
    188. use_multiprocessing: Boolean. Used for generator or
    189. `keras.utils.Sequence` input only. If `True`, use process-based
    190. threading. If unspecified, `use_multiprocessing` will default to
    191. `False`. Note that because this implementation relies on
    192. multiprocessing, you should not pass non-picklable arguments to
    193. the generator as they can't be passed easily to children processes.
    194. Unpacking behavior for iterator-like inputs:
    195. A common pattern is to pass a tf.data.Dataset, generator, or
    196. tf.keras.utils.Sequence to the `x` argument of fit, which will in fact
    197. yield not only features (x) but optionally targets (y) and sample weights.
    198. Keras requires that the output of such iterator-likes be unambiguous. The
    199. iterator should return a tuple of length 1, 2, or 3, where the optional
    200. second and third elements will be used for y and sample_weight
    201. respectively. Any other type provided will be wrapped in a length one
    202. tuple, effectively treating everything as 'x'. When yielding dicts, they
    203. should still adhere to the top-level tuple structure.
    204. e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
    205. features, targets, and weights from the keys of a single dict.
    206. A notable unsupported data type is the namedtuple. The reason is that
    207. it behaves like both an ordered datatype (tuple) and a mapping
    208. datatype (dict). So given a namedtuple of the form:
    209. `namedtuple("example_tuple", ["y", "x"])`
    210. it is ambiguous whether to reverse the order of the elements when
    211. interpreting the value. Even worse is a tuple of the form:
    212. `namedtuple("other_tuple", ["x", "y", "z"])`
    213. where it is unclear if the tuple was intended to be unpacked into x, y,
    214. and sample_weight or passed through as a single element to `x`. As a
    215. result the data processing code will simply raise a ValueError if it
    216. encounters a namedtuple. (Along with instructions to remedy the issue.)
    217. Returns:
    218. A `History` object. Its `History.history` attribute is
    219. a record of training loss values and metrics values
    220. at successive epochs, as well as validation loss values
    221. and validation metrics values (if applicable).
    222. Raises:
    223. RuntimeError: 1. If the model was never compiled or,
    224. 2. If `model.fit` is wrapped in `tf.function`.
    225. ValueError: In case of mismatch between the provided input data
    226. and what the model expects or when the input data is empty.
    227. """
    228. base_layer.keras_api_gauge.get_cell('fit').set(True)
    229. # Legacy graph support is contained in `training_v1.Model`.
    230. version_utils.disallow_legacy_graph('Model', 'fit')
    231. self._assert_compile_was_called()
    232. self._check_call_args('fit')
    233. _disallow_inside_tf_function('fit')
    234. verbose = _get_verbosity(verbose, self.distribute_strategy)
    235. if validation_split and validation_data is None:
    236. # Create the validation data using the training data. Only supported for
    237. # `Tensor` and `NumPy` input.
    238. (x, y, sample_weight), validation_data = (
    239. data_adapter.train_validation_split(
    240. (x, y, sample_weight), validation_split=validation_split))
    241. if validation_data:
    242. val_x, val_y, val_sample_weight = (
    243. data_adapter.unpack_x_y_sample_weight(validation_data))
    244. if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access
    245. self._cluster_coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
    246. self.distribute_strategy)
    247. with self.distribute_strategy.scope(), \
    248. training_utils.RespectCompiledTrainableState(self):
    249. # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
    250. data_handler = data_adapter.get_data_handler(
    251. x=x,
    252. y=y,
    253. sample_weight=sample_weight,
    254. batch_size=batch_size,
    255. steps_per_epoch=steps_per_epoch,
    256. initial_epoch=initial_epoch,
    257. epochs=epochs,
    258. shuffle=shuffle,
    259. class_weight=class_weight,
    260. max_queue_size=max_queue_size,
    261. workers=workers,
    262. use_multiprocessing=use_multiprocessing,
    263. model=self,
    264. steps_per_execution=self._steps_per_execution)
    265. # Container that configures and calls `tf.keras.Callback`s.
    266. if not isinstance(callbacks, callbacks_module.CallbackList):
    267. callbacks = callbacks_module.CallbackList(
    268. callbacks,
    269. add_history=True,
    270. add_progbar=verbose != 0,
    271. model=self,
    272. verbose=verbose,
    273. epochs=epochs,
    274. steps=data_handler.inferred_steps)
    275. self.stop_training = False
    276. self.train_function = self.make_train_function()
    277. self._train_counter.assign(0)
    278. callbacks.on_train_begin()
    279. training_logs = None
    280. # Handle fault-tolerance for multi-worker.
    281. # TODO(omalleyt): Fix the ordering issues that mean this has to
    282. # happen after `callbacks.on_train_begin`.
    283. data_handler._initial_epoch = ( # pylint: disable=protected-access
    284. self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
    285. logs = None
    286. for epoch, iterator in data_handler.enumerate_epochs():
    287. self.reset_metrics()
    288. callbacks.on_epoch_begin(epoch)
    289. with data_handler.catch_stop_iteration():
    290. data_handler._initial_step = self._maybe_load_initial_step_from_ckpt() # pylint: disable=protected-access
    291. for step in data_handler.steps():
    292. with tf.profiler.experimental.Trace(
    293. 'train',
    294. epoch_num=epoch,
    295. step_num=step,
    296. batch_size=batch_size,
    297. _r=1):
    298. callbacks.on_train_batch_begin(step)
    299. tmp_logs = self.train_function(iterator)
    300. if data_handler.should_sync:
    301. context.async_wait()
    302. logs = tmp_logs # No error, now safe to assign to logs.
    303. end_step = step + data_handler.step_increment
    304. callbacks.on_train_batch_end(end_step, logs)
    305. if self.stop_training:
    306. break
    307. logs = tf_utils.sync_to_numpy_or_python_type(logs)
    308. if logs is None:
    309. raise ValueError('Unexpected result of `train_function` '
    310. '(Empty logs). Please use '
    311. '`Model.compile(..., run_eagerly=True)`, or '
    312. '`tf.config.run_functions_eagerly(True)` for more '
    313. 'information of where went wrong, or file a '
    314. 'issue/bug to `tf.keras`.')
    315. epoch_logs = copy.copy(logs)
    316. # Run validation.
    317. if validation_data and self._should_eval(epoch, validation_freq):
    318. # Create data_handler for evaluation and cache it.
    319. if getattr(self, '_eval_data_handler', None) is None:
    320. self._eval_data_handler = data_adapter.get_data_handler(
    321. x=val_x,
    322. y=val_y,
    323. sample_weight=val_sample_weight,
    324. batch_size=validation_batch_size or batch_size,
    325. steps_per_epoch=validation_steps,
    326. initial_epoch=0,
    327. epochs=1,
    328. max_queue_size=max_queue_size,
    329. workers=workers,
    330. use_multiprocessing=use_multiprocessing,
    331. model=self,
    332. steps_per_execution=self._steps_per_execution)
    333. val_logs = self.evaluate(
    334. x=val_x,
    335. y=val_y,
    336. sample_weight=val_sample_weight,
    337. batch_size=validation_batch_size or batch_size,
    338. steps=validation_steps,
    339. callbacks=callbacks,
    340. max_queue_size=max_queue_size,
    341. workers=workers,
    342. use_multiprocessing=use_multiprocessing,
    343. return_dict=True,
    344. _use_cached_eval_dataset=True)
    345. val_logs = {'val_' + name: val for name, val in val_logs.items()}
    346. epoch_logs.update(val_logs)
    347. callbacks.on_epoch_end(epoch, epoch_logs)
    348. training_logs = epoch_logs
    349. if self.stop_training:
    350. break
    351. if isinstance(self.optimizer, optimizer_experimental.Optimizer):
    352. self.optimizer.finalize_variable_values(self.trainable_variables)
    353. # If eval data_handler exists, delete it after all epochs are done.
    354. if getattr(self, '_eval_data_handler', None) is not None:
    355. del self._eval_data_handler
    356. callbacks.on_train_end(logs=training_logs)
    357. return self.history

    evaluate函数源码

    1. @traceback_utils.filter_traceback
    2. def evaluate(self,
    3. x=None,
    4. y=None,
    5. batch_size=None,
    6. verbose='auto',
    7. sample_weight=None,
    8. steps=None,
    9. callbacks=None,
    10. max_queue_size=10,
    11. workers=1,
    12. use_multiprocessing=False,
    13. return_dict=False,
    14. **kwargs):
    15. """Returns the loss value & metrics values for the model in test mode.
    16. Computation is done in batches (see the `batch_size` arg.)
    17. Args:
    18. x: Input data. It could be:
    19. - A Numpy array (or array-like), or a list of arrays
    20. (in case the model has multiple inputs).
    21. - A TensorFlow tensor, or a list of tensors
    22. (in case the model has multiple inputs).
    23. - A dict mapping input names to the corresponding array/tensors,
    24. if the model has named inputs.
    25. - A `tf.data` dataset. Should return a tuple
    26. of either `(inputs, targets)` or
    27. `(inputs, targets, sample_weights)`.
    28. - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
    29. or `(inputs, targets, sample_weights)`.
    30. A more detailed description of unpacking behavior for iterator types
    31. (Dataset, generator, Sequence) is given in the `Unpacking behavior
    32. for iterator-like inputs` section of `Model.fit`.
    33. y: Target data. Like the input data `x`, it could be either Numpy
    34. array(s) or TensorFlow tensor(s). It should be consistent with `x`
    35. (you cannot have Numpy inputs and tensor targets, or inversely). If
    36. `x` is a dataset, generator or `keras.utils.Sequence` instance, `y`
    37. should not be specified (since targets will be obtained from the
    38. iterator/dataset).
    39. batch_size: Integer or `None`. Number of samples per batch of
    40. computation. If unspecified, `batch_size` will default to 32. Do not
    41. specify the `batch_size` if your data is in the form of a dataset,
    42. generators, or `keras.utils.Sequence` instances (since they generate
    43. batches).
    44. verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
    45. 0 = silent, 1 = progress bar, 2 = single line.
    46. `"auto"` defaults to 1 for most cases, and to 2 when used with
    47. `ParameterServerStrategy`. Note that the progress bar is not
    48. particularly useful when logged to a file, so `verbose=2` is
    49. recommended when not running interactively (e.g. in a production
    50. environment).
    51. sample_weight: Optional Numpy array of weights for the test samples,
    52. used for weighting the loss function. You can either pass a flat (1D)
    53. Numpy array with the same length as the input samples
    54. (1:1 mapping between weights and samples), or in the case of
    55. temporal data, you can pass a 2D array with shape `(samples,
    56. sequence_length)`, to apply a different weight to every timestep
    57. of every sample. This argument is not supported when `x` is a
    58. dataset, instead pass sample weights as the third element of `x`.
    59. steps: Integer or `None`. Total number of steps (batches of samples)
    60. before declaring the evaluation round finished. Ignored with the
    61. default value of `None`. If x is a `tf.data` dataset and `steps` is
    62. None, 'evaluate' will run until the dataset is exhausted. This
    63. argument is not supported with array inputs.
    64. callbacks: List of `keras.callbacks.Callback` instances. List of
    65. callbacks to apply during evaluation. See
    66. [callbacks](/api_docs/python/tf/keras/callbacks).
    67. max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
    68. input only. Maximum size for the generator queue. If unspecified,
    69. `max_queue_size` will default to 10.
    70. workers: Integer. Used for generator or `keras.utils.Sequence` input
    71. only. Maximum number of processes to spin up when using process-based
    72. threading. If unspecified, `workers` will default to 1.
    73. use_multiprocessing: Boolean. Used for generator or
    74. `keras.utils.Sequence` input only. If `True`, use process-based
    75. threading. If unspecified, `use_multiprocessing` will default to
    76. `False`. Note that because this implementation relies on
    77. multiprocessing, you should not pass non-picklable arguments to the
    78. generator as they can't be passed easily to children processes.
    79. return_dict: If `True`, loss and metric results are returned as a dict,
    80. with each key being the name of the metric. If `False`, they are
    81. returned as a list.
    82. **kwargs: Unused at this time.
    83. See the discussion of `Unpacking behavior for iterator-like inputs` for
    84. `Model.fit`.
    85. Returns:
    86. Scalar test loss (if the model has a single output and no metrics)
    87. or list of scalars (if the model has multiple outputs
    88. and/or metrics). The attribute `model.metrics_names` will give you
    89. the display labels for the scalar outputs.
    90. Raises:
    91. RuntimeError: If `model.evaluate` is wrapped in a `tf.function`.
    92. """
    93. base_layer.keras_api_gauge.get_cell('evaluate').set(True)
    94. version_utils.disallow_legacy_graph('Model', 'evaluate')
    95. self._assert_compile_was_called()
    96. self._check_call_args('evaluate')
    97. self._check_sample_weight_warning(x, sample_weight)
    98. _disallow_inside_tf_function('evaluate')
    99. use_cached_eval_dataset = kwargs.pop('_use_cached_eval_dataset', False)
    100. if kwargs:
    101. raise TypeError(f'Invalid keyword arguments: {list(kwargs.keys())}')
    102. if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access
    103. self._cluster_coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
    104. self.distribute_strategy)
    105. verbose = _get_verbosity(verbose, self.distribute_strategy)
    106. with self.distribute_strategy.scope():
    107. # Use cached evaluation data only when it's called in `Model.fit`
    108. if (use_cached_eval_dataset
    109. and getattr(self, '_eval_data_handler', None) is not None):
    110. data_handler = self._eval_data_handler
    111. else:
    112. # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
    113. data_handler = data_adapter.get_data_handler(
    114. x=x,
    115. y=y,
    116. sample_weight=sample_weight,
    117. batch_size=batch_size,
    118. steps_per_epoch=steps,
    119. initial_epoch=0,
    120. epochs=1,
    121. max_queue_size=max_queue_size,
    122. workers=workers,
    123. use_multiprocessing=use_multiprocessing,
    124. model=self,
    125. steps_per_execution=self._steps_per_execution)
    126. # Container that configures and calls `tf.keras.Callback`s.
    127. if not isinstance(callbacks, callbacks_module.CallbackList):
    128. callbacks = callbacks_module.CallbackList(
    129. callbacks,
    130. add_history=True,
    131. add_progbar=verbose != 0,
    132. model=self,
    133. verbose=verbose,
    134. epochs=1,
    135. steps=data_handler.inferred_steps)
    136. logs = {}
    137. self.test_function = self.make_test_function()
    138. self._test_counter.assign(0)
    139. callbacks.on_test_begin()
    140. for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
    141. self.reset_metrics()
    142. with data_handler.catch_stop_iteration():
    143. for step in data_handler.steps():
    144. with tf.profiler.experimental.Trace('test', step_num=step, _r=1):
    145. callbacks.on_test_batch_begin(step)
    146. tmp_logs = self.test_function(iterator)
    147. if data_handler.should_sync:
    148. context.async_wait()
    149. logs = tmp_logs # No error, now safe to assign to logs.
    150. end_step = step + data_handler.step_increment
    151. callbacks.on_test_batch_end(end_step, logs)
    152. logs = tf_utils.sync_to_numpy_or_python_type(logs)
    153. callbacks.on_test_end(logs=logs)
    154. if return_dict:
    155. return logs
    156. else:
    157. return flatten_metrics_in_order(logs, self.metrics_names)
  • 相关阅读:
    Python接口自动化搭建过程,含request请求封装
    22.3 指针与数组
    java168-java连接SQL server数据库
    docker入门(利用docker部署web应用)
    14.AQS的前世,从1990年的论文说起
    python爬取某乎保存为json文件
    照片+制作照片书神器,效果太棒了!
    整合vue elementui springboot mybatisplus前后端分离的 简单增加功能 删改查未实现
    期货开户公司底蕴深厚实力强大
    Python日期时间差的计算(天/小时/分钟)及timedelta函数的使用(附python代码)
  • 原文地址:https://blog.csdn.net/ytomc/article/details/126298659