-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
140 lines (121 loc) · 5.17 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import requests
import io
from PIL import Image, UnidentifiedImageError
from flask import Flask, render_template, request, send_file, jsonify, url_for
import os
import tempfile
from datetime import datetime
import sqlite3
app = Flask(__name__)
# Configuration
API_URL = "https://ghostunblocker.vercel.app//proxy/https://api-inference.huggingface.co/models/Artples/LAI-ImageGeneration-vSDXL-2"
HEADERS = {"Authorization": "Bearer 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"}
# Initialize database
def init_db():
conn = sqlite3.connect('images.db')
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS images
(id INTEGER PRIMARY KEY AUTOINCREMENT,
prompt TEXT,
filepath TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''')
conn.commit()
conn.close()
def query(payload):
try:
# Increase timeout to 180 seconds (3 minutes)
response = requests.post(API_URL, headers=HEADERS, json=payload, timeout=180)
return response.content
except requests.exceptions.Timeout:
raise Exception("Image generation timed out. The server is probably busy, please try again.")
except requests.exceptions.RequestException as e:
raise Exception(f"Network error: {str(e)}")
@app.route("/", methods=["GET"])
def home():
# Get recent images from database
conn = sqlite3.connect('images.db')
c = conn.cursor()
c.execute("SELECT * FROM images ORDER BY created_at DESC LIMIT 8")
recent_images = c.fetchall()
conn.close()
return render_template("index.html", recent_images=recent_images)
@app.route("/generate", methods=["POST"])
def generate():
input_prompt = request.form.get("prompt")
size = request.form.get("size", "512x512")
num_images = min(int(request.form.get("num_images", 1)), 4)
# Get additional parameters from form
num_inference_steps = int(request.form.get("num_inference_steps", 30))
guidance_scale = float(request.form.get("guidance_scale", 7.5))
negative_prompt = request.form.get("negative_prompt", "blurry, bad quality, worst quality, low quality")
if not input_prompt:
return jsonify({"error": "No prompt provided"}), 400
try:
width, height = map(int, size.split('x'))
payload = {
"inputs": input_prompt,
"parameters": {
"width": width,
"height": height,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"negative_prompt": negative_prompt,
"num_outputs": num_images,
"return_full_object": False
}
}
# Get the response from the API
response = requests.post(API_URL, headers=HEADERS, json=payload, timeout=180)
# Handle various error cases
if response.status_code == 401:
return jsonify({"error": "API Authorization failed. Please check your API token."}), 401
elif response.status_code == 403:
return jsonify({"error": "API access forbidden. Please verify your API token has correct permissions."}), 403
elif response.status_code == 504:
return jsonify({"error": "The request timed out. The server might be busy, please try again."}), 504
elif response.status_code != 200:
error_msg = f"API Error: {response.status_code}"
try:
error_msg = response.json().get('error', error_msg)
except:
pass
return jsonify({"error": error_msg}), response.status_code
# Verify we got image data
image_bytes = response.content
if len(image_bytes) == 0:
return jsonify({"error": "Received empty response from API"}), 500
# Try to open the image data
try:
image = Image.open(io.BytesIO(image_bytes))
except UnidentifiedImageError:
# If we can't open it as an image, try to get error message from response
try:
error_data = response.json()
error_msg = error_data.get('error', 'Invalid image data received')
except:
error_msg = 'Invalid image data received'
return jsonify({"error": error_msg}), 500
# Save image
img_dir = os.path.join(app.static_folder, 'generated')
os.makedirs(img_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"image_{timestamp}.png"
filepath = os.path.join(img_dir, filename)
image.save(filepath, format='PNG')
# Save to database
conn = sqlite3.connect('images.db')
c = conn.cursor()
c.execute("INSERT INTO images (prompt, filepath) VALUES (?, ?)",
(input_prompt, f"generated/{filename}"))
conn.commit()
conn.close()
return jsonify({
"success": True,
"image_url": url_for('static', filename=f'generated/{filename}'),
"prompt": input_prompt
})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
init_db()
app.run(debug=True, host="0.0.0.0", port=8080)