tf.compat.v1.estimator.tpu.TPUEstimator(
model_fn=None,
model_dir=None,
config=None,
params=None,
use_tpu=True,
train_batch_size=None,
eval_batch_size=None,
predict_batch_size=None,
batch_axis=None,
eval_on_tpu=True,
export_to_tpu=True,
export_to_cpu=True,
warm_start_from=None,
embedding_config_spec=None,
export_saved_model_api_version=ExportSavedModelApiVersion.V1
)
Args
model_fn Model function as required by Estimator which returns EstimatorSpec or TPUEstimatorSpec. training_hooks, ‘evaluation_hooks’, and prediction_hooks must not capure any TPU Tensor inside the model_fn.
model_dir Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If None, the model_dir in config will be used if set. If both are set, they must be same. If both are None, a temporary directory will be used.
config An tpu_config.RunConfig configuration object. Cannot be None.
params An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including ‘batch_size’.
use_tpu A bool indicating whether TPU support is enabled. Currently, - TPU training and evaluation respect this bit, but eval_on_tpu can override execution of eval. See below.
train_batch_size An int representing the global training batch size. TPUEstimator transforms this global batch size to a per-shard batch size, as params[‘batch_size’], when calling input_fn and model_fn. Cannot be None if use_tpu is True. Must be divisible by total number of replicas.
eval_batch_size An int representing evaluation batch size. Must be divisible by total number of replicas.
predict_batch_size An int representing the prediction batch size. Must be divisible by total number of replicas.
batch_axis A python tuple of int values describing how each tensor produced by the Estimator input_fn should be split across the TPU compute shards. For example, if your input_fn produced (images, labels) where the images tensor is in HWCN format, your shard dimensions would be [3, 0], where 3 corresponds to the N dimension of your images Tensor, and 0 corresponds to the dimension along which to split the labels to match up with the corresponding images. If None is supplied, and per_host_input_for_training is True, batches will be sharded based on the major dimension. If tpu_config.per_host_input_for_training is False or PER_HOST_V2, batch_axis is ignored.
eval_on_tpu If False, evaluation runs on CPU or GPU. In this case, the model_fn must return EstimatorSpec when called with mode as EVAL.
export_to_tpu If True, export_saved_model() exports a metagraph for serving on TPU. Note that unsupported export modes such as EVAL will be ignored. For those modes, only a CPU model will be exported. Currently, export_to_tpu only supports PREDICT.
export_to_cpu If True, export_saved_model() exports a metagraph for serving on CPU.
warm_start_from Optional string filepath to a checkpoint or SavedModel to warm-start from, or a tf.estimator.WarmStartSettings object to fully configure warm-starting. If the string filepath is provided instead of a WarmStartSettings, then all variables are warm-started, and it is assumed that vocabularies and Tensor names are unchanged.
embedding_config_spec Optional EmbeddingConfigSpec instance to support using TPU embedding.
export_saved_model_api_version an integer: 1 or 2. 1 corresponds to V1, 2 corresponds to V2. (Defaults to V1). With V1, export_saved_model() adds rewrite() and TPUPartitionedCallOp() for user; while in v2, user is expected to add rewrite(), TPUPartitionedCallOp() etc in their model_fn.