目录
TensorFlow 2.9的零零碎碎(二)-TensorFlow 2.9的零零碎碎(六)都是围绕使用TensorFlow 2.9在MNIST数据集上训练和评价模型来展开。
Python环境3.8。
代码调试都用的PyCharm。
在构建好数据、模型,完成模型编译之后,接下来的就是模型训练,使用fit函数,以及模型评价,使用evaluate函数。
- import tensorflow as tf
-
- mnist = tf.keras.datasets.mnist
- (x_train, y_train), (x_test, y_test) = mnist.load_data()
- x_train, x_test = x_train / 255.0, x_test / 255.0
-
- model = tf.keras.models.Sequential()
- model.add(tf.keras.layers.Input(shape=(28, 28)))
- model.add(tf.keras.layers.Flatten())
- model.add(tf.keras.layers.Dense(128, activation='relu'))
- model.add(tf.keras.layers.Dropout(0.2))
- model.add(tf.keras.layers.Dense(10, activation='softmax'))
-
- model.compile(optimizer='adam',
- loss='sparse_categorical_crossentropy',
- metrics=['acc'])
-
- model.fit(x_train, y_train, epochs=5)
-
- model.evaluate(x_test, y_test, verbose=2)
其实代码写到这,也就写完了,看看代码的运行结果

