Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 94 additions & 43 deletions cli/lms_go.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package main

import (
"context"
"encoding/json" // Added this import
"flag"
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"

"github.com/hypernetix/lmstudio-go/pkg/lmstudio"
)
Expand Down Expand Up @@ -195,61 +198,108 @@ func truncateString(s string, maxLen int) string {
}

// loadModelWithProgress loads a model and displays a progress bar with model information
func loadModelWithProgress(client *lmstudio.LMStudioClient, modelIdentifier string, logger lmstudio.Logger) error {
var modelInfo *lmstudio.Model
var modelDisplayed bool
var lastProgress float64 = -1

// Use the client's LoadModelWithProgress method
err := client.LoadModelWithProgress(modelIdentifier, func(progress float64, info *lmstudio.Model) {
// Display model info on first callback
if !modelDisplayed {
modelInfo = info
if modelInfo != nil {
quietPrintf("Loading model \"%s\" (size: %s, format: %s) ...\n", modelInfo.ModelKey, formatSize(modelInfo.Size), modelInfo.Format)
if modelInfo.Size > 0 {
// Extract format from model info for display
func loadModelWithProgress(client *lmstudio.LMStudioClient, loadTimeout time.Duration, modelIdentifier string, logger lmstudio.Logger) error {
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Set up signal handling for Ctrl+C cancellation
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)

// Channel to communicate completion or error
done := make(chan error, 1)

// Start the loading process in a goroutine
go func() {
var modelInfo *lmstudio.Model
var modelDisplayed bool
var lastProgress float64 = -1

// Use the client's LoadModelWithProgressContext method
err := client.LoadModelWithProgressContext(ctx, loadTimeout, modelIdentifier, func(progress float64, info *lmstudio.Model) {
// Display model info on first callback
if !modelDisplayed {
modelInfo = info
if modelInfo != nil {
format := modelInfo.Format
if format == "" && modelInfo.Path != "" {
if strings.Contains(modelInfo.Path, "MLX") {
format = "MLX"
} else if strings.Contains(modelInfo.Path, "GGUF") {
format = "GGUF"
if modelInfo.Size > 0 {
// Extract format from model info for display
if format == "" && modelInfo.Path != "" {
if strings.Contains(modelInfo.Path, "MLX") {
format = "MLX"
} else if strings.Contains(modelInfo.Path, "GGUF") {
format = "GGUF"
}
}
}

// Display size and format like in the screenshot
sizeStr := formatSize(modelInfo.Size)
if format != "" {
quietPrintf("Model: %s (%s)\n", sizeStr, format)
} else {
quietPrintf("Model: %s\n", sizeStr)
}
quietPrintf("Loading model \"%s\" (size: %s, format: %s) ...\n", modelInfo.ModelKey, formatSize(modelInfo.Size), format)
} else {
quietPrintf("Loading model \"%s\" ...\n", modelIdentifier)
}
modelDisplayed = true
}

// Only update progress if it increased significantly to avoid flickering
if progress > lastProgress+0.001 || progress >= 1.0 {
displayProgressBar(progress)
lastProgress = progress
}

// If model was already loaded, show completion immediately
if progress >= 1.0 {
quietPrintf("\n✓ Model loaded successfully\n")
}
})

done <- err
}()

// Wait for either completion or cancellation signal
select {
case err := <-done:
// Loading completed (successfully or with error)
if err != nil {
if strings.Contains(err.Error(), "timed out") {
quietPrintf("\n⏰ Model loading timed out\n")
} else if strings.Contains(err.Error(), "cancelled") {
quietPrintf("\n⚠ Model loading cancelled\n")
} else {
quietPrintf("Loading model \"%s\" ...\n", modelIdentifier)
quietPrintf("\nFailed to load model: %v\n", err)
}
modelDisplayed = true
return err
}
return nil

case sig := <-sigChan:
// User pressed Ctrl+C or sent termination signal
logger.Debug("Received signal: %v", sig)

// Only update progress if it increased significantly to avoid flickering
if progress > lastProgress+0.001 || progress >= 1.0 {
displayProgressBar(progress)
lastProgress = progress
// Clear progress bar if displayed
if !quietMode {
fmt.Printf("\r%s\r", strings.Repeat(" ", 80)) // Clear the line
}

// If model was already loaded, show completion immediately
if progress >= 1.0 {
quietPrintf("\n✓ Model loaded successfully\n")
quietPrintf("\n⚠ Model loading cancelled by user\n")

// Cancel the context to stop the loading operation
cancel()

// Wait a short time for graceful cancellation
select {
case <-done:
// Loading operation acknowledged the cancellation
case <-func() <-chan struct{} {
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), time.Second*2)
defer timeoutCancel()
return timeoutCtx.Done()
}():
// Timeout waiting for graceful cancellation
logger.Debug("Timeout waiting for loading cancellation")
}
})

if err != nil {
quietPrintf("\nFailed to load model: %v\n", err)
return err
return fmt.Errorf("model loading cancelled by user")
}

return nil
}

// displayProgressBar shows a progress bar similar to the screenshot
Expand Down Expand Up @@ -308,6 +358,7 @@ func main() {
jsonOutput := flag.Bool("json", false, "Output list commands in JSON format")
quiet := flag.Bool("q", false, "Quiet mode - suppress all stdout messages except JSON output and errors")
quietLong := flag.Bool("quiet", false, "Quiet mode - suppress all stdout messages except JSON output and errors")
loadTimeout := flag.Duration("timeout", 120*time.Second, "Timeout for loading a model")

// Parse command line flags
flag.Parse()
Expand Down Expand Up @@ -463,7 +514,7 @@ func main() {
// Load a model
if *loadModel != "" {
operation = true
if err := loadModelWithProgress(client, *loadModel, logger); err != nil {
if err := loadModelWithProgress(client, *loadTimeout, *loadModel, logger); err != nil {
logger.Error("Failed to load model: %v", err)
os.Exit(1)
}
Expand Down
37 changes: 25 additions & 12 deletions pkg/lmstudio/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ type ChannelHandler interface {

// namespaceConnection represents a connection to a specific LM Studio namespace
type namespaceConnection struct {
logger Logger
namespace string
conn *websocket.Conn
nextID int
pendingCalls map[int]chan json.RawMessage
activeChannels map[int]ChannelHandler // Interface for different channel types
connected bool
mu sync.Mutex
logger Logger
namespace string
conn *websocket.Conn
nextID int
pendingCalls map[int]chan json.RawMessage
pendingUnloadCalls map[int]bool // Track unload calls to avoid logging errors for their responses
activeChannels map[int]ChannelHandler // Interface for different channel types
connected bool
mu sync.Mutex
}

// connect establishes a connection to a specific LM Studio namespace
Expand Down Expand Up @@ -242,10 +243,22 @@ func (nc *namespaceConnection) handleMessages(ctx context.Context) {
if msgType == "rpcResult" || msgType == "rpcError" {
if callID, ok := msg["callId"].(float64); ok {
nc.mu.Lock()
if ch, exists := nc.pendingCalls[int(callID)]; exists {
callIDInt := int(callID)

// Check if this is a response to an unload call that we should ignore
isUnloadCall := nc.pendingUnloadCalls != nil && nc.pendingUnloadCalls[callIDInt]
if isUnloadCall {
// Clean up the pending unload call tracking
delete(nc.pendingUnloadCalls, callIDInt)
nc.logger.Debug("Received response for unload call ID %d from %s (ignoring)", callIDInt, nc.namespace)
nc.mu.Unlock()
continue
}

if ch, exists := nc.pendingCalls[callIDInt]; exists {
if msgType == "rpcResult" {
ch <- message
delete(nc.pendingCalls, int(callID))
delete(nc.pendingCalls, callIDInt)
} else if msgType == "rpcError" {
// Handle error responses
if errObj, ok := msg["error"].(map[string]interface{}); ok {
Expand All @@ -265,10 +278,10 @@ func (nc *namespaceConnection) handleMessages(ctx context.Context) {
}
}
ch <- message
delete(nc.pendingCalls, int(callID))
delete(nc.pendingCalls, callIDInt)
}
} else {
nc.logger.Error("Received response for unknown call ID %d from %s", int(callID), nc.namespace)
nc.logger.Error("Received response for unknown call ID %d from %s", callIDInt, nc.namespace)
}
nc.mu.Unlock()
continue
Expand Down
70 changes: 64 additions & 6 deletions pkg/lmstudio/lmstudio_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package lmstudio
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/rand"
"sync"
Expand Down Expand Up @@ -288,8 +287,29 @@ func (c *LMStudioClient) LoadModel(modelIdentifier string) error {
return c.waitForModelLoading(channel, modelIdentifier, loadTimeout)
}

// LoadModelWithProgress loads a specified model in LM Studio with progress reporting
func (c *LMStudioClient) LoadModelWithProgress(modelIdentifier string, progressCallback func(progress float64, modelInfo *Model)) error {
// LoadModelWithProgress loads a specified model in LM Studio with progress reporting and cancellation support
func (c *LMStudioClient) LoadModelWithProgress(
loadTimeout time.Duration,
modelIdentifier string,
progressCallback func(progress float64, modelInfo *Model),
) error {
return c.LoadModelWithProgressContext(context.Background(), loadTimeout, modelIdentifier, progressCallback)
}

// LoadModelWithProgressContext loads a specified model in LM Studio with progress reporting and cancellation support via context
func (c *LMStudioClient) LoadModelWithProgressContext(
ctx context.Context,
loadTimeout time.Duration,
modelIdentifier string,
progressCallback func(progress float64, modelInfo *Model),
) error {
// Check if context is already cancelled
select {
case <-ctx.Done():
return fmt.Errorf("model loading cancelled before start: %w", ctx.Err())
default:
}

// Get model information from downloaded models
var modelInfo *Model
downloadedModels, err := c.ListDownloadedModels()
Expand Down Expand Up @@ -338,8 +358,46 @@ func (c *LMStudioClient) LoadModelWithProgress(modelIdentifier string, progressC
}

// Use a longer timeout for model loading - some large models can take several minutes
loadTimeout := 120 * time.Second
return c.waitForModelLoading(channel, modelIdentifier, loadTimeout)
return c.waitForModelLoadingWithContext(ctx, channel, modelIdentifier, loadTimeout)
}

// waitForModelLoadingWithContext waits for a model to finish loading with cancellation support
func (c *LMStudioClient) waitForModelLoadingWithContext(ctx context.Context, channel *ModelLoadingChannel, modelIdentifier string, loadTimeout time.Duration) error {
c.logger.Debug("Waiting for model %s to load (timeout: %d seconds)...",
modelIdentifier, int(loadTimeout.Seconds()))

// Create a context that combines the parent context with timeout
timeoutCtx, cancel := context.WithTimeout(ctx, loadTimeout)
defer cancel()

// Use the channel's context-aware wait method
result, err := channel.WaitForResultWithContext(timeoutCtx, loadTimeout)

if err != nil {
// Check if this was a timeout vs other cancellation
if timeoutCtx.Err() == context.DeadlineExceeded {
c.logger.Error("Model loading timed out after %v for %s", loadTimeout, modelIdentifier)
return fmt.Errorf("model loading timed out after %v", loadTimeout)
}

// Check if this was a cancellation from parent context
if ctx.Err() != nil {
c.logger.Debug("Model loading cancelled for %s: %v", modelIdentifier, ctx.Err())
return fmt.Errorf("model loading cancelled: %w", ctx.Err())
}

// Other error
c.logger.Error("Model loading failed for %s: %v", modelIdentifier, err)
return fmt.Errorf("model loading failed: %w", err)
}

if !result.Success {
c.logger.Error("Model %s failed to load", modelIdentifier)
return fmt.Errorf("model %s failed to load", modelIdentifier)
}

c.logger.Debug("Model %s loaded successfully with identifier: %s", modelIdentifier, result.Identifier)
return nil
}

// UnloadModel unloads a specified model in LM Studio
Expand Down Expand Up @@ -482,7 +540,7 @@ func (c *LMStudioClient) CheckStatus() (bool, error) {
// If we can connect, check if we can make a simple API call
_, err = conn.RemoteCall(ModelListDownloadedEndpoint, nil)
if err != nil {
return false, errors.New("service is running but API is not responding correctly: " + err.Error())
return false, fmt.Errorf("service is running but API is not responding correctly: %w", err)
}

// If we get here, the service is running and responding to API calls
Expand Down
Loading