diff --git a/go.mod b/go.mod index d9fa347..56453b9 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/schollz/progressbar/v3 v3.17.1 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 - github.com/tmc/langchaingo v0.1.13 golang.org/x/net v0.52.0 modernc.org/sqlite v1.29.10 ) @@ -17,7 +16,6 @@ require ( require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect @@ -35,8 +33,6 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/pkg/errors v0.9.1 // indirect - github.com/pkoukk/tiktoken-go v0.1.6 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/go.sum b/go.sum index 1670f4b..c577dae 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= -github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -62,10 +60,6 @@ github.com/pdfcpu/pdfcpu v0.10.0 h1:K7Hv2tW/poMvG+DuowVGfYvNcN1Y7USS+8IebA3Z8+w= github.com/pdfcpu/pdfcpu v0.10.0/go.mod h1:Q2Z3sqdRqHTdIq1mPAUl8nfAoim8p3c1ASOaQ10mCpE= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= -github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -107,8 +101,6 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= -github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/internal/llm/azure.go b/internal/llm/azure.go index 66c8516..069ca73 100644 --- a/internal/llm/azure.go +++ b/internal/llm/azure.go @@ -1,47 +1,33 @@ package llm import ( + "bytes" "context" + "encoding/json" "fmt" + "io" + "net/http" "github.com/RandomCodeSpace/docsgraphcontext/internal/config" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/openai" ) type azureProvider struct { - chatLLM *openai.LLM - embedLLM *openai.LLM + endpoint string + apiKey string + apiVersion string chatModel string embedModel string + client *http.Client } func newAzureProvider(cfg *config.LLMConfig) (Provider, error) { - chatLLM, err := openai.New( - openai.WithAPIType(openai.APITypeAzure), - openai.WithBaseURL(cfg.Azure.Endpoint), - openai.WithToken(cfg.Azure.APIKey), - openai.WithAPIVersion(cfg.Azure.APIVersion), - openai.WithModel(cfg.Azure.ChatModel), - ) - if err != nil { - return nil, fmt.Errorf("azure chat llm: %w", err) - } - embedLLM, err := openai.New( - openai.WithAPIType(openai.APITypeAzure), - openai.WithBaseURL(cfg.Azure.Endpoint), - openai.WithToken(cfg.Azure.APIKey), - openai.WithAPIVersion(cfg.Azure.APIVersion), - openai.WithEmbeddingModel(cfg.Azure.EmbedModel), - ) - if err != nil { - return nil, fmt.Errorf("azure embed llm: %w", err) - } return &azureProvider{ - chatLLM: chatLLM, - embedLLM: embedLLM, + endpoint: cfg.Azure.Endpoint, + apiKey: cfg.Azure.APIKey, + apiVersion: cfg.Azure.APIVersion, chatModel: cfg.Azure.ChatModel, embedModel: cfg.Azure.EmbedModel, + client: &http.Client{}, }, nil } @@ -50,24 +36,65 @@ func (p *azureProvider) ModelID() string { return p.chatModel } func (p *azureProvider) Complete(ctx context.Context, prompt string, opts ...Option) (string, error) { o := applyOptions(opts) - callOpts := []llms.CallOption{ - llms.WithMaxTokens(o.maxTokens), - llms.WithTemperature(o.temperature), + + type message struct { + Role string `json:"role"` + Content string `json:"content"` + } + reqBody := map[string]any{ + "messages": []message{{Role: "user", Content: prompt}}, + "max_tokens": o.maxTokens, + "temperature": o.temperature, } if o.jsonMode { - callOpts = append(callOpts, llms.WithJSONMode()) + reqBody["response_format"] = map[string]string{"type": "json_object"} + } + + body, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("azure complete marshal: %w", err) + } + + url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", + p.endpoint, p.chatModel, p.apiVersion) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("azure complete request: %w", err) } - resp, err := llms.GenerateFromSinglePrompt(ctx, p.chatLLM, prompt, callOpts...) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", p.apiKey) + + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("azure complete: %w", err) } - return resp, nil + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("azure complete HTTP %d: %s", resp.StatusCode, b) + } + + var result struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("azure complete decode: %w", err) + } + if len(result.Choices) == 0 { + return "", fmt.Errorf("azure complete: empty response") + } + return result.Choices[0].Message.Content, nil } func (p *azureProvider) Embed(ctx context.Context, text string) ([]float32, error) { - vecs, err := p.embedLLM.CreateEmbedding(ctx, []string{text}) + vecs, err := p.EmbedBatch(ctx, []string{text}) if err != nil { - return nil, fmt.Errorf("azure embed: %w", err) + return nil, err } if len(vecs) == 0 { return nil, fmt.Errorf("azure embed: empty response") @@ -76,5 +103,44 @@ func (p *azureProvider) Embed(ctx context.Context, text string) ([]float32, erro } func (p *azureProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { - return p.embedLLM.CreateEmbedding(ctx, texts) + reqBody := map[string]any{"input": texts} + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("azure embed marshal: %w", err) + } + + url := fmt.Sprintf("%s/openai/deployments/%s/embeddings?api-version=%s", + p.endpoint, p.embedModel, p.apiVersion) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("azure embed request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", p.apiKey) + + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("azure embed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("azure embed HTTP %d: %s", resp.StatusCode, b) + } + + var result struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("azure embed decode: %w", err) + } + + vecs := make([][]float32, len(result.Data)) + for i, d := range result.Data { + vecs[i] = d.Embedding + } + return vecs, nil } diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index 8b7ad67..907e437 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -1,41 +1,29 @@ package llm import ( + "bytes" "context" + "encoding/json" "fmt" + "io" + "net/http" "github.com/RandomCodeSpace/docsgraphcontext/internal/config" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/ollama" ) type ollamaProvider struct { - chatLLM *ollama.LLM - embedLLM *ollama.LLM + baseURL string chatModel string embedModel string + client *http.Client } func newOllamaProvider(cfg *config.LLMConfig) (Provider, error) { - chatLLM, err := ollama.New( - ollama.WithServerURL(cfg.Ollama.BaseURL), - ollama.WithModel(cfg.Ollama.ChatModel), - ) - if err != nil { - return nil, fmt.Errorf("ollama chat: %w", err) - } - embedLLM, err := ollama.New( - ollama.WithServerURL(cfg.Ollama.BaseURL), - ollama.WithModel(cfg.Ollama.EmbedModel), - ) - if err != nil { - return nil, fmt.Errorf("ollama embed: %w", err) - } return &ollamaProvider{ - chatLLM: chatLLM, - embedLLM: embedLLM, + baseURL: cfg.Ollama.BaseURL, chatModel: cfg.Ollama.ChatModel, embedModel: cfg.Ollama.EmbedModel, + client: &http.Client{}, }, nil } @@ -44,24 +32,55 @@ func (p *ollamaProvider) ModelID() string { return p.chatModel } func (p *ollamaProvider) Complete(ctx context.Context, prompt string, opts ...Option) (string, error) { o := applyOptions(opts) - callOpts := []llms.CallOption{ - llms.WithMaxTokens(o.maxTokens), - llms.WithTemperature(o.temperature), + + reqBody := map[string]any{ + "model": p.chatModel, + "prompt": prompt, + "stream": false, + "options": map[string]any{ + "num_predict": o.maxTokens, + "temperature": o.temperature, + }, } if o.jsonMode { - callOpts = append(callOpts, llms.WithJSONMode()) + reqBody["format"] = "json" } - resp, err := llms.GenerateFromSinglePrompt(ctx, p.chatLLM, prompt, callOpts...) + + body, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("ollama complete marshal: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/generate", bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("ollama complete request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("ollama complete: %w", err) } - return resp, nil + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("ollama complete HTTP %d: %s", resp.StatusCode, b) + } + + var result struct { + Response string `json:"response"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("ollama complete decode: %w", err) + } + return result.Response, nil } func (p *ollamaProvider) Embed(ctx context.Context, text string) ([]float32, error) { - vecs, err := p.embedLLM.CreateEmbedding(ctx, []string{text}) + vecs, err := p.EmbedBatch(ctx, []string{text}) if err != nil { - return nil, fmt.Errorf("ollama embed: %w", err) + return nil, err } if len(vecs) == 0 { return nil, fmt.Errorf("ollama embed: empty response") @@ -70,5 +89,37 @@ func (p *ollamaProvider) Embed(ctx context.Context, text string) ([]float32, err } func (p *ollamaProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { - return p.embedLLM.CreateEmbedding(ctx, texts) + reqBody := map[string]any{ + "model": p.embedModel, + "input": texts, + } + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("ollama embed marshal: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/embed", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("ollama embed request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("ollama embed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("ollama embed HTTP %d: %s", resp.StatusCode, b) + } + + var result struct { + Embeddings [][]float32 `json:"embeddings"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("ollama embed decode: %w", err) + } + return result.Embeddings, nil }