diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java index 2b8751d49..a0b9269e4 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java @@ -47,4 +47,26 @@ public class AgentExecutionOptions { public static final ConfigOption RAG_ASYNC = new ConfigOption<>("rag.async", Boolean.class, true); + + /** Set to a positive value in milliseconds to enable short-term memory TTL; 0 disables it. */ + public static final ConfigOption SHORT_TERM_MEMORY_STATE_TTL_MS = + new ConfigOption<>("short-term-memory.state-ttl.ms", Long.class, 0L); + + /** Update policy for short-term memory TTL, consulted only when TTL is enabled. */ + public static final ConfigOption + SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE = + new ConfigOption<>( + "short-term-memory.state-ttl.update-type", + ShortTermMemoryTtlUpdate.class, + ShortTermMemoryTtlUpdate.ON_READ_AND_WRITE); + + /** + * Visibility policy for expired short-term memory state, consulted only when TTL is enabled. + */ + public static final ConfigOption + SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY = + new ConfigOption<>( + "short-term-memory.state-ttl.visibility", + ShortTermMemoryTtlVisibility.class, + ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED); } diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ShortTermMemoryTtlUpdate.java b/api/src/main/java/org/apache/flink/agents/api/agents/ShortTermMemoryTtlUpdate.java new file mode 100644 index 000000000..06b92de52 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ShortTermMemoryTtlUpdate.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.agents; + +/** Defines when short-term memory state TTL is refreshed. */ +public enum ShortTermMemoryTtlUpdate { + ON_CREATE_AND_WRITE, + ON_READ_AND_WRITE +} diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ShortTermMemoryTtlVisibility.java b/api/src/main/java/org/apache/flink/agents/api/agents/ShortTermMemoryTtlVisibility.java new file mode 100644 index 000000000..0e252c6cf --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ShortTermMemoryTtlVisibility.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.agents; + +/** Defines whether expired short-term memory state can be returned before cleanup. */ +public enum ShortTermMemoryTtlVisibility { + NEVER_RETURN_EXPIRED, + RETURN_EXPIRED_IF_NOT_CLEANED_UP +} diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java index 9b0d143eb..312c844ca 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java @@ -72,6 +72,7 @@ * public static ResourceDesc openAIResponses() { * return ResourceDescriptor.Builder.newBuilder(OpenAIResponsesModelConnection.class.getName()) * .addInitialArgument("api_key", System.getenv("OPENAI_API_KEY")) + * .addInitialArgument("api_base_url", System.getenv("OPENAI_API_URL")) * .addInitialArgument("timeout", 120) * .addInitialArgument("max_retries", 3) * .build(); diff --git a/python/flink_agents/api/core_options.py b/python/flink_agents/api/core_options.py index 5b575c3f6..f5247616c 100644 --- a/python/flink_agents/api/core_options.py +++ b/python/flink_agents/api/core_options.py @@ -82,6 +82,20 @@ class ErrorHandlingStrategy(Enum): IGNORE = "ignore" +class ShortTermMemoryTtlUpdate(Enum): + """Update policy for short-term memory TTL.""" + + ON_CREATE_AND_WRITE = "ON_CREATE_AND_WRITE" + ON_READ_AND_WRITE = "ON_READ_AND_WRITE" + + +class ShortTermMemoryTtlVisibility(Enum): + """Visibility policy for expired short-term memory state.""" + + NEVER_RETURN_EXPIRED = "NEVER_RETURN_EXPIRED" + RETURN_EXPIRED_IF_NOT_CLEANED_UP = "RETURN_EXPIRED_IF_NOT_CLEANED_UP" + + class LoggerType(Enum): """Built-in event logger types. @@ -179,3 +193,26 @@ class AgentExecutionOptions: config_type=bool, default=True, ) + + # Set to a positive value in milliseconds to enable short-term memory TTL; + # 0 disables it. + SHORT_TERM_MEMORY_STATE_TTL_MS = ConfigOption( + key="short-term-memory.state-ttl.ms", + config_type=int, + default=0, + ) + + # Update policy for short-term memory TTL, consulted only when TTL is enabled. + SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE = ConfigOption( + key="short-term-memory.state-ttl.update-type", + config_type=ShortTermMemoryTtlUpdate, + default=ShortTermMemoryTtlUpdate.ON_READ_AND_WRITE, + ) + + # Visibility policy for expired short-term memory state, consulted only when TTL + # is enabled. + SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY = ConfigOption( + key="short-term-memory.state-ttl.visibility", + config_type=ShortTermMemoryTtlVisibility, + default=ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED, + ) diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/short_term_memory_ttl_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/short_term_memory_ttl_test.py new file mode 100644 index 000000000..99564de15 --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/short_term_memory_ttl_test.py @@ -0,0 +1,178 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os +import sysconfig +import time +from pathlib import Path +from typing import Any + +from pydantic import BaseModel +from pyflink.common import Configuration +from pyflink.datastream import KeySelector, StreamExecutionEnvironment + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.core_options import ( + AgentExecutionOptions, + ShortTermMemoryTtlUpdate, + ShortTermMemoryTtlVisibility, +) +from flink_agents.api.decorators import action +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.api.runner_context import RunnerContext + +current_dir = Path(__file__).parent +os.environ["PYTHONPATH"] = ( + f"{current_dir.parent.parent.parent}:{sysconfig.get_paths()['purelib']}" +) + + +class TtlTestInput(BaseModel): + event_key: str + sleep_ms: int + + +class TtlTestKeySelector(KeySelector): + def get_key(self, value: TtlTestInput) -> str: + return "test_key" + + +class ShortTermMemoryTtlTestAgent(Agent): + @action(InputEvent.EVENT_TYPE) + @staticmethod + def input(event: Event, ctx: RunnerContext) -> None: + input_data = TtlTestInput.model_validate(InputEvent.from_event(event).input) + + short_term_memory = ctx.short_term_memory + existing_value = short_term_memory.get(input_data.event_key) + current_count = 0 + if isinstance(existing_value, int): + current_count = existing_value + elif isinstance(existing_value, float): + current_count = int(existing_value) + + short_term_memory.set(input_data.event_key, current_count + 1) + time.sleep(input_data.sleep_ms / 1000) + ctx.send_event( + OutputEvent( + output=( + f"{input_data.event_key}|" + f"{'NEW' if existing_value is None else 'EXISTING'}" + ) + ) + ) + + +def run_scenario( + ttl_ms: int, + sleep_ms: int, + *, + configure_ttl_ms: bool, + configure_ttl_options: bool, +) -> list[Any]: + config = Configuration() + config.set_string("python.pythonpath", os.environ["PYTHONPATH"]) + env = StreamExecutionEnvironment.get_execution_environment(config) + env.set_parallelism(1) + + input_stream = env.from_collection( + [ + TtlTestInput(event_key="event1", sleep_ms=sleep_ms), + TtlTestInput(event_key="event2", sleep_ms=sleep_ms), + TtlTestInput(event_key="event1", sleep_ms=sleep_ms), + ] + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + agents_config = agents_env.get_config() + if configure_ttl_ms: + agents_config.set(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS, ttl_ms) + if configure_ttl_options: + agents_config.set( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE, + ShortTermMemoryTtlUpdate.ON_CREATE_AND_WRITE, + ) + agents_config.set( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY, + ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED, + ) + + output_datastream = ( + agents_env.from_datastream( + input=input_stream, key_selector=TtlTestKeySelector() + ) + .apply(ShortTermMemoryTtlTestAgent()) + .to_datastream() + ) + + return list(output_datastream.execute_and_collect()) + + +def test_value_still_visible_before_ttl_expiry() -> None: + results = run_scenario( + 1000, + 0, + configure_ttl_ms=True, + configure_ttl_options=True, + ) + + assert results == ["event1|NEW", "event2|NEW", "event1|EXISTING"] + + +def test_ttl_configuration_disabled_with_zero_ttl() -> None: + results = run_scenario( + 0, + 2000, + configure_ttl_ms=True, + configure_ttl_options=True, + ) + + assert results == ["event1|NEW", "event2|NEW", "event1|EXISTING"] + + +def test_ttl_configuration_disabled_by_default() -> None: + results = run_scenario( + 0, + 2000, + configure_ttl_ms=False, + configure_ttl_options=True, + ) + + assert results == ["event1|NEW", "event2|NEW", "event1|EXISTING"] + + +def test_value_expires_after_ttl() -> None: + results = run_scenario( + 1000, + 2000, + configure_ttl_ms=True, + configure_ttl_options=True, + ) + + assert results == ["event1|NEW", "event2|NEW", "event1|NEW"] + + +def test_ttl_configuration_applied_with_default_update_type_and_visibility() -> None: + results = run_scenario( + 1000, + 2000, + configure_ttl_ms=True, + configure_ttl_options=False, + ) + + assert results == ["event1|NEW", "event2|NEW", "event1|NEW"] diff --git a/python/flink_agents/plan/tests/compatibility/create_python_option_from_java_option.py b/python/flink_agents/plan/tests/compatibility/create_python_option_from_java_option.py index c2da3714d..b7251ee17 100644 --- a/python/flink_agents/plan/tests/compatibility/create_python_option_from_java_option.py +++ b/python/flink_agents/plan/tests/compatibility/create_python_option_from_java_option.py @@ -19,7 +19,12 @@ from pyflink.util.java_utils import add_jars_to_context_class_loader -from flink_agents.api.core_options import AgentConfigOptions +from flink_agents.api.core_options import ( + AgentConfigOptions, + AgentExecutionOptions, + ShortTermMemoryTtlUpdate, + ShortTermMemoryTtlVisibility, +) # This script is used to verify that Java-defined configuration options # (e.g., AgentConfigOptions) are correctly exposed and accessible in the @@ -39,3 +44,36 @@ assert AgentConfigOptions.BASE_LOG_DIR.get_key() == "baseLogDir" assert AgentConfigOptions.BASE_LOG_DIR.get_type() is str assert AgentConfigOptions.BASE_LOG_DIR.get_default_value() is None + + assert ( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS.get_key() + == "short-term-memory.state-ttl.ms" + ) + assert AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS.get_type() is int + assert AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS.get_default_value() == 0 + + assert ( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE.get_key() + == "short-term-memory.state-ttl.update-type" + ) + assert ( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE.get_type() + is ShortTermMemoryTtlUpdate + ) + assert ( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE.get_default_value() + is ShortTermMemoryTtlUpdate.ON_READ_AND_WRITE + ) + + assert ( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY.get_key() + == "short-term-memory.state-ttl.visibility" + ) + assert ( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY.get_type() + is ShortTermMemoryTtlVisibility + ) + assert ( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY.get_default_value() + is ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED + ) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index a0a3813da..dde720eec 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -141,7 +141,7 @@ public void setup( public void open() throws Exception { super.open(); - stateManager.initializeKeyedStates(getRuntimeContext()); + stateManager.initializeKeyedStates(getRuntimeContext(), agentPlan); stateManager.initializeOperatorStates(getOperatorStateBackend()); // ResourceCache constructs its own long-lived ResourceContextImpl internally; on diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java index 843a02078..4d4263739 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -19,11 +19,16 @@ package org.apache.flink.agents.runtime.operator; import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.agents.AgentExecutionOptions; +import org.apache.flink.agents.api.agents.ShortTermMemoryTtlUpdate; +import org.apache.flink.agents.api.agents.ShortTermMemoryTtlVisibility; +import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -37,6 +42,8 @@ import javax.annotation.Nullable; +import java.time.Duration; + import static org.apache.flink.agents.runtime.utils.StateUtil.*; /** @@ -56,9 +63,9 @@ * *

Lifecycle: instantiated by the operator's {@code initializeState()} (the Flink lifecycle runs * {@code initializeState} before {@code open}). Both {@link - * #initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext)} and {@link - * #initializeOperatorStates(OperatorStateBackend)} are invoked later from the operator's {@code - * open()}. There is no explicit close — the underlying state handles are owned by Flink. + * #initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext, AgentPlan)} and + * {@link #initializeOperatorStates(OperatorStateBackend)} are invoked later from the operator's + * {@code open()}. There is no explicit close — the underlying state handles are owned by Flink. * *

Design constraint: package-private; no manager-to-manager held references. Cross-cutting data * flows via method parameters (see for example {@link ActionTaskContextManager#transferContexts} @@ -87,7 +94,9 @@ class OperatorStateManager { * * @param runtimeContext the operator's runtime context, used to obtain keyed state handles. */ - void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext runtimeContext) + void initializeKeyedStates( + org.apache.flink.api.common.functions.RuntimeContext runtimeContext, + AgentPlan agentPlan) throws Exception { // init sensoryMemState MapStateDescriptor sensoryMemStateDescriptor = @@ -103,6 +112,7 @@ void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext "shortTermMemory", TypeInformation.of(String.class), TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); + maybeEnableShortTermMemoryTTL(shortTermMemStateDescriptor, agentPlan); shortTermMemState = runtimeContext.getMapState(shortTermMemStateDescriptor); // init sequence number state for per key message ordering @@ -121,6 +131,63 @@ void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class))); } + /** + * When {@link AgentExecutionOptions#SHORT_TERM_MEMORY_STATE_TTL_MS} is positive, attaches Flink + * {@link StateTtlConfig} to the short-term memory {@link MapStateDescriptor}. Unset, null, or + * non-positive values disable TTL (Flink does not allow zero/negative TTL). + */ + private void maybeEnableShortTermMemoryTTL( + MapStateDescriptor descriptor, + AgentPlan agentPlan) { + Long ttlMs = + agentPlan.getConfig().get(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS); + if (ttlMs == null || ttlMs <= 0) { + return; + } + + ShortTermMemoryTtlUpdate updateType = + agentPlan + .getConfig() + .get(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE); + + ShortTermMemoryTtlVisibility stateVisibility = + agentPlan + .getConfig() + .get(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY); + + StateTtlConfig ttlConfig = + StateTtlConfig.newBuilder(Duration.ofMillis(ttlMs)) + .setUpdateType(toFlinkUpdateType(updateType)) + .setStateVisibility(toFlinkStateVisibility(stateVisibility)) + .cleanupFullSnapshot() + .build(); + descriptor.enableTimeToLive(ttlConfig); + } + + private StateTtlConfig.UpdateType toFlinkUpdateType(ShortTermMemoryTtlUpdate updateType) { + switch (updateType) { + case ON_CREATE_AND_WRITE: + return StateTtlConfig.UpdateType.OnCreateAndWrite; + case ON_READ_AND_WRITE: + return StateTtlConfig.UpdateType.OnReadAndWrite; + default: + throw new IllegalArgumentException("Unsupported TTL update type: " + updateType); + } + } + + private StateTtlConfig.StateVisibility toFlinkStateVisibility( + ShortTermMemoryTtlVisibility stateVisibility) { + switch (stateVisibility) { + case NEVER_RETURN_EXPIRED: + return StateTtlConfig.StateVisibility.NeverReturnExpired; + case RETURN_EXPIRED_IF_NOT_CLEANED_UP: + return StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp; + default: + throw new IllegalArgumentException( + "Unsupported TTL state visibility: " + stateVisibility); + } + } + /** * Registers operator-level (non-keyed) state. * diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/ShortTermMemoryTTLIntegrationTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/ShortTermMemoryTTLIntegrationTest.java new file mode 100644 index 000000000..3317ca5e2 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/ShortTermMemoryTTLIntegrationTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.agents.AgentExecutionOptions; +import org.apache.flink.agents.api.agents.ShortTermMemoryTtlUpdate; +import org.apache.flink.agents.api.agents.ShortTermMemoryTtlVisibility; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** Integration test for Short-Term Memory TTL functionality. */ +class ShortTermMemoryTTLIntegrationTest { + + private static final String MEMORY_KEY = "test_key"; + + private static final class TestInput { + public String eventKey; + public long sleepMs; + + private TestInput() {} + + private TestInput(String eventKey, long sleepMs) { + this.eventKey = eventKey; + this.sleepMs = sleepMs; + } + } + + public static class TTLTestAgent extends Agent { + + @Action(listenEventTypes = {InputEvent.EVENT_TYPE}) + public static void input(org.apache.flink.agents.api.Event event, RunnerContext ctx) + throws Exception { + InputEvent inputEvent = (InputEvent) event; + TestInput input = (TestInput) inputEvent.getInput(); + + MemoryObject shortTermMemory = ctx.getShortTermMemory(); + MemoryObject memoryObject = shortTermMemory.get(input.eventKey); + + Object existingValue = null; + int currentCount = 0; + if (memoryObject != null && !memoryObject.isNestedObject()) { + existingValue = memoryObject.getValue(); + if (existingValue instanceof Integer) { + currentCount = (Integer) existingValue; + } else if (existingValue instanceof Number) { + currentCount = ((Number) existingValue).intValue(); + } + } + + shortTermMemory.set(input.eventKey, currentCount + 1); + Thread.sleep(input.sleepMs); + ctx.sendEvent( + new OutputEvent( + input.eventKey + "|" + (existingValue == null ? "NEW" : "EXISTING"))); + } + } + + @Test + void testValueStillVisibleBeforeTTLExpiry() throws Exception { + List results = runScenario(1000L, 0L, true, true); + + assertEquals(List.of("event1|NEW", "event2|NEW", "event1|EXISTING"), results); + } + + @Test + void testTTLConfigurationDisabledWithZeroTtl() throws Exception { + List results = runScenario(0L, 2000L, true, true); + + assertEquals(List.of("event1|NEW", "event2|NEW", "event1|EXISTING"), results); + } + + @Test + void testTTLConfigurationDisabledByDefault() throws Exception { + List results = runScenario(0L, 2000L, false, true); + + assertEquals(List.of("event1|NEW", "event2|NEW", "event1|EXISTING"), results); + } + + @Test + void testValueExpiresAfterTTL() throws Exception { + List results = runScenario(1000L, 2000L, true, true); + + assertEquals(List.of("event1|NEW", "event2|NEW", "event1|NEW"), results); + } + + @Test + void testTTLConfigurationAppliedWithDefaultUpdateTypeAndVisibility() throws Exception { + List results = runScenario(1000L, 2000L, true, false); + + assertEquals(List.of("event1|NEW", "event2|NEW", "event1|NEW"), results); + } + + private static List runScenario( + long ttlMs, long sleepMs, boolean configureTtlMs, boolean configureTtlOptions) + throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + AgentsExecutionEnvironment agentEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + AgentConfiguration agentsConfig = (AgentConfiguration) agentEnv.getConfig(); + if (configureTtlMs) { + agentsConfig.set(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS, ttlMs); + } + if (configureTtlOptions) { + agentsConfig.set( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE, + ShortTermMemoryTtlUpdate.ON_CREATE_AND_WRITE); + agentsConfig.set( + AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY, + ShortTermMemoryTtlVisibility.NEVER_RETURN_EXPIRED); + } + + List testData = new ArrayList<>(); + testData.add(new TestInput("event1", sleepMs)); + testData.add(new TestInput("event2", sleepMs)); + testData.add(new TestInput("event1", sleepMs)); + + DataStream inputStream = env.fromCollection(testData); + DataStream outputStream = + agentEnv.fromDataStream(inputStream, x -> MEMORY_KEY) + .apply(new TTLTestAgent()) + .toDataStream(); + + List results = new ArrayList<>(); + outputStream.map(Object::toString).executeAndCollect().forEachRemaining(results::add); + return results; + } +}