Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 103 additions & 26 deletions src/OpenClaw.Shared/WebSocketClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public abstract class WebSocketClientBase : IDisposable
private bool _disposed;
private int _reconnectAttempts;
private int _reconnectLoopActive;
private long _connectionGeneration;
private readonly SemaphoreSlim _sendLock = new(1, 1);
private static readonly int[] BackoffMs = { 1000, 2000, 4000, 8000, 15000, 30000, 60000 };

Expand Down Expand Up @@ -111,49 +112,82 @@ public async Task ConnectAsync()
return;
}

var connectGeneration = Interlocked.Increment(ref _connectionGeneration);
ClientWebSocket? ws = null;

try
{
RaiseStatusChanged(ConnectionStatus.Connecting);
_logger.Info($"Connecting to {ClientRole}: {GatewayUrlForDisplay}");

_webSocket = new ClientWebSocket();
_webSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(30);
ws = new ClientWebSocket();
ws.Options.KeepAliveInterval = TimeSpan.FromSeconds(30);
_webSocket = ws;

// Set Origin header (convert ws/wss to http/https)
var uri = new Uri(_gatewayUrl);
var originScheme = uri.Scheme == "wss" ? "https" : "http";
var origin = $"{originScheme}://{uri.Host}:{uri.Port}";
_webSocket.Options.SetRequestHeader("Origin", origin);
ws.Options.SetRequestHeader("Origin", origin);

if (!string.IsNullOrEmpty(_credentials))
{
var credentialsToEncode = GatewayUrlHelper.DecodeCredentials(_credentials);
_webSocket.Options.SetRequestHeader(
ws.Options.SetRequestHeader(
"Authorization",
$"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes(credentialsToEncode))}");
}

await _webSocket.ConnectAsync(uri, _cts.Token);
await ws.ConnectAsync(uri, _cts.Token);
if (!IsCurrentConnection(ws, connectGeneration))
{
DisposeStaleSocket(ws);
return;
}

// Don't reset _reconnectAttempts here — TCP connect succeeding doesn't mean
// auth will succeed. Reset only after the full application-level handshake
// completes (subclass calls ResetReconnectAttempts after hello-ok).
_logger.Info($"{ClientRole} connected, waiting for challenge...");

await OnConnectedAsync();
if (!IsCurrentConnection(ws, connectGeneration))
{
DisposeStaleSocket(ws);
return;
}

_ = Task.Run(() => ListenForMessagesAsync(), _cts.Token);
_ = Task.Run(() => ListenForMessagesAsync(ws, connectGeneration), _cts.Token);
}
catch (OperationCanceledException)
{
if (ws != null)
{
DisposeStaleSocket(ws);
}
_logger.Debug($"{ClientRole} connect canceled (likely shutdown)");
}
catch (ObjectDisposedException)
{
if (ws != null)
{
DisposeStaleSocket(ws);
}
_logger.Debug($"{ClientRole} connect aborted after dispose");
}
catch (Exception ex)
{
if (ws != null && !IsCurrentConnection(ws, connectGeneration))
{
DisposeStaleSocket(ws);
_logger.Debug($"{ClientRole} stale connection failure ignored: {ex.Message}");
return;
}

if (ws != null)
{
DisposeStaleSocket(ws);
}
_logger.Error($"{ClientRole} connection failed", ex);
RaiseStatusChanged(ConnectionStatus.Error);

Expand All @@ -164,7 +198,23 @@ public async Task ConnectAsync()
}
}

private async Task ListenForMessagesAsync()
private bool IsCurrentConnection(ClientWebSocket ws, long generation) =>
!_disposed
&& Interlocked.Read(ref _connectionGeneration) == generation
&& ReferenceEquals(_webSocket, ws);

private void DisposeStaleSocket(ClientWebSocket ws)
{
if (ReferenceEquals(_webSocket, ws))
{
_webSocket = null;
}

// slopwatch-ignore: SW003 Cleanup is best-effort for superseded sockets.
try { ws.Dispose(); } catch { }
}

