Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion core/model/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type PriceCondition struct {
Resolution []string `json:"resolution,omitempty"`
Quality []string `json:"quality,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
InputMedia *bool `json:"input_media,omitempty"`
InputVideo *bool `json:"input_video,omitempty"`
OutputAudio *bool `json:"output_audio,omitempty"`
}
Expand Down Expand Up @@ -157,6 +158,10 @@ func priceConditionSpecificity(condition PriceCondition) int {
specificity++
}

if condition.InputMedia != nil {
specificity++
}

if condition.OutputAudio != nil {
specificity++
}
Expand Down Expand Up @@ -398,7 +403,8 @@ func (p *Price) ValidateConditionalPrices() error {
continue
}

if !boolConditionOverlap(condition.InputVideo, otherCondition.InputVideo) ||
if !boolConditionOverlap(condition.InputMedia, otherCondition.InputMedia) ||
!boolConditionOverlap(condition.InputVideo, otherCondition.InputVideo) ||
!boolConditionOverlap(condition.OutputAudio, otherCondition.OutputAudio) {
continue
}
Expand Down Expand Up @@ -696,6 +702,7 @@ type UsageContext struct {
NativeResolution string `gorm:"size:32" json:"native_resolution,omitempty"`
Quality string `gorm:"size:32" json:"quality,omitempty"`
ServiceTier string `gorm:"size:32" json:"service_tier,omitempty"`
InputMedia *bool ` json:"input_media,omitempty"`
InputVideo *bool ` json:"input_video,omitempty"`
OutputAudio *bool ` json:"output_audio,omitempty"`
}
Expand Down Expand Up @@ -733,6 +740,12 @@ func (c UsageContext) priceConditionMatches(
return false
}

if condition.InputMedia != nil {
if c.InputMedia == nil || *c.InputMedia != *condition.InputMedia {
return false
}
}

if condition.InputVideo != nil {
if c.InputVideo == nil || *c.InputVideo != *condition.InputVideo {
return false
Expand Down Expand Up @@ -765,6 +778,10 @@ func (c UsageContext) WithFallback(fallback UsageContext) UsageContext {
c.Quality = fallback.Quality
}

if c.InputMedia == nil {
c.InputMedia = fallback.InputMedia
}

if c.InputVideo == nil {
c.InputVideo = fallback.InputVideo
}
Expand Down
29 changes: 29 additions & 0 deletions core/model/usage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,12 @@ func TestPrice_SelectConditionalPrice_WithMediaFlags(t *testing.T) {
price := model.Price{
OutputPrice: 0.20,
ConditionalPrices: []model.ConditionalPrice{
{
Condition: model.PriceCondition{
InputMedia: new(false),
},
Price: model.Price{OutputPrice: 0.012},
},
{
Condition: model.PriceCondition{
Resolution: []string{"720p"},
Expand Down Expand Up @@ -1190,6 +1196,13 @@ func TestPrice_SelectConditionalPrice_WithMediaFlags(t *testing.T) {
t.Fatalf("expected text-only price 0.046, got %v", textOnlyPrice.OutputPrice)
}

pureTextPrice := price.SelectConditionalPrice(model.Usage{}, model.UsageContext{
InputMedia: new(false),
})
if float64(pureTextPrice.OutputPrice) != 0.012 {
t.Fatalf("expected pure text price 0.012, got %v", pureTextPrice.OutputPrice)
}

unknownInputVideoPrice := price.SelectConditionalPrice(model.Usage{}, model.UsageContext{
Resolution: "720p",
})
Expand Down Expand Up @@ -1615,6 +1628,22 @@ func TestPrice_ValidateConditionalPrices_WithMediaConditions(t *testing.T) {
},
wantErr: false,
},
{
name: "same ranges with different input media flags are allowed",
price: model.Price{
ConditionalPrices: []model.ConditionalPrice{
{
Condition: model.PriceCondition{InputMedia: new(false)},
Price: model.Price{OutputPrice: 0.08},
},
{
Condition: model.PriceCondition{InputMedia: new(true)},
Price: model.Price{OutputPrice: 0.04},
},
},
},
wantErr: false,
},
{
name: "same ranges with different output audio flags are allowed",
price: model.Price{
Expand Down
158 changes: 158 additions & 0 deletions core/relay/adaptor/ali/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/bytedance/sonic/ast"
"github.com/gin-gonic/gin"
"github.com/labring/aiproxy/core/common"
coremodel "github.com/labring/aiproxy/core/model"
"github.com/labring/aiproxy/core/relay/adaptor"
"github.com/labring/aiproxy/core/relay/adaptor/openai"
"github.com/labring/aiproxy/core/relay/meta"
Expand Down Expand Up @@ -189,5 +190,162 @@ func ChatHandler(
result.Usage.WebSearchCount++
}

result.UsageContext = aliChatUsageContextWithDefaults(
aliChatUsageContext(result.Usage).
WithFallback(aliChatRequestUsageContext(&node)).
WithFallback(result.UsageContext),
)

return result, nil
}

func aliChatUsageContext(usage coremodel.Usage) coremodel.UsageContext {
usageContext := coremodel.UsageContext{}

if usage.ImageInputTokens > 0 || usage.AudioInputTokens > 0 || usage.VideoInputTokens > 0 {
usageContext.InputMedia = new(bool)
*usageContext.InputMedia = true
}

if usage.VideoInputTokens > 0 {
usageContext.InputVideo = new(bool)
*usageContext.InputVideo = true
}

if usage.AudioOutputTokens > 0 {
usageContext.OutputAudio = new(bool)
*usageContext.OutputAudio = true
}

return usageContext
}

func aliChatRequestUsageContext(node *ast.Node) coremodel.UsageContext {
usageContext := coremodel.UsageContext{}

if aliChatRequestHasInputMedia(node) {
usageContext.InputMedia = new(bool)
*usageContext.InputMedia = true
}

if aliChatRequestHasInputVideo(node) {
usageContext.InputVideo = new(bool)
*usageContext.InputVideo = true
}

if aliChatRequestWantsOutputAudio(node) {
usageContext.OutputAudio = new(bool)
*usageContext.OutputAudio = true
}

return usageContext
}

func aliChatUsageContextWithDefaults(usageContext coremodel.UsageContext) coremodel.UsageContext {
if usageContext.InputMedia == nil {
usageContext.InputMedia = new(bool)
}

if usageContext.OutputAudio == nil {
usageContext.OutputAudio = new(bool)
}

return usageContext
}

func aliChatRequestHasInputMedia(node *ast.Node) bool {
return aliChatRequestContentMatches(node, func(content *ast.Node) bool {
return aliChatContentHasField(content, "image_url") ||
aliChatContentHasField(content, "input_audio") ||
aliChatContentHasField(content, "audio") ||
aliChatContentHasField(content, "video") ||
aliChatContentHasField(content, "video_url") ||
aliChatContentTypeContains(content, "image") ||
aliChatContentTypeContains(content, "audio") ||
aliChatContentTypeContains(content, "video")
})
}

func aliChatRequestHasInputVideo(node *ast.Node) bool {
return aliChatRequestContentMatches(node, func(content *ast.Node) bool {
return aliChatContentHasField(content, "video") ||
aliChatContentHasField(content, "video_url") ||
aliChatContentTypeContains(content, "video")
})
}

func aliChatRequestWantsOutputAudio(node *ast.Node) bool {
modalities := node.Get("modalities")
if modalities.Exists() && modalities.TypeSafe() == ast.V_ARRAY {
hasAudio := false
_ = modalities.ForEach(func(_ ast.Sequence, item *ast.Node) bool {
value, err := item.String()
if err == nil && strings.EqualFold(value, "audio") {
hasAudio = true
return false
}

return true
})

if hasAudio {
return true
}
}

return node.Get("audio").Exists()
}

func aliChatRequestContentMatches(node *ast.Node, match func(*ast.Node) bool) bool {
messages := node.Get("messages")
if !messages.Exists() || messages.TypeSafe() != ast.V_ARRAY {
return false
}

matched := false
_ = messages.ForEach(func(_ ast.Sequence, message *ast.Node) bool {
content := message.Get("content")
if !content.Exists() {
return true
}

if content.TypeSafe() == ast.V_ARRAY {
_ = content.ForEach(func(_ ast.Sequence, item *ast.Node) bool {
if match(item) {
matched = true
return false
}

return true
})

return !matched
}

if match(content) {
matched = true
return false
}

return true
})

return matched
}

func aliChatContentHasField(content *ast.Node, field string) bool {
return content.TypeSafe() == ast.V_OBJECT && content.Get(field).Exists()
}

func aliChatContentTypeContains(content *ast.Node, pattern string) bool {
if content.TypeSafe() != ast.V_OBJECT {
return false
}

contentType, err := content.Get("type").String()
if err != nil {
return false
}

return strings.Contains(strings.ToLower(contentType), pattern)
}
Loading
Loading