Skip to content

Commit c4f0eb2

Browse files
authored
fix: price and usage typed (#55)
* refactor: price and usage model and balance check once * fix: ci lint * fix: meta set input tokens * chore: test no need new relay controller * fix: dynamic check billing enabled * chore: model config price field * fix: token search * feat: usage and price record donot use pointer * fix: ci lint
1 parent 197225a commit c4f0eb2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2879
-2348
lines changed

common/consume/consume.go

+16-93
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,9 @@ func Wait() {
2323
func AsyncConsume(
2424
postGroupConsumer balance.PostGroupConsumer,
2525
code int,
26-
usage *relaymodel.Usage,
2726
meta *meta.Meta,
28-
inputPrice,
29-
outputPrice float64,
30-
cachedPrice float64,
31-
cacheCreationPrice float64,
27+
usage relaymodel.Usage,
28+
modelPrice model.Price,
3229
content string,
3330
ip string,
3431
retryTimes int,
@@ -47,12 +44,9 @@ func AsyncConsume(
4744
context.Background(),
4845
postGroupConsumer,
4946
code,
50-
usage,
5147
meta,
52-
inputPrice,
53-
outputPrice,
54-
cachedPrice,
55-
cacheCreationPrice,
48+
usage,
49+
modelPrice,
5650
content,
5751
ip,
5852
retryTimes,
@@ -65,29 +59,23 @@ func Consume(
6559
ctx context.Context,
6660
postGroupConsumer balance.PostGroupConsumer,
6761
code int,
68-
usage *relaymodel.Usage,
6962
meta *meta.Meta,
70-
inputPrice,
71-
outputPrice float64,
72-
cachedPrice float64,
73-
cacheCreationPrice float64,
63+
usage relaymodel.Usage,
64+
modelPrice model.Price,
7465
content string,
7566
ip string,
7667
retryTimes int,
7768
requestDetail *model.RequestDetail,
7869
downstreamResult bool,
7970
) {
80-
amount := CalculateAmount(usage, inputPrice, outputPrice, cachedPrice, cacheCreationPrice)
71+
amount := CalculateAmount(usage, modelPrice)
8172

8273
amount = consumeAmount(ctx, amount, postGroupConsumer, meta)
8374

8475
err := recordConsume(meta,
8576
code,
8677
usage,
87-
inputPrice,
88-
outputPrice,
89-
cachedPrice,
90-
cacheCreationPrice,
78+
modelPrice,
9179
content,
9280
ip,
9381
requestDetail,
@@ -114,13 +102,9 @@ func consumeAmount(
114102
}
115103

116104
func CalculateAmount(
117-
usage *relaymodel.Usage,
118-
inputPrice, outputPrice, cachedPrice, cacheCreationPrice float64,
105+
usage relaymodel.Usage,
106+
modelPrice model.Price,
119107
) float64 {
120-
if usage == nil {
121-
return 0
122-
}
123-
124108
promptTokens := usage.PromptTokens
125109
completionTokens := usage.CompletionTokens
126110
var cachedTokens int
@@ -130,24 +114,24 @@ func CalculateAmount(
130114
cacheCreationTokens = usage.PromptTokensDetails.CacheCreationTokens
131115
}
132116

133-
if cachedPrice > 0 {
117+
if modelPrice.CachedPrice > 0 {
134118
promptTokens -= cachedTokens
135119
}
136-
if cacheCreationPrice > 0 {
120+
if modelPrice.CacheCreationPrice > 0 {
137121
promptTokens -= cacheCreationTokens
138122
}
139123

140124
promptAmount := decimal.NewFromInt(int64(promptTokens)).
141-
Mul(decimal.NewFromFloat(inputPrice)).
125+
Mul(decimal.NewFromFloat(modelPrice.InputPrice)).
142126
Div(decimal.NewFromInt(model.PriceUnit))
143127
completionAmount := decimal.NewFromInt(int64(completionTokens)).
144-
Mul(decimal.NewFromFloat(outputPrice)).
128+
Mul(decimal.NewFromFloat(modelPrice.OutputPrice)).
145129
Div(decimal.NewFromInt(model.PriceUnit))
146130
cachedAmount := decimal.NewFromInt(int64(cachedTokens)).
147-
Mul(decimal.NewFromFloat(cachedPrice)).
131+
Mul(decimal.NewFromFloat(modelPrice.CachedPrice)).
148132
Div(decimal.NewFromInt(model.PriceUnit))
149133
cacheCreationAmount := decimal.NewFromInt(int64(cacheCreationTokens)).
150-
Mul(decimal.NewFromFloat(cacheCreationPrice)).
134+
Mul(decimal.NewFromFloat(modelPrice.CacheCreationPrice)).
151135
Div(decimal.NewFromInt(model.PriceUnit))
152136

153137
return promptAmount.
@@ -182,64 +166,3 @@ func processGroupConsume(
182166
}
183167
return consumedAmount
184168
}
185-
186-
func recordConsume(
187-
meta *meta.Meta,
188-
code int,
189-
usage *relaymodel.Usage,
190-
inputPrice,
191-
outputPrice float64,
192-
cachedPrice float64,
193-
cacheCreationPrice float64,
194-
content string,
195-
ip string,
196-
requestDetail *model.RequestDetail,
197-
amount float64,
198-
retryTimes int,
199-
downstreamResult bool,
200-
) error {
201-
promptTokens := 0
202-
completionTokens := 0
203-
cachedTokens := 0
204-
cacheCreationTokens := 0
205-
if usage != nil {
206-
promptTokens = usage.PromptTokens
207-
completionTokens = usage.CompletionTokens
208-
if usage.PromptTokensDetails != nil {
209-
cachedTokens = usage.PromptTokensDetails.CachedTokens
210-
cacheCreationTokens = usage.PromptTokensDetails.CacheCreationTokens
211-
}
212-
}
213-
214-
var channelID int
215-
if meta.Channel != nil {
216-
channelID = meta.Channel.ID
217-
}
218-
219-
return model.BatchRecordConsume(
220-
meta.RequestID,
221-
meta.RequestAt,
222-
meta.Group.ID,
223-
code,
224-
channelID,
225-
promptTokens,
226-
completionTokens,
227-
cachedTokens,
228-
cacheCreationTokens,
229-
meta.OriginModel,
230-
meta.Token.ID,
231-
meta.Token.Name,
232-
amount,
233-
inputPrice,
234-
outputPrice,
235-
cachedPrice,
236-
cacheCreationPrice,
237-
meta.Endpoint,
238-
content,
239-
int(meta.Mode),
240-
ip,
241-
retryTimes,
242-
requestDetail,
243-
downstreamResult,
244-
)
245-
}

common/consume/record.go

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package consume
2+
3+
import (
4+
"github.com/labring/aiproxy/model"
5+
"github.com/labring/aiproxy/relay/meta"
6+
relaymodel "github.com/labring/aiproxy/relay/model"
7+
)
8+
9+
func recordConsume(
10+
meta *meta.Meta,
11+
code int,
12+
usage relaymodel.Usage,
13+
modelPrice model.Price,
14+
content string,
15+
ip string,
16+
requestDetail *model.RequestDetail,
17+
amount float64,
18+
retryTimes int,
19+
downstreamResult bool,
20+
) error {
21+
us := model.Usage{
22+
InputTokens: usage.PromptTokens,
23+
OutputTokens: usage.CompletionTokens,
24+
TotalTokens: usage.PromptTokens + usage.CompletionTokens,
25+
}
26+
if usage.PromptTokensDetails != nil {
27+
us.CachedTokens = usage.PromptTokensDetails.CachedTokens
28+
us.CacheCreationTokens = usage.PromptTokensDetails.CacheCreationTokens
29+
}
30+
31+
var channelID int
32+
if meta.Channel != nil {
33+
channelID = meta.Channel.ID
34+
}
35+
36+
return model.BatchRecordConsume(
37+
meta.RequestID,
38+
meta.RequestAt,
39+
meta.Group.ID,
40+
code,
41+
channelID,
42+
meta.OriginModel,
43+
meta.Token.ID,
44+
meta.Token.Name,
45+
meta.Endpoint,
46+
content,
47+
int(meta.Mode),
48+
ip,
49+
retryTimes,
50+
requestDetail,
51+
downstreamResult,
52+
us,
53+
modelPrice,
54+
amount,
55+
)
56+
}

controller/channel-test.go

+1-5
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,7 @@ func testSingleModel(mc *model.ModelCaches, channel *model.Channel, modelName st
102102
modelConfig,
103103
meta.WithRequestID(channelTestRequestID),
104104
)
105-
relayController, ok := relayController(m)
106-
if !ok {
107-
return nil, fmt.Errorf("relay mode %d not implemented", m)
108-
}
109-
result := relayController(meta, newc)
105+
result := relayHandler(meta, newc)
110106
success := result.Error == nil
111107
var respStr string
112108
var code int

controller/modelconfig.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
// @Produce json
1717
// @Security ApiKeyAuth
1818
// @Success 200 {object} middleware.APIResponse{data=map[string]any{configs=[]model.ModelConfig,total=int}}
19-
// @Router /api/modelconfigs [get]
19+
// @Router /api/model_configs [get]
2020
func GetModelConfigs(c *gin.Context) {
2121
page, perPage := parsePageParams(c)
2222
_model := c.Query("model")
@@ -39,7 +39,7 @@ func GetModelConfigs(c *gin.Context) {
3939
// @Produce json
4040
// @Security ApiKeyAuth
4141
// @Success 200 {object} middleware.APIResponse{data=[]model.ModelConfig}
42-
// @Router /api/modelconfigs/all [get]
42+
// @Router /api/model_configs/all [get]
4343
func GetAllModelConfigs(c *gin.Context) {
4444
configs, err := model.GetAllModelConfigs()
4545
if err != nil {
@@ -61,7 +61,7 @@ type GetModelConfigsByModelsContainsRequest struct {
6161
// @Produce json
6262
// @Security ApiKeyAuth
6363
// @Success 200 {object} middleware.APIResponse{data=[]model.ModelConfig}
64-
// @Router /api/modelconfigs/contains [post]
64+
// @Router /api/model_configs/contains [post]
6565
func GetModelConfigsByModelsContains(c *gin.Context) {
6666
request := GetModelConfigsByModelsContainsRequest{}
6767
err := c.ShouldBindJSON(&request)
@@ -85,7 +85,7 @@ func GetModelConfigsByModelsContains(c *gin.Context) {
8585
// @Produce json
8686
// @Security ApiKeyAuth
8787
// @Success 200 {object} middleware.APIResponse{data=map[string]any{configs=[]model.ModelConfig,total=int}}
88-
// @Router /api/modelconfigs/search [get]
88+
// @Router /api/model_configs/search [get]
8989
func SearchModelConfigs(c *gin.Context) {
9090
keyword := c.Query("keyword")
9191
page, perPage := parsePageParams(c)

0 commit comments

Comments
 (0)