{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "PNasEXNdDC4S" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/vizier/blob/main/docs/guides/developer/early_stopping.ipynb)\n", "\n", "# Early Stopping\n", "This notebook will allow a developer to:\n", "\n", "* Understand the Early Stopping API.\n", "* Write Pythia policies for early stopping.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "96nvhAjMDCXb" }, "source": [ "## Installation and reference imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iOHkUCw5v0DX" }, "outputs": [], "source": [ "!pip install google-vizier" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1zJ5SQGIDJrB" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "from vizier import pythia" ] }, { "cell_type": "markdown", "metadata": { "id": "vF5_CgfND1vl" }, "source": [ "## Early Stopping\n", "In hyperparameter optimization, early stopping is a useful mechanism to prevent wasted resources by stopping unpromising trials. Two main considerations for determining whether to stop an active trial are:\n", "\n", "* **At a macro level, how a trial's performance compares to the rest of the trials globally.** For example, we may stop a trial if it is predicted to significantly underperform compared to the history of trials so far in the study.\n", "* **At a micro level, how a trial's intermediate measurements are changing over time.** For example, in a classification task, overfitting may be happening when test accuracy starts to decrease." ] }, { "cell_type": "markdown", "metadata": { "id": "4r54fvqcIfad" }, "source": [ "## API\n", "Based on the above considerations, to allow full flexibility to consider when to stop a trial, we thus use the following abridged API below. Exact class entrypoint can be found [here](https://github.com/google/vizier/blob/main/vizier/pythia.py).\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "mdFOk5ar5hpN" }, "source": [ "The `EarlyStopRequest` takes in a set of trial ID's for early stopping consideration. However, note that trials outside of this set can also be stopped.\n", "\n", "```python\n", "class EarlyStopRequest:\n", " \"\"\"Early stopping request.\"\"\"\n", "\n", " trial_ids: Optional[FrozenSet[int]]\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "xg3ArdGZrINc" }, "source": [ "In addition, we have the `EarlyStopDecision` to denote a single trial's stopping condition and the plural `EarlyStopDecisions` for a set of trials:\n", "\n", "```python\n", "class EarlyStopDecision:\n", " \"\"\"Stopping decision on a single trial.\"\"\"\n", "\n", " id: int\n", " should_stop: bool\n", "```\n", "\n", "```python\n", "class EarlyStopDecisions:\n", " \"\"\"This is the output of the Policy.early_stop() method.\"\"\"\n", "\n", " decisions: list[EarlyStopDecision]\n", " metadata: vz.MetadataDelta\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "qDTvlrIzItWE" }, "source": [ "They will be used in the Pythia policy's `early_stop` method:\n", "\n", "```python\n", "class Policy(abc.ABC):\n", " \"\"\"Interface for Pythia2 Policy subclasses.\"\"\"\n", "\n", " @abc.abstractmethod\n", " def early_stop(self, request: EarlyStopRequest) -\u003e EarlyStopDecisions:\n", " \"\"\"Decide which Trials Vizier should stop.\"\"\"\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "0EZOVGTTDhmL" }, "source": [ "## Example usage\n", "As an example, suppose our rule is to stop all requested trials whose 50th intermediate measurement is too low, e.g. bottom 10% of all trials so far." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ngxgn_oaJHrR" }, "outputs": [], "source": [ "class MyEarlyStoppingPolicy(pythia.Policy):\n", " \"\"\"Stops requested trial if its 50th measurement is too low.\"\"\"\n", "\n", " def __init__(self, policy_supporter: pythia.PolicySupporter, index: int = 50):\n", " self._policy_supporter = policy_supporter\n", " self._index = index\n", "\n", " def early_stop(\n", " self, request: pythia.EarlyStopRequest\n", " ) -\u003e pythia.EarlyStopDecisions:\n", " metric_name = request.study_config.metric_information.item().name\n", "\n", " # Obtain cutoff for 10th percentile.\n", " all_trials = self._policy_supporter.GetTrials(study_guid=request.study_guid)\n", " all_metrics = []\n", " for trial in all_trials:\n", " if len(trial.measurements) \u003e self._index:\n", " all_metrics.append(trial.measurements[self._index].metrics[metric_name])\n", " cutoff = np.percentile(all_metrics, 10)\n", "\n", " # Filter requested trials by cutoff.\n", " considered_trials = [\n", " trial for trial in all_trials if trial.id in request.trial_ids\n", " ]\n", " stopping_decisions = []\n", " for trial in considered_trials:\n", " if trial.measurements[self._index].metrics[metric_name] \u003c cutoff:\n", " decision = pythia.EarlyStopDecision(\n", " trial.id, reason='Below cutoff', should_stop=True\n", " )\n", " else:\n", " decision = pythia.EarlyStopDecision(\n", " trial.id, reason='Above cutoff', should_stop=False\n", " )\n", " stopping_decisions.append(decision)\n", " return pythia.EarlyStopDecisions(decisions=stopping_decisions)" ] } ], "metadata": { "colab": { "name": "Early Stopping.ipynb", "private_outputs": true, "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }