From be3c51856ab9d71d98f9d2cd6cdae19ec6fb0cc0 Mon Sep 17 00:00:00 2001 From: William Bradford Clark Date: Tue, 24 Sep 2024 10:42:47 -0400 Subject: [PATCH] Allow split T5 & CLIP prompts for flux & add a separate T5 token counter --- backend/diffusion_engine/base.py | 7 ++- backend/diffusion_engine/flux.py | 38 ++++++++++++++-- backend/diffusion_engine/sd15.py | 9 +++- backend/diffusion_engine/sd20.py | 9 +++- backend/diffusion_engine/sdxl.py | 9 +++- javascript/token-counters.js | 61 +++++++++++++++++++++----- modules/ui.py | 74 +++++++++++++++++++++++++------- modules/ui_toprow.py | 4 +- style.css | 68 +++++++++++++++++++++++++++++ 9 files changed, 245 insertions(+), 34 deletions(-) diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py index 89a6055e8..d867544a3 100644 --- a/backend/diffusion_engine/base.py +++ b/backend/diffusion_engine/base.py @@ -51,7 +51,12 @@ def decode_first_stage(self, x): pass def get_prompt_lengths_on_ui(self, prompt): - return 0, 75 + return { + 't5': 0, + 't5_max': 255, + 'clip': 0, + 'clip_max': 75, + } def is_webui_legacy_model(self): return self.is_sd1 or self.is_sd2 or self.is_sdxl or self.is_sd3 diff --git a/backend/diffusion_engine/flux.py b/backend/diffusion_engine/flux.py index 8d4589c20..94b88dc67 100644 --- a/backend/diffusion_engine/flux.py +++ b/backend/diffusion_engine/flux.py @@ -74,9 +74,22 @@ def set_clip_skip(self, clip_skip): @torch.inference_mode() def get_learned_conditioning(self, prompt: list[str]): + prompt_t5 = [] + prompt_l = [] + + for p in prompt: + if 'SPLIT' in p: + before_split, after_split = p.split('SPLIT', 1) + prompt_t5.append(before_split.strip()) + prompt_l.append(after_split.strip()) + else: + prompt_t5.append(p) + prompt_l.append(p) + memory_management.load_model_gpu(self.forge_objects.clip.patcher) - cond_l, pooled_l = self.text_processing_engine_l(prompt) - cond_t5 = self.text_processing_engine_t5(prompt) + + cond_l, pooled_l = self.text_processing_engine_l(prompt_l) + cond_t5 = self.text_processing_engine_t5(prompt_t5) cond = dict(crossattn=cond_t5, vector=pooled_l) if self.use_distilled_cfg_scale: @@ -90,8 +103,25 @@ def get_learned_conditioning(self, prompt: list[str]): @torch.inference_mode() def get_prompt_lengths_on_ui(self, prompt): - token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0]) - return token_count, max(255, token_count) + if 'SPLIT' in prompt: + prompt_t5, prompt_l = prompt.split('SPLIT', 1) + prompt_t5 = prompt_t5.strip() + prompt_l = prompt_l.strip() + else: + prompt_t5 = prompt_l = prompt + + t5_token_count = len(self.text_processing_engine_t5.tokenize([prompt_t5])[0]) + t5_max_length = max(255, t5_token_count) + + _, clip_token_count = self.text_processing_engine_l.process_texts([prompt_l]) + clip_max_length = self.text_processing_engine_l.get_target_prompt_token_count(clip_token_count) + + return { + 't5': t5_token_count, + 't5_max': t5_max_length, + 'clip': clip_token_count, + 'clip_max': clip_max_length, + } @torch.inference_mode() def encode_first_stage(self, x): diff --git a/backend/diffusion_engine/sd15.py b/backend/diffusion_engine/sd15.py index af47eb53c..d8832f4a2 100644 --- a/backend/diffusion_engine/sd15.py +++ b/backend/diffusion_engine/sd15.py @@ -66,7 +66,14 @@ def get_learned_conditioning(self, prompt: list[str]): @torch.inference_mode() def get_prompt_lengths_on_ui(self, prompt): _, token_count = self.text_processing_engine.process_texts([prompt]) - return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count) + max_length = self.text_processing_engine.get_target_prompt_token_count(token_count) + + return { + 't5': 0, + 't5_max': 255, + 'clip': token_count, + 'clip_max': max_length, + } @torch.inference_mode() def encode_first_stage(self, x): diff --git a/backend/diffusion_engine/sd20.py b/backend/diffusion_engine/sd20.py index adb69528c..3df31405f 100644 --- a/backend/diffusion_engine/sd20.py +++ b/backend/diffusion_engine/sd20.py @@ -66,7 +66,14 @@ def get_learned_conditioning(self, prompt: list[str]): @torch.inference_mode() def get_prompt_lengths_on_ui(self, prompt): _, token_count = self.text_processing_engine.process_texts([prompt]) - return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count) + max_length = self.text_processing_engine.get_target_prompt_token_count(token_count) + + return { + 't5': 0, + 't5_max': 255, + 'clip': token_count, + 'clip_max': max_length, + } @torch.inference_mode() def encode_first_stage(self, x): diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py index 0873da189..8a7895ec0 100644 --- a/backend/diffusion_engine/sdxl.py +++ b/backend/diffusion_engine/sdxl.py @@ -118,7 +118,14 @@ def get_learned_conditioning(self, prompt: list[str]): @torch.inference_mode() def get_prompt_lengths_on_ui(self, prompt): _, token_count = self.text_processing_engine_l.process_texts([prompt]) - return token_count, self.text_processing_engine_l.get_target_prompt_token_count(token_count) + max_length = self.text_processing_engine_l.get_target_prompt_token_count(token_count) + + return { + 't5': 0, + 't5_max': 255, + 'clip': token_count, + 'clip_max': max_length, + } @torch.inference_mode() def encode_first_stage(self, x): diff --git a/javascript/token-counters.js b/javascript/token-counters.js index eeea7a5d2..aafe20664 100644 --- a/javascript/token-counters.js +++ b/javascript/token-counters.js @@ -43,12 +43,12 @@ function recalculate_prompts_img2img() { return Array.from(arguments); } -function setupTokenCounting(id, id_counter, id_button) { +function setupSingleTokenCounting(id, id_counter, id_button) { var prompt = gradioApp().getElementById(id); var counter = gradioApp().getElementById(id_counter); var textarea = gradioApp().querySelector(`#${id} > label > textarea`); - if (counter.parentElement == prompt.parentElement) { + if (counter.parentElement === prompt.parentElement) { return; } @@ -64,24 +64,65 @@ function setupTokenCounting(id, id_counter, id_button) { promptTokenCountUpdateFunctions[id_button] = func; } -function toggleTokenCountingVisibility(id, id_counter, id_button) { +function setupDualTokenCounting(id, id_t5_counter, id_clip_counter, id_button) { + var prompt = gradioApp().getElementById(id); + var t5_counter = gradioApp().getElementById(id_t5_counter); + var clip_counter = gradioApp().getElementById(id_clip_counter); + var textarea = gradioApp().querySelector(`#${id} > label > textarea`); + + if (t5_counter.parentElement === prompt.parentElement && clip_counter.parentElement === prompt.parentElement) { + return; + } + + prompt.parentElement.insertBefore(t5_counter, prompt); + prompt.parentElement.insertBefore(clip_counter, prompt); + prompt.parentElement.style.position = "relative"; + + var func = onEdit(id, textarea, 800, function() { + if (t5_counter.classList.contains("token-counter-visible") || clip_counter.classList.contains("token-counter-visible")) { + gradioApp().getElementById(id_button)?.click(); + } + }); + promptTokenCountUpdateFunctions[id] = func; + promptTokenCountUpdateFunctions[id_button] = func; +} + +function toggleSingleTokenCountingVisibility(id, id_counter, id_button) { var counter = gradioApp().getElementById(id_counter); + var shouldDisplay = !opts.disable_token_counters; - counter.style.display = opts.disable_token_counters ? "none" : "block"; - counter.classList.toggle("token-counter-visible", !opts.disable_token_counters); + counter.style.display = shouldDisplay ? "block" : "none"; + counter.classList.toggle("token-counter-visible", shouldDisplay); } -function runCodeForTokenCounters(fun) { - fun('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); +function toggleDualTokenCountingVisibility(id, id_t5_counter, id_clip_counter, id_button) { + var t5_counter = gradioApp().getElementById(id_t5_counter); + var clip_counter = gradioApp().getElementById(id_clip_counter); + var shouldDisplay = !opts.disable_token_counters; + + t5_counter.style.display = shouldDisplay ? "block" : "none"; + clip_counter.style.display = shouldDisplay ? "block" : "none"; + + t5_counter.classList.toggle("token-counter-visible", shouldDisplay); + clip_counter.classList.toggle("token-counter-visible", shouldDisplay); +} + +function runCodeForSingleTokenCounters(fun) { fun('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); - fun('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); fun('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); } +function runCodeForDualTokenCounters(fun) { + fun('txt2img_prompt', 'txt2img_t5_token_counter', 'txt2img_token_counter', 'txt2img_token_button'); + fun('img2img_prompt', 'img2img_t5_token_counter', 'img2img_token_counter', 'img2img_token_button'); +} + onUiLoaded(function() { - runCodeForTokenCounters(setupTokenCounting); + runCodeForSingleTokenCounters(setupSingleTokenCounting); + runCodeForDualTokenCounters(setupDualTokenCounting); }); onOptionsChanged(function() { - runCodeForTokenCounters(toggleTokenCountingVisibility); + runCodeForSingleTokenCounters(toggleSingleTokenCountingVisibility); + runCodeForDualTokenCounters(toggleDualTokenCountingVisibility); }); diff --git a/modules/ui.py b/modules/ui.py index f9c7f493c..6ed3191d3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -162,6 +162,10 @@ def connect_clear_prompt(button): ) +def wrap_counter_value(value): + return f"{value}" + + def update_token_counter(text, steps, styles, *, is_positive=True): params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive) script_callbacks.before_token_counter_callback(params) @@ -171,7 +175,11 @@ def update_token_counter(text, steps, styles, *, is_positive=True): is_positive = params.is_positive if shared.opts.include_styles_into_token_counters: - apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt + apply_styles = ( + shared.prompt_styles.apply_styles_to_prompt + if is_positive + else shared.prompt_styles.apply_negative_styles_to_prompt + ) text = apply_styles(text, styles) try: @@ -184,21 +192,57 @@ def update_token_counter(text, steps, styles, *, is_positive=True): prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - try: get_prompt_lengths_on_ui = sd_models.model_data.sd_model.get_prompt_lengths_on_ui assert get_prompt_lengths_on_ui is not None + except Exception: - return f"?/?" + counter_value = wrap_counter_value('?/?') + if is_positive: + counter_value = [counter_value] * 2 - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + return counter_value + + flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([get_prompt_lengths_on_ui(prompt) for prompt in prompts], key=lambda args: args[0]) - return f"{token_count}/{max_length}" + counts = [get_prompt_lengths_on_ui(prompt) for prompt in prompts] + + clip_token_count = clip_max_length = t5_token_count = t5_max_length = None + + if is_positive: + for count in counts: + if 'clip' in count and count['clip'] is not None: + clip_token_count = max(clip_token_count or 0, count['clip']) + clip_max_length = count['clip_max'] + if 't5' in count and count['t5'] is not None: + t5_token_count = max(t5_token_count or 0, count['t5']) + t5_max_length = count['t5_max'] + + clip_counter_text = ( + wrap_counter_value(f"{clip_token_count}/{clip_max_length}") + if clip_token_count is not None + else wrap_counter_value('-/-') + ) + + t5_counter_text = ( + wrap_counter_value(f"{t5_token_count}/{t5_max_length}") + if t5_token_count is not None + else wrap_counter_value('-/-') + ) + + counter_value = [t5_counter_text, clip_counter_text] + else: + for count in counts: + if 'clip' in count and count['clip'] is not None: + clip_token_count = max(clip_token_count or 0, count['clip']) + clip_max_length = count['clip_max'] + + counter_value = ( + wrap_counter_value(f"{clip_token_count}/{clip_max_length}") + if clip_token_count is not None + else wrap_counter_value('-/-') + ) + + return counter_value def update_negative_prompt_token_counter(*args): @@ -504,9 +548,9 @@ def create_ui(): steps = scripts.scripts_txt2img.script('Sampler').steps - toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) + toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.t5_token_counter, toprow.token_counter]) toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) - toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) + toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.t5_token_counter, toprow.token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') @@ -851,9 +895,9 @@ def select_img2img_tab(tab): steps = scripts.scripts_img2img.script('Sampler').steps - toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) + toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.t5_token_counter, toprow.token_counter]) toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) - toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) + toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.t5_token_counter, toprow.token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) img2img_paste_fields = [ diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 622ed5870..b1854443d 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -26,6 +26,7 @@ class Toprow: apply_styles = None restore_progress_button = None + t5_token_counter = None token_counter = None token_button = None negative_token_counter = None @@ -127,9 +128,10 @@ def create_tools_row(self): self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress") + self.t5_token_counter = gr.HTML(value="0/255", elem_id=f"{self.id_part}_t5_token_counter", elem_classes=["t5-token-counter"], visible=False) self.token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"], visible=False) self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button") - self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"], visible=False) + self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["negative-token-counter"], visible=False) self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button") self.clear_prompt_button.click( diff --git a/style.css b/style.css index 4bcaa110c..16ab4d956 100644 --- a/style.css +++ b/style.css @@ -250,6 +250,74 @@ input[type="checkbox"].input-accordion-checkbox{ padding: 0.1em 0.75em; } +.block.t5-token-counter{ + position: absolute; + display: inline-block; + right: 6em; + min-width: 0 !important; + width: auto; + z-index: 100; + top: -0.75em; +} + +.block.t5-token-counter-visible{ + display: block !important; +} + +.block.t5-token-counter span{ + background: var(--input-background-fill) !important; + box-shadow: 0 0 0.0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075); + border: 2px solid rgba(192,192,192,0.4) !important; + border-radius: 0.4em; +} + +.block.t5-token-counter.error span{ + box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); + border: 2px solid rgba(255,0,0,0.4) !important; +} + +.block.t5-token-counter div{ + display: inline; +} + +.block.t5-token-counter span{ + padding: 0.1em 0.75em; +} + +.block.negative-token-counter{ + position: absolute; + display: inline-block; + right: 1em; + min-width: 0 !important; + width: auto; + z-index: 100; + top: -0.75em; +} + +.block.negative-token-counter-visible{ + display: block !important; +} + +.block.negative-token-counter span{ + background: var(--input-background-fill) !important; + box-shadow: 0 0 0.0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075); + border: 2px solid rgba(192,192,192,0.4) !important; + border-radius: 0.4em; +} + +.block.negative-token-counter.error span{ + box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); + border: 2px solid rgba(255,0,0,0.4) !important; +} + +.block.negative-token-counter div{ + display: inline; +} + +.block.negative-token-counter span{ + padding: 0.1em 0.75em; +} + [id$=_subseed_show]{ min-width: auto !important; flex-grow: 0 !important;