diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java new file mode 100644 index 000000000000..e499cbdf2c1e --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -0,0 +1,673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + *

For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn extends DoFn, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap itemsInBuffer = + new ConcurrentHashMap<>(); + + private static final ReentrantLock lock = new ReentrantLock(); + private static final boolean verboseLogging = false; + + private static class InFlightElement { + final CompletableFuture> future; + + InFlightElement(CompletableFuture> future) { + this.future = future; + } + } + + // The In-Memory Accumulating Receiver + // Accumulates elements in-memory during asynchronous background worker execution. + // Buffered elements are only committed downstream once the parent task completes successfully + // and the timer fires. + private static class AccumulatingOutputReceiver implements OutputReceiver { + private final List outputs = Collections.synchronizedList(new ArrayList<>()); + + @Override + public org.apache.beam.sdk.values.OutputBuilder builder(T value) { + return org.apache.beam.sdk.values.WindowedValues.builder() + .setValue(value) + .setTimestamp(Instant.now()) + .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) + .setReceiver(windowedValue -> outputs.add(windowedValue.getValue())); + } + + // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. + // JVM optimization to prevent garbage collection pressure under high pipeline throughput. + @Override + public void output(T output) { + outputs.add(output); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + outputs.add(output); + } + + public List getOutputs() { + return outputs; + } + } + + public AsyncDoFn( + DoFn syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction idFn, + boolean useThreadPool) { + this( + syncFn, + parallelism, + timerFrequency, + maxItemsToBuffer, + timeout, + maxWaitTime, + idFn, + useThreadPool, + null); + } + + public AsyncDoFn( + DoFn syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction idFn, + boolean useThreadPool, + @Nullable Coder> coder) { + this.syncFn = syncFn; + this.parallelism = parallelism; + this.timerFrequency = timerFrequency; + this.maxItemsToBuffer = + (maxItemsToBuffer != null) + ? maxItemsToBuffer + : Math.max(parallelism * 2, DEFAULT_MIN_BUFFER_CAPACITY); + this.timeout = (timeout != null) ? timeout : Duration.standardSeconds(DEFAULT_TIMEOUT_SEC); + this.maxWaitTime = + (maxWaitTime != null) ? maxWaitTime : Duration.millis(DEFAULT_MAX_WAIT_TIME_MS); + this.idFn = + (idFn != null) + ? idFn + : (SerializableFunction) + input -> java.util.Objects.requireNonNull(input); + this.useThreadPool = useThreadPool; + this.uuid = UUID.randomUUID().toString(); + this.toProcessSpec = (coder != null) ? StateSpecs.bag(coder) : StateSpecs.bag(); + } + + private ExecutorService getThreadPool() { + ExecutorService threadPool = pool.get(uuid); + if (threadPool == null) { + throw new IllegalStateException("Thread pool not initialized for UUID: " + uuid); + } + return threadPool; + } + + @SuppressWarnings("unchecked") + private ConcurrentHashMap> getProcessingElements() { + ConcurrentHashMap> elements = processingElements.get(uuid); + if (elements == null) { + throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); + } + return (ConcurrentHashMap>) (ConcurrentHashMap) elements; + } + + private AtomicInteger getItemsInBuffer() { + AtomicInteger buffer = itemsInBuffer.get(uuid); + if (buffer == null) { + throw new IllegalStateException("Buffer counter not initialized for UUID: " + uuid); + } + return buffer; + } + + @Setup + public void setup(PipelineOptions options) { + this.pipelineOptions = options; + + // Setup the wrapped DoFn + DoFnInvokers.invokerFor(syncFn) + .invokeSetup( + new DoFnInvoker.BaseArgumentProvider() { + @Override + public PipelineOptions pipelineOptions() { + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Setup"; + } + }); + + if (useThreadPool) { + LOG.info("Using thread pool for asynchronous execution with parallelism {}", parallelism); + } + + lock.lock(); + try { + pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); + itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); + } finally { + lock.unlock(); + } + } + + // Clean up JVM-wide shared resources to prevent thread leaks on the worker + @Teardown + public void teardown() { + DoFnInvokers.invokerFor(syncFn).invokeTeardown(); + + ExecutorService threadPool; + lock.lock(); + try { + threadPool = pool.remove(uuid); + processingElements.remove(uuid); + itemsInBuffer.remove(uuid); + } finally { + lock.unlock(); + } + + if (threadPool != null) { + threadPool.shutdown(); + try { + if (!threadPool.awaitTermination(TEARDOWN_AWAIT_SEC, TimeUnit.SECONDS)) { + threadPool.shutdownNow(); + } + } catch (InterruptedException e) { + threadPool.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + + // Asynchronous Scheduling & Deduplication + // Submits tasks to the background thread pool. If an element with the same ID is already + // in-flight, + // the submission is silently ignored to enforce exactly-once semantics. + private boolean scheduleIfRoom( + KV element, BoundedWindow window, Instant timestamp, boolean ignoreBuffer) { + lock.lock(); + try { + ConcurrentHashMap> activeElements = getProcessingElements(); + Object elementId = idFn.apply(element.getValue()); + + if (activeElements.containsKey(elementId)) { + LOG.info("Item {} already in processing elements", element); + return true; + } + + int currentBuffer = getItemsInBuffer().get(); + if (currentBuffer < maxItemsToBuffer || ignoreBuffer) { + java.util.concurrent.Executor executor = + useThreadPool ? getThreadPool() : java.util.concurrent.ForkJoinPool.commonPool(); + + // Pending asynchronous task that will produce a list of outputs + CompletableFuture> future = + CompletableFuture.supplyAsync( + () -> { + try { + AccumulatingOutputReceiver receiver = + new AccumulatingOutputReceiver<>(); + DoFnInvoker invoker = DoFnInvokers.invokerFor(syncFn); + + DoFnInvoker.ArgumentProvider bundleArgProvider = + new DoFnInvoker.BaseArgumentProvider() { + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return doFn.new FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions(); + } + + @Override + public void output( + OutputT output, Instant timestamp, BoundedWindow window) { + receiver.output(output); + } + + @Override + public void output( + TupleTag tag, + T output, + Instant timestamp, + BoundedWindow window) { + throw new UnsupportedOperationException( + "Tagged output not supported in FinishBundleContext for AsyncDoFn"); + } + }; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Bundle"; + } + }; + + invoker.invokeStartBundle(bundleArgProvider); + + DoFnInvoker.ArgumentProvider processArgProvider = + new DoFnInvoker.BaseArgumentProvider() { + @Override + public InputT element(DoFn doFn) { + return element.getValue(); + } + + @Override + public OutputReceiver outputReceiver( + DoFn doFn) { + return receiver; + } + + @Override + public BoundedWindow window() { + return window; + } + + @Override + public Instant timestamp(DoFn doFn) { + return timestamp; + } + + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Process"; + } + }; + + invoker.invokeProcessElement(processArgProvider); + invoker.invokeFinishBundle(bundleArgProvider); + + return receiver.getOutputs(); + } catch (Exception e) { + throw new CompletionException(e); + } + }, + executor); + + // Assigned to 'unused' to satisfy ErrorProne while preserving parent future for + // cancellation + CompletableFuture> unused = + future.whenComplete( + (res, ex) -> { + lock.lock(); + try { + getItemsInBuffer().decrementAndGet(); + } finally { + lock.unlock(); + } + }); + + activeElements.put(elementId, new InFlightElement<>(future)); + getItemsInBuffer().incrementAndGet(); + return true; + } + + return false; + } finally { + lock.unlock(); + } + } + + private void scheduleItem(KV element, BoundedWindow window, Instant timestamp) { + boolean done = false; + long sleepTime = INITIAL_BACKOFF_SLEEP_MS; + long totalSleep = 0; + long timeoutMs = timeout.getMillis(); + + while (!done && totalSleep < timeoutMs) { + done = scheduleIfRoom(element, window, timestamp, false); + if (!done) { + long sleep = Math.min(maxWaitTime.getMillis(), sleepTime); + if (verboseLogging || totalSleep > BACKPRESSURE_LOG_THRESHOLD_MS) { + LOG.info( + "buffer is full for item {}, {} waiting {} ms. Have waited for {} ms.", + element, + getItemsInBuffer().get(), + sleep, + totalSleep); + } + try { + Thread.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for space in buffer", e); + } + sleepTime *= 2; + totalSleep += sleep; + } + } + // Timeout: element skips JVM pool but stays in BagState for timer to reschedule later. + } + + private Instant nextTimeToFire(@Nullable K key) { + long seed = (key == null) ? 0 : key.hashCode(); + Random random = new Random(seed); + double timerFrequencySec = timerFrequency.getMillis() / 1000.0; + double nowSec = System.currentTimeMillis() / 1000.0; + + double base = Math.floor((nowSec + timerFrequencySec) / timerFrequencySec) * timerFrequencySec; + double offset = random.nextDouble() * timerFrequencySec; + + return Instant.ofEpochMilli((long) ((base + offset) * 1000)); + } + + @ProcessElement + public void processElement( + ProcessContext c, + BoundedWindow window, + @StateId("to_process") BagState> toProcessState, + @TimerId("timer") Timer timer) { + + KV element = c.element(); + scheduleItem(element, window, c.timestamp()); + toProcessState.add(element); + + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + @OnTimer("timer") + public void onTimer( + OnTimerContext c, + @StateId("to_process") BagState> toProcessState, + @TimerId("timer") Timer timer, + OutputReceiver receiver) { + + commitFinishedItems(c.fireTimestamp(), toProcessState, timer, receiver); + } + + // Synchronizes local task results with the runner's persistent state container. + // Emits successfully completed elements, cancels rolled-back tasks, and reschedules lost work. + private void commitFinishedItems( + Instant fireTimestamp, + BagState> toProcessState, + Timer timer, + OutputReceiver receiver) { + + Iterable> toProcessLocal = toProcessState.read(); + if (toProcessLocal == null || !toProcessLocal.iterator().hasNext()) { + // Early Exit: if BagState is empty, we skip checking activeElements for this key. + return; + } + + // Since fireTimestamp is key-scoped, we determine the current key from the first element in + // state + List> stateList = new ArrayList<>(); + K key = null; + for (KV element : toProcessLocal) { + stateList.add(element); + if (key == null) { + key = element.getKey(); + } + } + + if (verboseLogging) { + LOG.info("processing timer for key: {}", key); + } + + ConcurrentHashMap> activeElements = getProcessingElements(); + + List> toReturn = new ArrayList<>(); + Set> finishedItems = new HashSet<>(); + List> toReschedule = new ArrayList<>(); + + int itemsFinished = 0; + int itemsNotYetFinished = 0; + int itemsRescheduled = 0; + int itemsCancelled = 0; + + Set finishedElementIds = new HashSet<>(); + Set inFlightElementIds = new HashSet<>(); + Set rescheduledElementIds = new HashSet<>(); + + lock.lock(); + try { + for (KV element : stateList) { + Object elementId = idFn.apply(element.getValue()); + + // Skip processing if we already completed, rescheduled, or found this elementId active in + // this cycle + if (finishedElementIds.contains(elementId) + || rescheduledElementIds.contains(elementId) + || inFlightElementIds.contains(elementId)) { + continue; + } + + if (activeElements.containsKey(elementId)) { + InFlightElement inFlight = activeElements.get(elementId); + if (inFlight.future.isDone()) { + try { + if (!inFlight.future.isCancelled()) { + toReturn.add(inFlight.future.get()); + } + finishedItems.add(element); + finishedElementIds.add(elementId); + activeElements.remove(elementId); + itemsFinished++; + } catch (Exception e) { + LOG.error("Error executing async task for element {}", element, e); + finishedItems.add(element); + finishedElementIds.add(elementId); + activeElements.remove(elementId); + } + } else { + inFlightElementIds.add(elementId); + itemsNotYetFinished++; + } + } else { + LOG.info( + "Item {} found in state but not in local active elements, scheduling now", element); + toReschedule.add(element); + rescheduledElementIds.add(elementId); + itemsRescheduled++; + } + } + } finally { + lock.unlock(); + } + + // Reschedule missing elements + for (KV element : toReschedule) { + scheduleItem(element, GlobalWindow.INSTANCE, fireTimestamp); + } + + // Update State: keep only unfinished items + toProcessState.clear(); + int itemsInProcessingState = 0; + for (KV element : stateList) { + if (!finishedItems.contains(element)) { + toProcessState.add(element); + itemsInProcessingState++; + } + } + + // Emit completed outputs (Emit completed tasks immediately; do not wait for all active tasks to + // finish). + for (List outputs : toReturn) { + for (OutputT out : outputs) { + receiver.output(out); + } + } + + LOG.info( + "Items finished: {}, not yet finished: {}, rescheduled: {}, cancelled: {}, in processing state: {}", + itemsFinished, + itemsNotYetFinished, + itemsRescheduled, + itemsCancelled, + itemsInProcessingState); + + if (itemsInProcessingState > 0) { + Instant timeToFire = nextTimeToFire(key); + timer.set(timeToFire); + } + } + + // Package-private helper methods for testing direct execution without Pipeline / ProcessContext + // boilerplate + void processDirect( + KV element, + BoundedWindow window, + Instant timestamp, + BagState> toProcessState, + Timer timer) { + scheduleItem(element, window, timestamp); + toProcessState.add(element); + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + List commitFinishedItemsDirect( + Instant fireTimestamp, BagState> toProcessState, Timer timer) { + AccumulatingOutputReceiver receiver = new AccumulatingOutputReceiver<>(); + commitFinishedItems(fireTimestamp, toProcessState, timer, receiver); + return receiver.getOutputs(); + } + + boolean isEmpty() { + return getItemsInBuffer().get() == 0; + } + + int getItemsInBufferCount() { + return getItemsInBuffer().get(); + } + + static void resetState() { + lock.lock(); + try { + for (Map.Entry entry : pool.entrySet()) { + entry.getValue().shutdownNow(); + } + pool.clear(); + processingElements.clear(); + itemsInBuffer.clear(); + } finally { + lock.unlock(); + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java new file mode 100644 index 000000000000..912aca3f309c --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java @@ -0,0 +1,733 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for verifying async processing structures and logic. */ +@RunWith(JUnit4.class) +public class AsyncDoFnTest implements Serializable { + + @Rule public final transient TestPipeline p = TestPipeline.create(); + private final boolean useThreadPool = true; + + // Used for testing basic DoFn processing logic with optional latency. + private static class BasicDofn extends DoFn { + private final long sleepTimeMs; + private int processed = 0; + private final ReentrantLock lock = new ReentrantLock(); + + BasicDofn(long sleepTimeMs) { + this.sleepTimeMs = sleepTimeMs; + } + + BasicDofn() { + this(0); + } + + @ProcessElement + public void processElement(@Element String element, OutputReceiver receiver) { + if (sleepTimeMs > 0) { + try { + Thread.sleep(sleepTimeMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + lock.lock(); + try { + processed += 1; + } finally { + lock.unlock(); + } + receiver.output(element); + } + + int getProcessed() { + lock.lock(); + try { + return processed; + } finally { + lock.unlock(); + } + } + } + + // Used for testing multi element processing with optional finish bundle call. + private static class MultiElementDoFn extends DoFn { + @ProcessElement + public void processElement(@Element String element, OutputReceiver receiver) { + receiver.output(element); + receiver.output(element); + } + + @FinishBundle + public void finishBundle(FinishBundleContext c) { + c.output("bundle end", Instant.now(), GlobalWindow.INSTANCE); + } + } + + // Used for testing BagState thread safety. + private static class FakeBagState implements BagState { + private final List items; + private final ReentrantLock lock = new ReentrantLock(); + + FakeBagState(List initialItems) { + this.items = new ArrayList<>(initialItems); + } + + FakeBagState(T initialItem) { + this(new ArrayList<>(List.of(initialItem))); + } + + FakeBagState() { + this(new ArrayList<>()); + } + + @Override + public void add(T item) { + lock.lock(); + try { + items.add(item); + } finally { + lock.unlock(); + } + } + + @Override + public void clear() { + lock.lock(); + try { + items.clear(); + } finally { + lock.unlock(); + } + } + + @Override + public Iterable read() { + lock.lock(); + try { + return new ArrayList<>(items); + } finally { + lock.unlock(); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + lock.lock(); + try { + return items.isEmpty(); + } finally { + lock.unlock(); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public BagState readLater() { + return this; + } + } + + // 4. Used for testing Timer mock implementations. + private static class FakeTimer implements Timer { + private Instant time = Instant.EPOCH; + + @Override + public void set(Instant absoluteTime) { + this.time = absoluteTime; + } + + @Override + public void setRelative() {} + + @Override + public void clear() { + this.time = Instant.EPOCH; + } + + @Override + public Timer offset(Duration offset) { + return this; + } + + @Override + public Timer align(Duration period) { + return this; + } + + @Override + public Timer withOutputTimestamp(Instant outputTime) { + return this; + } + + @Override + public Timer withNoOutputTimestamp() { + return this; + } + + @Override + public Instant getCurrentRelativeTime() { + return time; + } + } + + @Before + public void setUp() { + AsyncDoFn.resetState(); + } + + private void waitForEmpty(AsyncDoFn asyncDoFn) { + waitForEmpty(asyncDoFn, 10); + } + + private void waitForEmpty(AsyncDoFn asyncDoFn, int timeoutSeconds) { + int count = 0; + while (!asyncDoFn.isEmpty()) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + count += 1; + if (count > timeoutSeconds) { + throw new RuntimeException("Timed out waiting for async dofn to be empty"); + } + } + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private void checkOutput(List result, List expectedOutput) { + List resultStr = new ArrayList<>(); + for (T val : result) { + resultStr.add(val.toString()); + } + List expectedStr = new ArrayList<>(); + for (T val : expectedOutput) { + expectedStr.add(val.toString()); + } + Collections.sort(resultStr); + Collections.sort(expectedStr); + assertEquals(expectedStr, resultStr); + } + + private void checkItemsInBuffer(AsyncDoFn asyncDoFn, int expectedCount) { + assertEquals(expectedCount, asyncDoFn.getItemsInBufferCount()); + } + + // Test 1: testCustomIdFn + // Verifies key extraction custom logic. Duplicate elements (same custom ID but different payload) + // should be recognized as already in-flight and deduplicated. + @Test + public void testCustomIdFn() { + class CustomIdObject implements Serializable { + final int elementId; + final String value; + + CustomIdObject(int elementId, String value) { + this.elementId = elementId; + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof CustomIdObject)) return false; + CustomIdObject that = (CustomIdObject) o; + return elementId == that.elementId; + } + + @Override + public int hashCode() { + return java.util.Objects.hash(elementId); + } + + @Override + public String toString() { + return "CustomIdObject{id=" + elementId + ", val=" + value + "}"; + } + } + + class CustomIdDofn extends DoFn { + @ProcessElement + public void processElement(@Element CustomIdObject element, OutputReceiver receiver) { + receiver.output(element.value); + } + } + + CustomIdDofn dofn = new CustomIdDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, + 1, + Duration.standardSeconds(5), + null, + null, + null, + x -> x.elementId, + useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + KV msg1 = KV.of("key1", new CustomIdObject(1, "a")); + KV msg2 = KV.of("key1", new CustomIdObject(1, "b")); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("a")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 2: testBasic + // Verifies the standard end-to-end execution flow. Elements should be queued in persistent state + // and output correctly upon completion. + @Test + public void testBasic() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + assertEquals(1, fakeBagState.items.size()); + assertNotEquals(Instant.EPOCH, fakeTimer.getCurrentRelativeTime()); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(1, dofn.getProcessed()); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 3: testMultiKey + // Verifies key grouping isolation. Firing a timer for one partition key must not release + // or interfere with elements queued under a different partition key. + @Test + public void testMultiKey() { + for (boolean useThreadPool : new boolean[] {true, false}) { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagStateKey1 = new FakeBagState<>(); + FakeBagState> fakeBagStateKey2 = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key2", "2"); + + asyncDoFn.processDirect( + msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagStateKey1, fakeTimer); + asyncDoFn.processDirect( + msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagStateKey2, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagStateKey2, fakeTimer); + checkOutput(result, Collections.singletonList("2")); + assertEquals(1, fakeBagStateKey1.items.size()); + assertEquals(0, fakeBagStateKey2.items.size()); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagStateKey1, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(0, fakeBagStateKey1.items.size()); + assertEquals(0, fakeBagStateKey2.items.size()); + } + } + + // Test 4: testLongItem + // Verifies that outputs are kept in-flight and not committed prematurely if the background + // execution task has not finished processing yet. + @Test + public void testLongItem() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + assertEquals(0, dofn.getProcessed()); + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn, 20); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(1, dofn.getProcessed()); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 5: testLostItem + // Verifies if the local worker's in-memory cache is empty but the runner's + // persistent state contains pending items. + // The wrapper must automatically detect the mismatch and reschedule execution. + @Test + public void testLostItem() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + FakeBagState> fakeBagState = new FakeBagState<>(msg); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + + waitForEmpty(asyncDoFn); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + } + + // Test 6: testCancelledItem + // Verifies active task cancellation. If a pending element is deleted from the runner's persistent + // state prior to a commit (e.g., due to a rollback), the background future task must be actively + // cancelled. + @Test + public void testCancelledItem() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key1", "2"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + fakeBagState.clear(); + fakeBagState.add(msg2); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("2")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 7: testMultiElementDofn + // Verifies support for DoFns that emit multiple outputs per element, and correctly aggregates + // outputs produced during the finishBundle stage of the sync DoFn's lifecycle. + @Test + public void testMultiElementDofn() { + MultiElementDoFn dofn = new MultiElementDoFn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Arrays.asList("1", "1", "bundle end")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 8: testDuplicates + // Verifies deduplication of duplicate elements under active processing. + // Identical elements should not spawn multiple concurrent background executions. + @Test + public void testDuplicates() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + fakeBagState.clear(); + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 9: testBufferCount + // Verifies accurate in-flight metrics tracking. + // The item count in the buffer must increment on task scheduling + // and decrement immediately upon execution completion. + @Test + public void testBufferCount() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg = KV.of("key1", "1"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 1); + + waitForEmpty(asyncDoFn); + checkItemsInBuffer(asyncDoFn, 0); + + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 0); + } + + // Test 10: testBufferStopsAcceptingItems + // Verifies queue boundaries and backpressure throttling. + // When concurrent threads push elements exceeding the capacity limit, + // the scheduler must block and delay submissions appropriately. + @Test + public void testBufferStopsAcceptingItems() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, + 1, + Duration.standardSeconds(5), + 5, // max buffer capacity + null, + null, + null, + useThreadPool); + asyncDoFn.setup(null); + + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + ExecutorService poolExecutor = Executors.newFixedThreadPool(10); + List expectedOutput = new ArrayList<>(); + List> futures = new ArrayList<>(); + + for (int i = 0; i < 10; i++) { + final int idx = i; + expectedOutput.add(String.valueOf(idx)); + futures.add( + poolExecutor.submit( + () -> { + KV item = KV.of("key", String.valueOf(idx)); + asyncDoFn.processDirect( + item, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + })); + } + + try { + Thread.sleep(200); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + assertEquals(5, asyncDoFn.getItemsInBufferCount()); + + waitForEmpty(asyncDoFn, 100); + + // Verify that all background tasks completed successfully without throwing exceptions + for (Future future : futures) { + try { + future.get(); // This will re-throw any exception that occurred in the background thread + } catch (Exception e) { + throw new AssertionError("Background task failed", e); + } + } + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn, 100); + + result.addAll( + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer)); + + checkOutput(result, expectedOutput); + checkItemsInBuffer(asyncDoFn, 0); + poolExecutor.shutdown(); + } + + // Test 11: testBufferWithCancellation + // Verifies backpressure behavior in conjunction with element cancellation. + // Elements that are actively cancelled during queue throttling should be dropped cleanly from the + // buffer. + @Test + public void testBufferWithCancellation() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key1", "2"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + checkItemsInBuffer(asyncDoFn, 2); + + fakeBagState.clear(); + fakeBagState.add(msg2); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 0); + checkOutput(result, Collections.singletonList("2")); + } + + // Test 12: testResetStateConcurrentTeardown + // Verifies safe resource cleanup during concurrent shutdown. + // Resetting the global shared execution state while workers are running + // must complete cleanly without thread or lock deadlocks. + @Test + public void testResetStateConcurrentTeardown() { + BasicDofn dofn = new BasicDofn(500); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + asyncDoFn.processDirect( + KV.of("key1", "1"), GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Verify calling resetState() while background tasks are running finishes cleanly + AsyncDoFn.resetState(); + } +}