Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 186 additions & 42 deletions src/ert/dark_storage/endpoints/experiment_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import dataclasses
import datetime
import logging
import os
import queue
Expand All @@ -11,7 +10,7 @@
import warnings
from base64 import b64decode
from queue import SimpleQueue
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import (
APIRouter,
Expand All @@ -23,19 +22,23 @@
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import TypeAdapter
from starlette import status
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.websockets import WebSocket

from ert.base_model_context import use_runtime_plugins
from ert.config import ConfigWarning, QueueSystem
from ert.ensemble_evaluator import EndEvent, EvaluatorServerConfig
from ert.ensemble_evaluator.event import FullSnapshotEvent, SnapshotUpdateEvent
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot
from ert.plugins import get_site_plugins
from ert.run_models import StatusEvents
from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel
from ert.run_models.everest_run_model import EverestExitCode
from ert.run_models.model_factory import (
EverestRunModel,
RunModelConfigs,
_instantiate_run_model,
)
from everest.config import EverestConfig
from everest.detached.everserver import (
ExperimentState,
Expand All @@ -48,6 +51,9 @@
EverEndpoints,
)

if TYPE_CHECKING:
from ert.run_models.run_model import RunModel

router = APIRouter(prefix="/experiment_server", tags=["experiment_server"])


Expand All @@ -64,6 +70,9 @@ class ExperimentRunnerState:
run_path: str | os.PathLike[str] | None = None
storage_path: str | os.PathLike[str] | None = None
start_time_unix: int | None = None
run_model: "RunModel | None" = dataclasses.field(default=None)
supports_rerunning_failed_realizations: bool = False
has_failed_realizations: bool = False


_runs: dict[str, ExperimentRunnerState] = {}
Expand Down Expand Up @@ -211,30 +220,80 @@ async def start_experiment(
request: Request,
background_tasks: BackgroundTasks,
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
rerun_from_run_id: str | None = None,
) -> JSONResponse:
_log(request)
_check_user(credentials)
run_id = str(uuid.uuid4())
if rerun_from_run_id is not None:
if rerun_from_run_id not in _runs:
raise HTTPException(
status_code=404, detail=f"Run '{rerun_from_run_id}' not found"
)
source_run = _runs[rerun_from_run_id]
if source_run.run_model is None:
raise HTTPException(
status_code=400,
detail=f"Run '{rerun_from_run_id}' has no run model to rerun.",
)
if not source_run.supports_rerunning_failed_realizations:
raise HTTPException(
status_code=400,
detail=f"Run '{rerun_from_run_id}' "
f"does not support rerunning failed realizations.",
)
run_state = ExperimentRunnerState(
config_path=source_run.config_path,
run_path=source_run.run_path,
storage_path=source_run.storage_path,
run_model=source_run.run_model,
supports_rerunning_failed_realizations=source_run.supports_rerunning_failed_realizations,
)
_runs[run_id] = run_state
runner = ExperimentRunner(None, run_id)
background_tasks.add_task(runner.run, rerun=True)
run_state.start_time_unix = int(time.time())
return JSONResponse(
{
"run_id": run_id,
"supports_rerunning_failed_realizations": (
run_state.supports_rerunning_failed_realizations
),
}
)
run_state = ExperimentRunnerState()
_runs[run_id] = run_state
request_data = await request.json()
# The output of warnings is the task of the user interface, not
# of everserver. Therefore we suppress them here:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConfigWarning)
config = EverestConfig.with_plugins(request_data)
adapter: TypeAdapter[RunModelConfigs] = TypeAdapter(RunModelConfigs)
config: RunModelConfigs | EverestConfig
if request_data.get("type") == "everest_config":
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConfigWarning)
config = EverestConfig.with_plugins(request_data)
else:
config = adapter.validate_python(request_data)
runner = ExperimentRunner(config, run_id)
try:
background_tasks.add_task(runner.run)
run_state.config_path = config.config_path

run_state.run_path = config.simulation_dir
run_state.storage_path = config.output_dir

