Skip to content

Commit ea1c381

Browse files
edomora97veluca93
authored andcommitted
Add support for ephemeral services.
Ephemeral services are services that are not fixed in the configuration file, but dynamically added as they connect. This is especially useful in a setup in which cmsWorker/cmsContestWebServer are scaled dynamically, as one might do when configuring CMS for running on cloud services.
1 parent df9a5fa commit ea1c381

9 files changed

Lines changed: 203 additions & 13 deletions

File tree

cms/conf.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@
2222
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2323

2424
import errno
25+
import ipaddress
2526
import json
2627
import logging
2728
import os
29+
import socket
2830
import sys
2931
from collections import namedtuple
32+
from contextlib import closing
3033

3134
from .log import set_detailed_logs
3235

@@ -44,6 +47,7 @@ class ServiceCoord(namedtuple("ServiceCoord", "name shard")):
4447
service (thus identifying it).
4548
4649
"""
50+
4751
def __repr__(self):
4852
return "%s,%d" % (self.name, self.shard)
4953

@@ -53,6 +57,75 @@ class ConfigError(Exception):
5357
pass
5458

5559

60+
class EphemeralServiceConfig:
61+
"""Configuration of an ephemeral service. An ephemeral service is a
62+
normal service whose shard is chosen depending on its address and
63+
port. The port is assigned inside a range and the address must be
64+
inside the subnet.
65+
"""
66+
EPHEMERAL_SHARD_OFFSET = 10000
67+
68+
def __init__(self, subnet, min_port, max_port):
69+
self.subnet = ipaddress.ip_network(subnet)
70+
self.min_port = min_port
71+
self.max_port = max_port
72+
if min_port > max_port:
73+
raise ConfigError("Invalid port range: [%s, %s]"
74+
% (min_port, max_port))
75+
76+
def get_shard(self, address, port):
77+
"""Get the ephemeral shard for a service given its address and port.
78+
79+
address (IPv4Address|IPv6Address): address of the service.
80+
port (int): port of the service.
81+
82+
return (int): shard of the service
83+
"""
84+
if address not in self.subnet:
85+
raise ValueError("The address is not inside the subnet")
86+
host_id = int(address) & int(self.subnet.hostmask)
87+
num_ports = self.max_port - self.min_port + 1
88+
shard = host_id * num_ports + (port - self.min_port)
89+
return shard + self.EPHEMERAL_SHARD_OFFSET
90+
91+
def get_address(self, shard):
92+
"""Get the address and port of a service given its shard.
93+
94+
shard (int): shard of the service
95+
96+
return (Address): address and port of the service
97+
"""
98+
shard -= self.EPHEMERAL_SHARD_OFFSET
99+
num_ports = self.max_port - self.min_port + 1
100+
port_offset = shard % num_ports
101+
host_id = (shard - port_offset) // num_ports
102+
103+
port = self.min_port + port_offset
104+
addr = self.subnet.network_address + host_id
105+
if addr not in self.subnet:
106+
raise ValueError("The shard is not valid")
107+
return Address(str(addr), port)
108+
109+
def find_free_port(self, address):
110+
"""Find the first open port.
111+
112+
address (IPv4Address|IPv6Address): local address to bind to
113+
"""
114+
if address.version == 4:
115+
family = socket.AF_INET
116+
else:
117+
family = socket.AF_INET6
118+
for port in range(self.min_port, self.max_port+1):
119+
with closing(socket.socket(family, socket.SOCK_STREAM)) as sock:
120+
try:
121+
sock.bind((str(address), port))
122+
return port
123+
except socket.error:
124+
continue
125+
raise ValueError("No free port found in range [%s, %s] "
126+
"for address %s" % (minport, maxport, address))
127+
128+
56129
class AsyncConfig:
57130
"""This class will contain the configuration for the
58131
services. This needs to be populated at the initilization stage.
@@ -69,6 +142,7 @@ class AsyncConfig:
69142
"""
70143
core_services = {}
71144
other_services = {}
145+
ephemeral_services = {} # type: dict[str, EphemeralServiceConfig]
72146

73147

74148
async_config = AsyncConfig()
@@ -81,6 +155,7 @@ class Config:
81155
directory for information on the meaning of the fields.
82156
83157
"""
158+
84159
def __init__(self):
85160
"""Default values for configuration, plus decide if this
86161
instance is running from the system path or from the source
@@ -274,6 +349,18 @@ def _load_unique(self, path):
274349
self.async_config.other_services[coord] = Address(*shard)
275350
del data["other_services"]
276351

352+
for service_name in data["ephemeral_services"]:
353+
if service_name.startswith("_"):
354+
continue
355+
service = data["ephemeral_services"][service_name]
356+
self.async_config.ephemeral_services[service_name] = \
357+
EphemeralServiceConfig(
358+
service["subnet"],
359+
service["min_port"],
360+
service["max_port"],
361+
)
362+
del data["ephemeral_services"]
363+
277364
# Put everything else in self.
278365
for key, value in data.items():
279366
setattr(self, key, value)

cms/io/web_service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(self, listen_port, handlers, parameters, shard=0,
106106
if num_proxies_used > 0:
107107
self.wsgi_app = ProxyFix(self.wsgi_app, num_proxies_used)
108108

109+
logger.info("%s listening on '%s' at port %d",
110+
type(self).__name__, listen_address, listen_port)
109111
self.web_server = WSGIServer((listen_address, listen_port), self)
110112

111113
def __call__(self, environ, start_response):

cms/server/contest/server.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from cms.io import WebService
4646
from cms.locale import get_translations
4747
from cms.server.contest.jinja2_toolbox import CWS_ENVIRONMENT
48+
from cms.util import is_shard_ephemeral
4849
from cmscommon.binary import hex_to_bin
4950
from .handlers import HANDLERS
5051
from .handlers.base import ContestListHandler
@@ -73,8 +74,12 @@ def __init__(self, shard, contest_id=None):
7374
}
7475

7576
try:
76-
listen_address = config.contest_listen_address[shard]
77-
listen_port = config.contest_listen_port[shard]
77+
if is_shard_ephemeral(shard):
78+
index = 0
79+
else:
80+
index = shard
81+
listen_address = config.contest_listen_address[index]
82+
listen_port = config.contest_listen_port[index]
7883
except IndexError:
7984
raise ConfigError("Wrong shard number for %s, or missing "
8085
"address/port configuration. Please check "

cms/service/EvaluationService.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def enqueue(self, item, priority, timestamp):
161161
item_entry = item.to_dict()
162162
del item_entry["testcase_codename"]
163163
item_entry["multiplicity"] = 1
164-
entry = {"item": item_entry, "priority": priority, "timestamp": make_timestamp(timestamp)}
164+
entry = {"item": item_entry, "priority": priority,
165+
"timestamp": make_timestamp(timestamp)}
165166
self.queue_status_cumulative[key] = entry
166167
return success
167168

@@ -197,6 +198,11 @@ def _remove_from_cumulative_status(self, queue_entry):
197198
if self.queue_status_cumulative[key]["item"]["multiplicity"] == 0:
198199
del self.queue_status_cumulative[key]
199200

201+
def add_worker(self, worker_coord):
202+
"""Add a new worker to the pool.
203+
"""
204+
self.pool.add_worker(worker_coord, ephemeral=True)
205+
200206

201207
def with_post_finish_lock(func):
202208
"""Decorator for locking on self.post_finish_lock.
@@ -379,6 +385,13 @@ def workers_status(self):
379385
"""
380386
return self.get_executor().pool.get_status()
381387

388+
@rpc_method
389+
def add_worker(self, coord):
390+
"""Register a new worker to the list of workers.
391+
"""
392+
service, shard = coord
393+
self.get_executor().add_worker(ServiceCoord(service, shard))
394+
382395
def check_workers_timeout(self):
383396
"""We ask WorkerPool for the unresponsive workers, and we put
384397
again their operations in the queue.

