Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions autoagent/environment/docker_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import socket
import json
import secrets
from pathlib import Path
import shutil
wd = Path(__file__).parent.resolve()
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self, config: Union[DockerConfig, Dict]):
self.git_clone = config.git_clone
self.setup_package = config.setup_package
self.communication_port = config.communication_port
self.command_token = secrets.token_hex(32)
self.conda_path = config.conda_path

def init_container(self):
Expand Down Expand Up @@ -100,7 +102,7 @@ def init_container(self):
"-v", f"{self.local_workplace}:{self.docker_workplace}",
"-w", f"{self.docker_workplace}", "-p", f"{self.communication_port}:{self.communication_port}", BASE_IMAGES,
"/bin/bash", "-c",
f"python3 {self.docker_workplace}/tcp_server.py --workplace {self.workplace_name} --conda_path {self.conda_path} --port {self.communication_port}"
f"python3 {self.docker_workplace}/tcp_server.py --workplace {self.workplace_name} --conda_path {self.conda_path} --port {self.communication_port} --token {self.command_token}"
]
# execute the docker command
result = subprocess.run(docker_command, capture_output=True, text=True)
Expand Down Expand Up @@ -159,7 +161,8 @@ def run_command(self, command, stream_callback=None):

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((hostname, port))
s.sendall(command.encode())
payload = json.dumps({"token": self.command_token, "command": command})
s.sendall(payload.encode())

partial_line = ""
while True:
Expand Down
68 changes: 45 additions & 23 deletions autoagent/environment/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,71 @@
import subprocess
import json
import argparse
import hmac

parser = argparse.ArgumentParser()
parser.add_argument("--workplace", type=str, default=None)
parser.add_argument("--conda_path", type=str, default=None)
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--token", type=str, default=None)
args = parser.parse_args()


def receive_all(conn, buffer_size=4096):
data = b""
while True:
part = conn.recv(buffer_size)
data += part
if len(part) < buffer_size:
# 如果接收的数据小于缓冲区大小,可能已经接收完毕
break
return data.decode()


def parse_command_request(raw_request: str, expected_token: str) -> str:
try:
request = json.loads(raw_request)
except json.JSONDecodeError as exc:
raise ValueError("command request must be JSON") from exc
if not isinstance(request, dict):
raise ValueError("command request must be a JSON object")
token = request.get("token")
command = request.get("command")
if not isinstance(token, str) or not hmac.compare_digest(token, expected_token):
raise PermissionError("valid command token required")
if not isinstance(command, str) or not command:
raise ValueError("command must be a non-empty string")
return command


if __name__ == "__main__":
assert args.workplace is not None, "Workplace is not specified"
assert args.conda_path is not None, "Conda path is not specified"
assert args.port is not None, "Port is not specified"
assert args.token is not None and args.token, "Command token is not specified"
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(("0.0.0.0", args.port))
server.listen(1)

print(f"Listening on port {args.port}...")
def receive_all(conn, buffer_size=4096):
data = b""
while True:
part = conn.recv(buffer_size)
data += part
if len(part) < buffer_size:
# 如果接收的数据小于缓冲区大小,可能已经接收完毕
break
return data.decode()

while True:
conn, addr = server.accept()
print(f"Connection from {addr}")
while True:
# command = conn.recv(1024).decode()
command = receive_all(conn)
if not command:
raw_request = receive_all(conn)
if not raw_request:
break

try:
command = parse_command_request(raw_request, args.token)
except Exception as e:
error_response = {
"type": "final",
"status": -1,
"result": f"Rejected command request: {str(e)}"
}
conn.send(json.dumps(error_response).encode() + b"\n")
break

# Execute the command
try:
modified_command = f"/bin/bash -c 'source {args.conda_path}/etc/profile.d/conda.sh && conda activate autogpt && cd /{args.workplace} && {command}'"
Expand Down Expand Up @@ -69,12 +99,4 @@ def receive_all(conn, buffer_size=4096):
}
conn.send(json.dumps(error_response).encode() + b"\n")

# Create a JSON response
# response = {
# "status": exit_code,
# "result": output
# }

# # Send the JSON response
# conn.send(json.dumps(response).encode())
conn.close()
conn.close()
59 changes: 39 additions & 20 deletions autoagent/tcp_server.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,59 @@
import argparse
import hmac
import json
import socket
import subprocess
import json
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--workplace", type=str, default=None)
parser.add_argument("--token", type=str, default=None)
args = parser.parse_args()


def receive_all(conn, buffer_size=4096):
data = b""
while True:
part = conn.recv(buffer_size)
data += part
if len(part) < buffer_size:
# 如果接收的数据小于缓冲区大小,可能已经接收完毕
break
return data.decode()


