Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,13 @@ class _BackendRendezvousStateHolder(_RendezvousStateHolder):
_last_sync_time: float
_dead_nodes: List[_NodeDesc]

# CAS metrics tracking
_cas_metrics_lock: threading.Lock
_cas_total_attempts: int
_cas_successful_attempts: int
_cas_failed_attempts: int
_cas_start_time: float

def __init__(
self,
backend: RendezvousBackend,
Expand All @@ -432,6 +439,40 @@ def __init__(
self._last_sync_time = -1
self._dead_nodes = []

# Initialize CAS metrics lock (created once)
self._cas_metrics_lock = threading.Lock()
# Initialize CAS metrics values
self._init_cas_metrics()

def _init_cas_metrics(self) -> None:
"""Initialize CAS metrics tracking."""
with self._cas_metrics_lock:
self._cas_total_attempts = 0
self._cas_successful_attempts = 0
self._cas_failed_attempts = 0
self._cas_start_time = time.monotonic()

def get_cas_metrics(self) -> dict:
"""Get CAS operation metrics for debugging rendezvous issues."""
with self._cas_metrics_lock:
current_time = time.monotonic()
elapsed_time = current_time - self._cas_start_time

success_rate = 0.0
if self._cas_total_attempts > 0:
success_rate = (self._cas_successful_attempts / self._cas_total_attempts) * 100.0

return {
"total_attempts": self._cas_total_attempts,
"successful_attempts": self._cas_successful_attempts,
"failed_attempts": self._cas_failed_attempts,
"success_rate_percent": success_rate,
"elapsed_time_seconds": elapsed_time,
"attempts_per_second": (
self._cas_total_attempts / elapsed_time if elapsed_time > 0 else 0.0
),
}

def _record(self, message: str, node_state: NodeState = NodeState.RUNNING):
construct_and_record_rdzv_event(
name=f"{self.__class__.__name__}.{get_method_name()}",
Expand All @@ -458,9 +499,28 @@ def sync(self) -> Optional[bool]:

state_bits = pickle.dumps(self._state)

# Track CAS operation
with self._cas_metrics_lock:
self._cas_total_attempts += 1

set_response = self._backend.set_state(state_bits, self._token)
if set_response is not None:
state_bits, token, has_set = set_response

# Track CAS result
cas_failed = False
with self._cas_metrics_lock:
if has_set:
self._cas_successful_attempts += 1
else:
self._cas_failed_attempts += 1
cas_failed = True

# Add random delay on CAS failure to reduce thundering herd effect
# This spreads out retry attempts when multiple nodes compete for the same state update
# Delay is applied outside the lock to avoid blocking other threads
if cas_failed:
_delay(seconds=(0, 0.3))
else:
has_set = None

Expand Down Expand Up @@ -1322,6 +1382,9 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
self._record(message=msg)
log.info(msg)

# Reset CAS metrics for this rendezvous round
self._state_holder._init_cas_metrics()

try:
self._stop_heartbeats()

Expand Down Expand Up @@ -1362,6 +1425,17 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
self._record(message=msg, rank=rank)
log.info(msg)

# Print CAS metrics on success
cas_metrics = self._state_holder.get_cas_metrics()
if cas_metrics['total_attempts'] > 0:
log.info(
f"CAS METRICS [{self._this_node}] - "
f"Total: {cas_metrics['total_attempts']}, "
f"Success: {cas_metrics['successful_attempts']}, "
f"Failed: {cas_metrics['failed_attempts']}, "
f"Success Rate: {cas_metrics['success_rate_percent']:.1f}%"
)

# Use RendezvousInfo if available (newer PyTorch versions >= 2.4.0)
# Fall back to tuple format if RendezvousInfo is not supported
if _RENDEZVOUS_INFO_AVAILABLE:
Expand Down
118 changes: 118 additions & 0 deletions src/nvidia_resiliency_ext/fault_tolerance/c10d_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python3

"""
Monkey patch for PyTorch's c10d_rendezvous_backend to add use_libuv support.

This patch modifies the _create_tcp_store function to accept and use the use_libuv
parameter from RendezvousParameters, allowing users to control whether to use
the libuv backend or the traditional socket backend for TCPStore.

Usage:
from nvidia_resiliency_ext.fault_tolerance.c10d_monkey_patch import apply_c10d_patch
apply_c10d_patch()
"""

import logging

logger = logging.getLogger(__name__)


def _patched_create_tcp_store(params: "RendezvousParameters") -> "TCPStore": # noqa: F821
"""
Patched version of _create_tcp_store that supports use_libuv parameter.

This function is identical to the original _create_tcp_store except it
extracts and uses the use_libuv parameter from RendezvousParameters.
"""
import os
from datetime import timedelta
from typing import cast

from torch.distributed import TCPStore
from torch.distributed.elastic.events import NodeState, construct_and_record_rdzv_event
from torch.distributed.elastic.rendezvous.api import RendezvousConnectionError
from torch.distributed.elastic.rendezvous.c10d_rendezvous_backend import (
DEFAULT_PORT,
_matches_machine_hostname,
parse_rendezvous_endpoint,
)

host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT)

cfg_is_host = params.get_as_bool("is_host")
# If the user has explicitly specified whether our process should host the
# the store, respect it.
if cfg_is_host is not None:
is_host = cfg_is_host
# Otherwise try to determine whether we are the host based on our hostname
# and IP address.
else:
is_host = _matches_machine_hostname(host)

# The timeout
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
if read_timeout <= 0:
raise ValueError("The read timeout must be a positive integer.")

# The use_libuv parameter - NEW FUNCTIONALITY
use_libuv = params.get_as_bool("use_libuv", True)

# In specific cases we attempt to instantiate the store twice. For details
# see the explanation in the except clause below.
for is_server in [is_host, False]:
try:
store = TCPStore(
host,
port,
is_master=is_server,
multi_tenant=True,
timeout=timedelta(seconds=read_timeout),
use_libuv=use_libuv, # NEW PARAMETER
)

if is_server:
msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
construct_and_record_rdzv_event(
run_id=params.run_id, message=msg, node_state=NodeState.INIT
)
logger.info(msg)

break
except (ValueError, RuntimeError, TimeoutError) as exc:
# If we heuristically inferred the value of is_host as True and our
# first attempt to instantiate the TCP store has failed, try it one
# more time with is_host set to False. As an edge case there can be
# more than one process that is part of the same rendezvous on this
# machine and only one of them will eventually host the store.

if not is_server or cfg_is_host is not None:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc

return store # type: ignore[possibly-undefined]


def apply_c10d_patch():
"""
Apply the monkey patch to add use_libuv support to c10d_rendezvous_backend.

This function patches the _create_tcp_store function in the c10d_rendezvous_backend
module to support the use_libuv parameter.
"""
try:
from torch.distributed.elastic.rendezvous import c10d_rendezvous_backend

# Apply the patch
c10d_rendezvous_backend._create_tcp_store = _patched_create_tcp_store

logger.info(
"Successfully applied c10d_rendezvous_backend monkey patch for use_libuv support"
)

except ImportError as e:
logger.error(f"Failed to import c10d_rendezvous_backend: {e}")
raise
except Exception as e:
logger.error(f"Failed to apply c10d monkey patch: {e}")
raise
4 changes: 4 additions & 0 deletions src/nvidia_resiliency_ext/fault_tolerance/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def _register_ft_rdzv_handler():
from torch.distributed.elastic.rendezvous.c10d_rendezvous_backend import create_backend

from ._ft_rendezvous import FtRendezvousHandler, create_handler
from .c10d_monkey_patch import apply_c10d_patch

# Apply monkey patch to add use_libuv support to c10d backend
apply_c10d_patch()

def _create_ft_rdzv_handler(params: RendezvousParameters) -> FtRendezvousHandler:
backend, store = create_backend(params)
Expand Down
Loading