通过model = tf.keras.models.Sequential(),我们知道model是Sequential类的一个对象,所以我们去Sequential类里看看
层层跳转,在keras.engine.sequential找到Sequential类的定义
- @keras_export('keras.Sequential', 'keras.models.Sequential')
- class Sequential(functional.Functional):
- """`Sequential` groups a linear stack of layers into a `tf.keras.Model`.
- `Sequential` provides training and inference features on this model.
可以看出,Sequential类里并没有fit函数
但是我们发现Sequential类继承了functional模块的Functional类,继续跳转
- # pylint: disable=g-classes-have-attributes
- class Functional(training_lib.Model):
- """A `Functional` model is a `Model` defined as a directed graph of layers.
- Three types of `Model` exist: subclassed `Model`, `Functional` model,
- and `Sequential` (a special case of `Functional`).
- In general, more Keras features are supported with `Functional`
- than with subclassed `Model`s, specifically:
- - Model cloning (`keras.models.clone`)
- - Serialization (`model.get_config()/from_config`, `model.to_json()`
- - Whole-model saving (`model.save()`)
- A `Functional` model can be instantiated by passing two arguments to
- `__init__`. The first argument is the `keras.Input` Tensors that represent
- the inputs to the model. The second argument specifies the output
- tensors that represent the outputs of this model. Both arguments can be a
- nested structure of tensors.
同样,Functional类里并没有fit函数
但是我们发现Functional类继承了training模块的Model类
注意不是training_lib模块,因为training_lib只是一个别名(from keras.engine import training as training_lib)
继续跳转
- @keras_export('keras.Model', 'keras.models.Model')
- class Model(base_layer.Layer, version_utils.ModelVersionSelector):
- """`Model` groups layers into an object with training and inference features.
在training模块的Model类中定义了fit函数。源码太长了,这里给一部分,全部的版本贴在最后。
- def fit(self,
- x=None,
- y=None,
- batch_size=None,
- epochs=1,
- verbose='auto',
- callbacks=None,
- validation_split=0.,
- validation_data=None,
- shuffle=True,
- class_weight=None,
- sample_weight=None,
- initial_epoch=0,
- steps_per_epoch=None,
- validation_steps=None,
- validation_batch_size=None,
- validation_freq=1,
- max_queue_size=10,
- workers=1,
- use_multiprocessing=False):
- """Trains the model for a fixed number of epochs (iterations on a dataset).
fit函数的参数比较多,但在上面我们训练模型的代码里其实就用了3个参数,分别是x(输入的数据)、y(输入数据的标签)、epochs(训练的轮数)。
x也就是输入的数据,可以是多种类型,比如Numpy数组、TensorFlow张量。
- x: Input data. It could be:
- - A Numpy array (or array-like), or a list of arrays
- (in case the model has multiple inputs).
- - A TensorFlow tensor, or a list of tensors
- (in case the model has multiple inputs).
- - A dict mapping input names to the corresponding array/tensors,
- if the model has named inputs.
- - A `tf.data` dataset. Should return a tuple
- of either `(inputs, targets)` or
- `(inputs, targets, sample_weights)`.
- - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
- or `(inputs, targets, sample_weights)`.
- - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a
- callable that takes a single argument of type
- `tf.distribute.InputContext`, and returns a `tf.data.Dataset`.
- `DatasetCreator` should be used when users prefer to specify the
- per-replica batching and sharding logic for the `Dataset`.
- See `tf.keras.utils.experimental.DatasetCreator` doc for more
- information.
- A more detailed description of unpacking behavior for iterator types
- (Dataset, generator, Sequence) is given below. If using
- `tf.distribute.experimental.ParameterServerStrategy`, only
- `DatasetCreator` type is supported for `x`.
TensorFlow前面的版本中,还有一个函数叫fit_generator。传进去的不是数组或者张量,而是生成器,生成器里会同时生成输入数据和输入数据的标签,主要用于资源有限(如GPU资源)、数据量比较大的情况,将大规模的数据切片,分别传入进行模型训练。
- def fit_generator(self,
- generator,
- steps_per_epoch=None,
- epochs=1,
- verbose=1,
- callbacks=None,
- validation_data=None,
- validation_steps=None,
- validation_freq=1,
- class_weight=None,
- max_queue_size=10,
- workers=1,
- use_multiprocessing=False,
- shuffle=True,
- initial_epoch=0):
- """Fits the model on data yielded batch-by-batch by a Python generator.
但在TensorFlow 2.9中,fit_generator函数已经被弃用了,这是因为fit函数将fit_generator函数的功能集成了进去。
现在的fit函数的x参数一样接受生成器作为参数。
y是输入数据的标签,数据类型和x类似,比如可以是Numpy数组或者TensorFlow张量,但y的数据类型需要和x一致
比如x是Numpy数组,则y也必须是Numpy数组
如果x是一个生成器,由于生成器里会同时生成输入数据和输入数据的标签,所以这里的y应该留空(不指定)
- y: Target data. Like the input data `x`,
- it could be either Numpy array(s) or TensorFlow tensor(s).
- It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset, generator,
- or `keras.utils.Sequence` instance, `y` should
- not be specified (since targets will be obtained from `x`).
epochs是训练的轮数,也有用“迭代数”这个词的
如果不设置steps_per_epoch参数,则一个epoch(轮)指完整的在x和y的所有数据上训练一次
- epochs: Integer. Number of epochs to train the model.
- An epoch is an iteration over the entire `x` and `y`
- data provided
- (unless the `steps_per_epoch` flag is set to
- something other than None).
- Note that in conjunction with `initial_epoch`,
- `epochs` is to be understood as "final epoch".
- The model is not trained for a number of iterations
- given by `epochs`, but merely until the epoch
- of index `epochs` is reached.
和fit函数一样的原理,在training模块下可以找到。
evaluate函数的参数和fit函数的参数类似
- @traceback_utils.filter_traceback
- def evaluate(self,
- x=None,
- y=None,
- batch_size=None,
- verbose='auto',
- sample_weight=None,
- steps=None,
- callbacks=None,
- max_queue_size=10,
- workers=1,
- use_multiprocessing=False,
- return_dict=False,
- **kwargs):
- """Returns the loss value & metrics values for the model in test mode.
- Computation is done in batches (see the `batch_size` arg.)
在上面我们评价模型的代码里 ,有一个verbose=2
这个参数在fit函数中也有,用于控制打印过程信息
verbose参数有4个取值,auto、0、1、2,默认值是auto
verbose=2指的是逐行打印
这个自己试试看看效果就很清楚了
- verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
- 0 = silent, 1 = progress bar, 2 = single line.
- `"auto"` defaults to 1 for most cases, and to 2 when used with
- `ParameterServerStrategy`. Note that the progress bar is not
- particularly useful when logged to a file, so `verbose=2` is
- recommended when not running interactively (e.g. in a production
- environment).
- @traceback_utils.filter_traceback
- def fit(self,
- x=None,
- y=None,
- batch_size=None,
- epochs=1,
- verbose='auto',
- callbacks=None,
- validation_split=0.,
- validation_data=None,
- shuffle=True,
- class_weight=None,
- sample_weight=None,
- initial_epoch=0,
- steps_per_epoch=None,
- validation_steps=None,
- validation_batch_size=None,
- validation_freq=1,
- max_queue_size=10,
- workers=1,
- use_multiprocessing=False):
- """Trains the model for a fixed number of epochs (iterations on a dataset).
- Args:
- x: Input data. It could be:
- - A Numpy array (or array-like), or a list of arrays
- (in case the model has multiple inputs).
- - A TensorFlow tensor, or a list of tensors
- (in case the model has multiple inputs).
- - A dict mapping input names to the corresponding array/tensors,
- if the model has named inputs.
- - A `tf.data` dataset. Should return a tuple
- of either `(inputs, targets)` or
- `(inputs, targets, sample_weights)`.
- - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
- or `(inputs, targets, sample_weights)`.
- - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a
- callable that takes a single argument of type
- `tf.distribute.InputContext`, and returns a `tf.data.Dataset`.
- `DatasetCreator` should be used when users prefer to specify the
- per-replica batching and sharding logic for the `Dataset`.
- See `tf.keras.utils.experimental.DatasetCreator` doc for more
- information.
- A more detailed description of unpacking behavior for iterator types
- (Dataset, generator, Sequence) is given below. If using
- `tf.distribute.experimental.ParameterServerStrategy`, only
- `DatasetCreator` type is supported for `x`.
- y: Target data. Like the input data `x`,
- it could be either Numpy array(s) or TensorFlow tensor(s).
- It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset, generator,
- or `keras.utils.Sequence` instance, `y` should
- not be specified (since targets will be obtained from `x`).
- batch_size: Integer or `None`.
- Number of samples per gradient update.
- If unspecified, `batch_size` will default to 32.
- Do not specify the `batch_size` if your data is in the
- form of datasets, generators, or `keras.utils.Sequence` instances
- (since they generate batches).
- epochs: Integer. Number of epochs to train the model.
- An epoch is an iteration over the entire `x` and `y`
- data provided
- (unless the `steps_per_epoch` flag is set to
- something other than None).
- Note that in conjunction with `initial_epoch`,
- `epochs` is to be understood as "final epoch".
- The model is not trained for a number of iterations
- given by `epochs`, but merely until the epoch
- of index `epochs` is reached.
- verbose: 'auto', 0, 1, or 2. Verbosity mode.
- 0 = silent, 1 = progress bar, 2 = one line per epoch.
- 'auto' defaults to 1 for most cases, but 2 when used with
- `ParameterServerStrategy`. Note that the progress bar is not
- particularly useful when logged to a file, so verbose=2 is
- recommended when not running interactively (eg, in a production
- environment).
- callbacks: List of `keras.callbacks.Callback` instances.
- List of callbacks to apply during training.
- See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`
- and `tf.keras.callbacks.History` callbacks are created automatically
- and need not be passed into `model.fit`.
- `tf.keras.callbacks.ProgbarLogger` is created or not based on
- `verbose` argument to `model.fit`.
- Callbacks with batch-level calls are currently unsupported with
- `tf.distribute.experimental.ParameterServerStrategy`, and users are
- advised to implement epoch-level calls instead with an appropriate
- `steps_per_epoch` value.
- validation_split: Float between 0 and 1.
- Fraction of the training data to be used as validation data.
- The model will set apart this fraction of the training data,
- will not train on it, and will evaluate
- the loss and any model metrics
- on this data at the end of each epoch.
- The validation data is selected from the last samples
- in the `x` and `y` data provided, before shuffling. This argument is
- not supported when `x` is a dataset, generator or
- `keras.utils.Sequence` instance.
- If both `validation_data` and `validation_split` are provided,
- `validation_data` will override `validation_split`.
- `validation_split` is not yet supported with
- `tf.distribute.experimental.ParameterServerStrategy`.
- validation_data: Data on which to evaluate
- the loss and any model metrics at the end of each epoch.
- The model will not be trained on this data. Thus, note the fact
- that the validation loss of data provided using `validation_split`
- or `validation_data` is not affected by regularization layers like
- noise and dropout.
- `validation_data` will override `validation_split`.
- `validation_data` could be:
- - A tuple `(x_val, y_val)` of Numpy arrays or tensors.
- - A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays.
- - A `tf.data.Dataset`.
- - A Python generator or `keras.utils.Sequence` returning
- `(inputs, targets)` or `(inputs, targets, sample_weights)`.
- `validation_data` is not yet supported with
- `tf.distribute.experimental.ParameterServerStrategy`.
- shuffle: Boolean (whether to shuffle the training data
- before each epoch) or str (for 'batch'). This argument is ignored
- when `x` is a generator or an object of tf.data.Dataset.
- 'batch' is a special option for dealing
- with the limitations of HDF5 data; it shuffles in batch-sized
- chunks. Has no effect when `steps_per_epoch` is not `None`.
- class_weight: Optional dictionary mapping class indices (integers)
- to a weight (float) value, used for weighting the loss function
- (during training only).
- This can be useful to tell the model to
- "pay more attention" to samples from
- an under-represented class.
- sample_weight: Optional Numpy array of weights for
- the training samples, used for weighting the loss function
- (during training only). You can either pass a flat (1D)
- Numpy array with the same length as the input samples
- (1:1 mapping between weights and samples),
- or in the case of temporal data,
- you can pass a 2D array with shape
- `(samples, sequence_length)`,
- to apply a different weight to every timestep of every sample. This
- argument is not supported when `x` is a dataset, generator, or
- `keras.utils.Sequence` instance, instead provide the sample_weights
- as the third element of `x`.
- initial_epoch: Integer.
- Epoch at which to start training
- (useful for resuming a previous training run).
- steps_per_epoch: Integer or `None`.
- Total number of steps (batches of samples)
- before declaring one epoch finished and starting the
- next epoch. When training with input tensors such as
- TensorFlow data tensors, the default `None` is equal to
- the number of samples in your dataset divided by
- the batch size, or 1 if that cannot be determined. If x is a
- `tf.data` dataset, and 'steps_per_epoch'
- is None, the epoch will run until the input dataset is exhausted.
- When passing an infinitely repeating dataset, you must specify the
- `steps_per_epoch` argument. If `steps_per_epoch=-1` the training
- will run indefinitely with an infinitely repeating dataset.
- This argument is not supported with array inputs.
- When using `tf.distribute.experimental.ParameterServerStrategy`:
- * `steps_per_epoch=None` is not supported.
- validation_steps: Only relevant if `validation_data` is provided and
- is a `tf.data` dataset. Total number of steps (batches of
- samples) to draw before stopping when performing validation
- at the end of every epoch. If 'validation_steps' is None, validation
- will run until the `validation_data` dataset is exhausted. In the
- case of an infinitely repeated dataset, it will run into an
- infinite loop. If 'validation_steps' is specified and only part of
- the dataset will be consumed, the evaluation will start from the
- beginning of the dataset at each epoch. This ensures that the same
- validation samples are used every time.
- validation_batch_size: Integer or `None`.
- Number of samples per validation batch.
- If unspecified, will default to `batch_size`.
- Do not specify the `validation_batch_size` if your data is in the
- form of datasets, generators, or `keras.utils.Sequence` instances
- (since they generate batches).
- validation_freq: Only relevant if validation data is provided. Integer
- or `collections.abc.Container` instance (e.g. list, tuple, etc.).
- If an integer, specifies how many training epochs to run before a
- new validation run is performed, e.g. `validation_freq=2` runs
- validation every 2 epochs. If a Container, specifies the epochs on
- which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
- validation at the end of the 1st, 2nd, and 10th epochs.
- max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
- input only. Maximum size for the generator queue.
- If unspecified, `max_queue_size` will default to 10.
- workers: Integer. Used for generator or `keras.utils.Sequence` input
- only. Maximum number of processes to spin up
- when using process-based threading. If unspecified, `workers`
- will default to 1.
- use_multiprocessing: Boolean. Used for generator or
- `keras.utils.Sequence` input only. If `True`, use process-based
- threading. If unspecified, `use_multiprocessing` will default to
- `False`. Note that because this implementation relies on
- multiprocessing, you should not pass non-picklable arguments to
- the generator as they can't be passed easily to children processes.
- Unpacking behavior for iterator-like inputs:
- A common pattern is to pass a tf.data.Dataset, generator, or
- tf.keras.utils.Sequence to the `x` argument of fit, which will in fact
- yield not only features (x) but optionally targets (y) and sample weights.
- Keras requires that the output of such iterator-likes be unambiguous. The
- iterator should return a tuple of length 1, 2, or 3, where the optional
- second and third elements will be used for y and sample_weight
- respectively. Any other type provided will be wrapped in a length one
- tuple, effectively treating everything as 'x'. When yielding dicts, they
- should still adhere to the top-level tuple structure.
- e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
- features, targets, and weights from the keys of a single dict.
- A notable unsupported data type is the namedtuple. The reason is that
- it behaves like both an ordered datatype (tuple) and a mapping
- datatype (dict). So given a namedtuple of the form:
- `namedtuple("example_tuple", ["y", "x"])`
- it is ambiguous whether to reverse the order of the elements when
- interpreting the value. Even worse is a tuple of the form:
- `namedtuple("other_tuple", ["x", "y", "z"])`
- where it is unclear if the tuple was intended to be unpacked into x, y,
- and sample_weight or passed through as a single element to `x`. As a
- result the data processing code will simply raise a ValueError if it
- encounters a namedtuple. (Along with instructions to remedy the issue.)
- Returns:
- A `History` object. Its `History.history` attribute is
- a record of training loss values and metrics values
- at successive epochs, as well as validation loss values
- and validation metrics values (if applicable).
- Raises:
- RuntimeError: 1. If the model was never compiled or,
- 2. If `model.fit` is wrapped in `tf.function`.
- ValueError: In case of mismatch between the provided input data
- and what the model expects or when the input data is empty.
- """
- base_layer.keras_api_gauge.get_cell('fit').set(True)
- # Legacy graph support is contained in `training_v1.Model`.
- version_utils.disallow_legacy_graph('Model', 'fit')
- self._assert_compile_was_called()
- self._check_call_args('fit')
- _disallow_inside_tf_function('fit')
-
- verbose = _get_verbosity(verbose, self.distribute_strategy)
-
- if validation_split and validation_data is None:
- # Create the validation data using the training data. Only supported for
- # `Tensor` and `NumPy` input.
- (x, y, sample_weight), validation_data = (
- data_adapter.train_validation_split(
- (x, y, sample_weight), validation_split=validation_split))
-
- if validation_data:
- val_x, val_y, val_sample_weight = (
- data_adapter.unpack_x_y_sample_weight(validation_data))
-
- if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access
- self._cluster_coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
- self.distribute_strategy)
-
- with self.distribute_strategy.scope(), \
- training_utils.RespectCompiledTrainableState(self):
- # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
- data_handler = data_adapter.get_data_handler(
- x=x,
- y=y,
- sample_weight=sample_weight,
- batch_size=batch_size,
- steps_per_epoch=steps_per_epoch,
- initial_epoch=initial_epoch,
- epochs=epochs,
- shuffle=shuffle,
- class_weight=class_weight,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing,
- model=self,
- steps_per_execution=self._steps_per_execution)
-
- # Container that configures and calls `tf.keras.Callback`s.
- if not isinstance(callbacks, callbacks_module.CallbackList):
- callbacks = callbacks_module.CallbackList(
- callbacks,
- add_history=True,
- add_progbar=verbose != 0,
- model=self,
- verbose=verbose,
- epochs=epochs,
- steps=data_handler.inferred_steps)
-
- self.stop_training = False
- self.train_function = self.make_train_function()
- self._train_counter.assign(0)
- callbacks.on_train_begin()
- training_logs = None
- # Handle fault-tolerance for multi-worker.
- # TODO(omalleyt): Fix the ordering issues that mean this has to
- # happen after `callbacks.on_train_begin`.
- data_handler._initial_epoch = ( # pylint: disable=protected-access
- self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
- logs = None
- for epoch, iterator in data_handler.enumerate_epochs():
- self.reset_metrics()
- callbacks.on_epoch_begin(epoch)
- with data_handler.catch_stop_iteration():
- data_handler._initial_step = self._maybe_load_initial_step_from_ckpt() # pylint: disable=protected-access
- for step in data_handler.steps():
- with tf.profiler.experimental.Trace(
- 'train',
- epoch_num=epoch,
- step_num=step,
- batch_size=batch_size,
- _r=1):
- callbacks.on_train_batch_begin(step)
- tmp_logs = self.train_function(iterator)
- if data_handler.should_sync:
- context.async_wait()
- logs = tmp_logs # No error, now safe to assign to logs.
- end_step = step + data_handler.step_increment
- callbacks.on_train_batch_end(end_step, logs)
- if self.stop_training:
- break
-
- logs = tf_utils.sync_to_numpy_or_python_type(logs)
- if logs is None:
- raise ValueError('Unexpected result of `train_function` '
- '(Empty logs). Please use '
- '`Model.compile(..., run_eagerly=True)`, or '
- '`tf.config.run_functions_eagerly(True)` for more '
- 'information of where went wrong, or file a '
- 'issue/bug to `tf.keras`.')
- epoch_logs = copy.copy(logs)
-
- # Run validation.
- if validation_data and self._should_eval(epoch, validation_freq):
- # Create data_handler for evaluation and cache it.
- if getattr(self, '_eval_data_handler', None) is None:
- self._eval_data_handler = data_adapter.get_data_handler(
- x=val_x,
- y=val_y,
- sample_weight=val_sample_weight,
- batch_size=validation_batch_size or batch_size,
- steps_per_epoch=validation_steps,
- initial_epoch=0,
- epochs=1,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing,
- model=self,
- steps_per_execution=self._steps_per_execution)
- val_logs = self.evaluate(
- x=val_x,
- y=val_y,
- sample_weight=val_sample_weight,
- batch_size=validation_batch_size or batch_size,
- steps=validation_steps,
- callbacks=callbacks,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing,
- return_dict=True,
- _use_cached_eval_dataset=True)
- val_logs = {'val_' + name: val for name, val in val_logs.items()}
- epoch_logs.update(val_logs)
-
- callbacks.on_epoch_end(epoch, epoch_logs)
- training_logs = epoch_logs
- if self.stop_training:
- break
-
- if isinstance(self.optimizer, optimizer_experimental.Optimizer):
- self.optimizer.finalize_variable_values(self.trainable_variables)
-
- # If eval data_handler exists, delete it after all epochs are done.
- if getattr(self, '_eval_data_handler', None) is not None:
- del self._eval_data_handler
- callbacks.on_train_end(logs=training_logs)
- return self.history
- @traceback_utils.filter_traceback
- def evaluate(self,
- x=None,
- y=None,
- batch_size=None,
- verbose='auto',
- sample_weight=None,
- steps=None,
- callbacks=None,
- max_queue_size=10,
- workers=1,
- use_multiprocessing=False,
- return_dict=False,
- **kwargs):
- """Returns the loss value & metrics values for the model in test mode.
- Computation is done in batches (see the `batch_size` arg.)
- Args:
- x: Input data. It could be:
- - A Numpy array (or array-like), or a list of arrays
- (in case the model has multiple inputs).
- - A TensorFlow tensor, or a list of tensors
- (in case the model has multiple inputs).
- - A dict mapping input names to the corresponding array/tensors,
- if the model has named inputs.
- - A `tf.data` dataset. Should return a tuple
- of either `(inputs, targets)` or
- `(inputs, targets, sample_weights)`.
- - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
- or `(inputs, targets, sample_weights)`.
- A more detailed description of unpacking behavior for iterator types
- (Dataset, generator, Sequence) is given in the `Unpacking behavior
- for iterator-like inputs` section of `Model.fit`.
- y: Target data. Like the input data `x`, it could be either Numpy
- array(s) or TensorFlow tensor(s). It should be consistent with `x`
- (you cannot have Numpy inputs and tensor targets, or inversely). If
- `x` is a dataset, generator or `keras.utils.Sequence` instance, `y`
- should not be specified (since targets will be obtained from the
- iterator/dataset).
- batch_size: Integer or `None`. Number of samples per batch of
- computation. If unspecified, `batch_size` will default to 32. Do not
- specify the `batch_size` if your data is in the form of a dataset,
- generators, or `keras.utils.Sequence` instances (since they generate
- batches).
- verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
- 0 = silent, 1 = progress bar, 2 = single line.
- `"auto"` defaults to 1 for most cases, and to 2 when used with
- `ParameterServerStrategy`. Note that the progress bar is not
- particularly useful when logged to a file, so `verbose=2` is
- recommended when not running interactively (e.g. in a production
- environment).
- sample_weight: Optional Numpy array of weights for the test samples,
- used for weighting the loss function. You can either pass a flat (1D)
- Numpy array with the same length as the input samples
- (1:1 mapping between weights and samples), or in the case of
- temporal data, you can pass a 2D array with shape `(samples,
- sequence_length)`, to apply a different weight to every timestep
- of every sample. This argument is not supported when `x` is a
- dataset, instead pass sample weights as the third element of `x`.
- steps: Integer or `None`. Total number of steps (batches of samples)
- before declaring the evaluation round finished. Ignored with the
- default value of `None`. If x is a `tf.data` dataset and `steps` is
- None, 'evaluate' will run until the dataset is exhausted. This
- argument is not supported with array inputs.
- callbacks: List of `keras.callbacks.Callback` instances. List of
- callbacks to apply during evaluation. See
- [callbacks](/api_docs/python/tf/keras/callbacks).
- max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
- input only. Maximum size for the generator queue. If unspecified,
- `max_queue_size` will default to 10.
- workers: Integer. Used for generator or `keras.utils.Sequence` input
- only. Maximum number of processes to spin up when using process-based
- threading. If unspecified, `workers` will default to 1.
- use_multiprocessing: Boolean. Used for generator or
- `keras.utils.Sequence` input only. If `True`, use process-based
- threading. If unspecified, `use_multiprocessing` will default to
- `False`. Note that because this implementation relies on
- multiprocessing, you should not pass non-picklable arguments to the
- generator as they can't be passed easily to children processes.
- return_dict: If `True`, loss and metric results are returned as a dict,
- with each key being the name of the metric. If `False`, they are
- returned as a list.
- **kwargs: Unused at this time.
- See the discussion of `Unpacking behavior for iterator-like inputs` for
- `Model.fit`.
- Returns:
- Scalar test loss (if the model has a single output and no metrics)
- or list of scalars (if the model has multiple outputs
- and/or metrics). The attribute `model.metrics_names` will give you
- the display labels for the scalar outputs.
- Raises:
- RuntimeError: If `model.evaluate` is wrapped in a `tf.function`.
- """
- base_layer.keras_api_gauge.get_cell('evaluate').set(True)
- version_utils.disallow_legacy_graph('Model', 'evaluate')
- self._assert_compile_was_called()
- self._check_call_args('evaluate')
- self._check_sample_weight_warning(x, sample_weight)
- _disallow_inside_tf_function('evaluate')
- use_cached_eval_dataset = kwargs.pop('_use_cached_eval_dataset', False)
- if kwargs:
- raise TypeError(f'Invalid keyword arguments: {list(kwargs.keys())}')
-
- if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access
- self._cluster_coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
- self.distribute_strategy)
-
- verbose = _get_verbosity(verbose, self.distribute_strategy)
- with self.distribute_strategy.scope():
- # Use cached evaluation data only when it's called in `Model.fit`
- if (use_cached_eval_dataset
- and getattr(self, '_eval_data_handler', None) is not None):
- data_handler = self._eval_data_handler
- else:
- # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
- data_handler = data_adapter.get_data_handler(
- x=x,
- y=y,
- sample_weight=sample_weight,
- batch_size=batch_size,
- steps_per_epoch=steps,
- initial_epoch=0,
- epochs=1,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing,
- model=self,
- steps_per_execution=self._steps_per_execution)
-
- # Container that configures and calls `tf.keras.Callback`s.
- if not isinstance(callbacks, callbacks_module.CallbackList):
- callbacks = callbacks_module.CallbackList(
- callbacks,
- add_history=True,
- add_progbar=verbose != 0,
- model=self,
- verbose=verbose,
- epochs=1,
- steps=data_handler.inferred_steps)
-
- logs = {}
- self.test_function = self.make_test_function()
- self._test_counter.assign(0)
- callbacks.on_test_begin()
- for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
- self.reset_metrics()
- with data_handler.catch_stop_iteration():
- for step in data_handler.steps():
- with tf.profiler.experimental.Trace('test', step_num=step, _r=1):
- callbacks.on_test_batch_begin(step)
- tmp_logs = self.test_function(iterator)
- if data_handler.should_sync:
- context.async_wait()
- logs = tmp_logs # No error, now safe to assign to logs.
- end_step = step + data_handler.step_increment
- callbacks.on_test_batch_end(end_step, logs)
- logs = tf_utils.sync_to_numpy_or_python_type(logs)
- callbacks.on_test_end(logs=logs)
-
- if return_dict:
- return logs
- else:
- return flatten_metrics_in_order(logs, self.metrics_names)