Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow split T5 & CLIP prompts for flux & add a separate T5 token counter #1906

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backend/diffusion_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def decode_first_stage(self, x):
pass

def get_prompt_lengths_on_ui(self, prompt):
return 0, 75
return {
't5': 0,
't5_max': 255,
'clip': 0,
'clip_max': 75,
}

def is_webui_legacy_model(self):
return self.is_sd1 or self.is_sd2 or self.is_sdxl or self.is_sd3
Expand Down
38 changes: 34 additions & 4 deletions backend/diffusion_engine/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,22 @@ def set_clip_skip(self, clip_skip):

@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
prompt_t5 = []
prompt_l = []

for p in prompt:
if 'SPLIT' in p:
before_split, after_split = p.split('SPLIT', 1)
prompt_t5.append(before_split.strip())
prompt_l.append(after_split.strip())
else:
prompt_t5.append(p)
prompt_l.append(p)

memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_l, pooled_l = self.text_processing_engine_l(prompt)
cond_t5 = self.text_processing_engine_t5(prompt)

cond_l, pooled_l = self.text_processing_engine_l(prompt_l)
cond_t5 = self.text_processing_engine_t5(prompt_t5)
cond = dict(crossattn=cond_t5, vector=pooled_l)

if self.use_distilled_cfg_scale:
Expand All @@ -90,8 +103,25 @@ def get_learned_conditioning(self, prompt: list[str]):

@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
return token_count, max(255, token_count)
if 'SPLIT' in prompt:
prompt_t5, prompt_l = prompt.split('SPLIT', 1)
prompt_t5 = prompt_t5.strip()
prompt_l = prompt_l.strip()
else:
prompt_t5 = prompt_l = prompt

t5_token_count = len(self.text_processing_engine_t5.tokenize([prompt_t5])[0])
t5_max_length = max(255, t5_token_count)

_, clip_token_count = self.text_processing_engine_l.process_texts([prompt_l])
clip_max_length = self.text_processing_engine_l.get_target_prompt_token_count(clip_token_count)

return {
't5': t5_token_count,
't5_max': t5_max_length,
'clip': clip_token_count,
'clip_max': clip_max_length,
}

@torch.inference_mode()
def encode_first_stage(self, x):
Expand Down
9 changes: 8 additions & 1 deletion backend/diffusion_engine/sd15.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def get_learned_conditioning(self, prompt: list[str]):
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine.process_texts([prompt])
return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count)
max_length = self.text_processing_engine.get_target_prompt_token_count(token_count)

return {
't5': 0,
't5_max': 255,
'clip': token_count,
'clip_max': max_length,
}

@torch.inference_mode()
def encode_first_stage(self, x):
Expand Down
9 changes: 8 additions & 1 deletion backend/diffusion_engine/sd20.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def get_learned_conditioning(self, prompt: list[str]):
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine.process_texts([prompt])
return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count)
max_length = self.text_processing_engine.get_target_prompt_token_count(token_count)

return {
't5': 0,
't5_max': 255,
'clip': token_count,
'clip_max': max_length,
}

@torch.inference_mode()
def encode_first_stage(self, x):
Expand Down
9 changes: 8 additions & 1 deletion backend/diffusion_engine/sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,14 @@ def get_learned_conditioning(self, prompt: list[str]):
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine_l.process_texts([prompt])
return token_count, self.text_processing_engine_l.get_target_prompt_token_count(token_count)
max_length = self.text_processing_engine_l.get_target_prompt_token_count(token_count)

return {
't5': 0,
't5_max': 255,
'clip': token_count,
'clip_max': max_length,
}

@torch.inference_mode()
def encode_first_stage(self, x):
Expand Down
61 changes: 51 additions & 10 deletions javascript/token-counters.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ function recalculate_prompts_img2img() {
return Array.from(arguments);
}

