From af6d166f7b0228e25f16fa30bde5f7580ab9c189 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Fri, 22 May 2026 21:03:19 +0000 Subject: [PATCH 1/2] Created an Asynchronous Wrapper for DoFn as well as JUnit tests for the Apache Beam Java SDK (#38529) --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 689 ++++++++++++++++ .../beam/sdk/transforms/AsyncDoFnTest.java | 733 ++++++++++++++++++ 2 files changed, 1422 insertions(+) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java new file mode 100644 index 000000000000..31c7e6c3d78c --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -0,0 +1,689 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + *

For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn extends DoFn, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap< + String, ConcurrentHashMap>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap itemsInBuffer = + new ConcurrentHashMap<>(); + + private static final ReentrantLock lock = new ReentrantLock(); + private static final boolean verboseLogging = false; + + private static class InFlightElement { + final KV element; + final CompletableFuture> future; + + InFlightElement(KV element, CompletableFuture> future) { + this.element = element; + this.future = future; + } + } + + // The In-Memory Accumulating Receiver + // Accumulates elements in-memory during asynchronous background worker execution. + // Buffered elements are only committed downstream once the parent task completes successfully + // and the timer fires. + private static class AccumulatingOutputReceiver implements OutputReceiver { + private final List outputs = Collections.synchronizedList(new ArrayList<>()); + + @Override + public org.apache.beam.sdk.values.OutputBuilder builder(T value) { + return org.apache.beam.sdk.values.WindowedValues.builder() + .setValue(value) + .setTimestamp(Instant.now()) + .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) + .setReceiver(windowedValue -> outputs.add(windowedValue.getValue())); + } + + // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. + // JVM optimization to prevent garbage collection pressure under high pipeline throughput. + @Override + public void output(T output) { + outputs.add(output); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + outputs.add(output); + } + + public List getOutputs() { + return outputs; + } + } + + public AsyncDoFn( + DoFn syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction idFn, + boolean useThreadPool) { + this( + syncFn, + parallelism, + timerFrequency, + maxItemsToBuffer, + timeout, + maxWaitTime, + idFn, + useThreadPool, + null); + } + + public AsyncDoFn( + DoFn syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction idFn, + boolean useThreadPool, + @Nullable Coder> coder) { + this.syncFn = syncFn; + this.parallelism = parallelism; + this.timerFrequency = timerFrequency; + this.maxItemsToBuffer = + (maxItemsToBuffer != null) + ? maxItemsToBuffer + : Math.max(parallelism * 2, DEFAULT_MIN_BUFFER_CAPACITY); + this.timeout = (timeout != null) ? timeout : Duration.standardSeconds(DEFAULT_TIMEOUT_SEC); + this.maxWaitTime = + (maxWaitTime != null) ? maxWaitTime : Duration.millis(DEFAULT_MAX_WAIT_TIME_MS); + this.idFn = + (idFn != null) + ? idFn + : (SerializableFunction) + input -> java.util.Objects.requireNonNull(input); + this.useThreadPool = useThreadPool; + this.uuid = UUID.randomUUID().toString(); + this.toProcessSpec = (coder != null) ? StateSpecs.bag(coder) : StateSpecs.bag(); + } + + private ExecutorService getThreadPool() { + ExecutorService threadPool = pool.get(uuid); + if (threadPool == null) { + throw new IllegalStateException("Thread pool not initialized for UUID: " + uuid); + } + return threadPool; + } + + @SuppressWarnings("unchecked") + private ConcurrentHashMap> getProcessingElements() { + ConcurrentHashMap> elements = processingElements.get(uuid); + if (elements == null) { + throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); + } + return (ConcurrentHashMap>) + (ConcurrentHashMap) elements; + } + + private AtomicInteger getItemsInBuffer() { + AtomicInteger buffer = itemsInBuffer.get(uuid); + if (buffer == null) { + throw new IllegalStateException("Buffer counter not initialized for UUID: " + uuid); + } + return buffer; + } + + @Setup + public void setup(PipelineOptions options) { + this.pipelineOptions = options; + + // Setup the wrapped DoFn + DoFnInvokers.invokerFor(syncFn) + .invokeSetup( + new DoFnInvoker.BaseArgumentProvider() { + @Override + public PipelineOptions pipelineOptions() { + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Setup"; + } + }); + + if (useThreadPool) { + LOG.info("Using thread pool for asynchronous execution with parallelism {}", parallelism); + } + + lock.lock(); + try { + pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); + itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); + } finally { + lock.unlock(); + } + } + + // Clean up JVM-wide shared resources to prevent thread leaks on the worker + @Teardown + public void teardown() { + DoFnInvokers.invokerFor(syncFn).invokeTeardown(); + + ExecutorService threadPool; + lock.lock(); + try { + threadPool = pool.remove(uuid); + processingElements.remove(uuid); + itemsInBuffer.remove(uuid); + } finally { + lock.unlock(); + } + + if (threadPool != null) { + threadPool.shutdown(); + try { + if (!threadPool.awaitTermination(TEARDOWN_AWAIT_SEC, TimeUnit.SECONDS)) { + threadPool.shutdownNow(); + } + } catch (InterruptedException e) { + threadPool.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + + // Asynchronous Scheduling & Deduplication + // Submits tasks to the background thread pool. If an element with the same ID is already + // in-flight, + // the submission is silently ignored to enforce exactly-once semantics. + private boolean scheduleIfRoom( + KV element, BoundedWindow window, Instant timestamp, boolean ignoreBuffer) { + lock.lock(); + try { + ConcurrentHashMap> activeElements = + getProcessingElements(); + Object elementId = idFn.apply(element.getValue()); + + if (activeElements.containsKey(elementId)) { + LOG.info("Item {} already in processing elements", element); + return true; + } + + int currentBuffer = getItemsInBuffer().get(); + if (currentBuffer < maxItemsToBuffer || ignoreBuffer) { + java.util.concurrent.Executor executor = + useThreadPool ? getThreadPool() : java.util.concurrent.ForkJoinPool.commonPool(); + + // Pending asynchronous task that will produce a list of outputs + CompletableFuture> future = + CompletableFuture.supplyAsync( + () -> { + try { + AccumulatingOutputReceiver receiver = + new AccumulatingOutputReceiver<>(); + DoFnInvoker invoker = DoFnInvokers.invokerFor(syncFn); + + DoFnInvoker.ArgumentProvider bundleArgProvider = + new DoFnInvoker.BaseArgumentProvider() { + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return doFn.new FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions(); + } + + @Override + public void output( + OutputT output, Instant timestamp, BoundedWindow window) { + receiver.output(output); + } + + @Override + public void output( + TupleTag tag, + T output, + Instant timestamp, + BoundedWindow window) { + throw new UnsupportedOperationException( + "Tagged output not supported in FinishBundleContext for AsyncDoFn"); + } + }; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Bundle"; + } + }; + + invoker.invokeStartBundle(bundleArgProvider); + + DoFnInvoker.ArgumentProvider processArgProvider = + new DoFnInvoker.BaseArgumentProvider() { + @Override + public InputT element(DoFn doFn) { + return element.getValue(); + } + + @Override + public OutputReceiver outputReceiver( + DoFn doFn) { + return receiver; + } + + @Override + public BoundedWindow window() { + return window; + } + + @Override + public Instant timestamp(DoFn doFn) { + return timestamp; + } + + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Process"; + } + }; + + invoker.invokeProcessElement(processArgProvider); + invoker.invokeFinishBundle(bundleArgProvider); + + return receiver.getOutputs(); + } catch (Exception e) { + throw new CompletionException(e); + } + }, + executor); + + // Assigned to 'unused' to satisfy ErrorProne while preserving parent future for + // cancellation + CompletableFuture> unused = + future.whenComplete( + (res, ex) -> { + lock.lock(); + try { + getItemsInBuffer().decrementAndGet(); + } finally { + lock.unlock(); + } + }); + + activeElements.put(elementId, new InFlightElement<>(element, future)); + getItemsInBuffer().incrementAndGet(); + return true; + } + + return false; + } finally { + lock.unlock(); + } + } + + private void scheduleItem(KV element, BoundedWindow window, Instant timestamp) { + boolean done = false; + long sleepTime = INITIAL_BACKOFF_SLEEP_MS; + long totalSleep = 0; + long timeoutMs = timeout.getMillis(); + + while (!done && totalSleep < timeoutMs) { + done = scheduleIfRoom(element, window, timestamp, false); + if (!done) { + long sleep = Math.min(maxWaitTime.getMillis(), sleepTime); + if (verboseLogging || totalSleep > BACKPRESSURE_LOG_THRESHOLD_MS) { + LOG.info( + "buffer is full for item {}, {} waiting {} ms. Have waited for {} ms.", + element, + getItemsInBuffer().get(), + sleep, + totalSleep); + } + try { + Thread.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for space in buffer", e); + } + sleepTime *= 2; + totalSleep += sleep; + } + } + // Timeout: element skips JVM pool but stays in BagState for timer to reschedule later. + } + + private Instant nextTimeToFire(@Nullable K key) { + long seed = (key == null) ? 0 : key.hashCode(); + Random random = new Random(seed); + double timerFrequencySec = timerFrequency.getMillis() / 1000.0; + double nowSec = System.currentTimeMillis() / 1000.0; + + double base = Math.floor((nowSec + timerFrequencySec) / timerFrequencySec) * timerFrequencySec; + double offset = random.nextDouble() * timerFrequencySec; + + return Instant.ofEpochMilli((long) ((base + offset) * 1000)); + } + + @ProcessElement + public void processElement( + ProcessContext c, + BoundedWindow window, + @StateId("to_process") BagState> toProcessState, + @TimerId("timer") Timer timer) { + + KV element = c.element(); + scheduleItem(element, window, c.timestamp()); + toProcessState.add(element); + + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + @OnTimer("timer") + public void onTimer( + OnTimerContext c, + @StateId("to_process") BagState> toProcessState, + @TimerId("timer") Timer timer, + OutputReceiver receiver) { + + commitFinishedItems(c.fireTimestamp(), toProcessState, timer, receiver); + } + + // Synchronizes local task results with the runner's persistent state container. + // Emits successfully completed elements, cancels rolled-back tasks, and reschedules lost work. + private void commitFinishedItems( + Instant fireTimestamp, + BagState> toProcessState, + Timer timer, + OutputReceiver receiver) { + + Iterable> toProcessLocal = toProcessState.read(); + if (toProcessLocal == null || !toProcessLocal.iterator().hasNext()) { + // Early Exit: if BagState is empty, we skip checking activeElements for this key. + return; + } + + // Since fireTimestamp is key-scoped, we determine the current key from the first element in + // state + List> stateList = new ArrayList<>(); + K key = null; + for (KV element : toProcessLocal) { + stateList.add(element); + if (key == null) { + key = element.getKey(); + } + } + + if (verboseLogging) { + LOG.info("processing timer for key: {}", key); + } + + ConcurrentHashMap> activeElements = + getProcessingElements(); + Set stateIds = new HashSet<>(); + for (KV element : stateList) { + stateIds.add(idFn.apply(element.getValue())); + } + + List toCancel = new ArrayList<>(); + lock.lock(); + try { + // Cancel any active elements for this key that are no longer in runner's state + for (Map.Entry> entry : + activeElements.entrySet()) { + Object elementId = entry.getKey(); + InFlightElement inFlight = entry.getValue(); + + if (Objects.equals(inFlight.element.getKey(), key) && !stateIds.contains(elementId)) { + inFlight.future.cancel(true); + toCancel.add(elementId); + LOG.info("Cancelling item {} which is no longer in state", inFlight.element); + } + } + for (Object elementId : toCancel) { + activeElements.remove(elementId); + } + } finally { + lock.unlock(); + } + + List> toReturn = new ArrayList<>(); + List> finishedItems = new ArrayList<>(); + List> toReschedule = new ArrayList<>(); + + int itemsFinished = 0; + int itemsNotYetFinished = 0; + int itemsRescheduled = 0; + int itemsCancelled = toCancel.size(); + + lock.lock(); + try { + for (KV element : stateList) { + Object elementId = idFn.apply(element.getValue()); + if (activeElements.containsKey(elementId)) { + InFlightElement inFlight = activeElements.get(elementId); + if (inFlight.future.isDone()) { + try { + if (!inFlight.future.isCancelled()) { + toReturn.add(inFlight.future.get()); + } + finishedItems.add(element); + activeElements.remove(elementId); + itemsFinished++; + } catch (Exception e) { + LOG.error("Error executing async task for element {}", element, e); + finishedItems.add(element); + activeElements.remove(elementId); + } + } else { + itemsNotYetFinished++; + } + } else { + LOG.info( + "Item {} found in state but not in local active elements, scheduling now", element); + toReschedule.add(element); + itemsRescheduled++; + } + } + } finally { + lock.unlock(); + } + + // Reschedule missing elements + for (KV element : toReschedule) { + scheduleItem(element, GlobalWindow.INSTANCE, fireTimestamp); + } + + // Update State: keep only unfinished items + toProcessState.clear(); + int itemsInProcessingState = 0; + for (KV element : stateList) { + if (!finishedItems.contains(element)) { + toProcessState.add(element); + itemsInProcessingState++; + } + } + + // Emit completed outputs (Emit completed tasks immediately; do not wait for all active tasks to + // finish). + for (List outputs : toReturn) { + for (OutputT out : outputs) { + receiver.output(out); + } + } + + LOG.info( + "Items finished: {}, not yet finished: {}, rescheduled: {}, cancelled: {}, in processing state: {}", + itemsFinished, + itemsNotYetFinished, + itemsRescheduled, + itemsCancelled, + itemsInProcessingState); + + if (itemsInProcessingState > 0) { + Instant timeToFire = nextTimeToFire(key); + timer.set(timeToFire); + } + } + + // Package-private helper methods for testing direct execution without Pipeline / ProcessContext + // boilerplate + void processDirect( + KV element, + BoundedWindow window, + Instant timestamp, + BagState> toProcessState, + Timer timer) { + scheduleItem(element, window, timestamp); + toProcessState.add(element); + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + List commitFinishedItemsDirect( + Instant fireTimestamp, BagState> toProcessState, Timer timer) { + AccumulatingOutputReceiver receiver = new AccumulatingOutputReceiver<>(); + commitFinishedItems(fireTimestamp, toProcessState, timer, receiver); + return receiver.getOutputs(); + } + + boolean isEmpty() { + return getItemsInBuffer().get() == 0; + } + + int getItemsInBufferCount() { + return getItemsInBuffer().get(); + } + + static void resetState() { + lock.lock(); + try { + for (Map.Entry entry : pool.entrySet()) { + entry.getValue().shutdownNow(); + } + pool.clear(); + processingElements.clear(); + itemsInBuffer.clear(); + } finally { + lock.unlock(); + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java new file mode 100644 index 000000000000..912aca3f309c --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java @@ -0,0 +1,733 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for verifying async processing structures and logic. */ +@RunWith(JUnit4.class) +public class AsyncDoFnTest implements Serializable { + + @Rule public final transient TestPipeline p = TestPipeline.create(); + private final boolean useThreadPool = true; + + // Used for testing basic DoFn processing logic with optional latency. + private static class BasicDofn extends DoFn { + private final long sleepTimeMs; + private int processed = 0; + private final ReentrantLock lock = new ReentrantLock(); + + BasicDofn(long sleepTimeMs) { + this.sleepTimeMs = sleepTimeMs; + } + + BasicDofn() { + this(0); + } + + @ProcessElement + public void processElement(@Element String element, OutputReceiver receiver) { + if (sleepTimeMs > 0) { + try { + Thread.sleep(sleepTimeMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + lock.lock(); + try { + processed += 1; + } finally { + lock.unlock(); + } + receiver.output(element); + } + + int getProcessed() { + lock.lock(); + try { + return processed; + } finally { + lock.unlock(); + } + } + } + + // Used for testing multi element processing with optional finish bundle call. + private static class MultiElementDoFn extends DoFn { + @ProcessElement + public void processElement(@Element String element, OutputReceiver receiver) { + receiver.output(element); + receiver.output(element); + } + + @FinishBundle + public void finishBundle(FinishBundleContext c) { + c.output("bundle end", Instant.now(), GlobalWindow.INSTANCE); + } + } + + // Used for testing BagState thread safety. + private static class FakeBagState implements BagState { + private final List items; + private final ReentrantLock lock = new ReentrantLock(); + + FakeBagState(List initialItems) { + this.items = new ArrayList<>(initialItems); + } + + FakeBagState(T initialItem) { + this(new ArrayList<>(List.of(initialItem))); + } + + FakeBagState() { + this(new ArrayList<>()); + } + + @Override + public void add(T item) { + lock.lock(); + try { + items.add(item); + } finally { + lock.unlock(); + } + } + + @Override + public void clear() { + lock.lock(); + try { + items.clear(); + } finally { + lock.unlock(); + } + } + + @Override + public Iterable read() { + lock.lock(); + try { + return new ArrayList<>(items); + } finally { + lock.unlock(); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + lock.lock(); + try { + return items.isEmpty(); + } finally { + lock.unlock(); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public BagState readLater() { + return this; + } + } + + // 4. Used for testing Timer mock implementations. + private static class FakeTimer implements Timer { + private Instant time = Instant.EPOCH; + + @Override + public void set(Instant absoluteTime) { + this.time = absoluteTime; + } + + @Override + public void setRelative() {} + + @Override + public void clear() { + this.time = Instant.EPOCH; + } + + @Override + public Timer offset(Duration offset) { + return this; + } + + @Override + public Timer align(Duration period) { + return this; + } + + @Override + public Timer withOutputTimestamp(Instant outputTime) { + return this; + } + + @Override + public Timer withNoOutputTimestamp() { + return this; + } + + @Override + public Instant getCurrentRelativeTime() { + return time; + } + } + + @Before + public void setUp() { + AsyncDoFn.resetState(); + } + + private void waitForEmpty(AsyncDoFn asyncDoFn) { + waitForEmpty(asyncDoFn, 10); + } + + private void waitForEmpty(AsyncDoFn asyncDoFn, int timeoutSeconds) { + int count = 0; + while (!asyncDoFn.isEmpty()) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + count += 1; + if (count > timeoutSeconds) { + throw new RuntimeException("Timed out waiting for async dofn to be empty"); + } + } + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private void checkOutput(List result, List expectedOutput) { + List resultStr = new ArrayList<>(); + for (T val : result) { + resultStr.add(val.toString()); + } + List expectedStr = new ArrayList<>(); + for (T val : expectedOutput) { + expectedStr.add(val.toString()); + } + Collections.sort(resultStr); + Collections.sort(expectedStr); + assertEquals(expectedStr, resultStr); + } + + private void checkItemsInBuffer(AsyncDoFn asyncDoFn, int expectedCount) { + assertEquals(expectedCount, asyncDoFn.getItemsInBufferCount()); + } + + // Test 1: testCustomIdFn + // Verifies key extraction custom logic. Duplicate elements (same custom ID but different payload) + // should be recognized as already in-flight and deduplicated. + @Test + public void testCustomIdFn() { + class CustomIdObject implements Serializable { + final int elementId; + final String value; + + CustomIdObject(int elementId, String value) { + this.elementId = elementId; + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof CustomIdObject)) return false; + CustomIdObject that = (CustomIdObject) o; + return elementId == that.elementId; + } + + @Override + public int hashCode() { + return java.util.Objects.hash(elementId); + } + + @Override + public String toString() { + return "CustomIdObject{id=" + elementId + ", val=" + value + "}"; + } + } + + class CustomIdDofn extends DoFn { + @ProcessElement + public void processElement(@Element CustomIdObject element, OutputReceiver receiver) { + receiver.output(element.value); + } + } + + CustomIdDofn dofn = new CustomIdDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, + 1, + Duration.standardSeconds(5), + null, + null, + null, + x -> x.elementId, + useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + KV msg1 = KV.of("key1", new CustomIdObject(1, "a")); + KV msg2 = KV.of("key1", new CustomIdObject(1, "b")); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("a")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 2: testBasic + // Verifies the standard end-to-end execution flow. Elements should be queued in persistent state + // and output correctly upon completion. + @Test + public void testBasic() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + assertEquals(1, fakeBagState.items.size()); + assertNotEquals(Instant.EPOCH, fakeTimer.getCurrentRelativeTime()); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(1, dofn.getProcessed()); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 3: testMultiKey + // Verifies key grouping isolation. Firing a timer for one partition key must not release + // or interfere with elements queued under a different partition key. + @Test + public void testMultiKey() { + for (boolean useThreadPool : new boolean[] {true, false}) { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagStateKey1 = new FakeBagState<>(); + FakeBagState> fakeBagStateKey2 = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key2", "2"); + + asyncDoFn.processDirect( + msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagStateKey1, fakeTimer); + asyncDoFn.processDirect( + msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagStateKey2, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagStateKey2, fakeTimer); + checkOutput(result, Collections.singletonList("2")); + assertEquals(1, fakeBagStateKey1.items.size()); + assertEquals(0, fakeBagStateKey2.items.size()); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagStateKey1, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(0, fakeBagStateKey1.items.size()); + assertEquals(0, fakeBagStateKey2.items.size()); + } + } + + // Test 4: testLongItem + // Verifies that outputs are kept in-flight and not committed prematurely if the background + // execution task has not finished processing yet. + @Test + public void testLongItem() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + assertEquals(0, dofn.getProcessed()); + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn, 20); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(1, dofn.getProcessed()); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 5: testLostItem + // Verifies if the local worker's in-memory cache is empty but the runner's + // persistent state contains pending items. + // The wrapper must automatically detect the mismatch and reschedule execution. + @Test + public void testLostItem() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + FakeBagState> fakeBagState = new FakeBagState<>(msg); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + + waitForEmpty(asyncDoFn); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + } + + // Test 6: testCancelledItem + // Verifies active task cancellation. If a pending element is deleted from the runner's persistent + // state prior to a commit (e.g., due to a rollback), the background future task must be actively + // cancelled. + @Test + public void testCancelledItem() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key1", "2"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + fakeBagState.clear(); + fakeBagState.add(msg2); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("2")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 7: testMultiElementDofn + // Verifies support for DoFns that emit multiple outputs per element, and correctly aggregates + // outputs produced during the finishBundle stage of the sync DoFn's lifecycle. + @Test + public void testMultiElementDofn() { + MultiElementDoFn dofn = new MultiElementDoFn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Arrays.asList("1", "1", "bundle end")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 8: testDuplicates + // Verifies deduplication of duplicate elements under active processing. + // Identical elements should not spawn multiple concurrent background executions. + @Test + public void testDuplicates() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + fakeBagState.clear(); + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 9: testBufferCount + // Verifies accurate in-flight metrics tracking. + // The item count in the buffer must increment on task scheduling + // and decrement immediately upon execution completion. + @Test + public void testBufferCount() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg = KV.of("key1", "1"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 1); + + waitForEmpty(asyncDoFn); + checkItemsInBuffer(asyncDoFn, 0); + + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 0); + } + + // Test 10: testBufferStopsAcceptingItems + // Verifies queue boundaries and backpressure throttling. + // When concurrent threads push elements exceeding the capacity limit, + // the scheduler must block and delay submissions appropriately. + @Test + public void testBufferStopsAcceptingItems() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, + 1, + Duration.standardSeconds(5), + 5, // max buffer capacity + null, + null, + null, + useThreadPool); + asyncDoFn.setup(null); + + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + ExecutorService poolExecutor = Executors.newFixedThreadPool(10); + List expectedOutput = new ArrayList<>(); + List> futures = new ArrayList<>(); + + for (int i = 0; i < 10; i++) { + final int idx = i; + expectedOutput.add(String.valueOf(idx)); + futures.add( + poolExecutor.submit( + () -> { + KV item = KV.of("key", String.valueOf(idx)); + asyncDoFn.processDirect( + item, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + })); + } + + try { + Thread.sleep(200); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + assertEquals(5, asyncDoFn.getItemsInBufferCount()); + + waitForEmpty(asyncDoFn, 100); + + // Verify that all background tasks completed successfully without throwing exceptions + for (Future future : futures) { + try { + future.get(); // This will re-throw any exception that occurred in the background thread + } catch (Exception e) { + throw new AssertionError("Background task failed", e); + } + } + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn, 100); + + result.addAll( + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer)); + + checkOutput(result, expectedOutput); + checkItemsInBuffer(asyncDoFn, 0); + poolExecutor.shutdown(); + } + + // Test 11: testBufferWithCancellation + // Verifies backpressure behavior in conjunction with element cancellation. + // Elements that are actively cancelled during queue throttling should be dropped cleanly from the + // buffer. + @Test + public void testBufferWithCancellation() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key1", "2"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + checkItemsInBuffer(asyncDoFn, 2); + + fakeBagState.clear(); + fakeBagState.add(msg2); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 0); + checkOutput(result, Collections.singletonList("2")); + } + + // Test 12: testResetStateConcurrentTeardown + // Verifies safe resource cleanup during concurrent shutdown. + // Resetting the global shared execution state while workers are running + // must complete cleanly without thread or lock deadlocks. + @Test + public void testResetStateConcurrentTeardown() { + BasicDofn dofn = new BasicDofn(500); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + asyncDoFn.processDirect( + KV.of("key1", "1"), GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Verify calling resetState() while background tasks are running finishes cleanly + AsyncDoFn.resetState(); + } +} From 218b24de97f64e3082798546cbcbe9072672ace4 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Fri, 22 May 2026 22:52:59 +0000 Subject: [PATCH 2/2] Optimize State reconciliation loop and eliminate O(N^2) complexity. Removed O(N) global activeElements scan. Fixed logic bug where duplicate elements were incorrectly marked for rescheduling. Optimized lookups by converting finishedItems from a list to a HashSet. --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 74 ++++++++----------- 1 file changed, 29 insertions(+), 45 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index 31c7e6c3d78c..e499cbdf2c1e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -22,7 +22,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Random; import java.util.Set; import java.util.UUID; @@ -99,8 +98,7 @@ public class AsyncDoFn extends DoFn, OutputT> // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. private static final ConcurrentHashMap pool = new ConcurrentHashMap<>(); // activeElements (processingElements) is global JVM memory (all keys) - private static final ConcurrentHashMap< - String, ConcurrentHashMap>> + private static final ConcurrentHashMap>> processingElements = new ConcurrentHashMap<>(); private static final ConcurrentHashMap itemsInBuffer = new ConcurrentHashMap<>(); @@ -108,12 +106,10 @@ public class AsyncDoFn extends DoFn, OutputT> private static final ReentrantLock lock = new ReentrantLock(); private static final boolean verboseLogging = false; - private static class InFlightElement { - final KV element; + private static class InFlightElement { final CompletableFuture> future; - InFlightElement(KV element, CompletableFuture> future) { - this.element = element; + InFlightElement(CompletableFuture> future) { this.future = future; } } @@ -212,13 +208,12 @@ private ExecutorService getThreadPool() { } @SuppressWarnings("unchecked") - private ConcurrentHashMap> getProcessingElements() { - ConcurrentHashMap> elements = processingElements.get(uuid); + private ConcurrentHashMap> getProcessingElements() { + ConcurrentHashMap> elements = processingElements.get(uuid); if (elements == null) { throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); } - return (ConcurrentHashMap>) - (ConcurrentHashMap) elements; + return (ConcurrentHashMap>) (ConcurrentHashMap) elements; } private AtomicInteger getItemsInBuffer() { @@ -298,8 +293,7 @@ private boolean scheduleIfRoom( KV element, BoundedWindow window, Instant timestamp, boolean ignoreBuffer) { lock.lock(); try { - ConcurrentHashMap> activeElements = - getProcessingElements(); + ConcurrentHashMap> activeElements = getProcessingElements(); Object elementId = idFn.apply(element.getValue()); if (activeElements.containsKey(elementId)) { @@ -428,7 +422,7 @@ public String getErrorContext() { } }); - activeElements.put(elementId, new InFlightElement<>(element, future)); + activeElements.put(elementId, new InFlightElement<>(future)); getItemsInBuffer().incrementAndGet(); return true; } @@ -536,70 +530,60 @@ private void commitFinishedItems( LOG.info("processing timer for key: {}", key); } - ConcurrentHashMap> activeElements = - getProcessingElements(); - Set stateIds = new HashSet<>(); - for (KV element : stateList) { - stateIds.add(idFn.apply(element.getValue())); - } - - List toCancel = new ArrayList<>(); - lock.lock(); - try { - // Cancel any active elements for this key that are no longer in runner's state - for (Map.Entry> entry : - activeElements.entrySet()) { - Object elementId = entry.getKey(); - InFlightElement inFlight = entry.getValue(); - - if (Objects.equals(inFlight.element.getKey(), key) && !stateIds.contains(elementId)) { - inFlight.future.cancel(true); - toCancel.add(elementId); - LOG.info("Cancelling item {} which is no longer in state", inFlight.element); - } - } - for (Object elementId : toCancel) { - activeElements.remove(elementId); - } - } finally { - lock.unlock(); - } + ConcurrentHashMap> activeElements = getProcessingElements(); List> toReturn = new ArrayList<>(); - List> finishedItems = new ArrayList<>(); + Set> finishedItems = new HashSet<>(); List> toReschedule = new ArrayList<>(); int itemsFinished = 0; int itemsNotYetFinished = 0; int itemsRescheduled = 0; - int itemsCancelled = toCancel.size(); + int itemsCancelled = 0; + + Set finishedElementIds = new HashSet<>(); + Set inFlightElementIds = new HashSet<>(); + Set rescheduledElementIds = new HashSet<>(); lock.lock(); try { for (KV element : stateList) { Object elementId = idFn.apply(element.getValue()); + + // Skip processing if we already completed, rescheduled, or found this elementId active in + // this cycle + if (finishedElementIds.contains(elementId) + || rescheduledElementIds.contains(elementId) + || inFlightElementIds.contains(elementId)) { + continue; + } + if (activeElements.containsKey(elementId)) { - InFlightElement inFlight = activeElements.get(elementId); + InFlightElement inFlight = activeElements.get(elementId); if (inFlight.future.isDone()) { try { if (!inFlight.future.isCancelled()) { toReturn.add(inFlight.future.get()); } finishedItems.add(element); + finishedElementIds.add(elementId); activeElements.remove(elementId); itemsFinished++; } catch (Exception e) { LOG.error("Error executing async task for element {}", element, e); finishedItems.add(element); + finishedElementIds.add(elementId); activeElements.remove(elementId); } } else { + inFlightElementIds.add(elementId); itemsNotYetFinished++; } } else { LOG.info( "Item {} found in state but not in local active elements, scheduling now", element); toReschedule.add(element); + rescheduledElementIds.add(elementId); itemsRescheduled++; } }