diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java
new file mode 100644
index 000000000000..e499cbdf2c1e
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java
@@ -0,0 +1,673 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.ReentrantLock;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TupleTag;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Class that wraps a dofn and converts it from one which process elements synchronously to one
+ * which processes them asynchronously.
+ *
+ *
For synchronous dofns the default settings mean that many (100s) of elements will be processed
+ * in parallel and that processing an element will block all other work on that key. In addition
+ * runners are optimized for latencies less than a few seconds and longer operations can result in
+ * high retry rates. Async should be considered when the default parallelism is not correct and/or
+ * items are expected to take longer than a few seconds to process.
+ */
+public class AsyncDoFn extends DoFn, OutputT> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class);
+
+ private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10;
+ private static final int DEFAULT_TIMEOUT_SEC = 1;
+ private static final int DEFAULT_MAX_WAIT_TIME_MS = 500;
+ private static final int TEARDOWN_AWAIT_SEC = 5;
+ private static final int INITIAL_BACKOFF_SLEEP_MS = 10;
+ private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000;
+
+ @StateId("to_process")
+ private final StateSpec>> toProcessSpec;
+
+ @TimerId("timer")
+ private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
+
+ private final DoFn syncFn;
+ private final int parallelism;
+ private final Duration timerFrequency;
+ private final int maxItemsToBuffer;
+ private final Duration timeout;
+ private final Duration maxWaitTime;
+ private final SerializableFunction idFn;
+ private final boolean useThreadPool;
+ private final String uuid;
+
+ private transient @Nullable PipelineOptions pipelineOptions;
+
+ // Shared JVM-Wide States (Static Registries)
+ // Map-backed registry holding shared resources across serialized worker instances. Since runners
+ // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse.
+ private static final ConcurrentHashMap pool = new ConcurrentHashMap<>();
+ // activeElements (processingElements) is global JVM memory (all keys)
+ private static final ConcurrentHashMap>>
+ processingElements = new ConcurrentHashMap<>();
+ private static final ConcurrentHashMap itemsInBuffer =
+ new ConcurrentHashMap<>();
+
+ private static final ReentrantLock lock = new ReentrantLock();
+ private static final boolean verboseLogging = false;
+
+ private static class InFlightElement {
+ final CompletableFuture> future;
+
+ InFlightElement(CompletableFuture> future) {
+ this.future = future;
+ }
+ }
+
+ // The In-Memory Accumulating Receiver
+ // Accumulates elements in-memory during asynchronous background worker execution.
+ // Buffered elements are only committed downstream once the parent task completes successfully
+ // and the timer fires.
+ private static class AccumulatingOutputReceiver implements OutputReceiver {
+ private final List outputs = Collections.synchronizedList(new ArrayList<>());
+
+ @Override
+ public org.apache.beam.sdk.values.OutputBuilder builder(T value) {
+ return org.apache.beam.sdk.values.WindowedValues.builder()
+ .setValue(value)
+ .setTimestamp(Instant.now())
+ .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE))
+ .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING)
+ .setReceiver(windowedValue -> outputs.add(windowedValue.getValue()));
+ }
+
+ // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs.
+ // JVM optimization to prevent garbage collection pressure under high pipeline throughput.
+ @Override
+ public void output(T output) {
+ outputs.add(output);
+ }
+
+ @Override
+ public void outputWithTimestamp(T output, Instant timestamp) {
+ outputs.add(output);
+ }
+
+ public List getOutputs() {
+ return outputs;
+ }
+ }
+
+ public AsyncDoFn(
+ DoFn syncFn,
+ int parallelism,
+ Duration timerFrequency,
+ @Nullable Integer maxItemsToBuffer,
+ @Nullable Duration timeout,
+ @Nullable Duration maxWaitTime,
+ @Nullable SerializableFunction idFn,
+ boolean useThreadPool) {
+ this(
+ syncFn,
+ parallelism,
+ timerFrequency,
+ maxItemsToBuffer,
+ timeout,
+ maxWaitTime,
+ idFn,
+ useThreadPool,
+ null);
+ }
+
+ public AsyncDoFn(
+ DoFn syncFn,
+ int parallelism,
+ Duration timerFrequency,
+ @Nullable Integer maxItemsToBuffer,
+ @Nullable Duration timeout,
+ @Nullable Duration maxWaitTime,
+ @Nullable SerializableFunction idFn,
+ boolean useThreadPool,
+ @Nullable Coder> coder) {
+ this.syncFn = syncFn;
+ this.parallelism = parallelism;
+ this.timerFrequency = timerFrequency;
+ this.maxItemsToBuffer =
+ (maxItemsToBuffer != null)
+ ? maxItemsToBuffer
+ : Math.max(parallelism * 2, DEFAULT_MIN_BUFFER_CAPACITY);
+ this.timeout = (timeout != null) ? timeout : Duration.standardSeconds(DEFAULT_TIMEOUT_SEC);
+ this.maxWaitTime =
+ (maxWaitTime != null) ? maxWaitTime : Duration.millis(DEFAULT_MAX_WAIT_TIME_MS);
+ this.idFn =
+ (idFn != null)
+ ? idFn
+ : (SerializableFunction)
+ input -> java.util.Objects.requireNonNull(input);
+ this.useThreadPool = useThreadPool;
+ this.uuid = UUID.randomUUID().toString();
+ this.toProcessSpec = (coder != null) ? StateSpecs.bag(coder) : StateSpecs.bag();
+ }
+
+ private ExecutorService getThreadPool() {
+ ExecutorService threadPool = pool.get(uuid);
+ if (threadPool == null) {
+ throw new IllegalStateException("Thread pool not initialized for UUID: " + uuid);
+ }
+ return threadPool;
+ }
+
+ @SuppressWarnings("unchecked")
+ private ConcurrentHashMap> getProcessingElements() {
+ ConcurrentHashMap> elements = processingElements.get(uuid);
+ if (elements == null) {
+ throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid);
+ }
+ return (ConcurrentHashMap>) (ConcurrentHashMap, ?>) elements;
+ }
+
+ private AtomicInteger getItemsInBuffer() {
+ AtomicInteger buffer = itemsInBuffer.get(uuid);
+ if (buffer == null) {
+ throw new IllegalStateException("Buffer counter not initialized for UUID: " + uuid);
+ }
+ return buffer;
+ }
+
+ @Setup
+ public void setup(PipelineOptions options) {
+ this.pipelineOptions = options;
+
+ // Setup the wrapped DoFn
+ DoFnInvokers.invokerFor(syncFn)
+ .invokeSetup(
+ new DoFnInvoker.BaseArgumentProvider() {
+ @Override
+ public PipelineOptions pipelineOptions() {
+ return options;
+ }
+
+ @Override
+ public String getErrorContext() {
+ return "AsyncDoFn/Setup";
+ }
+ });
+
+ if (useThreadPool) {
+ LOG.info("Using thread pool for asynchronous execution with parallelism {}", parallelism);
+ }
+
+ lock.lock();
+ try {
+ pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism));
+ processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>());
+ itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0));
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ // Clean up JVM-wide shared resources to prevent thread leaks on the worker
+ @Teardown
+ public void teardown() {
+ DoFnInvokers.invokerFor(syncFn).invokeTeardown();
+
+ ExecutorService threadPool;
+ lock.lock();
+ try {
+ threadPool = pool.remove(uuid);
+ processingElements.remove(uuid);
+ itemsInBuffer.remove(uuid);
+ } finally {
+ lock.unlock();
+ }
+
+ if (threadPool != null) {
+ threadPool.shutdown();
+ try {
+ if (!threadPool.awaitTermination(TEARDOWN_AWAIT_SEC, TimeUnit.SECONDS)) {
+ threadPool.shutdownNow();
+ }
+ } catch (InterruptedException e) {
+ threadPool.shutdownNow();
+ Thread.currentThread().interrupt();
+ }
+ }
+ }
+
+ // Asynchronous Scheduling & Deduplication
+ // Submits tasks to the background thread pool. If an element with the same ID is already
+ // in-flight,
+ // the submission is silently ignored to enforce exactly-once semantics.
+ private boolean scheduleIfRoom(
+ KV element, BoundedWindow window, Instant timestamp, boolean ignoreBuffer) {
+ lock.lock();
+ try {
+ ConcurrentHashMap> activeElements = getProcessingElements();
+ Object elementId = idFn.apply(element.getValue());
+
+ if (activeElements.containsKey(elementId)) {
+ LOG.info("Item {} already in processing elements", element);
+ return true;
+ }
+
+ int currentBuffer = getItemsInBuffer().get();
+ if (currentBuffer < maxItemsToBuffer || ignoreBuffer) {
+ java.util.concurrent.Executor executor =
+ useThreadPool ? getThreadPool() : java.util.concurrent.ForkJoinPool.commonPool();
+
+ // Pending asynchronous task that will produce a list of outputs
+ CompletableFuture> future =
+ CompletableFuture.supplyAsync(
+ () -> {
+ try {
+ AccumulatingOutputReceiver receiver =
+ new AccumulatingOutputReceiver<>();
+ DoFnInvoker invoker = DoFnInvokers.invokerFor(syncFn);
+
+ DoFnInvoker.ArgumentProvider bundleArgProvider =
+ new DoFnInvoker.BaseArgumentProvider() {
+ @Override
+ public PipelineOptions pipelineOptions() {
+ PipelineOptions options = pipelineOptions;
+ if (options == null) {
+ throw new IllegalStateException("PipelineOptions not set");
+ }
+ return options;
+ }
+
+ @Override
+ public DoFn.FinishBundleContext finishBundleContext(
+ DoFn doFn) {
+ return doFn.new FinishBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return pipelineOptions();
+ }
+
+ @Override
+ public void output(
+ OutputT output, Instant timestamp, BoundedWindow window) {
+ receiver.output(output);
+ }
+
+ @Override
+ public void output(
+ TupleTag tag,
+ T output,
+ Instant timestamp,
+ BoundedWindow window) {
+ throw new UnsupportedOperationException(
+ "Tagged output not supported in FinishBundleContext for AsyncDoFn");
+ }
+ };
+ }
+
+ @Override
+ public String getErrorContext() {
+ return "AsyncDoFn/Bundle";
+ }
+ };
+
+ invoker.invokeStartBundle(bundleArgProvider);
+
+ DoFnInvoker.ArgumentProvider processArgProvider =
+ new DoFnInvoker.BaseArgumentProvider() {
+ @Override
+ public InputT element(DoFn doFn) {
+ return element.getValue();
+ }
+
+ @Override
+ public OutputReceiver outputReceiver(
+ DoFn doFn) {
+ return receiver;
+ }
+
+ @Override
+ public BoundedWindow window() {
+ return window;
+ }
+
+ @Override
+ public Instant timestamp(DoFn doFn) {
+ return timestamp;
+ }
+
+ @Override
+ public PipelineOptions pipelineOptions() {
+ PipelineOptions options = pipelineOptions;
+ if (options == null) {
+ throw new IllegalStateException("PipelineOptions not set");
+ }
+ return options;
+ }
+
+ @Override
+ public String getErrorContext() {
+ return "AsyncDoFn/Process";
+ }
+ };
+
+ invoker.invokeProcessElement(processArgProvider);
+ invoker.invokeFinishBundle(bundleArgProvider);
+
+ return receiver.getOutputs();
+ } catch (Exception e) {
+ throw new CompletionException(e);
+ }
+ },
+ executor);
+
+ // Assigned to 'unused' to satisfy ErrorProne while preserving parent future for
+ // cancellation
+ CompletableFuture> unused =
+ future.whenComplete(
+ (res, ex) -> {
+ lock.lock();
+ try {
+ getItemsInBuffer().decrementAndGet();
+ } finally {
+ lock.unlock();
+ }
+ });
+
+ activeElements.put(elementId, new InFlightElement<>(future));
+ getItemsInBuffer().incrementAndGet();
+ return true;
+ }
+
+ return false;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ private void scheduleItem(KV element, BoundedWindow window, Instant timestamp) {
+ boolean done = false;
+ long sleepTime = INITIAL_BACKOFF_SLEEP_MS;
+ long totalSleep = 0;
+ long timeoutMs = timeout.getMillis();
+
+ while (!done && totalSleep < timeoutMs) {
+ done = scheduleIfRoom(element, window, timestamp, false);
+ if (!done) {
+ long sleep = Math.min(maxWaitTime.getMillis(), sleepTime);
+ if (verboseLogging || totalSleep > BACKPRESSURE_LOG_THRESHOLD_MS) {
+ LOG.info(
+ "buffer is full for item {}, {} waiting {} ms. Have waited for {} ms.",
+ element,
+ getItemsInBuffer().get(),
+ sleep,
+ totalSleep);
+ }
+ try {
+ Thread.sleep(sleep);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException("Interrupted while waiting for space in buffer", e);
+ }
+ sleepTime *= 2;
+ totalSleep += sleep;
+ }
+ }
+ // Timeout: element skips JVM pool but stays in BagState for timer to reschedule later.
+ }
+
+ private Instant nextTimeToFire(@Nullable K key) {
+ long seed = (key == null) ? 0 : key.hashCode();
+ Random random = new Random(seed);
+ double timerFrequencySec = timerFrequency.getMillis() / 1000.0;
+ double nowSec = System.currentTimeMillis() / 1000.0;
+
+ double base = Math.floor((nowSec + timerFrequencySec) / timerFrequencySec) * timerFrequencySec;
+ double offset = random.nextDouble() * timerFrequencySec;
+
+ return Instant.ofEpochMilli((long) ((base + offset) * 1000));
+ }
+
+ @ProcessElement
+ public void processElement(
+ ProcessContext c,
+ BoundedWindow window,
+ @StateId("to_process") BagState> toProcessState,
+ @TimerId("timer") Timer timer) {
+
+ KV element = c.element();
+ scheduleItem(element, window, c.timestamp());
+ toProcessState.add(element);
+
+ Instant timeToFire = nextTimeToFire(element.getKey());
+ timer.set(timeToFire);
+ }
+
+ @OnTimer("timer")
+ public void onTimer(
+ OnTimerContext c,
+ @StateId("to_process") BagState> toProcessState,
+ @TimerId("timer") Timer timer,
+ OutputReceiver receiver) {
+
+ commitFinishedItems(c.fireTimestamp(), toProcessState, timer, receiver);
+ }
+
+ // Synchronizes local task results with the runner's persistent state container.
+ // Emits successfully completed elements, cancels rolled-back tasks, and reschedules lost work.
+ private void commitFinishedItems(
+ Instant fireTimestamp,
+ BagState> toProcessState,
+ Timer timer,
+ OutputReceiver receiver) {
+
+ Iterable> toProcessLocal = toProcessState.read();
+ if (toProcessLocal == null || !toProcessLocal.iterator().hasNext()) {
+ // Early Exit: if BagState is empty, we skip checking activeElements for this key.
+ return;
+ }
+
+ // Since fireTimestamp is key-scoped, we determine the current key from the first element in
+ // state
+ List> stateList = new ArrayList<>();
+ K key = null;
+ for (KV element : toProcessLocal) {
+ stateList.add(element);
+ if (key == null) {
+ key = element.getKey();
+ }
+ }
+
+ if (verboseLogging) {
+ LOG.info("processing timer for key: {}", key);
+ }
+
+ ConcurrentHashMap> activeElements = getProcessingElements();
+
+ List> toReturn = new ArrayList<>();
+ Set> finishedItems = new HashSet<>();
+ List> toReschedule = new ArrayList<>();
+
+ int itemsFinished = 0;
+ int itemsNotYetFinished = 0;
+ int itemsRescheduled = 0;
+ int itemsCancelled = 0;
+
+ Set finishedElementIds = new HashSet<>();
+ Set inFlightElementIds = new HashSet<>();
+ Set rescheduledElementIds = new HashSet<>();
+
+ lock.lock();
+ try {
+ for (KV element : stateList) {
+ Object elementId = idFn.apply(element.getValue());
+
+ // Skip processing if we already completed, rescheduled, or found this elementId active in
+ // this cycle
+ if (finishedElementIds.contains(elementId)
+ || rescheduledElementIds.contains(elementId)
+ || inFlightElementIds.contains(elementId)) {
+ continue;
+ }
+
+ if (activeElements.containsKey(elementId)) {
+ InFlightElement inFlight = activeElements.get(elementId);
+ if (inFlight.future.isDone()) {
+ try {
+ if (!inFlight.future.isCancelled()) {
+ toReturn.add(inFlight.future.get());
+ }
+ finishedItems.add(element);
+ finishedElementIds.add(elementId);
+ activeElements.remove(elementId);
+ itemsFinished++;
+ } catch (Exception e) {
+ LOG.error("Error executing async task for element {}", element, e);
+ finishedItems.add(element);
+ finishedElementIds.add(elementId);
+ activeElements.remove(elementId);
+ }
+ } else {
+ inFlightElementIds.add(elementId);
+ itemsNotYetFinished++;
+ }
+ } else {
+ LOG.info(
+ "Item {} found in state but not in local active elements, scheduling now", element);
+ toReschedule.add(element);
+ rescheduledElementIds.add(elementId);
+ itemsRescheduled++;
+ }
+ }
+ } finally {
+ lock.unlock();
+ }
+
+ // Reschedule missing elements
+ for (KV element : toReschedule) {
+ scheduleItem(element, GlobalWindow.INSTANCE, fireTimestamp);
+ }
+
+ // Update State: keep only unfinished items
+ toProcessState.clear();
+ int itemsInProcessingState = 0;
+ for (KV element : stateList) {
+ if (!finishedItems.contains(element)) {
+ toProcessState.add(element);
+ itemsInProcessingState++;
+ }
+ }
+
+ // Emit completed outputs (Emit completed tasks immediately; do not wait for all active tasks to
+ // finish).
+ for (List outputs : toReturn) {
+ for (OutputT out : outputs) {
+ receiver.output(out);
+ }
+ }
+
+ LOG.info(
+ "Items finished: {}, not yet finished: {}, rescheduled: {}, cancelled: {}, in processing state: {}",
+ itemsFinished,
+ itemsNotYetFinished,
+ itemsRescheduled,
+ itemsCancelled,
+ itemsInProcessingState);
+
+ if (itemsInProcessingState > 0) {
+ Instant timeToFire = nextTimeToFire(key);
+ timer.set(timeToFire);
+ }
+ }
+
+ // Package-private helper methods for testing direct execution without Pipeline / ProcessContext
+ // boilerplate
+ void processDirect(
+ KV element,
+ BoundedWindow window,
+ Instant timestamp,
+ BagState> toProcessState,
+ Timer timer) {
+ scheduleItem(element, window, timestamp);
+ toProcessState.add(element);
+ Instant timeToFire = nextTimeToFire(element.getKey());
+ timer.set(timeToFire);
+ }
+
+ List commitFinishedItemsDirect(
+ Instant fireTimestamp, BagState> toProcessState, Timer timer) {
+ AccumulatingOutputReceiver receiver = new AccumulatingOutputReceiver<>();
+ commitFinishedItems(fireTimestamp, toProcessState, timer, receiver);
+ return receiver.getOutputs();
+ }
+
+ boolean isEmpty() {
+ return getItemsInBuffer().get() == 0;
+ }
+
+ int getItemsInBufferCount() {
+ return getItemsInBuffer().get();
+ }
+
+ static void resetState() {
+ lock.lock();
+ try {
+ for (Map.Entry entry : pool.entrySet()) {
+ entry.getValue().shutdownNow();
+ }
+ pool.clear();
+ processingElements.clear();
+ itemsInBuffer.clear();
+ } finally {
+ lock.unlock();
+ }
+ }
+}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java
new file mode 100644
index 000000000000..912aca3f309c
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java
@@ -0,0 +1,733 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.locks.ReentrantLock;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for verifying async processing structures and logic. */
+@RunWith(JUnit4.class)
+public class AsyncDoFnTest implements Serializable {
+
+ @Rule public final transient TestPipeline p = TestPipeline.create();
+ private final boolean useThreadPool = true;
+
+ // Used for testing basic DoFn processing logic with optional latency.
+ private static class BasicDofn extends DoFn {
+ private final long sleepTimeMs;
+ private int processed = 0;
+ private final ReentrantLock lock = new ReentrantLock();
+
+ BasicDofn(long sleepTimeMs) {
+ this.sleepTimeMs = sleepTimeMs;
+ }
+
+ BasicDofn() {
+ this(0);
+ }
+
+ @ProcessElement
+ public void processElement(@Element String element, OutputReceiver receiver) {
+ if (sleepTimeMs > 0) {
+ try {
+ Thread.sleep(sleepTimeMs);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }
+ lock.lock();
+ try {
+ processed += 1;
+ } finally {
+ lock.unlock();
+ }
+ receiver.output(element);
+ }
+
+ int getProcessed() {
+ lock.lock();
+ try {
+ return processed;
+ } finally {
+ lock.unlock();
+ }
+ }
+ }
+
+ // Used for testing multi element processing with optional finish bundle call.
+ private static class MultiElementDoFn extends DoFn {
+ @ProcessElement
+ public void processElement(@Element String element, OutputReceiver receiver) {
+ receiver.output(element);
+ receiver.output(element);
+ }
+
+ @FinishBundle
+ public void finishBundle(FinishBundleContext c) {
+ c.output("bundle end", Instant.now(), GlobalWindow.INSTANCE);
+ }
+ }
+
+ // Used for testing BagState thread safety.
+ private static class FakeBagState implements BagState {
+ private final List items;
+ private final ReentrantLock lock = new ReentrantLock();
+
+ FakeBagState(List initialItems) {
+ this.items = new ArrayList<>(initialItems);
+ }
+
+ FakeBagState(T initialItem) {
+ this(new ArrayList<>(List.of(initialItem)));
+ }
+
+ FakeBagState() {
+ this(new ArrayList<>());
+ }
+
+ @Override
+ public void add(T item) {
+ lock.lock();
+ try {
+ items.add(item);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public void clear() {
+ lock.lock();
+ try {
+ items.clear();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public Iterable read() {
+ lock.lock();
+ try {
+ return new ArrayList<>(items);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public ReadableState isEmpty() {
+ return new ReadableState() {
+ @Override
+ public Boolean read() {
+ lock.lock();
+ try {
+ return items.isEmpty();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public ReadableState readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public BagState readLater() {
+ return this;
+ }
+ }
+
+ // 4. Used for testing Timer mock implementations.
+ private static class FakeTimer implements Timer {
+ private Instant time = Instant.EPOCH;
+
+ @Override
+ public void set(Instant absoluteTime) {
+ this.time = absoluteTime;
+ }
+
+ @Override
+ public void setRelative() {}
+
+ @Override
+ public void clear() {
+ this.time = Instant.EPOCH;
+ }
+
+ @Override
+ public Timer offset(Duration offset) {
+ return this;
+ }
+
+ @Override
+ public Timer align(Duration period) {
+ return this;
+ }
+
+ @Override
+ public Timer withOutputTimestamp(Instant outputTime) {
+ return this;
+ }
+
+ @Override
+ public Timer withNoOutputTimestamp() {
+ return this;
+ }
+
+ @Override
+ public Instant getCurrentRelativeTime() {
+ return time;
+ }
+ }
+
+ @Before
+ public void setUp() {
+ AsyncDoFn.resetState();
+ }
+
+ private void waitForEmpty(AsyncDoFn, ?, ?> asyncDoFn) {
+ waitForEmpty(asyncDoFn, 10);
+ }
+
+ private void waitForEmpty(AsyncDoFn, ?, ?> asyncDoFn, int timeoutSeconds) {
+ int count = 0;
+ while (!asyncDoFn.isEmpty()) {
+ try {
+ Thread.sleep(1000);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+ count += 1;
+ if (count > timeoutSeconds) {
+ throw new RuntimeException("Timed out waiting for async dofn to be empty");
+ }
+ }
+ try {
+ Thread.sleep(1000);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }
+
+ private void checkOutput(List result, List expectedOutput) {
+ List resultStr = new ArrayList<>();
+ for (T val : result) {
+ resultStr.add(val.toString());
+ }
+ List expectedStr = new ArrayList<>();
+ for (T val : expectedOutput) {
+ expectedStr.add(val.toString());
+ }
+ Collections.sort(resultStr);
+ Collections.sort(expectedStr);
+ assertEquals(expectedStr, resultStr);
+ }
+
+ private void checkItemsInBuffer(AsyncDoFn, ?, ?> asyncDoFn, int expectedCount) {
+ assertEquals(expectedCount, asyncDoFn.getItemsInBufferCount());
+ }
+
+ // Test 1: testCustomIdFn
+ // Verifies key extraction custom logic. Duplicate elements (same custom ID but different payload)
+ // should be recognized as already in-flight and deduplicated.
+ @Test
+ public void testCustomIdFn() {
+ class CustomIdObject implements Serializable {
+ final int elementId;
+ final String value;
+
+ CustomIdObject(int elementId, String value) {
+ this.elementId = elementId;
+ this.value = value;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (!(o instanceof CustomIdObject)) return false;
+ CustomIdObject that = (CustomIdObject) o;
+ return elementId == that.elementId;
+ }
+
+ @Override
+ public int hashCode() {
+ return java.util.Objects.hash(elementId);
+ }
+
+ @Override
+ public String toString() {
+ return "CustomIdObject{id=" + elementId + ", val=" + value + "}";
+ }
+ }
+
+ class CustomIdDofn extends DoFn {
+ @ProcessElement
+ public void processElement(@Element CustomIdObject element, OutputReceiver receiver) {
+ receiver.output(element.value);
+ }
+ }
+
+ CustomIdDofn dofn = new CustomIdDofn();
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn,
+ 1,
+ Duration.standardSeconds(5),
+ null,
+ null,
+ null,
+ x -> x.elementId,
+ useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeBagState> fakeBagState = new FakeBagState<>();
+ FakeTimer fakeTimer = new FakeTimer();
+
+ KV msg1 = KV.of("key1", new CustomIdObject(1, "a"));
+ KV msg2 = KV.of("key1", new CustomIdObject(1, "b"));
+
+ asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+ asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+
+ waitForEmpty(asyncDoFn);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.singletonList("a"));
+ assertEquals(0, fakeBagState.items.size());
+ }
+
+ // Test 2: testBasic
+ // Verifies the standard end-to-end execution flow. Elements should be queued in persistent state
+ // and output correctly upon completion.
+ @Test
+ public void testBasic() {
+ BasicDofn dofn = new BasicDofn();
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeBagState> fakeBagState = new FakeBagState<>();
+ FakeTimer fakeTimer = new FakeTimer();
+ KV msg = KV.of("key1", "1");
+
+ asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+
+ assertEquals(1, fakeBagState.items.size());
+ assertNotEquals(Instant.EPOCH, fakeTimer.getCurrentRelativeTime());
+
+ waitForEmpty(asyncDoFn);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.singletonList("1"));
+ assertEquals(1, dofn.getProcessed());
+ assertEquals(0, fakeBagState.items.size());
+ }
+
+ // Test 3: testMultiKey
+ // Verifies key grouping isolation. Firing a timer for one partition key must not release
+ // or interfere with elements queued under a different partition key.
+ @Test
+ public void testMultiKey() {
+ for (boolean useThreadPool : new boolean[] {true, false}) {
+ BasicDofn dofn = new BasicDofn();
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeBagState> fakeBagStateKey1 = new FakeBagState<>();
+ FakeBagState> fakeBagStateKey2 = new FakeBagState<>();
+ FakeTimer fakeTimer = new FakeTimer();
+
+ KV msg1 = KV.of("key1", "1");
+ KV msg2 = KV.of("key2", "2");
+
+ asyncDoFn.processDirect(
+ msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagStateKey1, fakeTimer);
+ asyncDoFn.processDirect(
+ msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagStateKey2, fakeTimer);
+
+ waitForEmpty(asyncDoFn);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagStateKey2, fakeTimer);
+ checkOutput(result, Collections.singletonList("2"));
+ assertEquals(1, fakeBagStateKey1.items.size());
+ assertEquals(0, fakeBagStateKey2.items.size());
+
+ result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagStateKey1, fakeTimer);
+ checkOutput(result, Collections.singletonList("1"));
+ assertEquals(0, fakeBagStateKey1.items.size());
+ assertEquals(0, fakeBagStateKey2.items.size());
+ }
+ }
+
+ // Test 4: testLongItem
+ // Verifies that outputs are kept in-flight and not committed prematurely if the background
+ // execution task has not finished processing yet.
+ @Test
+ public void testLongItem() {
+ BasicDofn dofn = new BasicDofn(1000);
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeBagState> fakeBagState = new FakeBagState<>();
+ FakeTimer fakeTimer = new FakeTimer();
+ KV msg = KV.of("key1", "1");
+
+ asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.emptyList());
+ assertEquals(0, dofn.getProcessed());
+ assertEquals(1, fakeBagState.items.size());
+
+ waitForEmpty(asyncDoFn, 20);
+
+ result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.singletonList("1"));
+ assertEquals(1, dofn.getProcessed());
+ assertEquals(0, fakeBagState.items.size());
+ }
+
+ // Test 5: testLostItem
+ // Verifies if the local worker's in-memory cache is empty but the runner's
+ // persistent state contains pending items.
+ // The wrapper must automatically detect the mismatch and reschedule execution.
+ @Test
+ public void testLostItem() {
+ BasicDofn dofn = new BasicDofn();
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeTimer fakeTimer = new FakeTimer();
+ KV msg = KV.of("key1", "1");
+ FakeBagState> fakeBagState = new FakeBagState<>(msg);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.emptyList());
+
+ waitForEmpty(asyncDoFn);
+
+ result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.singletonList("1"));
+ }
+
+ // Test 6: testCancelledItem
+ // Verifies active task cancellation. If a pending element is deleted from the runner's persistent
+ // state prior to a commit (e.g., due to a rollback), the background future task must be actively
+ // cancelled.
+ @Test
+ public void testCancelledItem() {
+ BasicDofn dofn = new BasicDofn();
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ KV msg1 = KV.of("key1", "1");
+ KV msg2 = KV.of("key1", "2");
+ FakeTimer fakeTimer = new FakeTimer();
+ FakeBagState> fakeBagState = new FakeBagState<>();
+
+ asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+ asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+
+ waitForEmpty(asyncDoFn);
+
+ fakeBagState.clear();
+ fakeBagState.add(msg2);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.singletonList("2"));
+ assertEquals(0, fakeBagState.items.size());
+ }
+
+ // Test 7: testMultiElementDofn
+ // Verifies support for DoFns that emit multiple outputs per element, and correctly aggregates
+ // outputs produced during the finishBundle stage of the sync DoFn's lifecycle.
+ @Test
+ public void testMultiElementDofn() {
+ MultiElementDoFn dofn = new MultiElementDoFn();
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeBagState> fakeBagState = new FakeBagState<>();
+ FakeTimer fakeTimer = new FakeTimer();
+ KV msg = KV.of("key1", "1");
+
+ asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+
+ waitForEmpty(asyncDoFn);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Arrays.asList("1", "1", "bundle end"));
+ assertEquals(0, fakeBagState.items.size());
+ }
+
+ // Test 8: testDuplicates
+ // Verifies deduplication of duplicate elements under active processing.
+ // Identical elements should not spawn multiple concurrent background executions.
+ @Test
+ public void testDuplicates() {
+ BasicDofn dofn = new BasicDofn(1000);
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeBagState> fakeBagState = new FakeBagState<>();
+ FakeTimer fakeTimer = new FakeTimer();
+ KV msg = KV.of("key1", "1");
+
+ asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+ fakeBagState.clear();
+ asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+
+ assertEquals(1, fakeBagState.items.size());
+
+ waitForEmpty(asyncDoFn);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.singletonList("1"));
+ assertEquals(0, fakeBagState.items.size());
+ }
+
+ // Test 9: testBufferCount
+ // Verifies accurate in-flight metrics tracking.
+ // The item count in the buffer must increment on task scheduling
+ // and decrement immediately upon execution completion.
+ @Test
+ public void testBufferCount() {
+ BasicDofn dofn = new BasicDofn(1000);
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ KV msg = KV.of("key1", "1");
+ FakeTimer fakeTimer = new FakeTimer();
+ FakeBagState> fakeBagState = new FakeBagState<>();
+
+ asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+ checkItemsInBuffer(asyncDoFn, 1);
+
+ waitForEmpty(asyncDoFn);
+ checkItemsInBuffer(asyncDoFn, 0);
+
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkItemsInBuffer(asyncDoFn, 0);
+ }
+
+ // Test 10: testBufferStopsAcceptingItems
+ // Verifies queue boundaries and backpressure throttling.
+ // When concurrent threads push elements exceeding the capacity limit,
+ // the scheduler must block and delay submissions appropriately.
+ @Test
+ public void testBufferStopsAcceptingItems() {
+ BasicDofn dofn = new BasicDofn(1000);
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn,
+ 1,
+ Duration.standardSeconds(5),
+ 5, // max buffer capacity
+ null,
+ null,
+ null,
+ useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeTimer fakeTimer = new FakeTimer();
+ FakeBagState> fakeBagState = new FakeBagState<>();
+
+ ExecutorService poolExecutor = Executors.newFixedThreadPool(10);
+ List expectedOutput = new ArrayList<>();
+ List> futures = new ArrayList<>();
+
+ for (int i = 0; i < 10; i++) {
+ final int idx = i;
+ expectedOutput.add(String.valueOf(idx));
+ futures.add(
+ poolExecutor.submit(
+ () -> {
+ KV item = KV.of("key", String.valueOf(idx));
+ asyncDoFn.processDirect(
+ item, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+ }));
+ }
+
+ try {
+ Thread.sleep(200);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+
+ assertEquals(5, asyncDoFn.getItemsInBufferCount());
+
+ waitForEmpty(asyncDoFn, 100);
+
+ // Verify that all background tasks completed successfully without throwing exceptions
+ for (Future> future : futures) {
+ try {
+ future.get(); // This will re-throw any exception that occurred in the background thread
+ } catch (Exception e) {
+ throw new AssertionError("Background task failed", e);
+ }
+ }
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+
+ waitForEmpty(asyncDoFn, 100);
+
+ result.addAll(
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer));
+
+ checkOutput(result, expectedOutput);
+ checkItemsInBuffer(asyncDoFn, 0);
+ poolExecutor.shutdown();
+ }
+
+ // Test 11: testBufferWithCancellation
+ // Verifies backpressure behavior in conjunction with element cancellation.
+ // Elements that are actively cancelled during queue throttling should be dropped cleanly from the
+ // buffer.
+ @Test
+ public void testBufferWithCancellation() {
+ BasicDofn dofn = new BasicDofn(1000);
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ KV msg1 = KV.of("key1", "1");
+ KV msg2 = KV.of("key1", "2");
+ FakeTimer fakeTimer = new FakeTimer();
+ FakeBagState> fakeBagState = new FakeBagState<>();
+
+ asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+ asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+
+ checkItemsInBuffer(asyncDoFn, 2);
+
+ fakeBagState.clear();
+ fakeBagState.add(msg2);
+
+ List result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkOutput(result, Collections.emptyList());
+ assertEquals(1, fakeBagState.items.size());
+
+ waitForEmpty(asyncDoFn);
+
+ result =
+ asyncDoFn.commitFinishedItemsDirect(
+ fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer);
+ checkItemsInBuffer(asyncDoFn, 0);
+ checkOutput(result, Collections.singletonList("2"));
+ }
+
+ // Test 12: testResetStateConcurrentTeardown
+ // Verifies safe resource cleanup during concurrent shutdown.
+ // Resetting the global shared execution state while workers are running
+ // must complete cleanly without thread or lock deadlocks.
+ @Test
+ public void testResetStateConcurrentTeardown() {
+ BasicDofn dofn = new BasicDofn(500);
+ AsyncDoFn asyncDoFn =
+ new AsyncDoFn<>(
+ dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool);
+ asyncDoFn.setup(null);
+
+ FakeBagState> fakeBagState = new FakeBagState<>();
+ FakeTimer fakeTimer = new FakeTimer();
+
+ asyncDoFn.processDirect(
+ KV.of("key1", "1"), GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer);
+
+ try {
+ Thread.sleep(50);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+
+ // Verify calling resetState() while background tasks are running finishes cleanly
+ AsyncDoFn.resetState();
+ }
+}