Open in Colab

Early Stopping

This notebook will allow a developer to:

  • Understand the Early Stopping API.

  • Write Pythia policies for early stopping.

Installation and reference imports

!pip install google-vizier
import numpy as np

from vizier import pythia

Early Stopping

In hyperparameter optimization, early stopping is a useful mechanism to prevent wasted resources by stopping unpromising trials. Two main considerations for determining whether to stop an active trial are:

  • At a macro level, how a trial’s performance compares to the rest of the trials globally. For example, we may stop a trial if it is predicted to significantly underperform compared to the history of trials so far in the study.

  • At a micro level, how a trial’s intermediate measurements are changing over time. For example, in a classification task, overfitting may be happening when test accuracy starts to decrease.

API

Based on the above considerations, to allow full flexibility to consider when to stop a trial, we thus use the following abridged API below. Exact class entrypoint can be found here.

The EarlyStopRequest takes in a set of trial ID’s for early stopping consideration. However, note that trials outside of this set can also be stopped.

class EarlyStopRequest:
  """Early stopping request."""

  trial_ids: Optional[FrozenSet[int]]

In addition, we have the EarlyStopDecision to denote a single trial’s stopping condition and the plural EarlyStopDecisions for a set of trials:

class EarlyStopDecision:
  """Stopping decision on a single trial."""

  id: int
  should_stop: bool
class EarlyStopDecisions:
  """This is the output of the Policy.early_stop() method."""

  decisions: list[EarlyStopDecision]
  metadata: vz.MetadataDelta

They will be used in the Pythia policy’s early_stop method:

class Policy(abc.ABC):
  """Interface for Pythia2 Policy subclasses."""

  @abc.abstractmethod
  def early_stop(self, request: EarlyStopRequest) -> EarlyStopDecisions:
    """Decide which Trials Vizier should stop."""

Example usage

As an example, suppose our rule is to stop all requested trials whose 50th intermediate measurement is too low, e.g. bottom 10% of all trials so far.

class MyEarlyStoppingPolicy(pythia.Policy):
  """Stops requested trial if its 50th measurement is too low."""

  def __init__(self, policy_supporter: pythia.PolicySupporter, index: int = 50):
    self._policy_supporter = policy_supporter
    self._index = index

  def early_stop(
      self, request: pythia.EarlyStopRequest
  ) -> pythia.EarlyStopDecisions:
    metric_name = request.study_config.metric_information.item().name

    # Obtain cutoff for 10th percentile.
    all_trials = self._policy_supporter.GetTrials(study_guid=request.study_guid)
    all_metrics = []
    for trial in all_trials:
      if len(trial.measurements) > self._index:
        all_metrics.append(trial.measurements[self._index].metrics[metric_name])
    cutoff = np.percentile(all_metrics, 10)

    # Filter requested trials by cutoff.
    considered_trials = [
        trial for trial in all_trials if trial.id in request.trial_ids
    ]
    stopping_decisions = []
    for trial in considered_trials:
      if trial.measurements[self._index].metrics[metric_name] < cutoff:
        decision = pythia.EarlyStopDecision(
            trial.id, reason='Below cutoff', should_stop=True
        )
      else:
        decision = pythia.EarlyStopDecision(
            trial.id, reason='Above cutoff', should_stop=False
        )
      stopping_decisions.append(decision)
    return pythia.EarlyStopDecisions(decisions=stopping_decisions)