diff --git a/docs/content.zh/docs/dev/table/functions/ptfs.md b/docs/content.zh/docs/dev/table/functions/ptfs.md index 4190e9e3ba5a6..008a83a4aab5c 100644 --- a/docs/content.zh/docs/dev/table/functions/ptfs.md +++ b/docs/content.zh/docs/dev/table/functions/ptfs.md @@ -2275,6 +2275,189 @@ void testScalarOnly() throws Exception { {{< /tab >}} {{< /tabs >}} +#### Testing with State + +The harness supports all PTF state types: value state, `Row`, `ListView`, and `MapView`. + +{{< tabs "state-testing" >}} +{{< tab "Java" >}} +```java +// A PTF that uses all four state types: value state, Row, ListView, and MapView. +@DataTypeHint("ROW") +public class StatefulPTF extends ProcessTableFunction { + public static class ValueState { + public long count = 0L; + } + + public void eval( + @StateHint ValueState valueState, + @StateHint(type = @DataTypeHint("ROW")) Row rowState, + @StateHint(type = @DataTypeHint("ARRAY")) ListView listState, + @StateHint MapView mapState, + @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) throws Exception { + // Value state — increment counter + valueState.count++; + + // Row state — track the last value seen + int value = input.getFieldAs("value"); + rowState.setField("lastValue", value); + + // ListView state — accumulate values + listState.add(value); + + // MapView state — count occurrences by name + String name = input.getFieldAs("name"); + Integer tagCount = mapState.get(name); + mapState.put(name, tagCount == null ? 1 : tagCount + 1); + + collect(Row.of(valueState.count)); + } +} + +@Test +void testWithState() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build()) { + + harness.processElement(Row.of("Alice", 10)); + harness.processElement(Row.of("Alice", 20)); + + List output = harness.getOutput(); + assertThat(output.get(0)).isEqualTo(Row.of("Alice", 1L)); + assertThat(output.get(1)).isEqualTo(Row.of("Alice", 2L)); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +**Initial State Setup**: Use `.withInitialStateForKey()` to pre-populate state before processing. +State initialization is scoped per partition key: + +{{< tabs "initial-state" >}} +{{< tab "Java" >}} +```java +@Test +void testWithInitialState() throws Exception { + // Value state + StatefulPTF.ValueState initialValue = new StatefulPTF.ValueState(); + initialValue.count = 100L; + + // Row state + Row initialRow = Row.withNames(); + initialRow.setField("lastValue", 42); + + // ListView state + ListView initialList = new ListView<>(); + initialList.add(10); + initialList.add(20); + + // MapView state + MapView initialMap = new MapView<>(); + initialMap.put("Alice", 5); + + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + // Initial state is set per partition key + .withInitialStateForKey("valueState", Row.of("Alice"), initialValue) + .withInitialStateForKey("rowState", Row.of("Alice"), initialRow) + .withInitialStateForKey("listState", Row.of("Alice"), initialList) + .withInitialStateForKey("mapState", Row.of("Alice"), initialMap) + .build()) { + + harness.processElement(Row.of("Alice", 10)); + + List output = harness.getOutput(); + assertThat(output).containsExactly(Row.of("Alice", 101L)); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +**State Introspection**: Use `getStateForKey()`, `getKeysForState()`, and `getStateForAllKeys()` to inspect state during tests: + +{{< tabs "state-introspection" >}} +{{< tab "Java" >}} +```java +@Test +void testStateIntrospection() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build()) { + + harness.processElement(Row.of("Alice", 10)); + harness.processElement(Row.of("Bob", 20)); + + // Check value state + StatefulPTF.ValueState aliceState = + harness.getStateForKey("valueState", Row.of("Alice")); + assertThat(aliceState.count).isEqualTo(1L); + + // Check Row state + Row aliceRowState = harness.getStateForKey("rowState", Row.of("Alice")); + assertThat(aliceRowState.getField("lastValue")).isEqualTo(10); + + // Check ListView state + ListView aliceList = harness.getStateForKey("listState", Row.of("Alice")); + assertThat(aliceList.getList()).containsExactly(10); + + // Check MapView state + MapView aliceMap = harness.getStateForKey("mapState", Row.of("Alice")); + assertThat(aliceMap.get("Alice")).isEqualTo(1); + + // Get all partition keys with state + Set keys = harness.getKeysForState("valueState"); + assertThat(keys).containsExactlyInAnyOrder(Row.of("Alice"), Row.of("Bob")); + + // Get all state across partition keys + Map allState = + harness.getStateForAllKeys("valueState"); + assertThat(allState.get(Row.of("Bob")).count).isEqualTo(1L); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +**State Mutation**: Use `setStateForKey()`, `clearStateForKey()`, and `clearStateEntryForKey()` to modify state during tests: + +{{< tabs "state-mutation" >}} +{{< tab "Java" >}} +```java +@Test +void testStateMutation() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build()) { + + harness.processElement(Row.of("Alice", 10)); + + // Overwrite a specific state entry for a partition key + StatefulPTF.ValueState newState = new StatefulPTF.ValueState(); + newState.count = 100L; + harness.setStateForKey("valueState", Row.of("Alice"), newState); + + // Clear a specific state entry (resets to default) + harness.clearStateEntryForKey("listState", Row.of("Alice")); + + // Clear all state for a partition key + harness.clearStateForKey(Row.of("Alice")); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + #### Configuring Table Argument Types In contexts where the harness can't infer the table argument types for table arguments (when using unannotated `Row` inputs, @@ -2348,8 +2531,8 @@ void testPOJO() throws Exception { ### PTF Features Unsupported by the TestHarness -- `Context` paramter -- State (`@StateHint`) +- `Context` parameter - Timers (`onTimer`) - `on_time` / `rowtime` - Update traits (`SUPPORTS_UPDATES`, `REQUIRE_UPDATE_BEFORE`) +- State TTL (state is supported but TTL expiration is not yet implemented) diff --git a/docs/content/docs/dev/table/functions/ptfs.md b/docs/content/docs/dev/table/functions/ptfs.md index 7181fc1a3d142..e7f4580d1241a 100644 --- a/docs/content/docs/dev/table/functions/ptfs.md +++ b/docs/content/docs/dev/table/functions/ptfs.md @@ -2278,6 +2278,189 @@ void testScalarOnly() throws Exception { {{< /tab >}} {{< /tabs >}} +#### Testing with State + +The harness supports all PTF state types: value state, `Row`, `ListView`, and `MapView`. + +{{< tabs "state-testing" >}} +{{< tab "Java" >}} +```java +// A PTF that uses all four state types: value state, Row, ListView, and MapView. +@DataTypeHint("ROW") +public class StatefulPTF extends ProcessTableFunction { + public static class ValueState { + public long count = 0L; + } + + public void eval( + @StateHint ValueState valueState, + @StateHint(type = @DataTypeHint("ROW")) Row rowState, + @StateHint(type = @DataTypeHint("ARRAY")) ListView listState, + @StateHint MapView mapState, + @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) throws Exception { + // Value state — increment counter + valueState.count++; + + // Row state — track the last value seen + int value = input.getFieldAs("value"); + rowState.setField("lastValue", value); + + // ListView state — accumulate values + listState.add(value); + + // MapView state — count occurrences by name + String name = input.getFieldAs("name"); + Integer tagCount = mapState.get(name); + mapState.put(name, tagCount == null ? 1 : tagCount + 1); + + collect(Row.of(valueState.count)); + } +} + +@Test +void testWithState() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build()) { + + harness.processElement(Row.of("Alice", 10)); + harness.processElement(Row.of("Alice", 20)); + + List output = harness.getOutput(); + assertThat(output.get(0)).isEqualTo(Row.of("Alice", 1L)); + assertThat(output.get(1)).isEqualTo(Row.of("Alice", 2L)); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +**Initial State Setup**: Use `.withInitialStateForKey()` to pre-populate state before processing. +State initialization is scoped per partition key: + +{{< tabs "initial-state" >}} +{{< tab "Java" >}} +```java +@Test +void testWithInitialState() throws Exception { + // Value state + StatefulPTF.ValueState initialValue = new StatefulPTF.ValueState(); + initialValue.count = 100L; + + // Row state + Row initialRow = Row.withNames(); + initialRow.setField("lastValue", 42); + + // ListView state + ListView initialList = new ListView<>(); + initialList.add(10); + initialList.add(20); + + // MapView state + MapView initialMap = new MapView<>(); + initialMap.put("Alice", 5); + + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + // Initial state is set per partition key + .withInitialStateForKey("valueState", Row.of("Alice"), initialValue) + .withInitialStateForKey("rowState", Row.of("Alice"), initialRow) + .withInitialStateForKey("listState", Row.of("Alice"), initialList) + .withInitialStateForKey("mapState", Row.of("Alice"), initialMap) + .build()) { + + harness.processElement(Row.of("Alice", 10)); + + List output = harness.getOutput(); + assertThat(output).containsExactly(Row.of("Alice", 101L)); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +**State Introspection**: Use `getStateForKey()`, `getKeysForState()`, and `getStateForAllKeys()` to inspect state during tests: + +{{< tabs "state-introspection" >}} +{{< tab "Java" >}} +```java +@Test +void testStateIntrospection() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build()) { + + harness.processElement(Row.of("Alice", 10)); + harness.processElement(Row.of("Bob", 20)); + + // Check value state + StatefulPTF.ValueState aliceState = + harness.getStateForKey("valueState", Row.of("Alice")); + assertThat(aliceState.count).isEqualTo(1L); + + // Check Row state + Row aliceRowState = harness.getStateForKey("rowState", Row.of("Alice")); + assertThat(aliceRowState.getField("lastValue")).isEqualTo(10); + + // Check ListView state + ListView aliceList = harness.getStateForKey("listState", Row.of("Alice")); + assertThat(aliceList.getList()).containsExactly(10); + + // Check MapView state + MapView aliceMap = harness.getStateForKey("mapState", Row.of("Alice")); + assertThat(aliceMap.get("Alice")).isEqualTo(1); + + // Get all partition keys with state + Set keys = harness.getKeysForState("valueState"); + assertThat(keys).containsExactlyInAnyOrder(Row.of("Alice"), Row.of("Bob")); + + // Get all state across partition keys + Map allState = + harness.getStateForAllKeys("valueState"); + assertThat(allState.get(Row.of("Bob")).count).isEqualTo(1L); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +**State Mutation**: Use `setStateForKey()`, `clearStateForKey()`, and `clearStateEntryForKey()` to modify state during tests: + +{{< tabs "state-mutation" >}} +{{< tab "Java" >}} +```java +@Test +void testStateMutation() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build()) { + + harness.processElement(Row.of("Alice", 10)); + + // Overwrite a specific state entry for a partition key + StatefulPTF.ValueState newState = new StatefulPTF.ValueState(); + newState.count = 100L; + harness.setStateForKey("valueState", Row.of("Alice"), newState); + + // Clear a specific state entry (resets to default) + harness.clearStateEntryForKey("listState", Row.of("Alice")); + + // Clear all state for a partition key + harness.clearStateForKey(Row.of("Alice")); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + #### Configuring Table Argument Types In contexts where the harness can't infer the table argument types for table arguments (when using unannotated `Row` inputs, @@ -2351,8 +2534,8 @@ void testPOJO() throws Exception { ### PTF Features Unsupported by the TestHarness -- `Context` paramter -- State (`@StateHint`) +- `Context` parameter - Timers (`onTimer`) - `on_time` / `rowtime` - Update traits (`SUPPORTS_UPDATES`, `REQUIRE_UPDATE_BEFORE`) +- State TTL (state is supported but TTL expiration is not yet implemented) diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ListViewStateConverter.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ListViewStateConverter.java new file mode 100644 index 0000000000000..07ed63309a849 --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ListViewStateConverter.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.dataview.ListView; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.conversion.DataStructureConverter; +import org.apache.flink.table.types.logical.ArrayType; + +import java.util.ArrayList; +import java.util.List; + +/** + * Converter for ListView state. + * + *

Converts between external ListView objects and internal ArrayData representation. + */ +@Internal +class ListViewStateConverter implements StateConverter { + + private final DataStructureConverter elementConverter; + private final ArrayData.ElementGetter elementGetter; + + ListViewStateConverter( + ArrayType arrayType, DataStructureConverter elementConverter) { + this.elementConverter = elementConverter; + this.elementGetter = ArrayData.createElementGetter(arrayType.getElementType()); + } + + @Override + public Object toInternal(Object external) { + ListView listView = (ListView) external; + List elements = listView.getList(); + + Object[] internalArray = new Object[elements.size()]; + for (int i = 0; i < elements.size(); i++) { + internalArray[i] = elementConverter.toInternal(elements.get(i)); + } + return new GenericArrayData(internalArray); + } + + @Override + public Object toExternal(Object internal) { + ArrayData arrayData = (ArrayData) internal; + ListView listView = new ListView<>(); + + List elements = new ArrayList<>(); + for (int i = 0; i < arrayData.size(); i++) { + Object internalElement = elementGetter.getElementOrNull(arrayData, i); + Object externalElement = elementConverter.toExternal(internalElement); + elements.add(externalElement); + } + listView.setList(elements); + return listView; + } + + @Override + public Object createNewInternalState() { + return new GenericArrayData(new Object[0]); + } +} diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/MapViewStateConverter.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/MapViewStateConverter.java new file mode 100644 index 0000000000000..b173f8d090c6f --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/MapViewStateConverter.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.dataview.MapView; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericMapData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.conversion.DataStructureConverter; +import org.apache.flink.table.types.logical.MapType; + +import java.util.HashMap; +import java.util.Map; + +/** + * Converter for MapView state. + * + *

Converts between external MapView objects and internal MapData representation. + */ +@Internal +class MapViewStateConverter implements StateConverter { + + private final DataStructureConverter keyConverter; + private final DataStructureConverter valueConverter; + private final ArrayData.ElementGetter keyGetter; + private final ArrayData.ElementGetter valueGetter; + + MapViewStateConverter( + MapType mapType, + DataStructureConverter keyConverter, + DataStructureConverter valueConverter) { + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + this.keyGetter = ArrayData.createElementGetter(mapType.getKeyType()); + this.valueGetter = ArrayData.createElementGetter(mapType.getValueType()); + } + + @Override + public Object toInternal(Object external) { + MapView mapView = (MapView) external; + Map entries = mapView.getMap(); + + Map internalMap = new HashMap<>(); + for (Map.Entry entry : entries.entrySet()) { + Object internalKey = keyConverter.toInternal(entry.getKey()); + Object internalValue = valueConverter.toInternal(entry.getValue()); + internalMap.put(internalKey, internalValue); + } + return new GenericMapData(internalMap); + } + + @Override + public Object toExternal(Object internal) { + MapData mapData = (MapData) internal; + MapView mapView = new MapView<>(); + + Map entries = new HashMap<>(); + ArrayData keyArray = mapData.keyArray(); + ArrayData valueArray = mapData.valueArray(); + for (int i = 0; i < keyArray.size(); i++) { + Object internalKey = keyGetter.getElementOrNull(keyArray, i); + Object internalValue = valueGetter.getElementOrNull(valueArray, i); + Object externalKey = keyConverter.toExternal(internalKey); + Object externalValue = valueConverter.toExternal(internalValue); + entries.put(externalKey, externalValue); + } + mapView.setMap(entries); + return mapView; + } + + @Override + public Object createNewInternalState() { + return new GenericMapData(new HashMap<>()); + } +} diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java index 425c88c375357..d08e9b82a79e0 100644 --- a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java @@ -20,8 +20,8 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.annotation.ArgumentTrait; -import org.apache.flink.table.annotation.StateHint; import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.conversion.DataStructureConverter; import org.apache.flink.table.data.conversion.DataStructureConverters; import org.apache.flink.table.functions.FunctionContext; @@ -31,12 +31,15 @@ import org.apache.flink.table.types.AbstractDataType; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.FieldsDataType; +import org.apache.flink.table.types.inference.StateTypeStrategy; import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.flink.table.types.inference.SystemTypeInference; import org.apache.flink.table.types.inference.TypeInference; import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.ArrayType; import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.MapType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.StructuredType; import org.apache.flink.table.types.utils.TypeConversions; @@ -47,6 +50,7 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Parameter; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.EnumSet; @@ -56,6 +60,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -85,16 +90,16 @@ @PublicEvolving public class ProcessTableFunctionTestHarness implements AutoCloseable { - /** Holds input and output converters for a table argument. */ - private static class ConverterPair { - final DataStructureConverter input; - final DataStructureConverter output; + /** Holds converters for transforming table argument input rows. */ + private static class TableArgumentConverters { + final DataStructureConverter toNamedRow; + final DataStructureConverter toEvalArgument; - ConverterPair( - DataStructureConverter input, - DataStructureConverter output) { - this.input = input; - this.output = output; + TableArgumentConverters( + DataStructureConverter toNamedRow, + DataStructureConverter toEvalArgument) { + this.toNamedRow = toNamedRow; + this.toEvalArgument = toEvalArgument; } } @@ -103,6 +108,7 @@ private static class ConverterPair { private final List output; private boolean isOpen; private final HarnessCollector collector; + private final TestHarnessStateManager stateManager; private final String defaultTableArgument; private final Method evalMethod; @@ -110,10 +116,9 @@ private static class ConverterPair { private final Map argumentsByName; private final boolean isSingleTableFunction; - private final Map scalarArgumentValues; private boolean hasTableArguments = false; - private final Map argumentConverters; + private final Map argumentConverters; private final DataStructureConverter harnessOutputConverter; private ProcessTableFunctionTestHarness( @@ -121,17 +126,17 @@ private ProcessTableFunctionTestHarness( FunctionContext functionContext, Method evalMethod, List arguments, - Map scalarArgumentValues, - Map argumentConverters, - DataStructureConverter harnessOutputConverter) + Map argumentConverters, + DataStructureConverter harnessOutputConverter, + TestHarnessStateManager stateManager) throws Exception { this.function = function; this.functionContext = functionContext; this.evalMethod = evalMethod; this.arguments = arguments; - this.scalarArgumentValues = scalarArgumentValues; this.argumentConverters = argumentConverters; this.harnessOutputConverter = harnessOutputConverter; + this.stateManager = stateManager; this.output = new ArrayList<>(); this.collector = new HarnessCollector(); this.isOpen = false; @@ -143,13 +148,8 @@ private ProcessTableFunctionTestHarness( } } - final List tableArguments = new ArrayList<>(); - for (ArgumentInfo arg : arguments) { - if (arg.isTableArgument) { - tableArguments.add(arg); - this.hasTableArguments = true; - } - } + final List tableArguments = ArgumentInfo.filterTableArguments(arguments); + this.hasTableArguments = !tableArguments.isEmpty(); if (tableArguments.size() == 1) { this.defaultTableArgument = tableArguments.get(0).name; @@ -213,11 +213,14 @@ public void processElementForTable(String tableArgument, Row row) throws Excepti checkState(isOpen, "Harness not open"); checkNotNull(tableArgument, "tableArgument must not be null"); - ArgumentInfo tableArg = argumentsByName.get(tableArgument); - if (tableArg == null) { + ArgumentInfo arg = argumentsByName.get(tableArgument); + if (arg == null) { throw new IllegalArgumentException("Unknown table argument: " + tableArgument); } - invokeEval(tableArg, row); + if (!(arg instanceof TableArgumentInfo)) { + throw new IllegalArgumentException("'" + tableArgument + "' is not a table argument"); + } + invokeEval((TableArgumentInfo) arg, row); } /** Process a single element for a specific table argument. */ @@ -250,7 +253,7 @@ public void process() throws Exception { + "Use processElement() or processElementForTable() instead."); } - Object[] args = arguments.stream().map(arg -> scalarArgumentValues.get(arg.name)).toArray(); + Object[] args = arguments.stream().map(arg -> ((ScalarArgumentInfo) arg).value).toArray(); try { evalMethod.invoke(function, args); @@ -270,43 +273,69 @@ public void clearOutput() { output.clear(); } - /** - * Given a target table argument and a row to process, construct the right set of arguments for - * the PTF's eval function and attempt to invoke it. - */ - private void invokeEval(ArgumentInfo activeTableArg, Row activeRow) throws Exception { - // Set collector context so it can prepend columns if needed - collector.setContext(activeTableArg, activeRow); + /** Get state for a specific partition key. */ + public T getStateForKey(String stateName, Row partitionKey) { + return stateManager.getStateForKey(stateName, partitionKey); + } - Object[] args = new Object[arguments.size()]; + /** Set state for a specific partition key. */ + public void setStateForKey(String stateName, Row partitionKey, Object state) throws Exception { + stateManager.setStateForKey(stateName, partitionKey, state); + } - for (int i = 0; i < arguments.size(); i++) { - ArgumentInfo arg = arguments.get(i); + /** Get all partition keys that have a specific state entry. */ + public Set getKeysForState(String stateName) { + return stateManager.getKeysForState(stateName); + } + + /** Get all state values for a state name across all partition keys. */ + public Map getStateForAllKeys(String stateName) { + return stateManager.getStateForAllKeys(stateName); + } + + /** Clear all state for a given partition key. */ + public void clearStateForKey(Row partitionKey) { + stateManager.clearStateForKey(partitionKey); + } - if (arg.isTableArgument && arg.name.equals(activeTableArg.name)) { - // If the argument is the active table argument, first convert the input row - // to an internal RowData type, and then convert the RowData to type that the - // argument expects. For Rows, this will structure the Row based on the table - // argument structure. Otherwise, for POJOs, it will pass the expected POJO to eval. + /** Clear specific state entry for a given partition key. */ + public void clearStateEntryForKey(String stateName, Row partitionKey) { + stateManager.clearStateEntryForKey(stateName, partitionKey); + } - ConverterPair pair = argumentConverters.get(arg.name); + private void invokeEval(TableArgumentInfo activeTableArg, Row activeRow) throws Exception { + TableArgumentConverters converters = argumentConverters.get(activeTableArg.name); - args[i] = pair.output.toExternalOrNull(pair.input.toInternalOrNull(activeRow)); + RowData rowData = (RowData) converters.toNamedRow.toInternal(activeRow); + Row namedRow = (Row) converters.toNamedRow.toExternal(rowData); + Object evalArgument = converters.toEvalArgument.toExternal(rowData); - } else if (arg.isScalar) { - args[i] = scalarArgumentValues.get(arg.name); + collector.setContext(activeTableArg, namedRow); - } else if (arg.isTableArgument) { - // Inactive table arguments receive null - args[i] = null; - } else { - throw new IllegalStateException( - "Unexpected argument type at position " + i + ": " + arg.name); + Row partitionKey = extractPartitionKey(activeTableArg, namedRow); + Map stateMap = stateManager.loadStateForKey(partitionKey); + + Object[] args = new Object[arguments.size()]; + int i = 0; + + for (ArgumentInfo arg : arguments) { + if (arg instanceof StateArgumentInfo) { + args[i++] = stateMap.get(arg.name); + } else if (arg instanceof TableArgumentInfo) { + TableArgumentInfo tableArg = (TableArgumentInfo) arg; + if (tableArg.name.equals(activeTableArg.name)) { + args[i++] = evalArgument; + } else { + args[i++] = null; + } + } else if (arg instanceof ScalarArgumentInfo) { + args[i++] = ((ScalarArgumentInfo) arg).value; } } try { evalMethod.invoke(function, args); + stateManager.updateStateForKey(partitionKey, stateMap); } catch (InvocationTargetException e) { String partitionInfo = activeTableArg.partitionColumnNames != null @@ -323,6 +352,16 @@ private void invokeEval(ArgumentInfo activeTableArg, Row activeRow) throws Excep } } + private Row extractPartitionKey(TableArgumentInfo tableArg, Row row) { + if (tableArg.partitionColumnNames == null || tableArg.partitionColumnNames.length == 0) { + return Row.of(); + } + + Object[] keyValues = + Arrays.stream(tableArg.partitionColumnNames).map(row::getField).toArray(); + return Row.of(keyValues); + } + /** Collector implementation that stores output in the harness. */ private class HarnessCollector implements Collector { private ArgumentInfo activeTableArg; @@ -337,10 +376,11 @@ void setContext(ArgumentInfo tableArg, Row row) { public void collect(OUT record) { OUT finalRecord; - if (activeTableArg == null || !activeTableArg.isTableArgument) { + if (activeTableArg == null || !(activeTableArg instanceof TableArgumentInfo)) { finalRecord = record; } else { - switch (activeTableArg.prependStrategy) { + TableArgumentInfo tableArg = (TableArgumentInfo) activeTableArg; + switch (tableArg.prependStrategy) { case ALL_COLUMNS: finalRecord = prependAllColumns(record); break; @@ -352,7 +392,7 @@ public void collect(OUT record) { break; default: throw new IllegalStateException( - "Unknown prepend strategy: " + activeTableArg.prependStrategy); + "Unknown prepend strategy: " + tableArg.prependStrategy); } } @@ -384,8 +424,11 @@ private OUT prependPartitionKeys(OUT ptfOutput) { int totalPartitionKeyCount = 0; for (ArgumentInfo arg : arguments) { - if (arg.isSetSemantic && arg.partitionColumnNames != null) { - totalPartitionKeyCount += arg.partitionColumnNames.length; + if (arg instanceof TableArgumentInfo) { + TableArgumentInfo tableArg = (TableArgumentInfo) arg; + if (tableArg.isSetSemantic && tableArg.partitionColumnNames != null) { + totalPartitionKeyCount += tableArg.partitionColumnNames.length; + } } } @@ -394,19 +437,22 @@ private OUT prependPartitionKeys(OUT ptfOutput) { Row result = new Row(ptfRow.getKind(), totalArity); - // Extract partition key values from active row - Object[] partitionKeyValues = new Object[activeTableArg.partitionColumnNames.length]; - for (int i = 0; i < activeTableArg.partitionColumnNames.length; i++) { - String columnName = activeTableArg.partitionColumnNames[i]; - int columnIndex = getFieldIndex(activeTableArg.dataType, columnName); + TableArgumentInfo activeTableInfo = (TableArgumentInfo) activeTableArg; + Object[] partitionKeyValues = new Object[activeTableInfo.partitionColumnNames.length]; + for (int i = 0; i < activeTableInfo.partitionColumnNames.length; i++) { + String columnName = activeTableInfo.partitionColumnNames[i]; + int columnIndex = getFieldIndex(activeTableInfo.dataType, columnName); partitionKeyValues[i] = activeRow.getField(columnIndex); } int resultIndex = 0; for (ArgumentInfo arg : arguments) { - if (arg.isSetSemantic && arg.partitionColumnNames != null) { - for (int i = 0; i < arg.partitionColumnNames.length; i++) { - result.setField(resultIndex++, partitionKeyValues[i]); + if (arg instanceof TableArgumentInfo) { + TableArgumentInfo tableArg = (TableArgumentInfo) arg; + if (tableArg.isSetSemantic && tableArg.partitionColumnNames != null) { + for (int i = 0; i < tableArg.partitionColumnNames.length; i++) { + result.setField(resultIndex++, partitionKeyValues[i]); + } } } } @@ -496,6 +542,7 @@ public static class Builder { private final Map scalarArgs = new HashMap<>(); private final Map tableArgs = new HashMap<>(); private final Map partitionConfigs = new HashMap<>(); + private final Map stateArgs = new HashMap<>(); private Builder(Class> functionClass) { this.functionClass = checkNotNull(functionClass, "functionClass must not be null"); @@ -568,6 +615,20 @@ public Builder withScalarArgument(String argumentName, Object value) { return this; } + /** Sets initial state for a state parameter. */ + public Builder withInitialStateForKey( + String stateName, Row partitionKey, Object state) { + checkNotNull(stateName, "stateName must not be null"); + checkNotNull(partitionKey, "partitionKey must not be null"); + checkNotNull(state, "state must not be null"); + + stateArgs + .computeIfAbsent(stateName, k -> new StateArgumentConfiguration()) + .initialValues + .put(partitionKey, state); + return this; + } + // --------------------------------------------------------------------- // Partitioning // --------------------------------------------------------------------- @@ -610,6 +671,8 @@ public Builder withPartitionBy(String argumentName, String... columnNames) public ProcessTableFunctionTestHarness build() throws Exception { ProcessTableFunction function = instantiateFunction(); + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + DataTypeFactory dataTypeFactory = createDataTypeFactory(); TypeInference baseTypeInference = function.getTypeInference(dataTypeFactory); TypeInference systemTypeInference = @@ -618,21 +681,42 @@ public ProcessTableFunctionTestHarness build() throws Exception { List arguments = extractAndValidateTypeInference(function, systemTypeInference); - FunctionContext functionContext = - new FunctionContext(null, Thread.currentThread().getContextClassLoader(), null); + FunctionContext functionContext = new FunctionContext(null, classLoader, null); Method evalMethod = findEvalMethod(); validateEvalMethodSupported(evalMethod, arguments); validatePartitionConsistency(arguments); + validateInitialStateKeys(arguments); + + Map argumentConverters = new HashMap<>(); + Map stateConverters = new HashMap<>(); + createConverters(arguments, argumentConverters, stateConverters, classLoader); + + // Create state manager + List stateArguments = ArgumentInfo.filterStateArguments(arguments); + TestHarnessStateManager stateManager = + new TestHarnessStateManager( + stateArguments, stateConverters, extractPartitionKeyInfo(arguments)); + + // Populate initial state + for (Map.Entry entry : stateArgs.entrySet()) { + String stateName = entry.getKey(); + for (Map.Entry stateEntry : + entry.getValue().initialValues.entrySet()) { + stateManager.setStateForKey( + stateName, stateEntry.getKey(), stateEntry.getValue()); + } + } - Map argumentConverters = new HashMap<>(); - createConverters(arguments, argumentConverters); + // Extract table arguments for output type derivation + // SystemTypeInference needs table semantics for pass-through column deduplication + List tableArgs = ArgumentInfo.filterTableArguments(arguments); - // Derive output schema using SystemTypeInference (includes deduplication) + // Derive output schema using SystemTypeInference DataType derivedOutputType = deriveOutputTypeFromSystemInference( - function, dataTypeFactory, systemTypeInference, arguments); + function, dataTypeFactory, systemTypeInference, tableArgs); // Create output converter for PTF emissions DataStructureConverter harnessOutputConverter = @@ -643,23 +727,9 @@ public ProcessTableFunctionTestHarness build() throws Exception { functionContext, evalMethod, arguments, - extractScalarValues(arguments), argumentConverters, - harnessOutputConverter); - } - - /** Extracts scalar values from configs, creating a map keyed by argument name. */ - private Map extractScalarValues(List arguments) { - Map values = new HashMap<>(); - for (ArgumentInfo arg : arguments) { - if (arg.isScalar) { - ScalarArgumentConfiguration config = scalarArgs.get(arg.name); - if (config != null) { - values.put(arg.name, config.value); - } - } - } - return values; + harnessOutputConverter, + stateManager); } /** @@ -675,58 +745,60 @@ private DataStructureConverter createPTFOutputConverter( } /** - * Creates and initializes data structure converters for all table arguments. + * Creates and initializes converters for all table and state arguments. * - *

For Row types, both input and output converters are the same (between Row and - * RowData). - * - *

For structured types, input converter uses Row types (Row to RowData), and the output - * converter uses the structured type. + *

For table arguments with Row types, both converters are the same (between Row and + * RowData). For structured types, toNamedRow uses Row type (Row to RowData), and + * toEvalArgument uses the structured type. */ private void createConverters( - List arguments, Map argumentConverters) { - ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + List arguments, + Map argumentConverters, + Map stateConverters, + ClassLoader classLoader) + throws Exception { + + for (StateArgumentInfo stateArg : ArgumentInfo.filterStateArguments(arguments)) { + StateConverter converter = createStateConverter(stateArg.dataType, classLoader); + stateConverters.put(stateArg.name, converter); + } - for (ArgumentInfo arg : arguments) { - if (arg.isTableArgument) { - String converterKey = arg.name; - - LogicalType logicalType = arg.dataType.getLogicalType(); - boolean isStructuredType = - logicalType instanceof StructuredType - && ((StructuredType) logicalType) - .getImplementationClass() - .isPresent(); - - if (isStructuredType) { - StructuredType structuredType = (StructuredType) logicalType; - List rowFields = new ArrayList<>(); - for (StructuredType.StructuredAttribute attr : - structuredType.getAttributes()) { - rowFields.add(new RowType.RowField(attr.getName(), attr.getType())); - } - RowType rowType = new RowType(logicalType.isNullable(), rowFields); - DataType rowDataType = TypeConversions.fromLogicalToDataType(rowType); + for (TableArgumentInfo tableArg : ArgumentInfo.filterTableArguments(arguments)) { + String converterKey = tableArg.name; + + LogicalType logicalType = tableArg.dataType.getLogicalType(); + boolean isStructuredType = + logicalType instanceof StructuredType + && ((StructuredType) logicalType) + .getImplementationClass() + .isPresent(); + + if (isStructuredType) { + StructuredType structuredType = (StructuredType) logicalType; + List rowFields = new ArrayList<>(); + for (StructuredType.StructuredAttribute attr : structuredType.getAttributes()) { + rowFields.add(new RowType.RowField(attr.getName(), attr.getType())); + } + RowType rowType = new RowType(logicalType.isNullable(), rowFields); + DataType rowDataType = TypeConversions.fromLogicalToDataType(rowType); - DataStructureConverter inputConverter = - DataStructureConverters.getConverter(rowDataType); - inputConverter.open(classLoader); + DataStructureConverter toNamedRow = + DataStructureConverters.getConverter(rowDataType); + toNamedRow.open(classLoader); - DataStructureConverter outputConverter = - DataStructureConverters.getConverter(arg.dataType); - outputConverter.open(classLoader); + DataStructureConverter toEvalArgument = + DataStructureConverters.getConverter(tableArg.dataType); + toEvalArgument.open(classLoader); - argumentConverters.put( - converterKey, new ConverterPair(inputConverter, outputConverter)); - } else { - // For Row types, input and output converters are the same - DataStructureConverter converter = - DataStructureConverters.getConverter(arg.dataType); - converter.open(classLoader); + argumentConverters.put( + converterKey, new TableArgumentConverters(toNamedRow, toEvalArgument)); + } else { + DataStructureConverter converter = + DataStructureConverters.getConverter(tableArg.dataType); + converter.open(classLoader); - argumentConverters.put( - converterKey, new ConverterPair(converter, converter)); - } + argumentConverters.put( + converterKey, new TableArgumentConverters(converter, converter)); } } } @@ -757,8 +829,8 @@ private Method findEvalMethod() throws NoSuchMethodException { } /** - * Validates that the eval() method doesn't use unsupported features. Temporary, until state - * and context is supported. + * Validates that the eval() method doesn't use unsupported features. Temporary, until + * context is supported. */ private void validateEvalMethodSupported(Method evalMethod, List arguments) { Parameter[] parameters = evalMethod.getParameters(); @@ -774,34 +846,32 @@ private void validateEvalMethodSupported(Method evalMethod, List a + "Found Context parameter at position %d in eval() method. ", i)); } - - if (param.isAnnotationPresent(StateHint.class)) { - throw new IllegalStateException( - String.format( - "ProcessTableFunctionTestHarness does not yet support state parameters. " - + "Found @StateHint parameter at position %d in eval() method. ", - i)); - } } if (parameters.length != arguments.size()) { + long stateCount = ArgumentInfo.filterStateArguments(arguments).size(); + long nonStateCount = arguments.size() - stateCount; throw new IllegalStateException( String.format( - "Parameter count mismatch: eval() has %d parameters but only %d arguments were extracted. " - + "This may indicate missing @ArgumentHint annotations.", - parameters.length, arguments.size())); + "Parameter count mismatch: eval() has %d parameters but expected %d " + + "(%d state + %d table/scalar arguments). " + + "eval() signature: %s. " + + "This may indicate missing @ArgumentHint or @StateHint annotations.", + parameters.length, + arguments.size(), + stateCount, + nonStateCount, + evalMethod)); } - for (int i = 0; i < parameters.length; i++) { + for (int i = 0; i < arguments.size(); i++) { Parameter param = parameters[i]; Class paramType = param.getType(); ArgumentInfo arg = arguments.get(i); - if (arg.isScalar) { - ScalarArgumentConfiguration config = scalarArgs.get(arg.name); - if (config != null - && config.value != null - && !paramType.isAssignableFrom(config.value.getClass())) { + if (arg instanceof ScalarArgumentInfo) { + Object value = ((ScalarArgumentInfo) arg).value; + if (value != null && !paramType.isAssignableFrom(value.getClass())) { throw new IllegalStateException( String.format( "Type mismatch for scalar argument '%s' at position %d: " @@ -809,7 +879,7 @@ private void validateEvalMethodSupported(Method evalMethod, List a arg.name, i, paramType.getName(), - config.value.getClass().getName())); + value.getClass().getName())); } } } @@ -821,10 +891,13 @@ private void validateEvalMethodSupported(Method evalMethod, List a * matching data types. */ private void validatePartitionConsistency(List arguments) { - final List partitionedTables = new ArrayList<>(); + final List partitionedTables = new ArrayList<>(); for (ArgumentInfo arg : arguments) { - if (arg.isSetSemantic && arg.partitionColumnNames != null) { - partitionedTables.add(arg); + if (arg instanceof TableArgumentInfo) { + TableArgumentInfo tableArg = (TableArgumentInfo) arg; + if (tableArg.isSetSemantic && tableArg.partitionColumnNames != null) { + partitionedTables.add(tableArg); + } } } @@ -832,11 +905,11 @@ private void validatePartitionConsistency(List arguments) { return; } - final ArgumentInfo first = partitionedTables.get(0); + final TableArgumentInfo first = partitionedTables.get(0); final int expectedPartitionColumnCount = first.partitionColumnNames.length; for (int i = 1; i < partitionedTables.size(); i++) { - ArgumentInfo current = partitionedTables.get(i); + TableArgumentInfo current = partitionedTables.get(i); if (current.partitionColumnNames.length != expectedPartitionColumnCount) { throw new IllegalArgumentException( @@ -876,17 +949,74 @@ private void validatePartitionConsistency(List arguments) { } } - private DataType extractPartitionColumnType(ArgumentInfo arg, String columnName) { - if (!(arg.dataType instanceof FieldsDataType)) { + private void validateInitialStateKeys(List arguments) { + if (stateArgs.isEmpty()) { + return; + } + + // all partitioned tables share the same partition key shape, so any one is + // sufficient for validation. + Optional partitionedTable = + arguments.stream() + .filter(arg -> arg instanceof TableArgumentInfo) + .map(arg -> (TableArgumentInfo) arg) + .filter(t -> t.isSetSemantic && t.partitionColumnNames != null) + .findFirst(); + + if (partitionedTable.isEmpty()) { + return; + } + + TableArgumentInfo table = partitionedTable.get(); + int expectedArity = table.partitionColumnNames.length; + LogicalType[] expectedTypes = + Arrays.stream(table.partitionColumnNames) + .map(col -> extractPartitionColumnType(table, col).getLogicalType()) + .toArray(LogicalType[]::new); + + for (Map.Entry entry : stateArgs.entrySet()) { + for (Row key : entry.getValue().initialValues.keySet()) { + if (key.getArity() != expectedArity) { + throw new IllegalArgumentException( + String.format( + "Initial state key for state '%s' has arity %d, " + + "but partition key has arity %d.", + entry.getKey(), key.getArity(), expectedArity)); + } + + for (int i = 0; i < expectedArity; i++) { + Object value = key.getField(i); + Class expectedClass = expectedTypes[i].getDefaultConversion(); + if (value != null && !expectedClass.isInstance(value)) { + throw new IllegalArgumentException( + String.format( + "Initial state key for state '%s' has type %s " + + "at position %d, but partition column '%s' " + + "expects %s.", + entry.getKey(), + value.getClass().getSimpleName(), + i, + table.partitionColumnNames[i], + expectedClass.getSimpleName())); + } + } + } + } + } + + private DataType extractPartitionColumnType(TableArgumentInfo tableArg, String columnName) { + if (!(tableArg.dataType instanceof FieldsDataType)) { throw new IllegalStateException( String.format( "Cannot extract data type for partition column '%s' of argument '%s': " + "argument data type is not a FieldsDataType (actual: %s)", - columnName, arg.name, arg.dataType.getClass().getSimpleName())); + columnName, + tableArg.name, + tableArg.dataType.getClass().getSimpleName())); } - FieldsDataType fieldsDataType = (FieldsDataType) arg.dataType; - List fieldNames = getFieldNames(arg.dataType); + FieldsDataType fieldsDataType = (FieldsDataType) tableArg.dataType; + List fieldNames = getFieldNames(tableArg.dataType); List fieldDataTypes = fieldsDataType.getChildren(); int fieldIndex = fieldNames.indexOf(columnName); @@ -897,7 +1027,30 @@ private DataType extractPartitionColumnType(ArgumentInfo arg, String columnName) throw new IllegalStateException( String.format( "Partition column '%s' not found in argument '%s'", - columnName, arg.name)); + columnName, tableArg.name)); + } + + private TestHarnessStateManager.PartitionKeyInfo extractPartitionKeyInfo( + List arguments) { + Optional partitionedTable = + arguments.stream() + .filter(arg -> arg instanceof TableArgumentInfo) + .map(arg -> (TableArgumentInfo) arg) + .filter(t -> t.isSetSemantic && t.partitionColumnNames != null) + .findFirst(); + + if (partitionedTable.isEmpty()) { + return new TestHarnessStateManager.PartitionKeyInfo(0, null, null); + } + + TableArgumentInfo table = partitionedTable.get(); + String[] columnNames = table.partitionColumnNames; + LogicalType[] columnTypes = + Arrays.stream(columnNames) + .map(col -> extractPartitionColumnType(table, col).getLogicalType()) + .toArray(LogicalType[]::new); + return new TestHarnessStateManager.PartitionKeyInfo( + columnNames.length, columnNames, columnTypes); } // --------------------------------------------------------------------- @@ -911,7 +1064,8 @@ private DataType extractPartitionColumnType(ArgumentInfo arg, String columnName) * table argument rules, static argument trait validation, etc. */ private List extractAndValidateTypeInference( - ProcessTableFunction function, TypeInference systemTypeInference) { + ProcessTableFunction function, TypeInference systemTypeInference) + throws Exception { Optional> staticArgsOpt = systemTypeInference.getStaticArguments(); if (staticArgsOpt.isEmpty()) { @@ -928,7 +1082,7 @@ private List extractAndValidateTypeInference( } } - List arguments = new ArrayList<>(); + List tableAndScalarArguments = new ArrayList<>(); for (StaticArgument staticArg : userArgs) { boolean isScalar = staticArg.getTraits().contains(StaticArgumentTrait.SCALAR); @@ -940,7 +1094,7 @@ private List extractAndValidateTypeInference( if (isScalar || isTableArg) { ArgumentInfo argInfo = buildArgumentInfo(staticArg); - arguments.add(argInfo); + tableAndScalarArguments.add(argInfo); } else { throw new IllegalStateException( "Unknown argument type for StaticArgument. " @@ -948,9 +1102,98 @@ private List extractAndValidateTypeInference( } } - validateArgumentConfiguration(arguments); + validateArgumentConfiguration(tableAndScalarArguments); + + // Extract state arguments from TypeInference + List stateArguments = new ArrayList<>(); + + Map stateStrategies = + systemTypeInference.getStateTypeStrategies(); - return arguments; + DataTypeFactory dataTypeFactory = createDataTypeFactory(); + + List tableArgs = + ArgumentInfo.filterTableArguments(tableAndScalarArguments); + List argumentDataTypes = new ArrayList<>(); + for (TableArgumentInfo tArg : tableArgs) { + argumentDataTypes.add(tArg.dataType); + } + Map tableSemanticsMap = new HashMap<>(); + for (int i = 0; i < tableArgs.size(); i++) { + TableArgumentInfo tArg = tableArgs.get(i); + int[] partitionIndices = getPartitionColumnIndices(tArg); + tableSemanticsMap.put( + i, new TestHarnessTableSemantics(tArg.dataType, partitionIndices)); + } + + TestHarnessCallContext callContext = new TestHarnessCallContext(); + callContext.typeFactory = dataTypeFactory; + callContext.argumentDataTypes = argumentDataTypes; + callContext.tableSemantics = tableSemanticsMap; + callContext.functionDefinition = function; + callContext.name = function.getClass().getSimpleName(); + + for (Map.Entry entry : stateStrategies.entrySet()) { + String stateName = entry.getKey(); + StateTypeStrategy strategy = entry.getValue(); + + Optional dataTypeOpt = strategy.inferType(callContext); + if (dataTypeOpt.isEmpty()) { + throw new IllegalStateException( + String.format( + "Could not infer data type for state parameter '%s'", + stateName)); + } + DataType stateDataType = dataTypeOpt.get(); + + Optional ttlOpt = strategy.getTimeToLive(callContext); + stateArguments.add( + new StateArgumentInfo(stateName, stateDataType, ttlOpt.orElse(null))); + } + + List allArguments = new ArrayList<>(); + allArguments.addAll(stateArguments); + allArguments.addAll(tableAndScalarArguments); + + return allArguments; + } + + /** Creates appropriate StateConverter for the given state data type. */ + private StateConverter createStateConverter(DataType stateDataType, ClassLoader classLoader) + throws Exception { + LogicalType logicalType = stateDataType.getLogicalType(); + + if (logicalType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) logicalType; + DataType elementType = stateDataType.getChildren().get(0); + DataStructureConverter elementConverter = + DataStructureConverters.getConverter(elementType); + elementConverter.open(classLoader); + return new ListViewStateConverter(arrayType, elementConverter); + } else if (logicalType instanceof MapType) { + MapType mapType = (MapType) logicalType; + DataType keyType = stateDataType.getChildren().get(0); + DataType valueType = stateDataType.getChildren().get(1); + DataStructureConverter keyConverter = + DataStructureConverters.getConverter(keyType); + DataStructureConverter valueConverter = + DataStructureConverters.getConverter(valueType); + keyConverter.open(classLoader); + valueConverter.open(classLoader); + return new MapViewStateConverter(mapType, keyConverter, valueConverter); + } else if (logicalType instanceof RowType) { + RowType rowType = (RowType) logicalType; + DataStructureConverter converter = + DataStructureConverters.getConverter(stateDataType); + converter.open(classLoader); + return new RowStateConverter(converter, rowType); + } else { + DataStructureConverter converter = + DataStructureConverters.getConverter(stateDataType); + converter.open(classLoader); + Class stateClass = stateDataType.getConversionClass(); + return new StructuredTypeStateConverter(converter, stateClass); + } } /** Checks if an argument name is a system-reserved argument. */ @@ -1034,8 +1277,14 @@ private ArgumentInfo buildArgumentInfo(StaticArgument staticArg) { boolean hasPassColumnsThrough = staticArg.getTraits().contains(StaticArgumentTrait.PASS_COLUMNS_THROUGH); - return new ArgumentInfo( - name, dataType, primaryTrait, partitionColumnNames, hasPassColumnsThrough); + if (primaryTrait == ArgumentTrait.SCALAR) { + ScalarArgumentConfiguration config = scalarArgs.get(name); + Object value = config != null ? config.value : null; + return new ScalarArgumentInfo(name, dataType, value); + } else { + return new TableArgumentInfo( + name, dataType, primaryTrait, partitionColumnNames, hasPassColumnsThrough); + } } private ArgumentTrait extractPrimaryTrait(EnumSet staticTraits) { @@ -1083,7 +1332,7 @@ private String[] extractAndValidatePartitionColumns( /** Validates scalar argument values are configured and no unknown arguments exist. */ private void validateArgumentConfiguration(List arguments) { for (ArgumentInfo arg : arguments) { - if (arg.isScalar && !scalarArgs.containsKey(arg.name)) { + if (arg instanceof ScalarArgumentInfo && !scalarArgs.containsKey(arg.name)) { throw new IllegalStateException( String.format( "Missing required scalar argument '%s'. " @@ -1137,22 +1386,20 @@ private DataType deriveOutputTypeFromSystemInference( ProcessTableFunction function, DataTypeFactory dataTypeFactory, TypeInference systemTypeInference, - List arguments) { + List arguments) { List argumentDataTypes = new ArrayList<>(); - for (ArgumentInfo arg : arguments) { + for (TableArgumentInfo arg : arguments) { argumentDataTypes.add(arg.dataType); } Map tableSemanticsMap = new HashMap<>(); for (int i = 0; i < arguments.size(); i++) { - ArgumentInfo arg = arguments.get(i); - if (arg.isTableArgument) { - int[] partitionIndices = getPartitionColumnIndices(arg); - TableSemantics semantics = - new TestHarnessTableSemantics(arg.dataType, partitionIndices); - tableSemanticsMap.put(i, semantics); - } + TableArgumentInfo arg = arguments.get(i); + int[] partitionIndices = getPartitionColumnIndices(arg); + TableSemantics semantics = + new TestHarnessTableSemantics(arg.dataType, partitionIndices); + tableSemanticsMap.put(i, semantics); } TestHarnessCallContext callContext = new TestHarnessCallContext(); @@ -1174,28 +1421,12 @@ private DataType deriveOutputTypeFromSystemInference( return outputTypeOpt.get(); } - private static List extractFieldNames(DataType dataType) { - LogicalType logicalType = dataType.getLogicalType(); - if (logicalType instanceof RowType) { - return ((RowType) logicalType).getFieldNames(); - } else if (logicalType instanceof StructuredType) { - return ((StructuredType) logicalType) - .getAttributes().stream() - .map(StructuredType.StructuredAttribute::getName) - .collect(java.util.stream.Collectors.toList()); - } else { - throw new IllegalStateException( - "Expected RowType or StructuredType, got: " - + logicalType.getClass().getSimpleName()); - } - } - - private int[] getPartitionColumnIndices(ArgumentInfo arg) { + private int[] getPartitionColumnIndices(TableArgumentInfo arg) { if (arg.partitionColumnNames == null || arg.partitionColumnNames.length == 0) { return new int[0]; } - List fieldNames = extractFieldNames(arg.dataType); + List fieldNames = getFieldNames(arg.dataType); int[] indices = new int[arg.partitionColumnNames.length]; for (int i = 0; i < arg.partitionColumnNames.length; i++) { @@ -1252,31 +1483,66 @@ private void handleEvalInvocationException( } /** - * Metadata for a single argument extracted from type inference. + * Base class for PTF eval() arguments. * *

Represents validated argument information combining PTF signature, type inference results, * and builder configuration. */ - private static class ArgumentInfo { + private abstract static class ArgumentInfo { final String name; final DataType dataType; + + ArgumentInfo(String name, DataType dataType) { + this.name = name; + this.dataType = dataType; + } + + static List filterStateArguments(List arguments) { + return arguments.stream() + .filter(arg -> arg instanceof StateArgumentInfo) + .map(arg -> (StateArgumentInfo) arg) + .collect(Collectors.toList()); + } + + static List filterTableArguments(List arguments) { + return arguments.stream() + .filter(arg -> arg instanceof TableArgumentInfo) + .map(arg -> (TableArgumentInfo) arg) + .collect(Collectors.toList()); + } + + static List filterScalarArguments(List arguments) { + return arguments.stream() + .filter(arg -> arg instanceof ScalarArgumentInfo) + .map(arg -> (ScalarArgumentInfo) arg) + .collect(Collectors.toList()); + } + } + + /** State parameter with TTL configuration. */ + static class StateArgumentInfo extends ArgumentInfo { + final Duration ttl; + + StateArgumentInfo(String name, DataType dataType, Duration ttl) { + super(name, dataType); + this.ttl = ttl; + } + } + + /** Table argument with partitioning and output prepending strategy. */ + private static class TableArgumentInfo extends ArgumentInfo { final String[] partitionColumnNames; - final boolean isScalar; - final boolean isTableArgument; final boolean isSetSemantic; final OutputPrependStrategy prependStrategy; - ArgumentInfo( + TableArgumentInfo( String name, DataType dataType, ArgumentTrait primaryTrait, String[] partitionColumnNames, boolean hasPassColumnsThrough) { - this.name = name; - this.dataType = dataType; + super(name, dataType); this.partitionColumnNames = partitionColumnNames; - this.isScalar = (primaryTrait == ArgumentTrait.SCALAR); - this.isTableArgument = (primaryTrait != ArgumentTrait.SCALAR); this.isSetSemantic = (primaryTrait == ArgumentTrait.SET_SEMANTIC_TABLE); this.prependStrategy = hasPassColumnsThrough @@ -1287,6 +1553,16 @@ private static class ArgumentInfo { } } + /** Scalar (constant) argument. */ + private static class ScalarArgumentInfo extends ArgumentInfo { + final Object value; + + ScalarArgumentInfo(String name, DataType dataType, Object value) { + super(name, dataType); + this.value = value; + } + } + private static class TableArgumentConfiguration { final AbstractDataType explicitType; @@ -1310,4 +1586,12 @@ private static class PartitionConfiguration { this.columnNames = columnNames; } } + + private static class StateArgumentConfiguration { + final Map initialValues; + + StateArgumentConfiguration() { + this.initialValues = new HashMap<>(); + } + } } diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java new file mode 100644 index 0000000000000..f293edb6aae91 --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.data.conversion.DataStructureConverter; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.Row; + +/** Converter for {@link Row}-typed state. */ +@Internal +class RowStateConverter implements StateConverter { + + private final DataStructureConverter converter; + private final RowType rowType; + + RowStateConverter(DataStructureConverter converter, RowType rowType) { + this.converter = converter; + this.rowType = rowType; + } + + @Override + public Object toInternal(Object external) { + if (external == null) { + return null; + } + return converter.toInternal(external); + } + + @Override + public Object toExternal(Object internal) { + if (internal == null) { + return null; + } + return converter.toExternal(internal); + } + + @Override + public Object createNewInternalState() { + Row row = Row.withNames(); + rowType.getFieldNames().forEach(name -> row.setField(name, null)); + return converter.toInternal(row); + } +} diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StateConverter.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StateConverter.java new file mode 100644 index 0000000000000..7b625396c9f9a --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StateConverter.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.Internal; + +/** + * Converter between external state representations (ListView, MapView & value state) and internal + * storage formats (ArrayData, MapData, & RowData). + */ +@Internal +interface StateConverter { + + /** Converts an external state object to internal storage format. */ + Object toInternal(Object external) throws Exception; + + /** Converts an internal storage format to external state object. */ + Object toExternal(Object internal); + + /** Create new internal state instance. */ + Object createNewInternalState(); +} diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java new file mode 100644 index 0000000000000..5599aef14a7ac --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.data.conversion.DataStructureConverter; + +/** + * Converter for value state backed by structured types. + * + *

Converts between external value state objects and internal RowData representation. + */ +@Internal +class StructuredTypeStateConverter implements StateConverter { + + private final DataStructureConverter converter; + private final Class pojoClass; + + StructuredTypeStateConverter( + DataStructureConverter converter, Class pojoClass) { + this.converter = converter; + this.pojoClass = pojoClass; + } + + @Override + public Object toInternal(Object external) { + if (external == null) { + return null; + } + return converter.toInternal(external); + } + + @Override + public Object toExternal(Object internal) { + if (internal == null) { + return null; + } + return converter.toExternal(internal); + } + + @Override + public Object createNewInternalState() { + try { + Object newPojo = pojoClass.getDeclaredConstructor().newInstance(); + return converter.toInternal(newPojo); + } catch (Exception e) { + throw new RuntimeException( + "Failed to create new instance of POJO class: " + pojoClass.getName(), e); + } + } +} diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessDataTypeFactory.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessDataTypeFactory.java index af748db7ad25c..ff01478fba109 100644 --- a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessDataTypeFactory.java +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessDataTypeFactory.java @@ -28,6 +28,7 @@ import org.apache.flink.table.types.extraction.DataTypeExtractor; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.utils.LogicalTypeParser; +import org.apache.flink.table.types.utils.TypeConversions; import org.apache.flink.table.types.utils.TypeInfoDataTypeConverter; /** @@ -53,8 +54,7 @@ public DataType createDataType(AbstractDataType abstractDataType) { public DataType createDataType(String typeString) { LogicalType logicalType = LogicalTypeParser.parse(typeString, Thread.currentThread().getContextClassLoader()); - return org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType( - logicalType); + return TypeConversions.fromLogicalToDataType(logicalType); } @Override diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessStateManager.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessStateManager.java new file mode 100644 index 0000000000000..5f051bbe45851 --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessStateManager.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.types.Row; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * State manager for {@link ProcessTableFunctionTestHarness}. + * + *

Handles state storage, lifecycle, and conversion between external and internal storage + * formats. + */ +@Internal +class TestHarnessStateManager { + + private final Map> stateByKey = new HashMap<>(); + private final List stateArguments; + private final Map stateConverters; + private final PartitionKeyInfo partitionKeyInfo; + + TestHarnessStateManager( + List stateArguments, + Map stateConverters, + PartitionKeyInfo partitionKeyInfo) { + this.stateArguments = stateArguments; + this.stateConverters = stateConverters; + this.partitionKeyInfo = partitionKeyInfo; + } + + static class PartitionKeyInfo { + final int arity; + @Nullable final String[] columnNames; + @Nullable final Class[] columnTypes; + + PartitionKeyInfo( + int arity, + @Nullable String[] columnNames, + @Nullable LogicalType[] columnLogicalTypes) { + this.arity = arity; + this.columnNames = columnNames; + this.columnTypes = + columnLogicalTypes != null + ? Arrays.stream(columnLogicalTypes) + .map(LogicalType::getDefaultConversion) + .toArray(Class[]::new) + : null; + } + + void validate(Row key) { + if (key.getArity() != arity) { + throw new IllegalArgumentException( + String.format( + "Partition key has arity %d, but expected arity %d.", + key.getArity(), arity)); + } + if (columnTypes == null) { + return; + } + for (int i = 0; i < arity; i++) { + Object value = key.getField(i); + if (value != null && !columnTypes[i].isInstance(value)) { + String columnName = columnNames != null ? columnNames[i] : "position " + i; + throw new IllegalArgumentException( + String.format( + "Partition key has type %s at position %d, " + + "but partition column '%s' expects %s.", + value.getClass().getSimpleName(), + i, + columnName, + columnTypes[i].getSimpleName())); + } + } + } + } + + /** + * Load state for a partition key. Creates new state instances if none exist. Converts internal + * storage to external objects (value state, ListView, MapView). + */ + Map loadStateForKey(Row key) { + Map internalState = + stateByKey.computeIfAbsent(key, k -> createEmptyKeyState()); + + Map externalState = new HashMap<>(); + for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg : stateArguments) { + Object internalData = internalState.get(stateArg.name); + Object external = convertToExternal(internalData, stateArg); + externalState.put(stateArg.name, external); + } + return externalState; + } + + /** + * Update mutated state after eval() invocation. Converts external objects to internal format. + */ + void updateStateForKey(Row key, Map externalState) throws Exception { + Map internalState = new HashMap<>(); + for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg : stateArguments) { + Object external = externalState.get(stateArg.name); + Object internalData = convertToInternal(external, stateArg); + internalState.put(stateArg.name, internalData); + } + stateByKey.put(key, internalState); + } + + /** Clear all state for a partition key. */ + void clearStateForKey(Row key) { + partitionKeyInfo.validate(key); + stateByKey.remove(key); + } + + /** Clear specific state entry for a given partition key, resetting it to its default value. */ + void clearStateEntryForKey(String stateName, Row key) { + partitionKeyInfo.validate(key); + Map internalState = stateByKey.get(key); + if (internalState != null) { + ProcessTableFunctionTestHarness.StateArgumentInfo stateArg = + findStateArgument(stateName); + internalState.put(stateName, createNewStateInternalData(stateArg)); + } + } + + /** Sets the state for a given partition key. */ + void setStateForKey(String stateName, Row key, Object externalState) throws Exception { + partitionKeyInfo.validate(key); + ProcessTableFunctionTestHarness.StateArgumentInfo stateArg = findStateArgument(stateName); + Object internalData = convertToInternal(externalState, stateArg); + + Map internalState = + stateByKey.computeIfAbsent(key, k -> createEmptyKeyState()); + internalState.put(stateName, internalData); + } + + /** Get the state for a given partition key. */ + @SuppressWarnings("unchecked") + T getStateForKey(String stateName, Row key) { + partitionKeyInfo.validate(key); + Map internalState = stateByKey.get(key); + if (internalState == null) { + return null; + } + Object internalData = internalState.get(stateName); + if (internalData == null) { + return null; + } + return (T) convertToExternal(internalData, findStateArgument(stateName)); + } + + /** Get all partition keys that have a specific state entry. */ + Set getKeysForState(String stateName) { + return stateByKey.entrySet().stream() + .filter(entry -> entry.getValue().containsKey(stateName)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + } + + /** Get all state values for a state name across all partition keys. */ + @SuppressWarnings("unchecked") + Map getStateForAllKeys(String stateName) { + ProcessTableFunctionTestHarness.StateArgumentInfo stateArg = findStateArgument(stateName); + Map result = new HashMap<>(); + for (Map.Entry> entry : stateByKey.entrySet()) { + Object internalData = entry.getValue().get(stateName); + if (internalData != null) { + result.put(entry.getKey(), (T) convertToExternal(internalData, stateArg)); + } + } + return result; + } + + private Map createEmptyKeyState() { + Map newState = new HashMap<>(); + for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg : stateArguments) { + newState.put(stateArg.name, createNewStateInternalData(stateArg)); + } + return newState; + } + + private Object createNewStateInternalData( + ProcessTableFunctionTestHarness.StateArgumentInfo stateArg) { + return stateConverters.get(stateArg.name).createNewInternalState(); + } + + private Object convertToExternal( + Object internalData, ProcessTableFunctionTestHarness.StateArgumentInfo stateArg) { + return stateConverters.get(stateArg.name).toExternal(internalData); + } + + private Object convertToInternal( + Object external, ProcessTableFunctionTestHarness.StateArgumentInfo stateArg) + throws Exception { + return stateConverters.get(stateArg.name).toInternal(external); + } + + private ProcessTableFunctionTestHarness.StateArgumentInfo findStateArgument(String stateName) { + for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg : stateArguments) { + if (stateArg.name.equals(stateName)) { + return stateArg; + } + } + String available = + stateArguments.stream().map(arg -> arg.name).collect(Collectors.joining(", ")); + throw new IllegalArgumentException( + "Unknown state: '" + stateName + "'. Available states: [" + available + "]"); + } +} diff --git a/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java b/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java index a02e9a8eefd2f..206f48e2ce153 100644 --- a/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java +++ b/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java @@ -24,13 +24,19 @@ import org.apache.flink.table.annotation.StateHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.api.dataview.ListView; +import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.table.functions.ProcessTableFunction; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -136,6 +142,25 @@ public void eval( } } + /** Stateful PTF with OPTIONAL_PARTITION_BY. */ + @DataTypeHint("ROW") + public static class StatefulOptionalPartitionPTF extends ProcessTableFunction { + public static class CounterState { + public long counter = 0L; + } + + public void eval( + @StateHint CounterState state, + @ArgumentHint({ + ArgumentTrait.SET_SEMANTIC_TABLE, + ArgumentTrait.OPTIONAL_PARTITION_BY + }) + Row input) { + state.counter++; + collect(Row.of(state.counter)); + } + } + /** Simple POJO for testing structured type input/output. */ public static class User { public String name; @@ -162,12 +187,12 @@ public boolean equals(Object o) { return false; } User user = (User) o; - return age == user.age && java.util.Objects.equals(name, user.name); + return age == user.age && Objects.equals(name, user.name); } @Override public int hashCode() { - return java.util.Objects.hash(name, age); + return Objects.hash(name, age); } } @@ -282,17 +307,94 @@ public void eval(Context ctx, @ArgumentHint(ArgumentTrait.ROW_SEMANTIC_TABLE) Ro } } - /** PTF with State parameter - should be rejected by test harness. */ - @DataTypeHint("ROW") - public static class PTFWithState extends ProcessTableFunction { - public static class CountState { + /** PTF with simple value state - counts rows per partition. */ + @DataTypeHint("ROW") + public static class PTFWithValueState extends ProcessTableFunction { + public static class CounterState { public long counter = 0L; } public void eval( - @StateHint CountState state, - @ArgumentHint(ArgumentTrait.ROW_SEMANTIC_TABLE) Row input) { - collect(input); + @StateHint CounterState state, + @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) { + state.counter++; + collect(Row.of(state.counter)); + } + } + + /** PTF with ListView state - accumulates values in a list. */ + @DataTypeHint("ROW>") + public static class PTFWithListViewState extends ProcessTableFunction { + public void eval( + @StateHint(type = @DataTypeHint("ARRAY")) ListView listState, + @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) + throws Exception { + Integer value = input.getFieldAs("value"); + listState.add(value); + + // Collect all values as an array + List values = new ArrayList<>(); + for (Integer v : listState.get()) { + values.add(v); + } + collect(Row.of((Object) values.toArray(new Integer[0]))); + } + } + + /** PTF with MapView state - counts occurrences of each key. */ + @DataTypeHint("ROW") + public static class PTFWithMapViewState extends ProcessTableFunction { + public void eval( + @StateHint MapView mapState, + @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) + throws Exception { + String key = input.getFieldAs("key"); + Integer count = mapState.get(key); + if (count == null) { + mapState.put(key, 1); + } else { + mapState.put(key, count + 1); + } + collect(Row.of(key, mapState.get(key))); + } + } + + /** PTF with Row state - mirrors the doc example using Row as state type. */ + @DataTypeHint("ROW") + public static class PTFWithRowState extends ProcessTableFunction { + public void eval( + @StateHint(type = @DataTypeHint("ROW")) Row memory, + @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) { + Long newCount = 1L; + if (memory.getField("count") != null) { + newCount += memory.getFieldAs("count"); + } + memory.setField("count", newCount); + collect(Row.of(newCount)); + } + } + + /** PTF with both value state and ListView state. */ + @DataTypeHint("ROW") + public static class PTFWithMultipleStates extends ProcessTableFunction { + public static class CounterState { + public long count = 0L; + } + + public void eval( + @StateHint CounterState counter, + @StateHint(type = @DataTypeHint("ARRAY")) ListView history, + @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) + throws Exception { + Integer value = input.getFieldAs("value"); + counter.count++; + history.add(value); + + int sum = 0; + for (Integer v : history.get()) { + sum += v; + } + collect(Row.of(counter.count, sum)); } } @@ -597,6 +699,56 @@ void testOptionalPartitionByWithPartition() throws Exception { } } + @Test + void testOptionalPartitionByWithStateNoPartition() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulOptionalPartitionPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .build()) { + + harness.processElement(Row.of("A", 10)); + harness.processElement(Row.of("B", 20)); + harness.processElement(Row.of("A", 30)); + + List output = harness.getOutput(); + assertThat(output).hasSize(3); + assertThat(output.get(0)).isEqualTo(Row.of(1L)); + assertThat(output.get(1)).isEqualTo(Row.of(2L)); + assertThat(output.get(2)).isEqualTo(Row.of(3L)); + + StatefulOptionalPartitionPTF.CounterState state = + harness.getStateForKey("state", Row.of()); + assertThat(state.counter).isEqualTo(3L); + } + } + + @Test + void testOptionalPartitionByWithStateAndPartition() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(StatefulOptionalPartitionPTF.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "key") + .build()) { + + harness.processElement(Row.of("A", 10)); + harness.processElement(Row.of("B", 20)); + harness.processElement(Row.of("A", 30)); + + List output = harness.getOutput(); + assertThat(output).hasSize(3); + assertThat(output.get(0)).isEqualTo(Row.of("A", 1L)); + assertThat(output.get(1)).isEqualTo(Row.of("B", 1L)); + assertThat(output.get(2)).isEqualTo(Row.of("A", 2L)); + + StatefulOptionalPartitionPTF.CounterState stateA = + harness.getStateForKey("state", Row.of("A")); + StatefulOptionalPartitionPTF.CounterState stateB = + harness.getStateForKey("state", Row.of("B")); + assertThat(stateA.counter).isEqualTo(2L); + assertThat(stateB.counter).isEqualTo(1L); + } + } + // ------------------------------------------------------------------------- // Data Type Conversion Tests // ------------------------------------------------------------------------- @@ -902,7 +1054,7 @@ void testPassColumnsThroughWithMultipleTablesRejected() { // Verify that PASS_COLUMNS_THROUGH is rejected when used with multiple table arguments Exception exception = assertThrows( - org.apache.flink.table.api.ValidationException.class, + ValidationException.class, () -> { ProcessTableFunctionTestHarness.ofClass( InvalidPassColumnsThroughMultiTablePTF.class) @@ -998,22 +1150,6 @@ void testContextParameterRejected() { .contains("position 0"); } - @Test - void testStateParameterRejected() { - Exception exception = - assertThrows( - IllegalStateException.class, - () -> - ProcessTableFunctionTestHarness.ofClass(PTFWithState.class) - .withTableArgument("input", DataTypes.of("ROW")) - .build()); - - assertThat(exception.getMessage()) - .contains("does not yet support state parameters") - .contains("@StateHint parameter") - .contains("position 0"); - } - @Test void testSetSemanticMissingPartitionConfigThrows() { Exception exception = @@ -1062,4 +1198,497 @@ void testPartitionByDuplicateConfigThrows() { assertThat(exception.getMessage()).contains("Partition config already exists"); } + + // ------------------------------------------------------------------------- + // State Tests + // ------------------------------------------------------------------------- + + @Test + void testValueState() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + harness.processElementForTable("input", Row.of("Alice", 10)); + assertThat(harness.getOutput()).containsExactly(Row.of("Alice", 1L)); + + PTFWithValueState.CounterState state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(1L); + + harness.processElementForTable("input", Row.of("Alice", 15)); + assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("Alice", 2L)); + + state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(2L); + + harness.close(); + } + + @Test + void testValueStatePartitionIsolation() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + harness.processElementForTable("input", Row.of("Alice", 10)); + harness.processElementForTable("input", Row.of("Bob", 20)); + harness.processElementForTable("input", Row.of("Alice", 15)); + + PTFWithValueState.CounterState aliceState = + harness.getStateForKey("state", Row.of("Alice")); + PTFWithValueState.CounterState bobState = harness.getStateForKey("state", Row.of("Bob")); + + assertThat(aliceState.counter).isEqualTo(2L); + assertThat(bobState.counter).isEqualTo(1L); + + harness.close(); + } + + @Test + void testValueStateWithInitialState() throws Exception { + PTFWithValueState.CounterState initialState = new PTFWithValueState.CounterState(); + initialState.counter = 100L; + + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "id") + .withInitialStateForKey("state", Row.of(1), initialState) + .build(); + + PTFWithValueState.CounterState state = harness.getStateForKey("state", Row.of(1)); + assertThat(state.counter).isEqualTo(100L); + + harness.processElement(Row.of(1)); + assertThat(harness.getOutput()).containsExactly(Row.of(1, 101L)); + + harness.processElement(Row.of(2)); + assertThat(harness.getOutput().get(1)).isEqualTo(Row.of(2, 1L)); + + harness.close(); + } + + @Test + void testGetStateKeys() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + harness.processElementForTable("input", Row.of("Alice", 10)); + harness.processElementForTable("input", Row.of("Bob", 20)); + harness.processElementForTable("input", Row.of("Charlie", 30)); + + Set keys = harness.getKeysForState("state"); + assertThat(keys) + .containsExactlyInAnyOrder(Row.of("Alice"), Row.of("Bob"), Row.of("Charlie")); + + harness.close(); + } + + @Test + void testGetAllState() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + harness.processElementForTable("input", Row.of("Alice", 10)); + harness.processElementForTable("input", Row.of("Alice", 15)); + harness.processElementForTable("input", Row.of("Bob", 20)); + + Map allState = harness.getStateForAllKeys("state"); + + assertThat(allState).hasSize(2); + assertThat(allState.get(Row.of("Alice")).counter).isEqualTo(2L); + assertThat(allState.get(Row.of("Bob")).counter).isEqualTo(1L); + + harness.close(); + } + + @Test + void testListViewState() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithListViewState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "key") + .build(); + + harness.processElementForTable("input", Row.of("A", 1)); + assertThat(harness.getOutput()).containsExactly(Row.of("A", new Integer[] {1})); + + harness.processElementForTable("input", Row.of("A", 2)); + assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("A", new Integer[] {1, 2})); + + ListView listState = harness.getStateForKey("listState", Row.of("A")); + assertThat(listState.get()).containsExactly(1, 2); + + harness.close(); + } + + @Test + void testMapViewState() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithMapViewState.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .build(); + + harness.processElementForTable("input", Row.of("P1", "foo")); + assertThat(harness.getOutput()).containsExactly(Row.of("P1", "foo", 1)); + + harness.processElementForTable("input", Row.of("P1", "foo")); + assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("P1", "foo", 2)); + + harness.processElementForTable("input", Row.of("P1", "bar")); + assertThat(harness.getOutput().get(2)).isEqualTo(Row.of("P1", "bar", 1)); + + MapView mapState = harness.getStateForKey("mapState", Row.of("P1")); + assertThat(mapState.get("foo")).isEqualTo(2); + assertThat(mapState.get("bar")).isEqualTo(1); + + harness.close(); + } + + @Test + void testRowState() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithRowState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + harness.processElementForTable("input", Row.of("Alice", 10)); + assertThat(harness.getOutput()).containsExactly(Row.of("Alice", 1L)); + + harness.processElementForTable("input", Row.of("Alice", 20)); + assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("Alice", 2L)); + + Row state = harness.getStateForKey("memory", Row.of("Alice")); + assertThat((Long) state.getFieldAs("count")).isEqualTo(2L); + + harness.close(); + } + + @Test + void testEmptyState() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + PTFWithValueState.CounterState state = harness.getStateForKey("state", Row.of("Alice")); + + assertThat(state).isNull(); + + harness.close(); + } + + @Test + void testClearStateForKey() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + harness.processElementForTable("input", Row.of("Alice", 10)); + harness.processElementForTable("input", Row.of("Alice", 15)); + + PTFWithValueState.CounterState state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(2L); + + harness.clearStateForKey(Row.of("Alice")); + + state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state).isNull(); + + harness.processElementForTable("input", Row.of("Alice", 30)); + state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(1L); + + harness.close(); + } + + @Test + void testClearStateEntry() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + harness.processElementForTable("input", Row.of("Alice", 10)); + harness.processElementForTable("input", Row.of("Alice", 15)); + + PTFWithValueState.CounterState state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(2L); + + harness.clearStateEntryForKey("state", Row.of("Alice")); + + state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(0L); + + harness.processElementForTable("input", Row.of("Alice", 30)); + state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(1L); + + harness.close(); + } + + @Test + void testMultipleStateParameters() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithMultipleStates.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "key") + .build(); + + harness.processElementForTable("input", Row.of("A", 10)); + harness.processElementForTable("input", Row.of("A", 20)); + harness.processElementForTable("input", Row.of("B", 5)); + + assertThat(harness.getOutput()) + .containsExactly(Row.of("A", 1L, 10), Row.of("A", 2L, 30), Row.of("B", 1L, 5)); + + PTFWithMultipleStates.CounterState counterA = + harness.getStateForKey("counter", Row.of("A")); + assertThat(counterA.count).isEqualTo(2L); + + ListView historyA = harness.getStateForKey("history", Row.of("A")); + assertThat(historyA.get()).containsExactly(10, 20); + + harness.close(); + } + + @Test + void testInitialStateWithListView() throws Exception { + ListView initialList = new ListView<>(); + initialList.add(100); + initialList.add(200); + + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithListViewState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "key") + .withInitialStateForKey("listState", Row.of("A"), initialList) + .build(); + + ListView listState = harness.getStateForKey("listState", Row.of("A")); + assertThat(listState.get()).containsExactly(100, 200); + + harness.processElementForTable("input", Row.of("A", 3)); + assertThat(harness.getOutput()).containsExactly(Row.of("A", new Integer[] {100, 200, 3})); + + harness.close(); + } + + @Test + void testInitialStateWithMapView() throws Exception { + MapView initialMap = new MapView<>(); + initialMap.put("existing", 42); + + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithMapViewState.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withInitialStateForKey("mapState", Row.of("P1"), initialMap) + .build(); + + MapView mapState = harness.getStateForKey("mapState", Row.of("P1")); + assertThat(mapState.get("existing")).isEqualTo(42); + + harness.processElementForTable("input", Row.of("P1", "existing")); + assertThat(harness.getOutput()).containsExactly(Row.of("P1", "existing", 43)); + + harness.close(); + } + + @Test + void testInitialStateKeyArityMismatch() { + Exception exception = + assertThrows( + IllegalArgumentException.class, + () -> + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument( + "input", + DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .withInitialStateForKey( + "state", + Row.of("Alice", 42), + new PTFWithValueState.CounterState()) + .build()); + + assertThat(exception.getMessage()).contains("state"); + assertThat(exception.getMessage()).contains("arity 2"); + assertThat(exception.getMessage()).contains("arity 1"); + } + + @Test + void testInitialStateKeyTypeMismatch() { + Exception exception = + assertThrows( + IllegalArgumentException.class, + () -> + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument( + "input", + DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .withInitialStateForKey( + "state", + Row.of(42), + new PTFWithValueState.CounterState()) + .build()); + + assertThat(exception.getMessage()).contains("state"); + assertThat(exception.getMessage()).contains("Integer"); + assertThat(exception.getMessage()).contains("name"); + assertThat(exception.getMessage()).contains("String"); + } + + @Test + void testSetStateForKey() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + harness.processElementForTable("input", Row.of("Alice", 10)); + harness.processElementForTable("input", Row.of("Alice", 20)); + + PTFWithValueState.CounterState state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(2L); + + PTFWithValueState.CounterState newState = new PTFWithValueState.CounterState(); + newState.counter = 50L; + harness.setStateForKey("state", Row.of("Alice"), newState); + + state = harness.getStateForKey("state", Row.of("Alice")); + assertThat(state.counter).isEqualTo(50L); + + harness.processElementForTable("input", Row.of("Alice", 30)); + assertThat(harness.getOutput().get(2)).isEqualTo(Row.of("Alice", 51L)); + + harness.close(); + } + + @Test + void testInvalidStateNameInWithInitialState() { + Exception exception = + assertThrows( + IllegalArgumentException.class, + () -> + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "id") + .withInitialStateForKey( + "nonExistentState", Row.of(1), "value") + .build()); + + assertThat(exception.getMessage()).contains("Unknown state"); + assertThat(exception.getMessage()).contains("nonExistentState"); + assertThat(exception.getMessage()).contains("Available states"); + assertThat(exception.getMessage()).contains("state"); + } + + // ------------------------------------------------------------------------- + // Partition Key Validation Tests + // ------------------------------------------------------------------------- + + @Test + void testPartitionKeyValidationWrongArity() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + Exception exception = + assertThrows( + IllegalArgumentException.class, + () -> harness.getStateForKey("state", Row.of("Alice", "extra"))); + assertThat(exception.getMessage()).contains("arity 2"); + assertThat(exception.getMessage()).contains("expected arity 1"); + + harness.close(); + } + + @Test + void testPartitionKeyValidationWrongType() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + Exception exception = + assertThrows( + IllegalArgumentException.class, + () -> harness.getStateForKey("state", Row.of(123))); + assertThat(exception.getMessage()).contains("Integer"); + assertThat(exception.getMessage()).contains("name"); + assertThat(exception.getMessage()).contains("String"); + + harness.close(); + } + + @Test + void testPartitionKeyValidationOnSetState() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + PTFWithValueState.CounterState state = new PTFWithValueState.CounterState(); + state.counter = 1L; + + assertThrows( + IllegalArgumentException.class, + () -> harness.setStateForKey("state", Row.of(1, 2), state)); + + harness.close(); + } + + @Test + void testPartitionKeyValidationOnClearState() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + assertThrows( + IllegalArgumentException.class, () -> harness.clearStateForKey(Row.of("a", "b"))); + + harness.close(); + } + + @Test + void testPartitionKeyValidationOnClearStateEntry() throws Exception { + ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class) + .withTableArgument("input", DataTypes.of("ROW")) + .withPartitionBy("input", "name") + .build(); + + assertThrows( + IllegalArgumentException.class, + () -> harness.clearStateEntryForKey("state", Row.of(42))); + + harness.close(); + } }