private async Task ListenForMessagesAsync(ClientWebSocket ws, long connectionGeneration)
{
// Rent a pooled buffer — consistent with the SendRawAsync hot path; avoids a large
// (16–64 KB) heap allocation per connection that would otherwise land on the LOH.
Expand All @@ -173,10 +223,14 @@ private async Task ListenForMessagesAsync()

try
{
while (_webSocket?.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested)
while (ws.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested)
{
var result = await _webSocket.ReceiveAsync(
var result = await ws.ReceiveAsync(
new ArraySegment<byte>(buffer, 0, ReceiveBufferSize), _cts.Token);
if (!IsCurrentConnection(ws, connectionGeneration))
{
break;
}

if (result.MessageType == WebSocketMessageType.Text)
{
Expand Down Expand Up @@ -211,50 +265,61 @@ private async Task ListenForMessagesAsync()
}
else if (result.MessageType == WebSocketMessageType.Close)
{
var closeStatus = _webSocket.CloseStatus?.ToString() ?? "unknown";
var closeDesc = _webSocket.CloseStatusDescription ?? "no description";
var closeStatus = ws.CloseStatus?.ToString() ?? "unknown";
var closeDesc = ws.CloseStatusDescription ?? "no description";
_logger.Info($"Server closed connection: {closeStatus} - {closeDesc}");
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
if (IsCurrentConnection(ws, connectionGeneration))
{
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
}
break;
}
}
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
{
_logger.Warn("Connection closed prematurely");
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
if (IsCurrentConnection(ws, connectionGeneration))
{
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
}
}
catch (OperationCanceledException) { /* Expected on shutdown/disconnect. */ }
catch (ObjectDisposedException) { /* CTS or WebSocket disposed during shutdown */ }
catch (Exception ex)
{
_logger.Error($"{ClientRole} listen error", ex);
OnError(ex);
RaiseStatusChanged(ConnectionStatus.Error);
if (IsCurrentConnection(ws, connectionGeneration))
{
OnError(ex);
RaiseStatusChanged(ConnectionStatus.Error);
}
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}

// Auto-reconnect if not intentionally disposed
if (!_disposed)
if (IsCurrentConnection(ws, connectionGeneration))
{
try
{
if (!_cts.Token.IsCancellationRequested && ShouldAutoReconnect())
{
await ReconnectWithBackoffAsync();
await ReconnectWithBackoffAsync(ws, connectionGeneration);
}
}
// slopwatch-ignore: SW003 Shutdown cancellation or disposal is expected and the caller already preserves the safe state.
catch (ObjectDisposedException) { /* CTS disposed during check */ }
}
}

protected async Task ReconnectWithBackoffAsync()
protected async Task ReconnectWithBackoffAsync(
ClientWebSocket? expectedSocket = null,
long expectedGeneration = 0)
{
if (Interlocked.CompareExchange(ref _reconnectLoopActive, 1, 0) != 0)
{
Expand All @@ -263,7 +328,10 @@ protected async Task ReconnectWithBackoffAsync()

try
{
while (!_disposed && !_cts.Token.IsCancellationRequested && ShouldAutoReconnect())
while (!_disposed
&& !_cts.Token.IsCancellationRequested
&& ShouldAutoReconnect()
&& IsReconnectOwner(expectedSocket, expectedGeneration))
{
var delay = BackoffMs[Math.Min(_reconnectAttempts, BackoffMs.Length - 1)];
// Add 0-25% jitter to prevent thundering herd when multiple clients
Expand All @@ -276,16 +344,20 @@ protected async Task ReconnectWithBackoffAsync()

await Task.Delay(delay, _cts.Token);

if (_cts.Token.IsCancellationRequested || _disposed || !ShouldAutoReconnect())
if (_cts.Token.IsCancellationRequested
|| _disposed
|| !ShouldAutoReconnect()
|| !IsReconnectOwner(expectedSocket, expectedGeneration))
{
break;
}

// Safely dispose old socket
var oldSocket = _webSocket;
_webSocket = null;
try { oldSocket?.Dispose(); }
catch (Exception ex) { _logger.Debug($"WebSocketClientBase: Dispose of old WebSocket during reconnect threw: {ex.Message}"); }
var oldSocket = expectedSocket ?? _webSocket;
if (oldSocket != null)
{
DisposeStaleSocket(oldSocket);
}

await ConnectAsync();

Expand All @@ -308,6 +380,9 @@ protected async Task ReconnectWithBackoffAsync()
}
}

private bool IsReconnectOwner(ClientWebSocket? expectedSocket, long expectedGeneration) =>
expectedSocket is null || IsCurrentConnection(expectedSocket, expectedGeneration);

/// <summary>Send a text message over the WebSocket. Thread-safe.</summary>
protected async Task SendRawAsync(string message)
{
Expand Down Expand Up @@ -391,6 +466,8 @@ public void Dispose()

OnDisposing();

Interlocked.Increment(ref _connectionGeneration);

try { _cts.Cancel(); }
catch (Exception ex) { _logger.Debug($"{ClientRole} cts.Cancel during Dispose threw: {ex.Message}"); }

Expand Down
Loading
Loading