def parse_command_request(raw_request: str, expected_token: str) -> str:
try:
request = json.loads(raw_request)
except json.JSONDecodeError as exc:
raise ValueError("command request must be JSON") from exc
if not isinstance(request, dict):
raise ValueError("command request must be a JSON object")
token = request.get("token")
command = request.get("command")
if not isinstance(token, str) or not hmac.compare_digest(token, expected_token):
raise PermissionError("valid command token required")
if not isinstance(command, str) or not command:
raise ValueError("command must be a non-empty string")
return command


if __name__ == "__main__":
assert args.workplace is not None, "Workplace is not specified"
assert args.token is not None and args.token, "Command token is not specified"
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(("0.0.0.0", 12345))
server.listen(1)

print("Listening on port 12345...")
def receive_all(conn, buffer_size=4096):
data = b""
while True:
part = conn.recv(buffer_size)
data += part
if len(part) < buffer_size:
# 如果接收的数据小于缓冲区大小,可能已经接收完毕
break
return data.decode()

while True:
conn, addr = server.accept()
print(f"Connection from {addr}")
while True:
# command = conn.recv(1024).decode()
command = receive_all(conn)
if not command:
raw_request = receive_all(conn)
if not raw_request:
break

# Execute the command
try:
command = parse_command_request(raw_request, args.token)
modified_command = f"/bin/bash -c 'source /home/user/micromamba/etc/profile.d/conda.sh && conda activate autogpt && cd /{args.workplace} && {command}'"
process = subprocess.Popen(modified_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
output = ''
Expand All @@ -48,14 +67,14 @@ def receive_all(conn, buffer_size=4096):
exit_code = process.wait()
except Exception as e:
exit_code = -1
output = f"Error running command: {str(e)}"
output = f"Rejected or failed command request: {str(e)}"

# Create a JSON response
response = {
"status": exit_code,
"result": output
}

# Send the JSON response
conn.send(json.dumps(response).encode())
conn.close()
conn.close()
111 changes: 111 additions & 0 deletions tests/test_tcp_server_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import json
import socket
import subprocess
import sys
import tempfile
import time
from pathlib import Path


def _connect(port: int, timeout: float = 5):
deadline = time.time() + timeout
while time.time() < deadline:
try:
return socket.create_connection(("127.0.0.1", port), timeout=0.5)
except OSError:
time.sleep(0.1)
raise RuntimeError("tcp server did not start")


def _read_all(sock: socket.socket) -> str:
data = b""
while True:
chunk = sock.recv(65535)
if not chunk:
break
data += chunk
return data.decode(errors="replace")


def test_environment_tcp_server_requires_command_token(tmp_path):
conda = tmp_path / "conda"
profile = conda / "etc" / "profile.d"
profile.mkdir(parents=True)
(profile / "conda.sh").write_text("conda(){ return 0; }\n")
workplace = tmp_path / "workplace"
workplace.mkdir()
marker = tmp_path / "unauthorized_marker"
port = 19191
server = Path(__file__).resolve().parents[1] / "autoagent" / "environment" / "tcp_server.py"
proc = subprocess.Popen(
[
sys.executable,
str(server),
"--workplace",
str(workplace).lstrip("/"),
"--conda_path",
str(conda),
"--port",
str(port),
"--token",
"expected-token",
],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
try:
sock = _connect(port)
sock.sendall(f"echo bad > {marker}".encode())
sock.shutdown(socket.SHUT_WR)
response = _read_all(sock)
assert "valid command token required" in response or "command request must be JSON" in response
assert not marker.exists()
finally:
proc.terminate()
try:
proc.wait(timeout=2)
except subprocess.TimeoutExpired:
proc.kill()


def test_environment_tcp_server_accepts_valid_command_token(tmp_path):
conda = tmp_path / "conda"
profile = conda / "etc" / "profile.d"
profile.mkdir(parents=True)
(profile / "conda.sh").write_text("conda(){ return 0; }\n")
workplace = tmp_path / "workplace"
workplace.mkdir()
marker = tmp_path / "authorized_marker"
port = 19192
server = Path(__file__).resolve().parents[1] / "autoagent" / "environment" / "tcp_server.py"
proc = subprocess.Popen(
[
sys.executable,
str(server),
"--workplace",
str(workplace).lstrip("/"),
"--conda_path",
str(conda),
"--port",
str(port),
"--token",
"expected-token",
],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
try:
sock = _connect(port)
sock.sendall(json.dumps({"token": "expected-token", "command": f"echo ok > {marker}"}).encode())
sock.shutdown(socket.SHUT_WR)
response = _read_all(sock)
assert '"status": 0' in response
assert marker.read_text().strip() == "ok"
finally:
proc.terminate()
try:
proc.wait(timeout=2)
except subprocess.TimeoutExpired:
proc.kill()