if isinstance(config, EverestConfig):
run_state.config_path = config.config_file
run_state.run_path = config.output_dir
run_state.storage_path = str(config.storage_dir)
else:
run_state.config_path = config.user_config_file
run_state.run_path = config.runpath_config.runpath_format_string
run_state.storage_path = config.storage_path
# Assume client and server is always in the same timezone
# so disregard timestamps
run_state.start_time_unix = int(time.time())
return JSONResponse({"run_id": run_id})
return JSONResponse(
{
"run_id": run_id,
"supports_rerunning_failed_realizations": (
run_state.supports_rerunning_failed_realizations
),
}
)
except Exception as e:
run_state.status = ExperimentStatus(
status=ExperimentState.failed,
Expand Down Expand Up @@ -281,6 +340,71 @@ async def start_time(
return Response(str(run.start_time_unix), status_code=200)


@router.post("/" + EverEndpoints.check_runpath)
async def check_runpath(
request: Request,
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
) -> JSONResponse:
_log(request)
_check_user(credentials)
status_queue: SimpleQueue[StatusEvents] = SimpleQueue()
request_data = await request.json()
adapter: TypeAdapter[RunModelConfigs] = TypeAdapter(RunModelConfigs)
try:
config = adapter.validate_python(request_data)
run_model = _instantiate_run_model(config, status_queue)
try:
runpath_exists = run_model.check_if_runpath_exists()
return JSONResponse(
{
"runpath_exists": runpath_exists,
"num_existing": run_model.get_number_of_existing_runpaths(),
"num_active": run_model.get_number_of_active_realizations(),
}
)
finally:
run_model._storage.close()
except Exception as e:
raise HTTPException(
status_code=422, detail=f"Could not check runpath: {e!s}"
) from e


@router.post("/" + EverEndpoints.delete_runpaths)
async def delete_runpaths(
request: Request,
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
) -> Response:
_log(request)
_check_user(credentials)
status_queue: SimpleQueue[StatusEvents] = SimpleQueue()
request_data = await request.json()
adapter: TypeAdapter[RunModelConfigs] = TypeAdapter(RunModelConfigs)
try:
config = adapter.validate_python(request_data)
run_model = _instantiate_run_model(config, status_queue)
try:
run_model.rm_run_path()
finally:
run_model._storage.close()
return Response("Runpaths deleted.", 200)
except Exception as e:
raise HTTPException(
status_code=422, detail=f"Could not delete runpaths: {e!s}"
) from e


@router.get(f"/{EverEndpoints.has_failed_realizations}/{{run_id}}")
def has_failed_realizations_endpoint(
request: Request,
run: Annotated[ExperimentRunnerState, Depends(_get_run)],
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
) -> JSONResponse:
_log(request)
_check_user(credentials)
return JSONResponse({"has_failed": run.has_failed_realizations})


@router.websocket(f"/{EverEndpoints.events}/{{run_id}}")
async def websocket_endpoint(websocket: WebSocket, run_id: str) -> None:
await websocket.accept()
Expand Down Expand Up @@ -328,27 +452,30 @@ async def _get_event(subscriber_id: str, run_id: str) -> StatusEvents:
class ExperimentRunner:
def __init__(
self,
everest_config: EverestConfig,
config: RunModelConfigs | None,
run_id: str,
) -> None:
super().__init__()

self._everest_config = everest_config
self._config = config
self._run_id = run_id

async def run(self) -> None:
async def run(self, rerun: bool = False) -> None:
run = _runs[self._run_id]
status_queue: SimpleQueue[StatusEvents] = SimpleQueue()
run_model: EverestRunModel | None = None
run_model: RunModel | None = None
try:
site_plugins = get_site_plugins()
with use_runtime_plugins(site_plugins):
run_model = EverestRunModel.create(
everest_config=self._everest_config,
experiment_name=f"EnOpt@{datetime.datetime.now().astimezone().isoformat(timespec='seconds')}",
target_ensemble="batch",
status_queue=status_queue,
runtime_plugins=site_plugins,
if rerun and run.run_model is not None:
run_model = run.run_model
run_model._status_queue = status_queue
else:
assert self._config is not None, (
"ExperimentRunner.run() called without config for a fresh run"
)
run_model = _instantiate_run_model(self._config, status_queue)
run.run_model = run_model
run.supports_rerunning_failed_realizations = (
run_model.supports_rerunning_failed_realizations
)
run.status = ExperimentStatus(
message="Experiment started", status=ExperimentState.running
Expand All @@ -362,10 +489,11 @@ async def run(self) -> None:
else EvaluatorServerConfig(use_ipc_protocol=False)
),
)
cancelled = False
while True:
if run.status.status == ExperimentState.stopped:
if run.status.status == ExperimentState.stopped and not cancelled:
run_model.cancel()
raise UserCancelled("Optimization aborted")
cancelled = True
try:
item: StatusEvents = status_queue.get(block=False)
except queue.Empty:
Expand All @@ -381,16 +509,27 @@ async def run(self) -> None:
for sub in list(run.subscribers.values()):
await sub.is_done()
break
await simulation_future
assert run_model.exit_code is not None
exp_status, msg = _get_optimization_status(
run_model.exit_code,
run.events,
)
run.status = ExperimentStatus(
message=msg,
status=exp_status,
)
try:
await simulation_future
except Exception:
if not cancelled:
raise
if not cancelled:
if isinstance(run_model, EverestRunModel):
assert run_model.exit_code is not None
exp_state, msg = _get_optimization_status(
run_model.exit_code, run.events
)
run_status = ExperimentStatus(
message=msg,
status=exp_state,
)
else:
run_status = ExperimentStatus(
message="Experiment completed.",
status=ExperimentState.completed,
)
run.status = run_status
except UserCancelled as e:
logging.getLogger(EXPERIMENT_SERVER).info(f"User cancelled: {e}")
except Exception as e:
Expand All @@ -400,8 +539,13 @@ async def run(self) -> None:
status=ExperimentState.failed,
)
finally:
if run_model and run_model._experiment:
run_model._experiment.status = run.status
if run_model is not None:
run.has_failed_realizations = run_model.has_failed_realizations()
if (
isinstance(run_model, EverestRunModel)
and run_model._experiment is not None
):
run_model._experiment.status = run.status

logging.getLogger(EXPERIMENT_SERVER).info(
f"ExperimentRunner done. Items left in queue: {status_queue.qsize()}"
Expand Down
Loading
Loading