-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic implementation for mlos_benchd service #949
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,70 @@ | ||||||
#!/usr/bin/env python3 | ||||||
# | ||||||
# Copyright (c) Microsoft Corporation. | ||||||
# Licensed under the MIT License. | ||||||
# | ||||||
""" | ||||||
mlos_bench background execution daemon. | ||||||
|
||||||
This script is responsible for polling the storage for runnable experiments and | ||||||
executing them in parallel. | ||||||
|
||||||
See the current ``--help`` `output for details. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
import argparse | ||||||
import time | ||||||
from concurrent.futures import ProcessPoolExecutor | ||||||
|
||||||
from mlos_bench.run import _main as mlos_bench_main | ||||||
from mlos_bench.storage import from_config | ||||||
|
||||||
|
||||||
def _main(args: argparse.Namespace) -> None: | ||||||
storage = from_config(config=args.storage) | ||||||
poll_interval = float(args.poll_interval) | ||||||
num_workers = int(args.num_workers) | ||||||
|
||||||
with ProcessPoolExecutor(max_workers=num_workers) as executor: | ||||||
while True: | ||||||
exp_id = storage.get_runnable_experiment() | ||||||
if exp_id is None: | ||||||
print(f"No runnable experiment found. Sleeping for {poll_interval} second(s).") | ||||||
time.sleep(poll_interval) | ||||||
continue | ||||||
|
||||||
exp = storage.experiments[exp_id] | ||||||
root_env_config, _, _ = exp.root_env_config | ||||||
|
||||||
executor.submit( | ||||||
mlos_bench_main, | ||||||
[ | ||||||
"--storage", | ||||||
args.storage, | ||||||
"--environment", | ||||||
root_env_config, | ||||||
"--experiment_id", | ||||||
exp_id, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll eventually need to include other things here too:
|
||||||
], | ||||||
) | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
parser = argparse.ArgumentParser(description="mlos_benchd") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a |
||||||
parser.add_argument( | ||||||
"--storage", | ||||||
required=True, | ||||||
help="Path to the storage configuration file.", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--num_workers", | ||||||
required=False, | ||||||
default=1, | ||||||
help="Number of workers to use. Default is 1.", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--poll_interval", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also provide |
||||||
required=False, | ||||||
default=1, | ||||||
help="Polling interval in seconds. Default is 1.", | ||||||
) | ||||||
_main(parser.parse_args()) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For testing, may also want some sort of hidden argument or environment variable used to set the |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,13 +23,17 @@ | |
""" | ||
|
||
import logging | ||
import os | ||
import platform | ||
from abc import ABCMeta, abstractmethod | ||
from collections.abc import Iterator, Mapping | ||
from contextlib import AbstractContextManager as ContextManager | ||
from datetime import datetime | ||
from types import TracebackType | ||
from typing import Any, Literal | ||
|
||
from pytz import UTC | ||
|
||
from mlos_bench.config.schemas import ConfigSchema | ||
from mlos_bench.dict_templater import DictTemplater | ||
from mlos_bench.environments.status import Status | ||
|
@@ -133,6 +137,17 @@ def experiment( # pylint: disable=too-many-arguments | |
the results of the experiment and related data. | ||
""" | ||
|
||
@abstractmethod | ||
def get_runnable_experiment(self) -> str | None: | ||
""" | ||
Get the ID of the experiment that can be run. | ||
|
||
Returns | ||
------- | ||
experiment_id : str | None | ||
ID of the experiment that can be run. | ||
""" | ||
|
||
class Experiment(ContextManager, metaclass=ABCMeta): | ||
# pylint: disable=too-many-instance-attributes | ||
""" | ||
|
@@ -150,6 +165,7 @@ def __init__( # pylint: disable=too-many-arguments | |
root_env_config: str, | ||
description: str, | ||
opt_targets: dict[str, Literal["min", "max"]], | ||
ts_start: datetime | None = None, | ||
): | ||
self._tunables = tunables.copy() | ||
self._trial_id = trial_id | ||
|
@@ -159,6 +175,11 @@ def __init__( # pylint: disable=too-many-arguments | |
) | ||
self._description = description | ||
self._opt_targets = opt_targets | ||
self._ts_start = ts_start or datetime.now(UTC) | ||
self._ts_end: datetime | None = None | ||
self._status = Status.PENDING | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should match what was stored in the backend for resumable Experiments, right? |
||
self._driver_name: str | None = None | ||
self._driver_pid: int | None = None | ||
self._in_context = False | ||
|
||
def __enter__(self) -> "Storage.Experiment": | ||
|
@@ -209,6 +230,9 @@ def _setup(self) -> None: | |
|
||
This method is called by `Storage.Experiment.__enter__()`. | ||
""" | ||
self._status = Status.RUNNING | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add some asserts on expected status to check for invalid state transitions. |
||
self._driver_name = platform.node() | ||
self._driver_pid = os.getpid() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These seem ok to initialize the values from There are a few cases I can think of:
|
||
|
||
def _teardown(self, is_ok: bool) -> None: | ||
""" | ||
|
@@ -221,6 +245,11 @@ def _teardown(self, is_ok: bool) -> None: | |
is_ok : bool | ||
True if there were no exceptions during the experiment, False otherwise. | ||
""" | ||
if is_ok: | ||
self._status = Status.SUCCEEDED | ||
else: | ||
self._status = Status.FAILED | ||
self._ts_end = datetime.now(UTC) | ||
|
||
@property | ||
def experiment_id(self) -> str: | ||
|
@@ -394,6 +423,10 @@ def _new_trial( | |
the results of the experiment trial run. | ||
""" | ||
|
||
@abstractmethod | ||
def save(self) -> None: | ||
"""Save the experiment to the storage, without running it.""" | ||
|
||
class Trial(metaclass=ABCMeta): | ||
# pylint: disable=too-many-instance-attributes | ||
""" | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -41,6 +41,7 @@ def __init__( # pylint: disable=too-many-arguments | |||||||||||||
root_env_config: str, | ||||||||||||||
description: str, | ||||||||||||||
opt_targets: dict[str, Literal["min", "max"]], | ||||||||||||||
ts_start: datetime | None = None, | ||||||||||||||
): | ||||||||||||||
super().__init__( | ||||||||||||||
tunables=tunables, | ||||||||||||||
|
@@ -49,12 +50,12 @@ def __init__( # pylint: disable=too-many-arguments | |||||||||||||
root_env_config=root_env_config, | ||||||||||||||
description=description, | ||||||||||||||
opt_targets=opt_targets, | ||||||||||||||
ts_start=ts_start, | ||||||||||||||
) | ||||||||||||||
self._engine = engine | ||||||||||||||
self._schema = schema | ||||||||||||||
|
||||||||||||||
def _setup(self) -> None: | ||||||||||||||
super()._setup() | ||||||||||||||
def _ensure_persisted(self) -> None: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
or
Suggested change
or
Suggested change
|
||||||||||||||
with self._engine.begin() as conn: | ||||||||||||||
# Get git info and the last trial ID for the experiment. | ||||||||||||||
# pylint: disable=not-callable | ||||||||||||||
|
@@ -90,6 +91,8 @@ def _setup(self) -> None: | |||||||||||||
git_repo=self._git_repo, | ||||||||||||||
git_commit=self._git_commit, | ||||||||||||||
root_env_config=self._root_env_config, | ||||||||||||||
ts_start=self._ts_start, | ||||||||||||||
status=self._status.name, | ||||||||||||||
) | ||||||||||||||
) | ||||||||||||||
conn.execute( | ||||||||||||||
|
@@ -125,6 +128,39 @@ def _setup(self) -> None: | |||||||||||||
exp_info.git_commit, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def save(self) -> None: | ||||||||||||||
self._ensure_persisted() | ||||||||||||||
|
||||||||||||||
def _setup(self) -> None: | ||||||||||||||
super()._setup() | ||||||||||||||
self._ensure_persisted() | ||||||||||||||
with self._engine.begin() as conn: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might need to separate that out to an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could also rename |
||||||||||||||
conn.execute( | ||||||||||||||
self._schema.experiment.update() | ||||||||||||||
.where(self._schema.experiment.c.exp_id == self._experiment_id) | ||||||||||||||
.values( | ||||||||||||||
{ | ||||||||||||||
self._schema.experiment.c.status: self._status.name, | ||||||||||||||
self._schema.experiment.c.driver_name: self._driver_name, | ||||||||||||||
self._schema.experiment.c.driver_pid: self._driver_pid, | ||||||||||||||
} | ||||||||||||||
) | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def _teardown(self, is_ok: bool) -> None: | ||||||||||||||
super()._teardown(is_ok) | ||||||||||||||
with self._engine.begin() as conn: | ||||||||||||||
conn.execute( | ||||||||||||||
self._schema.experiment.update() | ||||||||||||||
.where(self._schema.experiment.c.exp_id == self._experiment_id) | ||||||||||||||
.values( | ||||||||||||||
{ | ||||||||||||||
self._schema.experiment.c.status: self._status.name, | ||||||||||||||
self._schema.experiment.c.ts_end: self._ts_end, | ||||||||||||||
} | ||||||||||||||
) | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def merge(self, experiment_ids: list[str]) -> None: | ||||||||||||||
_LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) | ||||||||||||||
raise NotImplementedError("TODO: Merging experiments not implemented yet.") | ||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -5,10 +5,14 @@ | |||
"""Saving and restoring the benchmark data in SQL database.""" | ||||
|
||||
import logging | ||||
import platform | ||||
from datetime import datetime | ||||
from typing import Literal | ||||
|
||||
from sqlalchemy import URL, create_engine | ||||
from pytz import UTC | ||||
from sqlalchemy import URL, create_engine, exc | ||||
|
||||
from mlos_bench.environments.status import Status | ||||
from mlos_bench.services.base_service import Service | ||||
from mlos_bench.storage.base_experiment_data import ExperimentData | ||||
from mlos_bench.storage.base_storage import Storage | ||||
|
@@ -109,3 +113,48 @@ def experiments(self) -> dict[str, ExperimentData]: | |||
) | ||||
for exp in cur_exp.fetchall() | ||||
} | ||||
|
||||
def get_runnable_experiment(self) -> str | None: | ||||
with self._engine.connect() as conn: | ||||
with conn.begin() as trans: | ||||
try: | ||||
experiment_row = conn.execute( | ||||
self._schema.experiment.select() | ||||
.where( | ||||
self._schema.experiment.c.status == Status.PENDING.name, | ||||
self._schema.experiment.c.driver_name.is_(None), | ||||
self._schema.experiment.c.ts_start <= datetime.now(UTC), | ||||
) | ||||
.order_by(self._schema.experiment.c.ts_start.asc()) | ||||
.limit(1) | ||||
).fetchone() | ||||
if experiment_row: | ||||
# try to grab | ||||
result = conn.execute( | ||||
self._schema.experiment.update() | ||||
.where( | ||||
self._schema.experiment.c.driver_name.is_(None), | ||||
self._schema.experiment.c.exp_id == experiment_row.exp_id, | ||||
) | ||||
.values( | ||||
{ | ||||
self._schema.experiment.c.driver_name: platform.node(), | ||||
self._schema.experiment.c.status: Status.READY.name, | ||||
} | ||||
) | ||||
) | ||||
if result: | ||||
# succeeded, commit the transaction and return | ||||
trans.commit() | ||||
# return this to calling code to spawn a new `mlos_bench` | ||||
# process to fork and execute this Experiment on this host | ||||
# in the background | ||||
return str(experiment_row.exp_id) | ||||
else: | ||||
# someone else probably grabbed it | ||||
trans.rollback() | ||||
except exc.SQLAlchemyError: | ||||
# probably a conflict | ||||
trans.rollback() | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor tweaks to help make all of the docstring generation cross referencing. Might need some tweaks.