diff --git a/declib/api/decompiler_client.py b/declib/api/decompiler_client.py index 6dd0b8a..3099b4a 100644 --- a/declib/api/decompiler_client.py +++ b/declib/api/decompiler_client.py @@ -792,14 +792,14 @@ def shutdown(self) -> None: _l.info("DecompilerClient shut down complete") def shutdown_server(self) -> None: - """Ask the server to tear down its decompiler interface, then disconnect. + """Ask the server to stop, then disconnect. Used by CLI commands like ``decompiler stop``. Regular usage should prefer :meth:`shutdown`, which leaves the server running. """ if self._socket: try: - self._send_request({"type": "shutdown_deci"}) + self._send_request({"type": "shutdown_server"}) except Exception: pass self.shutdown() @@ -1216,4 +1216,4 @@ def global_variable_changed(self, gvar: GlobalVariable, **kwargs) -> GlobalVaria callback(gvar, **kwargs) except Exception as e: _l.error(f"Error in global variable change callback: {e}") - return gvar \ No newline at end of file + return gvar diff --git a/declib/decompilers/ghidra/interface.py b/declib/decompilers/ghidra/interface.py index 4e54684..588bce3 100644 --- a/declib/decompilers/ghidra/interface.py +++ b/declib/decompilers/ghidra/interface.py @@ -34,6 +34,10 @@ class GhidraDecompilerInterface(DecompilerInterface): CACHE_TIMEOUT = 5 + # GUI-backed Ghidra program APIs are not safe to enter concurrently from + # arbitrary socket worker threads. DecompilerServer routes calls through + # its main-thread dispatcher when this is enabled. + requires_main_thread_dispatch = True _program: Optional["Program"] flat_api: "FlatProgramAPI" @@ -76,6 +80,7 @@ def __init__( # main thread queue self._main_thread_queue = queue.Queue() self._results_queue = queue.Queue() + self.requires_main_thread_dispatch = not kwargs.get("headless", False) super().__init__( name="ghidra", @@ -1430,4 +1435,3 @@ def __function_code_units(self): [code_unit for code_unit in self.currentProgram.getListing().getCodeUnits(func.getBody(), True)] for func in self.currentProgram.getFunctionManager().getFunctions(True) ] - diff --git a/tests/test_client_server.py b/tests/test_client_server.py index f6849d2..23ee261 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -5,6 +5,7 @@ import threading import time import unittest +from types import MethodType from pathlib import Path from declib.api.decompiler_server import DecompilerServer @@ -217,6 +218,27 @@ def test_error_handling(self): # Test KeyError handling for non-existent function with self.assertRaises(KeyError, msg="Should raise KeyError for non-existent function"): self.client.functions[0xDEADBEEF] # Non-existent function + + def test_shutdown_server_requests_server_stop(self): + """Test that shutdown_server asks the remote server loop to stop.""" + sent_requests = [] + client = object.__new__(DecompilerClient) + client._socket = object() + client._event_listener_running = False + + def fake_send_request(self, request): + sent_requests.append(request) + + def fake_shutdown(self): + sent_requests.append({"type": "client_shutdown"}) + + client._send_request = MethodType(fake_send_request, client) + client.shutdown = MethodType(fake_shutdown, client) + + client.shutdown_server() + + self.assertEqual(sent_requests[0], {"type": "shutdown_server"}) + self.assertEqual(sent_requests[1], {"type": "client_shutdown"}) def test_client_context_manager(self): """Test client context manager functionality""" @@ -700,4 +722,4 @@ def func_hit(artifact, **kwargs): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_ghidra_server_dispatch.py b/tests/test_ghidra_server_dispatch.py new file mode 100644 index 0000000..57006ba --- /dev/null +++ b/tests/test_ghidra_server_dispatch.py @@ -0,0 +1,63 @@ +import sys +import types +import unittest + + +_MISSING = object() + + +class FakeCodeUnit: + EOL_COMMENT = 0 + PRE_COMMENT = 1 + POST_COMMENT = 2 + PLATE_COMMENT = 3 + REPEATABLE_COMMENT = 4 + + +def _restore_modules(saved_modules): + for name, module in saved_modules.items(): + if module is _MISSING: + sys.modules.pop(name, None) + else: + sys.modules[name] = module + + +def _install_ghidra_import_stubs(): + module_names = ( + "pyghidra", + "pyghidra.core", + "jpype", + "declib.decompilers.ghidra.compat.imports", + "declib.decompilers.ghidra.interface", + ) + saved_modules = {name: sys.modules.get(name, _MISSING) for name in module_names} + + pyghidra_mod = types.ModuleType("pyghidra") + pyghidra_core_mod = types.ModuleType("pyghidra.core") + pyghidra_core_mod._analyze_program = lambda *args, **kwargs: None + pyghidra_core_mod._get_language = lambda *args, **kwargs: None + pyghidra_core_mod._get_compiler_spec = lambda *args, **kwargs: None + sys.modules.setdefault("pyghidra", pyghidra_mod) + sys.modules.setdefault("pyghidra.core", pyghidra_core_mod) + + jpype_mod = types.ModuleType("jpype") + jpype_mod.JClass = type + sys.modules.setdefault("jpype", jpype_mod) + + compat_imports_mod = types.ModuleType("declib.decompilers.ghidra.compat.imports") + compat_imports_mod.CodeUnit = FakeCodeUnit + sys.modules["declib.decompilers.ghidra.compat.imports"] = compat_imports_mod + return saved_modules + + +class TestGhidraServerDispatch(unittest.TestCase): + def test_gui_ghidra_requires_server_main_thread_dispatch(self): + saved_modules = _install_ghidra_import_stubs() + self.addCleanup(_restore_modules, saved_modules) + from declib.decompilers.ghidra.interface import GhidraDecompilerInterface + + self.assertTrue(GhidraDecompilerInterface.requires_main_thread_dispatch) + + +if __name__ == "__main__": + unittest.main()