diff --git a/client-http/src/test/scala/org/apache/livy/client/http/HttpClientSpec.scala b/client-http/src/test/scala/org/apache/livy/client/http/HttpClientSpec.scala index f53d9f5b4..2215b58a9 100644 --- a/client-http/src/test/scala/org/apache/livy/client/http/HttpClientSpec.scala +++ b/client-http/src/test/scala/org/apache/livy/client/http/HttpClientSpec.scala @@ -276,7 +276,7 @@ private class HttpClientTestBootstrap extends LifeCycle { when(session.name).thenReturn(None) when(session.appId).thenReturn(None) when(session.appInfo).thenReturn(AppInfo()) - when(session.state).thenReturn(SessionState.Idle) + when(session.state).thenReturn(SessionState.Idle()) when(session.proxyUser).thenReturn(None) when(session.kind).thenReturn(Spark) when(session.stop()).thenReturn(Future.successful(())) diff --git a/core/src/main/scala/org/apache/livy/sessions/Kind.scala b/core/src/main/scala/org/apache/livy/sessions/Kind.scala index 0a05c8fcd..06b19eba2 100644 --- a/core/src/main/scala/org/apache/livy/sessions/Kind.scala +++ b/core/src/main/scala/org/apache/livy/sessions/Kind.scala @@ -45,6 +45,12 @@ object Kind { case "sql" => SQL case other => throw new IllegalArgumentException(s"Invalid kind: $other") } + + val kinds: Seq[Kind] = Seq(Spark, PySpark, SparkR, Shared, SQL) + + def isValid(kind: String): Boolean = { + kinds.map(_.name).contains(kind) + } } class SessionKindModule extends SimpleModule("SessionKind") { diff --git a/core/src/main/scala/org/apache/livy/sessions/SessionState.scala b/core/src/main/scala/org/apache/livy/sessions/SessionState.scala index d731c9b14..87ef8b6ad 100644 --- a/core/src/main/scala/org/apache/livy/sessions/SessionState.scala +++ b/core/src/main/scala/org/apache/livy/sessions/SessionState.scala @@ -30,13 +30,13 @@ class FinishedSessionState( object SessionState { def apply(s: String): SessionState = s match { - case "not_started" => NotStarted - case "starting" => Starting - case "recovering" => Recovering - case "idle" => Idle - case "running" => Running - case "busy" => Busy - case "shutting_down" => ShuttingDown + case "not_started" => NotStarted() + case "starting" => Starting() + case "recovering" => Recovering() + case "idle" => Idle() + case "running" => Running() + case "busy" => Busy() + case "shutting_down" => ShuttingDown() case "error" => Error() case "dead" => Dead() case "killed" => Killed() @@ -44,19 +44,19 @@ object SessionState { case _ => throw new IllegalArgumentException(s"Illegal session state: $s") } - object NotStarted extends SessionState("not_started", true) + case class NotStarted() extends SessionState("not_started", true) - object Starting extends SessionState("starting", true) + case class Starting() extends SessionState("starting", true) - object Recovering extends SessionState("recovering", true) + case class Recovering() extends SessionState("recovering", true) - object Idle extends SessionState("idle", true) + case class Idle() extends SessionState("idle", true) - object Running extends SessionState("running", true) + case class Running() extends SessionState("running", true) - object Busy extends SessionState("busy", true) + case class Busy() extends SessionState("busy", true) - object ShuttingDown extends SessionState("shutting_down", false) + case class ShuttingDown() extends SessionState("shutting_down", false) case class Killed(override val time: Long = System.nanoTime()) extends FinishedSessionState("killed", false, time) @@ -69,4 +69,11 @@ object SessionState { case class Success(override val time: Long = System.nanoTime()) extends FinishedSessionState("success", false, time) + + val states: Seq[SessionState] = Seq(NotStarted(), Starting(), Recovering(), Idle(), Running(), + Busy(), ShuttingDown(), Killed(), Error(), Dead(), Success()) + + def isValid(state: String): Boolean = { + states.map(_.state).contains(state) + } } diff --git a/integration-test/src/main/scala/org/apache/livy/test/framework/LivyRestClient.scala b/integration-test/src/main/scala/org/apache/livy/test/framework/LivyRestClient.scala index cf68f7707..4962d76fc 100644 --- a/integration-test/src/main/scala/org/apache/livy/test/framework/LivyRestClient.scala +++ b/integration-test/src/main/scala/org/apache/livy/test/framework/LivyRestClient.scala @@ -113,7 +113,7 @@ class LivyRestClient(val httpClient: AsyncHttpClient, val livyEndpoint: String) class BatchSession(id: Int) extends Session(id, BATCH_TYPE) { def verifySessionDead(): Unit = verifySessionState(SessionState.Dead()) def verifySessionKilled(): Unit = verifySessionState(SessionState.Killed()) - def verifySessionRunning(): Unit = verifySessionState(SessionState.Running) + def verifySessionRunning(): Unit = verifySessionState(SessionState.Running()) def verifySessionSuccess(): Unit = verifySessionState(SessionState.Success()) } @@ -240,7 +240,7 @@ class LivyRestClient(val httpClient: AsyncHttpClient, val livyEndpoint: String) } def verifySessionIdle(): Unit = { - verifySessionState(SessionState.Idle) + verifySessionState(SessionState.Idle()) } def verifySessionKilled(): Unit = { diff --git a/integration-test/src/test/scala/org/apache/livy/test/BatchIT.scala b/integration-test/src/test/scala/org/apache/livy/test/BatchIT.scala index a6f4e73eb..48ea5f882 100644 --- a/integration-test/src/test/scala/org/apache/livy/test/BatchIT.scala +++ b/integration-test/src/test/scala/org/apache/livy/test/BatchIT.scala @@ -85,7 +85,7 @@ class BatchIT extends BaseIntegrationTestSuite with BeforeAndAfterAll { test("deleting a session should kill YARN app") { val output = newOutputPath() withTestLib(classOf[SimpleSparkApp], List(output, "false")) { s => - s.verifySessionState(SessionState.Running) + s.verifySessionState(SessionState.Running()) s.snapshot().appInfo.driverLogUrl.value should include ("containerlogs") val appId = s.appId() @@ -100,7 +100,7 @@ class BatchIT extends BaseIntegrationTestSuite with BeforeAndAfterAll { test("killing YARN app should change batch state to dead") { val output = newOutputPath() withTestLib(classOf[SimpleSparkApp], List(output, "false")) { s => - s.verifySessionState(SessionState.Running) + s.verifySessionState(SessionState.Running()) val appId = s.appId() // Kill the YARN app and check batch state should be KILLED. diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index ea8a761c5..b51d325f1 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -65,7 +65,7 @@ class Session( private implicit val formats = DefaultFormats - private var _state: SessionState = SessionState.NotStarted + private var _state: SessionState = SessionState.NotStarted() // Number of statements kept in driver's memory private val numRetainedStatements = livyConf.getInt(RSCConf.Entry.RETAINED_STATEMENTS) @@ -120,7 +120,7 @@ class Session( def start(): Future[SparkEntries] = { val future = Future { - changeState(SessionState.Starting) + changeState(SessionState.Starting()) // Always start SparkInterpreter after beginning, because we rely on SparkInterpreter to // initialize SparkContext and create SparkEntries. @@ -133,7 +133,7 @@ class Session( interpGroup.put(Spark, sparkInterp) } - changeState(SessionState.Idle) + changeState(SessionState.Idle()) entries }(interpreterExecutor) @@ -263,12 +263,12 @@ class Session( private def executeCode(interp: Option[Interpreter], executionCount: Int, code: String): String = { - changeState(SessionState.Busy) + changeState(SessionState.Busy()) def transitToIdle() = { val executingLastStatement = executionCount == newStatementId.intValue() - 1 if (_statements.isEmpty || executingLastStatement) { - changeState(SessionState.Idle) + changeState(SessionState.Idle()) } } diff --git a/repl/src/test/scala/org/apache/livy/repl/BaseSessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/BaseSessionSpec.scala index 7e23d1fc9..bcea0df29 100644 --- a/repl/src/test/scala/org/apache/livy/repl/BaseSessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/BaseSessionSpec.scala @@ -60,7 +60,7 @@ abstract class BaseSessionSpec(kind: Kind) // Session's constructor should fire an initial state change event. stateChangedCalled.intValue() shouldBe 1 Await.ready(session.start(), 30 seconds) - assert(session.state === SessionState.Idle) + assert(session.state === SessionState.Idle()) // There should be at least 1 state change event fired when session transits to idle. stateChangedCalled.intValue() should (be > 1) testCode(session) @@ -74,14 +74,14 @@ abstract class BaseSessionSpec(kind: Kind) val future = session.start() try { Await.ready(future, 60 seconds) - session.state should (equal (SessionState.Starting) or equal (SessionState.Idle)) + session.state should (equal (SessionState.Starting()) or equal (SessionState.Idle())) } finally { session.close() } } it should "eventually become the idle state" in withSession { session => - session.state should equal (SessionState.Idle) + session.state should equal (SessionState.Idle()) } } diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/StatementState.java b/rsc/src/main/java/org/apache/livy/rsc/driver/StatementState.java index 787fc7793..e50143575 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/driver/StatementState.java +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/StatementState.java @@ -20,6 +20,7 @@ import java.util.*; import com.fasterxml.jackson.annotation.JsonValue; +import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,13 +74,18 @@ static void put(StatementState key, PREDECESSORS = Collections.unmodifiableMap(predecessors); } - static boolean isValid(StatementState from, StatementState to) { + static boolean isAllowed(StatementState from, StatementState to) { return PREDECESSORS.get(to).contains(from); } + public static boolean isValid(String state) { + return Arrays.stream(values()) + .map(x -> StringUtils.capitalize(x.state)).anyMatch(state::equals); + } + static void validate(StatementState from, StatementState to) { LOG.debug("{} -> {}", from, to); - if (!isValid(from, to)) { + if (!isAllowed(from, to)) { throw new IllegalStateException("Illegal Transition: " + from + " -> " + to); } } diff --git a/server/src/main/scala/org/apache/livy/server/batch/BatchSession.scala b/server/src/main/scala/org/apache/livy/server/batch/BatchSession.scala index c94fc04a2..ad3f7ddb5 100644 --- a/server/src/main/scala/org/apache/livy/server/batch/BatchSession.scala +++ b/server/src/main/scala/org/apache/livy/server/batch/BatchSession.scala @@ -115,7 +115,7 @@ object BatchSession extends Logging { id, name, appTag, - SessionState.Starting, + SessionState.Starting(), livyConf, owner, impersonatedUser, @@ -132,7 +132,7 @@ object BatchSession extends Logging { m.id, m.name, m.appTag, - SessionState.Recovering, + SessionState.Recovering(), livyConf, m.owner, m.proxyUser, @@ -184,7 +184,7 @@ class BatchSession( debug(s"$this state changed from $oldState to $newState") newState match { case SparkApp.State.RUNNING => - _state = SessionState.Running + _state = SessionState.Running() info(s"Batch session $id created [appid: ${appId.orNull}, state: ${state.toString}, " + s"info: ${appInfo.asJavaMap}]") case SparkApp.State.FINISHED => _state = SessionState.Success() diff --git a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala index bccdb4d92..b57725a6b 100644 --- a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala +++ b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala @@ -117,7 +117,7 @@ object InteractiveSession extends Logging { None, appTag, client, - SessionState.Starting, + SessionState.Starting(), request.kind, request.heartbeatTimeoutInSecond, livyConf, @@ -144,7 +144,7 @@ object InteractiveSession extends Logging { metadata.appId, metadata.appTag, client, - SessionState.Recovering, + SessionState.Recovering(), metadata.kind, metadata.heartbeatTimeoutS, livyConf, @@ -429,7 +429,7 @@ class InteractiveSession( override def onJobFailed(job: JobHandle[Void], cause: Throwable): Unit = errorOut() override def onJobSucceeded(job: JobHandle[Void], result: Void): Unit = { - transition(SessionState.Running) + transition(SessionState.Running()) info(s"Interactive session $id created [appid: ${appId.orNull}, " + s"owner: $owner, proxyUser:" + s" $proxyUser, state: ${state.toString}, kind: ${kind.toString}, " + @@ -440,7 +440,7 @@ class InteractiveSession( // Other code might call stop() to close the RPC channel. When RPC channel is closing, // this callback might be triggered. Check and don't call stop() to avoid nested called // if the session is already shutting down. - if (serverSideState != SessionState.ShuttingDown) { + if (serverSideState != SessionState.ShuttingDown()) { transition(SessionState.Error()) stop() app.foreach { a => @@ -460,18 +460,18 @@ class InteractiveSession( heartbeatTimeout.toSeconds.toInt, owner, proxyUser, rscDriverUri) override def state: SessionState = { - if (serverSideState == SessionState.Running) { + if (serverSideState == SessionState.Running()) { // If session is in running state, return the repl state from RSCClient. client .flatMap(s => Option(s.getReplState)) .map(SessionState(_)) - .getOrElse(SessionState.Busy) // If repl state is unknown, assume repl is busy. + .getOrElse(SessionState.Busy()) // If repl state is unknown, assume repl is busy. } else serverSideState } override def stopSession(): Unit = { try { - transition(SessionState.ShuttingDown) + transition(SessionState.ShuttingDown()) sessionStore.remove(RECOVERY_SESSION_TYPE, id) client.foreach { _.stop(true) } } catch { @@ -591,7 +591,7 @@ class InteractiveSession( private def ensureRunning(): Unit = synchronized { serverSideState match { - case SessionState.Running => + case SessionState.Running() => case _ => throw new IllegalStateException("Session is in state %s" format serverSideState) } diff --git a/server/src/main/scala/org/apache/livy/sessions/SessionManager.scala b/server/src/main/scala/org/apache/livy/sessions/SessionManager.scala index f8f98a2db..51e4ba3b7 100644 --- a/server/src/main/scala/org/apache/livy/sessions/SessionManager.scala +++ b/server/src/main/scala/org/apache/livy/sessions/SessionManager.scala @@ -155,7 +155,7 @@ class SessionManager[S <: Session, R <: RecoveryMetadata : ClassTag]( case _ => if (!sessionTimeoutCheck) { false - } else if (session.state == SessionState.Busy && sessionTimeoutCheckSkipBusy) { + } else if (session.state == SessionState.Busy() && sessionTimeoutCheckSkipBusy) { false } else if (session.isInstanceOf[BatchSession]) { false diff --git a/server/src/test/scala/org/apache/livy/server/SessionServletSpec.scala b/server/src/test/scala/org/apache/livy/server/SessionServletSpec.scala index cdd17832d..d148e2399 100644 --- a/server/src/test/scala/org/apache/livy/server/SessionServletSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/SessionServletSpec.scala @@ -38,7 +38,7 @@ object SessionServletSpec { override def recoveryMetadata: RecoveryMetadata = MockRecoveryMetadata(0) - override def state: SessionState = SessionState.Idle + override def state: SessionState = SessionState.Idle() override def start(): Unit = () diff --git a/server/src/test/scala/org/apache/livy/server/batch/BatchServletSpec.scala b/server/src/test/scala/org/apache/livy/server/batch/BatchServletSpec.scala index 9920586fe..71fe4dfe2 100644 --- a/server/src/test/scala/org/apache/livy/server/batch/BatchServletSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/batch/BatchServletSpec.scala @@ -64,7 +64,7 @@ class BatchServletSpec extends BaseSessionServletSpec[BatchSession, BatchRecover def testShowSessionProperties(name: Option[String]): Unit = { val id = 0 - val state = SessionState.Running + val state = SessionState.Running() val appId = "appid" val appInfo = AppInfo(Some("DRIVER LOG URL"), Some("SPARK UI URL")) val log = IndexedSeq[String]("log1", "log2") diff --git a/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala b/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala index bc9ddc4d3..cfb818a67 100644 --- a/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala @@ -146,7 +146,7 @@ class BatchSessionSpec val m = BatchRecoveryMetadata(99, name, None, "appTag", null, None) val batch = BatchSession.recover(m, conf, sessionStore, Some(mockApp)) - batch.state shouldBe (SessionState.Recovering) + batch.state shouldBe (SessionState.Recovering()) batch.name shouldBe (name) batch.appIdKnown("appId") diff --git a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala index c97aa19e3..4e0ff386a 100644 --- a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala @@ -64,7 +64,7 @@ class InteractiveSessionServletSpec extends BaseInteractiveServletSpec { when(session.appId).thenReturn(None) when(session.appInfo).thenReturn(AppInfo()) when(session.logLines()).thenReturn(IndexedSeq()) - when(session.state).thenReturn(SessionState.Idle) + when(session.state).thenReturn(SessionState.Idle()) when(session.stop()).thenReturn(Future.successful(())) when(session.proxyUser).thenReturn(None) when(session.heartbeatExpired).thenReturn(false) @@ -165,7 +165,7 @@ class InteractiveSessionServletSpec extends BaseInteractiveServletSpec { val appId = "appid" val owner = "owner" val proxyUser = "proxyUser" - val state = SessionState.Running + val state = SessionState.Running() val kind = Spark val appInfo = AppInfo(Some("DRIVER LOG URL"), Some("SPARK UI URL")) val log = IndexedSeq[String]("log1", "log2") diff --git a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala index 2e2148386..fb952f5b1 100644 --- a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala @@ -95,7 +95,7 @@ class InteractiveSessionSpec extends FunSpec it(desc) { assume(session != null, "No active session.") eventually(timeout(60 seconds), interval(100 millis)) { - session.state shouldBe (SessionState.Idle) + session.state shouldBe (SessionState.Idle()) } fn(session) } @@ -173,7 +173,7 @@ class InteractiveSessionSpec extends FunSpec verify(sessionStore, atLeastOnce()).save( MockitoMatchers.eq(InteractiveSession.RECOVERY_SESSION_TYPE), anyObject()) - session.state should (be(SessionState.Starting) or be(SessionState.Idle)) + session.state should (be(SessionState.Starting()) or be(SessionState.Idle())) } it("should propagate RSC configuration properties") { @@ -227,7 +227,7 @@ class InteractiveSessionSpec extends FunSpec result should equal (expectedResult) eventually(timeout(10 seconds), interval(30 millis)) { - session.state shouldBe (SessionState.Idle) + session.state shouldBe (SessionState.Idle()) } } @@ -266,7 +266,7 @@ class InteractiveSessionSpec extends FunSpec val s = InteractiveSession.recover(m, conf, sessionStore, None, Some(mockClient)) s.start() - s.state shouldBe (SessionState.Recovering) + s.state shouldBe (SessionState.Recovering()) s.appIdKnown("appId") verify(sessionStore, atLeastOnce()).save( @@ -283,7 +283,7 @@ class InteractiveSessionSpec extends FunSpec val s = InteractiveSession.recover(m, conf, sessionStore, None, Some(mockClient)) s.start() - s.state shouldBe (SessionState.Recovering) + s.state shouldBe (SessionState.Recovering()) s.appIdKnown("appId") verify(sessionStore, atLeastOnce()).save( diff --git a/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala index 164649257..ce99b017b 100644 --- a/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala @@ -172,7 +172,7 @@ class JobApiSpec extends BaseInteractiveServletSpec { protected def waitForIdle(id: Int): Unit = { eventually(timeout(1 minute), interval(100 millis)) { jget[SessionInfo](s"/$id") { status => - status.state should be (SessionState.Idle.toString()) + status.state should be (SessionState.Idle().toString()) } } } diff --git a/server/src/test/scala/org/apache/livy/sessions/MockSession.scala b/server/src/test/scala/org/apache/livy/sessions/MockSession.scala index ddcbd4b5a..c4fbce11d 100644 --- a/server/src/test/scala/org/apache/livy/sessions/MockSession.scala +++ b/server/src/test/scala/org/apache/livy/sessions/MockSession.scala @@ -31,7 +31,7 @@ class MockSession(id: Int, owner: String, conf: LivyConf, name: Option[String] = override def logLines(): IndexedSeq[String] = IndexedSeq() - var serverState: SessionState = SessionState.Idle + var serverState: SessionState = SessionState.Idle() override def state: SessionState = serverState override def recoveryMetadata: RecoveryMetadata = RecoveryMetadata(0) diff --git a/server/src/test/scala/org/apache/livy/sessions/SessionManagerSpec.scala b/server/src/test/scala/org/apache/livy/sessions/SessionManagerSpec.scala index a5e9ffa0e..43427db62 100644 --- a/server/src/test/scala/org/apache/livy/sessions/SessionManagerSpec.scala +++ b/server/src/test/scala/org/apache/livy/sessions/SessionManagerSpec.scala @@ -67,7 +67,7 @@ class SessionManagerSpec extends FunSpec with Matchers with LivyBaseUnitTestSuit val session2 = manager.register(new MockSession(manager.nextId(), null, livyConf)) manager.get(session1.id).isDefined should be(true) manager.get(session2.id).isDefined should be(true) - session2.serverState = SessionState.Busy + session2.serverState = SessionState.Busy() eventually(timeout(5 seconds), interval(100 millis)) { Await.result(manager.collectGarbage(), Duration.Inf) (manager.get(session1.id).isDefined, manager.get(session2.id).isDefined) should @@ -132,12 +132,12 @@ class SessionManagerSpec extends FunSpec with Matchers with LivyBaseUnitTestSuit } // Batch session should not be gc-ed when alive - for (s <- Seq(SessionState.Running, - SessionState.Idle, - SessionState.Recovering, - SessionState.NotStarted, - SessionState.Busy, - SessionState.ShuttingDown)) { + for (s <- Seq(SessionState.Running(), + SessionState.Idle(), + SessionState.Recovering(), + SessionState.NotStarted(), + SessionState.Busy(), + SessionState.ShuttingDown())) { changeStateAndCheck(s) { sm => sm.get(session.id) should be (Some(session)) } }