diff --git a/autoagent/environment/docker_env.py b/autoagent/environment/docker_env.py index 12f0d5d..4fdb9f3 100644 --- a/autoagent/environment/docker_env.py +++ b/autoagent/environment/docker_env.py @@ -5,6 +5,7 @@ import time import socket import json +import secrets from pathlib import Path import shutil wd = Path(__file__).parent.resolve() @@ -38,6 +39,7 @@ def __init__(self, config: Union[DockerConfig, Dict]): self.git_clone = config.git_clone self.setup_package = config.setup_package self.communication_port = config.communication_port + self.command_token = secrets.token_hex(32) self.conda_path = config.conda_path def init_container(self): @@ -100,7 +102,7 @@ def init_container(self): "-v", f"{self.local_workplace}:{self.docker_workplace}", "-w", f"{self.docker_workplace}", "-p", f"{self.communication_port}:{self.communication_port}", BASE_IMAGES, "/bin/bash", "-c", - f"python3 {self.docker_workplace}/tcp_server.py --workplace {self.workplace_name} --conda_path {self.conda_path} --port {self.communication_port}" + f"python3 {self.docker_workplace}/tcp_server.py --workplace {self.workplace_name} --conda_path {self.conda_path} --port {self.communication_port} --token {self.command_token}" ] # execute the docker command result = subprocess.run(docker_command, capture_output=True, text=True) @@ -159,7 +161,8 @@ def run_command(self, command, stream_callback=None): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect((hostname, port)) - s.sendall(command.encode()) + payload = json.dumps({"token": self.command_token, "command": command}) + s.sendall(payload.encode()) partial_line = "" while True: diff --git a/autoagent/environment/tcp_server.py b/autoagent/environment/tcp_server.py index 9ad41c8..f856e66 100644 --- a/autoagent/environment/tcp_server.py +++ b/autoagent/environment/tcp_server.py @@ -2,41 +2,71 @@ import subprocess import json import argparse +import hmac parser = argparse.ArgumentParser() parser.add_argument("--workplace", type=str, default=None) parser.add_argument("--conda_path", type=str, default=None) parser.add_argument("--port", type=int, default=None) +parser.add_argument("--token", type=str, default=None) args = parser.parse_args() + +def receive_all(conn, buffer_size=4096): + data = b"" + while True: + part = conn.recv(buffer_size) + data += part + if len(part) < buffer_size: + # 如果接收的数据小于缓冲区大小,可能已经接收完毕 + break + return data.decode() + + +def parse_command_request(raw_request: str, expected_token: str) -> str: + try: + request = json.loads(raw_request) + except json.JSONDecodeError as exc: + raise ValueError("command request must be JSON") from exc + if not isinstance(request, dict): + raise ValueError("command request must be a JSON object") + token = request.get("token") + command = request.get("command") + if not isinstance(token, str) or not hmac.compare_digest(token, expected_token): + raise PermissionError("valid command token required") + if not isinstance(command, str) or not command: + raise ValueError("command must be a non-empty string") + return command + + if __name__ == "__main__": assert args.workplace is not None, "Workplace is not specified" assert args.conda_path is not None, "Conda path is not specified" assert args.port is not None, "Port is not specified" + assert args.token is not None and args.token, "Command token is not specified" server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.bind(("0.0.0.0", args.port)) server.listen(1) print(f"Listening on port {args.port}...") - def receive_all(conn, buffer_size=4096): - data = b"" - while True: - part = conn.recv(buffer_size) - data += part - if len(part) < buffer_size: - # 如果接收的数据小于缓冲区大小,可能已经接收完毕 - break - return data.decode() - while True: conn, addr = server.accept() print(f"Connection from {addr}") while True: - # command = conn.recv(1024).decode() - command = receive_all(conn) - if not command: + raw_request = receive_all(conn) + if not raw_request: break - + try: + command = parse_command_request(raw_request, args.token) + except Exception as e: + error_response = { + "type": "final", + "status": -1, + "result": f"Rejected command request: {str(e)}" + } + conn.send(json.dumps(error_response).encode() + b"\n") + break + # Execute the command try: modified_command = f"/bin/bash -c 'source {args.conda_path}/etc/profile.d/conda.sh && conda activate autogpt && cd /{args.workplace} && {command}'" @@ -69,12 +99,4 @@ def receive_all(conn, buffer_size=4096): } conn.send(json.dumps(error_response).encode() + b"\n") - # Create a JSON response - # response = { - # "status": exit_code, - # "result": output - # } - - # # Send the JSON response - # conn.send(json.dumps(response).encode()) - conn.close() \ No newline at end of file + conn.close() diff --git a/autoagent/tcp_server.py b/autoagent/tcp_server.py index 0d11ba2..3c740f0 100644 --- a/autoagent/tcp_server.py +++ b/autoagent/tcp_server.py @@ -1,40 +1,59 @@ +import argparse +import hmac +import json import socket import subprocess -import json -import argparse parser = argparse.ArgumentParser() parser.add_argument("--workplace", type=str, default=None) +parser.add_argument("--token", type=str, default=None) args = parser.parse_args() + +def receive_all(conn, buffer_size=4096): + data = b"" + while True: + part = conn.recv(buffer_size) + data += part + if len(part) < buffer_size: + # 如果接收的数据小于缓冲区大小,可能已经接收完毕 + break + return data.decode() + + +def parse_command_request(raw_request: str, expected_token: str) -> str: + try: + request = json.loads(raw_request) + except json.JSONDecodeError as exc: + raise ValueError("command request must be JSON") from exc + if not isinstance(request, dict): + raise ValueError("command request must be a JSON object") + token = request.get("token") + command = request.get("command") + if not isinstance(token, str) or not hmac.compare_digest(token, expected_token): + raise PermissionError("valid command token required") + if not isinstance(command, str) or not command: + raise ValueError("command must be a non-empty string") + return command + + if __name__ == "__main__": assert args.workplace is not None, "Workplace is not specified" + assert args.token is not None and args.token, "Command token is not specified" server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.bind(("0.0.0.0", 12345)) server.listen(1) print("Listening on port 12345...") - def receive_all(conn, buffer_size=4096): - data = b"" - while True: - part = conn.recv(buffer_size) - data += part - if len(part) < buffer_size: - # 如果接收的数据小于缓冲区大小,可能已经接收完毕 - break - return data.decode() - while True: conn, addr = server.accept() print(f"Connection from {addr}") while True: - # command = conn.recv(1024).decode() - command = receive_all(conn) - if not command: + raw_request = receive_all(conn) + if not raw_request: break - - # Execute the command try: + command = parse_command_request(raw_request, args.token) modified_command = f"/bin/bash -c 'source /home/user/micromamba/etc/profile.d/conda.sh && conda activate autogpt && cd /{args.workplace} && {command}'" process = subprocess.Popen(modified_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) output = '' @@ -48,14 +67,14 @@ def receive_all(conn, buffer_size=4096): exit_code = process.wait() except Exception as e: exit_code = -1 - output = f"Error running command: {str(e)}" + output = f"Rejected or failed command request: {str(e)}" # Create a JSON response response = { "status": exit_code, "result": output } - + # Send the JSON response conn.send(json.dumps(response).encode()) - conn.close() \ No newline at end of file + conn.close() diff --git a/tests/test_tcp_server_auth.py b/tests/test_tcp_server_auth.py new file mode 100644 index 0000000..6621cb0 --- /dev/null +++ b/tests/test_tcp_server_auth.py @@ -0,0 +1,111 @@ +import json +import socket +import subprocess +import sys +import tempfile +import time +from pathlib import Path + + +def _connect(port: int, timeout: float = 5): + deadline = time.time() + timeout + while time.time() < deadline: + try: + return socket.create_connection(("127.0.0.1", port), timeout=0.5) + except OSError: + time.sleep(0.1) + raise RuntimeError("tcp server did not start") + + +def _read_all(sock: socket.socket) -> str: + data = b"" + while True: + chunk = sock.recv(65535) + if not chunk: + break + data += chunk + return data.decode(errors="replace") + + +def test_environment_tcp_server_requires_command_token(tmp_path): + conda = tmp_path / "conda" + profile = conda / "etc" / "profile.d" + profile.mkdir(parents=True) + (profile / "conda.sh").write_text("conda(){ return 0; }\n") + workplace = tmp_path / "workplace" + workplace.mkdir() + marker = tmp_path / "unauthorized_marker" + port = 19191 + server = Path(__file__).resolve().parents[1] / "autoagent" / "environment" / "tcp_server.py" + proc = subprocess.Popen( + [ + sys.executable, + str(server), + "--workplace", + str(workplace).lstrip("/"), + "--conda_path", + str(conda), + "--port", + str(port), + "--token", + "expected-token", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + try: + sock = _connect(port) + sock.sendall(f"echo bad > {marker}".encode()) + sock.shutdown(socket.SHUT_WR) + response = _read_all(sock) + assert "valid command token required" in response or "command request must be JSON" in response + assert not marker.exists() + finally: + proc.terminate() + try: + proc.wait(timeout=2) + except subprocess.TimeoutExpired: + proc.kill() + + +def test_environment_tcp_server_accepts_valid_command_token(tmp_path): + conda = tmp_path / "conda" + profile = conda / "etc" / "profile.d" + profile.mkdir(parents=True) + (profile / "conda.sh").write_text("conda(){ return 0; }\n") + workplace = tmp_path / "workplace" + workplace.mkdir() + marker = tmp_path / "authorized_marker" + port = 19192 + server = Path(__file__).resolve().parents[1] / "autoagent" / "environment" / "tcp_server.py" + proc = subprocess.Popen( + [ + sys.executable, + str(server), + "--workplace", + str(workplace).lstrip("/"), + "--conda_path", + str(conda), + "--port", + str(port), + "--token", + "expected-token", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + try: + sock = _connect(port) + sock.sendall(json.dumps({"token": "expected-token", "command": f"echo ok > {marker}"}).encode()) + sock.shutdown(socket.SHUT_WR) + response = _read_all(sock) + assert '"status": 0' in response + assert marker.read_text().strip() == "ok" + finally: + proc.terminate() + try: + proc.wait(timeout=2) + except subprocess.TimeoutExpired: + proc.kill()