Tuner classkeras_tuner.Tuner(
oracle,
hypermodel=None,
max_model_size=None,
optimizer=None,
loss=None,
metrics=None,
distribution_strategy=None,
directory=None,
project_name=None,
logger=None,
tuner_id=None,
overwrite=False,
executions_per_trial=1,
**kwargs
)
Tuner class for Keras models.
This is the base Tuner class for all tuners for Keras models. It manages
the building, training, evaluation and saving of the Keras models. New
tuners can be created by subclassing the class.
All Keras related logics are in Tuner.run_trial() and its subroutines.
When subclassing Tuner, if not calling super().run_trial(), it can tune
anything.
Arguments
Oracle class.HyperModel class (or callable that takes
hyperparameters and returns a Model instance). It is optional
when Tuner.run_trial() is overriden and does not use
self.hypermodel.optimizer
argument in the compile step for the models. If the hypermodel
does not compile the models it generates, then this argument must be
specified.loss argument in the
compile step for the models. If the hypermodel does not compile
the models it generates, then this argument must be specified.metrics
argument in the compile step for the models. If the hypermodel
does not compile the models it generates, then this argument must
be specified.tf.distribute.Strategy.
If specified, each trial will run under this scope. For example,
tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1']) will run
each trial on two GPUs. Currently only single-worker strategies are
supported.Tuner.Tuner.False. If False, reloads an
existing project of the same name if one is found. Otherwise,
overwrites the project.BaseTuner.Attributes
None if max_trials is
not set. This is useful when resuming a previously stopped search.get_best_hyperparameters methodTuner.get_best_hyperparameters(num_trials=1)
Returns the best hyperparameters, as determined by the objective.
This method can be used to reinstantiate the (untrained) best model found during the search process.
Example
best_hp = tuner.get_best_hyperparameters()[0]
model = tuner.hypermodel.build(best_hp)
Arguments
HyperParameters objects to return.Returns
List of HyperParameter objects sorted from the best to the worst.
get_best_models methodTuner.get_best_models(num_models=1)
Returns the best model(s), as determined by the tuner's objective.
The models are loaded with the weights corresponding to their best checkpoint (at the end of the best epoch of best trial).
This method is for querying the models trained during the search.
For best performance, it is recommended to retrain your Model on the
full dataset using the best hyperparameters found during search,
which can be obtained using tuner.get_best_hyperparameters().
Arguments
Returns
List of trained model instances sorted from the best to the worst.
get_state methodTuner.get_state()
Returns the current state of this object.
This method is called during save.
Returns
A dictionary of serializable objects as the state.
load_model methodTuner.load_model(trial)
Loads a Model from a given trial.
For models that report intermediate results to the Oracle, generally
load_model should load the best reported step by relying of
trial.best_step.
Arguments
Trial instance, the Trial corresponding to the model
to load.on_epoch_begin methodTuner.on_epoch_begin(trial, model, epoch, logs=None)
Called at the beginning of an epoch.
Arguments
Trial instance.Model.on_batch_begin methodTuner.on_batch_begin(trial, model, batch, logs)
Called at the beginning of a batch.
Arguments
Trial instance.Model.on_batch_end methodTuner.on_batch_end(trial, model, batch, logs=None)
Called at the end of a batch.
Arguments
Trial instance.Model.on_epoch_end methodTuner.on_epoch_end(trial, model, epoch, logs=None)
Called at the end of an epoch.
Arguments
Trial instance.Model.run_trial methodTuner.run_trial(trial, )
Evaluates a set of hyperparameter values.
This method is called multiple times during search to build and
evaluate the models with different hyperparameters and return the
objective value.
Example
You can use it with self.hypermodel to build and fit the model.
def run_trial(self, trial, *args, **kwargs):
hp = trial.hyperparameters
model = self.hypermodel.build(hp)
return self.hypermodel.fit(hp, model, *args, **kwargs)
You can also use it as a black-box optimizer for anything.
def run_trial(self, trial, *args, **kwargs):
hp = trial.hyperparameters
x = hp.Float("x", -2.0, 2.0)
y = x * x + 2 * x + 1
return y
Arguments
Trial instance that contains the information needed to
run this trial. Hyperparameters can be accessed via
trial.hyperparameters.search.search.Returns
A History object, which is the return value of model.fit(), a
dictionary, a float, or a list of one of these types.
If return a dictionary, it should be a dictionary of the metrics to
track. The keys are the metric names, which contains the
objective name. The values should be the metric values.
If return a float, it should be the objective value.
If evaluating the model for multiple times, you may return a list of results of any of the types above. The final objective value is the average of the results in the list.
results_summary methodTuner.results_summary(num_trials=10)
Display tuning results summary.
The method prints a summary of the search results including the hyperparameter values and evaluation results for each trial.
Arguments
save_model methodTuner.save_model(trial_id, model, step=0)
Saves a Model for a given trial.
Arguments
Trial corresponding to this Model.Oracle, the step the saved file correspond to. For example,
for Keras models this is the number of epochs trained.search methodTuner.search(*fit_args, **fit_kwargs)
Performs a search for best hyperparameter configuations.
Arguments
run_trial, for example the training and validation data.run_trial, for example the training and validation data.search_space_summary methodTuner.search_space_summary(extended=False)
Print search space summary.
The methods prints a summary of the hyperparameters in the search
space, which can be called before calling the search method.
Arguments
set_state methodTuner.set_state(state)
Sets the current state of this object.
This method is called during reload.
Arguments