diff --git a/core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/RayAppMasterUtils.java b/core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/RayAppMasterUtils.java index c40873dc..a9c3b233 100644 --- a/core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/RayAppMasterUtils.java +++ b/core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/RayAppMasterUtils.java @@ -43,6 +43,7 @@ public static ActorHandle createAppMaster( .substring(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX.length() + 1); creator.setResource(resourceName, resource.getValue()); } + creator.setMaxTaskRetries(3); return creator.remote(); } @@ -57,6 +58,16 @@ public static Map getRestartedExecutors( return handle.task(RayAppMaster::getRestartedExecutors).remote().get(); } + public static boolean finishApplication( + ActorHandle handle, + String appId, + String stateName, + int exitCode, + String diagnostics) { + return handle.task(RayAppMaster::finishApplication, appId, stateName, exitCode, diagnostics) + .remote().get(); + } + public static void stopAppMaster( ActorHandle handle) { handle.task(RayAppMaster::stop).remote().get(); diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 829517e5..ecf4b538 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -50,6 +50,7 @@ import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBibl import org.apache.spark._ import org.apache.spark.api.r.RUtils +import org.apache.spark.deploy.raydp.{DriverAppMasterReporter, DriverExitState} import org.apache.spark.deploy.rest._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -1011,6 +1012,19 @@ object SparkSubmit extends CommandLineUtils with Logging { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + private def finalizeDriverTermination(): Unit = { + val snapshot = DriverExitState.current() + if (DriverExitState.isTerminal(snapshot.state)) { + DriverAppMasterReporter.tryReportAndCleanup() + } + } + + private def describeFailure(t: Throwable): String = { + val message = Option(t.getMessage).filter(_.nonEmpty) + .getOrElse("No additional diagnostics available.") + s"${t.getClass.getName}: $message" + } + // Following constants are visible for testing. private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = "org.apache.spark.deploy.yarn.YarnClusterApplication" @@ -1020,6 +1034,19 @@ object SparkSubmit extends CommandLineUtils with Logging { "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" override def main(args: Array[String]): Unit = { + DriverExitState.reset() + DriverAppMasterReporter.reset() + val originalExitFn = exitFn + exitFn = (exitCode: Int) => { + if (exitCode == 0) { + DriverExitState.trySetFinished() + } else { + DriverExitState.trySetFailed(exitCode, s"SparkSubmit exited with status $exitCode") + } + finalizeDriverTermination() + originalExitFn(exitCode) + } + val submit = new SparkSubmit() { self => @@ -1050,7 +1077,18 @@ object SparkSubmit extends CommandLineUtils with Logging { } - submit.doSubmit(args) + try { + submit.doSubmit(args) + DriverExitState.trySetFinished() + finalizeDriverTermination() + } catch { + case t: Throwable => + DriverExitState.trySetFailed(1, describeFailure(t)) + finalizeDriverTermination() + throw t + } finally { + exitFn = originalExitFn + } } /** diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationInfo.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationInfo.scala index 4091ccdd..2dc02597 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationInfo.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationInfo.scala @@ -53,6 +53,8 @@ private[spark] class ApplicationInfo( var removedExecutors: ArrayBuffer[ExecutorDesc] = _ var coresGranted: Int = _ var endTime: Long = _ + var exitCode: Int = _ + var diagnostics: String = _ private var nextExecutorId: Int = _ // this only count those registered executors and minus removed executors private var registeredExecutors: Int = 0 @@ -65,6 +67,8 @@ private[spark] class ApplicationInfo( addressToExecutorId = new HashMap[RpcAddress, String] executorIdToHandler = new HashMap[String, ActorHandle[RayDPExecutor]] endTime = -1L + exitCode = 0 + diagnostics = null nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] } @@ -165,9 +169,21 @@ private[spark] class ApplicationInfo( def resetRetryCount(): Unit = _retryCount = 0 + def finish(endState: ApplicationState.Value, endExitCode: Int, endDiagnostics: String): Boolean = + synchronized { + if (isFinished) { + false + } else { + state = endState + exitCode = endExitCode + diagnostics = endDiagnostics + endTime = System.currentTimeMillis() + true + } + } + def markFinished(endState: ApplicationState.Value): Unit = { - state = endState - endTime = System.currentTimeMillis() + finish(endState, 0, null) } def isFinished: Boolean = { diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/DriverAppMasterReporter.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/DriverAppMasterReporter.scala new file mode 100644 index 00000000..4575ce7d --- /dev/null +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/DriverAppMasterReporter.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.raydp + +import java.util.concurrent.atomic.AtomicBoolean + +import scala.util.control.NonFatal + +import io.ray.api.ActorHandle + +import org.apache.spark.internal.Logging + +object DriverAppMasterReporter extends Logging { + + private val reported = new AtomicBoolean(false) + + private var appId: String = null + private var masterHandle: ActorHandle[RayAppMaster] = null + + def reset(): Unit = synchronized { + reported.set(false) + appId = null + masterHandle = null + } + + def bind(appId: String): Unit = synchronized { + if (!reported.get()) { + this.appId = appId + } + } + + def bindMasterHandle(masterHandle: ActorHandle[RayAppMaster]): Unit = synchronized { + if (!reported.get()) { + this.masterHandle = masterHandle + } + } + + def tryReportAndCleanup(): Boolean = { + val snapshot = DriverExitState.current() + if (!DriverExitState.isTerminal(snapshot.state)) { + logDebug(s"Skip AppMaster report because driver state is not terminal: ${snapshot.state}") + false + } else { + val binding = synchronized { + if (reported.get()) None + else Some((appId, masterHandle)) + } + binding match { + case None => false + case Some((currentAppId, currentMasterHandle)) => + if (currentAppId == null || currentMasterHandle == null) { + logWarning("Skip reporting terminal application state because AppMaster binding " + + "is incomplete.") + false + } else { + try { + val accepted = RayAppMasterUtils.finishApplication( + currentMasterHandle, + currentAppId, + snapshot.state.toString, + snapshot.exitCode, + snapshot.diagnostics) + + if (!accepted) { + logWarning("Terminal application state report was not accepted by AppMaster; " + + "keeping reporter state for a later retry.") + false + } else { + reported.set(true) + if (currentMasterHandle != null) { + try { + RayAppMasterUtils.stopAppMaster(currentMasterHandle) + } catch { + case NonFatal(e) => + logWarning("Failed to stop AppMaster during driver cleanup", e) + } + } + synchronized { + appId = null + masterHandle = null + } + true + } + } catch { + case NonFatal(e) => + logWarning("Failed to report terminal application state to AppMaster", e) + false + } + } + } + } + } +} diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/DriverExitState.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/DriverExitState.scala new file mode 100644 index 00000000..81254f34 --- /dev/null +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/DriverExitState.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.raydp + +object DriverExitState { + + case class Snapshot(state: ApplicationState.Value, exitCode: Int, diagnostics: String) + + private var snapshot = Snapshot(ApplicationState.UNKNOWN, 0, null) + + def reset(): Unit = synchronized { + snapshot = Snapshot(ApplicationState.UNKNOWN, 0, null) + } + + def current(): Snapshot = synchronized { + snapshot + } + + def isTerminal(state: ApplicationState.Value): Boolean = { + state == ApplicationState.FINISHED || + state == ApplicationState.FAILED || + state == ApplicationState.KILLED + } + + def trySetFinished(): Boolean = synchronized { + trySet(ApplicationState.FINISHED, 0, null) + } + + def trySetFailed(exitCode: Int, diagnostics: String): Boolean = synchronized { + trySet(ApplicationState.FAILED, normalizedFailureCode(exitCode), diagnostics) + } + + def trySetKilled(exitCode: Int, diagnostics: String): Boolean = synchronized { + val normalizedExitCode = if (exitCode == 0) { + 143 + } else { + exitCode + } + trySet(ApplicationState.KILLED, normalizedExitCode, diagnostics) + } + + private def trySet( + state: ApplicationState.Value, + exitCode: Int, + diagnostics: String): Boolean = { + if (isTerminal(snapshot.state)) { + false + } else { + snapshot = Snapshot(state, exitCode, diagnostics) + true + } + } + + private def normalizedFailureCode(exitCode: Int): Int = { + if (exitCode == 0) { + 1 + } else { + exitCode + } + } +} diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/Messages.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/Messages.scala index 15e8ffc5..77838c48 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/Messages.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/Messages.scala @@ -26,7 +26,11 @@ case class RegisterApplication(appDescription: ApplicationDescription, driver: R case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends RayDPDeployMessage -case class UnregisterApplication(appId: String) extends RayDPDeployMessage +case class FinishApplication( + appId: String, + state: ApplicationState.Value, + exitCode: Int, + diagnostics: String) extends RayDPDeployMessage case class RegisterExecutor(executorId: String, nodeIp: String) extends RayDPDeployMessage diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala index f4cc823d..752d2468 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala @@ -83,6 +83,15 @@ class RayAppMaster(host: String, def getRestartedExecutors(): java.util.Map[String, String] = restartedExecutors.asJava + def finishApplication( + appId: String, + stateName: String, + exitCode: Int, + diagnostics: String): Boolean = { + endpoint.askSync[Boolean]( + FinishApplication(appId, ApplicationState.withName(stateName), exitCode, diagnostics)) + } + /** * This is used to represent the Spark on Ray cluster URL. */ @@ -141,13 +150,13 @@ class RayAppMaster(host: String, logInfo("Registered app " + appDescription.name + " with ID " + app.id) driver.send(RegisteredApplication(app.id, self)) schedule() - - case UnregisterApplication(appId) => - assert(appInfo != null && appInfo.id == appId) - appInfo.markFinished(ApplicationState.FINISHED) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case FinishApplication(appId, state, exitCode, diagnostics) => + assert(appInfo != null && appInfo.id == appId) + context.reply(appInfo.finish(state, exitCode, diagnostics)) + case RegisterExecutor(executorId, executorIp) => val success = appInfo.registerExecutor(executorId) if (success) { diff --git a/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala b/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala index dce83a78..c7675760 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala @@ -59,7 +59,12 @@ class RayCoarseGrainedSchedulerBackend( private val launcherBackend = new LauncherBackend() { override protected def conf: SparkConf = sc.conf - override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + override protected def onStopRequest(): Unit = { + DriverExitState.trySetKilled( + 143, + "Spark launcher requested application stop.") + stop(SparkAppHandle.State.KILLED) + } } def prependPreferPath(cp: String): String = { @@ -93,6 +98,7 @@ class RayCoarseGrainedSchedulerBackend( masterHandle = RayAppMasterUtils.createAppMaster(cp, null, options.toBuffer.asJava, appMasterResources.toMap.asJava) + DriverAppMasterReporter.bindMasterHandle(masterHandle) uri = new URI(RayAppMasterUtils.getMasterUrl(masterHandle)) } else { uri = new URI(sparkUrl) @@ -195,9 +201,6 @@ class RayCoarseGrainedSchedulerBackend( override def stop(): Unit = { stop(SparkAppHandle.State.FINISHED) - if (masterHandle != null) { - RayAppMasterUtils.stopAppMaster(masterHandle) - } } def parseRayDPResourceRequirements(sparkConf: SparkConf): Map[String, Double] = { @@ -259,6 +262,7 @@ class RayCoarseGrainedSchedulerBackend( appId.set(id) launcherBackend.setAppId(id) appMasterRef.set(ref) + DriverAppMasterReporter.bind(id) registrationBarrier.release() } @@ -304,7 +308,9 @@ class RayCoarseGrainedSchedulerBackend( if (stopped.compareAndSet(false, true)) { try { super.stop() // this will stop all executors - appMasterRef.get.send(UnregisterApplication(appId.get)) + if (finalState == SparkAppHandle.State.KILLED) { + DriverAppMasterReporter.tryReportAndCleanup() + } } finally { appMasterRef.set(null) launcherBackend.setState(finalState) diff --git a/python/raydp/tests/test_driver_exit_state.py b/python/raydp/tests/test_driver_exit_state.py new file mode 100644 index 00000000..65b33e46 --- /dev/null +++ b/python/raydp/tests/test_driver_exit_state.py @@ -0,0 +1,181 @@ +import json +import os +import signal +import subprocess +import textwrap +import time +from pathlib import Path + +import ray +import pytest + + +FINISHED_SENTINEL = "RAYDP_E2E_FINISHED" +FAILED_SENTINEL = "RAYDP_E2E_INTENTIONAL_FAILURE" +KILLED_RUNNING_SENTINEL = "RAYDP_E2E_DRIVER_RUNNING" +KILL_AFTER_SECONDS = 30 + +# The scenario is passed by environment instead of as a script argument because +# bin/raydp-submit injects its own Spark --conf entries before the last argument. +DRIVER_SCRIPT = """ +import os +import time +from pyspark.sql import SparkSession + +scenario = os.environ["RAYDP_E2E_SCENARIO"] +spark = SparkSession.builder.appName("raydp-submit-state-e2e-" + scenario).getOrCreate() + +try: + # Run a real Spark action before reporting the sentinel so each case covers + # the raydp-submit -> Spark driver -> RayDP AppMaster path. + assert spark.range(0, 10).count() == 10 + + if scenario == "finished": + print("RAYDP_E2E_FINISHED", flush=True) + elif scenario == "failed": + print("RAYDP_E2E_INTENTIONAL_FAILURE", flush=True) + raise RuntimeError("RAYDP_E2E_INTENTIONAL_FAILURE") + elif scenario == "killed": + # Keep the driver alive long enough for pytest to terminate raydp-submit + # with a real external signal. + print("RAYDP_E2E_DRIVER_RUNNING", flush=True) + time.sleep(300) + else: + raise ValueError("unknown scenario: " + scenario) +finally: + spark.stop() +""" + + +@pytest.fixture(scope="module") +def ray_conf_path(tmp_path_factory): + # raydp-submit needs the same cluster metadata that raydp-submit normally + # receives in production. Reuse the Ray head started by Docker/CI. + started_ray_client = False + if not ray.is_initialized(): + ray.init(address="auto") + started_ray_client = True + try: + node = ray.worker.global_worker.node + options = { + "ray": { + "run-mode": "CLUSTER", + "node-ip": node.node_ip_address, + "address": node.address, + "session-dir": node.get_session_dir_path(), + } + } + finally: + if started_ray_client: + ray.shutdown() + + conf_path = tmp_path_factory.mktemp("ray-conf") / "ray.conf" + conf_path.write_text(json.dumps(options), encoding="utf-8") + return conf_path + + +@pytest.fixture(scope="module") +def driver_script_path(tmp_path_factory): + # Keep the driver app as a temp file so the test exercises bin/raydp-submit + # exactly like a user-submitted Python application. + script_path = tmp_path_factory.mktemp("raydp-submit-state") / "driver_state_app.py" + script_path.write_text(textwrap.dedent(DRIVER_SCRIPT), encoding="utf-8") + return script_path + + +def _repo_root(): + return Path(__file__).resolve().parents[3] + + +def _raydp_submit_command(ray_conf_path, driver_script_path): + # Use the smallest fixed Spark cluster shape that can run the action. The + # driver script must remain the final argument for bin/raydp-submit. + return [ + str(_repo_root() / "bin" / "raydp-submit"), + "--ray-conf", + str(ray_conf_path), + "--conf", + "spark.executor.cores=1", + "--conf", + "spark.executor.instances=1", + "--conf", + "spark.executor.memory=500m", + "--conf", + "spark.dynamicAllocation.enabled=false", + "--conf", + "spark.ui.enabled=false", + str(driver_script_path), + ] + + +def _subprocess_env(scenario): + env = os.environ.copy() + # Avoid inheriting PySpark launcher overrides from the outer test process. + env.pop("PYSPARK_DRIVER_PYTHON", None) + env.pop("PYSPARK_PYTHON", None) + env["RAYDP_E2E_SCENARIO"] = scenario + return env + + +def _run_raydp_submit(ray_conf_path, driver_script_path, scenario): + return subprocess.run( + _raydp_submit_command(ray_conf_path, driver_script_path), + cwd=str(_repo_root()), + env=_subprocess_env(scenario), + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=240, + ) + + +@pytest.mark.parametrize( + "scenario, expected_returncode, sentinel", + [ + ("finished", 0, FINISHED_SENTINEL), + ("failed", 1, FAILED_SENTINEL), + ], +) +def test_raydp_submit_terminal_state( + ray_conf_path, driver_script_path, scenario, expected_returncode, sentinel): + result = _run_raydp_submit(ray_conf_path, driver_script_path, scenario) + + # This smoke test validates user-visible terminal behavior: success exits 0, + # failure exits non-zero, and the driver reached the intended branch. + if expected_returncode == 0: + assert result.returncode == 0, result.stdout + else: + assert result.returncode != 0, result.stdout + assert sentinel in result.stdout + + +def test_raydp_submit_killed_smoke(ray_conf_path, driver_script_path, tmp_path): + output_path = tmp_path / "raydp-submit-killed.log" + with output_path.open("w", encoding="utf-8") as output_file: + proc = subprocess.Popen( + _raydp_submit_command(ray_conf_path, driver_script_path), + cwd=str(_repo_root()), + env=_subprocess_env("killed"), + text=True, + stdout=output_file, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + # The driver sleeps for 300s after the Spark action; waiting here keeps + # the test simple while still killing raydp-submit during driver runtime. + time.sleep(KILL_AFTER_SECONDS) + log = output_path.read_text(encoding="utf-8") + assert proc.poll() is None, log + assert KILLED_RUNNING_SENTINEL in log, ( + f"Driver never reached the killed branch; stdout so far:\n{log}" + ) + + os.killpg(proc.pid, signal.SIGTERM) + try: + proc.wait(timeout=30) + except subprocess.TimeoutExpired: + os.killpg(proc.pid, signal.SIGKILL) + proc.wait(timeout=10) + + assert proc.returncode != 0, output_path.read_text(encoding="utf-8")