function setupTokenCounting(id, id_counter, id_button) {
function setupSingleTokenCounting(id, id_counter, id_button) {
var prompt = gradioApp().getElementById(id);
var counter = gradioApp().getElementById(id_counter);
var textarea = gradioApp().querySelector(`#${id} > label > textarea`);

if (counter.parentElement == prompt.parentElement) {
if (counter.parentElement === prompt.parentElement) {
return;
}

Expand All @@ -64,24 +64,65 @@ function setupTokenCounting(id, id_counter, id_button) {
promptTokenCountUpdateFunctions[id_button] = func;
}

function toggleTokenCountingVisibility(id, id_counter, id_button) {
function setupDualTokenCounting(id, id_t5_counter, id_clip_counter, id_button) {
var prompt = gradioApp().getElementById(id);
var t5_counter = gradioApp().getElementById(id_t5_counter);
var clip_counter = gradioApp().getElementById(id_clip_counter);
var textarea = gradioApp().querySelector(`#${id} > label > textarea`);

if (t5_counter.parentElement === prompt.parentElement && clip_counter.parentElement === prompt.parentElement) {
return;
}

prompt.parentElement.insertBefore(t5_counter, prompt);
prompt.parentElement.insertBefore(clip_counter, prompt);
prompt.parentElement.style.position = "relative";

var func = onEdit(id, textarea, 800, function() {
if (t5_counter.classList.contains("token-counter-visible") || clip_counter.classList.contains("token-counter-visible")) {
gradioApp().getElementById(id_button)?.click();
}
});
promptTokenCountUpdateFunctions[id] = func;
promptTokenCountUpdateFunctions[id_button] = func;
}

function toggleSingleTokenCountingVisibility(id, id_counter, id_button) {
var counter = gradioApp().getElementById(id_counter);
var shouldDisplay = !opts.disable_token_counters;

counter.style.display = opts.disable_token_counters ? "none" : "block";
counter.classList.toggle("token-counter-visible", !opts.disable_token_counters);
counter.style.display = shouldDisplay ? "block" : "none";
counter.classList.toggle("token-counter-visible", shouldDisplay);
}

function runCodeForTokenCounters(fun) {
fun('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
function toggleDualTokenCountingVisibility(id, id_t5_counter, id_clip_counter, id_button) {
var t5_counter = gradioApp().getElementById(id_t5_counter);
var clip_counter = gradioApp().getElementById(id_clip_counter);
var shouldDisplay = !opts.disable_token_counters;

t5_counter.style.display = shouldDisplay ? "block" : "none";
clip_counter.style.display = shouldDisplay ? "block" : "none";

t5_counter.classList.toggle("token-counter-visible", shouldDisplay);
clip_counter.classList.toggle("token-counter-visible", shouldDisplay);
}

function runCodeForSingleTokenCounters(fun) {
fun('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
fun('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
fun('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
}

function runCodeForDualTokenCounters(fun) {
fun('txt2img_prompt', 'txt2img_t5_token_counter', 'txt2img_token_counter', 'txt2img_token_button');
fun('img2img_prompt', 'img2img_t5_token_counter', 'img2img_token_counter', 'img2img_token_button');
}

onUiLoaded(function() {
runCodeForTokenCounters(setupTokenCounting);
runCodeForSingleTokenCounters(setupSingleTokenCounting);
runCodeForDualTokenCounters(setupDualTokenCounting);
});

onOptionsChanged(function() {
runCodeForTokenCounters(toggleTokenCountingVisibility);
runCodeForSingleTokenCounters(toggleSingleTokenCountingVisibility);
runCodeForDualTokenCounters(toggleDualTokenCountingVisibility);
});
74 changes: 59 additions & 15 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def connect_clear_prompt(button):
)


def wrap_counter_value(value):
return f"<span class='gr-box gr-text-input'>{value}</span>"


def update_token_counter(text, steps, styles, *, is_positive=True):
params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive)
script_callbacks.before_token_counter_callback(params)
Expand All @@ -171,7 +175,11 @@ def update_token_counter(text, steps, styles, *, is_positive=True):
is_positive = params.is_positive

if shared.opts.include_styles_into_token_counters:
apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt
apply_styles = (
shared.prompt_styles.apply_styles_to_prompt
if is_positive
else shared.prompt_styles.apply_negative_styles_to_prompt
)
text = apply_styles(text, styles)

try:
Expand All @@ -184,21 +192,57 @@ def update_token_counter(text, steps, styles, *, is_positive=True):

prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)

except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [[[steps, text]]]

try:
get_prompt_lengths_on_ui = sd_models.model_data.sd_model.get_prompt_lengths_on_ui
assert get_prompt_lengths_on_ui is not None

except Exception:
return f"<span class='gr-box gr-text-input'>?/?</span>"
counter_value = wrap_counter_value('?/?')
if is_positive:
counter_value = [counter_value] * 2

flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
return counter_value

flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([get_prompt_lengths_on_ui(prompt) for prompt in prompts], key=lambda args: args[0])
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
counts = [get_prompt_lengths_on_ui(prompt) for prompt in prompts]

clip_token_count = clip_max_length = t5_token_count = t5_max_length = None

if is_positive:
for count in counts:
if 'clip' in count and count['clip'] is not None:
clip_token_count = max(clip_token_count or 0, count['clip'])
clip_max_length = count['clip_max']
if 't5' in count and count['t5'] is not None:
t5_token_count = max(t5_token_count or 0, count['t5'])
t5_max_length = count['t5_max']

clip_counter_text = (
wrap_counter_value(f"{clip_token_count}/{clip_max_length}")
if clip_token_count is not None
else wrap_counter_value('-/-')
)

t5_counter_text = (
wrap_counter_value(f"{t5_token_count}/{t5_max_length}")
if t5_token_count is not None
else wrap_counter_value('-/-')
)

counter_value = [t5_counter_text, clip_counter_text]
else:
for count in counts:
if 'clip' in count and count['clip'] is not None:
clip_token_count = max(clip_token_count or 0, count['clip'])
clip_max_length = count['clip_max']

counter_value = (
wrap_counter_value(f"{clip_token_count}/{clip_max_length}")
if clip_token_count is not None
else wrap_counter_value('-/-')
)

return counter_value


def update_negative_prompt_token_counter(*args):
Expand Down Expand Up @@ -504,9 +548,9 @@ def create_ui():

steps = scripts.scripts_txt2img.script('Sampler').steps

toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.t5_token_counter, toprow.token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.t5_token_counter, toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])

extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
Expand Down Expand Up @@ -851,9 +895,9 @@ def select_img2img_tab(tab):

steps = scripts.scripts_img2img.script('Sampler').steps

toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.t5_token_counter, toprow.token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.t5_token_counter, toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])

img2img_paste_fields = [
Expand Down
4 changes: 3 additions & 1 deletion modules/ui_toprow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Toprow:
apply_styles = None
restore_progress_button = None

t5_token_counter = None
token_counter = None
token_button = None
negative_token_counter = None
Expand Down Expand Up @@ -127,9 +128,10 @@ def create_tools_row(self):

self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress")

self.t5_token_counter = gr.HTML(value="<span>0/255</span>", elem_id=f"{self.id_part}_t5_token_counter", elem_classes=["t5-token-counter"], visible=False)
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"], visible=False)
self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button")
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"], visible=False)
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["negative-token-counter"], visible=False)
self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button")

self.clear_prompt_button.click(
Expand Down
Loading