Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ require (
github.com/schollz/progressbar/v3 v3.17.1
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
github.com/tmc/langchaingo v0.1.13
golang.org/x/net v0.52.0
modernc.org/sqlite v1.29.10
)

require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
Expand All @@ -35,8 +33,6 @@ require (
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
Expand Down
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
Expand Down Expand Up @@ -62,10 +60,6 @@ github.com/pdfcpu/pdfcpu v0.10.0 h1:K7Hv2tW/poMvG+DuowVGfYvNcN1Y7USS+8IebA3Z8+w=
github.com/pdfcpu/pdfcpu v0.10.0/go.mod h1:Q2Z3sqdRqHTdIq1mPAUl8nfAoim8p3c1ASOaQ10mCpE=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down Expand Up @@ -107,8 +101,6 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
Expand Down
136 changes: 101 additions & 35 deletions internal/llm/azure.go
Original file line number Diff line number Diff line change
@@ -1,47 +1,33 @@
package llm

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/RandomCodeSpace/docsgraphcontext/internal/config"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
)

type azureProvider struct {
chatLLM *openai.LLM
embedLLM *openai.LLM
endpoint string
apiKey string
apiVersion string
chatModel string
embedModel string
client *http.Client
}

func newAzureProvider(cfg *config.LLMConfig) (Provider, error) {
chatLLM, err := openai.New(
openai.WithAPIType(openai.APITypeAzure),
openai.WithBaseURL(cfg.Azure.Endpoint),
openai.WithToken(cfg.Azure.APIKey),
openai.WithAPIVersion(cfg.Azure.APIVersion),
openai.WithModel(cfg.Azure.ChatModel),
)
if err != nil {
return nil, fmt.Errorf("azure chat llm: %w", err)
}
embedLLM, err := openai.New(
openai.WithAPIType(openai.APITypeAzure),
openai.WithBaseURL(cfg.Azure.Endpoint),
openai.WithToken(cfg.Azure.APIKey),
openai.WithAPIVersion(cfg.Azure.APIVersion),
openai.WithEmbeddingModel(cfg.Azure.EmbedModel),
)
if err != nil {
return nil, fmt.Errorf("azure embed llm: %w", err)
}
return &azureProvider{
chatLLM: chatLLM,
embedLLM: embedLLM,
endpoint: cfg.Azure.Endpoint,
apiKey: cfg.Azure.APIKey,
apiVersion: cfg.Azure.APIVersion,
chatModel: cfg.Azure.ChatModel,
embedModel: cfg.Azure.EmbedModel,
client: &http.Client{},
}, nil
}

Expand All @@ -50,24 +36,65 @@ func (p *azureProvider) ModelID() string { return p.chatModel }

func (p *azureProvider) Complete(ctx context.Context, prompt string, opts ...Option) (string, error) {
o := applyOptions(opts)
callOpts := []llms.CallOption{
llms.WithMaxTokens(o.maxTokens),
llms.WithTemperature(o.temperature),

type message struct {
Role string `json:"role"`
Content string `json:"content"`
}
reqBody := map[string]any{
"messages": []message{{Role: "user", Content: prompt}},
"max_tokens": o.maxTokens,
"temperature": o.temperature,
}
if o.jsonMode {
callOpts = append(callOpts, llms.WithJSONMode())
reqBody["response_format"] = map[string]string{"type": "json_object"}
}

body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("azure complete marshal: %w", err)
}

url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s",
p.endpoint, p.chatModel, p.apiVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("azure complete request: %w", err)
}
resp, err := llms.GenerateFromSinglePrompt(ctx, p.chatLLM, prompt, callOpts...)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", p.apiKey)

resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("azure complete: %w", err)
}
return resp, nil
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("azure complete HTTP %d: %s", resp.StatusCode, b)
}

var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("azure complete decode: %w", err)
}
if len(result.Choices) == 0 {
return "", fmt.Errorf("azure complete: empty response")
}
return result.Choices[0].Message.Content, nil
}

