-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_client.py
executable file
·110 lines (92 loc) · 3.11 KB
/
test_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python3
"""
Test client for Yarn API
"""
import argparse
import json
import requests
def search_videos(
query,
max_frames=5,
top_k=3,
frame_mode="independent",
image_model="sd",
api_url="http://localhost:8000"
):
"""
Search for videos using the Yarn API
Args:
query: Text query to search for
max_frames: Maximum number of frames to generate
top_k: Number of results to return
frame_mode: Frame generation mode ("independent" or "continuous")
image_model: Image generation model ("sd" for Stable Diffusion)
api_url: Base URL of the Yarn API
Returns:
List of search results
"""
endpoint = f"{api_url}/api/search/"
payload = {
"query": query,
"max_frames": max_frames,
"top_k": top_k,
"frame_mode": frame_mode,
"image_model": image_model,
"embedding_model": "clip"
}
headers = {
"Content-Type": "application/json"
}
response = requests.post(endpoint, json=payload, headers=headers)
if response.status_code == 200:
return response.json()
else:
print(f"Error: {response.status_code}")
print(response.text)
return None
def main():
parser = argparse.ArgumentParser(description="Test client for Yarn API")
parser.add_argument("--query", type=str, required=True, help="Text query to search for")
parser.add_argument("--max-frames", type=int, default=5, help="Maximum number of frames to generate")
parser.add_argument("--top-k", type=int, default=3, help="Number of results to return")
parser.add_argument(
"--frame-mode",
type=str,
choices=["independent", "continuous"],
default="independent",
help="Frame generation mode (independent or continuous)"
)
parser.add_argument(
"--image-model",
type=str,
choices=["sd"],
default="sd",
help="Image generation model (sd for Stable Diffusion)"
)
parser.add_argument("--api-url", type=str, default="http://localhost:8000", help="Base URL of the Yarn API")
parser.add_argument("--output", type=str, help="Output file to save results (JSON format)")
args = parser.parse_args()
print(f"Searching for: {args.query}")
print(f"Frame generation mode: {args.frame_mode}")
print(f"Image generation model: {args.image_model}")
results = search_videos(
query=args.query,
max_frames=args.max_frames,
top_k=args.top_k,
frame_mode=args.frame_mode,
image_model=args.image_model,
api_url=args.api_url
)
if results:
print("\nSearch Results:")
for i, result in enumerate(results):
print(f"\n{i + 1}. Video: {result['video_id']} (Score: {result['score']:.4f})")
# Save to output file if specified
if args.output:
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {args.output}")
else:
print("No results found")
if __name__ == "__main__":
main()