cms/service/Worker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import gevent.lock
3232

33+
from cms import ServiceCoord
3334
from cms.db import SessionGen, Contest, enumerate_files
3435
from cms.db.filecacher import FileCacher, TombstoneError
3536
from cms.grading import JobException
@@ -64,6 +65,13 @@ def __init__(self, shard, fake_worker_time=None):
6465

6566
self._fake_worker_time = fake_worker_time
6667

68+
self.evaluation_service = self.connect_to(
69+
ServiceCoord("EvaluationService", 0),
70+
on_connect=self.on_es_connection)
71+
72+
def on_es_connection(self, address):
73+
self.evaluation_service.add_worker(coord=self._my_coord)
74+
6775
@rpc_method
6876
def precache_files(self, contest_id):
6977
"""RPC to ask the worker to precache of files in the contest.

cms/service/workerpool.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,20 @@ def wait_for_workers(self):
140140
"""Wait until a worker might be available."""
141141
self._workers_available_event.wait()
142142

143-
def add_worker(self, worker_coord):
143+
def add_worker(self, worker_coord, ephemeral=False):
144144
"""Add a new worker to the worker pool.
145145
146146
worker_coord (ServiceCoord): the coordinates of the worker.
147+
ephemeral (bool): remove the worker from the pool after the
148+
disconnection.
147149
148150
"""
149151
shard = worker_coord.shard
150152
# Instruct GeventLibrary to connect ES to the Worker.
151153
self._worker[shard] = self._service.connect_to(
152154
worker_coord,
153-
on_connect=self.on_worker_connected)
155+
on_connect=self.on_worker_connected,
156+
on_disconnect=lambda: self.on_worker_disconnected(worker_coord, ephemeral))
154157

155158
# And we fill all data.
156159
self._operations[shard] = WorkerPool.WORKER_INACTIVE
@@ -183,6 +186,24 @@ def on_worker_connected(self, worker_coord):
183186
# so we wake up the consumers.
184187
self._workers_available_event.set()
185188

189+
def on_worker_disconnected(self, worker_coord, ephemeral):
190+
"""If the worker is ephemeral, disable and the remove the worker
191+
form the pool.
192+
"""
193+
if not ephemeral:
194+
return
195+
shard = worker_coord.shard
196+
if self._operations[shard] != WorkerPool.WORKER_DISABLED:
197+
# disable the worker and re-enqueue the lost operations
198+
lost_operations = self.disable_worker(shard)
199+
for operation in lost_operations:
200+
logger.info("Operation %s put again in the queue because "
201+
"the worker disconnected.", operation)
202+
priority, timestamp = operation.side_data
203+
self._service.enqueue(operation, priority, timestamp)
204+
del self._worker[shard]
205+
logger.info("Worker %s removed", worker_coord)
206+
186207
def acquire_worker(self, operations):
187208
"""Tries to assign an operation to an available worker. If no workers
188209
are available then this returns None, otherwise this returns

