-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
165 lines (139 loc) · 4.78 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
// Package main provides an example of using the TabbyAPI client to stream model loading progress.
package main
import (
"context"
"fmt"
"io"
"log"
"os"
"time"
"github.com/pixelsquared/go-tabbyapi/tabby"
)
func main() {
// Get API endpoint and key from environment variables
endpoint := getEnvOrDefault("TABBY_API_ENDPOINT", "http://localhost:8080")
adminKey := os.Getenv("TABBY_ADMIN_KEY") // Admin key required for model management
// Create a new TabbyAPI client with admin privileges
client := tabby.NewClient(
tabby.WithBaseURL(endpoint),
tabby.WithAdminKey(adminKey), // Note: Model management requires admin access
tabby.WithTimeout(300*time.Second), // Longer timeout for model loading
)
// Ensure the client is closed properly
defer client.Close()
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
defer cancel()
// Define the model to load
// Change this to a model name that's available in your server's model directory
modelName := getEnvOrDefault("TABBY_MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2")
fmt.Printf("Preparing to load model: %s\n", modelName)
fmt.Println("Press Ctrl+C to cancel at any time")
// Create a model load request
loadReq := &tabby.ModelLoadRequest{
ModelName: modelName,
MaxSeqLen: 4096, // Optional: Context length
RopeScale: 1.0, // Optional: RoPE scaling factor
CacheSize: 2000, // Optional: KV cache size in MB
}
// Initialize the streaming model load
fmt.Println("Starting model loading process with streaming progress...")
stream, err := client.Models().LoadStream(ctx, loadReq)
if err != nil {
log.Fatalf("Error initializing model load stream: %v", err)
}
// Ensure the stream is closed properly
defer stream.Close()
// Track progress
var lastModule int
var totalModules int
startTime := time.Now()
// Process the stream
for {
// Receive the next update from the stream
response, err := stream.Recv()
// Check for end of stream or errors
if err != nil {
if err == io.EOF {
// Normal end of stream
break
}
// Check if it's a stream closed error from the client
if err == tabby.ErrStreamClosed {
fmt.Println("\nStream was closed")
break
}
// Handle other errors
log.Fatalf("Error receiving from stream: %v", err)
}
// Store total modules count when we first get it
if totalModules == 0 && response.Modules > 0 {
totalModules = response.Modules
}
// Update progress
currentModule := response.Module
// Only print if there's been a change in module number
if currentModule != lastModule {
elapsedTime := time.Since(startTime)
progress := float64(currentModule) / float64(totalModules) * 100.0
fmt.Printf("Loading module %d of %d (%.1f%%) - Model type: %s - Status: %s - Elapsed: %s\n",
currentModule, totalModules, progress, response.ModelType, response.Status, elapsedTime.Round(time.Second))
lastModule = currentModule
}
// If we get a status that indicates completion, break the loop
if response.Status == "loaded" && currentModule >= totalModules {
break
}
}
// Final confirmation
totalTime := time.Since(startTime).Round(time.Second)
fmt.Printf("\nModel loading completed in %s\n", totalTime)
// Verify the model is loaded by getting the current model
currentModel, err := client.Models().Get(ctx)
if err != nil {
fmt.Printf("Error verifying loaded model: %v\n", err)
} else if currentModel != nil {
fmt.Printf("\nSuccessfully loaded model: %s\n", currentModel.ID)
// Get model properties
props, err := client.Models().GetProps(ctx)
if err != nil {
fmt.Printf("Error getting model properties: %v\n", err)
} else {
fmt.Printf("Model context length: %d tokens\n", props.DefaultGenerationSettings.NCtx)
}
}
// Example of testing the model with a simple completion
fmt.Println("\nTesting the model with a simple completion...")
testCompletion(client, ctx)
}
// testCompletion tests the loaded model with a simple completion request
func testCompletion(client tabby.Client, ctx context.Context) {
// Create a simple completion request
req := &tabby.CompletionRequest{
Prompt: "Hello, world!",
MaxTokens: 20,
Temperature: 0.7,
}
// Call the API
resp, err := client.Completions().Create(ctx, req)
if err != nil {
fmt.Printf("Error testing model with completion: %v\n", err)
return
}
// Print the result
if len(resp.Choices) > 0 {
fmt.Println("\nModel test response:")
fmt.Printf("Prompt: \"Hello, world!\"\n")
fmt.Printf("Response: \"%s\"\n", resp.Choices[0].Text)
} else {
fmt.Println("No completion text was generated")
}
}
// getEnvOrDefault returns the value of the environment variable or a default value
func getEnvOrDefault(key, defaultValue string) string {
value := os.Getenv(key)
if value == "" {
return defaultValue
}
return value
}