-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprompt_blending.py
183 lines (143 loc) · 5.24 KB
/
prompt_blending.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import modules.scripts as scripts
import modules.prompt_parser as prompt_parser
import itertools
import torch
def hijacked_get_learned_conditioning(model, prompts, steps):
global real_get_learned_conditioning
if not hasattr(model, '__hacked'):
real_model_func = model.get_learned_conditioning
def hijacked_model_func(texts):
weighted_prompts = list(map(lambda t: get_weighted_prompt((t, 1)), texts))
all_texts = []
for weighted_prompt in weighted_prompts:
for (prompt, weight) in weighted_prompt:
all_texts.append(prompt)
if len(all_texts) > len(texts):
all_conds = real_model_func(all_texts)
offset = 0
conds = []
for weighted_prompt in weighted_prompts:
c = torch.zeros_like(all_conds[offset])
for (i, (prompt, weight)) in enumerate(weighted_prompt):
c = torch.add(c, all_conds[i+offset], alpha=weight)
conds.append(c)
offset += len(weighted_prompt)
return conds
else:
return real_model_func(texts)
model.get_learned_conditioning = hijacked_model_func
model.__hacked = True
switched_prompts = list(map(lambda p: switch_syntax(p), prompts))
return real_get_learned_conditioning(model, switched_prompts, steps)
real_get_learned_conditioning = hijacked_get_learned_conditioning # no really, overriden below
class Script(scripts.Script):
def title(self):
return "Prompt Blending"
def show(self, is_img2img):
global real_get_learned_conditioning
if real_get_learned_conditioning == hijacked_get_learned_conditioning:
real_get_learned_conditioning = prompt_parser.get_learned_conditioning
prompt_parser.get_learned_conditioning = hijacked_get_learned_conditioning
return False
def ui(self, is_img2img):
return []
def run(self, p, seeds):
return
OPEN = '{'
CLOSE = '}'
SEPARATE = '&'
MARK = '@'
REAL_MARK = ':'
def combine(left, right):
return map(lambda p: (p[0][0] + p[1][0], p[0][1] * p[1][1]), itertools.product(left, right))
def get_weighted_prompt(prompt_weight):
(prompt, full_weight) = prompt_weight
results = [('', full_weight)]
alts = []
start = 0
mark = -1
open_count = 0
first_open = 0
nested = False
for i, c in enumerate(prompt):
add_alt = False
do_combine = False
if c == OPEN:
open_count += 1
if open_count == 1:
first_open = i
results = list(combine(results, [(prompt[start:i], 1)]))
start = i + 1
else:
nested = True
if c == MARK and open_count == 1:
mark = i
if c == SEPARATE and open_count == 1:
add_alt = True
if c == CLOSE:
open_count -= 1
if open_count == 0:
add_alt = True
do_combine = True
if i == len(prompt) - 1 and open_count > 0:
add_alt = True
do_combine = True
if add_alt:
end = i
weight = 1
if mark != -1:
weight_str = prompt[mark + 1:i]
try:
weight = float(weight_str)
end = mark
except ValueError:
print("warning, not a number:", weight_str)
alt = (prompt[start:end], weight)
alts += get_weighted_prompt(alt) if nested else [alt]
nested = False
mark = -1
start = i + 1
if do_combine:
if len(alts) <= 1:
alts = [(prompt[first_open:i + 1], 1)]
results = list(combine(results, alts))
alts = []
# rest of the prompt
results = list(combine(results, [(prompt[start:], 1)]))
weight_sum = sum(map(lambda r: r[1], results))
results = list(map(lambda p: (p[0], p[1] / weight_sum * full_weight), results))
return results
def switch_syntax(prompt):
p = list(prompt)
stack = []
for i, c in enumerate(p):
if c == '{' or c == '[' or c == '(':
stack.append(c)
if len(stack) > 0:
if c == '}' or c == ']' or c == ')':
stack.pop()
if c == REAL_MARK and stack[-1] == '{':
p[i] = MARK
return "".join(p)
# def test(p, w=1):
# print('')
# print(p)
# result = get_weighted_prompt((p, w))
# print(result)
# print(sum(map(lambda x: x[1], result)))
#
#
# test("fantasy landscape")
# test("fantasy {landscape|city}, dark")
# test("fantasy {landscape|city}, {fire|ice} ")
# test("fantasy {landscape|city}, {fire|ice}, {dark|light} ")
# test("fantasy landscape, {{fire|lava}|ice}")
# test("fantasy landscape, {{fire@4|lava@1}|ice@2}")
# test("fantasy landscape, {{fire@error|lava@1}|ice@2}")
# test("fantasy landscape, {{fire|lava}|ice@2")
# test("fantasy landscape, {fire|lava} {cool} {ice,water}")
# test("fantasy landscape, {fire|lava} {cool} {ice,water")
# test("{lava|ice|water@5}")
# test("{fire@4|lava@1}", 5)
# test("{{fire@4|lava@1}|ice@2|water@5}")
# test("{fire|[email protected]}")