Early Stopping
This notebook will allow a developer to:
Understand the Early Stopping API.
Write Pythia policies for early stopping.
Installation and reference imports
!pip install google-vizier
import numpy as np
from vizier import pythia
Early Stopping
In hyperparameter optimization, early stopping is a useful mechanism to prevent wasted resources by stopping unpromising trials. Two main considerations for determining whether to stop an active trial are:
At a macro level, how a trial’s performance compares to the rest of the trials globally. For example, we may stop a trial if it is predicted to significantly underperform compared to the history of trials so far in the study.
At a micro level, how a trial’s intermediate measurements are changing over time. For example, in a classification task, overfitting may be happening when test accuracy starts to decrease.
API
Based on the above considerations, to allow full flexibility to consider when to stop a trial, we thus use the following abridged API below. Exact class entrypoint can be found here.
The EarlyStopRequest
takes in a set of trial ID’s for early stopping consideration. However, note that trials outside of this set can also be stopped.
class EarlyStopRequest:
"""Early stopping request."""
trial_ids: Optional[FrozenSet[int]]
In addition, we have the EarlyStopDecision
to denote a single trial’s stopping condition and the plural EarlyStopDecisions
for a set of trials:
class EarlyStopDecision:
"""Stopping decision on a single trial."""
id: int
should_stop: bool
class EarlyStopDecisions:
"""This is the output of the Policy.early_stop() method."""
decisions: list[EarlyStopDecision]
metadata: vz.MetadataDelta
They will be used in the Pythia policy’s early_stop
method:
class Policy(abc.ABC):
"""Interface for Pythia2 Policy subclasses."""
@abc.abstractmethod
def early_stop(self, request: EarlyStopRequest) -> EarlyStopDecisions:
"""Decide which Trials Vizier should stop."""
Example usage
As an example, suppose our rule is to stop all requested trials whose 50th intermediate measurement is too low, e.g. bottom 10% of all trials so far.
class MyEarlyStoppingPolicy(pythia.Policy):
"""Stops requested trial if its 50th measurement is too low."""
def __init__(self, policy_supporter: pythia.PolicySupporter, index: int = 50):
self._policy_supporter = policy_supporter
self._index = index
def early_stop(
self, request: pythia.EarlyStopRequest
) -> pythia.EarlyStopDecisions:
metric_name = request.study_config.metric_information.item().name
# Obtain cutoff for 10th percentile.
all_trials = self._policy_supporter.GetTrials(study_guid=request.study_guid)
all_metrics = []
for trial in all_trials:
if len(trial.measurements) > self._index:
all_metrics.append(trial.measurements[self._index].metrics[metric_name])
cutoff = np.percentile(all_metrics, 10)
# Filter requested trials by cutoff.
considered_trials = [
trial for trial in all_trials if trial.id in request.trial_ids
]
stopping_decisions = []
for trial in considered_trials:
if trial.measurements[self._index].metrics[metric_name] < cutoff:
decision = pythia.EarlyStopDecision(
trial.id, reason='Below cutoff', should_stop=True
)
else:
decision = pythia.EarlyStopDecision(
trial.id, reason='Above cutoff', should_stop=False
)
stopping_decisions.append(decision)
return pythia.EarlyStopDecisions(decisions=stopping_decisions)