func (p *azureProvider) Embed(ctx context.Context, text string) ([]float32, error) {
vecs, err := p.embedLLM.CreateEmbedding(ctx, []string{text})
vecs, err := p.EmbedBatch(ctx, []string{text})
if err != nil {
return nil, fmt.Errorf("azure embed: %w", err)
return nil, err
}
if len(vecs) == 0 {
return nil, fmt.Errorf("azure embed: empty response")
Expand All @@ -76,5 +103,44 @@ func (p *azureProvider) Embed(ctx context.Context, text string) ([]float32, erro
}

func (p *azureProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
return p.embedLLM.CreateEmbedding(ctx, texts)
reqBody := map[string]any{"input": texts}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("azure embed marshal: %w", err)
}

url := fmt.Sprintf("%s/openai/deployments/%s/embeddings?api-version=%s",
p.endpoint, p.embedModel, p.apiVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("azure embed request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", p.apiKey)

resp, err := p.client.Do(req)
if err != nil {
return nil, fmt.Errorf("azure embed: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("azure embed HTTP %d: %s", resp.StatusCode, b)
}

var result struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("azure embed decode: %w", err)
}

vecs := make([][]float32, len(result.Data))
for i, d := range result.Data {
vecs[i] = d.Embedding
}
return vecs, nil
}
109 changes: 80 additions & 29 deletions internal/llm/ollama.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,29 @@
package llm

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/RandomCodeSpace/docsgraphcontext/internal/config"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
)

type ollamaProvider struct {
chatLLM *ollama.LLM
embedLLM *ollama.LLM
baseURL string
chatModel string
embedModel string
client *http.Client
}

func newOllamaProvider(cfg *config.LLMConfig) (Provider, error) {
chatLLM, err := ollama.New(
ollama.WithServerURL(cfg.Ollama.BaseURL),
ollama.WithModel(cfg.Ollama.ChatModel),
)
if err != nil {
return nil, fmt.Errorf("ollama chat: %w", err)
}
embedLLM, err := ollama.New(
ollama.WithServerURL(cfg.Ollama.BaseURL),
ollama.WithModel(cfg.Ollama.EmbedModel),
)
if err != nil {
return nil, fmt.Errorf("ollama embed: %w", err)
}
return &ollamaProvider{
chatLLM: chatLLM,
embedLLM: embedLLM,
baseURL: cfg.Ollama.BaseURL,
chatModel: cfg.Ollama.ChatModel,
embedModel: cfg.Ollama.EmbedModel,
client: &http.Client{},
}, nil
}

Expand All @@ -44,24 +32,55 @@ func (p *ollamaProvider) ModelID() string { return p.chatModel }

func (p *ollamaProvider) Complete(ctx context.Context, prompt string, opts ...Option) (string, error) {
o := applyOptions(opts)
callOpts := []llms.CallOption{
llms.WithMaxTokens(o.maxTokens),
llms.WithTemperature(o.temperature),

reqBody := map[string]any{
"model": p.chatModel,
"prompt": prompt,
"stream": false,
"options": map[string]any{
"num_predict": o.maxTokens,
"temperature": o.temperature,
},
}
if o.jsonMode {
callOpts = append(callOpts, llms.WithJSONMode())
reqBody["format"] = "json"
}
resp, err := llms.GenerateFromSinglePrompt(ctx, p.chatLLM, prompt, callOpts...)

body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("ollama complete marshal: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/generate", bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("ollama complete request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("ollama complete: %w", err)
}
return resp, nil
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("ollama complete HTTP %d: %s", resp.StatusCode, b)
}

var result struct {
Response string `json:"response"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("ollama complete decode: %w", err)
}
return result.Response, nil
}

func (p *ollamaProvider) Embed(ctx context.Context, text string) ([]float32, error) {
vecs, err := p.embedLLM.CreateEmbedding(ctx, []string{text})
vecs, err := p.EmbedBatch(ctx, []string{text})
if err != nil {
return nil, fmt.Errorf("ollama embed: %w", err)
return nil, err
}
if len(vecs) == 0 {
return nil, fmt.Errorf("ollama embed: empty response")
Expand All @@ -70,5 +89,37 @@ func (p *ollamaProvider) Embed(ctx context.Context, text string) ([]float32, err
}

func (p *ollamaProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
return p.embedLLM.CreateEmbedding(ctx, texts)
reqBody := map[string]any{
"model": p.embedModel,
"input": texts,
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("ollama embed marshal: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/embed", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("ollama embed request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := p.client.Do(req)
if err != nil {
return nil, fmt.Errorf("ollama embed: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("ollama embed HTTP %d: %s", resp.StatusCode, b)
}

var result struct {
Embeddings [][]float32 `json:"embeddings"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("ollama embed decode: %w", err)
}
return result.Embeddings, nil
}
Loading