diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index 0b09b364362f..26a66ab24d9d 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -72,7 +72,7 @@ class _SharedCache: def __init__(self, constructor, destructor): self._constructor = constructor self._destructor = destructor - self._live_owners = set() + self._live_owners = {} self._cache = {} self._lock = threading.Lock() self._counter = 0 @@ -82,10 +82,10 @@ def _next_id(self): self._counter += 1 return self._counter - def register(self): + def register(self, is_context=False): with self._lock: owner = self._next_id() - self._live_owners.add(owner) + self._live_owners[owner] = is_context return owner def purge(self, owner): @@ -97,7 +97,7 @@ def purge(self, owner): "shutdown, the subprocess was already cleaned up earlier.", owner) return - self._live_owners.remove(owner) + del self._live_owners[owner] for key, entry in list(self._cache.items()): if owner in entry.owners: entry.owners.remove(owner) @@ -108,14 +108,22 @@ def purge(self, owner): for value in to_delete: self._destructor(value) - def get(self, *key): + def get(self, *key, owner=None): if not self._live_owners: raise RuntimeError("At least one owner must be registered.") with self._lock: if key not in self._cache: self._cache[key] = _SharedCacheEntry(self._constructor(*key), set()) - for owner in self._live_owners: + if owner is not None: + if owner not in self._live_owners: + raise RuntimeError("The requesting owner must be registered.") self._cache[key].owners.add(owner) + for live_owner, is_context in self._live_owners.items(): + if is_context: + self._cache[key].owners.add(live_owner) + else: + for live_owner in self._live_owners: + self._cache[key].owners.add(live_owner) return self._cache[key].obj def force_remove(self, *key): @@ -180,7 +188,7 @@ def cache_subprocesses(cls): These subprocesses may be shared with other contexts as well. """ try: - unique_id = cls._cache.register() + unique_id = cls._cache.register(is_context=True) yield finally: cls._cache.purge(unique_id) @@ -214,7 +222,7 @@ def start(self): channel_ready = grpc.channel_ready_future(self._grpc_channel) while True: if process is not None and process.poll() is not None: - _LOGGER.error("Started job service with %s", process.args) + _LOGGER.error("Failed to start job service with %s", process.args) raise RuntimeError( 'Service failed to start up with error %s' % process.poll()) try: @@ -235,15 +243,16 @@ def start(self): def start_process(self): if self._owner_id is not None: self._cache.purge(self._owner_id) - self._owner_id = self._cache.register() - return self._cache.get(tuple(self._cmd), self._port, self._logger) + self._owner_id = self._cache.register(is_context=False) + return self._cache.get( + tuple(self._cmd), self._port, self._logger, owner=self._owner_id) def _really_start_process(cmd, port, logger): if not port: port, = pick_port(None) cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd] # pylint: disable=not-an-iterable endpoint = 'localhost:%s' % port - _LOGGER.info("Starting service with %s", str(cmd).replace("',", "'")) + _LOGGER.warning("Really starting service at %s with cmd: %s", endpoint, cmd) process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) @@ -295,9 +304,11 @@ def stop_force(self): self._grpc_channel = None def _really_stop_process(process_and_endpoint): - process, _ = process_and_endpoint # pylint: disable=unpacking-non-sequence + process, endpoint = process_and_endpoint # pylint: disable=unpacking-non-sequence if not process: return + _LOGGER.warning( + "Really destroying service at %s with cmd: %s", endpoint, process.args) for _ in range(5): if process.poll() is not None: break diff --git a/sdks/python/apache_beam/utils/subprocess_server_test.py b/sdks/python/apache_beam/utils/subprocess_server_test.py index a44b89b17e37..a008ae05c52d 100644 --- a/sdks/python/apache_beam/utils/subprocess_server_test.py +++ b/sdks/python/apache_beam/utils/subprocess_server_test.py @@ -402,16 +402,16 @@ def mock_unregister(cb): self.assertEqual(len(registered_callbacks), 1) def test_concurrent_purge_race_condition(self): - # Concurrent threads attempting to check memebership and call purge for the same owner. - # Here we explicitly define a synchronized set to mimic the behavior of _live_owners. - # This set will block two threads on __contains__, allowing us to test the race condition. + # Concurrent threads attempting to check membership and call purge for the same owner. + # Here we explicitly define a synchronized dict to mimic the behavior of _live_owners. + # This dict will block two threads on __contains__, allowing us to test the race condition. cache = subprocess_server._SharedCache(lambda x: "obj", lambda x: None) owner = cache.register() barrier = threading.Barrier(2) exceptions = [] - class SynchronizedSet(set): + class SynchronizedDict(dict): def __contains__(self, item): res = super().__contains__(item) try: @@ -421,7 +421,7 @@ def __contains__(self, item): pass return res - cache._live_owners = SynchronizedSet(cache._live_owners) + cache._live_owners = SynchronizedDict(cache._live_owners) def purge_worker(): try: @@ -551,6 +551,53 @@ def __init__(self): # Clean up the other owner cache.purge(other_owner) + def test_non_context_owners_do_not_share_keys(self): + cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None) + # owner1 is a non-context owner (e.g., prism) + owner1 = cache.register(is_context=False) + a = cache.get('a', owner=owner1) + + # owner2 is another non-context owner (e.g., short-lived expansion service) + owner2 = cache.register(is_context=False) + b = cache.get('b', owner=owner2) + + # Verify that owner1 does not own 'b' + self.assertNotIn(owner1, cache._cache[('b', )].owners) + + # Verify that owner2 does not own 'a' + self.assertNotIn(owner2, cache._cache[('a', )].owners) + + # Purging owner2 should immediately destroy/remove 'b' + cache.purge(owner2) + self.assertNotIn(('b', ), cache._cache) + + # 'a' is still alive because owner1 is still registered + self.assertIn(('a', ), cache._cache) + + # Purging owner1 should destroy/remove 'a' + cache.purge(owner1) + self.assertNotIn(('a', ), cache._cache) + + def test_context_owner_owns_all_keys(self): + cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None) + # owner1 is a non-context owner (e.g., prism) + owner1 = cache.register(is_context=False) + + # owner2 is a context owner (e.g., cache_subprocesses) + owner2 = cache.register(is_context=True) + + # owner3 is another non-context owner (e.g., short-lived service) + owner3 = cache.register(is_context=False) + + # owner3 requests 'b' + b = cache.get('b', owner=owner3) + + # owner2 (context) should own 'b' + self.assertIn(owner2, cache._cache[('b', )].owners) + + # owner1 (non-context) should NOT own 'b' + self.assertNotIn(owner1, cache._cache[('b', )].owners) + if __name__ == '__main__': unittest.main()