Skip to content

Commit 0ba7d01

Browse files
Merge pull request #110 from williamberman/will/inpainting-fixes
add fixes to inpainting pipeline
2 parents ec965ed + 0f4171c commit 0ba7d01

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

muse/pipeline_muse.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ def __call__(
346346
generator: Optional[torch.Generator] = None,
347347
use_fp16: bool = False,
348348
image_size: int = 256,
349+
orig_size=(256, 256),
350+
crop_coords=(0, 0),
351+
aesthetic_score=6.0,
349352
):
350353
from torchvision import transforms
351354

@@ -366,7 +369,9 @@ def __call__(
366369
pixel_values = encode_transform(image).unsqueeze(0).to(self.device)
367370
_, image_tokens = self.vae.encode(pixel_values)
368371
mask_token_id = self.transformer.config.mask_token_id
372+
369373
image_tokens[mask[None]] = mask_token_id
374+
370375
image_tokens = image_tokens.repeat(num_images_per_prompt, 1)
371376
if class_ids is not None:
372377
if isinstance(class_ids, int):
@@ -388,7 +393,13 @@ def __call__(
388393
max_length=self.tokenizer.model_max_length,
389394
).input_ids # TODO: remove hardcode
390395
input_ids = input_ids.to(self.device)
391-
encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state
396+
397+
if self.transformer.config.add_cond_embeds:
398+
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
399+
pooled_embeds, encoder_hidden_states = outputs.text_embeds, outputs.hidden_states[-2]
400+
else:
401+
encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state
402+
pooled_embeds = None
392403

393404
if negative_text is not None:
394405
if isinstance(negative_text, str):
@@ -417,10 +428,27 @@ def __call__(
417428
bs_embed * num_images_per_prompt, seq_len, -1
418429
)
419430

431+
empty_input = self.tokenizer("", padding="max_length", return_tensors="pt").input_ids.to(
432+
self.text_encoder.device
433+
)
434+
outputs = self.text_encoder(empty_input, output_hidden_states=True)
435+
empty_embeds = outputs.hidden_states[-2]
436+
empty_cond_embeds = outputs[0]
437+
420438
model_inputs = {
421439
"encoder_hidden_states": encoder_hidden_states,
422440
"negative_embeds": negative_encoder_hidden_states,
441+
"empty_embeds": empty_embeds,
442+
"empty_cond_embeds": empty_cond_embeds,
443+
"cond_embeds": pooled_embeds,
423444
}
445+
446+
if self.transformer.config.add_micro_cond_embeds:
447+
micro_conds = list(orig_size) + list(crop_coords) + [aesthetic_score]
448+
micro_conds = torch.tensor(micro_conds, device=self.device, dtype=encoder_hidden_states.dtype)
449+
micro_conds = micro_conds.unsqueeze(0)
450+
model_inputs["micro_conds"] = micro_conds
451+
424452
generate = self.transformer.generate2
425453
with torch.autocast("cuda", enabled=use_fp16):
426454
generated_tokens = generate(

0 commit comments

Comments
 (0)