Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions declib/api/decompiler_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,14 +792,14 @@ def shutdown(self) -> None:
_l.info("DecompilerClient shut down complete")

def shutdown_server(self) -> None:
"""Ask the server to tear down its decompiler interface, then disconnect.
"""Ask the server to stop, then disconnect.

Used by CLI commands like ``decompiler stop``. Regular usage should
prefer :meth:`shutdown`, which leaves the server running.
"""
if self._socket:
try:
self._send_request({"type": "shutdown_deci"})
self._send_request({"type": "shutdown_server"})
except Exception:
pass
self.shutdown()
Expand Down Expand Up @@ -1216,4 +1216,4 @@ def global_variable_changed(self, gvar: GlobalVariable, **kwargs) -> GlobalVaria
callback(gvar, **kwargs)
except Exception as e:
_l.error(f"Error in global variable change callback: {e}")
return gvar
return gvar
6 changes: 5 additions & 1 deletion declib/decompilers/ghidra/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@

class GhidraDecompilerInterface(DecompilerInterface):
CACHE_TIMEOUT = 5
# GUI-backed Ghidra program APIs are not safe to enter concurrently from
# arbitrary socket worker threads. DecompilerServer routes calls through
# its main-thread dispatcher when this is enabled.
requires_main_thread_dispatch = True
_program: Optional["Program"]
flat_api: "FlatProgramAPI"

Expand Down Expand Up @@ -76,6 +80,7 @@ def __init__(
# main thread queue
self._main_thread_queue = queue.Queue()
self._results_queue = queue.Queue()
self.requires_main_thread_dispatch = not kwargs.get("headless", False)

super().__init__(
name="ghidra",
Expand Down Expand Up @@ -1430,4 +1435,3 @@ def __function_code_units(self):
[code_unit for code_unit in self.currentProgram.getListing().getCodeUnits(func.getBody(), True)]
for func in self.currentProgram.getFunctionManager().getFunctions(True)
]

24 changes: 23 additions & 1 deletion tests/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import time
import unittest
from types import MethodType
from pathlib import Path

from declib.api.decompiler_server import DecompilerServer
Expand Down Expand Up @@ -217,6 +218,27 @@ def test_error_handling(self):
# Test KeyError handling for non-existent function
with self.assertRaises(KeyError, msg="Should raise KeyError for non-existent function"):
self.client.functions[0xDEADBEEF] # Non-existent function

def test_shutdown_server_requests_server_stop(self):
"""Test that shutdown_server asks the remote server loop to stop."""
sent_requests = []
client = object.__new__(DecompilerClient)
client._socket = object()
client._event_listener_running = False

def fake_send_request(self, request):
sent_requests.append(request)

def fake_shutdown(self):
sent_requests.append({"type": "client_shutdown"})

client._send_request = MethodType(fake_send_request, client)
client.shutdown = MethodType(fake_shutdown, client)

client.shutdown_server()

self.assertEqual(sent_requests[0], {"type": "shutdown_server"})
self.assertEqual(sent_requests[1], {"type": "client_shutdown"})

def test_client_context_manager(self):
"""Test client context manager functionality"""
Expand Down Expand Up @@ -700,4 +722,4 @@ def func_hit(artifact, **kwargs):


if __name__ == "__main__":
unittest.main()
unittest.main()
63 changes: 63 additions & 0 deletions tests/test_ghidra_server_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import sys
import types
import unittest


_MISSING = object()


class FakeCodeUnit:
EOL_COMMENT = 0
PRE_COMMENT = 1
POST_COMMENT = 2
PLATE_COMMENT = 3
REPEATABLE_COMMENT = 4


def _restore_modules(saved_modules):
for name, module in saved_modules.items():
if module is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = module


def _install_ghidra_import_stubs():
module_names = (
"pyghidra",
"pyghidra.core",
"jpype",
"declib.decompilers.ghidra.compat.imports",
"declib.decompilers.ghidra.interface",
)
saved_modules = {name: sys.modules.get(name, _MISSING) for name in module_names}

pyghidra_mod = types.ModuleType("pyghidra")
pyghidra_core_mod = types.ModuleType("pyghidra.core")
pyghidra_core_mod._analyze_program = lambda *args, **kwargs: None
pyghidra_core_mod._get_language = lambda *args, **kwargs: None
pyghidra_core_mod._get_compiler_spec = lambda *args, **kwargs: None
sys.modules.setdefault("pyghidra", pyghidra_mod)
sys.modules.setdefault("pyghidra.core", pyghidra_core_mod)

jpype_mod = types.ModuleType("jpype")
jpype_mod.JClass = type
sys.modules.setdefault("jpype", jpype_mod)

compat_imports_mod = types.ModuleType("declib.decompilers.ghidra.compat.imports")
compat_imports_mod.CodeUnit = FakeCodeUnit
sys.modules["declib.decompilers.ghidra.compat.imports"] = compat_imports_mod
return saved_modules


class TestGhidraServerDispatch(unittest.TestCase):
def test_gui_ghidra_requires_server_main_thread_dispatch(self):
saved_modules = _install_ghidra_import_stubs()
self.addCleanup(_restore_modules, saved_modules)
from declib.decompilers.ghidra.interface import GhidraDecompilerInterface

self.assertTrue(GhidraDecompilerInterface.requires_main_thread_dispatch)


if __name__ == "__main__":
unittest.main()