diff --git a/cli/lms_go.go b/cli/lms_go.go index 5e837ad..362df04 100644 --- a/cli/lms_go.go +++ b/cli/lms_go.go @@ -1,12 +1,15 @@ package main import ( + "context" "encoding/json" // Added this import "flag" "fmt" "os" "os/signal" "strings" + "syscall" + "time" "github.com/hypernetix/lmstudio-go/pkg/lmstudio" ) @@ -195,61 +198,108 @@ func truncateString(s string, maxLen int) string { } // loadModelWithProgress loads a model and displays a progress bar with model information -func loadModelWithProgress(client *lmstudio.LMStudioClient, modelIdentifier string, logger lmstudio.Logger) error { - var modelInfo *lmstudio.Model - var modelDisplayed bool - var lastProgress float64 = -1 - - // Use the client's LoadModelWithProgress method - err := client.LoadModelWithProgress(modelIdentifier, func(progress float64, info *lmstudio.Model) { - // Display model info on first callback - if !modelDisplayed { - modelInfo = info - if modelInfo != nil { - quietPrintf("Loading model \"%s\" (size: %s, format: %s) ...\n", modelInfo.ModelKey, formatSize(modelInfo.Size), modelInfo.Format) - if modelInfo.Size > 0 { - // Extract format from model info for display +func loadModelWithProgress(client *lmstudio.LMStudioClient, loadTimeout time.Duration, modelIdentifier string, logger lmstudio.Logger) error { + // Create a context that can be cancelled + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling for Ctrl+C cancellation + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + // Channel to communicate completion or error + done := make(chan error, 1) + + // Start the loading process in a goroutine + go func() { + var modelInfo *lmstudio.Model + var modelDisplayed bool + var lastProgress float64 = -1 + + // Use the client's LoadModelWithProgressContext method + err := client.LoadModelWithProgressContext(ctx, loadTimeout, modelIdentifier, func(progress float64, info *lmstudio.Model) { + // Display model info on first callback + if !modelDisplayed { + modelInfo = info + if modelInfo != nil { format := modelInfo.Format - if format == "" && modelInfo.Path != "" { - if strings.Contains(modelInfo.Path, "MLX") { - format = "MLX" - } else if strings.Contains(modelInfo.Path, "GGUF") { - format = "GGUF" + if modelInfo.Size > 0 { + // Extract format from model info for display + if format == "" && modelInfo.Path != "" { + if strings.Contains(modelInfo.Path, "MLX") { + format = "MLX" + } else if strings.Contains(modelInfo.Path, "GGUF") { + format = "GGUF" + } } } - - // Display size and format like in the screenshot - sizeStr := formatSize(modelInfo.Size) - if format != "" { - quietPrintf("Model: %s (%s)\n", sizeStr, format) - } else { - quietPrintf("Model: %s\n", sizeStr) - } + quietPrintf("Loading model \"%s\" (size: %s, format: %s) ...\n", modelInfo.ModelKey, formatSize(modelInfo.Size), format) + } else { + quietPrintf("Loading model \"%s\" ...\n", modelIdentifier) } + modelDisplayed = true + } + + // Only update progress if it increased significantly to avoid flickering + if progress > lastProgress+0.001 || progress >= 1.0 { + displayProgressBar(progress) + lastProgress = progress + } + + // If model was already loaded, show completion immediately + if progress >= 1.0 { + quietPrintf("\n✓ Model loaded successfully\n") + } + }) + + done <- err + }() + + // Wait for either completion or cancellation signal + select { + case err := <-done: + // Loading completed (successfully or with error) + if err != nil { + if strings.Contains(err.Error(), "timed out") { + quietPrintf("\n⏰ Model loading timed out\n") + } else if strings.Contains(err.Error(), "cancelled") { + quietPrintf("\n⚠ Model loading cancelled\n") } else { - quietPrintf("Loading model \"%s\" ...\n", modelIdentifier) + quietPrintf("\nFailed to load model: %v\n", err) } - modelDisplayed = true + return err } + return nil + + case sig := <-sigChan: + // User pressed Ctrl+C or sent termination signal + logger.Debug("Received signal: %v", sig) - // Only update progress if it increased significantly to avoid flickering - if progress > lastProgress+0.001 || progress >= 1.0 { - displayProgressBar(progress) - lastProgress = progress + // Clear progress bar if displayed + if !quietMode { + fmt.Printf("\r%s\r", strings.Repeat(" ", 80)) // Clear the line } - // If model was already loaded, show completion immediately - if progress >= 1.0 { - quietPrintf("\n✓ Model loaded successfully\n") + quietPrintf("\n⚠ Model loading cancelled by user\n") + + // Cancel the context to stop the loading operation + cancel() + + // Wait a short time for graceful cancellation + select { + case <-done: + // Loading operation acknowledged the cancellation + case <-func() <-chan struct{} { + timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), time.Second*2) + defer timeoutCancel() + return timeoutCtx.Done() + }(): + // Timeout waiting for graceful cancellation + logger.Debug("Timeout waiting for loading cancellation") } - }) - if err != nil { - quietPrintf("\nFailed to load model: %v\n", err) - return err + return fmt.Errorf("model loading cancelled by user") } - - return nil } // displayProgressBar shows a progress bar similar to the screenshot @@ -308,6 +358,7 @@ func main() { jsonOutput := flag.Bool("json", false, "Output list commands in JSON format") quiet := flag.Bool("q", false, "Quiet mode - suppress all stdout messages except JSON output and errors") quietLong := flag.Bool("quiet", false, "Quiet mode - suppress all stdout messages except JSON output and errors") + loadTimeout := flag.Duration("timeout", 120*time.Second, "Timeout for loading a model") // Parse command line flags flag.Parse() @@ -463,7 +514,7 @@ func main() { // Load a model if *loadModel != "" { operation = true - if err := loadModelWithProgress(client, *loadModel, logger); err != nil { + if err := loadModelWithProgress(client, *loadTimeout, *loadModel, logger); err != nil { logger.Error("Failed to load model: %v", err) os.Exit(1) } diff --git a/pkg/lmstudio/conn.go b/pkg/lmstudio/conn.go index ba54bb5..d4f2f7e 100644 --- a/pkg/lmstudio/conn.go +++ b/pkg/lmstudio/conn.go @@ -23,14 +23,15 @@ type ChannelHandler interface { // namespaceConnection represents a connection to a specific LM Studio namespace type namespaceConnection struct { - logger Logger - namespace string - conn *websocket.Conn - nextID int - pendingCalls map[int]chan json.RawMessage - activeChannels map[int]ChannelHandler // Interface for different channel types - connected bool - mu sync.Mutex + logger Logger + namespace string + conn *websocket.Conn + nextID int + pendingCalls map[int]chan json.RawMessage + pendingUnloadCalls map[int]bool // Track unload calls to avoid logging errors for their responses + activeChannels map[int]ChannelHandler // Interface for different channel types + connected bool + mu sync.Mutex } // connect establishes a connection to a specific LM Studio namespace @@ -242,10 +243,22 @@ func (nc *namespaceConnection) handleMessages(ctx context.Context) { if msgType == "rpcResult" || msgType == "rpcError" { if callID, ok := msg["callId"].(float64); ok { nc.mu.Lock() - if ch, exists := nc.pendingCalls[int(callID)]; exists { + callIDInt := int(callID) + + // Check if this is a response to an unload call that we should ignore + isUnloadCall := nc.pendingUnloadCalls != nil && nc.pendingUnloadCalls[callIDInt] + if isUnloadCall { + // Clean up the pending unload call tracking + delete(nc.pendingUnloadCalls, callIDInt) + nc.logger.Debug("Received response for unload call ID %d from %s (ignoring)", callIDInt, nc.namespace) + nc.mu.Unlock() + continue + } + + if ch, exists := nc.pendingCalls[callIDInt]; exists { if msgType == "rpcResult" { ch <- message - delete(nc.pendingCalls, int(callID)) + delete(nc.pendingCalls, callIDInt) } else if msgType == "rpcError" { // Handle error responses if errObj, ok := msg["error"].(map[string]interface{}); ok { @@ -265,10 +278,10 @@ func (nc *namespaceConnection) handleMessages(ctx context.Context) { } } ch <- message - delete(nc.pendingCalls, int(callID)) + delete(nc.pendingCalls, callIDInt) } } else { - nc.logger.Error("Received response for unknown call ID %d from %s", int(callID), nc.namespace) + nc.logger.Error("Received response for unknown call ID %d from %s", callIDInt, nc.namespace) } nc.mu.Unlock() continue diff --git a/pkg/lmstudio/lmstudio_client.go b/pkg/lmstudio/lmstudio_client.go index 9f2c78f..606f236 100644 --- a/pkg/lmstudio/lmstudio_client.go +++ b/pkg/lmstudio/lmstudio_client.go @@ -3,7 +3,6 @@ package lmstudio import ( "context" "encoding/json" - "errors" "fmt" "math/rand" "sync" @@ -288,8 +287,29 @@ func (c *LMStudioClient) LoadModel(modelIdentifier string) error { return c.waitForModelLoading(channel, modelIdentifier, loadTimeout) } -// LoadModelWithProgress loads a specified model in LM Studio with progress reporting -func (c *LMStudioClient) LoadModelWithProgress(modelIdentifier string, progressCallback func(progress float64, modelInfo *Model)) error { +// LoadModelWithProgress loads a specified model in LM Studio with progress reporting and cancellation support +func (c *LMStudioClient) LoadModelWithProgress( + loadTimeout time.Duration, + modelIdentifier string, + progressCallback func(progress float64, modelInfo *Model), +) error { + return c.LoadModelWithProgressContext(context.Background(), loadTimeout, modelIdentifier, progressCallback) +} + +// LoadModelWithProgressContext loads a specified model in LM Studio with progress reporting and cancellation support via context +func (c *LMStudioClient) LoadModelWithProgressContext( + ctx context.Context, + loadTimeout time.Duration, + modelIdentifier string, + progressCallback func(progress float64, modelInfo *Model), +) error { + // Check if context is already cancelled + select { + case <-ctx.Done(): + return fmt.Errorf("model loading cancelled before start: %w", ctx.Err()) + default: + } + // Get model information from downloaded models var modelInfo *Model downloadedModels, err := c.ListDownloadedModels() @@ -338,8 +358,46 @@ func (c *LMStudioClient) LoadModelWithProgress(modelIdentifier string, progressC } // Use a longer timeout for model loading - some large models can take several minutes - loadTimeout := 120 * time.Second - return c.waitForModelLoading(channel, modelIdentifier, loadTimeout) + return c.waitForModelLoadingWithContext(ctx, channel, modelIdentifier, loadTimeout) +} + +// waitForModelLoadingWithContext waits for a model to finish loading with cancellation support +func (c *LMStudioClient) waitForModelLoadingWithContext(ctx context.Context, channel *ModelLoadingChannel, modelIdentifier string, loadTimeout time.Duration) error { + c.logger.Debug("Waiting for model %s to load (timeout: %d seconds)...", + modelIdentifier, int(loadTimeout.Seconds())) + + // Create a context that combines the parent context with timeout + timeoutCtx, cancel := context.WithTimeout(ctx, loadTimeout) + defer cancel() + + // Use the channel's context-aware wait method + result, err := channel.WaitForResultWithContext(timeoutCtx, loadTimeout) + + if err != nil { + // Check if this was a timeout vs other cancellation + if timeoutCtx.Err() == context.DeadlineExceeded { + c.logger.Error("Model loading timed out after %v for %s", loadTimeout, modelIdentifier) + return fmt.Errorf("model loading timed out after %v", loadTimeout) + } + + // Check if this was a cancellation from parent context + if ctx.Err() != nil { + c.logger.Debug("Model loading cancelled for %s: %v", modelIdentifier, ctx.Err()) + return fmt.Errorf("model loading cancelled: %w", ctx.Err()) + } + + // Other error + c.logger.Error("Model loading failed for %s: %v", modelIdentifier, err) + return fmt.Errorf("model loading failed: %w", err) + } + + if !result.Success { + c.logger.Error("Model %s failed to load", modelIdentifier) + return fmt.Errorf("model %s failed to load", modelIdentifier) + } + + c.logger.Debug("Model %s loaded successfully with identifier: %s", modelIdentifier, result.Identifier) + return nil } // UnloadModel unloads a specified model in LM Studio @@ -482,7 +540,7 @@ func (c *LMStudioClient) CheckStatus() (bool, error) { // If we can connect, check if we can make a simple API call _, err = conn.RemoteCall(ModelListDownloadedEndpoint, nil) if err != nil { - return false, errors.New("service is running but API is not responding correctly: " + err.Error()) + return false, fmt.Errorf("service is running but API is not responding correctly: %w", err) } // If we get here, the service is running and responding to API calls diff --git a/pkg/lmstudio/lmstudio_client_test.go b/pkg/lmstudio/lmstudio_client_test.go index 8d9843f..6fcc73f 100644 --- a/pkg/lmstudio/lmstudio_client_test.go +++ b/pkg/lmstudio/lmstudio_client_test.go @@ -1,6 +1,7 @@ package lmstudio import ( + "context" "fmt" "net/url" "strings" @@ -387,43 +388,31 @@ func TestLoadModelWithProgress(t *testing.T) { logger.Debug("Progress callback: %.2f%%, model: %v", progress*100, modelInfo != nil) } - // Call the method we're testing - err = client.LoadModelWithProgress("mock-model-0.5B", progressCallback) + // Call the method + err = client.LoadModelWithProgress(30*time.Second, "mock-model-0.5B", progressCallback) if err != nil { t.Fatalf("LoadModelWithProgress failed: %v", err) } - // Verify that progress callbacks were called + // Wait a moment to ensure all progress callbacks are processed + time.Sleep(100 * time.Millisecond) + + // Verify we got the expected number of progress callbacks callbackMutex.Lock() if len(progressCallbacks) == 0 { - t.Errorf("Expected progress callbacks, got none") + t.Errorf("Expected at least one progress callback, got %d", len(progressCallbacks)) } - // Verify progress values are reasonable (between 0 and 1) + // Check that progress values are in valid range for i, progress := range progressCallbacks { - if progress < 0 || progress > 1 { - t.Errorf("Progress callback %d has invalid value: %f (should be between 0 and 1)", i, progress) + if progress < 0.0 || progress > 1.0 { + t.Errorf("Progress callback %d has invalid value: %f (should be 0.0-1.0)", i, progress) } } - // Verify we got at least one progress update - if len(progressCallbacks) < 1 { - t.Errorf("Expected at least 1 progress callback, got %d", len(progressCallbacks)) - } - - // Verify model info was provided when available - hasModelInfo := false - for _, modelInfo := range modelInfoCallbacks { - if modelInfo != nil { - hasModelInfo = true - if modelInfo.ModelKey == "" { - t.Errorf("Expected model info to have ModelKey") - } - break - } - } - if !hasModelInfo { - t.Errorf("Expected at least one callback with model info") + // Check that the last progress value is 1.0 (completion) + if len(progressCallbacks) > 0 && progressCallbacks[len(progressCallbacks)-1] != 1.0 { + t.Errorf("Last progress value should be 1.0, got %f", progressCallbacks[len(progressCallbacks)-1]) } callbackMutex.Unlock() @@ -439,7 +428,7 @@ func TestLoadModelWithProgress(t *testing.T) { } } -// TestLoadModelWithProgressAlreadyLoaded tests LoadModelWithProgress when model is already loaded +// TestLoadModelWithProgressAlreadyLoaded tests LoadModelWithProgress with an already loaded model func TestLoadModelWithProgressAlreadyLoaded(t *testing.T) { fmt.Println("[TEST] TestLoadModelWithProgressAlreadyLoaded started") defer fmt.Println("[TEST] TestLoadModelWithProgressAlreadyLoaded finished or failed") @@ -459,21 +448,19 @@ func TestLoadModelWithProgressAlreadyLoaded(t *testing.T) { client := NewLMStudioClient(strings.TrimPrefix(serverURL.Host, "http://"), logger) defer client.Close() - // Track progress callbacks + // Track progress callbacks - for already loaded model we should get one callback with 100% var progressCallbacks []float64 - var modelInfoCallbacks []*Model var callbackMutex sync.Mutex progressCallback := func(progress float64, modelInfo *Model) { callbackMutex.Lock() defer callbackMutex.Unlock() progressCallbacks = append(progressCallbacks, progress) - modelInfoCallbacks = append(modelInfoCallbacks, modelInfo) - logger.Debug("Progress callback for already loaded: %.2f%%, model: %v", progress*100, modelInfo != nil) + logger.Debug("Progress callback: %.2f%%", progress*100) } - // Use a model that appears in the loaded models list (mock-model-7B) - err = client.LoadModelWithProgress("mock-model-7B", progressCallback) + // Call with a model that is already loaded (mock-model-7B) + err = client.LoadModelWithProgress(30*time.Second, "mock-model-7B", progressCallback) if err != nil { t.Fatalf("LoadModelWithProgress failed for already loaded model: %v", err) } @@ -511,8 +498,226 @@ func TestLoadModelWithProgressNilCallback(t *testing.T) { defer client.Close() // Call the method with nil callback (should not crash) - err = client.LoadModelWithProgress("mock-model-0.5B", nil) + err = client.LoadModelWithProgress(30*time.Second, "mock-model-0.5B", nil) if err != nil { t.Fatalf("LoadModelWithProgress with nil callback failed: %v", err) } } + +// TestLoadModelWithProgressContextCancellation tests LoadModelWithProgressContext with cancellation +func TestLoadModelWithProgressContextCancellation(t *testing.T) { + fmt.Println("[TEST] TestLoadModelWithProgressContextCancellation started") + defer fmt.Println("[TEST] TestLoadModelWithProgressContextCancellation finished or failed") + + // Create a mock server + server := NewMockLMStudioService(t, newMockLogger()) + defer server.Close() + + // Extract the host and port from the server URL + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + + // Create a client that connects to our mock server + logger := newMockLogger() + client := NewLMStudioClient(strings.TrimPrefix(serverURL.Host, "http://"), logger) + defer client.Close() + + // Create a context that will be cancelled + ctx, cancel := context.WithCancel(context.Background()) + + // Track progress callbacks + var progressCallbacks []float64 + var callbackMutex sync.Mutex + + progressCallback := func(progress float64, modelInfo *Model) { + callbackMutex.Lock() + defer callbackMutex.Unlock() + progressCallbacks = append(progressCallbacks, progress) + logger.Debug("Progress callback before cancellation: %.2f%%", progress*100) + + // Cancel the context after receiving first progress update + if len(progressCallbacks) == 1 { + logger.Debug("Cancelling context after first progress update") + cancel() + } + } + + // Call the method with cancellation context + err = client.LoadModelWithProgressContext(ctx, 10*time.Second, "mock-model-0.5B", progressCallback) + + // Should receive a cancellation error + if err == nil { + t.Fatalf("Expected cancellation error, got nil") + } + + if !strings.Contains(err.Error(), "cancelled") { + t.Errorf("Expected cancellation error, got: %v", err) + } + + // Verify we got at least one progress callback before cancellation + callbackMutex.Lock() + if len(progressCallbacks) == 0 { + t.Errorf("Expected at least one progress callback before cancellation") + } + callbackMutex.Unlock() +} + +// TestLoadModelWithProgressContextTimeout tests LoadModelWithProgressContext with timeout +func TestLoadModelWithProgressContextTimeout(t *testing.T) { + fmt.Println("[TEST] TestLoadModelWithProgressContextTimeout started") + defer fmt.Println("[TEST] TestLoadModelWithProgressContextTimeout finished or failed") + + // Create a mock server + server := NewMockLMStudioService(t, newMockLogger()) + defer server.Close() + + // Extract the host and port from the server URL + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + + // Create a client that connects to our mock server + logger := newMockLogger() + client := NewLMStudioClient(strings.TrimPrefix(serverURL.Host, "http://"), logger) + defer client.Close() + + // Create a context that is already cancelled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel it immediately + + // Call the method with cancelled context + err = client.LoadModelWithProgressContext(ctx, 30*time.Second, "mock-model-0.5B", nil) + + // Should receive a cancellation error + if err == nil { + t.Fatalf("Expected cancellation error, got nil") + } + + if !strings.Contains(err.Error(), "cancelled") { + t.Errorf("Expected cancellation error, got: %v", err) + } +} + +// TestLoadModelWithProgressCancellation tests model load cancellation during loading +func TestLoadModelWithProgressCancellation(t *testing.T) { + fmt.Println("[TEST] TestLoadModelWithProgressCancellation started") + defer fmt.Println("[TEST] TestLoadModelWithProgressCancellation finished or failed") + + // Create a mock server + server := NewMockLMStudioService(t, newMockLogger()) + defer server.Close() + + // Extract the host and port from the server URL + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + + // Create a client that connects to our mock server + logger := newMockLogger() + client := NewLMStudioClient(strings.TrimPrefix(serverURL.Host, "http://"), logger) + defer client.Close() + + // Create a context that we'll cancel after first progress update + ctx, cancel := context.WithCancel(context.Background()) + + // Track progress callbacks + var progressCallbacks []float64 + var callbackMutex sync.Mutex + var cancelled bool + + progressCallback := func(progress float64, modelInfo *Model) { + callbackMutex.Lock() + defer callbackMutex.Unlock() + progressCallbacks = append(progressCallbacks, progress) + logger.Debug("Progress callback: %.2f%%", progress*100) + + // Cancel after receiving the first progress update + if len(progressCallbacks) == 1 && !cancelled { + logger.Debug("Cancelling context after first progress update") + cancelled = true + cancel() + } + } + + // Call the method with cancellation context + err = client.LoadModelWithProgressContext(ctx, 30*time.Second, "mock-model-0.5B", progressCallback) + + // Should receive a cancellation error + if err == nil { + t.Fatalf("Expected cancellation error, got nil") + } + + if !strings.Contains(err.Error(), "cancelled") { + t.Errorf("Expected cancellation error, got: %v", err) + } + + // Verify we got at least one progress callback before cancellation + callbackMutex.Lock() + if len(progressCallbacks) == 0 { + t.Errorf("Expected at least one progress callback before cancellation") + } + callbackMutex.Unlock() +} + +// TestLoadModelWithProgressShortTimeout tests LoadModelWithProgress with a very short 1ms timeout +func TestLoadModelWithProgressShortTimeout(t *testing.T) { + fmt.Println("[TEST] TestLoadModelWithProgressShortTimeout started") + defer fmt.Println("[TEST] TestLoadModelWithProgressShortTimeout finished or failed") + + // Create a mock server + server := NewMockLMStudioService(t, newMockLogger()) + defer server.Close() + + // Extract the host and port from the server URL + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + + // Create a client that connects to our mock server + logger := newMockLogger() + client := NewLMStudioClient(strings.TrimPrefix(serverURL.Host, "http://"), logger) + defer client.Close() + + // Track progress callbacks + var progressCallbacks []float64 + var callbackMutex sync.Mutex + + progressCallback := func(progress float64, modelInfo *Model) { + callbackMutex.Lock() + defer callbackMutex.Unlock() + progressCallbacks = append(progressCallbacks, progress) + logger.Debug("Progress callback: %.2f%%", progress*100) + } + + // Create a context that's already expired (deadline in the past) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Hour)) + defer cancel() + + // Call the method with expired context - this should fail immediately + err = client.LoadModelWithProgressContext(ctx, 30*time.Second, "mock-model-0.5B", progressCallback) + + // Should receive a cancellation error + if err == nil { + t.Fatalf("Expected timeout/cancellation error, got nil") + } + + if !strings.Contains(err.Error(), "cancelled") && !strings.Contains(err.Error(), "deadline exceeded") { + t.Errorf("Expected timeout or cancellation error, got: %v", err) + } + + // Verify that the operation was cancelled quickly - should have no progress callbacks + callbackMutex.Lock() + progressCount := len(progressCallbacks) + callbackMutex.Unlock() + + logger.Debug("Received %d progress callbacks with expired context", progressCount) + // With expired context, we should get 0 callbacks + if progressCount > 0 { + t.Errorf("Expected no progress callbacks due to expired context, got %d", progressCount) + } +} diff --git a/pkg/lmstudio/model_loading.go b/pkg/lmstudio/model_loading.go index af661d1..dd4853d 100644 --- a/pkg/lmstudio/model_loading.go +++ b/pkg/lmstudio/model_loading.go @@ -173,11 +173,14 @@ func (ch *ModelLoadingChannel) processMessage(message []byte) { // Model resolution event - log but don't take any action ch.conn.logger.Trace("Channel %d: model resolved", ch.channelID) case "success": - // Model loaded successfully - ch.conn.logger.Trace("Channel %d: success message received", ch.channelID) + // Model loading completed successfully if info, ok := messageContent["info"].(map[string]interface{}); ok { if identifier, ok := info["identifier"].(string); ok { ch.conn.logger.Debug("Channel %d: model loaded with identifier %s", ch.channelID, identifier) + + // Ensure progress reaches 100% when model loading completes + ch.forceUpdateProgress(1.0) + ch.mu.Lock() ch.isFinished = true ch.mu.Unlock() @@ -207,6 +210,9 @@ func (ch *ModelLoadingChannel) processMessage(message []byte) { ch.conn.logger.Debug("Channel %d success message received", ch.channelID) if content, ok := msg["content"].(map[string]interface{}); ok { if identifier, ok := content["identifier"].(string); ok { + // Ensure progress reaches 100% when model loading completes + ch.forceUpdateProgress(1.0) + ch.mu.Lock() ch.isFinished = true ch.mu.Unlock() @@ -271,13 +277,41 @@ func (ch *ModelLoadingChannel) updateProgress(progress float64) { } } +// forceUpdateProgress updates progress and always calls the callback, used for completion +func (ch *ModelLoadingChannel) forceUpdateProgress(progress float64) { + ch.mu.Lock() + defer ch.mu.Unlock() + + ch.lastProgress = progress + + // Always call progress function if provided, regardless of previous progress + if ch.progressFn != nil { + ch.progressFn(progress) + } +} + // WaitForResult waits for the model loading to complete func (ch *ModelLoadingChannel) WaitForResult(timeout time.Duration) (*ModelLoadingResult, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + return ch.WaitForResultWithContext(context.Background(), timeout) +} + +// WaitForResultWithContext waits for the model loading to complete with cancellation support +func (ch *ModelLoadingChannel) WaitForResultWithContext(ctx context.Context, timeout time.Duration) (*ModelLoadingResult, error) { + // Check if context is already cancelled + select { + case <-ctx.Done(): + ch.sendCancellationWithCleanup() + ch.conn.logger.Debug("Model loading cancelled before wait: %v", ctx.Err()) + return nil, fmt.Errorf("model loading cancelled: %w", ctx.Err()) + default: + } + + // Create a timeout context that respects the parent context + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - // Log occasional updates during waiting - ticker := time.NewTicker(5 * time.Second) + // Log occasional updates during waiting (every 2 seconds instead of 5 for better responsiveness) + ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() for { @@ -290,26 +324,122 @@ func (ch *ModelLoadingChannel) WaitForResult(timeout time.Duration) (*ModelLoadi ch.conn.logger.Error("Received error for channel %d (model: %s): %v", ch.channelID, ch.modelKey, err) return nil, err + case <-ctx.Done(): + // Parent context cancelled - immediate cancellation + ch.conn.logger.Debug("Parent context cancelled for channel %d (model: %s) - sending immediate cancellation", ch.channelID, ch.modelKey) + ch.sendCancellationWithCleanup() + return nil, fmt.Errorf("model loading cancelled: %w", ctx.Err()) + + case <-timeoutCtx.Done(): + if timeoutCtx.Err() == context.DeadlineExceeded { + ch.conn.logger.Debug("Model loading timed out for channel %d after %v", ch.channelID, timeout) + ch.sendCancellationWithCleanup() + return nil, fmt.Errorf("model loading timed out after %v", timeout) + } + // This should not happen since we handle parent context cancellation above + ch.conn.logger.Debug("Timeout context cancelled for channel %d: %v", ch.channelID, timeoutCtx.Err()) + ch.sendCancellationWithCleanup() + return nil, fmt.Errorf("model loading cancelled: %w", timeoutCtx.Err()) + case <-ticker.C: - // Periodically check for messages from server - ch.conn.logger.Trace("Channel %d (model: %s) waiting for messages... (%.1f%% done)", - ch.channelID, ch.modelKey, ch.lastProgress*100) + // Check context more frequently and log connection status + select { + case <-ctx.Done(): + ch.conn.logger.Debug("Parent context cancelled during ticker for channel %d", ch.channelID) + ch.sendCancellationWithCleanup() + return nil, fmt.Errorf("model loading cancelled: %w", ctx.Err()) + default: + } - // Check connection status + // Check connection status periodically ch.conn.mu.Lock() isConnected := ch.conn.connected ch.conn.mu.Unlock() if !isConnected { + ch.conn.logger.Debug("Connection lost while waiting for model to load") return nil, fmt.Errorf("connection lost while waiting for model to load") } - - case <-ctx.Done(): - return nil, fmt.Errorf("model loading timed out after %v", timeout) } } } +// sendCancellationWithCleanup sends cancellation and performs thorough cleanup +func (ch *ModelLoadingChannel) sendCancellationWithCleanup() { + ch.mu.Lock() + defer ch.mu.Unlock() + + if ch.isFinished { + return // Already finished + } + + ch.isFinished = true + + ch.conn.logger.Debug("Attempting to cancel model loading for channel %d (model: %s)", ch.channelID, ch.modelKey) + + // First, try to unload the model that's being loaded to force-stop the loading process + // This is more aggressive than just closing the channel and ensures proper cancellation + unloadCallID := ch.channelID + 10000 // Use a unique call ID + unloadMsg := map[string]interface{}{ + "type": "rpcCall", + "callId": unloadCallID, + "endpoint": "unloadModel", + "parameter": map[string]interface{}{ + "identifier": ch.modelKey, + }, + } + + ch.conn.logger.Debug("Sending unload request to force-stop loading for model: %s (call ID: %d)", ch.modelKey, unloadCallID) + + // Track this call ID so we don't log errors for the response + ch.conn.mu.Lock() + if ch.conn.pendingUnloadCalls == nil { + ch.conn.pendingUnloadCalls = make(map[int]bool) + } + ch.conn.pendingUnloadCalls[unloadCallID] = true + ch.conn.mu.Unlock() + + // Send the unload request (best effort, don't wait for response) + err := ch.conn.conn.WriteJSON(unloadMsg) + if err != nil { + ch.conn.logger.Debug("Failed to send unload request for model %s: %v", ch.modelKey, err) + } else { + ch.conn.logger.Debug("Sent unload request for model %s", ch.modelKey) + } + + // Small delay to allow the unload request to be processed + time.Sleep(100 * time.Millisecond) + + // Now send channel close message to cleanup the channel + closeMsg := map[string]interface{}{ + "type": "channelClose", + "channelId": ch.channelID, + } + + ch.conn.logger.Debug("Sending channel close for channel %d", ch.channelID) + + // Send the close message + err = ch.conn.conn.WriteJSON(closeMsg) + if err != nil { + ch.conn.logger.Debug("Failed to send channel close message for channel %d: %v", ch.channelID, err) + } else { + ch.conn.logger.Debug("Sent channel close message for channel %d", ch.channelID) + } + + // Clean up the channel from active channels + ch.conn.mu.Lock() + delete(ch.conn.activeChannels, ch.channelID) + ch.conn.mu.Unlock() + + // Signal the message handler to stop if not already closed + select { + case <-ch.cancelCh: + // Already closed + default: + close(ch.cancelCh) + } +} + // Close closes the model loading channel func (ch *ModelLoadingChannel) Close() error { ch.mu.Lock()