cms/util.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import argparse
2525
import itertools
26+
import ipaddress
2627
import logging
2728
import netifaces
2829
import os
@@ -35,6 +36,7 @@
3536
import gevent.socket
3637

3738
from cms import ServiceCoord, ConfigError, async_config, config
39+
from cms.conf import EphemeralServiceConfig
3840

3941

4042
logger = logging.getLogger(__name__)
@@ -136,8 +138,19 @@ def get_safe_shard(service, provided_shard):
136138
raise (ValueError): if no safe shard can be returned.
137139
138140
"""
141+
addrs = _find_local_addresses()
142+
# Try to assign an ephemeral shard first. This needs to be done before
143+
# autodetecting the shared using the ip since here we cannot detect if
144+
# the service is already running on that port.
145+
if provided_shard is None and service in config.async_config.ephemeral_services:
146+
ephemeral_config = config.async_config.ephemeral_services[service]
147+
for addr in addrs:
148+
addr = ipaddress.ip_address(addr[1])
149+
if addr in ephemeral_config.subnet:
150+
port = ephemeral_config.find_free_port(addr)
151+
shard = ephemeral_config.get_shard(addr, port)
152+
return shard
139153
if provided_shard is None:
140-
addrs = _find_local_addresses()
141154
computed_shard = _get_shard_from_addresses(service, addrs)
142155
if computed_shard is None:
143156
logger.critical("Couldn't autodetect shard number and "
@@ -157,17 +170,30 @@ def get_safe_shard(service, provided_shard):
157170
return provided_shard
158171

159172

173+
def is_shard_ephemeral(shard):
174+
"""Checks if the shard is ephemeral.
175+
176+
shard (int): the shard to check.
177+
178+
return (bool): True if the shard is ephemeral.
179+
"""
180+
return shard >= EphemeralServiceConfig.EPHEMERAL_SHARD_OFFSET
181+
182+
160183
def get_service_address(key):
161184
"""Give the Address of a ServiceCoord.
162185
163186
key (ServiceCoord): the service needed.
164187
returns (Address): listening address of key.
165188
166189
"""
190+
service, shard = key
167191
if key in async_config.core_services:
168192
return async_config.core_services[key]
169193
elif key in async_config.other_services:
170194
return async_config.other_services[key]
195+
elif service in async_config.ephemeral_services:
196+
return async_config.ephemeral_services[service].get_address(shard)
171197
else:
172198
raise KeyError("Service not found.")
173199

@@ -179,11 +205,11 @@ def get_service_shards(service):
179205
returns (int): the number of shards defined in the configuration.
180206
181207
"""
182-
for i in itertools.count():
183-
try:
184-
get_service_address(ServiceCoord(service, i))
185-
except KeyError:
186-
return i
208+
count = 0
209+
for services in (async_config.core_services, async_config.other_services):
210+
count += len([0 for s in services if s.name == service])
211+
212+
return count
187213

188214

189215
def default_argument_parser(description, cls, ask_contest=None):

0 commit comments

Comments
 (0)